diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..24aa54d81a0af0eeffa744a6c81e8a79f5cb76a3
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+*.pyc
+__pycache__
+test.py
+flagged
+output
+gradio_cached*
+dist*
+*egg-info
+build*
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d56ccfbb7496de8592c9745d0d9a5e390af75fd6
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,437 @@
+Attribution-NonCommercial-ShareAlike 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.cp
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+     Considerations for licensors: Our public licenses are
+     intended for use by those authorized to give the public
+     permission to use material in ways otherwise restricted by
+     copyright and certain other rights. Our licenses are
+     irrevocable. Licensors should read and understand the terms
+     and conditions of the license they choose before applying it.
+     Licensors should also secure all rights necessary before
+     applying our licenses so that the public can reuse the
+     material as expected. Licensors should clearly mark any
+     material not subject to the license. This includes other CC-
+     licensed material, or material used under an exception or
+     limitation to copyright. More considerations for licensors:
+	wiki.creativecommons.org/Considerations_for_licensors
+
+     Considerations for the public: By using one of our public
+     licenses, a licensor grants the public permission to use the
+     licensed material under specified terms and conditions. If
+     the licensor's permission is not necessary for any reason--for
+     example, because of any applicable exception or limitation to
+     copyright--then that use is not regulated by the license. Our
+     licenses grant only permissions under copyright and certain
+     other rights that a licensor has authority to grant. Use of
+     the licensed material may still be restricted for other
+     reasons, including because others have copyright or other
+     rights in the material. A licensor may make special requests,
+     such as asking that all changes be marked or described.
+     Although not required by our licenses, you are encouraged to
+     respect those requests where reasonable. More_considerations
+     for the public: 
+	wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
+Public License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial-ShareAlike 4.0 International Public License
+("Public License"). To the extent this Public License may be
+interpreted as a contract, You are granted the Licensed Rights in
+consideration of Your acceptance of these terms and conditions, and the
+Licensor grants You such rights in consideration of benefits the
+Licensor receives from making the Licensed Material available under
+these terms and conditions.
+
+
+Section 1 -- Definitions.
+
+  a. Adapted Material means material subject to Copyright and Similar
+     Rights that is derived from or based upon the Licensed Material
+     and in which the Licensed Material is translated, altered,
+     arranged, transformed, or otherwise modified in a manner requiring
+     permission under the Copyright and Similar Rights held by the
+     Licensor. For purposes of this Public License, where the Licensed
+     Material is a musical work, performance, or sound recording,
+     Adapted Material is always produced where the Licensed Material is
+     synched in timed relation with a moving image.
+
+  b. Adapter's License means the license You apply to Your Copyright
+     and Similar Rights in Your contributions to Adapted Material in
+     accordance with the terms and conditions of this Public License.
+
+  c. BY-NC-SA Compatible License means a license listed at
+     creativecommons.org/compatiblelicenses, approved by Creative
+     Commons as essentially the equivalent of this Public License.
+
+  d. Copyright and Similar Rights means copyright and/or similar rights
+     closely related to copyright including, without limitation,
+     performance, broadcast, sound recording, and Sui Generis Database
+     Rights, without regard to how the rights are labeled or
+     categorized. For purposes of this Public License, the rights
+     specified in Section 2(b)(1)-(2) are not Copyright and Similar
+     Rights.
+
+  e. Effective Technological Measures means those measures that, in the
+     absence of proper authority, may not be circumvented under laws
+     fulfilling obligations under Article 11 of the WIPO Copyright
+     Treaty adopted on December 20, 1996, and/or similar international
+     agreements.
+
+  f. Exceptions and Limitations means fair use, fair dealing, and/or
+     any other exception or limitation to Copyright and Similar Rights
+     that applies to Your use of the Licensed Material.
+
+  g. License Elements means the license attributes listed in the name
+     of a Creative Commons Public License. The License Elements of this
+     Public License are Attribution, NonCommercial, and ShareAlike.
+
+  h. Licensed Material means the artistic or literary work, database,
+     or other material to which the Licensor applied this Public
+     License.
+
+  i. Licensed Rights means the rights granted to You subject to the
+     terms and conditions of this Public License, which are limited to
+     all Copyright and Similar Rights that apply to Your use of the
+     Licensed Material and that the Licensor has authority to license.
+
+  j. Licensor means the individual(s) or entity(ies) granting rights
+     under this Public License.
+
+  k. NonCommercial means not primarily intended for or directed towards
+     commercial advantage or monetary compensation. For purposes of
+     this Public License, the exchange of the Licensed Material for
+     other material subject to Copyright and Similar Rights by digital
+     file-sharing or similar means is NonCommercial provided there is
+     no payment of monetary compensation in connection with the
+     exchange.
+
+  l. Share means to provide material to the public by any means or
+     process that requires permission under the Licensed Rights, such
+     as reproduction, public display, public performance, distribution,
+     dissemination, communication, or importation, and to make material
+     available to the public including in ways that members of the
+     public may access the material from a place and at a time
+     individually chosen by them.
+
+  m. Sui Generis Database Rights means rights other than copyright
+     resulting from Directive 96/9/EC of the European Parliament and of
+     the Council of 11 March 1996 on the legal protection of databases,
+     as amended and/or succeeded, as well as other essentially
+     equivalent rights anywhere in the world.
+
+  n. You means the individual or entity exercising the Licensed Rights
+     under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+  a. License grant.
+
+       1. Subject to the terms and conditions of this Public License,
+          the Licensor hereby grants You a worldwide, royalty-free,
+          non-sublicensable, non-exclusive, irrevocable license to
+          exercise the Licensed Rights in the Licensed Material to:
+
+            a. reproduce and Share the Licensed Material, in whole or
+               in part, for NonCommercial purposes only; and
+
+            b. produce, reproduce, and Share Adapted Material for
+               NonCommercial purposes only.
+
+       2. Exceptions and Limitations. For the avoidance of doubt, where
+          Exceptions and Limitations apply to Your use, this Public
+          License does not apply, and You do not need to comply with
+          its terms and conditions.
+
+       3. Term. The term of this Public License is specified in Section
+          6(a).
+
+       4. Media and formats; technical modifications allowed. The
+          Licensor authorizes You to exercise the Licensed Rights in
+          all media and formats whether now known or hereafter created,
+          and to make technical modifications necessary to do so. The
+          Licensor waives and/or agrees not to assert any right or
+          authority to forbid You from making technical modifications
+          necessary to exercise the Licensed Rights, including
+          technical modifications necessary to circumvent Effective
+          Technological Measures. For purposes of this Public License,
+          simply making modifications authorized by this Section 2(a)
+          (4) never produces Adapted Material.
+
+       5. Downstream recipients.
+
+            a. Offer from the Licensor -- Licensed Material. Every
+               recipient of the Licensed Material automatically
+               receives an offer from the Licensor to exercise the
+               Licensed Rights under the terms and conditions of this
+               Public License.
+
+            b. Additional offer from the Licensor -- Adapted Material.
+               Every recipient of Adapted Material from You
+               automatically receives an offer from the Licensor to
+               exercise the Licensed Rights in the Adapted Material
+               under the conditions of the Adapter's License You apply.
+
+            c. No downstream restrictions. You may not offer or impose
+               any additional or different terms or conditions on, or
+               apply any Effective Technological Measures to, the
+               Licensed Material if doing so restricts exercise of the
+               Licensed Rights by any recipient of the Licensed
+               Material.
+
+       6. No endorsement. Nothing in this Public License constitutes or
+          may be construed as permission to assert or imply that You
+          are, or that Your use of the Licensed Material is, connected
+          with, or sponsored, endorsed, or granted official status by,
+          the Licensor or others designated to receive attribution as
+          provided in Section 3(a)(1)(A)(i).
+
+  b. Other rights.
+
+       1. Moral rights, such as the right of integrity, are not
+          licensed under this Public License, nor are publicity,
+          privacy, and/or other similar personality rights; however, to
+          the extent possible, the Licensor waives and/or agrees not to
+          assert any such rights held by the Licensor to the limited
+          extent necessary to allow You to exercise the Licensed
+          Rights, but not otherwise.
+
+       2. Patent and trademark rights are not licensed under this
+          Public License.
+
+       3. To the extent possible, the Licensor waives any right to
+          collect royalties from You for the exercise of the Licensed
+          Rights, whether directly or through a collecting society
+          under any voluntary or waivable statutory or compulsory
+          licensing scheme. In all other cases the Licensor expressly
+          reserves any right to collect such royalties, including when
+          the Licensed Material is used other than for NonCommercial
+          purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+  a. Attribution.
+
+       1. If You Share the Licensed Material (including in modified
+          form), You must:
+
+            a. retain the following if it is supplied by the Licensor
+               with the Licensed Material:
+
+                 i. identification of the creator(s) of the Licensed
+                    Material and any others designated to receive
+                    attribution, in any reasonable manner requested by
+                    the Licensor (including by pseudonym if
+                    designated);
+
+                ii. a copyright notice;
+
+               iii. a notice that refers to this Public License;
+
+                iv. a notice that refers to the disclaimer of
+                    warranties;
+
+                 v. a URI or hyperlink to the Licensed Material to the
+                    extent reasonably practicable;
+
+            b. indicate if You modified the Licensed Material and
+               retain an indication of any previous modifications; and
+
+            c. indicate the Licensed Material is licensed under this
+               Public License, and include the text of, or the URI or
+               hyperlink to, this Public License.
+
+       2. You may satisfy the conditions in Section 3(a)(1) in any
+          reasonable manner based on the medium, means, and context in
+          which You Share the Licensed Material. For example, it may be
+          reasonable to satisfy the conditions by providing a URI or
+          hyperlink to a resource that includes the required
+          information.
+       3. If requested by the Licensor, You must remove any of the
+          information required by Section 3(a)(1)(A) to the extent
+          reasonably practicable.
+
+  b. ShareAlike.
+
+     In addition to the conditions in Section 3(a), if You Share
+     Adapted Material You produce, the following conditions also apply.
+
+       1. The Adapter's License You apply must be a Creative Commons
+          license with the same License Elements, this version or
+          later, or a BY-NC-SA Compatible License.
+
+       2. You must include the text of, or the URI or hyperlink to, the
+          Adapter's License You apply. You may satisfy this condition
+          in any reasonable manner based on the medium, means, and
+          context in which You Share Adapted Material.
+
+       3. You may not offer or impose any additional or different terms
+          or conditions on, or apply any Effective Technological
+          Measures to, Adapted Material that restrict exercise of the
+          rights granted under the Adapter's License You apply.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+  a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+     to extract, reuse, reproduce, and Share all or a substantial
+     portion of the contents of the database for NonCommercial purposes
+     only;
+
+  b. if You include all or a substantial portion of the database
+     contents in a database in which You have Sui Generis Database
+     Rights, then the database in which You have Sui Generis Database
+     Rights (but not its individual contents) is Adapted Material,
+     including for purposes of Section 3(b); and
+
+  c. You must comply with the conditions in Section 3(a) if You Share
+     all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+  a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+     EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+     AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+     ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+     IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+     WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+     PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+     ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+     KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+     ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+  b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+     TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+     NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+     INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+     COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+     USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+     ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+     DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+     IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+  c. The disclaimer of warranties and limitation of liability provided
+     above shall be interpreted in a manner that, to the extent
+     possible, most closely approximates an absolute disclaimer and
+     waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+  a. This Public License applies for the term of the Copyright and
+     Similar Rights licensed here. However, if You fail to comply with
+     this Public License, then Your rights under this Public License
+     terminate automatically.
+
+  b. Where Your right to use the Licensed Material has terminated under
+     Section 6(a), it reinstates:
+
+       1. automatically as of the date the violation is cured, provided
+          it is cured within 30 days of Your discovery of the
+          violation; or
+
+       2. upon express reinstatement by the Licensor.
+
+     For the avoidance of doubt, this Section 6(b) does not affect any
+     right the Licensor may have to seek remedies for Your violations
+     of this Public License.
+
+  c. For the avoidance of doubt, the Licensor may also offer the
+     Licensed Material under separate terms or conditions or stop
+     distributing the Licensed Material at any time; however, doing so
+     will not terminate this Public License.
+
+  d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+     License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+  a. The Licensor shall not be bound by any additional or different
+     terms or conditions communicated by You unless expressly agreed.
+
+  b. Any arrangements, understandings, or agreements regarding the
+     Licensed Material not stated herein are separate from and
+     independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+  a. For the avoidance of doubt, this Public License does not, and
+     shall not be interpreted to, reduce, limit, restrict, or impose
+     conditions on any use of the Licensed Material that could lawfully
+     be made without permission under this Public License.
+
+  b. To the extent possible, if any provision of this Public License is
+     deemed unenforceable, it shall be automatically reformed to the
+     minimum extent necessary to make it enforceable. If the provision
+     cannot be reformed, it shall be severed from this Public License
+     without affecting the enforceability of the remaining terms and
+     conditions.
+
+  c. No term or condition of this Public License will be waived and no
+     failure to comply consented to unless expressly agreed to by the
+     Licensor.
+
+  d. Nothing in this Public License constitutes or may be interpreted
+     as a limitation upon, or waiver of, any privileges and immunities
+     that apply to the Licensor or You, including from the legal
+     processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..5dcbc139b80a520e73d210bb3236cb0a25a31129
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,2 @@
+include *.py LICENSE README.md
+recursive-include audioldm2 *.txt *.py *.gz *.npy *.json
\ No newline at end of file
diff --git a/README.md b/README.md
index d8cd2b7e733713d278e699014b616f91add1cfaa..4d55fdfc128beb20be6f2512363a1c8663bff795 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,23 @@
 ---
-title: Audioldm2 Text2audio Text2music
-emoji: 👁
-colorFrom: gray
-colorTo: green
+title: AudioLDM2 Text2Audio Text2Music Generation
+emoji: 🔊
+colorFrom: indigo
+colorTo: red
 sdk: gradio
-sdk_version: 3.39.0
+sdk_version: 3.27.0
 app_file: app.py
 pinned: false
-license: cc-by-nc-nd-4.0
+license: bigscience-openrail-m
+duplicated_from: haoheliu/audioldm2-text2audio-text2music
+
 ---
 
 Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+## Reference
+Part of the code from this repo is borrowed from the following repos. We would like to thank the authors of them for their contribution. 
+
+> https://github.com/LAION-AI/CLAP
+> https://github.com/CompVis/stable-diffusion
+> https://github.com/v-iashin/SpecVQGAN 
+> https://github.com/toshas/torch-fidelity
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2bba800105bda2f7fd171a68f28c31b00a87bd9
--- /dev/null
+++ b/app.py
@@ -0,0 +1,361 @@
+from huggingface_hub import hf_hub_download
+import torch
+import os
+
+os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+import gradio as gr
+from audioldm2 import text_to_audio, build_model
+from share_btn import community_icon_html, loading_icon_html, share_js
+
+model_id = "haoheliu/audioldm2-full"
+hf_hub_download(repo_id="haoheliu/audioldm2-full", filename="audioldm2-full.pth")
+
+audioldm = None
+current_model_name = None
+
+def text2audio(
+    text,
+    guidance_scale,
+    random_seed,
+    n_candidates,
+    model_name="audioldm2-full",
+):
+    global audioldm, current_model_name
+    torch.set_float32_matmul_precision("high")
+
+    if audioldm is None or model_name != current_model_name:
+        audioldm = build_model(model_name=model_name)
+        current_model_name = model_name
+        audioldm = torch.compile(audioldm)
+
+    # print(text, length, guidance_scale)
+    waveform = text_to_audio(
+        latent_diffusion=audioldm,
+        text=text,
+        seed=random_seed,
+        duration=10,
+        guidance_scale=guidance_scale,
+        n_candidate_gen_per_text=int(n_candidates),
+    )  # [bs, 1, samples]
+    waveform = [
+        gr.make_waveform((16000, wave[0]), bg_image="bg.png") for wave in waveform
+    ]
+    # waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))]
+    if len(waveform) == 1:
+        waveform = waveform[0]
+    return waveform
+
+css = """
+        a {
+            color: inherit;
+            text-decoration: underline;
+        }
+        .gradio-container {
+            font-family: 'IBM Plex Sans', sans-serif;
+        }
+        .gr-button {
+            color: white;
+            border-color: #000000;
+            background: #000000;
+        }
+        input[type='range'] {
+            accent-color: #000000;
+        }
+        .dark input[type='range'] {
+            accent-color: #dfdfdf;
+        }
+        .container {
+            max-width: 730px;
+            margin: auto;
+            padding-top: 1.5rem;
+        }
+        #gallery {
+            min-height: 22rem;
+            margin-bottom: 15px;
+            margin-left: auto;
+            margin-right: auto;
+            border-bottom-right-radius: .5rem !important;
+            border-bottom-left-radius: .5rem !important;
+        }
+        #gallery>div>.h-full {
+            min-height: 20rem;
+        }
+        .details:hover {
+            text-decoration: underline;
+        }
+        .gr-button {
+            white-space: nowrap;
+        }
+        .gr-button:focus {
+            border-color: rgb(147 197 253 / var(--tw-border-opacity));
+            outline: none;
+            box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
+            --tw-border-opacity: 1;
+            --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
+            --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
+            --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
+            --tw-ring-opacity: .5;
+        }
+        #advanced-btn {
+            font-size: .7rem !important;
+            line-height: 19px;
+            margin-top: 12px;
+            margin-bottom: 12px;
+            padding: 2px 8px;
+            border-radius: 14px !important;
+        }
+        #advanced-options {
+            margin-bottom: 20px;
+        }
+        .footer {
+            margin-bottom: 45px;
+            margin-top: 35px;
+            text-align: center;
+            border-bottom: 1px solid #e5e5e5;
+        }
+        .footer>p {
+            font-size: .8rem;
+            display: inline-block;
+            padding: 0 10px;
+            transform: translateY(10px);
+            background: white;
+        }
+        .dark .footer {
+            border-color: #303030;
+        }
+        .dark .footer>p {
+            background: #0b0f19;
+        }
+        .acknowledgments h4{
+            margin: 1.25em 0 .25em 0;
+            font-weight: bold;
+            font-size: 115%;
+        }
+        #container-advanced-btns{
+            display: flex;
+            flex-wrap: wrap;
+            justify-content: space-between;
+            align-items: center;
+        }
+        .animate-spin {
+            animation: spin 1s linear infinite;
+        }
+        @keyframes spin {
+            from {
+                transform: rotate(0deg);
+            }
+            to {
+                transform: rotate(360deg);
+            }
+        }
+        #share-btn-container {
+            display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
+            margin-top: 10px;
+            margin-left: auto;
+        }
+        #share-btn {
+            all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
+        }
+        #share-btn * {
+            all: unset;
+        }
+        #share-btn-container div:nth-child(-n+2){
+            width: auto !important;
+            min-height: 0px !important;
+        }
+        #share-btn-container .wrap {
+            display: none !important;
+        }
+        .gr-form{
+            flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
+        }
+        #prompt-container{
+            gap: 0;
+        }
+        #generated_id{
+            min-height: 700px
+        }
+        #setting_id{
+          margin-bottom: 12px;
+          text-align: center;
+          font-weight: 900;
+        }
+"""
+iface = gr.Blocks(css=css)
+
+with iface:
+    gr.HTML(
+        """
+            <div style="text-align: center; max-width: 700px; margin: 0 auto;">
+              <div
+                style="
+                  display: inline-flex;
+                  align-items: center;
+                  gap: 0.8rem;
+                  font-size: 1.75rem;
+                "
+              >
+                <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
+                  AudioLDM 2: A General Framework for Audio, Music, and Speech Generation
+                </h1>
+              </div>
+              <p style="margin-bottom: 10px; font-size: 94%">
+                <a href="https://arxiv.org/abs/2301.12503">[Paper]</a>  <a href="https://audioldm.github.io/">[Project page]</a>
+              </p>
+            </div>
+        """
+    )
+    gr.HTML(
+        """
+        <h1 style="font-weight: 900; margin-bottom: 7px;">
+        AudioLDM 2: A General Framework for Audio, Music, and Speech Generation
+        </h1>
+        <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
+        <br/>
+        <a href="https://huggingface.co/spaces/haoheliu/audioldm2-text2audio-text2music?duplicate=true">
+        <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
+        <p/>
+    """
+    )
+    with gr.Group():
+        with gr.Box():
+            ############# Input
+            textbox = gr.Textbox(
+                value="A forest of wind chimes singing a soothing melody in the breeze.",
+                max_lines=1,
+                label="Input your text here. Your text is important for the audio quality. Please ensure it is descriptive by using more adjectives.",
+                elem_id="prompt-in",
+            )
+
+            with gr.Accordion("Click to modify detailed configurations", open=False):
+                seed = gr.Number(
+                    value=45,
+                    label="Change this value (any integer number) will lead to a different generation result.",
+                )
+                # duration = gr.Slider(
+                #     10, 10, value=10, step=2.5, label="Duration (seconds)"
+                # )
+                guidance_scale = gr.Slider(
+                    0,
+                    6,
+                    value=3.5,
+                    step=0.5,
+                    label="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
+                )
+                n_candidates = gr.Slider(
+                    1,
+                    3,
+                    value=3,
+                    step=1,
+                    label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
+                )
+                # model_name = gr.Dropdown(
+                #       ["audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full","audioldm-s-full-v2", "audioldm-s-full", "audioldm-l-full"], value="audioldm-m-full", label="Choose the model to use. audioldm-m-text-ft and audioldm-s-text-ft are recommanded. -s- means small, -m- means medium and -l- means large",
+                #   )
+            ############# Output
+            # outputs=gr.Audio(label="Output", type="numpy")
+            outputs = gr.Video(label="Output", elem_id="output-video")
+
+            # with gr.Group(elem_id="container-advanced-btns"):
+            #   # advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
+            #   with gr.Group(elem_id="share-btn-container"):
+            #     community_icon = gr.HTML(community_icon_html, visible=False)
+            #     loading_icon = gr.HTML(loading_icon_html, visible=False)
+            #     share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
+            # outputs=[gr.Audio(label="Output", type="numpy"), gr.Audio(label="Output", type="numpy")]
+            btn = gr.Button("Submit").style(full_width=True)
+
+        with gr.Group(elem_id="share-btn-container", visible=False):
+            community_icon = gr.HTML(community_icon_html)
+            loading_icon = gr.HTML(loading_icon_html)
+            share_button = gr.Button("Share to community", elem_id="share-btn")
+
+        # btn.click(text2audio, inputs=[
+        #           textbox, duration, guidance_scale, seed, n_candidates, model_name], outputs=[outputs])
+        btn.click(
+            text2audio,
+            inputs=[textbox, guidance_scale, seed, n_candidates],
+            outputs=[outputs],
+        )
+
+        share_button.click(None, [], [], _js=share_js)
+        gr.HTML(
+            """
+        <div class="footer" style="text-align: center; max-width: 700px; margin: 0 auto;">
+                    <p>Follow the latest update of AudioLDM on our<a href="https://github.com/haoheliu/AudioLDM" style="text-decoration: underline;" target="_blank"> Github repo</a>
+                    </p>
+                    <br>
+                    <p>Model by <a href="https://twitter.com/LiuHaohe" style="text-decoration: underline;" target="_blank">Haohe Liu</a></p>
+                    <br>
+        </div>
+        """
+        )
+        gr.Examples(
+            [
+                [
+                    "An excited crowd cheering at a sports game.",
+                    3.5,
+                    45,
+                    3,
+                    "audioldm2-full",
+                ],
+                [
+                    "A cat is meowing for attention.",
+                    3.5,
+                    45,
+                    3,
+                    "audioldm2-full",
+                ],
+                [
+                    "Birds singing sweetly in a blooming garden.",
+                    3.5,
+                    45,
+                    3,
+                    "audioldm2-full",
+                ],
+                [
+                    "A modern synthesizer creating futuristic soundscapes.",
+                    3.5,
+                    45,
+                    3,
+                    "audioldm2-full",
+                ],
+                [
+                    "The vibrant beat of Brazilian samba drums.",
+                    3.5,
+                    45,
+                    3,
+                    "audioldm2-full",
+                ],
+            ],
+            fn=text2audio,
+            # inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name],
+            inputs=[textbox, guidance_scale, seed, n_candidates],
+            outputs=[outputs],
+            cache_examples=True,
+        )
+        gr.HTML(
+            """
+                <div class="acknowledgements">
+                <p>Essential Tricks for Enhancing the Quality of Your Generated Audio</p>
+                <p>1. Try to use more adjectives to describe your sound. For example: "A man is speaking clearly and slowly in a large room" is better than "A man is speaking". This can make sure AudioLDM understands what you want.</p>
+                <p>2. Try to use different random seeds, which can affect the generation quality significantly sometimes.</p>
+                <p>3. It's better to use general terms like 'man' or 'woman' instead of specific names for individuals or abstract objects that humans may not be familiar with, such as 'mummy'.</p>
+                </div>
+                """
+        )
+
+        with gr.Accordion("Additional information", open=False):
+            gr.HTML(
+                """
+                <div class="acknowledgments">
+                    <p> We build the model with data from <a href="http://research.google.com/audioset/">AudioSet</a>, <a href="https://freesound.org/">Freesound</a> and <a href="https://sound-effects.bbcrewind.co.uk/">BBC Sound Effect library</a>. We share this demo based on the <a href="https://assets.publishing.service.gov.uk/government/uploads/system/uploads/attachment_data/file/375954/Research.pdf">UK copyright exception</a> of data for academic research. </p>
+                            </div>
+                        """
+            )
+# <p>This demo is strictly for research demo purpose only. For commercial use please <a href="haoheliu@gmail.com">contact us</a>.</p>
+
+iface.queue(concurrency_count=3)
+# iface.launch(debug=True)
+iface.launch(debug=True, share=True)
diff --git a/audioldm2/__init__.py b/audioldm2/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..91befda907125b4772601b1df2c9a8a52b733735
--- /dev/null
+++ b/audioldm2/__init__.py
@@ -0,0 +1,2 @@
+from .utils import seed_everything, save_wave, get_time, get_duration, read_list
+from .pipeline import *
diff --git a/audioldm2/audiomae_gen/__init__.py b/audioldm2/audiomae_gen/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..7202889ac7aeb7e5b344da994206715cbbb3891e
--- /dev/null
+++ b/audioldm2/audiomae_gen/__init__.py
@@ -0,0 +1 @@
+from .sequence_input import Sequence2AudioMAE
diff --git a/audioldm2/audiomae_gen/sequence_input.py b/audioldm2/audiomae_gen/sequence_input.py
new file mode 100755
index 0000000000000000000000000000000000000000..4d961a0dd7157689fab6291bb3c40d9bd656b5f1
--- /dev/null
+++ b/audioldm2/audiomae_gen/sequence_input.py
@@ -0,0 +1,429 @@
+import torch
+import torch.nn as nn
+from audioldm2.latent_diffusion.util import (
+    instantiate_from_config,
+)
+
+# from latent_diffusion.modules.encoders.modules import CLAPAudioEmbeddingClassifierFreev2
+from transformers import GPT2Config, GPT2Model
+import torch.optim.lr_scheduler as lr_scheduler
+
+class Sequence2AudioMAE(nn.Module):
+    def __init__(
+        self,
+        base_learning_rate,
+        sequence_gen_length,
+        sequence_input_key,
+        sequence_input_embed_dim,
+        cond_stage_config,
+        optimizer_type="AdamW",
+        use_warmup=True,
+        use_ar_gen_loss=False,
+        use_audiomae_linear=False,
+        target_tokens_mask_ratio=0.0,
+        random_mask_ratio=False,
+        **kwargs
+    ):
+        super().__init__()
+        assert use_audiomae_linear == False
+        self.random_mask_ratio = random_mask_ratio
+        self.learning_rate = base_learning_rate
+        self.cond_stage_config = cond_stage_config
+        self.use_audiomae_linear = use_audiomae_linear
+        self.optimizer_type = optimizer_type
+        self.use_warmup = use_warmup
+        self.use_ar_gen_loss = use_ar_gen_loss
+        # Even though the LDM can be conditioned on mutliple pooling rate
+        # Our model always predict the higest pooling rate
+
+        # self.time_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["time_pooling_factors"])
+        # self.freq_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["freq_pooling_factors"])
+        # self.mae_token_num = int(512/(self.time_pool*self.freq_pool))
+
+        self.mae_token_num = sequence_gen_length
+        self.sequence_input_key = sequence_input_key
+        self.sequence_input_embed_dim = sequence_input_embed_dim
+        self.target_tokens_mask_ratio = target_tokens_mask_ratio
+
+        self.start_of_sequence_tokens = nn.Embedding(32, 768)
+        self.end_of_sequence_tokens = nn.Embedding(32, 768)
+
+        self.input_sequence_embed_linear = nn.ModuleList([])
+        self.initial_learning_rate = None
+
+        for dim in self.sequence_input_embed_dim:
+            self.input_sequence_embed_linear.append(nn.Linear(dim, 768))
+
+        self.cond_stage_models = nn.ModuleList([])
+        self.instantiate_cond_stage(cond_stage_config)
+        self.initialize_param_check_toolkit()
+
+        # configuration = GPT2Config(n_layer=1) # TODO
+        # self.model=GPT2Model(configuration)
+        ###################
+        # self.model=nn.Linear(768,768, bias=False) # TODO change the model
+        # with torch.no_grad():
+        #     self.model.weight.copy_(torch.eye(768))
+        ###################
+        self.model = GPT2Model(GPT2Config.from_pretrained("gpt2"))
+        ###################
+        # self.model = nn.LSTM(input_size=768, hidden_size=768, num_layers=1,bias=False) # TODO
+
+        # self.loss_fn = nn.MSELoss()
+        self.loss_fn = nn.L1Loss()
+
+        self.logger_save_dir = None
+        self.logger_exp_name = None
+        self.logger_exp_group_name = None
+        self.logger_version = None
+
+    def set_log_dir(self, save_dir, exp_group_name, exp_name):
+        self.logger_save_dir = save_dir
+        self.logger_exp_group_name = exp_group_name
+        self.logger_exp_name = exp_name
+
+    def cfg_uncond(self, batch_size):
+        unconditional_conditioning = {}
+        for key in self.cond_stage_model_metadata:
+            model_idx = self.cond_stage_model_metadata[key]["model_idx"]
+            unconditional_conditioning[key] = self.cond_stage_models[
+                model_idx
+            ].get_unconditional_condition(batch_size)
+        assert (
+            "crossattn_audiomae_pooled" in unconditional_conditioning.keys()
+        ), "The module is not initialized with AudioMAE"
+        unconditional_conditioning[
+            "crossattn_clap_to_audiomae_feature"
+        ] = unconditional_conditioning["crossattn_audiomae_pooled"]
+        return unconditional_conditioning
+
+    def configure_optimizers(self):
+        lr = float(self.learning_rate)
+        # params = list(self.model.parameters()) + list(self.input_sequence_embed_linear.parameters())
+        params = list(self.parameters())
+
+        # opt = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-9)
+        opt = eval(self.optimizer_type)(params, lr=lr)
+        scheduler = lr_scheduler.StepLR(opt, step_size=10, gamma=0.8)
+        return [opt], [scheduler]
+
+    def add_sos_eos_tokens(self, _id, sequence, attn_mask):
+        batchsize = sequence.size(0)
+
+        new_attn_mask_step = torch.ones((batchsize, 1)).to(sequence.device)
+        key_id = torch.tensor([_id]).to(sequence.device)
+
+        # Add two more steps to attn mask
+        new_attn_mask = torch.cat(
+            [new_attn_mask_step, attn_mask, new_attn_mask_step], dim=1
+        )
+
+        # Add two more tokens in the sequence
+        sos_token = self.start_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
+        eos_token = self.end_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
+        new_sequence = torch.cat([sos_token, sequence, eos_token], dim=1)
+        return new_sequence, new_attn_mask
+
+    def truncate_sequence_and_mask(self, sequence, mask, max_len=512):
+        if sequence.size(1) > max_len:
+            print(
+                "The input sequence length to GPT-2 model is too long:",
+                sequence.size(1),
+            )
+            return sequence[:, :max_len], mask[:, :max_len]
+        else:
+            return sequence, mask
+
+    def get_input_sequence_and_mask(self, cond_dict):
+        input_embeds = None
+        input_embeds_attn_mask = None
+        for _id, sequence_key in enumerate(self.sequence_input_key):
+            assert sequence_key in cond_dict.keys(), (
+                "Invalid sequence key %s" % sequence_key
+            )
+            cond_embed = cond_dict[sequence_key]
+            if isinstance(cond_embed, list):
+                assert (
+                    len(cond_embed) == 2
+                ), "The crossattn returned list should have length 2, including embed and attn_mask"
+                item_input_embeds, item_attn_mask = cond_embed
+
+                item_input_embeds = self.input_sequence_embed_linear[_id](
+                    item_input_embeds
+                )
+
+                item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
+                    _id, item_input_embeds, item_attn_mask
+                )
+
+                if input_embeds is None and input_embeds_attn_mask is None:
+                    input_embeds, input_embeds_attn_mask = (
+                        item_input_embeds,
+                        item_attn_mask,
+                    )
+                else:
+                    input_embeds = torch.cat(
+                        [input_embeds, item_input_embeds], dim=1
+                    )  # The 1-st dimension is time steps
+                    input_embeds_attn_mask = torch.cat(
+                        [input_embeds_attn_mask, item_attn_mask], dim=1
+                    )  # The 1-st dimension is time steps
+            else:
+                assert isinstance(cond_embed, torch.Tensor)
+                cond_embed = self.input_sequence_embed_linear[_id](cond_embed)
+                attn_mask = torch.ones((cond_embed.size(0), cond_embed.size(1))).to(
+                    cond_embed.device
+                )
+
+                item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
+                    _id, cond_embed, attn_mask
+                )
+
+                if input_embeds is None and input_embeds_attn_mask is None:
+                    input_embeds, input_embeds_attn_mask = (
+                        item_input_embeds,
+                        item_attn_mask,
+                    )
+                else:
+                    input_embeds, input_embeds_attn_mask = torch.cat(
+                        [input_embeds, item_input_embeds], dim=1
+                    ), torch.cat([input_embeds_attn_mask, item_attn_mask], dim=1)
+
+        assert input_embeds is not None and input_embeds_attn_mask is not None
+
+        input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask(
+            input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num)
+        )
+        cond_sequence_end_time_idx = input_embeds.size(
+            1
+        )  # The index that we start to collect the output embeds
+
+        return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx
+
+    def warmup_step(self):
+        if self.initial_learning_rate is None:
+            self.initial_learning_rate = float(self.learning_rate)
+
+        # Only the first parameter group
+        if self.global_step <= 1000:
+            if self.global_step == 0:
+                print(
+                    "Warming up learning rate start with %s"
+                    % self.initial_learning_rate
+                )
+            self.trainer.optimizers[0].param_groups[0]["lr"] = (
+                self.global_step / 1000
+            ) * self.initial_learning_rate
+        else:
+            # TODO set learning rate here
+            self.trainer.optimizers[0].param_groups[0][
+                "lr"
+            ] = self.initial_learning_rate
+
+    def mask_target_sequence(self, target_embeds, target_embeds_attn_mask):
+        time_seq_mask = None
+        if self.target_tokens_mask_ratio > 1e-4:
+            batchsize, time_seq_len, embed_dim = target_embeds.size()
+            _, time_seq_len = target_embeds_attn_mask.size()
+            # Generate random mask
+            if self.random_mask_ratio:
+                mask_ratio = torch.rand(1).item() * self.target_tokens_mask_ratio
+            else:
+                mask_ratio = self.target_tokens_mask_ratio
+
+            time_seq_mask = (torch.rand((batchsize, time_seq_len)) > mask_ratio).to(
+                target_embeds.device
+            )
+            # Mask the target embedding
+            target_embeds = target_embeds * time_seq_mask.unsqueeze(-1)
+            target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask
+        return target_embeds, target_embeds_attn_mask, time_seq_mask
+
+    def generate_partial(self, batch, cond_dict=None, no_grad=False):
+        if cond_dict is None:
+            cond_dict = self.get_input(batch)
+
+        print("Generate partially prompted audio with in-context learning")
+        # self.model.train()
+        # assert self.model.training==True
+
+        target_embeds, target_embeds_attn_mask = (
+            cond_dict["crossattn_audiomae_pooled"][0],
+            cond_dict["crossattn_audiomae_pooled"][1],
+        )
+
+        target_time_steps = target_embeds.size(1)
+
+        (
+            input_embeds,
+            input_embeds_attn_mask,
+            cond_sequence_end_time_idx,
+        ) = self.get_input_sequence_and_mask(cond_dict)
+
+        model_input = torch.cat(
+            [input_embeds, target_embeds[:, : target_time_steps // 4, :]], dim=1
+        )
+        model_input_mask = torch.cat(
+            [
+                input_embeds_attn_mask,
+                target_embeds_attn_mask[:, : target_time_steps // 4],
+            ],
+            dim=1,
+        )
+
+        steps = self.mae_token_num
+
+        for _ in range(3 * steps // 4):
+            output = self.model(
+                inputs_embeds=model_input, attention_mask=model_input_mask
+            )["last_hidden_state"]
+            # Update the model input
+            model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
+            # Update the attention mask
+            attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
+                model_input.device
+            )
+            model_input_mask = torch.cat(
+                [model_input_mask, attention_mask_new_step], dim=1
+            )
+
+        output = model_input[:, cond_sequence_end_time_idx:]
+
+        return output, cond_dict
+
+    def generate(self, batch, cond_dict=None, no_grad=False):
+        if cond_dict is None:
+            cond_dict = self.get_input(batch)
+
+        # self.model.train()
+        # print("!!!!!!!!!!!!!train")
+
+        (
+            input_embeds,
+            input_embeds_attn_mask,
+            cond_sequence_end_time_idx,
+        ) = self.get_input_sequence_and_mask(cond_dict)
+        model_input = input_embeds
+        model_input_mask = input_embeds_attn_mask
+
+        steps = self.mae_token_num
+
+        for _ in range(steps):
+            output = self.model(
+                inputs_embeds=model_input, attention_mask=model_input_mask
+            )["last_hidden_state"]
+            # Update the model input
+            model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
+            # Update the attention mask
+            attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
+                model_input.device
+            )
+            model_input_mask = torch.cat(
+                [model_input_mask, attention_mask_new_step], dim=1
+            )
+
+        return model_input[:, cond_sequence_end_time_idx:], cond_dict
+
+    def get_input_item(self, batch, k):
+        fname, text, waveform, stft, fbank = (
+            batch["fname"],
+            batch["text"],
+            batch["waveform"],
+            batch["stft"],
+            batch["log_mel_spec"],
+        )
+        ret = {}
+
+        ret["fbank"] = (
+            fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
+        )
+        ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
+        # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
+        ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
+        ret["text"] = list(text)
+        ret["fname"] = fname
+
+        for key in batch.keys():
+            if key not in ret.keys():
+                ret[key] = batch[key]
+
+        return ret[k]
+
+    def get_input(self, batch):
+        cond_dict = {}
+        if len(self.cond_stage_model_metadata.keys()) > 0:
+            unconditional_cfg = False
+
+            for cond_model_key in self.cond_stage_model_metadata.keys():
+                cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
+                    "cond_stage_key"
+                ]
+
+                # if(not self.training):
+                #     if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)):
+                #         assert cond_stage_key == "text" # CLAP model should use text for evaluation
+
+                # The original data for conditioning
+                xc = self.get_input_item(batch, cond_stage_key)
+                if type(xc) == torch.Tensor:
+                    xc = xc.to(self.device)
+
+                c = self.get_learned_conditioning(
+                    xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
+                )
+                cond_dict[cond_model_key] = c
+
+        return cond_dict
+
+    def instantiate_cond_stage(self, config):
+        self.cond_stage_model_metadata = {}
+
+        for i, cond_model_key in enumerate(config.keys()):
+            model = instantiate_from_config(config[cond_model_key])
+            self.cond_stage_models.append(model)
+            self.cond_stage_model_metadata[cond_model_key] = {
+                "model_idx": i,
+                "cond_stage_key": config[cond_model_key]["cond_stage_key"],
+                "conditioning_key": config[cond_model_key]["conditioning_key"],
+            }
+
+    def get_learned_conditioning(self, c, key, unconditional_cfg):
+        assert key in self.cond_stage_model_metadata.keys()
+
+        # Classifier-free guidance
+        if not unconditional_cfg:
+            c = self.cond_stage_models[
+                self.cond_stage_model_metadata[key]["model_idx"]
+            ](c)
+        else:
+            if isinstance(c, torch.Tensor):
+                batchsize = c.size(0)
+            elif isinstance(c, list):
+                batchsize = len(c)
+            else:
+                raise NotImplementedError()
+            c = self.cond_stage_models[
+                self.cond_stage_model_metadata[key]["model_idx"]
+            ].get_unconditional_condition(batchsize)
+
+        return c
+
+    def initialize_param_check_toolkit(self):
+        self.tracked_steps = 0
+        self.param_dict = {}
+
+    def statistic_require_grad_tensor_number(self, module, name=None):
+        requires_grad_num = 0
+        total_num = 0
+        require_grad_tensor = None
+        for p in module.parameters():
+            if p.requires_grad:
+                requires_grad_num += 1
+                if require_grad_tensor is None:
+                    require_grad_tensor = p
+            total_num += 1
+        print(
+            "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)"
+            % (name, requires_grad_num, total_num, requires_grad_num / total_num)
+        )
+        return require_grad_tensor
diff --git a/audioldm2/audiomae_gen/utils.py b/audioldm2/audiomae_gen/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..841d35adf338647bdf8bd1c31e9f33dee1252b6e
--- /dev/null
+++ b/audioldm2/audiomae_gen/utils.py
@@ -0,0 +1,27 @@
+import torch.nn as nn
+
+
+class Prenet(nn.Module):
+    def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5):
+        super(Prenet, self).__init__()
+        in_sizes = [in_dim] + sizes[:-1]
+        self.layers = nn.ModuleList(
+            [
+                nn.Linear(in_size, out_size)
+                for (in_size, out_size) in zip(in_sizes, sizes)
+            ]
+        )
+        self.relu = nn.ReLU()
+        self.dropout = nn.Dropout(dropout_rate)
+
+    def forward(self, inputs):
+        for linear in self.layers:
+            inputs = self.dropout(self.relu(linear(inputs)))
+        return inputs
+
+
+if __name__ == "__main__":
+    model = Prenet(in_dim=128, sizes=[256, 256, 128])
+    import ipdb
+
+    ipdb.set_trace()
diff --git a/audioldm2/clap/__init__.py b/audioldm2/clap/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/clap/open_clip/__init__.py b/audioldm2/clap/open_clip/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e9f728f2f273be5d5fdbec6c6cc41d737176a8c0
--- /dev/null
+++ b/audioldm2/clap/open_clip/__init__.py
@@ -0,0 +1,25 @@
+from .factory import (
+    list_models,
+    create_model,
+    create_model_and_transforms,
+    add_model_config,
+)
+from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
+from .model import (
+    CLAP,
+    CLAPTextCfg,
+    CLAPVisionCfg,
+    CLAPAudioCfp,
+    convert_weights_to_fp16,
+    trace_model,
+)
+from .openai import load_openai_model, list_openai_models
+from .pretrained import (
+    list_pretrained,
+    list_pretrained_tag_models,
+    list_pretrained_model_tags,
+    get_pretrained_url,
+    download_pretrained,
+)
+from .tokenizer import SimpleTokenizer, tokenize
+from .transform import image_transform
diff --git a/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz b/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
new file mode 100755
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/audioldm2/clap/open_clip/factory.py b/audioldm2/clap/open_clip/factory.py
new file mode 100755
index 0000000000000000000000000000000000000000..df0f4a194c2e7328f7b7d3fe11fa6801c6cc1a7c
--- /dev/null
+++ b/audioldm2/clap/open_clip/factory.py
@@ -0,0 +1,276 @@
+import json
+import logging
+import os
+import re
+from copy import deepcopy
+from pathlib import Path
+
+import torch
+
+from .model import CLAP, convert_weights_to_fp16
+from .openai import load_openai_model
+from .pretrained import get_pretrained_url, download_pretrained
+from .transform import image_transform
+
+_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
+_MODEL_CONFIGS = {}  # directory (model_name: config) of model architecture configs
+
+
+def _natural_key(string_):
+    return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
+
+
+def _rescan_model_configs():
+    global _MODEL_CONFIGS
+
+    config_ext = (".json",)
+    config_files = []
+    for config_path in _MODEL_CONFIG_PATHS:
+        if config_path.is_file() and config_path.suffix in config_ext:
+            config_files.append(config_path)
+        elif config_path.is_dir():
+            for ext in config_ext:
+                config_files.extend(config_path.glob(f"*{ext}"))
+
+    for cf in config_files:
+        if os.path.basename(cf)[0] == ".":
+            continue  # Ignore hidden files
+
+        with open(cf, "r") as f:
+            model_cfg = json.load(f)
+            if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
+                _MODEL_CONFIGS[cf.stem] = model_cfg
+
+    _MODEL_CONFIGS = {
+        k: v
+        for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
+    }
+
+
+_rescan_model_configs()  # initial populate of model config registry
+
+
+def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
+    checkpoint = torch.load(checkpoint_path, map_location=map_location)
+    if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
+        state_dict = checkpoint["state_dict"]
+    else:
+        state_dict = checkpoint
+    if skip_params:
+        if next(iter(state_dict.items()))[0].startswith("module"):
+            state_dict = {k[7:]: v for k, v in state_dict.items()}
+    # for k in state_dict:
+    #     if k.startswith('transformer'):
+    #         v = state_dict.pop(k)
+    #         state_dict['text_branch.' + k[12:]] = v
+    return state_dict
+
+
+def create_model(
+    amodel_name: str,
+    tmodel_name: str,
+    pretrained: str = "",
+    precision: str = "fp32",
+    device: torch.device = torch.device("cpu"),
+    jit: bool = False,
+    force_quick_gelu: bool = False,
+    openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
+    skip_params=True,
+    pretrained_audio: str = "",
+    pretrained_text: str = "",
+    enable_fusion: bool = False,
+    fusion_type: str = "None"
+    # pretrained_image: bool = False,
+):
+    amodel_name = amodel_name.replace(
+        "/", "-"
+    )  # for callers using old naming with / in ViT names
+    pretrained_orig = pretrained
+    pretrained = pretrained.lower()
+    if pretrained == "openai":
+        if amodel_name in _MODEL_CONFIGS:
+            logging.info(f"Loading {amodel_name} model config.")
+            model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
+        else:
+            logging.error(
+                f"Model config for {amodel_name} not found; available models {list_models()}."
+            )
+            raise RuntimeError(f"Model config for {amodel_name} not found.")
+
+        logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
+        # Hard Code in model name
+        model_cfg["text_cfg"]["model_type"] = tmodel_name
+        model = load_openai_model(
+            "ViT-B-16",
+            model_cfg,
+            device=device,
+            jit=jit,
+            cache_dir=openai_model_cache_dir,
+            enable_fusion=enable_fusion,
+            fusion_type=fusion_type,
+        )
+        # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
+        if precision == "amp" or precision == "fp32":
+            model = model.float()
+    else:
+        if amodel_name in _MODEL_CONFIGS:
+            logging.info(f"Loading {amodel_name} model config.")
+            model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
+        else:
+            logging.error(
+                f"Model config for {amodel_name} not found; available models {list_models()}."
+            )
+            raise RuntimeError(f"Model config for {amodel_name} not found.")
+
+        if force_quick_gelu:
+            # override for use of QuickGELU on non-OpenAI transformer models
+            model_cfg["quick_gelu"] = True
+
+        # if pretrained_image:
+        #     if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
+        #         # pretrained weight loading for timm models set via vision_cfg
+        #         model_cfg['vision_cfg']['timm_model_pretrained'] = True
+        #     else:
+        #         assert False, 'pretrained image towers currently only supported for timm models'
+        model_cfg["text_cfg"]["model_type"] = tmodel_name
+        model_cfg["enable_fusion"] = enable_fusion
+        model_cfg["fusion_type"] = fusion_type
+        model = CLAP(**model_cfg)
+
+        if pretrained:
+            checkpoint_path = ""
+            url = get_pretrained_url(amodel_name, pretrained)
+            if url:
+                checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
+            elif os.path.exists(pretrained_orig):
+                checkpoint_path = pretrained_orig
+            if checkpoint_path:
+                logging.info(
+                    f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
+                )
+                ckpt = load_state_dict(checkpoint_path, skip_params=True)
+                model.load_state_dict(ckpt)
+                param_names = [n for n, p in model.named_parameters()]
+                # for n in param_names:
+                #     print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
+            else:
+                logging.warning(
+                    f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
+                )
+                raise RuntimeError(
+                    f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
+                )
+
+        if pretrained_audio:
+            if amodel_name.startswith("PANN"):
+                if "Cnn14_mAP" in pretrained_audio:  # official checkpoint
+                    audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+                    audio_ckpt = audio_ckpt["model"]
+                    keys = list(audio_ckpt.keys())
+                    for key in keys:
+                        if (
+                            "spectrogram_extractor" not in key
+                            and "logmel_extractor" not in key
+                        ):
+                            v = audio_ckpt.pop(key)
+                            audio_ckpt["audio_branch." + key] = v
+                elif os.path.basename(pretrained_audio).startswith(
+                    "PANN"
+                ):  # checkpoint trained via HTSAT codebase
+                    audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+                    audio_ckpt = audio_ckpt["state_dict"]
+                    keys = list(audio_ckpt.keys())
+                    for key in keys:
+                        if key.startswith("sed_model"):
+                            v = audio_ckpt.pop(key)
+                            audio_ckpt["audio_branch." + key[10:]] = v
+                elif os.path.basename(pretrained_audio).startswith(
+                    "finetuned"
+                ):  # checkpoint trained via linear probe codebase
+                    audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+                else:
+                    raise ValueError("Unknown audio checkpoint")
+            elif amodel_name.startswith("HTSAT"):
+                if "HTSAT_AudioSet_Saved" in pretrained_audio:  # official checkpoint
+                    audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+                    audio_ckpt = audio_ckpt["state_dict"]
+                    keys = list(audio_ckpt.keys())
+                    for key in keys:
+                        if key.startswith("sed_model") and (
+                            "spectrogram_extractor" not in key
+                            and "logmel_extractor" not in key
+                        ):
+                            v = audio_ckpt.pop(key)
+                            audio_ckpt["audio_branch." + key[10:]] = v
+                elif os.path.basename(pretrained_audio).startswith(
+                    "HTSAT"
+                ):  # checkpoint trained via HTSAT codebase
+                    audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+                    audio_ckpt = audio_ckpt["state_dict"]
+                    keys = list(audio_ckpt.keys())
+                    for key in keys:
+                        if key.startswith("sed_model"):
+                            v = audio_ckpt.pop(key)
+                            audio_ckpt["audio_branch." + key[10:]] = v
+                elif os.path.basename(pretrained_audio).startswith(
+                    "finetuned"
+                ):  # checkpoint trained via linear probe codebase
+                    audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+                else:
+                    raise ValueError("Unknown audio checkpoint")
+            else:
+                raise f"this audio encoder pretrained checkpoint is not support"
+
+            model.load_state_dict(audio_ckpt, strict=False)
+            logging.info(
+                f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
+            )
+            param_names = [n for n, p in model.named_parameters()]
+            for n in param_names:
+                print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
+
+        model.to(device=device)
+        if precision == "fp16":
+            assert device.type != "cpu"
+            convert_weights_to_fp16(model)
+
+        if jit:
+            model = torch.jit.script(model)
+
+    return model, model_cfg
+
+
+def create_model_and_transforms(
+    model_name: str,
+    pretrained: str = "",
+    precision: str = "fp32",
+    device: torch.device = torch.device("cpu"),
+    jit: bool = False,
+    force_quick_gelu: bool = False,
+    # pretrained_image: bool = False,
+):
+    model = create_model(
+        model_name,
+        pretrained,
+        precision,
+        device,
+        jit,
+        force_quick_gelu=force_quick_gelu,
+        # pretrained_image=pretrained_image
+    )
+    preprocess_train = image_transform(model.visual.image_size, is_train=True)
+    preprocess_val = image_transform(model.visual.image_size, is_train=False)
+    return model, preprocess_train, preprocess_val
+
+
+def list_models():
+    """enumerate available model architectures based on config files"""
+    return list(_MODEL_CONFIGS.keys())
+
+
+def add_model_config(path):
+    """add model config path or file and update registry"""
+    if not isinstance(path, Path):
+        path = Path(path)
+    _MODEL_CONFIG_PATHS.append(path)
+    _rescan_model_configs()
diff --git a/audioldm2/clap/open_clip/feature_fusion.py b/audioldm2/clap/open_clip/feature_fusion.py
new file mode 100755
index 0000000000000000000000000000000000000000..dbe4e170e05894c12ebdc36ba1dc1de65e441b89
--- /dev/null
+++ b/audioldm2/clap/open_clip/feature_fusion.py
@@ -0,0 +1,192 @@
+"""
+Feature Fusion for Varible-Length Data Processing
+AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
+According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
+"""
+
+import torch
+import torch.nn as nn
+
+
+class DAF(nn.Module):
+    """
+    直接相加 DirectAddFuse
+    """
+
+    def __init__(self):
+        super(DAF, self).__init__()
+
+    def forward(self, x, residual):
+        return x + residual
+
+
+class iAFF(nn.Module):
+    """
+    多特征融合 iAFF
+    """
+
+    def __init__(self, channels=64, r=4, type="2D"):
+        super(iAFF, self).__init__()
+        inter_channels = int(channels // r)
+
+        if type == "1D":
+            # 本地注意力
+            self.local_att = nn.Sequential(
+                nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(channels),
+            )
+
+            # 全局注意力
+            self.global_att = nn.Sequential(
+                nn.AdaptiveAvgPool1d(1),
+                nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(channels),
+            )
+
+            # 第二次本地注意力
+            self.local_att2 = nn.Sequential(
+                nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(channels),
+            )
+            # 第二次全局注意力
+            self.global_att2 = nn.Sequential(
+                nn.AdaptiveAvgPool1d(1),
+                nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(channels),
+            )
+        elif type == "2D":
+            # 本地注意力
+            self.local_att = nn.Sequential(
+                nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(channels),
+            )
+
+            # 全局注意力
+            self.global_att = nn.Sequential(
+                nn.AdaptiveAvgPool2d(1),
+                nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(channels),
+            )
+
+            # 第二次本地注意力
+            self.local_att2 = nn.Sequential(
+                nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(channels),
+            )
+            # 第二次全局注意力
+            self.global_att2 = nn.Sequential(
+                nn.AdaptiveAvgPool2d(1),
+                nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(channels),
+            )
+        else:
+            raise f"the type is not supported"
+
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x, residual):
+        flag = False
+        xa = x + residual
+        if xa.size(0) == 1:
+            xa = torch.cat([xa, xa], dim=0)
+            flag = True
+        xl = self.local_att(xa)
+        xg = self.global_att(xa)
+        xlg = xl + xg
+        wei = self.sigmoid(xlg)
+        xi = x * wei + residual * (1 - wei)
+
+        xl2 = self.local_att2(xi)
+        xg2 = self.global_att(xi)
+        xlg2 = xl2 + xg2
+        wei2 = self.sigmoid(xlg2)
+        xo = x * wei2 + residual * (1 - wei2)
+        if flag:
+            xo = xo[0].unsqueeze(0)
+        return xo
+
+
+class AFF(nn.Module):
+    """
+    多特征融合 AFF
+    """
+
+    def __init__(self, channels=64, r=4, type="2D"):
+        super(AFF, self).__init__()
+        inter_channels = int(channels // r)
+
+        if type == "1D":
+            self.local_att = nn.Sequential(
+                nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(channels),
+            )
+            self.global_att = nn.Sequential(
+                nn.AdaptiveAvgPool1d(1),
+                nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm1d(channels),
+            )
+        elif type == "2D":
+            self.local_att = nn.Sequential(
+                nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(channels),
+            )
+            self.global_att = nn.Sequential(
+                nn.AdaptiveAvgPool2d(1),
+                nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(inter_channels),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+                nn.BatchNorm2d(channels),
+            )
+        else:
+            raise f"the type is not supported."
+
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x, residual):
+        flag = False
+        xa = x + residual
+        if xa.size(0) == 1:
+            xa = torch.cat([xa, xa], dim=0)
+            flag = True
+        xl = self.local_att(xa)
+        xg = self.global_att(xa)
+        xlg = xl + xg
+        wei = self.sigmoid(xlg)
+        xo = 2 * x * wei + 2 * residual * (1 - wei)
+        if flag:
+            xo = xo[0].unsqueeze(0)
+        return xo
diff --git a/audioldm2/clap/open_clip/htsat.py b/audioldm2/clap/open_clip/htsat.py
new file mode 100755
index 0000000000000000000000000000000000000000..8bf4fceea2dfef953522c14a3a39a417658f2257
--- /dev/null
+++ b/audioldm2/clap/open_clip/htsat.py
@@ -0,0 +1,1304 @@
+# Ke Chen
+# knutchen@ucsd.edu
+# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
+# Some layers designed on the model
+# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
+# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
+
+import torch
+import torch.nn as nn
+from itertools import repeat
+import collections.abc
+import math
+import warnings
+
+from torch.nn.init import _calculate_fan_in_and_fan_out
+import torch.utils.checkpoint as checkpoint
+
+import random
+
+from torchlibrosa.stft import Spectrogram, LogmelFilterBank
+from torchlibrosa.augmentation import SpecAugmentation
+
+from itertools import repeat
+from .utils import do_mixup, interpolate
+
+from .feature_fusion import iAFF, AFF, DAF
+
+
+# from PyTorch internals
+def _ntuple(n):
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable):
+            return x
+        return tuple(repeat(x, n))
+
+    return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+    'survival rate' as the argument.
+    """
+    if drop_prob == 0.0 or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (
+        x.ndim - 1
+    )  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+    random_tensor.floor_()  # binarize
+    output = x.div(keep_prob) * random_tensor
+    return output
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+
+class PatchEmbed(nn.Module):
+    """2D Image to Patch Embedding"""
+
+    def __init__(
+        self,
+        img_size=224,
+        patch_size=16,
+        in_chans=3,
+        embed_dim=768,
+        norm_layer=None,
+        flatten=True,
+        patch_stride=16,
+        enable_fusion=False,
+        fusion_type="None",
+    ):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        patch_stride = to_2tuple(patch_stride)
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.patch_stride = patch_stride
+        self.grid_size = (
+            img_size[0] // patch_stride[0],
+            img_size[1] // patch_stride[1],
+        )
+        self.num_patches = self.grid_size[0] * self.grid_size[1]
+        self.flatten = flatten
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.enable_fusion = enable_fusion
+        self.fusion_type = fusion_type
+
+        padding = (
+            (patch_size[0] - patch_stride[0]) // 2,
+            (patch_size[1] - patch_stride[1]) // 2,
+        )
+
+        if (self.enable_fusion) and (self.fusion_type == "channel_map"):
+            self.proj = nn.Conv2d(
+                in_chans * 4,
+                embed_dim,
+                kernel_size=patch_size,
+                stride=patch_stride,
+                padding=padding,
+            )
+        else:
+            self.proj = nn.Conv2d(
+                in_chans,
+                embed_dim,
+                kernel_size=patch_size,
+                stride=patch_stride,
+                padding=padding,
+            )
+        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+        if (self.enable_fusion) and (
+            self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+        ):
+            self.mel_conv2d = nn.Conv2d(
+                in_chans,
+                embed_dim,
+                kernel_size=(patch_size[0], patch_size[1] * 3),
+                stride=(patch_stride[0], patch_stride[1] * 3),
+                padding=padding,
+            )
+            if self.fusion_type == "daf_2d":
+                self.fusion_model = DAF()
+            elif self.fusion_type == "aff_2d":
+                self.fusion_model = AFF(channels=embed_dim, type="2D")
+            elif self.fusion_type == "iaff_2d":
+                self.fusion_model = iAFF(channels=embed_dim, type="2D")
+
+    def forward(self, x, longer_idx=None):
+        if (self.enable_fusion) and (
+            self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+        ):
+            global_x = x[:, 0:1, :, :]
+
+            # global processing
+            B, C, H, W = global_x.shape
+            assert (
+                H == self.img_size[0] and W == self.img_size[1]
+            ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+            global_x = self.proj(global_x)
+            TW = global_x.size(-1)
+            if len(longer_idx) > 0:
+                # local processing
+                local_x = x[longer_idx, 1:, :, :].contiguous()
+                B, C, H, W = local_x.shape
+                local_x = local_x.view(B * C, 1, H, W)
+                local_x = self.mel_conv2d(local_x)
+                local_x = local_x.view(
+                    B, C, local_x.size(1), local_x.size(2), local_x.size(3)
+                )
+                local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
+                TB, TC, TH, _ = local_x.size()
+                if local_x.size(-1) < TW:
+                    local_x = torch.cat(
+                        [
+                            local_x,
+                            torch.zeros(
+                                (TB, TC, TH, TW - local_x.size(-1)),
+                                device=global_x.device,
+                            ),
+                        ],
+                        dim=-1,
+                    )
+                else:
+                    local_x = local_x[:, :, :, :TW]
+
+                global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
+            x = global_x
+        else:
+            B, C, H, W = x.shape
+            assert (
+                H == self.img_size[0] and W == self.img_size[1]
+            ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+            x = self.proj(x)
+
+        if self.flatten:
+            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
+        x = self.norm(x)
+        return x
+
+
+class Mlp(nn.Module):
+    """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+    def __init__(
+        self,
+        in_features,
+        hidden_features=None,
+        out_features=None,
+        act_layer=nn.GELU,
+        drop=0.0,
+    ):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+    # Cut & paste from PyTorch official master until it's in a few official releases - RW
+    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+    def norm_cdf(x):
+        # Computes standard normal cumulative distribution function
+        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn(
+            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+            "The distribution of values may be incorrect.",
+            stacklevel=2,
+        )
+
+    with torch.no_grad():
+        # Values are generated by using a truncated uniform distribution and
+        # then using the inverse CDF for the normal distribution.
+        # Get upper and lower cdf values
+        l = norm_cdf((a - mean) / std)
+        u = norm_cdf((b - mean) / std)
+
+        # Uniformly fill tensor with values from [l, u], then translate to
+        # [2l-1, 2u-1].
+        tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+        # Use inverse cdf transform for normal distribution to get truncated
+        # standard normal
+        tensor.erfinv_()
+
+        # Transform to proper mean, std
+        tensor.mul_(std * math.sqrt(2.0))
+        tensor.add_(mean)
+
+        # Clamp to ensure it's in the proper range
+        tensor.clamp_(min=a, max=b)
+        return tensor
+
+
+def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
+    # type: (Tensor, float, float, float, float) -> Tensor
+    r"""Fills the input Tensor with values drawn from a truncated
+    normal distribution. The values are effectively drawn from the
+    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+    with values outside :math:`[a, b]` redrawn until they are within
+    the bounds. The method used for generating the random values works
+    best when :math:`a \leq \text{mean} \leq b`.
+    Args:
+        tensor: an n-dimensional `torch.Tensor`
+        mean: the mean of the normal distribution
+        std: the standard deviation of the normal distribution
+        a: the minimum cutoff value
+        b: the maximum cutoff value
+    Examples:
+        >>> w = torch.empty(3, 5)
+        >>> nn.init.trunc_normal_(w)
+    """
+    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+    if mode == "fan_in":
+        denom = fan_in
+    elif mode == "fan_out":
+        denom = fan_out
+    elif mode == "fan_avg":
+        denom = (fan_in + fan_out) / 2
+
+    variance = scale / denom
+
+    if distribution == "truncated_normal":
+        # constant is stddev of standard normal truncated to (-2, 2)
+        trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+    elif distribution == "normal":
+        tensor.normal_(std=math.sqrt(variance))
+    elif distribution == "uniform":
+        bound = math.sqrt(3 * variance)
+        tensor.uniform_(-bound, bound)
+    else:
+        raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+    variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = (
+        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    )
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(
+        B, H // window_size, W // window_size, window_size, window_size, -1
+    )
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class WindowAttention(nn.Module):
+    r"""Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(
+        self,
+        dim,
+        window_size,
+        num_heads,
+        qkv_bias=True,
+        qk_scale=None,
+        attn_drop=0.0,
+        proj_drop=0.0,
+    ):
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+        )  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = (
+            coords_flatten[:, :, None] - coords_flatten[:, None, :]
+        )  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(
+            1, 2, 0
+        ).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=0.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, mask=None):
+        """
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape
+        qkv = (
+            self.qkv(x)
+            .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
+            .permute(2, 0, 3, 1, 4)
+        )
+        q, k, v = (
+            qkv[0],
+            qkv[1],
+            qkv[2],
+        )  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = q @ k.transpose(-2, -1)
+
+        relative_position_bias = self.relative_position_bias_table[
+            self.relative_position_index.view(-1)
+        ].view(
+            self.window_size[0] * self.window_size[1],
+            self.window_size[0] * self.window_size[1],
+            -1,
+        )  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(
+            2, 0, 1
+        ).contiguous()  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
+                1
+            ).unsqueeze(0)
+            attn = attn.view(-1, self.num_heads, N, N)
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x, attn
+
+    def extra_repr(self):
+        return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
+
+
+# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
+class SwinTransformerBlock(nn.Module):
+    r"""Swin Transformer Block.
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resulotion.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(
+        self,
+        dim,
+        input_resolution,
+        num_heads,
+        window_size=7,
+        shift_size=0,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        qk_scale=None,
+        drop=0.0,
+        attn_drop=0.0,
+        drop_path=0.0,
+        act_layer=nn.GELU,
+        norm_layer=nn.LayerNorm,
+        norm_before_mlp="ln",
+    ):
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        self.norm_before_mlp = norm_before_mlp
+        if min(self.input_resolution) <= self.window_size:
+            # if window size is larger than input resolution, we don't partition windows
+            self.shift_size = 0
+            self.window_size = min(self.input_resolution)
+        assert (
+            0 <= self.shift_size < self.window_size
+        ), "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim,
+            window_size=to_2tuple(self.window_size),
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+        )
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+        if self.norm_before_mlp == "ln":
+            self.norm2 = nn.LayerNorm(dim)
+        elif self.norm_before_mlp == "bn":
+            self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
+                1, 2
+            )
+        else:
+            raise NotImplementedError
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(
+            in_features=dim,
+            hidden_features=mlp_hidden_dim,
+            act_layer=act_layer,
+            drop=drop,
+        )
+
+        if self.shift_size > 0:
+            # calculate attention mask for SW-MSA
+            H, W = self.input_resolution
+            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
+            h_slices = (
+                slice(0, -self.window_size),
+                slice(-self.window_size, -self.shift_size),
+                slice(-self.shift_size, None),
+            )
+            w_slices = (
+                slice(0, -self.window_size),
+                slice(-self.window_size, -self.shift_size),
+                slice(-self.shift_size, None),
+            )
+            cnt = 0
+            for h in h_slices:
+                for w in w_slices:
+                    img_mask[:, h, w, :] = cnt
+                    cnt += 1
+
+            mask_windows = window_partition(
+                img_mask, self.window_size
+            )  # nW, window_size, window_size, 1
+            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+            attn_mask = attn_mask.masked_fill(
+                attn_mask != 0, float(-100.0)
+            ).masked_fill(attn_mask == 0, float(0.0))
+        else:
+            attn_mask = None
+
+        self.register_buffer("attn_mask", attn_mask)
+
+    def forward(self, x):
+        # pdb.set_trace()
+        H, W = self.input_resolution
+        # print("H: ", H)
+        # print("W: ", W)
+        # pdb.set_trace()
+        B, L, C = x.shape
+        # assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(
+                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
+            )
+        else:
+            shifted_x = x
+
+        # partition windows
+        x_windows = window_partition(
+            shifted_x, self.window_size
+        )  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(
+            -1, self.window_size * self.window_size, C
+        )  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows, attn = self.attn(
+            x_windows, mask=self.attn_mask
+        )  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(
+                shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
+            )
+        else:
+            x = shifted_x
+        x = x.view(B, H * W, C)
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        return x, attn
+
+    def extra_repr(self):
+        return (
+            f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
+            f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+        )
+
+
+class PatchMerging(nn.Module):
+    r"""Patch Merging Layer.
+    Args:
+        input_resolution (tuple[int]): Resolution of input feature.
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x):
+        """
+        x: B, H*W, C
+        """
+        H, W = self.input_resolution
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+        x = x.view(B, H, W, C)
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+    def extra_repr(self):
+        return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+
+class BasicLayer(nn.Module):
+    """A basic Swin Transformer layer for one stage.
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(
+        self,
+        dim,
+        input_resolution,
+        depth,
+        num_heads,
+        window_size,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        qk_scale=None,
+        drop=0.0,
+        attn_drop=0.0,
+        drop_path=0.0,
+        norm_layer=nn.LayerNorm,
+        downsample=None,
+        use_checkpoint=False,
+        norm_before_mlp="ln",
+    ):
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList(
+            [
+                SwinTransformerBlock(
+                    dim=dim,
+                    input_resolution=input_resolution,
+                    num_heads=num_heads,
+                    window_size=window_size,
+                    shift_size=0 if (i % 2 == 0) else window_size // 2,
+                    mlp_ratio=mlp_ratio,
+                    qkv_bias=qkv_bias,
+                    qk_scale=qk_scale,
+                    drop=drop,
+                    attn_drop=attn_drop,
+                    drop_path=drop_path[i]
+                    if isinstance(drop_path, list)
+                    else drop_path,
+                    norm_layer=norm_layer,
+                    norm_before_mlp=norm_before_mlp,
+                )
+                for i in range(depth)
+            ]
+        )
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(
+                input_resolution, dim=dim, norm_layer=norm_layer
+            )
+        else:
+            self.downsample = None
+
+    def forward(self, x):
+        attns = []
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x, attn = blk(x)
+                if not self.training:
+                    attns.append(attn.unsqueeze(0))
+        if self.downsample is not None:
+            x = self.downsample(x)
+        if not self.training:
+            attn = torch.cat(attns, dim=0)
+            attn = torch.mean(attn, dim=0)
+        return x, attn
+
+    def extra_repr(self):
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+
+# The Core of HTSAT
+class HTSAT_Swin_Transformer(nn.Module):
+    r"""HTSAT based on the Swin Transformer
+    Args:
+        spec_size (int | tuple(int)): Input Spectrogram size. Default 256
+        patch_size (int | tuple(int)): Patch size. Default: 4
+        path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
+        in_chans (int): Number of input image channels. Default: 1 (mono)
+        num_classes (int): Number of classes for classification head. Default: 527
+        embed_dim (int): Patch embedding dimension. Default: 96
+        depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
+        num_heads (tuple(int)): Number of attention heads in different layers.
+        window_size (int): Window size. Default: 8
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+        drop_rate (float): Dropout rate. Default: 0
+        attn_drop_rate (float): Attention dropout rate. Default: 0
+        drop_path_rate (float): Stochastic depth rate. Default: 0.1
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+        config (module): The configuration Module from config.py
+    """
+
+    def __init__(
+        self,
+        spec_size=256,
+        patch_size=4,
+        patch_stride=(4, 4),
+        in_chans=1,
+        num_classes=527,
+        embed_dim=96,
+        depths=[2, 2, 6, 2],
+        num_heads=[4, 8, 16, 32],
+        window_size=8,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        qk_scale=None,
+        drop_rate=0.0,
+        attn_drop_rate=0.0,
+        drop_path_rate=0.1,
+        norm_layer=nn.LayerNorm,
+        ape=False,
+        patch_norm=True,
+        use_checkpoint=False,
+        norm_before_mlp="ln",
+        config=None,
+        enable_fusion=False,
+        fusion_type="None",
+        **kwargs,
+    ):
+        super(HTSAT_Swin_Transformer, self).__init__()
+
+        self.config = config
+        self.spec_size = spec_size
+        self.patch_stride = patch_stride
+        self.patch_size = patch_size
+        self.window_size = window_size
+        self.embed_dim = embed_dim
+        self.depths = depths
+        self.ape = ape
+        self.in_chans = in_chans
+        self.num_classes = num_classes
+        self.num_heads = num_heads
+        self.num_layers = len(self.depths)
+        self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
+
+        self.drop_rate = drop_rate
+        self.attn_drop_rate = attn_drop_rate
+        self.drop_path_rate = drop_path_rate
+
+        self.qkv_bias = qkv_bias
+        self.qk_scale = None
+
+        self.patch_norm = patch_norm
+        self.norm_layer = norm_layer if self.patch_norm else None
+        self.norm_before_mlp = norm_before_mlp
+        self.mlp_ratio = mlp_ratio
+
+        self.use_checkpoint = use_checkpoint
+
+        self.enable_fusion = enable_fusion
+        self.fusion_type = fusion_type
+
+        #  process mel-spec ; used only once
+        self.freq_ratio = self.spec_size // self.config.mel_bins
+        window = "hann"
+        center = True
+        pad_mode = "reflect"
+        ref = 1.0
+        amin = 1e-10
+        top_db = None
+        self.interpolate_ratio = 32  # Downsampled ratio
+        # Spectrogram extractor
+        self.spectrogram_extractor = Spectrogram(
+            n_fft=config.window_size,
+            hop_length=config.hop_size,
+            win_length=config.window_size,
+            window=window,
+            center=center,
+            pad_mode=pad_mode,
+            freeze_parameters=True,
+        )
+        # Logmel feature extractor
+        self.logmel_extractor = LogmelFilterBank(
+            sr=config.sample_rate,
+            n_fft=config.window_size,
+            n_mels=config.mel_bins,
+            fmin=config.fmin,
+            fmax=config.fmax,
+            ref=ref,
+            amin=amin,
+            top_db=top_db,
+            freeze_parameters=True,
+        )
+        # Spec augmenter
+        self.spec_augmenter = SpecAugmentation(
+            time_drop_width=64,
+            time_stripes_num=2,
+            freq_drop_width=8,
+            freq_stripes_num=2,
+        )  # 2 2
+        self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
+
+        # split spctrogram into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            img_size=self.spec_size,
+            patch_size=self.patch_size,
+            in_chans=self.in_chans,
+            embed_dim=self.embed_dim,
+            norm_layer=self.norm_layer,
+            patch_stride=patch_stride,
+            enable_fusion=self.enable_fusion,
+            fusion_type=self.fusion_type,
+        )
+
+        num_patches = self.patch_embed.num_patches
+        patches_resolution = self.patch_embed.grid_size
+        self.patches_resolution = patches_resolution
+
+        # absolute position embedding
+        if self.ape:
+            self.absolute_pos_embed = nn.Parameter(
+                torch.zeros(1, num_patches, self.embed_dim)
+            )
+            trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+        self.pos_drop = nn.Dropout(p=self.drop_rate)
+
+        # stochastic depth
+        dpr = [
+            x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
+        ]  # stochastic depth decay rule
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            layer = BasicLayer(
+                dim=int(self.embed_dim * 2**i_layer),
+                input_resolution=(
+                    patches_resolution[0] // (2**i_layer),
+                    patches_resolution[1] // (2**i_layer),
+                ),
+                depth=self.depths[i_layer],
+                num_heads=self.num_heads[i_layer],
+                window_size=self.window_size,
+                mlp_ratio=self.mlp_ratio,
+                qkv_bias=self.qkv_bias,
+                qk_scale=self.qk_scale,
+                drop=self.drop_rate,
+                attn_drop=self.attn_drop_rate,
+                drop_path=dpr[
+                    sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
+                ],
+                norm_layer=self.norm_layer,
+                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+                use_checkpoint=use_checkpoint,
+                norm_before_mlp=self.norm_before_mlp,
+            )
+            self.layers.append(layer)
+
+        self.norm = self.norm_layer(self.num_features)
+        self.avgpool = nn.AdaptiveAvgPool1d(1)
+        self.maxpool = nn.AdaptiveMaxPool1d(1)
+
+        SF = (
+            self.spec_size
+            // (2 ** (len(self.depths) - 1))
+            // self.patch_stride[0]
+            // self.freq_ratio
+        )
+        self.tscam_conv = nn.Conv2d(
+            in_channels=self.num_features,
+            out_channels=self.num_classes,
+            kernel_size=(SF, 3),
+            padding=(0, 1),
+        )
+        self.head = nn.Linear(num_classes, num_classes)
+
+        if (self.enable_fusion) and (
+            self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
+        ):
+            self.mel_conv1d = nn.Sequential(
+                nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
+                nn.BatchNorm1d(64),
+            )
+            if self.fusion_type == "daf_1d":
+                self.fusion_model = DAF()
+            elif self.fusion_type == "aff_1d":
+                self.fusion_model = AFF(channels=64, type="1D")
+            elif self.fusion_type == "iaff_1d":
+                self.fusion_model = iAFF(channels=64, type="1D")
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=0.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {"absolute_pos_embed"}
+
+    @torch.jit.ignore
+    def no_weight_decay_keywords(self):
+        return {"relative_position_bias_table"}
+
+    def forward_features(self, x, longer_idx=None):
+        # A deprecated optimization for using a hierarchical output from different blocks
+
+        frames_num = x.shape[2]
+        x = self.patch_embed(x, longer_idx=longer_idx)
+        if self.ape:
+            x = x + self.absolute_pos_embed
+        x = self.pos_drop(x)
+        for i, layer in enumerate(self.layers):
+            x, attn = layer(x)
+        # for x
+        x = self.norm(x)
+        B, N, C = x.shape
+        SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
+        ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
+        x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
+        B, C, F, T = x.shape
+        # group 2D CNN
+        c_freq_bin = F // self.freq_ratio
+        x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
+        x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
+        # get latent_output
+        fine_grained_latent_output = torch.mean(x, dim=2)
+        fine_grained_latent_output = interpolate(
+            fine_grained_latent_output.permute(0, 2, 1).contiguous(),
+            8 * self.patch_stride[1],
+        )
+
+        latent_output = self.avgpool(torch.flatten(x, 2))
+        latent_output = torch.flatten(latent_output, 1)
+
+        # display the attention map, if needed
+
+        x = self.tscam_conv(x)
+        x = torch.flatten(x, 2)  # B, C, T
+
+        fpx = interpolate(
+            torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
+        )
+
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+
+        output_dict = {
+            "framewise_output": fpx,  # already sigmoided
+            "clipwise_output": torch.sigmoid(x),
+            "fine_grained_embedding": fine_grained_latent_output,
+            "embedding": latent_output,
+        }
+
+        return output_dict
+
+    def crop_wav(self, x, crop_size, spe_pos=None):
+        time_steps = x.shape[2]
+        tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
+        for i in range(len(x)):
+            if spe_pos is None:
+                crop_pos = random.randint(0, time_steps - crop_size - 1)
+            else:
+                crop_pos = spe_pos
+            tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
+        return tx
+
+    # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
+    def reshape_wav2img(self, x):
+        B, C, T, F = x.shape
+        target_T = int(self.spec_size * self.freq_ratio)
+        target_F = self.spec_size // self.freq_ratio
+        assert (
+            T <= target_T and F <= target_F
+        ), "the wav size should less than or equal to the swin input size"
+        # to avoid bicubic zero error
+        if T < target_T:
+            x = nn.functional.interpolate(
+                x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
+            )
+        if F < target_F:
+            x = nn.functional.interpolate(
+                x, (x.shape[2], target_F), mode="bicubic", align_corners=True
+            )
+        x = x.permute(0, 1, 3, 2).contiguous()
+        x = x.reshape(
+            x.shape[0],
+            x.shape[1],
+            x.shape[2],
+            self.freq_ratio,
+            x.shape[3] // self.freq_ratio,
+        )
+        # print(x.shape)
+        x = x.permute(0, 1, 3, 2, 4).contiguous()
+        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
+        return x
+
+    # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
+    def repeat_wat2img(self, x, cur_pos):
+        B, C, T, F = x.shape
+        target_T = int(self.spec_size * self.freq_ratio)
+        target_F = self.spec_size // self.freq_ratio
+        assert (
+            T <= target_T and F <= target_F
+        ), "the wav size should less than or equal to the swin input size"
+        # to avoid bicubic zero error
+        if T < target_T:
+            x = nn.functional.interpolate(
+                x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
+            )
+        if F < target_F:
+            x = nn.functional.interpolate(
+                x, (x.shape[2], target_F), mode="bicubic", align_corners=True
+            )
+        x = x.permute(0, 1, 3, 2).contiguous()  # B C F T
+        x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
+        x = x.repeat(repeats=(1, 1, 4, 1))
+        return x
+
+    def forward(
+        self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
+    ):  # out_feat_keys: List[str] = None):
+        if self.enable_fusion and x["longer"].sum() == 0:
+            # if no audio is longer than 10s, then randomly select one audio to be longer
+            x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
+
+        if not self.enable_fusion:
+            x = x["waveform"].to(device=device, non_blocking=True)
+            x = self.spectrogram_extractor(x)  # (batch_size, 1, time_steps, freq_bins)
+            x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)
+            x = x.transpose(1, 3)
+            x = self.bn0(x)
+            x = x.transpose(1, 3)
+            if self.training:
+                x = self.spec_augmenter(x)
+
+            if self.training and mixup_lambda is not None:
+                x = do_mixup(x, mixup_lambda)
+
+            x = self.reshape_wav2img(x)
+            output_dict = self.forward_features(x)
+        else:
+            longer_list = x["longer"].to(device=device, non_blocking=True)
+            x = x["mel_fusion"].to(device=device, non_blocking=True)
+            x = x.transpose(1, 3)
+            x = self.bn0(x)
+            x = x.transpose(1, 3)
+            longer_list_idx = torch.where(longer_list)[0]
+            if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
+                new_x = x[:, 0:1, :, :].clone().contiguous()
+                if len(longer_list_idx) > 0:
+                    # local processing
+                    fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
+                    FB, FC, FT, FF = fusion_x_local.size()
+                    fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
+                    fusion_x_local = torch.permute(
+                        fusion_x_local, (0, 2, 1)
+                    ).contiguous()
+                    fusion_x_local = self.mel_conv1d(fusion_x_local)
+                    fusion_x_local = fusion_x_local.view(
+                        FB, FC, FF, fusion_x_local.size(-1)
+                    )
+                    fusion_x_local = (
+                        torch.permute(fusion_x_local, (0, 2, 1, 3))
+                        .contiguous()
+                        .flatten(2)
+                    )
+                    if fusion_x_local.size(-1) < FT:
+                        fusion_x_local = torch.cat(
+                            [
+                                fusion_x_local,
+                                torch.zeros(
+                                    (FB, FF, FT - fusion_x_local.size(-1)),
+                                    device=device,
+                                ),
+                            ],
+                            dim=-1,
+                        )
+                    else:
+                        fusion_x_local = fusion_x_local[:, :, :FT]
+                    # 1D fusion
+                    new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
+                    new_x[longer_list_idx] = self.fusion_model(
+                        new_x[longer_list_idx], fusion_x_local
+                    )
+                    x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
+                else:
+                    x = new_x
+
+            elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
+                x = x  # no change
+
+            if self.training:
+                x = self.spec_augmenter(x)
+            if self.training and mixup_lambda is not None:
+                x = do_mixup(x, mixup_lambda)
+
+            x = self.reshape_wav2img(x)
+            output_dict = self.forward_features(x, longer_idx=longer_list_idx)
+
+        # if infer_mode:
+        #     # in infer mode. we need to handle different length audio input
+        #     frame_num = x.shape[2]
+        #     target_T = int(self.spec_size * self.freq_ratio)
+        #     repeat_ratio = math.floor(target_T / frame_num)
+        #     x = x.repeat(repeats=(1,1,repeat_ratio,1))
+        #     x = self.reshape_wav2img(x)
+        #     output_dict = self.forward_features(x)
+        # else:
+        #     if x.shape[2] > self.freq_ratio * self.spec_size:
+        #         if self.training:
+        #             x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
+        #             x = self.reshape_wav2img(x)
+        #             output_dict = self.forward_features(x)
+        #         else:
+        #             # Change: Hard code here
+        #             overlap_size = (x.shape[2] - 1) // 4
+        #             output_dicts = []
+        #             crop_size = (x.shape[2] - 1) // 2
+        #             for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
+        #                 tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
+        #                 tx = self.reshape_wav2img(tx)
+        #                 output_dicts.append(self.forward_features(tx))
+        #             clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
+        #             framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
+        #             for d in output_dicts:
+        #                 clipwise_output += d["clipwise_output"]
+        #                 framewise_output += d["framewise_output"]
+        #             clipwise_output  = clipwise_output / len(output_dicts)
+        #             framewise_output = framewise_output / len(output_dicts)
+        #             output_dict = {
+        #                 'framewise_output': framewise_output,
+        #                 'clipwise_output': clipwise_output
+        #             }
+        #     else: # this part is typically used, and most easy one
+        #         x = self.reshape_wav2img(x)
+        #         output_dict = self.forward_features(x)
+        # x = self.head(x)
+
+        # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
+
+        return output_dict
+
+
+def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
+    try:
+        assert audio_cfg.model_name in [
+            "tiny",
+            "base",
+            "large",
+        ], "model name for HTS-AT is wrong!"
+        if audio_cfg.model_name == "tiny":
+            model = HTSAT_Swin_Transformer(
+                spec_size=256,
+                patch_size=4,
+                patch_stride=(4, 4),
+                num_classes=audio_cfg.class_num,
+                embed_dim=96,
+                depths=[2, 2, 6, 2],
+                num_heads=[4, 8, 16, 32],
+                window_size=8,
+                config=audio_cfg,
+                enable_fusion=enable_fusion,
+                fusion_type=fusion_type,
+            )
+        elif audio_cfg.model_name == "base":
+            model = HTSAT_Swin_Transformer(
+                spec_size=256,
+                patch_size=4,
+                patch_stride=(4, 4),
+                num_classes=audio_cfg.class_num,
+                embed_dim=128,
+                depths=[2, 2, 12, 2],
+                num_heads=[4, 8, 16, 32],
+                window_size=8,
+                config=audio_cfg,
+                enable_fusion=enable_fusion,
+                fusion_type=fusion_type,
+            )
+        elif audio_cfg.model_name == "large":
+            model = HTSAT_Swin_Transformer(
+                spec_size=256,
+                patch_size=4,
+                patch_stride=(4, 4),
+                num_classes=audio_cfg.class_num,
+                embed_dim=256,
+                depths=[2, 2, 12, 2],
+                num_heads=[4, 8, 16, 32],
+                window_size=8,
+                config=audio_cfg,
+                enable_fusion=enable_fusion,
+                fusion_type=fusion_type,
+            )
+
+        return model
+    except:
+        raise RuntimeError(
+            f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
+        )
diff --git a/audioldm2/clap/open_clip/loss.py b/audioldm2/clap/open_clip/loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..37faba58f3693d0659512ab1d6e19614fbda0675
--- /dev/null
+++ b/audioldm2/clap/open_clip/loss.py
@@ -0,0 +1,397 @@
+import torch
+import torch.distributed.nn
+from torch import distributed as dist, nn as nn
+from torch.nn import functional as F
+import numpy as np
+from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
+
+try:
+    import horovod.torch as hvd
+except ImportError:
+    hvd = None
+
+
+def gather_features(
+    audio_features,
+    text_features,
+    audio_features_mlp=None,
+    text_features_mlp=None,
+    local_loss=False,
+    gather_with_grad=False,
+    rank=0,
+    world_size=1,
+    use_horovod=False,
+    mlp_loss=False,
+):
+    if use_horovod:
+        assert hvd is not None, "Please install horovod"
+        if gather_with_grad:
+            all_audio_features = hvd.allgather(audio_features)
+            all_text_features = hvd.allgather(text_features)
+            if mlp_loss:
+                all_audio_features_mlp = hvd.allgather(audio_features_mlp)
+                all_text_features_mlp = hvd.allgather(text_features_mlp)
+        else:
+            with torch.no_grad():
+                all_audio_features = hvd.allgather(audio_features)
+                all_text_features = hvd.allgather(text_features)
+                if mlp_loss:
+                    all_audio_features_mlp = hvd.allgather(audio_features_mlp)
+                    all_text_features_mlp = hvd.allgather(text_features_mlp)
+            if not local_loss:
+                # ensure grads for local rank when all_* features don't have a gradient
+                gathered_audio_features = list(
+                    all_audio_features.chunk(world_size, dim=0)
+                )
+                gathered_text_features = list(
+                    all_text_features.chunk(world_size, dim=0)
+                )
+                gathered_audio_features[rank] = audio_features
+                gathered_text_features[rank] = text_features
+                all_audio_features = torch.cat(gathered_audio_features, dim=0)
+                all_text_features = torch.cat(gathered_text_features, dim=0)
+                if mlp_loss:
+                    gathered_audio_features_mlp = list(
+                        all_audio_features_mlp.chunk(world_size, dim=0)
+                    )
+                    gathered_text_features_mlp = list(
+                        all_text_features_mlp.chunk(world_size, dim=0)
+                    )
+                    gathered_audio_features_mlp[rank] = audio_features_mlp
+                    gathered_text_features_mlp[rank] = text_features_mlp
+                    all_audio_features_mlp = torch.cat(
+                        gathered_audio_features_mlp, dim=0
+                    )
+                    all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
+    else:
+        # We gather tensors from all gpus
+        if gather_with_grad:
+            all_audio_features = torch.cat(
+                torch.distributed.nn.all_gather(audio_features), dim=0
+            )
+            all_text_features = torch.cat(
+                torch.distributed.nn.all_gather(text_features), dim=0
+            )
+            if mlp_loss:
+                all_audio_features_mlp = torch.cat(
+                    torch.distributed.nn.all_gather(audio_features_mlp), dim=0
+                )
+                all_text_features_mlp = torch.cat(
+                    torch.distributed.nn.all_gather(text_features_mlp), dim=0
+                )
+        else:
+            gathered_audio_features = [
+                torch.zeros_like(audio_features) for _ in range(world_size)
+            ]
+            gathered_text_features = [
+                torch.zeros_like(text_features) for _ in range(world_size)
+            ]
+            dist.all_gather(gathered_audio_features, audio_features)
+            dist.all_gather(gathered_text_features, text_features)
+            if mlp_loss:
+                gathered_audio_features_mlp = [
+                    torch.zeros_like(audio_features_mlp) for _ in range(world_size)
+                ]
+                gathered_text_features_mlp = [
+                    torch.zeros_like(text_features_mlp) for _ in range(world_size)
+                ]
+                dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
+                dist.all_gather(gathered_text_features_mlp, text_features_mlp)
+            if not local_loss:
+                # ensure grads for local rank when all_* features don't have a gradient
+                gathered_audio_features[rank] = audio_features
+                gathered_text_features[rank] = text_features
+                if mlp_loss:
+                    gathered_audio_features_mlp[rank] = audio_features_mlp
+                    gathered_text_features_mlp[rank] = text_features_mlp
+
+            all_audio_features = torch.cat(gathered_audio_features, dim=0)
+            all_text_features = torch.cat(gathered_text_features, dim=0)
+            if mlp_loss:
+                all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
+                all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
+    if mlp_loss:
+        return (
+            all_audio_features,
+            all_text_features,
+            all_audio_features_mlp,
+            all_text_features_mlp,
+        )
+    else:
+        return all_audio_features, all_text_features
+
+
+class ClipLoss(nn.Module):
+    def __init__(
+        self,
+        local_loss=False,
+        gather_with_grad=False,
+        cache_labels=False,
+        rank=0,
+        world_size=1,
+        use_horovod=False,
+        mlp_loss=False,
+        weight_loss_kappa=0,
+    ):
+        super().__init__()
+        self.local_loss = local_loss
+        self.gather_with_grad = gather_with_grad
+        self.cache_labels = cache_labels
+        self.rank = rank
+        self.world_size = world_size
+        self.use_horovod = use_horovod
+        self.mlp_loss = mlp_loss
+        self.weighted_loss = bool(weight_loss_kappa != 0)
+        self.weight_loss_kappa = weight_loss_kappa
+        # cache state
+        self.prev_num_logits = 0
+        self.labels = {}
+
+    def forward(
+        self,
+        audio_features,
+        text_features,
+        logit_scale_a,
+        logit_scale_t=None,
+        audio_features_mlp=None,
+        text_features_mlp=None,
+    ):
+        device = audio_features.device
+        if self.mlp_loss:
+            if self.world_size > 1:
+                (
+                    all_audio_features,
+                    all_text_features,
+                    all_audio_features_mlp,
+                    all_text_features_mlp,
+                ) = gather_features(
+                    audio_features=audio_features,
+                    text_features=text_features,
+                    audio_features_mlp=audio_features_mlp,
+                    text_features_mlp=text_features_mlp,
+                    local_loss=self.local_loss,
+                    gather_with_grad=self.gather_with_grad,
+                    rank=self.rank,
+                    world_size=self.world_size,
+                    use_horovod=self.use_horovod,
+                    mlp_loss=self.mlp_loss,
+                )
+                if self.local_loss:
+                    a_logits_per_audio = (
+                        logit_scale_a * audio_features @ all_text_features_mlp.T
+                    )
+                    a_logits_per_text = (
+                        logit_scale_a * text_features_mlp @ all_audio_features.T
+                    )
+                    t_logits_per_audio = (
+                        logit_scale_t * audio_features_mlp @ all_text_features.T
+                    )
+                    t_logits_per_text = (
+                        logit_scale_t * text_features @ all_audio_features_mlp.T
+                    )
+                else:
+                    a_logits_per_audio = (
+                        logit_scale_a * all_audio_features @ all_text_features_mlp.T
+                    )
+                    a_logits_per_text = a_logits_per_audio.T
+                    t_logits_per_audio = (
+                        logit_scale_t * all_audio_features_mlp @ all_text_features.T
+                    )
+                    t_logits_per_text = t_logits_per_audio.T
+            else:
+                a_logits_per_audio = (
+                    logit_scale_a * audio_features @ text_features_mlp.T
+                )
+                a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
+                t_logits_per_audio = (
+                    logit_scale_t * audio_features_mlp @ text_features.T
+                )
+                t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
+
+            # calculated ground-truth and cache if enabled
+            num_logits = a_logits_per_audio.shape[0]
+            if self.prev_num_logits != num_logits or device not in self.labels:
+                labels = torch.arange(num_logits, device=device, dtype=torch.long)
+                if self.world_size > 1 and self.local_loss:
+                    labels = labels + num_logits * self.rank
+                if self.cache_labels:
+                    self.labels[device] = labels
+                    self.prev_num_logits = num_logits
+            else:
+                labels = self.labels[device]
+
+            if not self.weighted_loss:
+                total_loss = (
+                    F.cross_entropy(a_logits_per_audio, labels)
+                    + F.cross_entropy(a_logits_per_text, labels)
+                    + F.cross_entropy(t_logits_per_audio, labels)
+                    + F.cross_entropy(t_logits_per_text, labels)
+                ) / 4
+            else:
+                audio_weight = (audio_features @ audio_features.T).detach()
+                audio_weight = (
+                    torch.exp(
+                        torch.sum(audio_weight, axis=1)
+                        / (self.weight_loss_kappa * len(audio_weight))
+                    )
+                ).detach()
+                text_weight = (text_features @ text_features.T).detach()
+                text_weight = (
+                    torch.exp(
+                        torch.sum(text_weight, axis=1)
+                        / (self.weight_loss_kappa * len(text_features))
+                    )
+                ).detach()
+                total_loss = (
+                    F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
+                    + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
+                    + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
+                    + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
+                ) / 4
+        else:
+            if self.world_size > 1:
+                all_audio_features, all_text_features = gather_features(
+                    audio_features=audio_features,
+                    text_features=text_features,
+                    local_loss=self.local_loss,
+                    gather_with_grad=self.gather_with_grad,
+                    rank=self.rank,
+                    world_size=self.world_size,
+                    use_horovod=self.use_horovod,
+                    mlp_loss=self.mlp_loss,
+                )
+
+                if self.local_loss:
+                    logits_per_audio = (
+                        logit_scale_a * audio_features @ all_text_features.T
+                    )
+                    logits_per_text = (
+                        logit_scale_a * text_features @ all_audio_features.T
+                    )
+                else:
+                    logits_per_audio = (
+                        logit_scale_a * all_audio_features @ all_text_features.T
+                    )
+                    logits_per_text = logits_per_audio.T
+            else:
+                logits_per_audio = logit_scale_a * audio_features @ text_features.T
+                logits_per_text = logit_scale_a * text_features @ audio_features.T
+
+            # calculated ground-truth and cache if enabled
+            num_logits = logits_per_audio.shape[0]
+            if self.prev_num_logits != num_logits or device not in self.labels:
+                labels = torch.arange(num_logits, device=device, dtype=torch.long)
+                if self.world_size > 1 and self.local_loss:
+                    labels = labels + num_logits * self.rank
+                if self.cache_labels:
+                    self.labels[device] = labels
+                    self.prev_num_logits = num_logits
+            else:
+                labels = self.labels[device]
+            if not self.weighted_loss:
+                total_loss = (
+                    F.cross_entropy(logits_per_audio, labels)
+                    + F.cross_entropy(logits_per_text, labels)
+                ) / 2
+            else:
+                audio_weight = (all_audio_features @ all_audio_features.T).detach()
+                audio_weight = (
+                    torch.exp(
+                        torch.sum(audio_weight, axis=1)
+                        / (self.weight_loss_kappa * len(all_audio_features))
+                    )
+                ).detach()
+                text_weight = (all_text_features @ all_text_features.T).detach()
+                text_weight = (
+                    torch.exp(
+                        torch.sum(text_weight, axis=1)
+                        / (self.weight_loss_kappa * len(all_text_features))
+                    )
+                ).detach()
+                total_loss = (
+                    F.cross_entropy(logits_per_audio, labels, weight=text_weight)
+                    + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
+                ) / 2
+        return total_loss
+
+
+def lp_gather_features(pred, target, world_size=1, use_horovod=False):
+    if use_horovod:
+        assert hvd is not None, "Please install horovod"
+        with torch.no_grad():
+            all_preds = hvd.allgather(pred)
+            all_targets = hvd.allgath(target)
+    else:
+        gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
+        gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
+
+        dist.all_gather(gathered_preds, pred)
+        dist.all_gather(gathered_targets, target)
+        all_preds = torch.cat(gathered_preds, dim=0)
+        all_targets = torch.cat(gathered_targets, dim=0)
+
+    return all_preds, all_targets
+
+
+def get_map(pred, target):
+    pred = torch.sigmoid(pred).numpy()
+    target = target.numpy()
+    return np.mean(average_precision_score(target, pred, average=None))
+
+
+def get_acc(pred, target):
+    pred = torch.argmax(pred, 1).numpy()
+    target = torch.argmax(target, 1).numpy()
+    return accuracy_score(target, pred)
+
+
+def get_mauc(pred, target):
+    pred = torch.sigmoid(pred).numpy()
+    target = target.numpy()
+    return np.mean(roc_auc_score(target, pred, average=None))
+
+
+class LPMetrics(object):
+    def __init__(self, metric_names=["map", "acc", "mauc"]):
+        self.metrics = []
+        for name in metric_names:
+            self.metrics.append(self.get_metric(name))
+        self.metric_names = metric_names
+
+    def get_metric(self, name):
+        if name == "map":
+            return get_map
+        elif name == "acc":
+            return get_acc
+        elif name == "mauc":
+            return get_mauc
+        else:
+            raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
+
+    def evaluate_mertics(self, pred, target):
+        metric_dict = {}
+        for i in range(len(self.metric_names)):
+            metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
+        return metric_dict
+
+
+def calc_celoss(pred, target):
+    target = torch.argmax(target, 1).long()
+    return nn.CrossEntropyLoss()(pred, target)
+
+
+class LPLoss(nn.Module):
+    def __init__(self, loss_name):
+        super().__init__()
+        if loss_name == "bce":
+            self.loss_func = nn.BCEWithLogitsLoss()
+        elif loss_name == "ce":
+            self.loss_func = calc_celoss
+        elif loss_name == "mse":
+            self.loss_func = nn.MSELoss()
+        else:
+            raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
+
+    def forward(self, pred, target):
+        loss = self.loss_func(pred, target)
+        return loss
diff --git a/audioldm2/clap/open_clip/model.py b/audioldm2/clap/open_clip/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..130fb582d016868d478e2d10e90d7fc0e7999078
--- /dev/null
+++ b/audioldm2/clap/open_clip/model.py
@@ -0,0 +1,931 @@
+""" CLAP Model
+
+Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+Adapted to the Audio Task.
+"""
+
+from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Tuple, Union, Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+import logging
+from .utils import freeze_batch_norm_2d
+
+from .pann_model import create_pann_model
+from .htsat import create_htsat_model
+from transformers import BertModel, RobertaModel, BartModel, RobertaConfig
+
+
+class MLPLayers(nn.Module):
+    def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
+        super(MLPLayers, self).__init__()
+        self.nonlin = nonlin
+        self.dropout = dropout
+
+        sequence = []
+        for u0, u1 in zip(units[:-1], units[1:]):
+            sequence.append(nn.Linear(u0, u1))
+            sequence.append(self.nonlin)
+            sequence.append(nn.Dropout(self.dropout))
+        sequence = sequence[:-2]
+
+        self.sequential = nn.Sequential(*sequence)
+
+    def forward(self, X):
+        X = self.sequential(X)
+        return X
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1):
+        super().__init__()
+
+        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+
+        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = None
+        self.stride = stride
+
+        if stride > 1 or inplanes != planes * Bottleneck.expansion:
+            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+            self.downsample = nn.Sequential(
+                OrderedDict(
+                    [
+                        ("-1", nn.AvgPool2d(stride)),
+                        (
+                            "0",
+                            nn.Conv2d(
+                                inplanes,
+                                planes * self.expansion,
+                                1,
+                                stride=1,
+                                bias=False,
+                            ),
+                        ),
+                        ("1", nn.BatchNorm2d(planes * self.expansion)),
+                    ]
+                )
+            )
+
+    def forward(self, x: torch.Tensor):
+        identity = x
+
+        out = self.relu(self.bn1(self.conv1(x)))
+        out = self.relu(self.bn2(self.conv2(out)))
+        out = self.avgpool(out)
+        out = self.bn3(self.conv3(out))
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+        return out
+
+
+class AttentionPool2d(nn.Module):
+    def __init__(
+        self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
+    ):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(
+            torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
+        )
+        self.k_proj = nn.Linear(embed_dim, embed_dim)
+        self.q_proj = nn.Linear(embed_dim, embed_dim)
+        self.v_proj = nn.Linear(embed_dim, embed_dim)
+        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+        self.num_heads = num_heads
+
+    def forward(self, x):
+        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
+            2, 0, 1
+        )  # NCHW -> (HW)NC
+        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
+        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
+        x, _ = F.multi_head_attention_forward(
+            query=x,
+            key=x,
+            value=x,
+            embed_dim_to_check=x.shape[-1],
+            num_heads=self.num_heads,
+            q_proj_weight=self.q_proj.weight,
+            k_proj_weight=self.k_proj.weight,
+            v_proj_weight=self.v_proj.weight,
+            in_proj_weight=None,
+            in_proj_bias=torch.cat(
+                [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
+            ),
+            bias_k=None,
+            bias_v=None,
+            add_zero_attn=False,
+            dropout_p=0,
+            out_proj_weight=self.c_proj.weight,
+            out_proj_bias=self.c_proj.bias,
+            use_separate_proj_weight=True,
+            training=self.training,
+            need_weights=False,
+        )
+
+        return x[0]
+
+
+class ModifiedResNet(nn.Module):
+    """
+    A ResNet class that is similar to torchvision's but contains the following changes:
+    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+    - The final pooling layer is a QKV attention instead of an average pool
+    """
+
+    def __init__(self, layers, output_dim, heads, image_size=224, width=64):
+        super().__init__()
+        self.output_dim = output_dim
+        self.image_size = image_size
+
+        # the 3-layer stem
+        self.conv1 = nn.Conv2d(
+            3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
+        )
+        self.bn1 = nn.BatchNorm2d(width // 2)
+        self.conv2 = nn.Conv2d(
+            width // 2, width // 2, kernel_size=3, padding=1, bias=False
+        )
+        self.bn2 = nn.BatchNorm2d(width // 2)
+        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(width)
+        self.avgpool = nn.AvgPool2d(2)
+        self.relu = nn.ReLU(inplace=True)
+
+        # residual layers
+        self._inplanes = width  # this is a *mutable* variable used during construction
+        self.layer1 = self._make_layer(width, layers[0])
+        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+        embed_dim = width * 32  # the ResNet feature dimension
+        self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
+
+        self.init_parameters()
+
+    def _make_layer(self, planes, blocks, stride=1):
+        layers = [Bottleneck(self._inplanes, planes, stride)]
+
+        self._inplanes = planes * Bottleneck.expansion
+        for _ in range(1, blocks):
+            layers.append(Bottleneck(self._inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def init_parameters(self):
+        if self.attnpool is not None:
+            std = self.attnpool.c_proj.in_features**-0.5
+            nn.init.normal_(self.attnpool.q_proj.weight, std=std)
+            nn.init.normal_(self.attnpool.k_proj.weight, std=std)
+            nn.init.normal_(self.attnpool.v_proj.weight, std=std)
+            nn.init.normal_(self.attnpool.c_proj.weight, std=std)
+
+        for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
+            for name, param in resnet_block.named_parameters():
+                if name.endswith("bn3.weight"):
+                    nn.init.zeros_(param)
+
+    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+        assert (
+            unlocked_groups == 0
+        ), "partial locking not currently supported for this model"
+        for param in self.parameters():
+            param.requires_grad = False
+        if freeze_bn_stats:
+            freeze_batch_norm_2d(self)
+
+    def stem(self, x):
+        for conv, bn in [
+            (self.conv1, self.bn1),
+            (self.conv2, self.bn2),
+            (self.conv3, self.bn3),
+        ]:
+            x = self.relu(bn(conv(x)))
+        x = self.avgpool(x)
+        return x
+
+    def forward(self, x):
+        x = self.stem(x)
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.attnpool(x)
+
+        return x
+
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+        return x.to(orig_type)
+
+
+class QuickGELU(nn.Module):
+    # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = LayerNorm(d_model)
+        self.mlp = nn.Sequential(
+            OrderedDict(
+                [
+                    ("c_fc", nn.Linear(d_model, d_model * 4)),
+                    ("gelu", act_layer()),
+                    ("c_proj", nn.Linear(d_model * 4, d_model)),
+                ]
+            )
+        )
+        self.ln_2 = LayerNorm(d_model)
+
+    def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+        return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
+
+    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+        x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
+        x = x + self.mlp(self.ln_2(x))
+        return x
+
+
+class Transformer(nn.Module):
+    def __init__(
+        self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
+    ):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.resblocks = nn.ModuleList(
+            [
+                ResidualAttentionBlock(width, heads, act_layer=act_layer)
+                for _ in range(layers)
+            ]
+        )
+
+    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+        for r in self.resblocks:
+            x = r(x, attn_mask=attn_mask)
+        return x
+
+
+class VisualTransformer(nn.Module):
+    def __init__(
+        self,
+        image_size: int,
+        patch_size: int,
+        width: int,
+        layers: int,
+        heads: int,
+        output_dim: int,
+        act_layer: Callable = nn.GELU,
+    ):
+        super().__init__()
+        self.image_size = image_size
+        self.output_dim = output_dim
+        self.conv1 = nn.Conv2d(
+            in_channels=3,
+            out_channels=width,
+            kernel_size=patch_size,
+            stride=patch_size,
+            bias=False,
+        )
+
+        scale = width**-0.5
+        self.class_embedding = nn.Parameter(scale * torch.randn(width))
+        self.positional_embedding = nn.Parameter(
+            scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
+        )
+        self.ln_pre = LayerNorm(width)
+
+        self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
+
+        self.ln_post = LayerNorm(width)
+        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+        assert (
+            unlocked_groups == 0
+        ), "partial locking not currently supported for this model"
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, x: torch.Tensor):
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+        x = torch.cat(
+            [
+                self.class_embedding.to(x.dtype)
+                + torch.zeros(
+                    x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
+                ),
+                x,
+            ],
+            dim=1,
+        )  # shape = [*, grid ** 2 + 1, width]
+        x = x + self.positional_embedding.to(x.dtype)
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.text_branch(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        x = self.ln_post(x[:, 0, :])
+
+        if self.proj is not None:
+            x = x @ self.proj
+
+        return x
+
+
+@dataclass
+class CLAPVisionCfg:
+    layers: Union[Tuple[int, int, int, int], int] = 12
+    width: int = 768
+    patch_size: int = 16
+    image_size: Union[Tuple[int, int], int] = 224
+    timm_model_name: str = (
+        None  # a valid model name overrides layers, width, patch_size
+    )
+    timm_model_pretrained: bool = (
+        False  # use (imagenet) pretrained weights for named model
+    )
+    timm_pool: str = (
+        "avg"  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+    )
+    timm_proj: str = (
+        "linear"  # linear projection for timm model output ('linear', 'mlp', '')
+    )
+
+
+# Audio Config Class
+@dataclass
+class CLAPAudioCfp:
+    model_type: str = "PANN"
+    model_name: str = "Cnn14"
+    sample_rate: int = 48000
+    # Param
+    audio_length: int = 1024
+    window_size: int = 1024
+    hop_size: int = 1024
+    fmin: int = 50
+    fmax: int = 14000
+    class_num: int = 527
+    mel_bins: int = 64
+    clip_samples: int = 480000
+
+
+@dataclass
+class CLAPTextCfg:
+    context_length: int
+    vocab_size: int
+    width: int
+    heads: int
+    layers: int
+    model_type: str
+
+
+class CLAP(nn.Module):
+    def __init__(
+        self,
+        embed_dim: int,
+        audio_cfg: CLAPAudioCfp,
+        text_cfg: CLAPTextCfg,
+        quick_gelu: bool = False,
+        enable_fusion: bool = False,
+        fusion_type: str = "None",
+        joint_embed_shape: int = 512,
+        mlp_act: str = "relu",
+    ):
+        super().__init__()
+        if isinstance(audio_cfg, dict):
+            audio_cfg = CLAPAudioCfp(**audio_cfg)
+        if isinstance(text_cfg, dict):
+            text_cfg = CLAPTextCfg(**text_cfg)
+
+        self.audio_cfg = audio_cfg
+        self.text_cfg = text_cfg
+        self.enable_fusion = enable_fusion
+        self.fusion_type = fusion_type
+        self.joint_embed_shape = joint_embed_shape
+        self.mlp_act = mlp_act
+
+        self.context_length = text_cfg.context_length
+
+        # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
+        # memory efficient in recent PyTorch releases (>= 1.10).
+        # NOTE: timm models always use native GELU regardless of quick_gelu flag.
+        act_layer = QuickGELU if quick_gelu else nn.GELU
+
+        if mlp_act == "relu":
+            mlp_act_layer = nn.ReLU()
+        elif mlp_act == "gelu":
+            mlp_act_layer = nn.GELU()
+        else:
+            raise NotImplementedError
+
+        # audio branch
+        # audio branch parameters
+        if audio_cfg.model_type == "PANN":
+            self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
+        elif audio_cfg.model_type == "HTSAT":
+            self.audio_branch = create_htsat_model(
+                audio_cfg, enable_fusion, fusion_type
+            )
+        else:
+            logging.error(f"Model config for {audio_cfg.model_type} not found")
+            raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
+
+        # text branch
+        # text branch parameters
+        if text_cfg.model_type == "transformer":
+            self.text_branch = Transformer(
+                width=text_cfg.width,
+                layers=text_cfg.layers,
+                heads=text_cfg.heads,
+                act_layer=act_layer,
+            )
+            self.vocab_size = text_cfg.vocab_size
+            self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
+            self.positional_embedding = nn.Parameter(
+                torch.empty(self.context_length, text_cfg.width)
+            )
+            self.ln_final = LayerNorm(text_cfg.width)
+            self.text_transform = MLPLayers(
+                units=[
+                    self.joint_embed_shape,
+                    self.joint_embed_shape,
+                    self.joint_embed_shape,
+                ],
+                dropout=0.1,
+            )
+            self.text_projection = nn.Sequential(
+                nn.Linear(text_cfg.width, self.joint_embed_shape),
+                mlp_act_layer,
+                nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+            )
+        elif text_cfg.model_type == "bert":
+            self.text_branch = BertModel.from_pretrained("bert-base-uncased")
+            self.text_transform = MLPLayers(
+                units=[
+                    self.joint_embed_shape,
+                    self.joint_embed_shape,
+                    self.joint_embed_shape,
+                ],
+                dropout=0.1,
+            )
+            self.text_projection = nn.Sequential(
+                nn.Linear(768, self.joint_embed_shape),
+                mlp_act_layer,
+                nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+            )
+        elif text_cfg.model_type == "roberta":
+            self.text_branch = RobertaModel(
+                RobertaConfig.from_pretrained("roberta-base")
+            )
+            self.text_transform = MLPLayers(
+                units=[
+                    self.joint_embed_shape,
+                    self.joint_embed_shape,
+                    self.joint_embed_shape,
+                ],
+                dropout=0.1,
+            )
+            self.text_projection = nn.Sequential(
+                nn.Linear(768, self.joint_embed_shape),
+                mlp_act_layer,
+                nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+            )
+        elif text_cfg.model_type == "bart":
+            self.text_branch = BartModel.from_pretrained("facebook/bart-base")
+            self.text_transform = MLPLayers(
+                units=[
+                    self.joint_embed_shape,
+                    self.joint_embed_shape,
+                    self.joint_embed_shape,
+                ],
+                dropout=0.1,
+            )
+            self.text_projection = nn.Sequential(
+                nn.Linear(768, self.joint_embed_shape),
+                mlp_act_layer,
+                nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+            )
+        else:
+            logging.error(f"Model config for {text_cfg.model_type} not found")
+            raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
+        self.text_branch_type = text_cfg.model_type
+        # text branch parameters
+
+        # audio branch parameters
+        self.audio_transform = MLPLayers(
+            units=[
+                self.joint_embed_shape,
+                self.joint_embed_shape,
+                self.joint_embed_shape,
+            ],
+            dropout=0.1,
+        )
+
+        # below here is text branch parameters
+
+        # ============================================================================================================
+        self.audio_projection = nn.Sequential(
+            nn.Linear(embed_dim, self.joint_embed_shape),
+            mlp_act_layer,
+            nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+        )
+
+        self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+        self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+        self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
+
+        self.init_text_branch_parameters()
+
+    def init_text_branch_parameters(self):
+        if self.text_branch_type == "transformer":
+            nn.init.normal_(self.token_embedding.weight, std=0.02)
+            nn.init.normal_(self.positional_embedding, std=0.01)
+            proj_std = (self.text_branch.width**-0.5) * (
+                (2 * self.text_branch.layers) ** -0.5
+            )
+            attn_std = self.text_branch.width**-0.5
+            fc_std = (2 * self.text_branch.width) ** -0.5
+            for block in self.text_branch.resblocks:
+                nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+                nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+                nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+                nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+        if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
+            self.text_branch.embeddings.word_embeddings.weight.shape[-1]
+        elif self.text_branch_type == "bart":
+            self.text_branch.shared.weight.shape[-1]
+        else:
+            self.text_branch.width
+        nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
+        nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
+
+        # deprecated
+        # if hasattr(self.visual, 'init_parameters'):
+        # self.visual.init_parameters()
+
+        # if self.text_projection is not None:
+        #     nn.init.normal_(self.text_projection, std=width**-0.5)
+
+    def build_attention_mask(self):
+        # lazily create causal attention mask, with full attention between the vision tokens
+        # pytorch uses additive attention mask; fill with -inf
+        mask = torch.empty(self.context_length, self.context_length)
+        mask.fill_(float("-inf"))
+        mask.triu_(1)  # zero out the lower diagonal
+        return mask
+
+    def encode_audio(self, audio, device):
+        return self.audio_branch(
+            audio, mixup_lambda=None, device=device
+        )  # mix lambda needs to add
+
+    # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
+    #     tmp = {}
+    #     for k in x[0].keys():
+    #         tmp[k] = []
+    #         for i in range(len(x)):
+    #             tmp[k].append(x[i][k][:77])
+    #     for k in x[0].keys():
+    #         tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
+    #     return tmp
+
+    def encode_text(self, text, device):
+        if self.text_branch_type == "transformer":
+            text = text.to(device=device, non_blocking=True)
+            x = self.token_embedding(text)  # [batch_size, n_ctx, d_model]
+
+            x = x + self.positional_embedding
+            x = x.permute(1, 0, 2)  # NLD -> LND
+            x = self.text_branch(x, attn_mask=self.attn_mask)
+            x = x.permute(1, 0, 2)  # LND -> NLD
+            x = self.ln_final(x)
+
+            # x.shape = [batch_size, n_ctx, transformer.width]
+            # take features from the eot embedding (eot_token is the highest number in each sequence)
+            x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
+        elif self.text_branch_type == "bert":
+            # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
+            # text = BatchEncoding(text)
+            x = self.text_branch(
+                input_ids=text["input_ids"].to(device=device, non_blocking=True),
+                attention_mask=text["attention_mask"].to(
+                    device=device, non_blocking=True
+                ),
+                token_type_ids=text["token_type_ids"].to(
+                    device=device, non_blocking=True
+                ),
+            )["pooler_output"]
+            x = self.text_projection(x)
+        elif self.text_branch_type == "roberta":
+            x = self.text_branch(
+                input_ids=text["input_ids"].to(device=device, non_blocking=True),
+                attention_mask=text["attention_mask"].to(
+                    device=device, non_blocking=True
+                ),
+            )["pooler_output"]
+            x = self.text_projection(x)
+        elif self.text_branch_type == "bart":
+            x = torch.mean(
+                self.text_branch(
+                    input_ids=text["input_ids"].to(device=device, non_blocking=True),
+                    attention_mask=text["attention_mask"].to(
+                        device=device, non_blocking=True
+                    ),
+                )["encoder_last_hidden_state"],
+                axis=1,
+            )
+            x = self.text_projection(x)
+        else:
+            logging.error(f"Model type {self.text_branch_type} not found")
+            raise RuntimeError(f"Model type {self.text_branch_type} not found.")
+        return x
+
+    def forward(self, audio, text, device=None):
+        """Forward audio and text into the CLAP
+
+        Parameters
+        ----------
+        audio: torch.Tensor (batch_size, audio_length)
+            the time-domain audio input / the batch of mel_spec and longer list.
+        text: torch.Tensor () // need to add
+            the text token input
+        """
+        if device is None:
+            if audio is not None:
+                device = audio.device
+            elif text is not None:
+                device = text.device
+        if audio is None and text is None:
+            # a hack to get the logit scale
+            return self.logit_scale_a.exp(), self.logit_scale_t.exp()
+        elif audio is None:
+            return self.encode_text(text, device=device)
+        elif text is None:
+            return self.audio_projection(
+                self.encode_audio(audio, device=device)["embedding"]
+            )
+        audio_features = self.audio_projection(
+            self.encode_audio(audio, device=device)["embedding"]
+        )
+        audio_features = F.normalize(audio_features, dim=-1)
+
+        text_features = self.encode_text(text, device=device)
+        # print("text_features", text_features)
+        # print("text_features.shape", text_features.shape)
+        # print("text_features.type", type(text_features))
+        text_features = F.normalize(text_features, dim=-1)
+
+        audio_features_mlp = self.audio_transform(audio_features)
+        text_features_mlp = self.text_transform(text_features)
+        # Four outputs: audio features (basic & MLP), text features (basic & MLP)
+        return (
+            audio_features,
+            text_features,
+            audio_features_mlp,
+            text_features_mlp,
+            self.logit_scale_a.exp(),
+            self.logit_scale_t.exp(),
+        )
+
+    def get_logit_scale(self):
+        return self.logit_scale_a.exp(), self.logit_scale_t.exp()
+
+    def get_text_embedding(self, data):
+        """Get the text embedding from the model
+
+        Parameters
+        ----------
+        data: torch.Tensor
+            a tensor of text embedding
+
+        Returns
+        ----------
+        text_embed: torch.Tensor
+            a tensor of text_embeds (N, D)
+
+        """
+        device = next(self.parameters()).device
+        for k in data:
+            data[k] = data[k].to(device)
+        text_embeds = self.encode_text(data, device=device)
+        text_embeds = F.normalize(text_embeds, dim=-1)
+
+        return text_embeds
+
+    def get_audio_embedding(self, data):
+        """Get the audio embedding from the model
+
+        Parameters
+        ----------
+        data: a list of dict
+            the audio input dict list from 'get_audio_feature' method
+
+        Returns
+        ----------
+        audio_embed: torch.Tensor
+            a tensor of audio_embeds (N, D)
+
+        """
+        device = next(self.parameters()).device
+        # input_dict = {}
+        # keys = data[0].keys()
+        # for k in keys:
+        #     input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
+        #         device
+        #     )
+        audio_embeds = self.audio_projection(
+            self.encode_audio(data, device=device)["embedding"]
+        )
+        audio_embeds = F.normalize(audio_embeds, dim=-1)
+
+        return audio_embeds
+
+    def audio_infer(self, audio, hopsize=None, device=None):
+        """Forward one audio and produce the audio embedding
+
+        Parameters
+        ----------
+        audio:  (audio_length)
+            the time-domain audio input, notice that it must be only one input
+        hopsize: int
+            the overlap hopsize as the sliding window
+
+        Returns
+        ----------
+        output_dict: {
+            key: [n, (embedding_shape)] if "HTS-AT"
+            or
+            key: [(embedding_shape)] if "PANN"
+        }
+            the list of key values of the audio branch
+
+        """
+
+        assert not self.training, "the inference mode must be run at eval stage"
+        output_dict = {}
+        # PANN
+        if self.audio_cfg.model_type == "PANN":
+            audio_input = audio.unsqueeze(dim=0)
+            output_dict[key] = self.encode_audio(audio_input, device=device)[
+                key
+            ].squeeze(dim=0)
+        elif self.audio_cfg.model_type == "HTSAT":
+            # repeat
+            audio_len = len(audio)
+            k = self.audio_cfg.clip_samples // audio_len
+            if k > 1:
+                audio = audio.repeat(k)
+                audio_len = len(audio)
+
+            if hopsize is None:
+                hopsize = min(hopsize, audio_len)
+
+            if audio_len > self.audio_cfg.clip_samples:
+                audio_input = [
+                    audio[pos : pos + self.audio_cfg.clip_samples].clone()
+                    for pos in range(
+                        0, audio_len - self.audio_cfg.clip_samples, hopsize
+                    )
+                ]
+                audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
+                audio_input = torch.stack(audio_input)
+                output_dict[key] = self.encode_audio(audio_input, device=device)[key]
+            else:
+                audio_input = audio.unsqueeze(dim=0)
+                output_dict[key] = self.encode_audio(audio_input, device=device)[
+                    key
+                ].squeeze(dim=0)
+
+        return output_dict
+
+
+def convert_weights_to_fp16(model: nn.Module):
+    """Convert applicable model parameters to fp16"""
+
+    def _convert_weights_to_fp16(l):
+        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+            l.weight.data = l.weight.data.half()
+            if l.bias is not None:
+                l.bias.data = l.bias.data.half()
+
+        if isinstance(l, nn.MultiheadAttention):
+            for attr in [
+                *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
+                "in_proj_bias",
+                "bias_k",
+                "bias_v",
+            ]:
+                tensor = getattr(l, attr)
+                if tensor is not None:
+                    tensor.data = tensor.data.half()
+
+        for name in ["text_projection", "proj"]:
+            if hasattr(l, name):
+                attr = getattr(l, name)
+                if attr is not None:
+                    attr.data = attr.data.half()
+
+    model.apply(_convert_weights_to_fp16)
+
+
+# Ignore the state dict of the vision part
+def build_model_from_openai_state_dict(
+    state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
+):
+    embed_dim = model_cfg["embed_dim"]
+    audio_cfg = model_cfg["audio_cfg"]
+    text_cfg = model_cfg["text_cfg"]
+    state_dict["positional_embedding"].shape[0]
+    state_dict["token_embedding.weight"].shape[0]
+    transformer_width = state_dict["ln_final.weight"].shape[0]
+    transformer_width // 64
+    transformer_layers = len(
+        set(
+            k.split(".")[2]
+            for k in state_dict
+            if k.startswith(f"transformer.resblocks")
+        )
+    )
+
+    audio_cfg = CLAPAudioCfp(**audio_cfg)
+    text_cfg = CLAPTextCfg(**text_cfg)
+
+    model = CLAP(
+        embed_dim,
+        audio_cfg=audio_cfg,
+        text_cfg=text_cfg,
+        quick_gelu=True,  # OpenAI models were trained with QuickGELU
+        enable_fusion=enable_fusion,
+        fusion_type=fusion_type,
+    )
+    state_dict["logit_scale_a"] = state_dict["logit_scale"]
+    state_dict["logit_scale_t"] = state_dict["logit_scale"]
+    pop_keys = list(state_dict.keys())[::]
+    # pop the visual branch saved weights
+    for key in pop_keys:
+        if key.startswith("visual."):
+            state_dict.pop(key, None)
+
+    for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
+        state_dict.pop(key, None)
+
+    # not use fp16
+    # convert_weights_to_fp16(model)
+    model.load_state_dict(state_dict, strict=False)
+    return model.eval()
+
+
+def trace_model(model, batch_size=256, device=torch.device("cpu")):
+    model.eval()
+    audio_length = model.audio_cfg.audio_length
+    example_audio = torch.ones((batch_size, audio_length), device=device)
+    example_text = torch.zeros(
+        (batch_size, model.context_length), dtype=torch.int, device=device
+    )
+    model = torch.jit.trace_module(
+        model,
+        inputs=dict(
+            forward=(example_audio, example_text),
+            encode_text=(example_text,),
+            encode_image=(example_audio,),
+        ),
+    )
+    model.audio_cfg.audio_length = audio_length  # Question: what does this do?
+    return model
diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-base.json b/audioldm2/clap/open_clip/model_configs/HTSAT-base.json
new file mode 100755
index 0000000000000000000000000000000000000000..6cef625a89daf4431f1c9f72e10bc9640eef2ba8
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/HTSAT-base.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 1024,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 480000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1024,
+        "hop_size": 480,
+        "fmin": 50,
+        "fmax": 14000,
+        "class_num": 527,
+        "model_type": "HTSAT",
+        "model_name": "base"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-large.json b/audioldm2/clap/open_clip/model_configs/HTSAT-large.json
new file mode 100755
index 0000000000000000000000000000000000000000..699cdb1b16855582606551e4196b24aba2ffd871
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/HTSAT-large.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 2048,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 480000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1024,
+        "hop_size": 480,
+        "fmin": 50,
+        "fmax": 14000,
+        "class_num": 527,
+        "model_type": "HTSAT",
+        "model_name": "large"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json
new file mode 100755
index 0000000000000000000000000000000000000000..73e42990fe8361a0df502e7f93d29f19f58c9ecb
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 768,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 480000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1536,
+        "hop_size": 480,
+        "fmin": 50,
+        "fmax": 14000,
+        "class_num": 527,
+        "model_type": "HTSAT",
+        "model_name": "tiny"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json
new file mode 100755
index 0000000000000000000000000000000000000000..a6e7821163d9afa81c27345a1e472475b92af169
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 768,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 480000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1024,
+        "hop_size": 480,
+        "fmin": 50,
+        "fmax": 14000,
+        "class_num": 527,
+        "model_type": "HTSAT",
+        "model_name": "tiny"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-10.json b/audioldm2/clap/open_clip/model_configs/PANN-10.json
new file mode 100755
index 0000000000000000000000000000000000000000..954ddf62921aed7dde9c37ffffec98a2e96a4ee7
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-10.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 1024,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 480000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1024,
+        "hop_size": 480,
+        "fmin": 50,
+        "fmax": 14000,
+        "class_num": 527,
+        "model_type": "PANN",
+        "model_name": "Cnn10"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json
new file mode 100755
index 0000000000000000000000000000000000000000..b7989bc0cd95d0d39049b7524eba508b3e386439
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 2048,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 480000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1024,
+        "hop_size": 480,
+        "fmin": 50,
+        "fmax": 18000,
+        "class_num": 527,
+        "model_type": "PANN",
+        "model_name": "Cnn14"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json
new file mode 100755
index 0000000000000000000000000000000000000000..56bdb56bedc304ffa52d8bf5988cea2c1d82d14e
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 2048,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 960000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1024,
+        "hop_size": 360,
+        "fmin": 50,
+        "fmax": 8000,
+        "class_num": 527,
+        "model_type": "PANN",
+        "model_name": "Cnn14"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json b/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json
new file mode 100755
index 0000000000000000000000000000000000000000..5756e3bebc97cc985f512cb081930fee4e49bec1
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 2048,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 480000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1024,
+        "hop_size": 480,
+        "fmin": 50,
+        "fmax": 14000,
+        "class_num": 527,
+        "model_type": "PANN",
+        "model_name": "Cnn14"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 4
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json b/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json
new file mode 100755
index 0000000000000000000000000000000000000000..5a9e7e208b661619d5e26625e849da1adda8a475
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 2048,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 480000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1536,
+        "hop_size": 480,
+        "fmin": 50,
+        "fmax": 14000,
+        "class_num": 527,
+        "model_type": "PANN",
+        "model_name": "Cnn14"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-14.json b/audioldm2/clap/open_clip/model_configs/PANN-14.json
new file mode 100755
index 0000000000000000000000000000000000000000..39a5134cde1d8c50f4758377c952ef22f07bab41
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-14.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 2048,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 480000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1024,
+        "hop_size": 480,
+        "fmin": 50,
+        "fmax": 14000,
+        "class_num": 527,
+        "model_type": "PANN",
+        "model_name": "Cnn14"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/PANN-6.json b/audioldm2/clap/open_clip/model_configs/PANN-6.json
new file mode 100755
index 0000000000000000000000000000000000000000..21ebc344326de260c386ba77e0ad63cf9b04febf
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/PANN-6.json
@@ -0,0 +1,23 @@
+{
+    "embed_dim": 512,
+    "audio_cfg": {
+        "audio_length": 1024,
+        "clip_samples": 480000,
+        "mel_bins": 64,
+        "sample_rate": 48000,
+        "window_size": 1024,
+        "hop_size": 480,
+        "fmin": 50,
+        "fmax": 14000,
+        "class_num": 527,
+        "model_type": "PANN",
+        "model_name": "Cnn6"
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json b/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json
new file mode 100755
index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json
@@ -0,0 +1,22 @@
+{
+    "embed_dim": 512,
+    "quick_gelu": true,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": [
+            3,
+            4,
+            23,
+            3
+        ],
+        "width": 64,
+        "patch_size": null
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/RN101.json b/audioldm2/clap/open_clip/model_configs/RN101.json
new file mode 100755
index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN101.json
@@ -0,0 +1,21 @@
+{
+    "embed_dim": 512,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": [
+            3,
+            4,
+            23,
+            3
+        ],
+        "width": 64,
+        "patch_size": null
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json b/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json
new file mode 100755
index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json
@@ -0,0 +1,22 @@
+{
+    "embed_dim": 1024,
+    "quick_gelu": true,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": [
+            3,
+            4,
+            6,
+            3
+        ],
+        "width": 64,
+        "patch_size": null
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
diff --git a/audioldm2/clap/open_clip/model_configs/RN50.json b/audioldm2/clap/open_clip/model_configs/RN50.json
new file mode 100755
index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN50.json
@@ -0,0 +1,21 @@
+{
+    "embed_dim": 1024,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": [
+            3,
+            4,
+            6,
+            3
+        ],
+        "width": 64,
+        "patch_size": null
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/RN50x16.json b/audioldm2/clap/open_clip/model_configs/RN50x16.json
new file mode 100755
index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN50x16.json
@@ -0,0 +1,21 @@
+{
+    "embed_dim": 768,
+    "vision_cfg": {
+        "image_size": 384,
+        "layers": [
+            6,
+            8,
+            18,
+            8
+        ],
+        "width": 96,
+        "patch_size": null
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 768,
+        "heads": 12,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/RN50x4.json b/audioldm2/clap/open_clip/model_configs/RN50x4.json
new file mode 100755
index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/RN50x4.json
@@ -0,0 +1,21 @@
+{
+    "embed_dim": 640,
+    "vision_cfg": {
+        "image_size": 288,
+        "layers": [
+            4,
+            6,
+            10,
+            6
+        ],
+        "width": 80,
+        "patch_size": null
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 640,
+        "heads": 10,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-16.json b/audioldm2/clap/open_clip/model_configs/ViT-B-16.json
new file mode 100755
index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/ViT-B-16.json
@@ -0,0 +1,16 @@
+{
+    "embed_dim": 512,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": 12,
+        "width": 768,
+        "patch_size": 16
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json b/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json
new file mode 100755
index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json
@@ -0,0 +1,17 @@
+{
+    "embed_dim": 512,
+    "quick_gelu": true,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": 12,
+        "width": 768,
+        "patch_size": 32
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/ViT-B-32.json b/audioldm2/clap/open_clip/model_configs/ViT-B-32.json
new file mode 100755
index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/ViT-B-32.json
@@ -0,0 +1,16 @@
+{
+    "embed_dim": 512,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": 12,
+        "width": 768,
+        "patch_size": 32
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/model_configs/ViT-L-14.json b/audioldm2/clap/open_clip/model_configs/ViT-L-14.json
new file mode 100755
index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241
--- /dev/null
+++ b/audioldm2/clap/open_clip/model_configs/ViT-L-14.json
@@ -0,0 +1,16 @@
+{
+    "embed_dim": 768,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": 24,
+        "width": 1024,
+        "patch_size": 14
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 768,
+        "heads": 12,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/audioldm2/clap/open_clip/openai.py b/audioldm2/clap/open_clip/openai.py
new file mode 100755
index 0000000000000000000000000000000000000000..3f4eb8b55fe960e1792b3da804b60b3d8f70fe26
--- /dev/null
+++ b/audioldm2/clap/open_clip/openai.py
@@ -0,0 +1,156 @@
+""" OpenAI pretrained model functions
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+
+import os
+import warnings
+from typing import Union, List
+
+import torch
+
+from .model import build_model_from_openai_state_dict
+from .pretrained import (
+    get_pretrained_url,
+    list_pretrained_tag_models,
+    download_pretrained,
+)
+
+__all__ = ["list_openai_models", "load_openai_model"]
+
+
+def list_openai_models() -> List[str]:
+    """Returns the names of available CLIP models"""
+    return list_pretrained_tag_models("openai")
+
+
+def load_openai_model(
+    name: str,
+    model_cfg,
+    device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
+    jit=True,
+    cache_dir=os.path.expanduser("~/.cache/clip"),
+    enable_fusion: bool = False,
+    fusion_type: str = "None",
+):
+    """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
+
+    Parameters
+    ----------
+    name : str
+        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+    device : Union[str, torch.device]
+        The device to put the loaded model
+    jit : bool
+        Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+
+    Returns
+    -------
+    model : torch.nn.Module
+        The CLAP model
+    preprocess : Callable[[PIL.Image], torch.Tensor]
+        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+    """
+    if get_pretrained_url(name, "openai"):
+        model_path = download_pretrained(
+            get_pretrained_url(name, "openai"), root=cache_dir
+        )
+    elif os.path.isfile(name):
+        model_path = name
+    else:
+        raise RuntimeError(
+            f"Model {name} not found; available models = {list_openai_models()}"
+        )
+
+    try:
+        # loading JIT archive
+        model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+        state_dict = None
+    except RuntimeError:
+        # loading saved state dict
+        if jit:
+            warnings.warn(
+                f"File {model_path} is not a JIT archive. Loading as a state dict instead"
+            )
+            jit = False
+        state_dict = torch.load(model_path, map_location="cpu")
+
+    if not jit:
+        try:
+            model = build_model_from_openai_state_dict(
+                state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
+            ).to(device)
+        except KeyError:
+            sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
+            model = build_model_from_openai_state_dict(
+                sd, model_cfg, enable_fusion, fusion_type
+            ).to(device)
+
+        if str(device) == "cpu":
+            model.float()
+        return model
+
+    # patch the device names
+    device_holder = torch.jit.trace(
+        lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
+    )
+    device_node = [
+        n
+        for n in device_holder.graph.findAllNodes("prim::Constant")
+        if "Device" in repr(n)
+    ][-1]
+
+    def patch_device(module):
+        try:
+            graphs = [module.graph] if hasattr(module, "graph") else []
+        except RuntimeError:
+            graphs = []
+
+        if hasattr(module, "forward1"):
+            graphs.append(module.forward1.graph)
+
+        for graph in graphs:
+            for node in graph.findAllNodes("prim::Constant"):
+                if "value" in node.attributeNames() and str(node["value"]).startswith(
+                    "cuda"
+                ):
+                    node.copyAttributes(device_node)
+
+    model.apply(patch_device)
+    patch_device(model.encode_audio)
+    patch_device(model.encode_text)
+
+    # patch dtype to float32 on CPU
+    if str(device) == "cpu":
+        float_holder = torch.jit.trace(
+            lambda: torch.ones([]).float(), example_inputs=[]
+        )
+        float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+        float_node = float_input.node()
+
+        def patch_float(module):
+            try:
+                graphs = [module.graph] if hasattr(module, "graph") else []
+            except RuntimeError:
+                graphs = []
+
+            if hasattr(module, "forward1"):
+                graphs.append(module.forward1.graph)
+
+            for graph in graphs:
+                for node in graph.findAllNodes("aten::to"):
+                    inputs = list(node.inputs())
+                    for i in [
+                        1,
+                        2,
+                    ]:  # dtype can be the second or third argument to aten::to()
+                        if inputs[i].node()["value"] == 5:
+                            inputs[i].node().copyAttributes(float_node)
+
+        model.apply(patch_float)
+        patch_float(model.encode_audio)
+        patch_float(model.encode_text)
+        model.float()
+
+    model.audio_branch.audio_length = model.audio_cfg.audio_length
+    return model
diff --git a/audioldm2/clap/open_clip/pann_model.py b/audioldm2/clap/open_clip/pann_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..e9fab8e03cdca370c141a9e321e98d256e79fb27
--- /dev/null
+++ b/audioldm2/clap/open_clip/pann_model.py
@@ -0,0 +1,697 @@
+# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
+# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
+# Some layers are re-designed for CLAP
+import os
+
+os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchlibrosa.stft import Spectrogram, LogmelFilterBank
+from torchlibrosa.augmentation import SpecAugmentation
+
+from .utils import do_mixup, interpolate
+from .feature_fusion import iAFF, AFF, DAF
+
+
+def init_layer(layer):
+    """Initialize a Linear or Convolutional layer."""
+    nn.init.xavier_uniform_(layer.weight)
+
+    if hasattr(layer, "bias"):
+        if layer.bias is not None:
+            layer.bias.data.fill_(0.0)
+
+
+def init_bn(bn):
+    """Initialize a Batchnorm layer."""
+    bn.bias.data.fill_(0.0)
+    bn.weight.data.fill_(1.0)
+
+
+class ConvBlock(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(ConvBlock, self).__init__()
+
+        self.conv1 = nn.Conv2d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=(3, 3),
+            stride=(1, 1),
+            padding=(1, 1),
+            bias=False,
+        )
+
+        self.conv2 = nn.Conv2d(
+            in_channels=out_channels,
+            out_channels=out_channels,
+            kernel_size=(3, 3),
+            stride=(1, 1),
+            padding=(1, 1),
+            bias=False,
+        )
+
+        self.bn1 = nn.BatchNorm2d(out_channels)
+        self.bn2 = nn.BatchNorm2d(out_channels)
+
+        self.init_weight()
+
+    def init_weight(self):
+        init_layer(self.conv1)
+        init_layer(self.conv2)
+        init_bn(self.bn1)
+        init_bn(self.bn2)
+
+    def forward(self, input, pool_size=(2, 2), pool_type="avg"):
+        x = input
+        x = F.relu_(self.bn1(self.conv1(x)))
+        x = F.relu_(self.bn2(self.conv2(x)))
+        if pool_type == "max":
+            x = F.max_pool2d(x, kernel_size=pool_size)
+        elif pool_type == "avg":
+            x = F.avg_pool2d(x, kernel_size=pool_size)
+        elif pool_type == "avg+max":
+            x1 = F.avg_pool2d(x, kernel_size=pool_size)
+            x2 = F.max_pool2d(x, kernel_size=pool_size)
+            x = x1 + x2
+        else:
+            raise Exception("Incorrect argument!")
+
+        return x
+
+
+class ConvBlock5x5(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(ConvBlock5x5, self).__init__()
+
+        self.conv1 = nn.Conv2d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=(5, 5),
+            stride=(1, 1),
+            padding=(2, 2),
+            bias=False,
+        )
+
+        self.bn1 = nn.BatchNorm2d(out_channels)
+
+        self.init_weight()
+
+    def init_weight(self):
+        init_layer(self.conv1)
+        init_bn(self.bn1)
+
+    def forward(self, input, pool_size=(2, 2), pool_type="avg"):
+        x = input
+        x = F.relu_(self.bn1(self.conv1(x)))
+        if pool_type == "max":
+            x = F.max_pool2d(x, kernel_size=pool_size)
+        elif pool_type == "avg":
+            x = F.avg_pool2d(x, kernel_size=pool_size)
+        elif pool_type == "avg+max":
+            x1 = F.avg_pool2d(x, kernel_size=pool_size)
+            x2 = F.max_pool2d(x, kernel_size=pool_size)
+            x = x1 + x2
+        else:
+            raise Exception("Incorrect argument!")
+
+        return x
+
+
+class AttBlock(nn.Module):
+    def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
+        super(AttBlock, self).__init__()
+
+        self.activation = activation
+        self.temperature = temperature
+        self.att = nn.Conv1d(
+            in_channels=n_in,
+            out_channels=n_out,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=True,
+        )
+        self.cla = nn.Conv1d(
+            in_channels=n_in,
+            out_channels=n_out,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=True,
+        )
+
+        self.bn_att = nn.BatchNorm1d(n_out)
+        self.init_weights()
+
+    def init_weights(self):
+        init_layer(self.att)
+        init_layer(self.cla)
+        init_bn(self.bn_att)
+
+    def forward(self, x):
+        # x: (n_samples, n_in, n_time)
+        norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
+        cla = self.nonlinear_transform(self.cla(x))
+        x = torch.sum(norm_att * cla, dim=2)
+        return x, norm_att, cla
+
+    def nonlinear_transform(self, x):
+        if self.activation == "linear":
+            return x
+        elif self.activation == "sigmoid":
+            return torch.sigmoid(x)
+
+
+class Cnn14(nn.Module):
+    def __init__(
+        self,
+        sample_rate,
+        window_size,
+        hop_size,
+        mel_bins,
+        fmin,
+        fmax,
+        classes_num,
+        enable_fusion=False,
+        fusion_type="None",
+    ):
+        super(Cnn14, self).__init__()
+
+        window = "hann"
+        center = True
+        pad_mode = "reflect"
+        ref = 1.0
+        amin = 1e-10
+        top_db = None
+
+        self.enable_fusion = enable_fusion
+        self.fusion_type = fusion_type
+
+        # Spectrogram extractor
+        self.spectrogram_extractor = Spectrogram(
+            n_fft=window_size,
+            hop_length=hop_size,
+            win_length=window_size,
+            window=window,
+            center=center,
+            pad_mode=pad_mode,
+            freeze_parameters=True,
+        )
+
+        # Logmel feature extractor
+        self.logmel_extractor = LogmelFilterBank(
+            sr=sample_rate,
+            n_fft=window_size,
+            n_mels=mel_bins,
+            fmin=fmin,
+            fmax=fmax,
+            ref=ref,
+            amin=amin,
+            top_db=top_db,
+            freeze_parameters=True,
+        )
+
+        # Spec augmenter
+        self.spec_augmenter = SpecAugmentation(
+            time_drop_width=64,
+            time_stripes_num=2,
+            freq_drop_width=8,
+            freq_stripes_num=2,
+        )
+
+        self.bn0 = nn.BatchNorm2d(64)
+
+        if (self.enable_fusion) and (self.fusion_type == "channel_map"):
+            self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
+        else:
+            self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
+        self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
+        self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
+        self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
+        self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
+        self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
+
+        self.fc1 = nn.Linear(2048, 2048, bias=True)
+        self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
+
+        if (self.enable_fusion) and (
+            self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
+        ):
+            self.mel_conv1d = nn.Sequential(
+                nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
+                nn.BatchNorm1d(64),  # No Relu
+            )
+            if self.fusion_type == "daf_1d":
+                self.fusion_model = DAF()
+            elif self.fusion_type == "aff_1d":
+                self.fusion_model = AFF(channels=64, type="1D")
+            elif self.fusion_type == "iaff_1d":
+                self.fusion_model = iAFF(channels=64, type="1D")
+
+        if (self.enable_fusion) and (
+            self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+        ):
+            self.mel_conv2d = nn.Sequential(
+                nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
+                nn.BatchNorm2d(64),
+                nn.ReLU(inplace=True),
+            )
+
+            if self.fusion_type == "daf_2d":
+                self.fusion_model = DAF()
+            elif self.fusion_type == "aff_2d":
+                self.fusion_model = AFF(channels=64, type="2D")
+            elif self.fusion_type == "iaff_2d":
+                self.fusion_model = iAFF(channels=64, type="2D")
+        self.init_weight()
+
+    def init_weight(self):
+        init_bn(self.bn0)
+        init_layer(self.fc1)
+        init_layer(self.fc_audioset)
+
+    def forward(self, input, mixup_lambda=None, device=None):
+        """
+        Input: (batch_size, data_length)"""
+
+        if self.enable_fusion and input["longer"].sum() == 0:
+            # if no audio is longer than 10s, then randomly select one audio to be longer
+            input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
+
+        if not self.enable_fusion:
+            x = self.spectrogram_extractor(
+                input["waveform"].to(device=device, non_blocking=True)
+            )  # (batch_size, 1, time_steps, freq_bins)
+            x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)
+
+            x = x.transpose(1, 3)
+            x = self.bn0(x)
+            x = x.transpose(1, 3)
+        else:
+            longer_list = input["longer"].to(device=device, non_blocking=True)
+            x = input["mel_fusion"].to(device=device, non_blocking=True)
+            longer_list_idx = torch.where(longer_list)[0]
+            x = x.transpose(1, 3)
+            x = self.bn0(x)
+            x = x.transpose(1, 3)
+            if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
+                new_x = x[:, 0:1, :, :].clone().contiguous()
+                # local processing
+                if len(longer_list_idx) > 0:
+                    fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
+                    FB, FC, FT, FF = fusion_x_local.size()
+                    fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
+                    fusion_x_local = torch.permute(
+                        fusion_x_local, (0, 2, 1)
+                    ).contiguous()
+                    fusion_x_local = self.mel_conv1d(fusion_x_local)
+                    fusion_x_local = fusion_x_local.view(
+                        FB, FC, FF, fusion_x_local.size(-1)
+                    )
+                    fusion_x_local = (
+                        torch.permute(fusion_x_local, (0, 2, 1, 3))
+                        .contiguous()
+                        .flatten(2)
+                    )
+                    if fusion_x_local.size(-1) < FT:
+                        fusion_x_local = torch.cat(
+                            [
+                                fusion_x_local,
+                                torch.zeros(
+                                    (FB, FF, FT - fusion_x_local.size(-1)),
+                                    device=device,
+                                ),
+                            ],
+                            dim=-1,
+                        )
+                    else:
+                        fusion_x_local = fusion_x_local[:, :, :FT]
+                    # 1D fusion
+                    new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
+                    new_x[longer_list_idx] = self.fusion_model(
+                        new_x[longer_list_idx], fusion_x_local
+                    )
+                    x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
+                else:
+                    x = new_x
+            elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
+                x = x  # no change
+
+        if self.training:
+            x = self.spec_augmenter(x)
+        # Mixup on spectrogram
+        if self.training and mixup_lambda is not None:
+            x = do_mixup(x, mixup_lambda)
+        if (self.enable_fusion) and (
+            self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+        ):
+            global_x = x[:, 0:1, :, :]
+
+            # global processing
+            B, C, H, W = global_x.shape
+            global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
+            if len(longer_list_idx) > 0:
+                local_x = x[longer_list_idx, 1:, :, :].contiguous()
+                TH = global_x.size(-2)
+                # local processing
+                B, C, H, W = local_x.shape
+                local_x = local_x.view(B * C, 1, H, W)
+                local_x = self.mel_conv2d(local_x)
+                local_x = local_x.view(
+                    B, C, local_x.size(1), local_x.size(2), local_x.size(3)
+                )
+                local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
+                TB, TC, _, TW = local_x.size()
+                if local_x.size(-2) < TH:
+                    local_x = torch.cat(
+                        [
+                            local_x,
+                            torch.zeros(
+                                (TB, TC, TH - local_x.size(-2), TW),
+                                device=global_x.device,
+                            ),
+                        ],
+                        dim=-2,
+                    )
+                else:
+                    local_x = local_x[:, :, :TH, :]
+
+                global_x[longer_list_idx] = self.fusion_model(
+                    global_x[longer_list_idx], local_x
+                )
+            x = global_x
+        else:
+            x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
+
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = torch.mean(x, dim=3)
+
+        latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+        latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+        latent_x = latent_x1 + latent_x2
+        latent_x = latent_x.transpose(1, 2)
+        latent_x = F.relu_(self.fc1(latent_x))
+        latent_output = interpolate(latent_x, 32)
+
+        (x1, _) = torch.max(x, dim=2)
+        x2 = torch.mean(x, dim=2)
+        x = x1 + x2
+        x = F.dropout(x, p=0.5, training=self.training)
+        x = F.relu_(self.fc1(x))
+        embedding = F.dropout(x, p=0.5, training=self.training)
+        clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+        output_dict = {
+            "clipwise_output": clipwise_output,
+            "embedding": embedding,
+            "fine_grained_embedding": latent_output,
+        }
+        return output_dict
+
+
+class Cnn6(nn.Module):
+    def __init__(
+        self,
+        sample_rate,
+        window_size,
+        hop_size,
+        mel_bins,
+        fmin,
+        fmax,
+        classes_num,
+        enable_fusion=False,
+        fusion_type="None",
+    ):
+        super(Cnn6, self).__init__()
+
+        window = "hann"
+        center = True
+        pad_mode = "reflect"
+        ref = 1.0
+        amin = 1e-10
+        top_db = None
+
+        self.enable_fusion = enable_fusion
+        self.fusion_type = fusion_type
+
+        # Spectrogram extractor
+        self.spectrogram_extractor = Spectrogram(
+            n_fft=window_size,
+            hop_length=hop_size,
+            win_length=window_size,
+            window=window,
+            center=center,
+            pad_mode=pad_mode,
+            freeze_parameters=True,
+        )
+
+        # Logmel feature extractor
+        self.logmel_extractor = LogmelFilterBank(
+            sr=sample_rate,
+            n_fft=window_size,
+            n_mels=mel_bins,
+            fmin=fmin,
+            fmax=fmax,
+            ref=ref,
+            amin=amin,
+            top_db=top_db,
+            freeze_parameters=True,
+        )
+
+        # Spec augmenter
+        self.spec_augmenter = SpecAugmentation(
+            time_drop_width=64,
+            time_stripes_num=2,
+            freq_drop_width=8,
+            freq_stripes_num=2,
+        )
+
+        self.bn0 = nn.BatchNorm2d(64)
+
+        self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
+        self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
+        self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
+        self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
+
+        self.fc1 = nn.Linear(512, 512, bias=True)
+        self.fc_audioset = nn.Linear(512, classes_num, bias=True)
+
+        self.init_weight()
+
+    def init_weight(self):
+        init_bn(self.bn0)
+        init_layer(self.fc1)
+        init_layer(self.fc_audioset)
+
+    def forward(self, input, mixup_lambda=None, device=None):
+        """
+        Input: (batch_size, data_length)"""
+
+        x = self.spectrogram_extractor(input)  # (batch_size, 1, time_steps, freq_bins)
+        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)
+
+        x = x.transpose(1, 3)
+        x = self.bn0(x)
+        x = x.transpose(1, 3)
+
+        if self.training:
+            x = self.spec_augmenter(x)
+
+        # Mixup on spectrogram
+        if self.training and mixup_lambda is not None:
+            x = do_mixup(x, mixup_lambda)
+
+        x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = torch.mean(x, dim=3)
+
+        latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+        latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+        latent_x = latent_x1 + latent_x2
+        latent_x = latent_x.transpose(1, 2)
+        latent_x = F.relu_(self.fc1(latent_x))
+        latent_output = interpolate(latent_x, 16)
+
+        (x1, _) = torch.max(x, dim=2)
+        x2 = torch.mean(x, dim=2)
+        x = x1 + x2
+        x = F.dropout(x, p=0.5, training=self.training)
+        x = F.relu_(self.fc1(x))
+        embedding = F.dropout(x, p=0.5, training=self.training)
+        clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+        output_dict = {
+            "clipwise_output": clipwise_output,
+            "embedding": embedding,
+            "fine_grained_embedding": latent_output,
+        }
+
+        return output_dict
+
+
+class Cnn10(nn.Module):
+    def __init__(
+        self,
+        sample_rate,
+        window_size,
+        hop_size,
+        mel_bins,
+        fmin,
+        fmax,
+        classes_num,
+        enable_fusion=False,
+        fusion_type="None",
+    ):
+        super(Cnn10, self).__init__()
+
+        window = "hann"
+        center = True
+        pad_mode = "reflect"
+        ref = 1.0
+        amin = 1e-10
+        top_db = None
+
+        self.enable_fusion = enable_fusion
+        self.fusion_type = fusion_type
+
+        # Spectrogram extractor
+        self.spectrogram_extractor = Spectrogram(
+            n_fft=window_size,
+            hop_length=hop_size,
+            win_length=window_size,
+            window=window,
+            center=center,
+            pad_mode=pad_mode,
+            freeze_parameters=True,
+        )
+
+        # Logmel feature extractor
+        self.logmel_extractor = LogmelFilterBank(
+            sr=sample_rate,
+            n_fft=window_size,
+            n_mels=mel_bins,
+            fmin=fmin,
+            fmax=fmax,
+            ref=ref,
+            amin=amin,
+            top_db=top_db,
+            freeze_parameters=True,
+        )
+
+        # Spec augmenter
+        self.spec_augmenter = SpecAugmentation(
+            time_drop_width=64,
+            time_stripes_num=2,
+            freq_drop_width=8,
+            freq_stripes_num=2,
+        )
+
+        self.bn0 = nn.BatchNorm2d(64)
+
+        self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
+        self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
+        self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
+        self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
+        self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
+
+        self.fc1 = nn.Linear(1024, 1024, bias=True)
+        self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
+
+        self.init_weight()
+
+    def init_weight(self):
+        init_bn(self.bn0)
+        init_layer(self.fc1)
+        init_layer(self.fc_audioset)
+
+    def forward(self, input, mixup_lambda=None, device=None):
+        """
+        Input: (batch_size, data_length)"""
+
+        x = self.spectrogram_extractor(input)  # (batch_size, 1, time_steps, freq_bins)
+        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)
+
+        x = x.transpose(1, 3)
+        x = self.bn0(x)
+        x = x.transpose(1, 3)
+
+        if self.training:
+            x = self.spec_augmenter(x)
+
+        # Mixup on spectrogram
+        if self.training and mixup_lambda is not None:
+            x = do_mixup(x, mixup_lambda)
+
+        x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
+        x = F.dropout(x, p=0.2, training=self.training)
+        x = torch.mean(x, dim=3)
+
+        latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+        latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+        latent_x = latent_x1 + latent_x2
+        latent_x = latent_x.transpose(1, 2)
+        latent_x = F.relu_(self.fc1(latent_x))
+        latent_output = interpolate(latent_x, 32)
+
+        (x1, _) = torch.max(x, dim=2)
+        x2 = torch.mean(x, dim=2)
+        x = x1 + x2
+        x = F.dropout(x, p=0.5, training=self.training)
+        x = F.relu_(self.fc1(x))
+        embedding = F.dropout(x, p=0.5, training=self.training)
+        clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+        output_dict = {
+            "clipwise_output": clipwise_output,
+            "embedding": embedding,
+            "fine_grained_embedding": latent_output,
+        }
+
+        return output_dict
+
+
+def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
+    try:
+        ModelProto = eval(audio_cfg.model_name)
+        model = ModelProto(
+            sample_rate=audio_cfg.sample_rate,
+            window_size=audio_cfg.window_size,
+            hop_size=audio_cfg.hop_size,
+            mel_bins=audio_cfg.mel_bins,
+            fmin=audio_cfg.fmin,
+            fmax=audio_cfg.fmax,
+            classes_num=audio_cfg.class_num,
+            enable_fusion=enable_fusion,
+            fusion_type=fusion_type,
+        )
+        return model
+    except:
+        raise RuntimeError(
+            f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
+        )
diff --git a/audioldm2/clap/open_clip/pretrained.py b/audioldm2/clap/open_clip/pretrained.py
new file mode 100755
index 0000000000000000000000000000000000000000..e211d8b5b59320a599e62605f1dee6199f317253
--- /dev/null
+++ b/audioldm2/clap/open_clip/pretrained.py
@@ -0,0 +1,167 @@
+import hashlib
+import os
+import urllib
+import warnings
+
+from tqdm import tqdm
+
+_RN50 = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+    yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
+    cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
+)
+
+_RN50_quickgelu = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+    yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
+    cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
+)
+
+_RN101 = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+    yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
+)
+
+_RN101_quickgelu = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+    yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
+)
+
+_RN50x4 = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+)
+
+_RN50x16 = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+)
+
+_RN50x64 = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
+)
+
+_VITB32 = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+    laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
+    laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
+    laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
+)
+
+_VITB32_quickgelu = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+    laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
+    laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
+    laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
+)
+
+_VITB16 = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+)
+
+_VITL14 = dict(
+    openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
+)
+
+_PRETRAINED = {
+    "RN50": _RN50,
+    "RN50-quickgelu": _RN50_quickgelu,
+    "RN101": _RN101,
+    "RN101-quickgelu": _RN101_quickgelu,
+    "RN50x4": _RN50x4,
+    "RN50x16": _RN50x16,
+    "ViT-B-32": _VITB32,
+    "ViT-B-32-quickgelu": _VITB32_quickgelu,
+    "ViT-B-16": _VITB16,
+    "ViT-L-14": _VITL14,
+}
+
+
+def list_pretrained(as_str: bool = False):
+    """returns list of pretrained models
+    Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
+    """
+    return [
+        ":".join([k, t]) if as_str else (k, t)
+        for k in _PRETRAINED.keys()
+        for t in _PRETRAINED[k].keys()
+    ]
+
+
+def list_pretrained_tag_models(tag: str):
+    """return all models having the specified pretrain tag"""
+    models = []
+    for k in _PRETRAINED.keys():
+        if tag in _PRETRAINED[k]:
+            models.append(k)
+    return models
+
+
+def list_pretrained_model_tags(model: str):
+    """return all pretrain tags for the specified model architecture"""
+    tags = []
+    if model in _PRETRAINED:
+        tags.extend(_PRETRAINED[model].keys())
+    return tags
+
+
+def get_pretrained_url(model: str, tag: str):
+    if model not in _PRETRAINED:
+        return ""
+    model_pretrained = _PRETRAINED[model]
+    if tag not in model_pretrained:
+        return ""
+    return model_pretrained[tag]
+
+
+def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
+    os.makedirs(root, exist_ok=True)
+    filename = os.path.basename(url)
+
+    if "openaipublic" in url:
+        expected_sha256 = url.split("/")[-2]
+    else:
+        expected_sha256 = ""
+
+    download_target = os.path.join(root, filename)
+
+    if os.path.exists(download_target) and not os.path.isfile(download_target):
+        raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+    if os.path.isfile(download_target):
+        if expected_sha256:
+            if (
+                hashlib.sha256(open(download_target, "rb").read()).hexdigest()
+                == expected_sha256
+            ):
+                return download_target
+            else:
+                warnings.warn(
+                    f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
+                )
+        else:
+            return download_target
+
+    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+        with tqdm(
+            total=int(source.info().get("Content-Length")),
+            ncols=80,
+            unit="iB",
+            unit_scale=True,
+        ) as loop:
+            while True:
+                buffer = source.read(8192)
+                if not buffer:
+                    break
+
+                output.write(buffer)
+                loop.update(len(buffer))
+
+    if (
+        expected_sha256
+        and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
+        != expected_sha256
+    ):
+        raise RuntimeError(
+            f"Model has been downloaded but the SHA256 checksum does not not match"
+        )
+
+    return download_target
diff --git a/audioldm2/clap/open_clip/timm_model.py b/audioldm2/clap/open_clip/timm_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..b8486b9e62580bb65f0f50a0a7000890cb7ee42d
--- /dev/null
+++ b/audioldm2/clap/open_clip/timm_model.py
@@ -0,0 +1,112 @@
+""" timm model adapter
+
+Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
+"""
+from collections import OrderedDict
+
+import torch.nn as nn
+
+try:
+    import timm
+    from timm.models.layers import Mlp, to_2tuple
+    from timm.models.layers.attention_pool2d import RotAttentionPool2d
+    from timm.models.layers.attention_pool2d import (
+        AttentionPool2d as AbsAttentionPool2d,
+    )
+except ImportError:
+    timm = None
+
+from .utils import freeze_batch_norm_2d
+
+
+class TimmModel(nn.Module):
+    """timm model adapter
+    # FIXME this adapter is a work in progress, may change in ways that break weight compat
+    """
+
+    def __init__(
+        self,
+        model_name,
+        embed_dim,
+        image_size=224,
+        pool="avg",
+        proj="linear",
+        drop=0.0,
+        pretrained=False,
+    ):
+        super().__init__()
+        if timm is None:
+            raise RuntimeError("Please `pip install timm` to use timm models.")
+
+        self.image_size = to_2tuple(image_size)
+        self.trunk = timm.create_model(model_name, pretrained=pretrained)
+        feat_size = self.trunk.default_cfg.get("pool_size", None)
+        feature_ndim = 1 if not feat_size else 2
+        if pool in ("abs_attn", "rot_attn"):
+            assert feature_ndim == 2
+            # if attn pooling used, remove both classifier and default pool
+            self.trunk.reset_classifier(0, global_pool="")
+        else:
+            # reset global pool if pool config set, otherwise leave as network default
+            reset_kwargs = dict(global_pool=pool) if pool else {}
+            self.trunk.reset_classifier(0, **reset_kwargs)
+        prev_chs = self.trunk.num_features
+
+        head_layers = OrderedDict()
+        if pool == "abs_attn":
+            head_layers["pool"] = AbsAttentionPool2d(
+                prev_chs, feat_size=feat_size, out_features=embed_dim
+            )
+            prev_chs = embed_dim
+        elif pool == "rot_attn":
+            head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
+            prev_chs = embed_dim
+        else:
+            assert proj, "projection layer needed if non-attention pooling is used."
+
+        # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
+        if proj == "linear":
+            head_layers["drop"] = nn.Dropout(drop)
+            head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
+        elif proj == "mlp":
+            head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
+
+        self.head = nn.Sequential(head_layers)
+
+    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+        """lock modules
+        Args:
+            unlocked_groups (int): leave last n layer groups unlocked (default: 0)
+        """
+        if not unlocked_groups:
+            # lock full model
+            for param in self.trunk.parameters():
+                param.requires_grad = False
+            if freeze_bn_stats:
+                freeze_batch_norm_2d(self.trunk)
+        else:
+            # NOTE: partial freeze requires latest timm (master) branch and is subject to change
+            try:
+                # FIXME import here until API stable and in an official release
+                from timm.models.helpers import group_parameters, group_modules
+            except ImportError:
+                raise RuntimeError(
+                    "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
+                )
+            matcher = self.trunk.group_matcher()
+            gparams = group_parameters(self.trunk, matcher)
+            max_layer_id = max(gparams.keys())
+            max_layer_id = max_layer_id - unlocked_groups
+            for group_idx in range(max_layer_id + 1):
+                group = gparams[group_idx]
+                for param in group:
+                    self.trunk.get_parameter(param).requires_grad = False
+            if freeze_bn_stats:
+                gmodules = group_modules(self.trunk, matcher, reverse=True)
+                gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
+                freeze_batch_norm_2d(self.trunk, gmodules)
+
+    def forward(self, x):
+        x = self.trunk(x)
+        x = self.head(x)
+        return x
diff --git a/audioldm2/clap/open_clip/tokenizer.py b/audioldm2/clap/open_clip/tokenizer.py
new file mode 100755
index 0000000000000000000000000000000000000000..ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5
--- /dev/null
+++ b/audioldm2/clap/open_clip/tokenizer.py
@@ -0,0 +1,197 @@
+""" CLIP tokenizer
+
+Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+import gzip
+import html
+import os
+from functools import lru_cache
+from typing import Union, List
+
+import ftfy
+import regex as re
+import torch
+
+
+@lru_cache()
+def default_bpe():
+    return os.path.join(
+        os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
+    )
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = (
+        list(range(ord("!"), ord("~") + 1))
+        + list(range(ord("¡"), ord("¬") + 1))
+        + list(range(ord("®"), ord("ÿ") + 1))
+    )
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8 + n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+def basic_clean(text):
+    text = ftfy.fix_text(text)
+    text = html.unescape(html.unescape(text))
+    return text.strip()
+
+
+def whitespace_clean(text):
+    text = re.sub(r"\s+", " ", text)
+    text = text.strip()
+    return text
+
+
+class SimpleTokenizer(object):
+    def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
+        merges = merges[1 : 49152 - 256 - 2 + 1]
+        merges = [tuple(merge.split()) for merge in merges]
+        vocab = list(bytes_to_unicode().values())
+        vocab = vocab + [v + "</w>" for v in vocab]
+        for merge in merges:
+            vocab.append("".join(merge))
+        if not special_tokens:
+            special_tokens = ["<start_of_text>", "<end_of_text>"]
+        else:
+            special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
+        vocab.extend(special_tokens)
+        self.encoder = dict(zip(vocab, range(len(vocab))))
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(merges, range(len(merges))))
+        self.cache = {t: t for t in special_tokens}
+        special = "|".join(special_tokens)
+        self.pat = re.compile(
+            special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+            re.IGNORECASE,
+        )
+
+        self.vocab_size = len(self.encoder)
+        self.all_special_ids = [self.encoder[t] for t in special_tokens]
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token[:-1]) + (token[-1] + "</w>",)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token + "</w>"
+
+        while True:
+            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+                    new_word.append(first + second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = " ".join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        bpe_tokens = []
+        text = whitespace_clean(basic_clean(text)).lower()
+        for token in re.findall(self.pat, text):
+            token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+            bpe_tokens.extend(
+                self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
+            )
+        return bpe_tokens
+
+    def decode(self, tokens):
+        text = "".join([self.decoder[token] for token in tokens])
+        text = (
+            bytearray([self.byte_decoder[c] for c in text])
+            .decode("utf-8", errors="replace")
+            .replace("</w>", " ")
+        )
+        return text
+
+
+_tokenizer = SimpleTokenizer()
+
+
+def tokenize(
+    texts: Union[str, List[str]], context_length: int = 77
+) -> torch.LongTensor:
+    """
+    Returns the tokenized representation of given input string(s)
+
+    Parameters
+    ----------
+    texts : Union[str, List[str]]
+        An input string or a list of input strings to tokenize
+    context_length : int
+        The context length to use; all CLIP models use 77 as the context length
+
+    Returns
+    -------
+    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+    """
+    if isinstance(texts, str):
+        texts = [texts]
+
+    sot_token = _tokenizer.encoder["<start_of_text>"]
+    eot_token = _tokenizer.encoder["<end_of_text>"]
+    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+    for i, tokens in enumerate(all_tokens):
+        if len(tokens) > context_length:
+            tokens = tokens[:context_length]  # Truncate
+        result[i, : len(tokens)] = torch.tensor(tokens)
+
+    return result
diff --git a/audioldm2/clap/open_clip/transform.py b/audioldm2/clap/open_clip/transform.py
new file mode 100755
index 0000000000000000000000000000000000000000..77aaa722c4a5544ac50de6df35d3e922f63b111d
--- /dev/null
+++ b/audioldm2/clap/open_clip/transform.py
@@ -0,0 +1,45 @@
+from torchvision.transforms import (
+    Normalize,
+    Compose,
+    RandomResizedCrop,
+    InterpolationMode,
+    ToTensor,
+    Resize,
+    CenterCrop,
+)
+
+
+def _convert_to_rgb(image):
+    return image.convert("RGB")
+
+
+def image_transform(
+    image_size: int,
+    is_train: bool,
+    mean=(0.48145466, 0.4578275, 0.40821073),
+    std=(0.26862954, 0.26130258, 0.27577711),
+):
+    normalize = Normalize(mean=mean, std=std)
+    if is_train:
+        return Compose(
+            [
+                RandomResizedCrop(
+                    image_size,
+                    scale=(0.9, 1.0),
+                    interpolation=InterpolationMode.BICUBIC,
+                ),
+                _convert_to_rgb,
+                ToTensor(),
+                normalize,
+            ]
+        )
+    else:
+        return Compose(
+            [
+                Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+                CenterCrop(image_size),
+                _convert_to_rgb,
+                ToTensor(),
+                normalize,
+            ]
+        )
diff --git a/audioldm2/clap/open_clip/utils.py b/audioldm2/clap/open_clip/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..77875569ff4aff81bf9545ce6ec58e0326d49d0c
--- /dev/null
+++ b/audioldm2/clap/open_clip/utils.py
@@ -0,0 +1,356 @@
+import numpy as np
+import torch
+from torch import nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
+import logging
+import h5py
+from tqdm import tqdm
+import random
+import json
+import os
+import pathlib
+
+# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
+dataset_split = {
+    "audiocaps": ["train", "valid", "test"],
+    "audioset": ["balanced_train", "unbalanced_train", "eval"],
+    "BBCSoundEffects": ["train", "test"],
+    "Clotho": ["train", "test", "valid"],
+    "free_to_use_sounds": ["train", "test"],
+    "paramount_motion": ["train", "test"],
+    "sonniss_game_effects": ["train", "test"],
+    "wesoundeffects": ["train", "test"],
+    "MACS": ["train", "test"],
+    "freesound": ["train", "test"],
+    "FSD50K": ["train", "test", "valid"],
+    "fsd50k_class_label": ["train", "test", "valid"],
+    "esc50": ["train", "test"],
+    "audiostock": ["train", "test"],
+    "freesound_no_overlap_noesc50": ["train", "test"],
+    "epidemic_sound_effects": ["train", "test"],
+    "VGGSound": ["train", "test"],
+    "urbansound8k_class_label": ["train", "test"],
+    "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
+    "epidemic_sound_effects_t5": ["train", "test"],
+    "WavText5K": ["train", "test"],
+    "esc50_no_overlap": ["train", "test"],
+    "usd8k_no_overlap": ["train", "test"],
+    "fsd50k_200_class_label": ["train", "test", "valid"],
+}
+
+
+def freeze_batch_norm_2d(module, module_match={}, name=""):
+    """
+    Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
+    itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
+    returned. Otherwise, the module is walked recursively and submodules are converted in place.
+
+    Args:
+        module (torch.nn.Module): Any PyTorch module.
+        module_match (dict): Dictionary of full module names to freeze (all if empty)
+        name (str): Full module name (prefix)
+
+    Returns:
+        torch.nn.Module: Resulting module
+
+    Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+    """
+    res = module
+    is_match = True
+    if module_match:
+        is_match = name in module_match
+    if is_match and isinstance(
+        module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
+    ):
+        res = FrozenBatchNorm2d(module.num_features)
+        res.num_features = module.num_features
+        res.affine = module.affine
+        if module.affine:
+            res.weight.data = module.weight.data.clone().detach()
+            res.bias.data = module.bias.data.clone().detach()
+        res.running_mean.data = module.running_mean.data
+        res.running_var.data = module.running_var.data
+        res.eps = module.eps
+    else:
+        for child_name, child in module.named_children():
+            full_child_name = ".".join([name, child_name]) if name else child_name
+            new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
+            if new_child is not child:
+                res.add_module(child_name, new_child)
+    return res
+
+
+def exist(dataset_name, dataset_type):
+    """
+    Check if dataset exists
+    """
+    if dataset_type in dataset_split[dataset_name]:
+        return True
+    else:
+        return False
+
+
+def get_tar_path_from_dataset_name(
+    dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
+):
+    """
+    Get tar path from dataset name and type
+    """
+    output = []
+    for n in dataset_names:
+        if full_dataset is not None and n in full_dataset:
+            current_dataset_types = dataset_split[n]
+        else:
+            current_dataset_types = dataset_types
+        for s in current_dataset_types:
+            tmp = []
+            if islocal:
+                sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
+                if not os.path.exists(sizefilepath_):
+                    sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+            else:
+                sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+            if not os.path.exists(sizefilepath_):
+                continue
+            sizes = json.load(open(sizefilepath_, "r"))
+            for k in sizes.keys():
+                if islocal:
+                    tmp.append(f"{dataset_path}/{n}/{s}/{k}")
+                else:
+                    tmp.append(
+                        f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
+                    )
+            if proportion != 1:
+                tmp = random.sample(tmp, int(proportion * len(tmp)))
+            output.append(tmp)
+    return sum(output, [])
+
+
+def get_tar_path_from_txts(txt_path, islocal, proportion=1):
+    """
+    Get tar path from txt path
+    """
+    if isinstance(txt_path, (list, tuple)):
+        return sum(
+            [
+                get_tar_path_from_txts(
+                    txt_path[i], islocal=islocal, proportion=proportion
+                )
+                for i in range(len(txt_path))
+            ],
+            [],
+        )
+    if isinstance(txt_path, str):
+        with open(txt_path) as f:
+            lines = f.readlines()
+        if islocal:
+            lines = [
+                lines[i]
+                .split("\n")[0]
+                .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
+                for i in range(len(lines))
+            ]
+        else:
+            lines = [
+                lines[i].split("\n")[0].replace(".tar", ".tar -")
+                for i in range(len(lines))
+            ]
+        if proportion != 1:
+            print("Sampling tars with proportion of {}".format(proportion))
+            lines = random.sample(lines, int(proportion * len(lines)))
+        return lines
+
+
+def get_mix_lambda(mixup_alpha, batch_size):
+    mixup_lambdas = [
+        np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
+    ]
+    return np.array(mixup_lambdas).astype(np.float32)
+
+
+def do_mixup(x, mixup_lambda):
+    """
+    Args:
+      x: (batch_size , ...)
+      mixup_lambda: (batch_size,)
+    Returns:
+      out: (batch_size, ...)
+    """
+    out = (
+        x.transpose(0, -1) * mixup_lambda
+        + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
+    ).transpose(0, -1)
+    return out
+
+
+def interpolate(x, ratio):
+    """Interpolate data in time domain. This is used to compensate the
+    resolution reduction in downsampling of a CNN.
+
+    Args:
+      x: (batch_size, time_steps, classes_num)
+      ratio: int, ratio to interpolate
+    Returns:
+      upsampled: (batch_size, time_steps * ratio, classes_num)
+    """
+    (batch_size, time_steps, classes_num) = x.shape
+    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
+    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
+    return upsampled
+
+
+def pad_framewise_output(framewise_output, frames_num):
+    """Pad framewise_output to the same length as input frames. The pad value
+    is the same as the value of the last frame.
+    Args:
+      framewise_output: (batch_size, frames_num, classes_num)
+      frames_num: int, number of frames to pad
+    Outputs:
+      output: (batch_size, frames_num, classes_num)
+    """
+    pad = framewise_output[:, -1:, :].repeat(
+        1, frames_num - framewise_output.shape[1], 1
+    )
+    """tensor for padding"""
+
+    output = torch.cat((framewise_output, pad), dim=1)
+    """(batch_size, frames_num, classes_num)"""
+
+
+def process_ipc(index_path, classes_num, filename):
+    # load data
+    logging.info("Load Data...............")
+    ipc = [[] for _ in range(classes_num)]
+    with h5py.File(index_path, "r") as f:
+        for i in tqdm(range(len(f["target"]))):
+            t_class = np.where(f["target"][i])[0]
+            for t in t_class:
+                ipc[t].append(i)
+    print(ipc)
+    np.save(filename, ipc)
+    logging.info("Load Data Succeed...............")
+
+
+def save_to_dict(s, o_={}):
+    sp = s.split(": ")
+    o_.update({sp[0]: float(sp[1])})
+    return o_
+
+
+def get_data_from_log(txt_path):
+    """
+    Output dictionary from out.txt log file
+    """
+    with open(txt_path) as f:
+        lines = f.readlines()
+    val_data = {}
+    train_data = {}
+    train_losses = []
+    train_losses_epoch = []
+    for i in range(len(lines)):
+        if "| INFO |" in lines[i]:
+            if "Eval Epoch" in lines[i]:
+                if "val_loss" in lines[i]:
+                    # float(regex.sub("", lines[310].split("	")[-1]).replace(" ", ""))
+                    line = lines[i].split("Eval Epoch: ")[-1]
+                    num_epoch = int(line.split("	")[0].split(" ")[0])
+                    d = {
+                        line.split("	")[0]
+                        .split(" ")[1]
+                        .replace(":", ""): float(line.split("	")[0].split(" ")[-1])
+                    }
+                    for i in range(1, len(line.split("	"))):
+                        d = save_to_dict(line.split("	")[i], d)
+                    val_data[num_epoch] = d
+            elif "Train Epoch" in lines[i]:
+                num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
+                loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
+                train_losses.append(loss)
+                train_losses_epoch.append(num_epoch)
+    for i in range(len(train_losses)):
+        train_data[i] = {
+            "num_epoch": train_losses_epoch[i],
+            "train_loss": train_losses[i],
+        }
+    return train_data, val_data
+
+
+def save_p(obj, filename):
+    import pickle
+
+    try:
+        from deepdiff import DeepDiff
+    except:
+        os.system("pip install deepdiff")
+        from deepdiff import DeepDiff
+    with open(filename, "wb") as file:
+        pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)  # highest protocol
+    with open(filename, "rb") as file:
+        z = pickle.load(file)
+    assert (
+        DeepDiff(obj, z, ignore_string_case=True) == {}
+    ), "there is something wrong with the saving process"
+    return
+
+
+def load_p(filename):
+    import pickle
+
+    with open(filename, "rb") as file:
+        z = pickle.load(file)
+    return z
+
+
+def save_json(data, name="data.json"):
+    import json
+
+    with open(name, "w") as fp:
+        json.dump(data, fp)
+    return
+
+
+def load_json(name):
+    import json
+
+    with open(name, "r") as fp:
+        data = json.load(fp)
+    return data
+
+
+def load_class_label(path):
+    # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
+    # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
+    out = None
+    if path is not None:
+        if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
+            out = load_p(path)
+        elif pathlib.Path(path).suffix in [".json", ".txt"]:
+            out = load_json(path)
+        elif pathlib.Path(path).suffix in [".npy", ".npz"]:
+            out = np.load(path)
+        elif pathlib.Path(path).suffix in [".csv"]:
+            import pandas as pd
+
+            out = pd.read_csv(path)
+    return out
+    # if out is None:
+    #     return None
+    # else:
+    #     key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
+    #     val = Array('i', out.values(), lock=False)
+    #     return (key, val)
+
+
+from torch import optim
+
+
+def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
+    if optimizer_name.lower() == "adamw":
+        optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
+    elif optimizer_name.lower() == "sgd":
+        optimizer = optim.SGD(params, lr=lr, momentum=momentum)
+    elif optimizer_name.lower() == "adam":
+        optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
+    else:
+        raise ValueError("optimizer name is not correct")
+    return optimizer
diff --git a/audioldm2/clap/training/__init__.py b/audioldm2/clap/training/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/clap/training/audioset_textmap.npy b/audioldm2/clap/training/audioset_textmap.npy
new file mode 100755
index 0000000000000000000000000000000000000000..3da4c92d3819aaec11e5f576464a9973a6df811b
--- /dev/null
+++ b/audioldm2/clap/training/audioset_textmap.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b
+size 84448
diff --git a/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz b/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/audioldm2/clap/training/data.py b/audioldm2/clap/training/data.py
new file mode 100755
index 0000000000000000000000000000000000000000..ae01406c63a9b1c678151f67dacd7ea192cb84f2
--- /dev/null
+++ b/audioldm2/clap/training/data.py
@@ -0,0 +1,865 @@
+import json
+import logging
+import os
+import random
+import h5py
+from dataclasses import dataclass
+import numpy as np
+import pandas as pd
+import torch
+import torchvision.datasets as datasets
+from PIL import Image
+from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
+from torch.utils.data.distributed import DistributedSampler
+import soundfile as sf
+import io
+from pathlib import Path
+# import wget
+
+from audioldm2.clap.open_clip.utils import get_tar_path_from_dataset_name
+from audioldm2.clap.open_clip.utils import load_class_label
+
+try:
+    import horovod.torch as hvd
+except ImportError:
+    hvd = None
+
+try:
+    import torchaudio
+except ImportError:
+    torchaudio = None
+
+from audioldm2.clap.open_clip import tokenize
+
+
+def tokenizer(text):
+    return tokenize(text).squeeze(0)
+
+
+from transformers import RobertaTokenizer
+
+tokenize = RobertaTokenizer.from_pretrained("roberta-base")
+
+
+def tokenizer(text):
+    result = tokenize(
+        text,
+        padding="max_length",
+        truncation=True,
+        max_length=77,
+        return_tensors="pt",
+    )
+    return {k: v.squeeze(0) for k, v in result.items()}
+
+
+# initizlied the audioset map
+_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
+_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
+
+
+def int16_to_float32(x):
+    return (x / 32767.0).astype(np.float32)
+
+
+def float32_to_int16(x):
+    x = np.clip(x, a_min=-1.0, a_max=1.0)
+    return (x * 32767.0).astype(np.int16)
+
+
+# For Toy Dataset
+class ToyDataset(Dataset):
+    def __init__(self, index_path, ipc, config, eval_mode=False):
+        """Toy Dataset for testing the audioset input with text labels
+        Parameters
+        ----------
+            index_path: str
+                the link to the h5 file of each audio
+            idc: str
+                the link to the npy file, the number of samples in each class
+            config: dict
+                the audio cfg file
+           eval_model (bool): to indicate if the dataset is a testing dataset
+        """
+        self.audio_cfg = config["audio_cfg"]
+        self.text_cfg = config["text_cfg"]
+        self.fp = h5py.File(index_path, "r")
+        self.ipc = np.load(ipc, allow_pickle=True)
+        self.total_size = len(self.fp["audio_name"])
+        self.classes_num = self.audio_cfg["class_num"]
+        self.eval_mode = eval_mode
+
+        if not eval_mode:
+            self.generate_queue()
+        else:
+            self.queue = []
+            for i in range(self.total_size):
+                target = self.fp["target"][i]
+                if np.sum(target) > 0:
+                    self.queue.append(i)
+            self.total_size = len(self.queue)
+        logging.info("total dataset size: %d" % (self.total_size))
+        logging.info("class num: %d" % (self.classes_num))
+
+    def time_shifting(self, x):
+        frame_num = len(x)
+        shift_len = random.randint(0, frame_num - 1)
+        new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
+        return new_sample
+
+    def generate_queue(self):
+        self.queue = []
+        while len(self.queue) < self.total_size:
+            class_set = [*range(self.classes_num)]
+            random.shuffle(class_set)
+            self.queue += [
+                self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
+            ]
+        self.queue = self.queue[: self.total_size]
+
+        logging.info("queue regenerated:%s" % (self.queue[-5:]))
+
+    def crop_wav(self, x):
+        crop_size = self.audio_cfg["crop_size"]
+        crop_pos = random.randint(0, len(x) - crop_size - 1)
+        return x[crop_pos : crop_pos + crop_size]
+
+    def prompt_text(self, target):
+        events = _AUDIOSET_MAP[np.where(target > 0)]
+        event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
+        text = tokenize(event_text)[0]
+        return text
+
+    def __getitem__(self, index):
+        """Load waveform, text, and target of an audio clip
+
+        Parameters
+        ----------
+            index: int
+                the index number
+        Return
+        ------
+            output: dict {
+                "hdf5_path": str,
+                "index_in_hdf5": int,
+                "audio_name": str,
+                "waveform": list (audio_length,),
+                "target": list (class_num, ),
+                "text": torch.tensor (context_length,)
+            }
+                the output dictionary
+        """
+        s_index = self.queue[index]
+
+        audio_name = self.fp["audio_name"][s_index].decode()
+        # Hardcode here CHANGE
+        hdf5_path = (
+            self.fp["hdf5_path"][s_index]
+            .decode()
+            .replace(
+                "../workspace",
+                "/home/la/kechen/Research/ke_zsasp/workspace",
+            )
+        )
+        r_idx = self.fp["index_in_hdf5"][s_index]
+        target = self.fp["target"][s_index].astype(np.float32)
+        text = self.prompt_text(target)
+        with h5py.File(hdf5_path, "r") as f:
+            waveform = int16_to_float32(f["waveform"][r_idx])[
+                : self.audio_cfg["clip_samples"]
+            ]
+        assert (
+            len(waveform) == self.audio_cfg["clip_samples"]
+        ), "The sample length is not match"
+        # Time shift
+        # if (self.config.enable_time_shift) and (not self.eval_mode):
+        #     waveform = self.time_shifting(waveform)
+        # # Label Enhance
+        # if (self.config.crop_size is not None) and (not self.eval_mode):
+        #     waveform = self.crop_wav(waveform)
+        # # the label enhance rate is fixed 0.5
+        # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
+        #     kidx = np.where(target)[0]
+        #     for k in kidx:
+        #         for add_key in self.class_map[k][1]:
+        #             target[add_key] = 1.0
+        #         if len(self.class_map[k][2]) > 0:
+        #             add_key = random.choice(self.class_map[k][2])
+        #             target[add_key] = 1.0
+
+        # missing the text input
+        mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
+        mel_spec = (
+            torch.cat(
+                [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0
+            )
+            .cpu()
+            .numpy()
+        )
+        longer = random.choice([True, False])
+        if longer == False:
+            mel_spec[1:, :, :] = 0.0
+        data_dict = {
+            "hdf5_path": hdf5_path,
+            "index_in_hdf5": r_idx,
+            "audio_name": audio_name,
+            "waveform": waveform,
+            "class_label": target,
+            "text": text,
+            "longer": longer,
+            "mel_fusion": mel_spec,
+        }
+        return data_dict
+
+    def __len__(self):
+        return self.total_size
+
+
+class CsvDataset(Dataset):
+    def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
+        logging.debug(f"Loading csv data from {input_filename}.")
+        df = pd.read_csv(input_filename, sep=sep)
+
+        self.images = df[img_key].tolist()
+        self.captions = df[caption_key].tolist()
+        self.transforms = transforms
+        logging.debug("Done loading data.")
+
+    def __len__(self):
+        return len(self.captions)
+
+    def __getitem__(self, idx):
+        images = self.transforms(Image.open(str(self.images[idx])))
+        texts = tokenize([str(self.captions[idx])])[0]
+        return images, texts
+
+
+@dataclass
+class DataInfo:
+    dataloader: DataLoader
+    sampler: DistributedSampler
+
+
+def preprocess_txt(text):
+    return tokenize([str(text)])[0]
+
+
+# def get_dataset_size(shards, sizefilepath_=None, is_local=True):
+#     if isinstance(shards, list):
+#         size_list = []
+#         for s in shards:
+#             size_list.append(
+#                 get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
+#             )
+#     else:
+#         if not is_local:
+#             for n in dataset_split.keys():
+#                 if n in shards.split("/"):
+#                     break
+#             for s in dataset_split[n]:
+#                 if s in shards.split("/"):
+#                     break
+#             sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+#         shards_list = list(braceexpand.braceexpand(shards))
+#         dir_path = os.path.dirname(shards)
+#         if sizefilepath_ is not None:
+#             sizes = json.load(open(sizefilepath_, "r"))
+#             total_size = sum(
+#                 [
+#                     int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
+#                     for shard in shards_list
+#                 ]
+#             )
+#         else:
+#             sizes_filename = os.path.join(dir_path, "sizes.json")
+#             len_filename = os.path.join(dir_path, "__len__")
+#             if os.path.exists(sizes_filename):
+#                 sizes = json.load(open(sizes_filename, "r"))
+#                 total_size = sum(
+#                     [int(sizes[os.path.basename(shard)]) for shard in shards_list]
+#                 )
+#             elif os.path.exists(len_filename):
+#                 # FIXME this used to be eval(open(...)) but that seemed rather unsafe
+#                 total_size = ast.literal_eval(open(len_filename, "r").read())
+#             else:
+#                 raise Exception(
+#                     "Cannot find sizes file for dataset. Please specify the path to the file."
+#                 )
+#                 # total_size = None  # num samples undefined
+#                 # some common dataset sizes (at time of authors last download)
+#                 # cc3m-train: 2905954
+#                 # cc12m: 10968539
+#                 # LAION-400m: 407332084
+#         num_shards = len(shards_list)
+#     if isinstance(shards, list):
+#         return sum(size_list), len(shards)
+#     else:
+#         return total_size, num_shards
+
+
+def get_imagenet(args, preprocess_fns, split):
+    assert split in ["train", "val", "v2"]
+    is_train = split == "train"
+    preprocess_train, preprocess_val = preprocess_fns
+
+    if split == "v2":
+        from imagenetv2_pytorch import ImageNetV2Dataset
+
+        dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
+    else:
+        if is_train:
+            data_path = args.imagenet_train
+            preprocess_fn = preprocess_train
+        else:
+            data_path = args.imagenet_val
+            preprocess_fn = preprocess_val
+        assert data_path
+
+        dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
+
+    if is_train:
+        idxs = np.zeros(len(dataset.targets))
+        target_array = np.array(dataset.targets)
+        k = 50
+        for c in range(1000):
+            m = target_array == c
+            n = len(idxs[m])
+            arr = np.zeros(n)
+            arr[:k] = 1
+            np.random.shuffle(arr)
+            idxs[m] = arr
+
+        idxs = idxs.astype("int")
+        sampler = SubsetRandomSampler(np.where(idxs)[0])
+    else:
+        sampler = None
+
+    dataloader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        num_workers=args.workers,
+        sampler=sampler,
+    )
+
+    return DataInfo(dataloader, sampler)
+
+
+def count_samples(dataloader):
+    os.environ["WDS_EPOCH"] = "0"
+    n_elements, n_batches = 0, 0
+    for images, texts in dataloader:
+        n_batches += 1
+        n_elements += len(images)
+        assert len(images) == len(texts)
+    return n_elements, n_batches
+
+
+def filter_no_caption(sample):
+    return "txt" in sample
+
+
+def log_and_continue(exn):
+    """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
+    logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
+    return True
+
+
+_SHARD_SHUFFLE_SIZE = 2000
+_SHARD_SHUFFLE_INITIAL = 500
+_SAMPLE_SHUFFLE_SIZE = 5000
+_SAMPLE_SHUFFLE_INITIAL = 1000
+
+
+# def sample_prop(sizefile, inputs, proportion, is_local=True):
+#     """
+#     Sample a proportion of the data.
+#     """
+#     file_path_dict = {
+#         os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
+#         for i in range(len(inputs))
+#     }
+#     sampled_filepath_dict = {}
+#     sampled_size_dict = {}
+#     if not is_local:
+#         if os.path.exists("sizes.json"):
+#             os.remove("sizes.json")
+#         wget.download(sizefile, "sizes.json")
+#         sizefile = "sizes.json"
+#     with open(sizefile, "r", encoding="UTF-8") as f:
+#         load_dict = json.load(f)
+#     L = int(len(file_path_dict) * proportion)
+#     subkeys = random.sample(file_path_dict.keys(), L)
+#     for k in subkeys:
+#         sampled_size_dict[k] = load_dict[k]
+#         sampled_filepath_dict[k] = file_path_dict[k]
+#     return (
+#         sum(sampled_size_dict.values()),
+#         L,
+#         [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
+#         sampled_size_dict,
+#     )
+
+
+def get_mel(audio_data, audio_cfg):
+    # mel shape: (n_mels, T)
+    mel = torchaudio.transforms.MelSpectrogram(
+        sample_rate=audio_cfg["sample_rate"],
+        n_fft=audio_cfg["window_size"],
+        win_length=audio_cfg["window_size"],
+        hop_length=audio_cfg["hop_size"],
+        center=True,
+        pad_mode="reflect",
+        power=2.0,
+        norm=None,
+        onesided=True,
+        n_mels=64,
+        f_min=audio_cfg["fmin"],
+        f_max=audio_cfg["fmax"],
+    ).to(audio_data.device)
+    mel = mel(audio_data)
+    # we use log mel spectrogram as input
+    mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
+    return mel.T  # (T, n_mels)
+
+
+def get_audio_features(
+    audio_data, mel, max_len, data_truncating, data_filling, audio_cfg
+):
+    """
+    Calculate and add audio features to sample.
+    Sample: a dict containing all the data of current sample.
+    audio_data: a tensor of shape (T) containing audio data.
+    max_len: the maximum length of audio data.
+    data_truncating: the method of truncating data.
+    data_filling: the method of filling data.
+    audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
+    """
+    sample = {}
+
+    # assert audio_data.size(-1) <= max_len, str(audio_data.size())
+
+    # split to three parts
+    chunk_frames = (
+        max_len // audio_cfg["hop_size"] + 1
+    )  # the +1 related to how the spectrogram is computed
+    mel = mel[:chunk_frames]
+
+    audio_data = audio_data[..., :max_len]
+    sample["mel_fusion"] = mel
+    longer = torch.tensor([True])
+
+    sample["longer"] = longer
+    sample["waveform"] = audio_data
+
+    return sample
+
+
+def preprocess(
+    sample,
+    audio_ext,
+    text_ext,
+    max_len,
+    audio_cfg,
+    class_index_dict=None,
+    data_filling="pad",
+    data_truncating="rand_trunc",
+    text_augment_selection=None,
+):
+    """
+    Preprocess a single sample for wdsdataloader.
+    """
+    audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
+    audio_data = int16_to_float32(float32_to_int16(audio_data))
+    audio_data = torch.tensor(audio_data).float()
+
+    # TODO: (yusong) to be include in the future
+    # # if torchaudio not installed, use soundfile to load audio
+    # if torchaudio is None:
+    #     audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
+    #     audio_data = torch.tensor(audio_data).float()
+    # else:
+    #     # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
+    #     with tempfile.TemporaryDirectory() as dirname:
+    #         os.makedirs(dirname, exist_ok=True)
+    #         fname = os.path.join(dirname, f"file.flac")
+    #         with open(fname, "wb") as stream:
+    #             stream.write(sample[audio_ext])
+    #         audio_data, orig_sr = torchaudio.load(fname)
+    #         audio_data = audio_data[0, :].float()
+
+    sample = get_audio_features(
+        sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
+    )
+    del sample[audio_ext]
+
+    try:
+        json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
+    except:
+        print("sample[__url__]:", sample["__url__"])
+
+    # For selecting augmented text from dataset
+    if text_augment_selection is None or text_augment_selection == "none":
+        texts = json_dict_raw["text"]
+    elif text_augment_selection == "all":
+        if "text_augment_all" in json_dict_raw.keys():
+            texts = json_dict_raw["text_augment_all"]
+        else:
+            texts = json_dict_raw["text"]
+    elif text_augment_selection == "augment_only":
+        if "text_augment_all" in json_dict_raw.keys():
+            if json_dict_raw["text_augment_t5"] is None:
+                texts = json_dict_raw["text"]
+            else:
+                texts = json_dict_raw["text_augment_t5"]
+        else:
+            texts = json_dict_raw["text"]
+    else:
+        raise NotImplementedError(
+            f"text_augment_selection {text_augment_selection} not implemented"
+        )
+    sample["full_text"] = texts
+
+    if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
+        texts = random.choice(texts)
+    sample["raw_text"] = texts
+    sample["text"] = tokenizer(texts)  # text shape: [num_token]
+    if class_index_dict is not None:
+        # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
+        # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
+        # key, val = class_index_dict
+        # key = key[:].split('\n')
+        # _dict = {k: v for k, v in zip(key, val)}
+        sample["class_label"] = np.zeros(len(class_index_dict.keys()))
+        for x in json_dict_raw["tag"]:
+            sample["class_label"][class_index_dict[x]] = 1
+        sample["class_label"] = torch.tensor(sample["class_label"]).float()
+    del sample[text_ext]
+    sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
+    sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
+    sample["audio_orig_sr"] = orig_sr
+    return sample
+
+
+def collate_fn(batch):
+    """
+    Collate function for wdsdataloader.
+    batch: a list of dict, each dict is a sample
+    """
+    # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
+    batch_dict = {}
+    for k in batch[0].keys():
+        if isinstance(batch[0][k], dict):  # dealwith bert tokenizer output
+            batch_dict[k] = {}
+            for kk in batch[0][k].keys():
+                tmp = []
+                for i in range(len(batch)):
+                    tmp.append(batch[i][k][kk])
+                batch_dict[k][kk] = torch.vstack(tmp)
+        elif isinstance(batch[0][k], torch.Tensor):
+            batch_dict[k] = torch.stack([sample[k] for sample in batch])
+        elif isinstance(batch[0][k], np.ndarray):
+            batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch]))
+        else:
+            batch_dict[k] = [sample[k] for sample in batch]
+    return batch_dict
+
+
+# def get_wds_dataset(
+#     args,
+#     model_cfg,
+#     is_train,
+#     audio_ext="flac",
+#     text_ext="json",
+#     max_len=480000,
+#     proportion=1.0,
+#     sizefilepath_=None,
+#     is_local=None,
+# ):
+#     """
+#     Get a dataset for wdsdataloader.
+#     """
+#     if is_local is None and (not args.remotedata is None):
+#         is_local = not args.remotedata
+
+#     input_shards = args.train_data if is_train else args.val_data
+#     assert input_shards is not None
+
+#     if not sizefilepath_ is None:
+#         sizefilepath = sizefilepath_
+#     else:
+#         sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
+
+#     if proportion != 1.0:
+#         num_samples, num_shards, input_shards, _ = sample_prop(
+#             sizefilepath, input_shards, proportion, is_local=is_local
+#         )
+#     else:
+#         num_samples, num_shards = get_dataset_size(
+#             input_shards, sizefilepath_=sizefilepath_, is_local=is_local
+#         )
+
+#     if not num_samples:
+#         if is_train:
+#             num_samples = args.train_num_samples
+#             if not num_samples:
+#                 raise RuntimeError(
+#                     "Currently, number of dataset samples must be specified for training dataset. "
+#                     "Please specify via `--train-num-samples` if no dataset length info present."
+#                 )
+#         else:
+#             num_samples = (
+#                 args.val_num_samples or 0
+#             )  # eval will just exhaust the iterator if not specified
+
+#     pipeline = [wds.SimpleShardList(input_shards)]
+#     # at this point we have an iterator over all the shards
+#     # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
+#     if is_train or args.parallel_eval:
+#         pipeline.extend(
+#             [
+#                 wds.detshuffle(
+#                     bufsize=_SHARD_SHUFFLE_SIZE,
+#                     initial=_SHARD_SHUFFLE_INITIAL,
+#                     seed=args.seed,
+#                 ),
+#                 wds.split_by_node,
+#                 wds.split_by_worker,
+#                 # at this point, we have an iterator over the shards assigned to each worker at each node
+#                 wds.tarfile_to_samples(handler=log_and_continue),
+#                 wds.shuffle(
+#                     bufsize=_SAMPLE_SHUFFLE_SIZE,
+#                     initial=_SAMPLE_SHUFFLE_INITIAL,
+#                     rng=random.Random(args.seed),
+#                 ),
+#                 # wds.repeatedly,  # FIXME determine if this is beneficial
+#             ]
+#         )
+#     else:
+#         pipeline.extend(
+#             [
+#                 wds.split_by_worker,
+#                 # at this point, we have an iterator over the shards assigned to each worker
+#                 wds.tarfile_to_samples(handler=log_and_continue),
+#             ]
+#         )
+#     pipeline.append(
+#         wds.map(
+#             partial(
+#                 preprocess,
+#                 audio_ext=audio_ext,
+#                 text_ext=text_ext,
+#                 max_len=max_len,
+#                 audio_cfg=model_cfg["audio_cfg"],
+#                 class_index_dict=copy.deepcopy(args.class_index_dict),
+#                 data_filling=args.data_filling,
+#                 data_truncating=args.data_truncating,
+#                 text_augment_selection=args.text_augment_selection,
+#             )
+#         ),
+#     )
+
+#     pipeline.append(
+#         wds.batched(
+#             args.batch_size,
+#             partial=not (is_train or args.parallel_eval),
+#             collation_fn=collate_fn,
+#         )
+#     )
+
+#     dataset = wds.DataPipeline(*pipeline)
+#     if is_train or args.parallel_eval:
+#         # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
+#         # (yusong): See comments below.
+#         # roll over and repeat a few samples to get same number of full batches on each node
+#         global_batch_size = args.batch_size * args.world_size
+#         num_batches = math.ceil(num_samples / global_batch_size)
+#         num_workers = max(1, args.workers)
+#         num_worker_batches = math.ceil(
+#             num_batches / num_workers
+#         )  # per dataloader worker
+#         num_batches = num_worker_batches * num_workers
+#         num_samples = num_batches * global_batch_size
+#         dataset = dataset.with_epoch(
+#             num_worker_batches
+#         )  # each worker is iterating over this
+#     else:
+#         # last batches are partial, eval is done on single (master) node
+#         num_batches = math.ceil(num_samples / args.batch_size)
+
+#     kwargs = {}
+#     if args.horovod:  # multi-node training on summit
+#         kwargs["multiprocessing_context"] = "forkserver"
+
+#     dataloader = wds.WebLoader(
+#         dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
+#     )
+
+#     # FIXME not clear which approach is better, with_epoch before vs after dataloader?
+#     # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
+#     # if is_train:
+#     #     # roll over and repeat a few samples to get same number of full batches on each node
+#     #     global_batch_size = args.batch_size * args.world_size
+#     #     num_batches = math.ceil(num_samples / global_batch_size)
+#     #     num_workers = max(1, args.workers)
+#     #     num_batches = math.ceil(num_batches / num_workers) * num_workers
+#     #     num_samples = num_batches * global_batch_size
+#     #     dataloader = dataloader.with_epoch(num_batches)
+#     # else:
+#     #     # last batches are partial, eval is done on single (master) node
+#     #     num_batches = math.ceil(num_samples / args.batch_size)
+
+#     # add meta-data to dataloader instance for convenience
+#     dataloader.num_batches = num_batches
+#     dataloader.num_samples = num_samples
+
+#     return DataInfo(dataloader, None)
+
+
+def wds_batch_list2dict(
+    batch,
+    keys=[
+        "__url__",
+        "__key__",
+        "waveform",
+        "text",
+        "raw_text",
+        "audio_name",
+        "text_name",
+        "audio_orig_sr",
+    ],
+):
+    """
+    Return a dictionary of the batch, with keys as the names of the fields.
+    """
+    assert len(keys) == len(
+        batch
+    ), "batch must have same number of keys as keys argument"
+    return {keys[i]: batch[i] for i in range(len(batch))}
+
+
+def get_csv_dataset(args, preprocess_fn, is_train):
+    input_filename = args.train_data if is_train else args.val_data
+    assert input_filename
+    dataset = CsvDataset(
+        input_filename,
+        preprocess_fn,
+        img_key=args.csv_img_key,
+        caption_key=args.csv_caption_key,
+        sep=args.csv_separator,
+    )
+    num_samples = len(dataset)
+    sampler = DistributedSampler(dataset) if args.distributed and is_train else None
+    shuffle = is_train and sampler is None
+
+    dataloader = DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        shuffle=shuffle,
+        num_workers=args.workers,
+        pin_memory=True,
+        sampler=sampler,
+        drop_last=is_train,
+    )
+    dataloader.num_samples = num_samples
+    dataloader.num_batches = len(dataloader)
+
+    return DataInfo(dataloader, sampler)
+
+
+def get_toy_dataset(args, model_cfg, is_train):
+    index_path = args.train_data if is_train else args.val_data
+    ipc_path = args.train_ipc if is_train else args.val_ipc
+    assert index_path and ipc_path
+    eval_mode = not is_train
+    dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
+
+    num_samples = len(dataset)
+    sampler = (
+        DistributedSampler(dataset, shuffle=False)
+        if args.distributed and is_train
+        else None
+    )
+
+    dataloader = DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+        sampler=sampler,
+        drop_last=is_train,
+    )
+    dataloader.num_samples = num_samples
+    dataloader.num_batches = len(dataloader)
+
+    return DataInfo(dataloader, sampler)
+
+
+def get_dataset_fn(data_path, dataset_type):
+    if dataset_type == "webdataset":
+        return get_wds_dataset
+    elif dataset_type == "csv":
+        return get_csv_dataset
+    elif dataset_type == "auto":
+        ext = data_path.split(".")[-1]
+        if ext in ["csv", "tsv"]:
+            return get_csv_dataset
+        elif ext in ["tar"]:
+            return get_wds_dataset
+        else:
+            raise ValueError(
+                f"Tried to figure out dataset type, but failed for extention {ext}."
+            )
+    elif dataset_type == "toy":
+        return get_toy_dataset
+    else:
+        raise ValueError(f"Unsupported dataset type: {dataset_type}")
+
+
+def get_data(args, model_cfg):
+    data = {}
+
+    args.class_index_dict = load_class_label(args.class_label_path)
+
+    if args.datasetinfos is None:
+        args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
+    if args.dataset_type == "webdataset":
+        args.train_data = get_tar_path_from_dataset_name(
+            args.datasetnames,
+            args.datasetinfos,
+            islocal=not args.remotedata,
+            proportion=args.dataset_proportion,
+            dataset_path=args.datasetpath,
+            full_dataset=args.full_train_dataset,
+        )
+
+        if args.full_train_dataset is None:
+            args.full_train_dataset = []
+        if args.exclude_eval_dataset is None:
+            args.exclude_eval_dataset = []
+        excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
+
+        val_dataset_names = (
+            [n for n in args.datasetnames if n not in excluded_eval_datasets]
+            if excluded_eval_datasets
+            else args.datasetnames
+        )
+        args.val_dataset_names = val_dataset_names
+        args.val_data = get_tar_path_from_dataset_name(
+            val_dataset_names,
+            ["valid", "test", "eval"],
+            islocal=not args.remotedata,
+            proportion=1,
+            dataset_path=args.datasetpath,
+            full_dataset=None,
+        )
+
+    if args.train_data:
+        data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
+            args, model_cfg, is_train=True
+        )
+
+    if args.val_data:
+        data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
+            args, model_cfg, is_train=False
+        )
+
+    return data
diff --git a/audioldm2/clap/training/params.py b/audioldm2/clap/training/params.py
new file mode 100755
index 0000000000000000000000000000000000000000..0cc1a0e2d982e900988cf5a4b24b2e59b093537b
--- /dev/null
+++ b/audioldm2/clap/training/params.py
@@ -0,0 +1,563 @@
+import argparse
+
+
+def get_default_params(model_name):
+    # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
+    model_name = model_name.lower()
+    if "vit" in model_name:
+        return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
+    else:
+        return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--train-data",
+        type=str,
+        default=None,
+        help="Path to h5 filewith training data",
+    )
+    parser.add_argument(
+        "--val-data",
+        type=str,
+        default=None,
+        help="Path to h5 file with validation data",
+    )
+    parser.add_argument(
+        "--freeze-text",
+        default=False,
+        action="store_true",
+        help="if you need to freeze the text encoder, make this True",
+    )
+    parser.add_argument(
+        "--freeze-text-after",
+        type=int,
+        default=-1,
+        help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it",
+    )
+    parser.add_argument(
+        "--train-ipc",
+        type=str,
+        default=None,
+        help="Path to npy file of the number of instance per class in training data",
+    )
+    parser.add_argument(
+        "--val-ipc",
+        type=str,
+        default=None,
+        help="Path to npy file of the number of instance per class in validation data",
+    )
+    parser.add_argument(
+        "--train-num-samples",
+        type=int,
+        default=None,
+        help="Number of samples in dataset. Required for webdataset if not available in info file.",
+    )
+    parser.add_argument(
+        "--val-num-samples",
+        type=int,
+        default=None,
+        help="Number of samples in dataset. Useful for webdataset if not available in info file.",
+    )
+    parser.add_argument(
+        "--dataset-type",
+        choices=["webdataset", "csv", "auto", "toy"],
+        default="auto",
+        help="Which type of dataset to process.",
+    )
+    parser.add_argument(
+        "--csv-separator",
+        type=str,
+        default="\t",
+        help="For csv-like datasets, which separator to use.",
+    )
+    parser.add_argument(
+        "--csv-img-key",
+        type=str,
+        default="filepath",
+        help="For csv-like datasets, the name of the key for the image paths.",
+    )
+    parser.add_argument(
+        "--csv-caption-key",
+        type=str,
+        default="title",
+        help="For csv-like datasets, the name of the key for the captions.",
+    )
+    parser.add_argument(
+        "--imagenet-val",
+        type=str,
+        default=None,
+        help="Path to imagenet val set for conducting zero shot evaluation.",
+    )
+    parser.add_argument(
+        "--imagenet-v2",
+        type=str,
+        default=None,
+        help="Path to imagenet v2 for conducting zero shot evaluation.",
+    )
+    parser.add_argument(
+        "--datasetnames",
+        nargs="+",
+        default=None,
+        help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects",
+    )
+    parser.add_argument(
+        "--full-train-dataset",
+        nargs="+",
+        default=None,
+        help="Which dataset will be trained with all the subsets. (train+test)",
+    )
+    parser.add_argument(
+        "--exclude-eval-dataset",
+        nargs="+",
+        default=None,
+        help="Which dataset will be excluded with evaluation",
+    )
+    parser.add_argument(
+        "--datasetinfos",
+        nargs="+",
+        default=None,
+        help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval",
+    )
+    parser.add_argument(
+        "--dataset-proportion",
+        type=float,
+        default=1.0,
+        help="How much proportion of dataset we want to train.",
+    )
+    parser.add_argument(
+        "--remotedata",
+        default=False,
+        action="store_true",
+        help="if the dataset is remote, set this flag",
+    )
+    parser.add_argument(
+        "--class-label-path",
+        type=str,
+        default=None,
+        help="The path of the class label pickle or csv.",
+    )
+    parser.add_argument(
+        "--datasetpath",
+        type=str,
+        default="/mnt/audio_clip/webdataset_tar",
+        help="The path to the dataset",
+    )
+    parser.add_argument(
+        "--logs",
+        type=str,
+        default="./logs/",
+        help="Where to store tensorboard logs. Use None to avoid storing logs.",
+    )
+    parser.add_argument(
+        "--log-local",
+        action="store_true",
+        default=False,
+        help="log files on local master, otherwise global master only.",
+    )
+    parser.add_argument(
+        "--name",
+        type=str,
+        default=None,
+        help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
+    )
+    parser.add_argument(
+        "--workers", type=int, default=1, help="Number of workers per GPU."
+    )
+    parser.add_argument(
+        "--batch-size", type=int, default=64, help="Batch size per GPU."
+    )
+    parser.add_argument(
+        "--epochs", type=int, default=32, help="Number of epochs to train for."
+    )
+    parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
+    parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
+    parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
+    parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
+    parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.")
+    parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
+
+    parser.add_argument(
+        "--split-opt",
+        action="store_true",
+        default=False,
+        help="Use this flag to skip the learning rate decay.",
+    )
+    parser.add_argument(
+        "--lr-pretrained", type=float, default=None, help="Learning rate for text."
+    )
+    parser.add_argument(
+        "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text."
+    )
+    parser.add_argument(
+        "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text."
+    )
+    parser.add_argument(
+        "--eps-pretrained", type=float, default=None, help="Adam epsilon for text."
+    )
+    parser.add_argument(
+        "--wd-pretrained", type=float, default=0.2, help="Weight decay for text."
+    )
+    parser.add_argument(
+        "--momentum-pretrained", type=float, default=0.9, help="Momentum for text."
+    )
+    parser.add_argument(
+        "--lr-new", type=float, default=None, help="Learning rate for audio."
+    )
+    parser.add_argument(
+        "--beta1-new", type=float, default=None, help="Adam beta 1 for audio."
+    )
+    parser.add_argument(
+        "--beta2-new", type=float, default=None, help="Adam beta 2 for audio."
+    )
+    parser.add_argument(
+        "--eps-new", type=float, default=None, help="Adam epsilon for audio."
+    )
+    parser.add_argument(
+        "--wd-new", type=float, default=0.2, help="Weight decay for audio."
+    )
+    parser.add_argument(
+        "--momentum-new", type=float, default=0.9, help="Momentum for audio."
+    )
+    parser.add_argument(
+        "--warmup", type=int, default=10000, help="Number of steps to warmup for."
+    )
+    parser.add_argument(
+        "--use-bn-sync",
+        default=False,
+        action="store_true",
+        help="Whether to use batch norm sync.",
+    )
+    parser.add_argument(
+        "--skip-scheduler",
+        action="store_true",
+        default=False,
+        help="Use this flag to skip the learning rate decay.",
+    )
+    parser.add_argument(
+        "--save-frequency", type=int, default=1, help="How often to save checkpoints."
+    )
+    parser.add_argument(
+        "--save-top-performance",
+        type=int,
+        default=0,
+        help="Save the top x performance weights if the value >0",
+    )
+    parser.add_argument(
+        "--save-most-recent",
+        action="store_true",
+        default=False,
+        help="Always save the most recent model trained to epoch_latest.pt.",
+    )
+    parser.add_argument(
+        "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
+    )
+    parser.add_argument(
+        "--val-frequency",
+        type=int,
+        default=1,
+        help="How often to run evaluation with val data.",
+    )
+    parser.add_argument(
+        "--resume",
+        default=None,
+        type=str,
+        help="path to latest checkpoint (default: none)",
+    )
+    parser.add_argument(
+        "--precision",
+        choices=["amp", "fp16", "fp32"],
+        default="amp",
+        help="Floating point precision.",
+    )
+    parser.add_argument(
+        "--amodel",
+        type=str,
+        default="RN50",
+        help="Name of the audio backbone to use.",
+    )
+    parser.add_argument(
+        "--tmodel",
+        type=str,
+        default="transformer",
+        help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]",
+    )
+    parser.add_argument(
+        "--pretrained-audio",
+        default="",
+        type=str,
+        help="Use a pretrained audio model weights for the audio encoder of CLAP",
+    )
+    parser.add_argument(
+        "--pretrained-text",
+        default="",
+        type=str,
+        help="Use a pretrained text model weights for the text encoder of CLAP",
+    )
+    parser.add_argument(
+        "--pretrained",
+        default="",
+        type=str,
+        help="Use a pretrained CLIP model weights with the specified tag or file path.",
+    )
+    parser.add_argument(
+        "--pretrained-image",
+        default=False,
+        action="store_true",
+        help="Load imagenet pretrained weights for image tower backbone if available.",
+    )
+    parser.add_argument(
+        "--lock-image",
+        default=False,
+        action="store_true",
+        help="Lock full image tower by disabling gradients.",
+    )
+    parser.add_argument(
+        "--lock-image-unlocked-groups",
+        type=int,
+        default=0,
+        help="Leave last n image tower layer groups unlocked.",
+    )
+    parser.add_argument(
+        "--lock-image-freeze-bn-stats",
+        default=False,
+        action="store_true",
+        help="Freeze BatchNorm running stats in image tower for any locked layers.",
+    )
+    parser.add_argument(
+        "--local-loss",
+        default=False,
+        action="store_true",
+        help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)",
+    )
+    parser.add_argument(
+        "--gather-with-grad",
+        default=False,
+        action="store_true",
+        help="enable full distributed gradient for feature gather",
+    )
+    parser.add_argument(
+        "--force-quick-gelu",
+        default=False,
+        action="store_true",
+        help="Force use of QuickGELU activation for non-OpenAI transformer models.",
+    )
+    parser.add_argument(
+        "--torchscript",
+        default=False,
+        action="store_true",
+        help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
+    )
+    parser.add_argument(
+        "--trace",
+        default=False,
+        action="store_true",
+        help="torch.jit.trace the model for inference / eval only",
+    )
+    # arguments for distributed training
+    parser.add_argument(
+        "--dist-url",
+        default="env://",
+        type=str,
+        help="url used to set up distributed training",
+    )
+    parser.add_argument(
+        "--dist-backend", default="nccl", type=str, help="distributed backend"
+    )
+    parser.add_argument(
+        "--report-to",
+        default="",
+        type=str,
+        help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']",
+    )
+    parser.add_argument(
+        "--wandb-notes", default="", type=str, help="Notes if logging with wandb"
+    )
+    parser.add_argument(
+        "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
+    )
+    parser.add_argument(
+        "--debug",
+        default=False,
+        action="store_true",
+        help="If true, more information is logged.",
+    )
+    parser.add_argument(
+        "--copy-codebase",
+        default=False,
+        action="store_true",
+        help="If true, we copy the entire base on the log diretory, and execute from there.",
+    )
+    parser.add_argument(
+        "--horovod",
+        default=False,
+        action="store_true",
+        help="Use horovod for distributed training.",
+    )
+    parser.add_argument(
+        "--ddp-static-graph",
+        default=False,
+        action="store_true",
+        help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
+    )
+    parser.add_argument(
+        "--no-set-device-rank",
+        default=False,
+        action="store_true",
+        help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
+    )
+    parser.add_argument("--seed", type=int, default=4242, help="Default random seed.")
+
+    parser.add_argument(
+        "--top-k-checkpoint-select-dataset",
+        type=str,
+        default="all",
+        help="The dataset of selecting top-k checkpoint.",
+    )
+
+    # @R10, @R@5, @R1, mAP@10
+    parser.add_argument(
+        "--top-k-checkpoint-select-metric",
+        type=str,
+        default="_R@10",
+        help="The metric for selecting top-k checkpoint.",
+    )
+    parser.add_argument(
+        "--openai-model-cache-dir",
+        type=str,
+        default="~/.cache/clip",
+        help="Directory to download OpenAI models.",
+    )
+    parser.add_argument(
+        "--optimizer",
+        type=str,
+        default="adamw",
+        help="can be AdamW or SGD",
+    )
+    parser.add_argument(
+        "--parallel-eval",
+        default=False,
+        action="store_true",
+        help="Eval in parallel (multi-GPU, multi-node).",
+    )
+
+    parser.add_argument(
+        "--no-eval",
+        default=False,
+        action="store_true",
+        help="Training without evaluation.",
+    )
+
+    parser.add_argument(
+        "--lp-mlp",
+        default=False,
+        action="store_true",
+        help="Linear Probe using MLP layer or not.",
+    )
+
+    parser.add_argument(
+        "--lp-freeze",
+        default=False,
+        action="store_true",
+        help="Linear Probe using Freeze CLAP or not",
+    )
+
+    parser.add_argument(
+        "--lp-act",
+        default="None",
+        type=str,
+        help="Options are ['relu','elu','prelu','softmax','sigmoid']",
+    )
+
+    parser.add_argument(
+        "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe."
+    )
+
+    parser.add_argument(
+        "--lp-metrics",
+        type=str,
+        default="map,mauc,acc",
+        help="Metrics of Linear Probe.",
+    )
+
+    parser.add_argument(
+        "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe"
+    )
+    parser.add_argument(
+        "--kappa",
+        type=float,
+        default=0,
+        help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss",
+    )
+
+    parser.add_argument(
+        "--data-filling",
+        type=str,
+        default="pad",
+        help="type of data filling when the audio length is shorter than the max length."
+        "Can be one of the following: repeat, repeatpad, pad",
+    )
+    parser.add_argument(
+        "--data-truncating",
+        type=str,
+        default="rand_trunc",
+        help="type of data truncation when the audio length is longer than the max length."
+        "Can be one of the following: rand_trunc, fusion",
+    )
+
+    parser.add_argument(
+        "--clap-mlploss",
+        default=False,
+        action="store_true",
+        help="Using MLP loss for CLAP model or not",
+    )
+
+    parser.add_argument(
+        "--wandb-id",
+        type=str,
+        default=None,
+        help="the id of wandb experiment to restore.",
+    )
+
+    parser.add_argument(
+        "--sleep", type=float, default=0, help="sleep n seconds before start training"
+    )
+
+    # variable length processing
+    parser.add_argument(
+        "--enable-fusion",
+        default=False,
+        action="store_true",
+        help="Enable feature funsion for variable-length data",
+    )
+
+    parser.add_argument(
+        "--fusion-type",
+        type=str,
+        default="None",
+        help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']",
+    )
+
+    parser.add_argument(
+        "--mixup",
+        default=False,
+        action="store_true",
+        help="Enable mixup in finetuning training.",
+    )
+    parser.add_argument(
+        "--text-augment-selection",
+        type=str,
+        default=None,
+        help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']",
+    )
+
+    args = parser.parse_args()
+
+    # If some params are not passed, we use the default values based on model name.
+    default_params = get_default_params(args.amodel)
+    for name, val in default_params.items():
+        if getattr(args, name) is None:
+            setattr(args, name, val)
+
+    return args
diff --git a/audioldm2/hifigan/LICENSE b/audioldm2/hifigan/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..5afae394d6b37da0e12ba6b290d2512687f421ac
--- /dev/null
+++ b/audioldm2/hifigan/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Jungil Kong
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/audioldm2/hifigan/__init__.py b/audioldm2/hifigan/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..34e055557bf2ecb457376663b67390543c71fb1f
--- /dev/null
+++ b/audioldm2/hifigan/__init__.py
@@ -0,0 +1,8 @@
+from .models_v2 import Generator
+from .models import Generator as Generator_old
+
+
+class AttrDict(dict):
+    def __init__(self, *args, **kwargs):
+        super(AttrDict, self).__init__(*args, **kwargs)
+        self.__dict__ = self
diff --git a/audioldm2/hifigan/models.py b/audioldm2/hifigan/models.py
new file mode 100755
index 0000000000000000000000000000000000000000..c4382cc39de0463f9b7c0f33f037dbc233e7cb36
--- /dev/null
+++ b/audioldm2/hifigan/models.py
@@ -0,0 +1,174 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+LRELU_SLOPE = 0.1
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+class ResBlock(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock, self).__init__()
+        self.h = h
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class Generator(torch.nn.Module):
+    def __init__(self, h):
+        super(Generator, self).__init__()
+        self.h = h
+        self.num_kernels = len(h.resblock_kernel_sizes)
+        self.num_upsamples = len(h.upsample_rates)
+        self.conv_pre = weight_norm(
+            Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
+        )
+        resblock = ResBlock
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        h.upsample_initial_channel // (2**i),
+                        h.upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = h.upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+            ):
+                self.resblocks.append(resblock(h, ch, k, d))
+
+        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+    def forward(self, x):
+        x = self.conv_pre(x)
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        # print("Removing weight norm...")
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
diff --git a/audioldm2/hifigan/models_v2.py b/audioldm2/hifigan/models_v2.py
new file mode 100755
index 0000000000000000000000000000000000000000..27a2df6b54bdd3a5b259645442624800ac0e8afe
--- /dev/null
+++ b/audioldm2/hifigan/models_v2.py
@@ -0,0 +1,395 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+LRELU_SLOPE = 0.1
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+class ResBlock1(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.h = h
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2, self).__init__()
+        self.h = h
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+    def forward(self, x):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+class Generator(torch.nn.Module):
+    def __init__(self, h):
+        super(Generator, self).__init__()
+        self.h = h
+        self.num_kernels = len(h.resblock_kernel_sizes)
+        self.num_upsamples = len(h.upsample_rates)
+        self.conv_pre = weight_norm(
+            Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3)
+        )
+        resblock = ResBlock1 if h.resblock == "1" else ResBlock2
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        h.upsample_initial_channel // (2**i),
+                        h.upsample_initial_channel // (2 ** (i + 1)),
+                        u * 2,
+                        u,
+                        padding=u // 2 + u % 2,
+                        output_padding=u % 2,
+                    )
+                )
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = h.upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+            ):
+                self.resblocks.append(resblock(h, ch, k, d))
+
+        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+    def forward(self, x):
+        # import ipdb; ipdb.set_trace()
+        x = self.conv_pre(x)
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        # print('Removing weight norm...')
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
+
+
+##################################################################################################
+
+# import torch
+# import torch.nn as nn
+# import torch.nn.functional as F
+# from torch.nn import Conv1d, ConvTranspose1d
+# from torch.nn.utils import weight_norm, remove_weight_norm
+
+# LRELU_SLOPE = 0.1
+
+
+# def init_weights(m, mean=0.0, std=0.01):
+#     classname = m.__class__.__name__
+#     if classname.find("Conv") != -1:
+#         m.weight.data.normal_(mean, std)
+
+
+# def get_padding(kernel_size, dilation=1):
+#     return int((kernel_size * dilation - dilation) / 2)
+
+
+# class ResBlock(torch.nn.Module):
+#     def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+#         super(ResBlock, self).__init__()
+#         self.h = h
+#         self.convs1 = nn.ModuleList(
+#             [
+#                 weight_norm(
+#                     Conv1d(
+#                         channels,
+#                         channels,
+#                         kernel_size,
+#                         1,
+#                         dilation=dilation[0],
+#                         padding=get_padding(kernel_size, dilation[0]),
+#                     )
+#                 ),
+#                 weight_norm(
+#                     Conv1d(
+#                         channels,
+#                         channels,
+#                         kernel_size,
+#                         1,
+#                         dilation=dilation[1],
+#                         padding=get_padding(kernel_size, dilation[1]),
+#                     )
+#                 ),
+#                 weight_norm(
+#                     Conv1d(
+#                         channels,
+#                         channels,
+#                         kernel_size,
+#                         1,
+#                         dilation=dilation[2],
+#                         padding=get_padding(kernel_size, dilation[2]),
+#                     )
+#                 ),
+#             ]
+#         )
+#         self.convs1.apply(init_weights)
+
+#         self.convs2 = nn.ModuleList(
+#             [
+#                 weight_norm(
+#                     Conv1d(
+#                         channels,
+#                         channels,
+#                         kernel_size,
+#                         1,
+#                         dilation=1,
+#                         padding=get_padding(kernel_size, 1),
+#                     )
+#                 ),
+#                 weight_norm(
+#                     Conv1d(
+#                         channels,
+#                         channels,
+#                         kernel_size,
+#                         1,
+#                         dilation=1,
+#                         padding=get_padding(kernel_size, 1),
+#                     )
+#                 ),
+#                 weight_norm(
+#                     Conv1d(
+#                         channels,
+#                         channels,
+#                         kernel_size,
+#                         1,
+#                         dilation=1,
+#                         padding=get_padding(kernel_size, 1),
+#                     )
+#                 ),
+#             ]
+#         )
+#         self.convs2.apply(init_weights)
+
+#     def forward(self, x):
+#         for c1, c2 in zip(self.convs1, self.convs2):
+#             xt = F.leaky_relu(x, LRELU_SLOPE)
+#             xt = c1(xt)
+#             xt = F.leaky_relu(xt, LRELU_SLOPE)
+#             xt = c2(xt)
+#             x = xt + x
+#         return x
+
+#     def remove_weight_norm(self):
+#         for l in self.convs1:
+#             remove_weight_norm(l)
+#         for l in self.convs2:
+#             remove_weight_norm(l)
+
+# class Generator(torch.nn.Module):
+#     def __init__(self, h):
+#         super(Generator, self).__init__()
+#         self.h = h
+#         self.num_kernels = len(h.resblock_kernel_sizes)
+#         self.num_upsamples = len(h.upsample_rates)
+#         self.conv_pre = weight_norm(
+#             Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
+#         )
+#         resblock = ResBlock
+
+#         self.ups = nn.ModuleList()
+#         for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+#             self.ups.append(
+#                 weight_norm(
+#                     ConvTranspose1d(
+#                         h.upsample_initial_channel // (2**i),
+#                         h.upsample_initial_channel // (2 ** (i + 1)),
+#                         k,
+#                         u,
+#                         padding=(k - u) // 2,
+#                     )
+#                 )
+#             )
+
+#         self.resblocks = nn.ModuleList()
+#         for i in range(len(self.ups)):
+#             ch = h.upsample_initial_channel // (2 ** (i + 1))
+#             for j, (k, d) in enumerate(
+#                 zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+#             ):
+#                 self.resblocks.append(resblock(h, ch, k, d))
+
+#         self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+#         self.ups.apply(init_weights)
+#         self.conv_post.apply(init_weights)
+
+#     def forward(self, x):
+#         x = self.conv_pre(x)
+#         for i in range(self.num_upsamples):
+#             x = F.leaky_relu(x, LRELU_SLOPE)
+#             x = self.ups[i](x)
+#             xs = None
+#             for j in range(self.num_kernels):
+#                 if xs is None:
+#                     xs = self.resblocks[i * self.num_kernels + j](x)
+#                 else:
+#                     xs += self.resblocks[i * self.num_kernels + j](x)
+#             x = xs / self.num_kernels
+#         x = F.leaky_relu(x)
+#         x = self.conv_post(x)
+#         x = torch.tanh(x)
+
+#         return x
+
+#     def remove_weight_norm(self):
+#         print("Removing weight norm...")
+#         for l in self.ups:
+#             remove_weight_norm(l)
+#         for l in self.resblocks:
+#             l.remove_weight_norm()
+#         remove_weight_norm(self.conv_pre)
+#         remove_weight_norm(self.conv_post)
diff --git a/audioldm2/latent_diffusion/__init__.py b/audioldm2/latent_diffusion/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/models/__init__.py b/audioldm2/latent_diffusion/models/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/models/ddim.py b/audioldm2/latent_diffusion/models/ddim.py
new file mode 100755
index 0000000000000000000000000000000000000000..0c07207af7959847552805f00831122304b4330e
--- /dev/null
+++ b/audioldm2/latent_diffusion/models/ddim.py
@@ -0,0 +1,487 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.util import (
+    make_ddim_sampling_parameters,
+    make_ddim_timesteps,
+    noise_like,
+    extract_into_tensor,
+)
+
+
+class DDIMSampler(object):
+    def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+        self.device = device
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != self.device:
+                attr = attr.to(self.device)
+        setattr(self, name, attr)
+
+    def make_schedule(
+        self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+    ):
+        self.ddim_timesteps = make_ddim_timesteps(
+            ddim_discr_method=ddim_discretize,
+            num_ddim_timesteps=ddim_num_steps,
+            num_ddpm_timesteps=self.ddpm_num_timesteps,
+            verbose=verbose,
+        )
+        alphas_cumprod = self.model.alphas_cumprod
+        assert (
+            alphas_cumprod.shape[0] == self.ddpm_num_timesteps
+        ), "alphas have to be defined for each timestep"
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer("betas", to_torch(self.model.betas))
+        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+        self.register_buffer(
+            "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+        )
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer(
+            "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_one_minus_alphas_cumprod",
+            to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+        )
+        self.register_buffer(
+            "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_recipm1_alphas_cumprod",
+            to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+        )
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+            alphacums=alphas_cumprod.cpu(),
+            ddim_timesteps=self.ddim_timesteps,
+            eta=ddim_eta,
+            verbose=verbose,
+        )
+        self.register_buffer("ddim_sigmas", ddim_sigmas)
+        self.register_buffer("ddim_alphas", ddim_alphas)
+        self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+        self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev)
+            / (1 - self.alphas_cumprod)
+            * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+        )
+        self.register_buffer(
+            "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+        )
+
+    @torch.no_grad()
+    def sample(
+        self,
+        S,
+        batch_size,
+        shape,
+        conditioning=None,
+        callback=None,
+        normals_sequence=None,
+        img_callback=None,
+        quantize_x0=False,
+        eta=0.0,
+        mask=None,
+        x0=None,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        verbose=True,
+        x_T=None,
+        log_every_t=100,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+        dynamic_threshold=None,
+        ucg_schedule=None,
+        **kwargs,
+    ):
+        # if conditioning is not None:
+        #     if isinstance(conditioning, dict):
+        #         ctmp = conditioning[list(conditioning.keys())[0]]
+        #         while isinstance(ctmp, list): ctmp = ctmp[0]
+        #         cbs = ctmp.shape[0]
+        #         if cbs != batch_size:
+        #             print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+        #     elif isinstance(conditioning, list):
+        #         for ctmp in conditioning:
+        #             if ctmp.shape[0] != batch_size:
+        #                 print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+        #     else:
+        #         if conditioning.shape[0] != batch_size:
+        #             print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+        samples, intermediates = self.ddim_sampling(
+            conditioning,
+            size,
+            callback=callback,
+            img_callback=img_callback,
+            quantize_denoised=quantize_x0,
+            mask=mask,
+            x0=x0,
+            ddim_use_original_steps=False,
+            noise_dropout=noise_dropout,
+            temperature=temperature,
+            score_corrector=score_corrector,
+            corrector_kwargs=corrector_kwargs,
+            x_T=x_T,
+            log_every_t=log_every_t,
+            unconditional_guidance_scale=unconditional_guidance_scale,
+            unconditional_conditioning=unconditional_conditioning,
+            dynamic_threshold=dynamic_threshold,
+            ucg_schedule=ucg_schedule,
+        )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def ddim_sampling(
+        self,
+        cond,
+        shape,
+        x_T=None,
+        ddim_use_original_steps=False,
+        callback=None,
+        timesteps=None,
+        quantize_denoised=False,
+        mask=None,
+        x0=None,
+        img_callback=None,
+        log_every_t=100,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        dynamic_threshold=None,
+        ucg_schedule=None,
+    ):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = (
+                self.ddpm_num_timesteps
+                if ddim_use_original_steps
+                else self.ddim_timesteps
+            )
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = (
+                int(
+                    min(timesteps / self.ddim_timesteps.shape[0], 1)
+                    * self.ddim_timesteps.shape[0]
+                )
+                - 1
+            )
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {"x_inter": [img], "pred_x0": [img]}
+        time_range = (
+            reversed(range(0, timesteps))
+            if ddim_use_original_steps
+            else np.flip(timesteps)
+        )
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(
+                    x0, ts
+                )  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1.0 - mask) * img
+
+            if ucg_schedule is not None:
+                assert len(ucg_schedule) == len(time_range)
+                unconditional_guidance_scale = ucg_schedule[i]
+
+            outs = self.p_sample_ddim(
+                img,
+                cond,
+                ts,
+                index=index,
+                use_original_steps=ddim_use_original_steps,
+                quantize_denoised=quantize_denoised,
+                temperature=temperature,
+                noise_dropout=noise_dropout,
+                score_corrector=score_corrector,
+                corrector_kwargs=corrector_kwargs,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+                dynamic_threshold=dynamic_threshold,
+            )
+            img, pred_x0 = outs
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates["x_inter"].append(img)
+                intermediates["pred_x0"].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_ddim(
+        self,
+        x,
+        c,
+        t,
+        index,
+        repeat_noise=False,
+        use_original_steps=False,
+        quantize_denoised=False,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        dynamic_threshold=None,
+    ):
+        b, *_, device = *x.shape, x.device
+
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
+            model_output = self.model.apply_model(x, t, c)
+        else:
+            x_in = x
+            t_in = t
+
+            assert isinstance(c, dict)
+            assert isinstance(unconditional_conditioning, dict)
+
+            model_uncond = self.model.apply_model(
+                x_in, t_in, unconditional_conditioning
+            )
+            model_t = self.model.apply_model(x_in, t_in, c)
+
+            model_output = model_uncond + unconditional_guidance_scale * (
+                model_t - model_uncond
+            )
+
+        if self.model.parameterization == "v":
+            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+        else:
+            e_t = model_output
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps", "not implemented"
+            e_t = score_corrector.modify_score(
+                self.model, e_t, x, t, c, **corrector_kwargs
+            )
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = (
+            self.model.alphas_cumprod_prev
+            if use_original_steps
+            else self.ddim_alphas_prev
+        )
+        sqrt_one_minus_alphas = (
+            self.model.sqrt_one_minus_alphas_cumprod
+            if use_original_steps
+            else self.ddim_sqrt_one_minus_alphas
+        )
+        sigmas = (
+            self.model.ddim_sigmas_for_original_num_steps
+            if use_original_steps
+            else self.ddim_sigmas
+        )
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full(
+            (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
+        )
+
+        # current prediction for x_0
+        if self.model.parameterization != "v":
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        else:
+            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+        if dynamic_threshold is not None:
+            raise NotImplementedError()
+
+        # direction pointing to x_t
+        dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.0:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    @torch.no_grad()
+    def encode(
+        self,
+        x0,
+        c,
+        t_enc,
+        use_original_steps=False,
+        return_intermediates=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        callback=None,
+    ):
+        num_reference_steps = (
+            self.ddpm_num_timesteps
+            if use_original_steps
+            else self.ddim_timesteps.shape[0]
+        )
+
+        assert t_enc <= num_reference_steps
+        num_steps = t_enc
+
+        if use_original_steps:
+            alphas_next = self.alphas_cumprod[:num_steps]
+            alphas = self.alphas_cumprod_prev[:num_steps]
+        else:
+            alphas_next = self.ddim_alphas[:num_steps]
+            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+        x_next = x0
+        intermediates = []
+        inter_steps = []
+        for i in tqdm(range(num_steps), desc="Encoding Image"):
+            t = torch.full(
+                (x0.shape[0],), i, device=self.model.device, dtype=torch.long
+            )
+            if unconditional_guidance_scale == 1.0:
+                noise_pred = self.model.apply_model(x_next, t, c)
+            else:
+                assert unconditional_conditioning is not None
+                e_t_uncond, noise_pred = torch.chunk(
+                    self.model.apply_model(
+                        torch.cat((x_next, x_next)),
+                        torch.cat((t, t)),
+                        torch.cat((unconditional_conditioning, c)),
+                    ),
+                    2,
+                )
+                noise_pred = e_t_uncond + unconditional_guidance_scale * (
+                    noise_pred - e_t_uncond
+                )
+
+            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+            weighted_noise_pred = (
+                alphas_next[i].sqrt()
+                * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
+                * noise_pred
+            )
+            x_next = xt_weighted + weighted_noise_pred
+            if (
+                return_intermediates
+                and i % (num_steps // return_intermediates) == 0
+                and i < num_steps - 1
+            ):
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            elif return_intermediates and i >= num_steps - 2:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            if callback:
+                callback(i)
+
+        out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
+        if return_intermediates:
+            out.update({"intermediates": intermediates})
+        return x_next, out
+
+    @torch.no_grad()
+    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+        # fast, but does not allow for exact reconstruction
+        # t serves as an index to gather the correct alphas
+        if use_original_steps:
+            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+        else:
+            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+        if noise is None:
+            noise = torch.randn_like(x0)
+        return (
+            extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+            + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
+        )
+
+    @torch.no_grad()
+    def decode(
+        self,
+        x_latent,
+        cond,
+        t_start,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        use_original_steps=False,
+        callback=None,
+    ):
+        timesteps = (
+            np.arange(self.ddpm_num_timesteps)
+            if use_original_steps
+            else self.ddim_timesteps
+        )
+        timesteps = timesteps[:t_start]
+
+        time_range = np.flip(timesteps)
+        total_steps = timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
+        x_dec = x_latent
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full(
+                (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
+            )
+            x_dec, _ = self.p_sample_ddim(
+                x_dec,
+                cond,
+                ts,
+                index=index,
+                use_original_steps=use_original_steps,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+            )
+            if callback:
+                callback(i)
+        return x_dec
diff --git a/audioldm2/latent_diffusion/models/ddpm.py b/audioldm2/latent_diffusion/models/ddpm.py
new file mode 100755
index 0000000000000000000000000000000000000000..df3a6c032ba2ec61250212a31d68184e763dcf0e
--- /dev/null
+++ b/audioldm2/latent_diffusion/models/ddpm.py
@@ -0,0 +1,1840 @@
+from multiprocessing.sharedctypes import Value
+import os
+
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange, repeat
+from contextlib import contextmanager
+from functools import partial
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from audioldm2.latent_diffusion.modules.encoders.modules import *
+
+from audioldm2.latent_diffusion.util import (
+    exists,
+    default,
+    count_params,
+    instantiate_from_config,
+)
+from audioldm2.latent_diffusion.modules.ema import LitEma
+from audioldm2.latent_diffusion.modules.distributions.distributions import (
+    DiagonalGaussianDistribution,
+)
+
+# from latent_encoder.autoencoder import (
+#     VQModelInterface,
+#     IdentityFirstStage,
+#     AutoencoderKL,
+# )
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.util import (
+    make_beta_schedule,
+    extract_into_tensor,
+    noise_like,
+)
+
+from audioldm2.latent_diffusion.models.ddim import DDIMSampler
+from audioldm2.latent_diffusion.models.plms import PLMSSampler
+import soundfile as sf
+import os
+
+__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
+
+CACHE_DIR = os.getenv(
+    "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2")
+)
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+    return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(nn.Module):
+    # classic DDPM with Gaussian diffusion, in image space
+    def __init__(
+        self,
+        unet_config,
+        sampling_rate=None,
+        timesteps=1000,
+        beta_schedule="linear",
+        loss_type="l2",
+        ckpt_path=None,
+        ignore_keys=[],
+        load_only_unet=False,
+        monitor="val/loss",
+        use_ema=True,
+        first_stage_key="image",
+        latent_t_size=256,
+        latent_f_size=16,
+        channels=3,
+        log_every_t=100,
+        clip_denoised=True,
+        linear_start=1e-4,
+        linear_end=2e-2,
+        cosine_s=8e-3,
+        given_betas=None,
+        original_elbo_weight=0.0,
+        v_posterior=0.0,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+        l_simple_weight=1.0,
+        conditioning_key=None,
+        parameterization="eps",  # all assuming fixed variance schedules
+        scheduler_config=None,
+        use_positional_encodings=False,
+        learn_logvar=False,
+        logvar_init=0.0,
+        evaluator=None,
+        device=None,
+    ):
+        super().__init__()
+        assert parameterization in [
+            "eps",
+            "x0",
+            "v",
+        ], 'currently only supporting "eps" and "x0" and "v"'
+        self.parameterization = parameterization
+        self.state = None
+        self.device = device
+        # print(
+        #     f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
+        # )
+        assert sampling_rate is not None
+        self.validation_folder_name = "temp_name"
+        self.clip_denoised = clip_denoised
+        self.log_every_t = log_every_t
+        self.first_stage_key = first_stage_key
+        self.sampling_rate = sampling_rate
+
+        self.clap = CLAPAudioEmbeddingClassifierFreev2(
+            pretrained_path="",
+            sampling_rate=self.sampling_rate,
+            embed_mode="audio",
+            amodel="HTSAT-base",
+        )
+
+        self.initialize_param_check_toolkit()
+
+        self.latent_t_size = latent_t_size
+        self.latent_f_size = latent_f_size
+
+        self.channels = channels
+        self.use_positional_encodings = use_positional_encodings
+        self.model = DiffusionWrapper(unet_config, conditioning_key)
+        count_params(self.model, verbose=True)
+        self.use_ema = use_ema
+        if self.use_ema:
+            self.model_ema = LitEma(self.model)
+            # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        self.use_scheduler = scheduler_config is not None
+        if self.use_scheduler:
+            self.scheduler_config = scheduler_config
+
+        self.v_posterior = v_posterior
+        self.original_elbo_weight = original_elbo_weight
+        self.l_simple_weight = l_simple_weight
+
+        if monitor is not None:
+            self.monitor = monitor
+        if ckpt_path is not None:
+            self.init_from_ckpt(
+                ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet
+            )
+
+        self.register_schedule(
+            given_betas=given_betas,
+            beta_schedule=beta_schedule,
+            timesteps=timesteps,
+            linear_start=linear_start,
+            linear_end=linear_end,
+            cosine_s=cosine_s,
+        )
+
+        self.loss_type = loss_type
+
+        self.learn_logvar = learn_logvar
+        self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+        if self.learn_logvar:
+            self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+        else:
+            self.logvar = nn.Parameter(self.logvar, requires_grad=False)
+
+        self.logger_save_dir = None
+        self.logger_exp_name = None
+        self.logger_exp_group_name = None
+        self.logger_version = None
+
+        self.label_indices_total = None
+        # To avoid the system cannot find metric value for checkpoint
+        self.metrics_buffer = {
+            "val/kullback_leibler_divergence_sigmoid": 15.0,
+            "val/kullback_leibler_divergence_softmax": 10.0,
+            "val/psnr": 0.0,
+            "val/ssim": 0.0,
+            "val/inception_score_mean": 1.0,
+            "val/inception_score_std": 0.0,
+            "val/kernel_inception_distance_mean": 0.0,
+            "val/kernel_inception_distance_std": 0.0,
+            "val/frechet_inception_distance": 133.0,
+            "val/frechet_audio_distance": 32.0,
+        }
+        self.initial_learning_rate = None
+        self.test_data_subset_path = None
+
+    def get_log_dir(self):
+        return os.path.join(
+            self.logger_save_dir, self.logger_exp_group_name, self.logger_exp_name
+        )
+
+    def set_log_dir(self, save_dir, exp_group_name, exp_name):
+        self.logger_save_dir = save_dir
+        self.logger_exp_group_name = exp_group_name
+        self.logger_exp_name = exp_name
+
+    def register_schedule(
+        self,
+        given_betas=None,
+        beta_schedule="linear",
+        timesteps=1000,
+        linear_start=1e-4,
+        linear_end=2e-2,
+        cosine_s=8e-3,
+    ):
+        if exists(given_betas):
+            betas = given_betas
+        else:
+            betas = make_beta_schedule(
+                beta_schedule,
+                timesteps,
+                linear_start=linear_start,
+                linear_end=linear_end,
+                cosine_s=cosine_s,
+            )
+        alphas = 1.0 - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+        (timesteps,) = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.linear_start = linear_start
+        self.linear_end = linear_end
+        assert (
+            alphas_cumprod.shape[0] == self.num_timesteps
+        ), "alphas have to be defined for each timestep"
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer("betas", to_torch(betas))
+        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+        self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer(
+            "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+        )
+        self.register_buffer(
+            "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+        )
+        self.register_buffer(
+            "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
+        )
+        self.register_buffer(
+            "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
+        )
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        posterior_variance = (1 - self.v_posterior) * betas * (
+            1.0 - alphas_cumprod_prev
+        ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer("posterior_variance", to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer(
+            "posterior_log_variance_clipped",
+            to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
+        )
+        self.register_buffer(
+            "posterior_mean_coef1",
+            to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
+        )
+        self.register_buffer(
+            "posterior_mean_coef2",
+            to_torch(
+                (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
+            ),
+        )
+
+        if self.parameterization == "eps":
+            lvlb_weights = self.betas**2 / (
+                2
+                * self.posterior_variance
+                * to_torch(alphas)
+                * (1 - self.alphas_cumprod)
+            )
+        elif self.parameterization == "x0":
+            lvlb_weights = (
+                0.5
+                * np.sqrt(torch.Tensor(alphas_cumprod))
+                / (2.0 * 1 - torch.Tensor(alphas_cumprod))
+            )
+        elif self.parameterization == "v":
+            lvlb_weights = torch.ones_like(
+                self.betas**2
+                / (
+                    2
+                    * self.posterior_variance
+                    * to_torch(alphas)
+                    * (1 - self.alphas_cumprod)
+                )
+            )
+        else:
+            raise NotImplementedError("mu not supported")
+        # TODO how to choose this term
+        lvlb_weights[0] = lvlb_weights[1]
+        self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
+        assert not torch.isnan(self.lvlb_weights).all()
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.model.parameters())
+            self.model_ema.copy_to(self.model)
+            # if context is not None:
+            #     print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.model.parameters())
+                # if context is not None:
+                #     print(f"{context}: Restored training weights")
+
+    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+        sd = torch.load(path, map_location="cpu")
+        if "state_dict" in list(sd.keys()):
+            sd = sd["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        missing, unexpected = (
+            self.load_state_dict(sd, strict=False)
+            if not only_model
+            else self.model.load_state_dict(sd, strict=False)
+        )
+        print(
+            f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+        )
+        if len(missing) > 0:
+            print(f"Missing Keys: {missing}")
+        if len(unexpected) > 0:
+            print(f"Unexpected Keys: {unexpected}")
+
+    def q_mean_variance(self, x_start, t):
+        """
+        Get the distribution q(x_t | x_0).
+        :param x_start: the [N x C x ...] tensor of noiseless inputs.
+        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+        """
+        mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+        log_variance = extract_into_tensor(
+            self.log_one_minus_alphas_cumprod, t, x_start.shape
+        )
+        return mean, variance, log_variance
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+            extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+            - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+            * noise
+        )
+
+    def q_posterior(self, x_start, x_t, t):
+        posterior_mean = (
+            extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+            + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+        )
+        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+        posterior_log_variance_clipped = extract_into_tensor(
+            self.posterior_log_variance_clipped, t, x_t.shape
+        )
+        return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+    def p_mean_variance(self, x, t, clip_denoised: bool):
+        model_out = self.model(x, t)
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        if clip_denoised:
+            x_recon.clamp_(-1.0, 1.0)
+
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+            x_start=x_recon, x_t=x, t=t
+        )
+        return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+        b, *_, device = *x.shape, x.device
+        model_mean, _, model_log_variance = self.p_mean_variance(
+            x=x, t=t, clip_denoised=clip_denoised
+        )
+        noise = noise_like(x.shape, device, repeat_noise)
+        # no noise when t == 0
+        nonzero_mask = (
+            (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
+        )
+        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def p_sample_loop(self, shape, return_intermediates=False):
+        device = self.betas.device
+        b = shape[0]
+        img = torch.randn(shape, device=device)
+        intermediates = [img]
+        for i in tqdm(
+            reversed(range(0, self.num_timesteps)),
+            desc="Sampling t",
+            total=self.num_timesteps,
+        ):
+            img = self.p_sample(
+                img,
+                torch.full((b,), i, device=device, dtype=torch.long),
+                clip_denoised=self.clip_denoised,
+            )
+            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+                intermediates.append(img)
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(self, batch_size=16, return_intermediates=False):
+        shape = (batch_size, channels, self.latent_t_size, self.latent_f_size)
+        self.channels
+        return self.p_sample_loop(shape, return_intermediates=return_intermediates)
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+            * noise
+        )
+
+    def get_loss(self, pred, target, mean=True):
+        if self.loss_type == "l1":
+            loss = (target - pred).abs()
+            if mean:
+                loss = loss.mean()
+        elif self.loss_type == "l2":
+            if mean:
+                loss = torch.nn.functional.mse_loss(target, pred)
+            else:
+                loss = torch.nn.functional.mse_loss(target, pred, reduction="none")
+        else:
+            raise NotImplementedError("unknown loss type '{loss_type}'")
+
+        return loss
+
+    def predict_start_from_z_and_v(self, x_t, t, v):
+        # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
+            - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+        )
+
+    def predict_eps_from_z_and_v(self, x_t, t, v):
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
+            * x_t
+        )
+
+    def get_v(self, x, noise, t):
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
+            - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+        )
+
+    def forward(self, x, *args, **kwargs):
+        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+        t = torch.randint(
+            0, self.num_timesteps, (x.shape[0],), device=self.device
+        ).long()
+        return self.p_losses(x, t, *args, **kwargs)
+
+    def get_input(self, batch, k):
+        # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch
+        # fbank, stft, label_indices, fname, waveform, text = batch
+        fname, text, waveform, stft, fbank = (
+            batch["fname"],
+            batch["text"],
+            batch["waveform"],
+            batch["stft"],
+            batch["log_mel_spec"],
+        )
+        # for i in range(fbank.size(0)):
+        #     fb = fbank[i].numpy()
+        #     seg_lb = seg_label[i].numpy()
+        #     logits = np.mean(seg_lb, axis=0)
+        #     index = np.argsort(logits)[::-1][:5]
+        #     plt.imshow(seg_lb[:,index], aspect="auto")
+        #     plt.title(index)
+        #     plt.savefig("%s_label.png" % i)
+        #     plt.close()
+        #     plt.imshow(fb, aspect="auto")
+        #     plt.savefig("%s_fb.png" % i)
+        #     plt.close()
+        ret = {}
+
+        ret["fbank"] = (
+            fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
+        )
+        ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
+        # ret["clip_label"] = clip_label.to(memory
+        # _format=torch.contiguous_format).float()
+        ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
+        ret["text"] = list(text)
+        ret["fname"] = fname
+
+        for key in batch.keys():
+            if key not in ret.keys():
+                ret[key] = batch[key]
+
+        return ret[k]
+
+    def _get_rows_from_list(self, samples):
+        n_imgs_per_row = len(samples)
+        denoise_grid = rearrange(samples, "n b c h w -> b n c h w")
+        denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w")
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.first_stage_key)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        x = x.to(self.device)[:N]
+        log["inputs"] = x
+
+        # get diffusion row
+        diffusion_row = list()
+        x_start = x[:n_row]
+
+        for t in range(self.num_timesteps):
+            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
+                t = t.to(self.device).long()
+                noise = torch.randn_like(x_start)
+                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+                diffusion_row.append(x_noisy)
+
+        log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+        if sample:
+            # get denoise row
+            with self.ema_scope("Plotting"):
+                samples, denoise_row = self.sample(
+                    batch_size=N, return_intermediates=True
+                )
+
+            log["samples"] = samples
+            log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.learn_logvar:
+            params = params + [self.logvar]
+        opt = torch.optim.AdamW(params, lr=lr)
+        return opt
+
+    def initialize_param_check_toolkit(self):
+        self.tracked_steps = 0
+        self.param_dict = {}
+
+    def statistic_require_grad_tensor_number(self, module, name=None):
+        requires_grad_num = 0
+        total_num = 0
+        require_grad_tensor = None
+        for p in module.parameters():
+            if p.requires_grad:
+                requires_grad_num += 1
+                if require_grad_tensor is None:
+                    require_grad_tensor = p
+            total_num += 1
+        print(
+            "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)"
+            % (name, requires_grad_num, total_num, requires_grad_num / total_num)
+        )
+        return require_grad_tensor
+
+
+class LatentDiffusion(DDPM):
+    """main class"""
+
+    def __init__(
+        self,
+        first_stage_config,
+        cond_stage_config=None,
+        num_timesteps_cond=None,
+        cond_stage_key="image",
+        optimize_ddpm_parameter=True,
+        unconditional_prob_cfg=0.1,
+        warmup_steps=10000,
+        cond_stage_trainable=False,
+        concat_mode=True,
+        cond_stage_forward=None,
+        conditioning_key=None,
+        scale_factor=1.0,
+        batchsize=None,
+        evaluation_params={},
+        scale_by_std=False,
+        base_learning_rate=None,
+        *args,
+        **kwargs,
+    ):
+        self.learning_rate = base_learning_rate
+        self.num_timesteps_cond = default(num_timesteps_cond, 1)
+        self.scale_by_std = scale_by_std
+        self.warmup_steps = warmup_steps
+
+        if optimize_ddpm_parameter:
+            if unconditional_prob_cfg == 0.0:
+                "You choose to optimize DDPM. The classifier free guidance scale should be 0.1"
+                unconditional_prob_cfg = 0.1
+        else:
+            if unconditional_prob_cfg == 0.1:
+                "You choose not to optimize DDPM. The classifier free guidance scale should be 0.0"
+                unconditional_prob_cfg = 0.0
+
+        self.evaluation_params = evaluation_params
+        assert self.num_timesteps_cond <= kwargs["timesteps"]
+
+        # for backwards compatibility after implementation of DiffusionWrapper
+        # if conditioning_key is None:
+        #     conditioning_key = "concat" if concat_mode else "crossattn"
+        # if cond_stage_config == "__is_unconditional__":
+        #     conditioning_key = None
+
+        conditioning_key = list(cond_stage_config.keys())
+
+        self.conditioning_key = conditioning_key
+
+        ckpt_path = kwargs.pop("ckpt_path", None)
+        ignore_keys = kwargs.pop("ignore_keys", [])
+        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+
+        self.optimize_ddpm_parameter = optimize_ddpm_parameter
+        # if(not optimize_ddpm_parameter):
+        #     print("Warning: Close the optimization of the latent diffusion model")
+        #     for p in self.model.parameters():
+        #         p.requires_grad=False
+
+        self.concat_mode = concat_mode
+        self.cond_stage_key = cond_stage_key
+        self.cond_stage_key_orig = cond_stage_key
+        try:
+            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+        except:
+            self.num_downs = 0
+        if not scale_by_std:
+            self.scale_factor = scale_factor
+        else:
+            self.register_buffer("scale_factor", torch.tensor(scale_factor))
+        self.instantiate_first_stage(first_stage_config)
+        self.unconditional_prob_cfg = unconditional_prob_cfg
+        self.cond_stage_models = nn.ModuleList([])
+        self.instantiate_cond_stage(cond_stage_config)
+        self.cond_stage_forward = cond_stage_forward
+        self.clip_denoised = False
+        self.bbox_tokenizer = None
+        self.conditional_dry_run_finished = False
+        self.restarted_from_ckpt = False
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys)
+            self.restarted_from_ckpt = True
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+
+        for each in self.cond_stage_models:
+            params = params + list(
+                each.parameters()
+            )  # Add the parameter from the conditional stage
+
+        if self.learn_logvar:
+            print("Diffusion model optimizing logvar")
+            params.append(self.logvar)
+        opt = torch.optim.AdamW(params, lr=lr)
+        # if self.use_scheduler:
+        #     assert "target" in self.scheduler_config
+        #     scheduler = instantiate_from_config(self.scheduler_config)
+
+        #     print("Setting up LambdaLR scheduler...")
+        #     scheduler = [
+        #         {
+        #             "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
+        #             "interval": "step",
+        #             "frequency": 1,
+        #         }
+        #     ]
+        #     return [opt], scheduler
+        return opt
+
+    def make_cond_schedule(
+        self,
+    ):
+        self.cond_ids = torch.full(
+            size=(self.num_timesteps,),
+            fill_value=self.num_timesteps - 1,
+            dtype=torch.long,
+        )
+        ids = torch.round(
+            torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
+        ).long()
+        self.cond_ids[: self.num_timesteps_cond] = ids
+
+    @torch.no_grad()
+    def on_train_batch_start(self, batch, batch_idx):
+        # only for very first batch
+        if (
+            self.scale_factor == 1
+            and self.scale_by_std
+            and self.current_epoch == 0
+            and self.global_step == 0
+            and batch_idx == 0
+            and not self.restarted_from_ckpt
+        ):
+            # assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+            # set rescale weight to 1./std of encodings
+            print("### USING STD-RESCALING ###")
+            x = super().get_input(batch, self.first_stage_key)
+            x = x.to(self.device)
+            encoder_posterior = self.encode_first_stage(x)
+            z = self.get_first_stage_encoding(encoder_posterior).detach()
+            del self.scale_factor
+            self.register_buffer("scale_factor", 1.0 / z.flatten().std())
+            print(f"setting self.scale_factor to {self.scale_factor}")
+            print("### USING STD-RESCALING ###")
+
+    def register_schedule(
+        self,
+        given_betas=None,
+        beta_schedule="linear",
+        timesteps=1000,
+        linear_start=1e-4,
+        linear_end=2e-2,
+        cosine_s=8e-3,
+    ):
+        super().register_schedule(
+            given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
+        )
+
+        self.shorten_cond_schedule = self.num_timesteps_cond > 1
+        if self.shorten_cond_schedule:
+            self.make_cond_schedule()
+
+    def instantiate_first_stage(self, config):
+        model = instantiate_from_config(config)
+        self.first_stage_model = model.eval()
+        self.first_stage_model.train = disabled_train
+        for param in self.first_stage_model.parameters():
+            param.requires_grad = False
+
+    def make_decision(self, probability):
+        if float(torch.rand(1)) < probability:
+            return True
+        else:
+            return False
+
+    def instantiate_cond_stage(self, config):
+        self.cond_stage_model_metadata = {}
+        for i, cond_model_key in enumerate(config.keys()):
+            model = instantiate_from_config(config[cond_model_key])
+            self.cond_stage_models.append(model)
+            self.cond_stage_model_metadata[cond_model_key] = {
+                "model_idx": i,
+                "cond_stage_key": config[cond_model_key]["cond_stage_key"],
+                "conditioning_key": config[cond_model_key]["conditioning_key"],
+            }
+
+    def get_first_stage_encoding(self, encoder_posterior):
+        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+            z = encoder_posterior.sample()
+        elif isinstance(encoder_posterior, torch.Tensor):
+            z = encoder_posterior
+        else:
+            raise NotImplementedError(
+                f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
+            )
+        return self.scale_factor * z
+
+    def get_learned_conditioning(self, c, key, unconditional_cfg):
+        assert key in self.cond_stage_model_metadata.keys()
+
+        # Classifier-free guidance
+        if not unconditional_cfg:
+            c = self.cond_stage_models[
+                self.cond_stage_model_metadata[key]["model_idx"]
+            ](c)
+        else:
+            # when the cond_stage_key is "all", pick one random element out
+            if isinstance(c, dict):
+                c = c[list(c.keys())[0]]
+
+            if isinstance(c, torch.Tensor):
+                batchsize = c.size(0)
+            elif isinstance(c, list):
+                batchsize = len(c)
+            else:
+                raise NotImplementedError()
+
+            c = self.cond_stage_models[
+                self.cond_stage_model_metadata[key]["model_idx"]
+            ].get_unconditional_condition(batchsize)
+
+        return c
+
+    def get_input(
+        self,
+        batch,
+        k,
+        return_first_stage_encode=True,
+        return_decoding_output=False,
+        return_encoder_input=False,
+        return_encoder_output=False,
+        unconditional_prob_cfg=0.1,
+    ):
+        x = super().get_input(batch, k)
+
+        x = x.to(self.device)
+
+        if return_first_stage_encode:
+            encoder_posterior = self.encode_first_stage(x)
+            z = self.get_first_stage_encoding(encoder_posterior).detach()
+        else:
+            z = None
+        cond_dict = {}
+        if len(self.cond_stage_model_metadata.keys()) > 0:
+            unconditional_cfg = False
+            if self.conditional_dry_run_finished and self.make_decision(
+                unconditional_prob_cfg
+            ):
+                unconditional_cfg = True
+            for cond_model_key in self.cond_stage_model_metadata.keys():
+                cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
+                    "cond_stage_key"
+                ]
+
+                if cond_model_key in cond_dict.keys():
+                    continue
+
+                if not self.training:
+                    if isinstance(
+                        self.cond_stage_models[
+                            self.cond_stage_model_metadata[cond_model_key]["model_idx"]
+                        ],
+                        CLAPAudioEmbeddingClassifierFreev2,
+                    ):
+                        print(
+                            "Warning: CLAP model normally should use text for evaluation"
+                        )
+
+                # The original data for conditioning
+                # If cond_model_key is "all", that means the conditional model need all the information from a batch
+
+                if cond_stage_key != "all":
+                    xc = super().get_input(batch, cond_stage_key)
+                    if type(xc) == torch.Tensor:
+                        xc = xc.to(self.device)
+                else:
+                    xc = batch
+
+                # if cond_stage_key is "all", xc will be a dictionary containing all keys
+                # Otherwise xc will be an entry of the dictionary
+                c = self.get_learned_conditioning(
+                    xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
+                )
+
+                # cond_dict will be used to condition the diffusion model
+                # If one conditional model return multiple conditioning signal
+                if isinstance(c, dict):
+                    for k in c.keys():
+                        cond_dict[k] = c[k]
+                else:
+                    cond_dict[cond_model_key] = c
+
+        # If the key is accidently added to the dictionary and not in the condition list, remove the condition
+        # for k in list(cond_dict.keys()):
+        #     if(k not in self.cond_stage_model_metadata.keys()):
+        #         del cond_dict[k]
+
+        out = [z, cond_dict]
+
+        if return_decoding_output:
+            xrec = self.decode_first_stage(z)
+            out += [xrec]
+
+        if return_encoder_input:
+            out += [x]
+
+        if return_encoder_output:
+            out += [encoder_posterior]
+
+        if not self.conditional_dry_run_finished:
+            self.conditional_dry_run_finished = True
+
+        # Output is a dictionary, where the value could only be tensor or tuple
+        return out
+
+    def decode_first_stage(self, z):
+        with torch.no_grad():
+            z = 1.0 / self.scale_factor * z
+            decoding = self.first_stage_model.decode(z)
+        return decoding
+
+    def mel_spectrogram_to_waveform(
+        self, mel, savepath=".", bs=None, name="outwav", save=True
+    ):
+        # Mel: [bs, 1, t-steps, fbins]
+        if len(mel.size()) == 4:
+            mel = mel.squeeze(1)
+        mel = mel.permute(0, 2, 1)
+        waveform = self.first_stage_model.vocoder(mel)
+        waveform = waveform.cpu().detach().numpy()
+        if save:
+            self.save_waveform(waveform, savepath, name)
+        return waveform
+
+    def encode_first_stage(self, x):
+        with torch.no_grad():
+            return self.first_stage_model.encode(x)
+
+    def extract_possible_loss_in_cond_dict(self, cond_dict):
+        # This function enable the conditional module to return loss function that can optimize them
+
+        assert isinstance(cond_dict, dict)
+        losses = {}
+
+        for cond_key in cond_dict.keys():
+            if "loss" in cond_key and "noncond" in cond_key:
+                assert cond_key not in losses.keys()
+                losses[cond_key] = cond_dict[cond_key]
+
+        return losses
+
+    def filter_useful_cond_dict(self, cond_dict):
+        new_cond_dict = {}
+        for key in cond_dict.keys():
+            if key in self.cond_stage_model_metadata.keys():
+                new_cond_dict[key] = cond_dict[key]
+
+        # All the conditional key in the metadata should be used
+        for key in self.cond_stage_model_metadata.keys():
+            assert key in new_cond_dict.keys(), "%s, %s" % (
+                key,
+                str(new_cond_dict.keys()),
+            )
+
+        return new_cond_dict
+
+    def shared_step(self, batch, **kwargs):
+        if self.training:
+            # Classifier-free guidance
+            unconditional_prob_cfg = self.unconditional_prob_cfg
+        else:
+            unconditional_prob_cfg = 0.0  # TODO possible bug here
+
+        x, c = self.get_input(
+            batch, self.first_stage_key, unconditional_prob_cfg=unconditional_prob_cfg
+        )
+
+        if self.optimize_ddpm_parameter:
+            loss, loss_dict = self(x, self.filter_useful_cond_dict(c))
+        else:
+            loss_dict = {}
+            loss = None
+
+        additional_loss_for_cond_modules = self.extract_possible_loss_in_cond_dict(c)
+        assert isinstance(additional_loss_for_cond_modules, dict)
+
+        loss_dict.update(additional_loss_for_cond_modules)
+
+        if len(additional_loss_for_cond_modules.keys()) > 0:
+            for k in additional_loss_for_cond_modules.keys():
+                if loss is None:
+                    loss = additional_loss_for_cond_modules[k]
+                else:
+                    loss = loss + additional_loss_for_cond_modules[k]
+
+        # for k,v in additional_loss_for_cond_modules.items():
+        #     self.log(
+        #         "cond_stage/"+k,
+        #         float(v),
+        #         prog_bar=True,
+        #         logger=True,
+        #         on_step=True,
+        #         on_epoch=True,
+        #     )
+        if self.training:
+            assert loss is not None
+
+        return loss, loss_dict
+
+    def forward(self, x, c, *args, **kwargs):
+        t = torch.randint(
+            0, self.num_timesteps, (x.shape[0],), device=self.device
+        ).long()
+
+        # assert c is not None
+        # c = self.get_learned_conditioning(c)
+
+        loss, loss_dict = self.p_losses(x, c, t, *args, **kwargs)
+        return loss, loss_dict
+
+    def reorder_cond_dict(self, cond_dict):
+        # To make sure the order is correct
+        new_cond_dict = {}
+        for key in self.conditioning_key:
+            new_cond_dict[key] = cond_dict[key]
+        return new_cond_dict
+
+    def apply_model(self, x_noisy, t, cond, return_ids=False):
+        cond = self.reorder_cond_dict(cond)
+
+        x_recon = self.model(x_noisy, t, cond_dict=cond)
+
+        if isinstance(x_recon, tuple) and not return_ids:
+            return x_recon[0]
+        else:
+            return x_recon
+
+    def p_losses(self, x_start, cond, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        model_output = self.apply_model(x_noisy, t, cond)
+
+        loss_dict = {}
+        prefix = "train" if self.training else "val"
+
+        if self.parameterization == "x0":
+            target = x_start
+        elif self.parameterization == "eps":
+            target = noise
+        elif self.parameterization == "v":
+            target = self.get_v(x_start, noise, t)
+        else:
+            raise NotImplementedError()
+        # print(model_output.size(), target.size())
+        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+        loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})
+
+        logvar_t = self.logvar[t].to(self.device)
+        loss = loss_simple / torch.exp(logvar_t) + logvar_t
+        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+        if self.learn_logvar:
+            loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
+            loss_dict.update({"logvar": self.logvar.data.mean()})
+
+        loss = self.l_simple_weight * loss.mean()
+
+        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+        loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
+        loss += self.original_elbo_weight * loss_vlb
+        loss_dict.update({f"{prefix}/loss": loss})
+
+        return loss, loss_dict
+
+    def p_mean_variance(
+        self,
+        x,
+        c,
+        t,
+        clip_denoised: bool,
+        return_codebook_ids=False,
+        quantize_denoised=False,
+        return_x0=False,
+        score_corrector=None,
+        corrector_kwargs=None,
+    ):
+        t_in = t
+        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+        if score_corrector is not None:
+            assert self.parameterization == "eps"
+            model_out = score_corrector.modify_score(
+                self, model_out, x, t, c, **corrector_kwargs
+            )
+
+        if return_codebook_ids:
+            model_out, logits = model_out
+
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        else:
+            raise NotImplementedError()
+
+        if clip_denoised:
+            x_recon.clamp_(-1.0, 1.0)
+        if quantize_denoised:
+            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+            x_start=x_recon, x_t=x, t=t
+        )
+        if return_codebook_ids:
+            return model_mean, posterior_variance, posterior_log_variance, logits
+        elif return_x0:
+            return model_mean, posterior_variance, posterior_log_variance, x_recon
+        else:
+            return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(
+        self,
+        x,
+        c,
+        t,
+        clip_denoised=False,
+        repeat_noise=False,
+        return_codebook_ids=False,
+        quantize_denoised=False,
+        return_x0=False,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+    ):
+        b, *_, device = *x.shape, x.device
+        outputs = self.p_mean_variance(
+            x=x,
+            c=c,
+            t=t,
+            clip_denoised=clip_denoised,
+            return_codebook_ids=return_codebook_ids,
+            quantize_denoised=quantize_denoised,
+            return_x0=return_x0,
+            score_corrector=score_corrector,
+            corrector_kwargs=corrector_kwargs,
+        )
+        if return_codebook_ids:
+            raise DeprecationWarning("Support dropped.")
+            model_mean, _, model_log_variance, logits = outputs
+        elif return_x0:
+            model_mean, _, model_log_variance, x0 = outputs
+        else:
+            model_mean, _, model_log_variance = outputs
+
+        noise = noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.0:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        # no noise when t == 0
+        nonzero_mask = (
+            (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
+        )
+
+        # if return_codebook_ids:
+        #     return model_mean + nonzero_mask * (
+        #         0.5 * model_log_variance
+        #     ).exp() * noise, logits.argmax(dim=1)
+        if return_x0:
+            return (
+                model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
+                x0,
+            )
+        else:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def progressive_denoising(
+        self,
+        cond,
+        shape,
+        verbose=True,
+        callback=None,
+        quantize_denoised=False,
+        img_callback=None,
+        mask=None,
+        x0=None,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        batch_size=None,
+        x_T=None,
+        start_T=None,
+        log_every_t=None,
+    ):
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        timesteps = self.num_timesteps
+        if batch_size is not None:
+            b = batch_size if batch_size is not None else shape[0]
+            shape = [batch_size] + list(shape)
+        else:
+            b = batch_size = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=self.device)
+        else:
+            img = x_T
+        intermediates = []
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {
+                    key: cond[key][:batch_size]
+                    if not isinstance(cond[key], list)
+                    else list(map(lambda x: x[:batch_size], cond[key]))
+                    for key in cond
+                }
+            else:
+                cond = (
+                    [c[:batch_size] for c in cond]
+                    if isinstance(cond, list)
+                    else cond[:batch_size]
+                )
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = (
+            tqdm(
+                reversed(range(0, timesteps)),
+                desc="Progressive Generation",
+                total=timesteps,
+            )
+            if verbose
+            else reversed(range(0, timesteps))
+        )
+        if type(temperature) == float:
+            temperature = [temperature] * timesteps
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != "hybrid"
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img, x0_partial = self.p_sample(
+                img,
+                cond,
+                ts,
+                clip_denoised=self.clip_denoised,
+                quantize_denoised=quantize_denoised,
+                return_x0=True,
+                temperature=temperature[i],
+                noise_dropout=noise_dropout,
+                score_corrector=score_corrector,
+                corrector_kwargs=corrector_kwargs,
+            )
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1.0 - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(x0_partial)
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(img, i)
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_loop(
+        self,
+        cond,
+        shape,
+        return_intermediates=False,
+        x_T=None,
+        verbose=True,
+        callback=None,
+        timesteps=None,
+        quantize_denoised=False,
+        mask=None,
+        x0=None,
+        img_callback=None,
+        start_T=None,
+        log_every_t=None,
+    ):
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        device = self.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        intermediates = [img]
+        if timesteps is None:
+            timesteps = self.num_timesteps
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = (
+            tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
+            if verbose
+            else reversed(range(0, timesteps))
+        )
+
+        if mask is not None:
+            assert x0 is not None
+            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=device, dtype=torch.long)
+
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != "hybrid"
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img = self.p_sample(
+                img,
+                cond,
+                ts,
+                clip_denoised=self.clip_denoised,
+                quantize_denoised=quantize_denoised,
+            )
+
+            if mask is not None:
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1.0 - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(img)
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(img, i)
+
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(
+        self,
+        cond,
+        batch_size=16,
+        return_intermediates=False,
+        x_T=None,
+        verbose=True,
+        timesteps=None,
+        quantize_denoised=False,
+        mask=None,
+        x0=None,
+        shape=None,
+        **kwargs,
+    ):
+        if shape is None:
+            shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {
+                    key: cond[key][:batch_size]
+                    if not isinstance(cond[key], list)
+                    else list(map(lambda x: x[:batch_size], cond[key]))
+                    for key in cond
+                }
+            else:
+                cond = (
+                    [c[:batch_size] for c in cond]
+                    if isinstance(cond, list)
+                    else cond[:batch_size]
+                )
+        return self.p_sample_loop(
+            cond,
+            shape,
+            return_intermediates=return_intermediates,
+            x_T=x_T,
+            verbose=verbose,
+            timesteps=timesteps,
+            quantize_denoised=quantize_denoised,
+            mask=mask,
+            x0=x0,
+            **kwargs,
+        )
+
+    def save_waveform(self, waveform, savepath, name="outwav"):
+        for i in range(waveform.shape[0]):
+            if type(name) is str:
+                path = os.path.join(
+                    savepath, "%s_%s_%s.wav" % (self.global_step, i, name)
+                )
+            elif type(name) is list:
+                path = os.path.join(
+                    savepath,
+                    "%s.wav"
+                    % (
+                        os.path.basename(name[i])
+                        if (not ".wav" in name[i])
+                        else os.path.basename(name[i]).split(".")[0]
+                    ),
+                )
+            else:
+                raise NotImplementedError
+            todo_waveform = waveform[i, 0]
+            todo_waveform = (
+                todo_waveform / np.max(np.abs(todo_waveform))
+            ) * 0.8  # Normalize the energy of the generation output
+            sf.write(path, todo_waveform, samplerate=self.sampling_rate)
+
+    @torch.no_grad()
+    def sample_log(
+        self,
+        cond,
+        batch_size,
+        ddim,
+        ddim_steps,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        use_plms=False,
+        mask=None,
+        **kwargs,
+    ):
+        if mask is not None:
+            shape = (self.channels, mask.size()[-2], mask.size()[-1])
+        else:
+            shape = (self.channels, self.latent_t_size, self.latent_f_size)
+
+        intermediate = None
+        if ddim and not use_plms:
+            ddim_sampler = DDIMSampler(self)
+            samples, intermediates = ddim_sampler.sample(
+                ddim_steps,
+                batch_size,
+                shape,
+                cond,
+                verbose=False,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+                mask=mask,
+                **kwargs,
+            )
+        elif use_plms:
+            plms_sampler = PLMSSampler(self)
+            samples, intermediates = plms_sampler.sample(
+                ddim_steps,
+                batch_size,
+                shape,
+                cond,
+                verbose=False,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                mask=mask,
+                unconditional_conditioning=unconditional_conditioning,
+                **kwargs,
+            )
+
+        else:
+            samples, intermediates = self.sample(
+                cond=cond,
+                batch_size=batch_size,
+                return_intermediates=True,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                mask=mask,
+                unconditional_conditioning=unconditional_conditioning,
+                **kwargs,
+            )
+
+        return samples, intermediate
+
+    @torch.no_grad()
+    def generate_batch(
+        self,
+        batch,
+        ddim_steps=200,
+        ddim_eta=1.0,
+        x_T=None,
+        n_gen=1,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        use_plms=False,
+        **kwargs,
+    ):
+        # Generate n_gen times and select the best
+        # Batch: audio, text, fnames
+        assert x_T is None
+
+        if use_plms:
+            assert ddim_steps is not None
+
+        use_ddim = ddim_steps is not None
+
+        # with self.ema_scope("Plotting"):
+        for i in range(1):
+            z, c = self.get_input(
+                batch,
+                self.first_stage_key,
+                unconditional_prob_cfg=0.0,  # Do not output unconditional information in the c
+            )
+
+            c = self.filter_useful_cond_dict(c)
+
+            text = super().get_input(batch, "text")
+
+            # Generate multiple samples
+            batch_size = z.shape[0] * n_gen
+
+            # Generate multiple samples at a time and filter out the best
+            # The condition to the diffusion wrapper can have many format
+            for cond_key in c.keys():
+                if isinstance(c[cond_key], list):
+                    for i in range(len(c[cond_key])):
+                        c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0)
+                elif isinstance(c[cond_key], dict):
+                    for k in c[cond_key].keys():
+                        c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0)
+                else:
+                    c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0)
+
+            text = text * n_gen
+
+            if unconditional_guidance_scale != 1.0:
+                unconditional_conditioning = {}
+                for key in self.cond_stage_model_metadata:
+                    model_idx = self.cond_stage_model_metadata[key]["model_idx"]
+                    unconditional_conditioning[key] = self.cond_stage_models[
+                        model_idx
+                    ].get_unconditional_condition(batch_size)
+
+            fnames = list(super().get_input(batch, "fname"))
+            samples, _ = self.sample_log(
+                cond=c,
+                batch_size=batch_size,
+                x_T=x_T,
+                ddim=use_ddim,
+                ddim_steps=ddim_steps,
+                eta=ddim_eta,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+                use_plms=use_plms,
+            )
+
+            mel = self.decode_first_stage(samples)
+
+            waveform = self.mel_spectrogram_to_waveform(
+                mel, savepath="", bs=None, name=fnames, save=False
+            )
+
+            if n_gen > 1:
+                best_index = []
+                similarity = self.clap.cos_similarity(
+                    torch.FloatTensor(waveform).squeeze(1), text
+                )
+                for i in range(z.shape[0]):
+                    candidates = similarity[i :: z.shape[0]]
+                    max_index = torch.argmax(candidates).item()
+                    best_index.append(i + max_index * z.shape[0])
+
+                waveform = waveform[best_index]
+
+                print("Similarity between generated audio and text:")
+                print(' '.join('{:.2f}'.format(num) for num in similarity.detach().cpu().tolist()))
+                print("Choose the following indexes as the output:", best_index)
+
+            return waveform
+
+    @torch.no_grad()
+    def generate_sample(
+        self,
+        batchs,
+        ddim_steps=200,
+        ddim_eta=1.0,
+        x_T=None,
+        n_gen=1,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        name=None,
+        use_plms=False,
+        limit_num=None,
+        **kwargs,
+    ):
+        # Generate n_gen times and select the best
+        # Batch: audio, text, fnames
+        assert x_T is None
+        try:
+            batchs = iter(batchs)
+        except TypeError:
+            raise ValueError("The first input argument should be an iterable object")
+
+        if use_plms:
+            assert ddim_steps is not None
+
+        use_ddim = ddim_steps is not None
+        if name is None:
+            name = self.get_validation_folder_name()
+
+        waveform_save_path = os.path.join(self.get_log_dir(), name)
+        os.makedirs(waveform_save_path, exist_ok=True)
+        print("Waveform save path: ", waveform_save_path)
+
+        if (
+            "audiocaps" in waveform_save_path
+            and len(os.listdir(waveform_save_path)) >= 964
+        ):
+            print("The evaluation has already been done at %s" % waveform_save_path)
+            return waveform_save_path
+
+        with self.ema_scope("Plotting"):
+            for i, batch in enumerate(batchs):
+                z, c = self.get_input(
+                    batch,
+                    self.first_stage_key,
+                    unconditional_prob_cfg=0.0,  # Do not output unconditional information in the c
+                )
+
+                if limit_num is not None and i * z.size(0) > limit_num:
+                    break
+
+                c = self.filter_useful_cond_dict(c)
+
+                text = super().get_input(batch, "text")
+
+                # Generate multiple samples
+                batch_size = z.shape[0] * n_gen
+
+                # Generate multiple samples at a time and filter out the best
+                # The condition to the diffusion wrapper can have many format
+                for cond_key in c.keys():
+                    if isinstance(c[cond_key], list):
+                        for i in range(len(c[cond_key])):
+                            c[cond_key][i] = torch.cat([c[cond_key][i]] * n_gen, dim=0)
+                    elif isinstance(c[cond_key], dict):
+                        for k in c[cond_key].keys():
+                            c[cond_key][k] = torch.cat([c[cond_key][k]] * n_gen, dim=0)
+                    else:
+                        c[cond_key] = torch.cat([c[cond_key]] * n_gen, dim=0)
+
+                text = text * n_gen
+
+                if unconditional_guidance_scale != 1.0:
+                    unconditional_conditioning = {}
+                    for key in self.cond_stage_model_metadata:
+                        model_idx = self.cond_stage_model_metadata[key]["model_idx"]
+                        unconditional_conditioning[key] = self.cond_stage_models[
+                            model_idx
+                        ].get_unconditional_condition(batch_size)
+
+                fnames = list(super().get_input(batch, "fname"))
+                samples, _ = self.sample_log(
+                    cond=c,
+                    batch_size=batch_size,
+                    x_T=x_T,
+                    ddim=use_ddim,
+                    ddim_steps=ddim_steps,
+                    eta=ddim_eta,
+                    unconditional_guidance_scale=unconditional_guidance_scale,
+                    unconditional_conditioning=unconditional_conditioning,
+                    use_plms=use_plms,
+                )
+
+                mel = self.decode_first_stage(samples)
+
+                waveform = self.mel_spectrogram_to_waveform(
+                    mel, savepath=waveform_save_path, bs=None, name=fnames, save=False
+                )
+
+                if n_gen > 1:
+                    try:
+                        best_index = []
+                        similarity = self.clap.cos_similarity(
+                            torch.FloatTensor(waveform).squeeze(1), text
+                        )
+                        for i in range(z.shape[0]):
+                            candidates = similarity[i :: z.shape[0]]
+                            max_index = torch.argmax(candidates).item()
+                            best_index.append(i + max_index * z.shape[0])
+
+                        waveform = waveform[best_index]
+
+                        print("Similarity between generated audio and text", similarity)
+                        print("Choose the following indexes:", best_index)
+                    except Exception as e:
+                        print("Warning: while calculating CLAP score (not fatal), ", e)
+                self.save_waveform(waveform, waveform_save_path, name=fnames)
+        return waveform_save_path
+
+
+class DiffusionWrapper(nn.Module):
+    def __init__(self, diff_model_config, conditioning_key):
+        super().__init__()
+        self.diffusion_model = instantiate_from_config(diff_model_config)
+
+        self.conditioning_key = conditioning_key
+
+        for key in self.conditioning_key:
+            if (
+                "concat" in key
+                or "crossattn" in key
+                or "hybrid" in key
+                or "film" in key
+                or "noncond" in key
+            ):
+                continue
+            else:
+                raise Value("The conditioning key %s is illegal" % key)
+
+        self.being_verbosed_once = False
+
+    def forward(self, x, t, cond_dict: dict = {}):
+        x = x.contiguous()
+        t = t.contiguous()
+
+        # x with condition (or maybe not)
+        xc = x
+
+        y = None
+        context_list, attn_mask_list = [], []
+
+        conditional_keys = cond_dict.keys()
+
+        for key in conditional_keys:
+            if "concat" in key:
+                xc = torch.cat([x, cond_dict[key].unsqueeze(1)], dim=1)
+            elif "film" in key:
+                if y is None:
+                    y = cond_dict[key].squeeze(1)
+                else:
+                    y = torch.cat([y, cond_dict[key].squeeze(1)], dim=-1)
+            elif "crossattn" in key:
+                # assert context is None, "You can only have one context matrix, got %s" % (cond_dict.keys())
+                if isinstance(cond_dict[key], dict):
+                    for k in cond_dict[key].keys():
+                        if "crossattn" in k:
+                            context, attn_mask = cond_dict[key][
+                                k
+                            ]  # crossattn_audiomae_pooled: torch.Size([12, 128, 768])
+                else:
+                    assert len(cond_dict[key]) == 2, (
+                        "The context condition for %s you returned should have two element, one context one mask"
+                        % (key)
+                    )
+                    context, attn_mask = cond_dict[key]
+
+                # The input to the UNet model is a list of context matrix
+                context_list.append(context)
+                attn_mask_list.append(attn_mask)
+
+            elif (
+                "noncond" in key
+            ):  # If you use loss function in the conditional module, include the keyword "noncond" in the return dictionary
+                continue
+            else:
+                raise NotImplementedError()
+
+        # if(not self.being_verbosed_once):
+        #     print("The input shape to the diffusion model is as follows:")
+        #     print("xc", xc.size())
+        #     print("t", t.size())
+        #     for i in range(len(context_list)):
+        #         print("context_%s" % i, context_list[i].size(), attn_mask_list[i].size())
+        #     if(y is not None):
+        #         print("y", y.size())
+        #     self.being_verbosed_once = True
+        out = self.diffusion_model(
+            xc, t, context_list=context_list, y=y, context_attn_mask_list=attn_mask_list
+        )
+        return out
+        self.warmup_step()
+
+        if (
+            self.state is None
+            and len(self.trainer.optimizers[0].state_dict()["state"].keys()) > 0
+        ):
+            self.state = (
+                self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"].clone()
+            )
+        elif self.state is not None and batch_idx % 1000 == 0:
+            assert (
+                torch.sum(
+                    torch.abs(
+                        self.state
+                        - self.trainer.optimizers[0].state_dict()["state"][0]["exp_avg"]
+                    )
+                )
+                > 1e-7
+            ), "Optimizer is not working"
+
+        if len(self.metrics_buffer.keys()) > 0:
+            for k in self.metrics_buffer.keys():
+                self.log(
+                    k,
+                    self.metrics_buffer[k],
+                    prog_bar=False,
+                    logger=True,
+                    on_step=True,
+                    on_epoch=False,
+                )
+                print(k, self.metrics_buffer[k])
+            self.metrics_buffer = {}
+
+        loss, loss_dict = self.shared_step(batch)
+
+        self.log_dict(
+            {k: float(v) for k, v in loss_dict.items()},
+            prog_bar=True,
+            logger=True,
+            on_step=True,
+            on_epoch=True,
+        )
+
+        self.log(
+            "global_step",
+            float(self.global_step),
+            prog_bar=True,
+            logger=True,
+            on_step=True,
+            on_epoch=False,
+        )
+
+        lr = self.trainer.optimizers[0].param_groups[0]["lr"]
+        self.log(
+            "lr_abs",
+            float(lr),
+            prog_bar=True,
+            logger=True,
+            on_step=True,
+            on_epoch=False,
+        )
+
+
+if __name__ == "__main__":
+    import yaml
+
+    model_config = "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/stable-diffusion/models/ldm/text2img256/config.yaml"
+    model_config = yaml.load(open(model_config, "r"), Loader=yaml.FullLoader)
+
+    latent_diffusion = LatentDiffusion(**model_config["model"]["params"])
+
+    import ipdb
+
+    ipdb.set_trace()
diff --git a/audioldm2/latent_diffusion/models/plms.py b/audioldm2/latent_diffusion/models/plms.py
new file mode 100755
index 0000000000000000000000000000000000000000..9c80796442bd653ac3dc1970c12f621068a4d821
--- /dev/null
+++ b/audioldm2/latent_diffusion/models/plms.py
@@ -0,0 +1,360 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.util import (
+    make_ddim_sampling_parameters,
+    make_ddim_timesteps,
+    noise_like,
+)
+
+
+class PLMSSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(
+        self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+    ):
+        if ddim_eta != 0:
+            ddim_eta = 0
+            # raise ValueError('ddim_eta must be 0 for PLMS')
+
+        self.ddim_timesteps = make_ddim_timesteps(
+            ddim_discr_method=ddim_discretize,
+            num_ddim_timesteps=ddim_num_steps,
+            num_ddpm_timesteps=self.ddpm_num_timesteps,
+            verbose=verbose,
+        )
+        alphas_cumprod = self.model.alphas_cumprod
+        assert (
+            alphas_cumprod.shape[0] == self.ddpm_num_timesteps
+        ), "alphas have to be defined for each timestep"
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer("betas", to_torch(self.model.betas))
+        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+        self.register_buffer(
+            "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+        )
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer(
+            "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_one_minus_alphas_cumprod",
+            to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+        )
+        self.register_buffer(
+            "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_recipm1_alphas_cumprod",
+            to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+        )
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+            alphacums=alphas_cumprod.cpu(),
+            ddim_timesteps=self.ddim_timesteps,
+            eta=ddim_eta,
+            verbose=verbose,
+        )
+        self.register_buffer("ddim_sigmas", ddim_sigmas)
+        self.register_buffer("ddim_alphas", ddim_alphas)
+        self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+        self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev)
+            / (1 - self.alphas_cumprod)
+            * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+        )
+        self.register_buffer(
+            "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+        )
+
+    @torch.no_grad()
+    def sample(
+        self,
+        S,
+        batch_size,
+        shape,
+        conditioning=None,
+        callback=None,
+        normals_sequence=None,
+        img_callback=None,
+        quantize_x0=False,
+        eta=0.0,
+        mask=None,
+        x0=None,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        verbose=True,
+        x_T=None,
+        log_every_t=100,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+        **kwargs,
+    ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+                if cbs != batch_size:
+                    print(
+                        f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+                    )
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(
+                        f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+                    )
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f"Data shape for PLMS sampling is {size}")
+
+        samples, intermediates = self.plms_sampling(
+            conditioning,
+            size,
+            callback=callback,
+            img_callback=img_callback,
+            quantize_denoised=quantize_x0,
+            mask=mask,
+            x0=x0,
+            ddim_use_original_steps=False,
+            noise_dropout=noise_dropout,
+            temperature=temperature,
+            score_corrector=score_corrector,
+            corrector_kwargs=corrector_kwargs,
+            x_T=x_T,
+            log_every_t=log_every_t,
+            unconditional_guidance_scale=unconditional_guidance_scale,
+            unconditional_conditioning=unconditional_conditioning,
+        )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def plms_sampling(
+        self,
+        cond,
+        shape,
+        x_T=None,
+        ddim_use_original_steps=False,
+        callback=None,
+        timesteps=None,
+        quantize_denoised=False,
+        mask=None,
+        x0=None,
+        img_callback=None,
+        log_every_t=100,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+    ):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = (
+                self.ddpm_num_timesteps
+                if ddim_use_original_steps
+                else self.ddim_timesteps
+            )
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = (
+                int(
+                    min(timesteps / self.ddim_timesteps.shape[0], 1)
+                    * self.ddim_timesteps.shape[0]
+                )
+                - 1
+            )
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {"x_inter": [img], "pred_x0": [img]}
+        time_range = (
+            list(reversed(range(0, timesteps)))
+            if ddim_use_original_steps
+            else np.flip(timesteps)
+        )
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
+        old_eps = []
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+            ts_next = torch.full(
+                (b,),
+                time_range[min(i + 1, len(time_range) - 1)],
+                device=device,
+                dtype=torch.long,
+            )
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(
+                    x0, ts
+                )  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1.0 - mask) * img
+
+            outs = self.p_sample_plms(
+                img,
+                cond,
+                ts,
+                index=index,
+                use_original_steps=ddim_use_original_steps,
+                quantize_denoised=quantize_denoised,
+                temperature=temperature,
+                noise_dropout=noise_dropout,
+                score_corrector=score_corrector,
+                corrector_kwargs=corrector_kwargs,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+                old_eps=old_eps,
+                t_next=ts_next,
+            )
+            img, pred_x0, e_t = outs
+            old_eps.append(e_t)
+            if len(old_eps) >= 4:
+                old_eps.pop(0)
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates["x_inter"].append(img)
+                intermediates["pred_x0"].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_plms(
+        self,
+        x,
+        c,
+        t,
+        index,
+        repeat_noise=False,
+        use_original_steps=False,
+        quantize_denoised=False,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        old_eps=None,
+        t_next=None,
+    ):
+        b, *_, device = *x.shape, x.device
+
+        def get_model_output(x, t):
+            if (
+                unconditional_conditioning is None
+                or unconditional_guidance_scale == 1.0
+            ):
+                e_t = self.model.apply_model(x, t, c)
+            else:
+                x_in = torch.cat([x] * 2)
+                t_in = torch.cat([t] * 2)
+                c_in = torch.cat([unconditional_conditioning, c])
+                e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+                e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+            if score_corrector is not None:
+                assert self.model.parameterization == "eps"
+                e_t = score_corrector.modify_score(
+                    self.model, e_t, x, t, c, **corrector_kwargs
+                )
+
+            return e_t
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = (
+            self.model.alphas_cumprod_prev
+            if use_original_steps
+            else self.ddim_alphas_prev
+        )
+        sqrt_one_minus_alphas = (
+            self.model.sqrt_one_minus_alphas_cumprod
+            if use_original_steps
+            else self.ddim_sqrt_one_minus_alphas
+        )
+        sigmas = (
+            self.model.ddim_sigmas_for_original_num_steps
+            if use_original_steps
+            else self.ddim_sigmas
+        )
+
+        def get_x_prev_and_pred_x0(e_t, index):
+            # select parameters corresponding to the currently considered timestep
+            a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+            a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+            sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+            sqrt_one_minus_at = torch.full(
+                (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
+            )
+
+            # current prediction for x_0
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+            if quantize_denoised:
+                pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+            # direction pointing to x_t
+            dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+            noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+            if noise_dropout > 0.0:
+                noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+            x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+            return x_prev, pred_x0
+
+        e_t = get_model_output(x, t)
+        if len(old_eps) == 0:
+            # Pseudo Improved Euler (2nd order)
+            x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+            e_t_next = get_model_output(x_prev, t_next)
+            e_t_prime = (e_t + e_t_next) / 2
+        elif len(old_eps) == 1:
+            # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (3 * e_t - old_eps[-1]) / 2
+        elif len(old_eps) == 2:
+            # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+        elif len(old_eps) >= 3:
+            # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (
+                55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
+            ) / 24
+
+        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+        return x_prev, pred_x0, e_t
diff --git a/audioldm2/latent_diffusion/modules/__init__.py b/audioldm2/latent_diffusion/modules/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/modules/attention.py b/audioldm2/latent_diffusion/modules/attention.py
new file mode 100755
index 0000000000000000000000000000000000000000..6116342da98249c681ddb5f696b48dc0f5ac69f2
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/attention.py
@@ -0,0 +1,467 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.util import checkpoint
+
+
+def exists(val):
+    return val is not None
+
+
+def uniq(arr):
+    return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+    return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+    dim = tensor.shape[-1]
+    std = 1 / math.sqrt(dim)
+    tensor.uniform_(-std, std)
+    return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2)
+
+    def forward(self, x):
+        x, gate = self.proj(x).chunk(2, dim=-1)
+        return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+        super().__init__()
+        inner_dim = int(dim * mult)
+        dim_out = default(dim_out, dim)
+        project_in = (
+            nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+            if not glu
+            else GEGLU(dim, inner_dim)
+        )
+
+        self.net = nn.Sequential(
+            project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+        )
+
+    def forward(self, x):
+        return self.net(x)
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def Normalize(in_channels):
+    return torch.nn.GroupNorm(
+        num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+    )
+
+
+class LinearAttention(nn.Module):
+    def __init__(self, dim, heads=4, dim_head=32):
+        super().__init__()
+        self.heads = heads
+        hidden_dim = dim_head * heads
+        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+    def forward(self, x):
+        b, c, h, w = x.shape
+        qkv = self.to_qkv(x)
+        q, k, v = rearrange(
+            qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+        )
+        k = k.softmax(dim=-1)
+        context = torch.einsum("bhdn,bhen->bhde", k, v)
+        out = torch.einsum("bhde,bhdn->bhen", context, q)
+        out = rearrange(
+            out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+        )
+        return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.k = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.v = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.proj_out = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b, c, h, w = q.shape
+        q = rearrange(q, "b c h w -> b (h w) c")
+        k = rearrange(k, "b c h w -> b c (h w)")
+        w_ = torch.einsum("bij,bjk->bik", q, k)
+
+        w_ = w_ * (int(c) ** (-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = rearrange(v, "b c h w -> b c (h w)")
+        w_ = rearrange(w_, "b i j -> b j i")
+        h_ = torch.einsum("bij,bjk->bik", v, w_)
+        h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+        h_ = self.proj_out(h_)
+
+        return x + h_
+
+
+# class CrossAttention(nn.Module):
+#     """
+#     ### Cross Attention Layer
+#     This falls-back to self-attention when conditional embeddings are not specified.
+#     """
+
+#     use_flash_attention: bool = True
+
+#     # use_flash_attention: bool = False
+#     def __init__(
+#         self,
+#         query_dim,
+#         context_dim=None,
+#         heads=8,
+#         dim_head=64,
+#         dropout=0.0,
+#         is_inplace: bool = True,
+#     ):
+#         # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
+#         """
+#         :param d_model: is the input embedding size
+#         :param n_heads: is the number of attention heads
+#         :param d_head: is the size of a attention head
+#         :param d_cond: is the size of the conditional embeddings
+#         :param is_inplace: specifies whether to perform the attention softmax computation inplace to
+#             save memory
+#         """
+#         super().__init__()
+
+#         self.is_inplace = is_inplace
+#         self.n_heads = heads
+#         self.d_head = dim_head
+
+#         # Attention scaling factor
+#         self.scale = dim_head**-0.5
+
+#         # The normal self-attention layer
+#         if context_dim is None:
+#             context_dim = query_dim
+
+#         # Query, key and value mappings
+#         d_attn = dim_head * heads
+#         self.to_q = nn.Linear(query_dim, d_attn, bias=False)
+#         self.to_k = nn.Linear(context_dim, d_attn, bias=False)
+#         self.to_v = nn.Linear(context_dim, d_attn, bias=False)
+
+#         # Final linear layer
+#         self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
+
+#         # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
+#         # Flash attention is only used if it's installed
+#         # and `CrossAttention.use_flash_attention` is set to `True`.
+#         try:
+#             # You can install flash attention by cloning their Github repo,
+#             # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
+#             # and then running `python setup.py install`
+#             from flash_attn.flash_attention import FlashAttention
+
+#             self.flash = FlashAttention()
+#             # Set the scale for scaled dot-product attention.
+#             self.flash.softmax_scale = self.scale
+#         # Set to `None` if it's not installed
+#         except ImportError:
+#             self.flash = None
+
+#     def forward(self, x, context=None, mask=None):
+#         """
+#         :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
+#         :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
+#         """
+
+#         # If `cond` is `None` we perform self attention
+#         has_cond = context is not None
+#         if not has_cond:
+#             context = x
+
+#         # Get query, key and value vectors
+#         q = self.to_q(x)
+#         k = self.to_k(context)
+#         v = self.to_v(context)
+
+#         # Use flash attention if it's available and the head size is less than or equal to `128`
+#         if (
+#             CrossAttention.use_flash_attention
+#             and self.flash is not None
+#             and not has_cond
+#             and self.d_head <= 128
+#         ):
+#             return self.flash_attention(q, k, v)
+#         # Otherwise, fallback to normal attention
+#         else:
+#             return self.normal_attention(q, k, v)
+
+#     def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+#         """
+#         #### Flash Attention
+#         :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+#         :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+#         :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+#         """
+
+#         # Get batch size and number of elements along sequence axis (`width * height`)
+#         batch_size, seq_len, _ = q.shape
+
+#         # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
+#         # shape `[batch_size, seq_len, 3, n_heads * d_head]`
+#         qkv = torch.stack((q, k, v), dim=2)
+#         # Split the heads
+#         qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
+
+#         # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
+#         # fit this size.
+#         if self.d_head <= 32:
+#             pad = 32 - self.d_head
+#         elif self.d_head <= 64:
+#             pad = 64 - self.d_head
+#         elif self.d_head <= 128:
+#             pad = 128 - self.d_head
+#         else:
+#             raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
+
+#         # Pad the heads
+#         if pad:
+#             qkv = torch.cat(
+#                 (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
+#             )
+
+#         # Compute attention
+#         # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
+#         # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
+#         # TODO here I add the dtype changing
+#         out, _ = self.flash(qkv.type(torch.float16))
+#         # Truncate the extra head size
+#         out = out[:, :, :, : self.d_head].float()
+#         # Reshape to `[batch_size, seq_len, n_heads * d_head]`
+#         out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
+
+#         # Map to `[batch_size, height * width, d_model]` with a linear layer
+#         return self.to_out(out)
+
+#     def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+#         """
+#         #### Normal Attention
+
+#         :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+#         :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+#         :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+#         """
+
+#         # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
+#         q = q.view(*q.shape[:2], self.n_heads, -1)  # [bs, 64, 20, 32]
+#         k = k.view(*k.shape[:2], self.n_heads, -1)  # [bs, 1, 20, 32]
+#         v = v.view(*v.shape[:2], self.n_heads, -1)
+
+#         # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
+#         attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
+
+#         # Compute softmax
+#         # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
+#         if self.is_inplace:
+#             half = attn.shape[0] // 2
+#             attn[half:] = attn[half:].softmax(dim=-1)
+#             attn[:half] = attn[:half].softmax(dim=-1)
+#         else:
+#             attn = attn.softmax(dim=-1)
+
+#         # Compute attention output
+#         # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
+#         # attn: [bs, 20, 64, 1]
+#         # v: [bs, 1, 20, 32]
+#         out = torch.einsum("bhij,bjhd->bihd", attn, v)
+#         # Reshape to `[batch_size, height * width, n_heads * d_head]`
+#         out = out.reshape(*out.shape[:2], -1)
+#         # Map to `[batch_size, height * width, d_model]` with a linear layer
+#         return self.to_out(out)
+
+
+class CrossAttention(nn.Module):
+    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+        super().__init__()
+        inner_dim = dim_head * heads
+        context_dim = default(context_dim, query_dim)
+
+        self.scale = dim_head**-0.5
+        self.heads = heads
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+        self.to_out = nn.Sequential(
+            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+        )
+
+    def forward(self, x, context=None, mask=None):
+        h = self.heads
+
+        q = self.to_q(x)
+        context = default(context, x)
+
+        k = self.to_k(context)
+        v = self.to_v(context)
+
+        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+
+        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+        if exists(mask):
+            mask = rearrange(mask, "b ... -> b (...)")
+            max_neg_value = -torch.finfo(sim.dtype).max
+            mask = repeat(mask, "b j -> (b h) () j", h=h)
+            sim.masked_fill_(~(mask == 1), max_neg_value)
+
+        # attention, what we cannot get enough of
+        attn = sim.softmax(dim=-1)
+
+        out = einsum("b i j, b j d -> b i d", attn, v)
+        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+        return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+    def __init__(
+        self,
+        dim,
+        n_heads,
+        d_head,
+        dropout=0.0,
+        context_dim=None,
+        gated_ff=True,
+        checkpoint=True,
+    ):
+        super().__init__()
+        self.attn1 = CrossAttention(
+            query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
+        )  # is a self-attention
+        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+        self.attn2 = CrossAttention(
+            query_dim=dim,
+            context_dim=context_dim,
+            heads=n_heads,
+            dim_head=d_head,
+            dropout=dropout,
+        )  # is self-attn if context is none
+        self.norm1 = nn.LayerNorm(dim)
+        self.norm2 = nn.LayerNorm(dim)
+        self.norm3 = nn.LayerNorm(dim)
+        self.checkpoint = checkpoint
+
+    def forward(self, x, context=None, mask=None):
+        if context is None:
+            return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
+        else:
+            return checkpoint(
+                self._forward, (x, context, mask), self.parameters(), self.checkpoint
+            )
+
+    def _forward(self, x, context=None, mask=None):
+        x = self.attn1(self.norm1(x)) + x
+        x = self.attn2(self.norm2(x), context=context, mask=mask) + x
+        x = self.ff(self.norm3(x)) + x
+        return x
+
+
+class SpatialTransformer(nn.Module):
+    """
+    Transformer block for image-like data.
+    First, project the input (aka embedding)
+    and reshape to b, t, d.
+    Then apply standard transformer action.
+    Finally, reshape to image
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        n_heads,
+        d_head,
+        depth=1,
+        dropout=0.0,
+        context_dim=None,
+    ):
+        super().__init__()
+
+        context_dim = context_dim
+
+        self.in_channels = in_channels
+        inner_dim = n_heads * d_head
+        self.norm = Normalize(in_channels)
+
+        self.proj_in = nn.Conv2d(
+            in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+        )
+
+        self.transformer_blocks = nn.ModuleList(
+            [
+                BasicTransformerBlock(
+                    inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
+                )
+                for d in range(depth)
+            ]
+        )
+
+        self.proj_out = zero_module(
+            nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+        )
+
+    def forward(self, x, context=None, mask=None):
+        # note: if no context is given, cross-attention defaults to self-attention
+        b, c, h, w = x.shape
+        x_in = x
+        x = self.norm(x)
+        x = self.proj_in(x)
+        x = rearrange(x, "b c h w -> b (h w) c")
+        for block in self.transformer_blocks:
+            x = block(x, context=context, mask=mask)
+        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+        x = self.proj_out(x)
+        return x + x_in
diff --git a/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py b/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py
new file mode 100755
index 0000000000000000000000000000000000000000..f02fa05e163076641b92bbeabceb5f89edb0f18e
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/AudioMAE.py
@@ -0,0 +1,149 @@
+"""
+Reference Repo: https://github.com/facebookresearch/AudioMAE
+"""
+
+import torch
+import torch.nn as nn
+from timm.models.layers import to_2tuple
+import audioldm2.latent_diffusion.modules.audiomae.models_vit as models_vit
+import audioldm2.latent_diffusion.modules.audiomae.models_mae as models_mae
+
+# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
+
+
+class PatchEmbed_new(nn.Module):
+    """Flexible Image to Patch Embedding"""
+
+    def __init__(
+        self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
+    ):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        stride = to_2tuple(stride)
+
+        self.img_size = img_size
+        self.patch_size = patch_size
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=patch_size, stride=stride
+        )  # with overlapped patches
+        # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+        # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
+        # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        _, _, h, w = self.get_output_shape(img_size)  # n, emb_dim, h, w
+        self.patch_hw = (h, w)
+        self.num_patches = h * w
+
+    def get_output_shape(self, img_size):
+        # todo: don't be lazy..
+        return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        # assert H == self.img_size[0] and W == self.img_size[1], \
+        #    f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x)
+        x = x.flatten(2).transpose(1, 2)
+        return x
+
+
+class AudioMAE(nn.Module):
+    """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)"""
+
+    def __init__(
+        self,
+    ):
+        super().__init__()
+        model = models_vit.__dict__["vit_base_patch16"](
+            num_classes=527,
+            drop_path_rate=0.1,
+            global_pool=True,
+            mask_2d=True,
+            use_custom_patch=False,
+        )
+
+        img_size = (1024, 128)
+        emb_dim = 768
+
+        model.patch_embed = PatchEmbed_new(
+            img_size=img_size,
+            patch_size=(16, 16),
+            in_chans=1,
+            embed_dim=emb_dim,
+            stride=16,
+        )
+        num_patches = model.patch_embed.num_patches
+        # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8
+        model.pos_embed = nn.Parameter(
+            torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False
+        )  # fixed sin-cos embedding
+
+        # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth'
+        # checkpoint = torch.load(checkpoint_path, map_location='cpu')
+        # msg = model.load_state_dict(checkpoint['model'], strict=False)
+        # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
+
+        self.model = model
+
+    def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
+        """
+        x: mel fbank [Batch, 1, T, F]
+        mask_t_prob: 'T masking ratio (percentage of removed patches).'
+        mask_f_prob: 'F masking ratio (percentage of removed patches).'
+        """
+        return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
+
+
+class Vanilla_AudioMAE(nn.Module):
+    """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)"""
+
+    def __init__(
+        self,
+    ):
+        super().__init__()
+        model = models_mae.__dict__["mae_vit_base_patch16"](
+            in_chans=1, audio_exp=True, img_size=(1024, 128)
+        )
+
+        # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth'
+        # checkpoint = torch.load(checkpoint_path, map_location='cpu')
+        # msg = model.load_state_dict(checkpoint['model'], strict=False)
+
+        # Skip the missing keys of decoder modules (not required)
+        # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
+
+        self.model = model.eval()
+
+    def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
+        """
+        x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
+        mask_ratio: 'masking ratio (percentage of removed patches).'
+        """
+        with torch.no_grad():
+            # embed: [B, 513, 768] for mask_ratio=0.0
+            if no_mask:
+                if no_average:
+                    raise RuntimeError("This function is deprecated")
+                    embed = self.model.forward_encoder_no_random_mask_no_average(
+                        x
+                    )  # mask_ratio
+                else:
+                    embed = self.model.forward_encoder_no_mask(x)  # mask_ratio
+            else:
+                raise RuntimeError("This function is deprecated")
+                embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
+        return embed
+
+
+if __name__ == "__main__":
+    model = Vanilla_AudioMAE().cuda()
+    input = torch.randn(4, 1, 1024, 128).cuda()
+    print("The first run")
+    embed = model(input, mask_ratio=0.0, no_mask=True)
+    print(embed)
+    print("The second run")
+    embed = model(input, mask_ratio=0.0)
+    print(embed)
diff --git a/audioldm2/latent_diffusion/modules/audiomae/__init__.py b/audioldm2/latent_diffusion/modules/audiomae/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/modules/audiomae/models_mae.py b/audioldm2/latent_diffusion/modules/audiomae/models_mae.py
new file mode 100755
index 0000000000000000000000000000000000000000..7ab0076710a08a7451dd4096bd6eb2f8f6e641aa
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/models_mae.py
@@ -0,0 +1,613 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.models.vision_transformer import Block
+from audioldm2.latent_diffusion.modules.audiomae.util.pos_embed import (
+    get_2d_sincos_pos_embed,
+    get_2d_sincos_pos_embed_flexible,
+)
+from audioldm2.latent_diffusion.modules.audiomae.util.patch_embed import (
+    PatchEmbed_new,
+    PatchEmbed_org,
+)
+
+
+class MaskedAutoencoderViT(nn.Module):
+    """Masked Autoencoder with VisionTransformer backbone"""
+
+    def __init__(
+        self,
+        img_size=224,
+        patch_size=16,
+        stride=10,
+        in_chans=3,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        decoder_embed_dim=512,
+        decoder_depth=8,
+        decoder_num_heads=16,
+        mlp_ratio=4.0,
+        norm_layer=nn.LayerNorm,
+        norm_pix_loss=False,
+        audio_exp=False,
+        alpha=0.0,
+        temperature=0.2,
+        mode=0,
+        contextual_depth=8,
+        use_custom_patch=False,
+        split_pos=False,
+        pos_trainable=False,
+        use_nce=False,
+        beta=4.0,
+        decoder_mode=0,
+        mask_t_prob=0.6,
+        mask_f_prob=0.5,
+        mask_2d=False,
+        epoch=0,
+        no_shift=False,
+    ):
+        super().__init__()
+
+        self.audio_exp = audio_exp
+        self.embed_dim = embed_dim
+        self.decoder_embed_dim = decoder_embed_dim
+        # --------------------------------------------------------------------------
+        # MAE encoder specifics
+        if use_custom_patch:
+            print(
+                f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}"
+            )
+            self.patch_embed = PatchEmbed_new(
+                img_size=img_size,
+                patch_size=patch_size,
+                in_chans=in_chans,
+                embed_dim=embed_dim,
+                stride=stride,
+            )
+        else:
+            self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
+        self.use_custom_patch = use_custom_patch
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+        # self.split_pos = split_pos # not useful
+        self.pos_embed = nn.Parameter(
+            torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable
+        )  # fixed sin-cos embedding
+
+        self.encoder_depth = depth
+        self.contextual_depth = contextual_depth
+        self.blocks = nn.ModuleList(
+            [
+                Block(
+                    embed_dim,
+                    num_heads,
+                    mlp_ratio,
+                    qkv_bias=True,
+                    norm_layer=norm_layer,
+                )  # qk_scale=None
+                for i in range(depth)
+            ]
+        )
+        self.norm = norm_layer(embed_dim)
+
+        # --------------------------------------------------------------------------
+        # MAE decoder specifics
+        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
+
+        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
+        self.decoder_pos_embed = nn.Parameter(
+            torch.zeros(1, num_patches + 1, decoder_embed_dim),
+            requires_grad=pos_trainable,
+        )  # fixed sin-cos embedding
+
+        self.no_shift = no_shift
+
+        self.decoder_mode = decoder_mode
+        if (
+            self.use_custom_patch
+        ):  # overlapped patches as in AST. Similar performance yet compute heavy
+            window_size = (6, 6)
+            feat_size = (102, 12)
+        else:
+            window_size = (4, 4)
+            feat_size = (64, 8)
+        if self.decoder_mode == 1:
+            decoder_modules = []
+            for index in range(16):
+                if self.no_shift:
+                    shift_size = (0, 0)
+                else:
+                    if (index % 2) == 0:
+                        shift_size = (0, 0)
+                    else:
+                        shift_size = (2, 0)
+                    # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])
+                decoder_modules.append(
+                    SwinTransformerBlock(
+                        dim=decoder_embed_dim,
+                        num_heads=16,
+                        feat_size=feat_size,
+                        window_size=window_size,
+                        shift_size=shift_size,
+                        mlp_ratio=mlp_ratio,
+                        drop=0.0,
+                        drop_attn=0.0,
+                        drop_path=0.0,
+                        extra_norm=False,
+                        sequential_attn=False,
+                        norm_layer=norm_layer,  # nn.LayerNorm,
+                    )
+                )
+            self.decoder_blocks = nn.ModuleList(decoder_modules)
+        else:
+            # Transfomer
+            self.decoder_blocks = nn.ModuleList(
+                [
+                    Block(
+                        decoder_embed_dim,
+                        decoder_num_heads,
+                        mlp_ratio,
+                        qkv_bias=True,
+                        norm_layer=norm_layer,
+                    )  # qk_scale=None,
+                    for i in range(decoder_depth)
+                ]
+            )
+
+        self.decoder_norm = norm_layer(decoder_embed_dim)
+        self.decoder_pred = nn.Linear(
+            decoder_embed_dim, patch_size**2 * in_chans, bias=True
+        )  # decoder to patch
+
+        # --------------------------------------------------------------------------
+
+        self.norm_pix_loss = norm_pix_loss
+
+        self.patch_size = patch_size
+        self.stride = stride
+
+        # audio exps
+        self.alpha = alpha
+        self.T = temperature
+        self.mode = mode
+        self.use_nce = use_nce
+        self.beta = beta
+
+        self.log_softmax = nn.LogSoftmax(dim=-1)
+
+        self.mask_t_prob = mask_t_prob
+        self.mask_f_prob = mask_f_prob
+        self.mask_2d = mask_2d
+
+        self.epoch = epoch
+
+        self.initialize_weights()
+
+    def initialize_weights(self):
+        # initialization
+        # initialize (and freeze) pos_embed by sin-cos embedding
+        if self.audio_exp:
+            pos_embed = get_2d_sincos_pos_embed_flexible(
+                self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True
+            )
+        else:
+            pos_embed = get_2d_sincos_pos_embed(
+                self.pos_embed.shape[-1],
+                int(self.patch_embed.num_patches**0.5),
+                cls_token=True,
+            )
+        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+        if self.audio_exp:
+            decoder_pos_embed = get_2d_sincos_pos_embed_flexible(
+                self.decoder_pos_embed.shape[-1],
+                self.patch_embed.patch_hw,
+                cls_token=True,
+            )
+        else:
+            decoder_pos_embed = get_2d_sincos_pos_embed(
+                self.decoder_pos_embed.shape[-1],
+                int(self.patch_embed.num_patches**0.5),
+                cls_token=True,
+            )
+        self.decoder_pos_embed.data.copy_(
+            torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
+        )
+
+        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
+        w = self.patch_embed.proj.weight.data
+        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
+        torch.nn.init.normal_(self.cls_token, std=0.02)
+        torch.nn.init.normal_(self.mask_token, std=0.02)
+
+        # initialize nn.Linear and nn.LayerNorm
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            # we use xavier_uniform following official JAX ViT:
+            torch.nn.init.xavier_uniform_(m.weight)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def patchify(self, imgs):
+        """
+        imgs: (N, 3, H, W)
+        x: (N, L, patch_size**2 *3)
+        L = (H/p)*(W/p)
+        """
+        p = self.patch_embed.patch_size[0]
+        # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+
+        if self.audio_exp:
+            if self.use_custom_patch:  # overlapped patch
+                h, w = self.patch_embed.patch_hw
+                # todo: fixed h/w patch size and stride size. Make hw custom in the future
+                x = imgs.unfold(2, self.patch_size, self.stride).unfold(
+                    3, self.patch_size, self.stride
+                )  # n,1,H,W -> n,1,h,w,p,p
+                x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
+                # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
+                # x = torch.einsum('nchpwq->nhwpqc', x)
+                # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
+            else:
+                h = imgs.shape[2] // p
+                w = imgs.shape[3] // p
+                # h,w = self.patch_embed.patch_hw
+                x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
+                x = torch.einsum("nchpwq->nhwpqc", x)
+                x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
+        else:
+            h = w = imgs.shape[2] // p
+            x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+            x = torch.einsum("nchpwq->nhwpqc", x)
+            x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
+
+        return x
+
+    def unpatchify(self, x):
+        """
+        x: (N, L, patch_size**2 *3)
+        specs: (N, 1, H, W)
+        """
+        p = self.patch_embed.patch_size[0]
+        h = 1024 // p
+        w = 128 // p
+        x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))
+        x = torch.einsum("nhwpqc->nchpwq", x)
+        specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))
+        return specs
+
+    def random_masking(self, x, mask_ratio):
+        """
+        Perform per-sample random masking by per-sample shuffling.
+        Per-sample shuffling is done by argsort random noise.
+        x: [N, L, D], sequence
+        """
+        N, L, D = x.shape  # batch, length, dim
+        len_keep = int(L * (1 - mask_ratio))
+
+        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
+
+        # sort noise for each sample
+        ids_shuffle = torch.argsort(
+            noise, dim=1
+        )  # ascend: small is keep, large is remove
+        ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+        # keep the first subset
+        ids_keep = ids_shuffle[:, :len_keep]
+        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+        # generate the binary mask: 0 is keep, 1 is remove
+        mask = torch.ones([N, L], device=x.device)
+        mask[:, :len_keep] = 0
+        # unshuffle to get the binary mask
+        mask = torch.gather(mask, dim=1, index=ids_restore)
+
+        return x_masked, mask, ids_restore
+
+    def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
+        """
+        2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
+        Perform per-sample random masking by per-sample shuffling.
+        Per-sample shuffling is done by argsort random noise.
+        x: [N, L, D], sequence
+        """
+        N, L, D = x.shape  # batch, length, dim
+        if self.use_custom_patch:  # overlapped patch
+            T = 101
+            F = 12
+        else:
+            T = 64
+            F = 8
+        # x = x.reshape(N, T, F, D)
+        len_keep_t = int(T * (1 - mask_t_prob))
+        len_keep_f = int(F * (1 - mask_f_prob))
+
+        # noise for mask in time
+        noise_t = torch.rand(N, T, device=x.device)  # noise in [0, 1]
+        # sort noise for each sample aling time
+        ids_shuffle_t = torch.argsort(
+            noise_t, dim=1
+        )  # ascend: small is keep, large is remove
+        ids_restore_t = torch.argsort(ids_shuffle_t, dim=1)
+        ids_keep_t = ids_shuffle_t[:, :len_keep_t]
+        # noise mask in freq
+        noise_f = torch.rand(N, F, device=x.device)  # noise in [0, 1]
+        ids_shuffle_f = torch.argsort(
+            noise_f, dim=1
+        )  # ascend: small is keep, large is remove
+        ids_restore_f = torch.argsort(ids_shuffle_f, dim=1)
+        ids_keep_f = ids_shuffle_f[:, :len_keep_f]  #
+
+        # generate the binary mask: 0 is keep, 1 is remove
+        # mask in freq
+        mask_f = torch.ones(N, F, device=x.device)
+        mask_f[:, :len_keep_f] = 0
+        mask_f = (
+            torch.gather(mask_f, dim=1, index=ids_restore_f)
+            .unsqueeze(1)
+            .repeat(1, T, 1)
+        )  # N,T,F
+        # mask in time
+        mask_t = torch.ones(N, T, device=x.device)
+        mask_t[:, :len_keep_t] = 0
+        mask_t = (
+            torch.gather(mask_t, dim=1, index=ids_restore_t)
+            .unsqueeze(1)
+            .repeat(1, F, 1)
+            .permute(0, 2, 1)
+        )  # N,T,F
+        mask = 1 - (1 - mask_t) * (1 - mask_f)  # N, T, F
+
+        # get masked x
+        id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device)
+        id2res = id2res + 999 * mask  # add a large value for masked elements
+        id2res2 = torch.argsort(id2res.flatten(start_dim=1))
+        ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t]
+        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+        ids_restore = torch.argsort(id2res2.flatten(start_dim=1))
+        mask = mask.flatten(start_dim=1)
+
+        return x_masked, mask, ids_restore
+
+    def forward_encoder(self, x, mask_ratio, mask_2d=False):
+        # embed patches
+        x = self.patch_embed(x)
+        # add pos embed w/o cls token
+        x = x + self.pos_embed[:, 1:, :]
+
+        # masking: length -> length * mask_ratio
+        if mask_2d:
+            x, mask, ids_restore = self.random_masking_2d(
+                x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob
+            )
+        else:
+            x, mask, ids_restore = self.random_masking(x, mask_ratio)
+
+        # append cls token
+        cls_token = self.cls_token + self.pos_embed[:, :1, :]
+        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
+        x = torch.cat((cls_tokens, x), dim=1)
+
+        # apply Transformer blocks
+        for blk in self.blocks:
+            x = blk(x)
+        x = self.norm(x)
+
+        return x, mask, ids_restore, None
+
+    def forward_encoder_no_random_mask_no_average(self, x):
+        # embed patches
+        x = self.patch_embed(x)
+        # add pos embed w/o cls token
+        x = x + self.pos_embed[:, 1:, :]
+
+        # masking: length -> length * mask_ratio
+        # if mask_2d:
+        #     x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)
+        # else:
+        #     x, mask, ids_restore = self.random_masking(x, mask_ratio)
+
+        # append cls token
+        cls_token = self.cls_token + self.pos_embed[:, :1, :]
+        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
+        x = torch.cat((cls_tokens, x), dim=1)
+
+        # apply Transformer blocks
+        for blk in self.blocks:
+            x = blk(x)
+        x = self.norm(x)
+
+        return x
+
+    def forward_encoder_no_mask(self, x):
+        # embed patches
+        x = self.patch_embed(x)
+
+        # add pos embed w/o cls token
+        x = x + self.pos_embed[:, 1:, :]
+
+        # masking: length -> length * mask_ratio
+        # x, mask, ids_restore = self.random_masking(x, mask_ratio)
+        # append cls token
+        cls_token = self.cls_token + self.pos_embed[:, :1, :]
+        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
+        x = torch.cat((cls_tokens, x), dim=1)
+
+        # apply Transformer blocks
+        contextual_embs = []
+        for n, blk in enumerate(self.blocks):
+            x = blk(x)
+            if n > self.contextual_depth:
+                contextual_embs.append(self.norm(x))
+        # x = self.norm(x)
+        contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)
+
+        return contextual_emb
+
+    def forward_decoder(self, x, ids_restore):
+        # embed tokens
+        x = self.decoder_embed(x)
+
+        # append mask tokens to sequence
+        mask_tokens = self.mask_token.repeat(
+            x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
+        )
+        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
+        x_ = torch.gather(
+            x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
+        )  # unshuffle
+        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
+
+        # add pos embed
+        x = x + self.decoder_pos_embed
+
+        if self.decoder_mode != 0:
+            B, L, D = x.shape
+            x = x[:, 1:, :]
+            if self.use_custom_patch:
+                x = x.reshape(B, 101, 12, D)
+                x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1)  # hack
+                x = x.reshape(B, 1224, D)
+        if self.decoder_mode > 3:  # mvit
+            x = self.decoder_blocks(x)
+        else:
+            # apply Transformer blocks
+            for blk in self.decoder_blocks:
+                x = blk(x)
+        x = self.decoder_norm(x)
+
+        # predictor projection
+        pred = self.decoder_pred(x)
+
+        # remove cls token
+        if self.decoder_mode != 0:
+            if self.use_custom_patch:
+                pred = pred.reshape(B, 102, 12, 256)
+                pred = pred[:, :101, :, :]
+                pred = pred.reshape(B, 1212, 256)
+            else:
+                pred = pred
+        else:
+            pred = pred[:, 1:, :]
+        return pred, None, None  # emb, emb_pixel
+
+    def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):
+        """
+        imgs: [N, 3, H, W]
+        pred: [N, L, p*p*3]
+        mask: [N, L], 0 is keep, 1 is remove,
+        """
+        target = self.patchify(imgs)
+        if norm_pix_loss:
+            mean = target.mean(dim=-1, keepdim=True)
+            var = target.var(dim=-1, keepdim=True)
+            target = (target - mean) / (var + 1.0e-6) ** 0.5
+
+        loss = (pred - target) ** 2
+        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch
+
+        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
+        return loss
+
+    def forward(self, imgs, mask_ratio=0.8):
+        emb_enc, mask, ids_restore, _ = self.forward_encoder(
+            imgs, mask_ratio, mask_2d=self.mask_2d
+        )
+        pred, _, _ = self.forward_decoder(emb_enc, ids_restore)  # [N, L, p*p*3]
+        loss_recon = self.forward_loss(
+            imgs, pred, mask, norm_pix_loss=self.norm_pix_loss
+        )
+        loss_contrastive = torch.FloatTensor([0.0]).cuda()
+        return loss_recon, pred, mask, loss_contrastive
+
+
+def mae_vit_small_patch16_dec512d8b(**kwargs):
+    model = MaskedAutoencoderViT(
+        patch_size=16,
+        embed_dim=384,
+        depth=12,
+        num_heads=6,
+        decoder_embed_dim=512,
+        decoder_num_heads=16,
+        mlp_ratio=4,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        **kwargs,
+    )
+    return model
+
+
+def mae_vit_base_patch16_dec512d8b(**kwargs):
+    model = MaskedAutoencoderViT(
+        patch_size=16,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        decoder_embed_dim=512,
+        decoder_num_heads=16,
+        mlp_ratio=4,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        **kwargs,
+    )
+    return model
+
+
+def mae_vit_large_patch16_dec512d8b(**kwargs):
+    model = MaskedAutoencoderViT(
+        patch_size=16,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        decoder_embed_dim=512,
+        decoder_num_heads=16,
+        mlp_ratio=4,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        **kwargs,
+    )
+    return model
+
+
+def mae_vit_huge_patch14_dec512d8b(**kwargs):
+    model = MaskedAutoencoderViT(
+        patch_size=14,
+        embed_dim=1280,
+        depth=32,
+        num_heads=16,
+        decoder_embed_dim=512,
+        decoder_num_heads=16,
+        mlp_ratio=4,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        **kwargs,
+    )
+    return model
+
+
+# set recommended archs
+mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
+mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
+mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks
+mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
diff --git a/audioldm2/latent_diffusion/modules/audiomae/models_vit.py b/audioldm2/latent_diffusion/modules/audiomae/models_vit.py
new file mode 100755
index 0000000000000000000000000000000000000000..cb37adbc16cfb9a232493c473c9400f199655b6c
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/models_vit.py
@@ -0,0 +1,243 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+import timm.models.vision_transformer
+
+
+class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
+    """Vision Transformer with support for global average pooling"""
+
+    def __init__(
+        self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs
+    ):
+        super(VisionTransformer, self).__init__(**kwargs)
+
+        self.global_pool = global_pool
+        if self.global_pool:
+            norm_layer = kwargs["norm_layer"]
+            embed_dim = kwargs["embed_dim"]
+            self.fc_norm = norm_layer(embed_dim)
+        del self.norm  # remove the original norm
+        self.mask_2d = mask_2d
+        self.use_custom_patch = use_custom_patch
+
+    def forward_features(self, x):
+        B = x.shape[0]
+        x = self.patch_embed(x)
+        x = x + self.pos_embed[:, 1:, :]
+        cls_token = self.cls_token + self.pos_embed[:, :1, :]
+        cls_tokens = cls_token.expand(
+            B, -1, -1
+        )  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+        x = self.pos_drop(x)
+
+        for blk in self.blocks:
+            x = blk(x)
+
+        if self.global_pool:
+            x = x[:, 1:, :].mean(dim=1)  # global pool without cls token
+            outcome = self.fc_norm(x)
+        else:
+            x = self.norm(x)
+            outcome = x[:, 0]
+
+        return outcome
+
+    def random_masking(self, x, mask_ratio):
+        """
+        Perform per-sample random masking by per-sample shuffling.
+        Per-sample shuffling is done by argsort random noise.
+        x: [N, L, D], sequence
+        """
+        N, L, D = x.shape  # batch, length, dim
+        len_keep = int(L * (1 - mask_ratio))
+
+        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
+
+        # sort noise for each sample
+        ids_shuffle = torch.argsort(
+            noise, dim=1
+        )  # ascend: small is keep, large is remove
+        ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+        # keep the first subset
+        ids_keep = ids_shuffle[:, :len_keep]
+        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+        # generate the binary mask: 0 is keep, 1 is remove
+        mask = torch.ones([N, L], device=x.device)
+        mask[:, :len_keep] = 0
+        # unshuffle to get the binary mask
+        mask = torch.gather(mask, dim=1, index=ids_restore)
+
+        return x_masked, mask, ids_restore
+
+    def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
+        """
+        2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
+        Perform per-sample random masking by per-sample shuffling.
+        Per-sample shuffling is done by argsort random noise.
+        x: [N, L, D], sequence
+        """
+
+        N, L, D = x.shape  # batch, length, dim
+        if self.use_custom_patch:
+            # # for AS
+            T = 101  # 64,101
+            F = 12  # 8,12
+            # # for ESC
+            # T=50
+            # F=12
+            # for SPC
+            # T=12
+            # F=12
+        else:
+            # ## for AS
+            T = 64
+            F = 8
+            # ## for ESC
+            # T=32
+            # F=8
+            ## for SPC
+            # T=8
+            # F=8
+
+        # mask T
+        x = x.reshape(N, T, F, D)
+        len_keep_T = int(T * (1 - mask_t_prob))
+        noise = torch.rand(N, T, device=x.device)  # noise in [0, 1]
+        # sort noise for each sample
+        ids_shuffle = torch.argsort(
+            noise, dim=1
+        )  # ascend: small is keep, large is remove
+        ids_keep = ids_shuffle[:, :len_keep_T]
+        index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)
+        # x_masked = torch.gather(x, dim=1, index=index)
+        # x_masked = x_masked.reshape(N,len_keep_T*F,D)
+        x = torch.gather(x, dim=1, index=index)  # N, len_keep_T(T'), F, D
+
+        # mask F
+        # x = x.reshape(N, T, F, D)
+        x = x.permute(0, 2, 1, 3)  # N T' F D => N F T' D
+        len_keep_F = int(F * (1 - mask_f_prob))
+        noise = torch.rand(N, F, device=x.device)  # noise in [0, 1]
+        # sort noise for each sample
+        ids_shuffle = torch.argsort(
+            noise, dim=1
+        )  # ascend: small is keep, large is remove
+        ids_keep = ids_shuffle[:, :len_keep_F]
+        # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
+        index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)
+        x_masked = torch.gather(x, dim=1, index=index)
+        x_masked = x_masked.permute(0, 2, 1, 3)  # N F' T' D => N T' F' D
+        # x_masked = x_masked.reshape(N,len_keep*T,D)
+        x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D)
+
+        return x_masked, None, None
+
+    def forward_features_mask(self, x, mask_t_prob, mask_f_prob):
+        B = x.shape[0]  # 4,1,1024,128
+        x = self.patch_embed(x)  # 4, 512, 768
+
+        x = x + self.pos_embed[:, 1:, :]
+        if self.random_masking_2d:
+            x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
+        else:
+            x, mask, ids_restore = self.random_masking(x, mask_t_prob)
+        cls_token = self.cls_token + self.pos_embed[:, :1, :]
+        cls_tokens = cls_token.expand(B, -1, -1)
+        x = torch.cat((cls_tokens, x), dim=1)
+        x = self.pos_drop(x)
+
+        # apply Transformer blocks
+        for blk in self.blocks:
+            x = blk(x)
+
+        if self.global_pool:
+            x = x[:, 1:, :].mean(dim=1)  # global pool without cls token
+            outcome = self.fc_norm(x)
+        else:
+            x = self.norm(x)
+            outcome = x[:, 0]
+
+        return outcome
+
+    # overwrite original timm
+    def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):
+        if mask_t_prob > 0.0 or mask_f_prob > 0.0:
+            x = self.forward_features_mask(
+                x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob
+            )
+        else:
+            x = self.forward_features(x)
+        x = self.head(x)
+        return x
+
+
+def vit_small_patch16(**kwargs):
+    model = VisionTransformer(
+        patch_size=16,
+        embed_dim=384,
+        depth=12,
+        num_heads=6,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        **kwargs
+    )
+    return model
+
+
+def vit_base_patch16(**kwargs):
+    model = VisionTransformer(
+        patch_size=16,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        **kwargs
+    )
+    return model
+
+
+def vit_large_patch16(**kwargs):
+    model = VisionTransformer(
+        patch_size=16,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        **kwargs
+    )
+    return model
+
+
+def vit_huge_patch14(**kwargs):
+    model = VisionTransformer(
+        patch_size=14,
+        embed_dim=1280,
+        depth=32,
+        num_heads=16,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        **kwargs
+    )
+    return model
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/crop.py b/audioldm2/latent_diffusion/modules/audiomae/util/crop.py
new file mode 100755
index 0000000000000000000000000000000000000000..525e3c783c3d348e593dc89c2b5fb8520918e9ea
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/crop.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+
+from torchvision import transforms
+from torchvision.transforms import functional as F
+
+
+class RandomResizedCrop(transforms.RandomResizedCrop):
+    """
+    RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
+    This may lead to results different with torchvision's version.
+    Following BYOL's TF code:
+    https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
+    """
+
+    @staticmethod
+    def get_params(img, scale, ratio):
+        width, height = F._get_image_size(img)
+        area = height * width
+
+        target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
+        log_ratio = torch.log(torch.tensor(ratio))
+        aspect_ratio = torch.exp(
+            torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
+        ).item()
+
+        w = int(round(math.sqrt(target_area * aspect_ratio)))
+        h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+        w = min(w, width)
+        h = min(h, height)
+
+        i = torch.randint(0, height - h + 1, size=(1,)).item()
+        j = torch.randint(0, width - w + 1, size=(1,)).item()
+
+        return i, j, h, w
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py b/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py
new file mode 100755
index 0000000000000000000000000000000000000000..b90f89a7d5f78c31bc9113dd88b632b0c234f10a
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/datasets.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+
+import os
+import PIL
+
+from torchvision import datasets, transforms
+
+from timm.data import create_transform
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+
+
+def build_dataset(is_train, args):
+    transform = build_transform(is_train, args)
+
+    root = os.path.join(args.data_path, "train" if is_train else "val")
+    dataset = datasets.ImageFolder(root, transform=transform)
+
+    print(dataset)
+
+    return dataset
+
+
+def build_transform(is_train, args):
+    mean = IMAGENET_DEFAULT_MEAN
+    std = IMAGENET_DEFAULT_STD
+    # train transform
+    if is_train:
+        # this should always dispatch to transforms_imagenet_train
+        transform = create_transform(
+            input_size=args.input_size,
+            is_training=True,
+            color_jitter=args.color_jitter,
+            auto_augment=args.aa,
+            interpolation="bicubic",
+            re_prob=args.reprob,
+            re_mode=args.remode,
+            re_count=args.recount,
+            mean=mean,
+            std=std,
+        )
+        return transform
+
+    # eval transform
+    t = []
+    if args.input_size <= 224:
+        crop_pct = 224 / 256
+    else:
+        crop_pct = 1.0
+    size = int(args.input_size / crop_pct)
+    t.append(
+        transforms.Resize(
+            size, interpolation=PIL.Image.BICUBIC
+        ),  # to maintain same ratio w.r.t. 224 images
+    )
+    t.append(transforms.CenterCrop(args.input_size))
+
+    t.append(transforms.ToTensor())
+    t.append(transforms.Normalize(mean, std))
+    return transforms.Compose(t)
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lars.py b/audioldm2/latent_diffusion/modules/audiomae/util/lars.py
new file mode 100755
index 0000000000000000000000000000000000000000..fc43923d22cf2c9af4ae9166612c3f3477faf254
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/lars.py
@@ -0,0 +1,60 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# LARS optimizer, implementation from MoCo v3:
+# https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+
+import torch
+
+
+class LARS(torch.optim.Optimizer):
+    """
+    LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
+    """
+
+    def __init__(
+        self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001
+    ):
+        defaults = dict(
+            lr=lr,
+            weight_decay=weight_decay,
+            momentum=momentum,
+            trust_coefficient=trust_coefficient,
+        )
+        super().__init__(params, defaults)
+
+    @torch.no_grad()
+    def step(self):
+        for g in self.param_groups:
+            for p in g["params"]:
+                dp = p.grad
+
+                if dp is None:
+                    continue
+
+                if p.ndim > 1:  # if not normalization gamma/beta or bias
+                    dp = dp.add(p, alpha=g["weight_decay"])
+                    param_norm = torch.norm(p)
+                    update_norm = torch.norm(dp)
+                    one = torch.ones_like(param_norm)
+                    q = torch.where(
+                        param_norm > 0.0,
+                        torch.where(
+                            update_norm > 0,
+                            (g["trust_coefficient"] * param_norm / update_norm),
+                            one,
+                        ),
+                        one,
+                    )
+                    dp = dp.mul(q)
+
+                param_state = self.state[p]
+                if "mu" not in param_state:
+                    param_state["mu"] = torch.zeros_like(p)
+                mu = param_state["mu"]
+                mu.mul_(g["momentum"]).add_(dp)
+                p.add_(mu, alpha=-g["lr"])
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py b/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py
new file mode 100755
index 0000000000000000000000000000000000000000..e90ed69d7b8d019dbf5d90571541668e2bd8efe8
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/lr_decay.py
@@ -0,0 +1,76 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# ELECTRA https://github.com/google-research/electra
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+
+def param_groups_lrd(
+    model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75
+):
+    """
+    Parameter groups for layer-wise lr decay
+    Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
+    """
+    param_group_names = {}
+    param_groups = {}
+
+    num_layers = len(model.blocks) + 1
+
+    layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
+
+    for n, p in model.named_parameters():
+        if not p.requires_grad:
+            continue
+
+        # no decay: all 1D parameters and model specific ones
+        if p.ndim == 1 or n in no_weight_decay_list:
+            g_decay = "no_decay"
+            this_decay = 0.0
+        else:
+            g_decay = "decay"
+            this_decay = weight_decay
+
+        layer_id = get_layer_id_for_vit(n, num_layers)
+        group_name = "layer_%d_%s" % (layer_id, g_decay)
+
+        if group_name not in param_group_names:
+            this_scale = layer_scales[layer_id]
+
+            param_group_names[group_name] = {
+                "lr_scale": this_scale,
+                "weight_decay": this_decay,
+                "params": [],
+            }
+            param_groups[group_name] = {
+                "lr_scale": this_scale,
+                "weight_decay": this_decay,
+                "params": [],
+            }
+
+        param_group_names[group_name]["params"].append(n)
+        param_groups[group_name]["params"].append(p)
+
+    # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
+
+    return list(param_groups.values())
+
+
+def get_layer_id_for_vit(name, num_layers):
+    """
+    Assign a parameter with its layer id
+    Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
+    """
+    if name in ["cls_token", "pos_embed"]:
+        return 0
+    elif name.startswith("patch_embed"):
+        return 0
+    elif name.startswith("blocks"):
+        return int(name.split(".")[1]) + 1
+    else:
+        return num_layers
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py b/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py
new file mode 100755
index 0000000000000000000000000000000000000000..efe184d8e3fb63ec6b4f83375b6ea719985900de
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/lr_sched.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+
+def adjust_learning_rate(optimizer, epoch, args):
+    """Decay the learning rate with half-cycle cosine after warmup"""
+    if epoch < args.warmup_epochs:
+        lr = args.lr * epoch / args.warmup_epochs
+    else:
+        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
+            1.0
+            + math.cos(
+                math.pi
+                * (epoch - args.warmup_epochs)
+                / (args.epochs - args.warmup_epochs)
+            )
+        )
+    for param_group in optimizer.param_groups:
+        if "lr_scale" in param_group:
+            param_group["lr"] = lr * param_group["lr_scale"]
+        else:
+            param_group["lr"] = lr
+    return lr
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/misc.py b/audioldm2/latent_diffusion/modules/audiomae/util/misc.py
new file mode 100755
index 0000000000000000000000000000000000000000..74184e09e23e0e174350b894b0cff29600c18b71
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/misc.py
@@ -0,0 +1,453 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import builtins
+import datetime
+import os
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+from torch._six import inf
+
+
+class SmoothedValue(object):
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    def synchronize_between_processes(self):
+        """
+        Warning: does not synchronize the deque!
+        """
+        if not is_dist_avail_and_initialized():
+            return
+        t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+        dist.barrier()
+        dist.all_reduce(t)
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = t[1]
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median,
+            avg=self.avg,
+            global_avg=self.global_avg,
+            max=self.max,
+            value=self.value,
+        )
+
+
+class MetricLogger(object):
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if v is None:
+                continue
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError(
+            "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+        )
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append("{}: {}".format(name, str(meter)))
+        return self.delimiter.join(loss_str)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None):
+        i = 0
+        if not header:
+            header = ""
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt="{avg:.4f}")
+        data_time = SmoothedValue(fmt="{avg:.4f}")
+        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+        log_msg = [
+            header,
+            "[{0" + space_fmt + "}/{1}]",
+            "eta: {eta}",
+            "{meters}",
+            "time: {time}",
+            "data: {data}",
+        ]
+        if torch.cuda.is_available():
+            log_msg.append("max mem: {memory:.0f}")
+        log_msg = self.delimiter.join(log_msg)
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0 or i == len(iterable) - 1:
+                eta_seconds = iter_time.global_avg * (len(iterable) - i)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+                if torch.cuda.is_available():
+                    print(
+                        log_msg.format(
+                            i,
+                            len(iterable),
+                            eta=eta_string,
+                            meters=str(self),
+                            time=str(iter_time),
+                            data=str(data_time),
+                            memory=torch.cuda.max_memory_allocated() / MB,
+                        )
+                    )
+                else:
+                    print(
+                        log_msg.format(
+                            i,
+                            len(iterable),
+                            eta=eta_string,
+                            meters=str(self),
+                            time=str(iter_time),
+                            data=str(data_time),
+                        )
+                    )
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print(
+            "{} Total time: {} ({:.4f} s / it)".format(
+                header, total_time_str, total_time / len(iterable)
+            )
+        )
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    builtin_print = builtins.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop("force", False)
+        force = force or (get_world_size() > 8)
+        if is_master or force:
+            now = datetime.datetime.now().time()
+            builtin_print("[{}] ".format(now), end="")  # print with time stamp
+            builtin_print(*args, **kwargs)
+
+    builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+    if is_main_process():
+        torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+    if args.dist_on_itp:
+        args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
+        args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
+        args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
+        args.dist_url = "tcp://%s:%s" % (
+            os.environ["MASTER_ADDR"],
+            os.environ["MASTER_PORT"],
+        )
+        os.environ["LOCAL_RANK"] = str(args.gpu)
+        os.environ["RANK"] = str(args.rank)
+        os.environ["WORLD_SIZE"] = str(args.world_size)
+        # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+    elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ["WORLD_SIZE"])
+        args.gpu = int(os.environ["LOCAL_RANK"])
+    elif "SLURM_PROCID" in os.environ:
+        args.rank = int(os.environ["SLURM_PROCID"])
+        args.gpu = args.rank % torch.cuda.device_count()
+    else:
+        print("Not using distributed mode")
+        setup_for_distributed(is_master=True)  # hack
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = "nccl"
+    print(
+        "| distributed init (rank {}): {}, gpu {}".format(
+            args.rank, args.dist_url, args.gpu
+        ),
+        flush=True,
+    )
+    torch.distributed.init_process_group(
+        backend=args.dist_backend,
+        init_method=args.dist_url,
+        world_size=args.world_size,
+        rank=args.rank,
+    )
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+    state_dict_key = "amp_scaler"
+
+    def __init__(self):
+        self._scaler = torch.cuda.amp.GradScaler()
+
+    def __call__(
+        self,
+        loss,
+        optimizer,
+        clip_grad=None,
+        parameters=None,
+        create_graph=False,
+        update_grad=True,
+    ):
+        self._scaler.scale(loss).backward(create_graph=create_graph)
+        if update_grad:
+            if clip_grad is not None:
+                assert parameters is not None
+                self._scaler.unscale_(
+                    optimizer
+                )  # unscale the gradients of optimizer's assigned params in-place
+                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+            else:
+                self._scaler.unscale_(optimizer)
+                norm = get_grad_norm_(parameters)
+            self._scaler.step(optimizer)
+            self._scaler.update()
+        else:
+            norm = None
+        return norm
+
+    def state_dict(self):
+        return self._scaler.state_dict()
+
+    def load_state_dict(self, state_dict):
+        self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = [p for p in parameters if p.grad is not None]
+    norm_type = float(norm_type)
+    if len(parameters) == 0:
+        return torch.tensor(0.0)
+    device = parameters[0].grad.device
+    if norm_type == inf:
+        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+    else:
+        total_norm = torch.norm(
+            torch.stack(
+                [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
+            ),
+            norm_type,
+        )
+    return total_norm
+
+
+def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
+    output_dir = Path(args.output_dir)
+    epoch_name = str(epoch)
+    if loss_scaler is not None:
+        checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)]
+        for checkpoint_path in checkpoint_paths:
+            to_save = {
+                "model": model_without_ddp.state_dict(),
+                "optimizer": optimizer.state_dict(),
+                "epoch": epoch,
+                "scaler": loss_scaler.state_dict(),
+                "args": args,
+            }
+
+            save_on_master(to_save, checkpoint_path)
+    else:
+        client_state = {"epoch": epoch}
+        model.save_checkpoint(
+            save_dir=args.output_dir,
+            tag="checkpoint-%s" % epoch_name,
+            client_state=client_state,
+        )
+
+
+def load_model(args, model_without_ddp, optimizer, loss_scaler):
+    if args.resume:
+        if args.resume.startswith("https"):
+            checkpoint = torch.hub.load_state_dict_from_url(
+                args.resume, map_location="cpu", check_hash=True
+            )
+        else:
+            checkpoint = torch.load(args.resume, map_location="cpu")
+        model_without_ddp.load_state_dict(checkpoint["model"])
+        print("Resume checkpoint %s" % args.resume)
+        if (
+            "optimizer" in checkpoint
+            and "epoch" in checkpoint
+            and not (hasattr(args, "eval") and args.eval)
+        ):
+            optimizer.load_state_dict(checkpoint["optimizer"])
+            args.start_epoch = checkpoint["epoch"] + 1
+            if "scaler" in checkpoint:
+                loss_scaler.load_state_dict(checkpoint["scaler"])
+            print("With optim & sched!")
+
+
+def all_reduce_mean(x):
+    world_size = get_world_size()
+    if world_size > 1:
+        x_reduce = torch.tensor(x).cuda()
+        dist.all_reduce(x_reduce)
+        x_reduce /= world_size
+        return x_reduce.item()
+    else:
+        return x
+
+
+# utils
+@torch.no_grad()
+def concat_all_gather(tensor):
+    """
+    Performs all_gather operation on the provided tensors.
+    *** Warning ***: torch.distributed.all_gather has no gradient.
+    """
+    tensors_gather = [
+        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
+    ]
+    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+    output = torch.cat(tensors_gather, dim=0)
+    return output
+
+
+def merge_vmae_to_avmae(avmae_state_dict, vmae_ckpt):
+    # keys_to_copy=['pos_embed','patch_embed']
+    # replaced=0
+
+    vmae_ckpt["cls_token"] = vmae_ckpt["cls_token_v"]
+    vmae_ckpt["mask_token"] = vmae_ckpt["mask_token_v"]
+
+    # pos_emb % not trainable, use default
+    pos_embed_v = vmae_ckpt["pos_embed_v"]  # 1,589,768
+    pos_embed = pos_embed_v[:, 1:, :]  # 1,588,768
+    cls_embed = pos_embed_v[:, 0, :].unsqueeze(1)
+    pos_embed = pos_embed.reshape(1, 2, 14, 14, 768).sum(dim=1)  # 1, 14, 14, 768
+    print("Position interpolate from 14,14 to 64,8")
+    pos_embed = pos_embed.permute(0, 3, 1, 2)  # 1, 14,14,768 -> 1,768,14,14
+    pos_embed = torch.nn.functional.interpolate(
+        pos_embed, size=(64, 8), mode="bicubic", align_corners=False
+    )
+    pos_embed = pos_embed.permute(0, 2, 3, 1).flatten(
+        1, 2
+    )  # 1, 14, 14, 768 => 1, 196,768
+    pos_embed = torch.cat((cls_embed, pos_embed), dim=1)
+    assert vmae_ckpt["pos_embed"].shape == pos_embed.shape
+    vmae_ckpt["pos_embed"] = pos_embed
+    # patch_emb
+    # aggregate 3 channels in video-rgb ckpt to 1 channel for audio
+    v_weight = vmae_ckpt["patch_embed_v.proj.weight"]  # 768,3,2,16,16
+    new_proj_weight = torch.nn.Parameter(v_weight.sum(dim=2).sum(dim=1).unsqueeze(1))
+    assert new_proj_weight.shape == vmae_ckpt["patch_embed.proj.weight"].shape
+    vmae_ckpt["patch_embed.proj.weight"] = new_proj_weight
+    vmae_ckpt["patch_embed.proj.bias"] = vmae_ckpt["patch_embed_v.proj.bias"]
+
+    # hack
+    vmae_ckpt["norm.weight"] = vmae_ckpt["norm_v.weight"]
+    vmae_ckpt["norm.bias"] = vmae_ckpt["norm_v.bias"]
+
+    # replace transformer encoder
+    for k, v in vmae_ckpt.items():
+        if k.startswith("blocks."):
+            kk = k.replace("blocks.", "blocks_v.")
+            vmae_ckpt[k] = vmae_ckpt[kk]
+        elif k.startswith("blocks_v."):
+            pass
+        else:
+            print(k)
+    print(k)
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py b/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py
new file mode 100755
index 0000000000000000000000000000000000000000..ac1e4d436c6f79aef9bf1de32cdac5d4f037c775
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/patch_embed.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+from timm.models.layers import to_2tuple
+
+
+class PatchEmbed_org(nn.Module):
+    """Image to Patch Embedding"""
+
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+        )
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        # assert H == self.img_size[0] and W == self.img_size[1], \
+        #    f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x)
+        y = x.flatten(2).transpose(1, 2)
+        return y
+
+
+class PatchEmbed_new(nn.Module):
+    """Flexible Image to Patch Embedding"""
+
+    def __init__(
+        self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
+    ):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        stride = to_2tuple(stride)
+
+        self.img_size = img_size
+        self.patch_size = patch_size
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=patch_size, stride=stride
+        )  # with overlapped patches
+        # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+        # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
+        # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        _, _, h, w = self.get_output_shape(img_size)  # n, emb_dim, h, w
+        self.patch_hw = (h, w)
+        self.num_patches = h * w
+
+    def get_output_shape(self, img_size):
+        # todo: don't be lazy..
+        return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        # assert H == self.img_size[0] and W == self.img_size[1], \
+        #    f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        # x = self.proj(x).flatten(2).transpose(1, 2)
+        x = self.proj(x)  # 32, 1, 1024, 128 -> 32, 768, 101, 12
+        x = x.flatten(2)  # 32, 768, 101, 12 -> 32, 768, 1212
+        x = x.transpose(1, 2)  # 32, 768, 1212 -> 32, 1212, 768
+        return x
+
+
+class PatchEmbed3D_new(nn.Module):
+    """Flexible Image to Patch Embedding"""
+
+    def __init__(
+        self,
+        video_size=(16, 224, 224),
+        patch_size=(2, 16, 16),
+        in_chans=3,
+        embed_dim=768,
+        stride=(2, 16, 16),
+    ):
+        super().__init__()
+
+        self.video_size = video_size
+        self.patch_size = patch_size
+        self.in_chans = in_chans
+
+        self.proj = nn.Conv3d(
+            in_chans, embed_dim, kernel_size=patch_size, stride=stride
+        )
+        _, _, t, h, w = self.get_output_shape(video_size)  # n, emb_dim, h, w
+        self.patch_thw = (t, h, w)
+        self.num_patches = t * h * w
+
+    def get_output_shape(self, video_size):
+        # todo: don't be lazy..
+        return self.proj(
+            torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2])
+        ).shape
+
+    def forward(self, x):
+        B, C, T, H, W = x.shape
+        x = self.proj(x)  # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14
+        x = x.flatten(2)  # 32, 768, 1568
+        x = x.transpose(1, 2)  # 32, 768, 1568 -> 32, 1568, 768
+        return x
+
+
+if __name__ == "__main__":
+    # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16))
+    # input = torch.rand(8,1,1024,128)
+    # output = patch_emb(input)
+    # print(output.shape) # (8,512,64)
+
+    patch_emb = PatchEmbed3D_new(
+        video_size=(6, 224, 224),
+        patch_size=(2, 16, 16),
+        in_chans=3,
+        embed_dim=768,
+        stride=(2, 16, 16),
+    )
+    input = torch.rand(8, 3, 6, 224, 224)
+    output = patch_emb(input)
+    print(output.shape)  # (8,64)
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py b/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py
new file mode 100755
index 0000000000000000000000000000000000000000..2d9177ed98dffcf35264f38aff94e7f00fb50abf
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/pos_embed.py
@@ -0,0 +1,206 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+import numpy as np
+
+import torch
+
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+    """
+    grid_size: int of the grid height and width
+    return:
+    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+    """
+    grid_h = np.arange(grid_size, dtype=np.float32)
+    grid_w = np.arange(grid_size, dtype=np.float32)
+    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
+    grid = np.stack(grid, axis=0)
+
+    grid = grid.reshape([2, 1, grid_size, grid_size])
+    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+    if cls_token:
+        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+    return pos_embed
+
+
+def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
+    """
+    grid_size: int of the grid height and width
+    return:
+    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+    """
+    grid_h = np.arange(grid_size[0], dtype=np.float32)
+    grid_w = np.arange(grid_size[1], dtype=np.float32)
+    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
+    grid = np.stack(grid, axis=0)
+
+    grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
+    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+    if cls_token:
+        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+    return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+    assert embed_dim % 2 == 0
+
+    # use half of dimensions to encode grid_h
+    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
+    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
+
+    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
+    return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+    """
+    embed_dim: output dimension for each position
+    pos: a list of positions to be encoded: size (M,)
+    out: (M, D)
+    """
+    assert embed_dim % 2 == 0
+    # omega = np.arange(embed_dim // 2, dtype=np.float)
+    omega = np.arange(embed_dim // 2, dtype=float)
+    omega /= embed_dim / 2.0
+    omega = 1.0 / 10000**omega  # (D/2,)
+
+    pos = pos.reshape(-1)  # (M,)
+    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
+
+    emb_sin = np.sin(out)  # (M, D/2)
+    emb_cos = np.cos(out)  # (M, D/2)
+
+    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
+    return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+    if "pos_embed" in checkpoint_model:
+        pos_embed_checkpoint = checkpoint_model["pos_embed"]
+        embedding_size = pos_embed_checkpoint.shape[-1]
+        num_patches = model.patch_embed.num_patches
+        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+        # height (== width) for the checkpoint position embedding
+        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+        # height (== width) for the new position embedding
+        new_size = int(num_patches**0.5)
+        # class_token and dist_token are kept unchanged
+        if orig_size != new_size:
+            print(
+                "Position interpolate from %dx%d to %dx%d"
+                % (orig_size, orig_size, new_size, new_size)
+            )
+            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+            # only the position tokens are interpolated
+            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+            pos_tokens = pos_tokens.reshape(
+                -1, orig_size, orig_size, embedding_size
+            ).permute(0, 3, 1, 2)
+            pos_tokens = torch.nn.functional.interpolate(
+                pos_tokens,
+                size=(new_size, new_size),
+                mode="bicubic",
+                align_corners=False,
+            )
+            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+            checkpoint_model["pos_embed"] = new_pos_embed
+
+
+def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size):
+    if "pos_embed" in checkpoint_model:
+        pos_embed_checkpoint = checkpoint_model["pos_embed"]
+        embedding_size = pos_embed_checkpoint.shape[-1]
+        num_patches = model.patch_embed.num_patches
+        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+        # height (== width) for the checkpoint position embedding
+        # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+        # height (== width) for the new position embedding
+        # new_size = int(num_patches ** 0.5)
+        # class_token and dist_token are kept unchanged
+        if orig_size != new_size:
+            print(
+                "Position interpolate from %dx%d to %dx%d"
+                % (orig_size[0], orig_size[1], new_size[0], new_size[1])
+            )
+            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+            # only the position tokens are interpolated
+            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+            pos_tokens = pos_tokens.reshape(
+                -1, orig_size[0], orig_size[1], embedding_size
+            ).permute(0, 3, 1, 2)
+            pos_tokens = torch.nn.functional.interpolate(
+                pos_tokens,
+                size=(new_size[0], new_size[1]),
+                mode="bicubic",
+                align_corners=False,
+            )
+            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+            checkpoint_model["pos_embed"] = new_pos_embed
+
+
+def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size):
+    if "pos_embed" in checkpoint_model:
+        pos_embed_checkpoint = checkpoint_model["pos_embed"]
+        embedding_size = pos_embed_checkpoint.shape[-1]
+        num_patches = model.patch_embed.num_patches
+        model.pos_embed.shape[-2] - num_patches
+        if orig_size != new_size:
+            print(
+                "Position interpolate from %dx%d to %dx%d"
+                % (orig_size[0], orig_size[1], new_size[0], new_size[1])
+            )
+            # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+            # only the position tokens are interpolated
+            cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1)
+            pos_tokens = pos_embed_checkpoint[:, 1:, :]  # remove
+            pos_tokens = pos_tokens.reshape(
+                -1, orig_size[0], orig_size[1], embedding_size
+            )  # .permute(0, 3, 1, 2)
+            # pos_tokens = torch.nn.functional.interpolate(
+            #    pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
+
+            # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+            pos_tokens = pos_tokens[:, :, : new_size[1], :]  # assume only time diff
+            pos_tokens = pos_tokens.flatten(1, 2)
+            new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1)
+            checkpoint_model["pos_embed"] = new_pos_embed
+
+
+def interpolate_patch_embed_audio(
+    model,
+    checkpoint_model,
+    orig_channel,
+    new_channel=1,
+    kernel_size=(16, 16),
+    stride=(16, 16),
+    padding=(0, 0),
+):
+    if orig_channel != new_channel:
+        if "patch_embed.proj.weight" in checkpoint_model:
+            # aggregate 3 channels in rgb ckpt to 1 channel for audio
+            new_proj_weight = torch.nn.Parameter(
+                torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze(
+                    1
+                )
+            )
+            checkpoint_model["patch_embed.proj.weight"] = new_proj_weight
diff --git a/audioldm2/latent_diffusion/modules/audiomae/util/stat.py b/audioldm2/latent_diffusion/modules/audiomae/util/stat.py
new file mode 100755
index 0000000000000000000000000000000000000000..3f8137249503f6eaa25c3170fe5ef6b87f187347
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/audiomae/util/stat.py
@@ -0,0 +1,76 @@
+import numpy as np
+from scipy import stats
+from sklearn import metrics
+import torch
+
+
+def d_prime(auc):
+    standard_normal = stats.norm()
+    d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
+    return d_prime
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+    """
+    Performs all_gather operation on the provided tensors.
+    *** Warning ***: torch.distributed.all_gather has no gradient.
+    """
+    tensors_gather = [
+        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
+    ]
+    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+    output = torch.cat(tensors_gather, dim=0)
+    return output
+
+
+def calculate_stats(output, target):
+    """Calculate statistics including mAP, AUC, etc.
+
+    Args:
+      output: 2d array, (samples_num, classes_num)
+      target: 2d array, (samples_num, classes_num)
+
+    Returns:
+      stats: list of statistic of each class.
+    """
+
+    classes_num = target.shape[-1]
+    stats = []
+
+    # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
+    acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
+
+    # Class-wise statistics
+    for k in range(classes_num):
+        # Average precision
+        avg_precision = metrics.average_precision_score(
+            target[:, k], output[:, k], average=None
+        )
+
+        # AUC
+        # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)
+
+        # Precisions, recalls
+        (precisions, recalls, thresholds) = metrics.precision_recall_curve(
+            target[:, k], output[:, k]
+        )
+
+        # FPR, TPR
+        (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])
+
+        save_every_steps = 1000  # Sample statistics to reduce size
+        dict = {
+            "precisions": precisions[0::save_every_steps],
+            "recalls": recalls[0::save_every_steps],
+            "AP": avg_precision,
+            "fpr": fpr[0::save_every_steps],
+            "fnr": 1.0 - tpr[0::save_every_steps],
+            # 'auc': auc,
+            # note acc is not class-wise, this is just to keep consistent with other metrics
+            "acc": acc,
+        }
+        stats.append(dict)
+
+    return stats
diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/__init__.py b/audioldm2/latent_diffusion/modules/diffusionmodules/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/model.py b/audioldm2/latent_diffusion/modules/diffusionmodules/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..851f8dd28e80046c5e3c9d95bd37726024f1367c
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/diffusionmodules/model.py
@@ -0,0 +1,1069 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+from audioldm2.latent_diffusion.util import instantiate_from_config
+from audioldm2.latent_diffusion.modules.attention import LinearAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+    """
+    This matches the implementation in Denoising Diffusion Probabilistic Models:
+    From Fairseq.
+    Build sinusoidal embeddings.
+    This matches the implementation in tensor2tensor, but differs slightly
+    from the description in Section 3.5 of "Attention Is All You Need".
+    """
+    assert len(timesteps.shape) == 1
+
+    half_dim = embedding_dim // 2
+    emb = math.log(10000) / (half_dim - 1)
+    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+    emb = emb.to(device=timesteps.device)
+    emb = timesteps.float()[:, None] * emb[None, :]
+    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+    if embedding_dim % 2 == 1:  # zero pad
+        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+    return emb
+
+
+def nonlinearity(x):
+    # swish
+    return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+    return torch.nn.GroupNorm(
+        num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+    )
+
+
+class Upsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv = torch.nn.Conv2d(
+                in_channels, in_channels, kernel_size=3, stride=1, padding=1
+            )
+
+    def forward(self, x):
+        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+        if self.with_conv:
+            x = self.conv(x)
+        return x
+
+
+class UpsampleTimeStride4(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv = torch.nn.Conv2d(
+                in_channels, in_channels, kernel_size=5, stride=1, padding=2
+            )
+
+    def forward(self, x):
+        x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
+        if self.with_conv:
+            x = self.conv(x)
+        return x
+
+
+class Downsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            # Do time downsampling here
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(
+                in_channels, in_channels, kernel_size=3, stride=2, padding=0
+            )
+
+    def forward(self, x):
+        if self.with_conv:
+            pad = (0, 1, 0, 1)
+            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+            x = self.conv(x)
+        else:
+            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+        return x
+
+
+class DownsampleTimeStride4(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            # Do time downsampling here
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(
+                in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
+            )
+
+    def forward(self, x):
+        if self.with_conv:
+            pad = (0, 1, 0, 1)
+            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+            x = self.conv(x)
+        else:
+            x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
+        return x
+
+
+class ResnetBlock(nn.Module):
+    def __init__(
+        self,
+        *,
+        in_channels,
+        out_channels=None,
+        conv_shortcut=False,
+        dropout,
+        temb_channels=512,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        out_channels = in_channels if out_channels is None else out_channels
+        self.out_channels = out_channels
+        self.use_conv_shortcut = conv_shortcut
+
+        self.norm1 = Normalize(in_channels)
+        self.conv1 = torch.nn.Conv2d(
+            in_channels, out_channels, kernel_size=3, stride=1, padding=1
+        )
+        if temb_channels > 0:
+            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+        self.norm2 = Normalize(out_channels)
+        self.dropout = torch.nn.Dropout(dropout)
+        self.conv2 = torch.nn.Conv2d(
+            out_channels, out_channels, kernel_size=3, stride=1, padding=1
+        )
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                self.conv_shortcut = torch.nn.Conv2d(
+                    in_channels, out_channels, kernel_size=3, stride=1, padding=1
+                )
+            else:
+                self.nin_shortcut = torch.nn.Conv2d(
+                    in_channels, out_channels, kernel_size=1, stride=1, padding=0
+                )
+
+    def forward(self, x, temb):
+        h = x
+        h = self.norm1(h)
+        h = nonlinearity(h)
+        h = self.conv1(h)
+
+        if temb is not None:
+            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+        h = self.norm2(h)
+        h = nonlinearity(h)
+        h = self.dropout(h)
+        h = self.conv2(h)
+
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                x = self.conv_shortcut(x)
+            else:
+                x = self.nin_shortcut(x)
+
+        return x + h
+
+
+class LinAttnBlock(LinearAttention):
+    """to match AttnBlock usage"""
+
+    def __init__(self, in_channels):
+        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.k = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.v = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.proj_out = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b, c, h, w = q.shape
+        q = q.reshape(b, c, h * w).contiguous()
+        q = q.permute(0, 2, 1).contiguous()  # b,hw,c
+        k = k.reshape(b, c, h * w).contiguous()  # b,c,hw
+        w_ = torch.bmm(q, k).contiguous()  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+        w_ = w_ * (int(c) ** (-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = v.reshape(b, c, h * w).contiguous()
+        w_ = w_.permute(0, 2, 1).contiguous()  # b,hw,hw (first hw of k, second of q)
+        h_ = torch.bmm(
+            v, w_
+        ).contiguous()  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+        h_ = h_.reshape(b, c, h, w).contiguous()
+
+        h_ = self.proj_out(h_)
+
+        return x + h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+    assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
+    # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+    if attn_type == "vanilla":
+        return AttnBlock(in_channels)
+    elif attn_type == "none":
+        return nn.Identity(in_channels)
+    else:
+        return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+    def __init__(
+        self,
+        *,
+        ch,
+        out_ch,
+        ch_mult=(1, 2, 4, 8),
+        num_res_blocks,
+        attn_resolutions,
+        dropout=0.0,
+        resamp_with_conv=True,
+        in_channels,
+        resolution,
+        use_timestep=True,
+        use_linear_attn=False,
+        attn_type="vanilla",
+    ):
+        super().__init__()
+        if use_linear_attn:
+            attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = self.ch * 4
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        self.use_timestep = use_timestep
+        if self.use_timestep:
+            # timestep embedding
+            self.temb = nn.Module()
+            self.temb.dense = nn.ModuleList(
+                [
+                    torch.nn.Linear(self.ch, self.temb_ch),
+                    torch.nn.Linear(self.temb_ch, self.temb_ch),
+                ]
+            )
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(
+            in_channels, self.ch, kernel_size=3, stride=1, padding=1
+        )
+
+        curr_res = resolution
+        in_ch_mult = (1,) + tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch * in_ch_mult[i_level]
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(
+                    ResnetBlock(
+                        in_channels=block_in,
+                        out_channels=block_out,
+                        temb_channels=self.temb_ch,
+                        dropout=dropout,
+                    )
+                )
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions - 1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch * ch_mult[i_level]
+            skip_in = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                if i_block == self.num_res_blocks:
+                    skip_in = ch * in_ch_mult[i_level]
+                block.append(
+                    ResnetBlock(
+                        in_channels=block_in + skip_in,
+                        out_channels=block_out,
+                        temb_channels=self.temb_ch,
+                        dropout=dropout,
+                    )
+                )
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up)  # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(
+            block_in, out_ch, kernel_size=3, stride=1, padding=1
+        )
+
+    def forward(self, x, t=None, context=None):
+        # assert x.shape[2] == x.shape[3] == self.resolution
+        if context is not None:
+            # assume aligned context, cat along channel axis
+            x = torch.cat((x, context), dim=1)
+        if self.use_timestep:
+            # timestep embedding
+            assert t is not None
+            temb = get_timestep_embedding(t, self.ch)
+            temb = self.temb.dense[0](temb)
+            temb = nonlinearity(temb)
+            temb = self.temb.dense[1](temb)
+        else:
+            temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions - 1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.up[i_level].block[i_block](
+                    torch.cat([h, hs.pop()], dim=1), temb
+                )
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+    def get_last_layer(self):
+        return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+    def __init__(
+        self,
+        *,
+        ch,
+        out_ch,
+        ch_mult=(1, 2, 4, 8),
+        num_res_blocks,
+        attn_resolutions,
+        dropout=0.0,
+        resamp_with_conv=True,
+        in_channels,
+        resolution,
+        z_channels,
+        double_z=True,
+        use_linear_attn=False,
+        attn_type="vanilla",
+        downsample_time_stride4_levels=[],
+        **ignore_kwargs,
+    ):
+        super().__init__()
+        if use_linear_attn:
+            attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.downsample_time_stride4_levels = downsample_time_stride4_levels
+
+        if len(self.downsample_time_stride4_levels) > 0:
+            assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
+                "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
+                % str(self.num_resolutions)
+            )
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(
+            in_channels, self.ch, kernel_size=3, stride=1, padding=1
+        )
+
+        curr_res = resolution
+        in_ch_mult = (1,) + tuple(ch_mult)
+        self.in_ch_mult = in_ch_mult
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch * in_ch_mult[i_level]
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(
+                    ResnetBlock(
+                        in_channels=block_in,
+                        out_channels=block_out,
+                        temb_channels=self.temb_ch,
+                        dropout=dropout,
+                    )
+                )
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions - 1:
+                if i_level in self.downsample_time_stride4_levels:
+                    down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
+                else:
+                    down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(
+            block_in,
+            2 * z_channels if double_z else z_channels,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+
+    def forward(self, x):
+        # timestep embedding
+        temb = None
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions - 1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class Decoder(nn.Module):
+    def __init__(
+        self,
+        *,
+        ch,
+        out_ch,
+        ch_mult=(1, 2, 4, 8),
+        num_res_blocks,
+        attn_resolutions,
+        dropout=0.0,
+        resamp_with_conv=True,
+        in_channels,
+        resolution,
+        z_channels,
+        give_pre_end=False,
+        tanh_out=False,
+        use_linear_attn=False,
+        downsample_time_stride4_levels=[],
+        attn_type="vanilla",
+        **ignorekwargs,
+    ):
+        super().__init__()
+        if use_linear_attn:
+            attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.give_pre_end = give_pre_end
+        self.tanh_out = tanh_out
+        self.downsample_time_stride4_levels = downsample_time_stride4_levels
+
+        if len(self.downsample_time_stride4_levels) > 0:
+            assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
+                "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
+                % str(self.num_resolutions)
+            )
+
+        # compute in_ch_mult, block_in and curr_res at lowest res
+        (1,) + tuple(ch_mult)
+        block_in = ch * ch_mult[self.num_resolutions - 1]
+        curr_res = resolution // 2 ** (self.num_resolutions - 1)
+        self.z_shape = (1, z_channels, curr_res, curr_res)
+        # print(
+        #     "Working with z of shape {} = {} dimensions.".format(
+        #         self.z_shape, np.prod(self.z_shape)
+        #     )
+        # )
+
+        # z to block_in
+        self.conv_in = torch.nn.Conv2d(
+            z_channels, block_in, kernel_size=3, stride=1, padding=1
+        )
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                block.append(
+                    ResnetBlock(
+                        in_channels=block_in,
+                        out_channels=block_out,
+                        temb_channels=self.temb_ch,
+                        dropout=dropout,
+                    )
+                )
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                if i_level - 1 in self.downsample_time_stride4_levels:
+                    up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
+                else:
+                    up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up)  # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(
+            block_in, out_ch, kernel_size=3, stride=1, padding=1
+        )
+
+    def forward(self, z):
+        # assert z.shape[1:] == self.z_shape[1:]
+        self.last_z_shape = z.shape
+
+        # timestep embedding
+        temb = None
+
+        # z to block_in
+        h = self.conv_in(z)
+
+        # middle
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.up[i_level].block[i_block](h, temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        if self.give_pre_end:
+            return h
+
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        if self.tanh_out:
+            h = torch.tanh(h)
+        return h
+
+
+class SimpleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, *args, **kwargs):
+        super().__init__()
+        self.model = nn.ModuleList(
+            [
+                nn.Conv2d(in_channels, in_channels, 1),
+                ResnetBlock(
+                    in_channels=in_channels,
+                    out_channels=2 * in_channels,
+                    temb_channels=0,
+                    dropout=0.0,
+                ),
+                ResnetBlock(
+                    in_channels=2 * in_channels,
+                    out_channels=4 * in_channels,
+                    temb_channels=0,
+                    dropout=0.0,
+                ),
+                ResnetBlock(
+                    in_channels=4 * in_channels,
+                    out_channels=2 * in_channels,
+                    temb_channels=0,
+                    dropout=0.0,
+                ),
+                nn.Conv2d(2 * in_channels, in_channels, 1),
+                Upsample(in_channels, with_conv=True),
+            ]
+        )
+        # end
+        self.norm_out = Normalize(in_channels)
+        self.conv_out = torch.nn.Conv2d(
+            in_channels, out_channels, kernel_size=3, stride=1, padding=1
+        )
+
+    def forward(self, x):
+        for i, layer in enumerate(self.model):
+            if i in [1, 2, 3]:
+                x = layer(x, None)
+            else:
+                x = layer(x)
+
+        h = self.norm_out(x)
+        h = nonlinearity(h)
+        x = self.conv_out(h)
+        return x
+
+
+class UpsampleDecoder(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        ch,
+        num_res_blocks,
+        resolution,
+        ch_mult=(2, 2),
+        dropout=0.0,
+    ):
+        super().__init__()
+        # upsampling
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        block_in = in_channels
+        curr_res = resolution // 2 ** (self.num_resolutions - 1)
+        self.res_blocks = nn.ModuleList()
+        self.upsample_blocks = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            res_block = []
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                res_block.append(
+                    ResnetBlock(
+                        in_channels=block_in,
+                        out_channels=block_out,
+                        temb_channels=self.temb_ch,
+                        dropout=dropout,
+                    )
+                )
+                block_in = block_out
+            self.res_blocks.append(nn.ModuleList(res_block))
+            if i_level != self.num_resolutions - 1:
+                self.upsample_blocks.append(Upsample(block_in, True))
+                curr_res = curr_res * 2
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(
+            block_in, out_channels, kernel_size=3, stride=1, padding=1
+        )
+
+    def forward(self, x):
+        # upsampling
+        h = x
+        for k, i_level in enumerate(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.res_blocks[i_level][i_block](h, None)
+            if i_level != self.num_resolutions - 1:
+                h = self.upsample_blocks[k](h)
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class LatentRescaler(nn.Module):
+    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+        super().__init__()
+        # residual block, interpolate, residual block
+        self.factor = factor
+        self.conv_in = nn.Conv2d(
+            in_channels, mid_channels, kernel_size=3, stride=1, padding=1
+        )
+        self.res_block1 = nn.ModuleList(
+            [
+                ResnetBlock(
+                    in_channels=mid_channels,
+                    out_channels=mid_channels,
+                    temb_channels=0,
+                    dropout=0.0,
+                )
+                for _ in range(depth)
+            ]
+        )
+        self.attn = AttnBlock(mid_channels)
+        self.res_block2 = nn.ModuleList(
+            [
+                ResnetBlock(
+                    in_channels=mid_channels,
+                    out_channels=mid_channels,
+                    temb_channels=0,
+                    dropout=0.0,
+                )
+                for _ in range(depth)
+            ]
+        )
+
+        self.conv_out = nn.Conv2d(
+            mid_channels,
+            out_channels,
+            kernel_size=1,
+        )
+
+    def forward(self, x):
+        x = self.conv_in(x)
+        for block in self.res_block1:
+            x = block(x, None)
+        x = torch.nn.functional.interpolate(
+            x,
+            size=(
+                int(round(x.shape[2] * self.factor)),
+                int(round(x.shape[3] * self.factor)),
+            ),
+        )
+        x = self.attn(x).contiguous()
+        for block in self.res_block2:
+            x = block(x, None)
+        x = self.conv_out(x)
+        return x
+
+
+class MergedRescaleEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        ch,
+        resolution,
+        out_ch,
+        num_res_blocks,
+        attn_resolutions,
+        dropout=0.0,
+        resamp_with_conv=True,
+        ch_mult=(1, 2, 4, 8),
+        rescale_factor=1.0,
+        rescale_module_depth=1,
+    ):
+        super().__init__()
+        intermediate_chn = ch * ch_mult[-1]
+        self.encoder = Encoder(
+            in_channels=in_channels,
+            num_res_blocks=num_res_blocks,
+            ch=ch,
+            ch_mult=ch_mult,
+            z_channels=intermediate_chn,
+            double_z=False,
+            resolution=resolution,
+            attn_resolutions=attn_resolutions,
+            dropout=dropout,
+            resamp_with_conv=resamp_with_conv,
+            out_ch=None,
+        )
+        self.rescaler = LatentRescaler(
+            factor=rescale_factor,
+            in_channels=intermediate_chn,
+            mid_channels=intermediate_chn,
+            out_channels=out_ch,
+            depth=rescale_module_depth,
+        )
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.rescaler(x)
+        return x
+
+
+class MergedRescaleDecoder(nn.Module):
+    def __init__(
+        self,
+        z_channels,
+        out_ch,
+        resolution,
+        num_res_blocks,
+        attn_resolutions,
+        ch,
+        ch_mult=(1, 2, 4, 8),
+        dropout=0.0,
+        resamp_with_conv=True,
+        rescale_factor=1.0,
+        rescale_module_depth=1,
+    ):
+        super().__init__()
+        tmp_chn = z_channels * ch_mult[-1]
+        self.decoder = Decoder(
+            out_ch=out_ch,
+            z_channels=tmp_chn,
+            attn_resolutions=attn_resolutions,
+            dropout=dropout,
+            resamp_with_conv=resamp_with_conv,
+            in_channels=None,
+            num_res_blocks=num_res_blocks,
+            ch_mult=ch_mult,
+            resolution=resolution,
+            ch=ch,
+        )
+        self.rescaler = LatentRescaler(
+            factor=rescale_factor,
+            in_channels=z_channels,
+            mid_channels=tmp_chn,
+            out_channels=tmp_chn,
+            depth=rescale_module_depth,
+        )
+
+    def forward(self, x):
+        x = self.rescaler(x)
+        x = self.decoder(x)
+        return x
+
+
+class Upsampler(nn.Module):
+    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+        super().__init__()
+        assert out_size >= in_size
+        num_blocks = int(np.log2(out_size // in_size)) + 1
+        factor_up = 1.0 + (out_size % in_size)
+        print(
+            f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
+        )
+        self.rescaler = LatentRescaler(
+            factor=factor_up,
+            in_channels=in_channels,
+            mid_channels=2 * in_channels,
+            out_channels=in_channels,
+        )
+        self.decoder = Decoder(
+            out_ch=out_channels,
+            resolution=out_size,
+            z_channels=in_channels,
+            num_res_blocks=2,
+            attn_resolutions=[],
+            in_channels=None,
+            ch=in_channels,
+            ch_mult=[ch_mult for _ in range(num_blocks)],
+        )
+
+    def forward(self, x):
+        x = self.rescaler(x)
+        x = self.decoder(x)
+        return x
+
+
+class Resize(nn.Module):
+    def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+        super().__init__()
+        self.with_conv = learned
+        self.mode = mode
+        if self.with_conv:
+            print(
+                f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
+            )
+            raise NotImplementedError()
+            assert in_channels is not None
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(
+                in_channels, in_channels, kernel_size=4, stride=2, padding=1
+            )
+
+    def forward(self, x, scale_factor=1.0):
+        if scale_factor == 1.0:
+            return x
+        else:
+            x = torch.nn.functional.interpolate(
+                x, mode=self.mode, align_corners=False, scale_factor=scale_factor
+            )
+        return x
+
+
+class FirstStagePostProcessor(nn.Module):
+    def __init__(
+        self,
+        ch_mult: list,
+        in_channels,
+        pretrained_model: nn.Module = None,
+        reshape=False,
+        n_channels=None,
+        dropout=0.0,
+        pretrained_config=None,
+    ):
+        super().__init__()
+        if pretrained_config is None:
+            assert (
+                pretrained_model is not None
+            ), 'Either "pretrained_model" or "pretrained_config" must not be None'
+            self.pretrained_model = pretrained_model
+        else:
+            assert (
+                pretrained_config is not None
+            ), 'Either "pretrained_model" or "pretrained_config" must not be None'
+            self.instantiate_pretrained(pretrained_config)
+
+        self.do_reshape = reshape
+
+        if n_channels is None:
+            n_channels = self.pretrained_model.encoder.ch
+
+        self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
+        self.proj = nn.Conv2d(
+            in_channels, n_channels, kernel_size=3, stride=1, padding=1
+        )
+
+        blocks = []
+        downs = []
+        ch_in = n_channels
+        for m in ch_mult:
+            blocks.append(
+                ResnetBlock(
+                    in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
+                )
+            )
+            ch_in = m * n_channels
+            downs.append(Downsample(ch_in, with_conv=False))
+
+        self.model = nn.ModuleList(blocks)
+        self.downsampler = nn.ModuleList(downs)
+
+    def instantiate_pretrained(self, config):
+        model = instantiate_from_config(config)
+        self.pretrained_model = model.eval()
+        # self.pretrained_model.train = False
+        for param in self.pretrained_model.parameters():
+            param.requires_grad = False
+
+    @torch.no_grad()
+    def encode_with_pretrained(self, x):
+        c = self.pretrained_model.encode(x)
+        if isinstance(c, DiagonalGaussianDistribution):
+            c = c.mode()
+        return c
+
+    def forward(self, x):
+        z_fs = self.encode_with_pretrained(x)
+        z = self.proj_norm(z_fs)
+        z = self.proj(z)
+        z = nonlinearity(z)
+
+        for submodel, downmodel in zip(self.model, self.downsampler):
+            z = submodel(z, temb=None)
+            z = downmodel(z)
+
+        if self.do_reshape:
+            z = rearrange(z, "b c h w -> b (h w) c")
+        return z
diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py b/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py
new file mode 100755
index 0000000000000000000000000000000000000000..e006e5a332c3cde5f4e221f003b270d86b34e933
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,1103 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.util import (
+    checkpoint,
+    conv_nd,
+    linear,
+    avg_pool_nd,
+    zero_module,
+    normalization,
+    timestep_embedding,
+)
+from audioldm2.latent_diffusion.modules.attention import SpatialTransformer
+
+
+# dummy replace
+def convert_module_to_f16(x):
+    pass
+
+
+def convert_module_to_f32(x):
+    pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+    """
+    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+    """
+
+    def __init__(
+        self,
+        spacial_dim: int,
+        embed_dim: int,
+        num_heads_channels: int,
+        output_dim: int = None,
+    ):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(
+            th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
+        )
+        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+        self.num_heads = embed_dim // num_heads_channels
+        self.attention = QKVAttention(self.num_heads)
+
+    def forward(self, x):
+        b, c, *_spatial = x.shape
+        x = x.reshape(b, c, -1).contiguous()  # NC(HW)
+        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)
+        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)
+        x = self.qkv_proj(x)
+        x = self.attention(x)
+        x = self.c_proj(x)
+        return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+    """
+    Any module where forward() takes timestep embeddings as a second argument.
+    """
+
+    @abstractmethod
+    def forward(self, x, emb):
+        """
+        Apply the module to `x` given `emb` timestep embeddings.
+        """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+    """
+    A sequential module that passes timestep embeddings to the children that
+    support it as an extra input.
+    """
+
+    def forward(self, x, emb, context_list=None, mask_list=None):
+        # The first spatial transformer block does not have context
+        spatial_transformer_id = 0
+        context_list = [None] + context_list
+        mask_list = [None] + mask_list
+
+        for layer in self:
+            if isinstance(layer, TimestepBlock):
+                x = layer(x, emb)
+            elif isinstance(layer, SpatialTransformer):
+                if spatial_transformer_id >= len(context_list):
+                    context, mask = None, None
+                else:
+                    context, mask = (
+                        context_list[spatial_transformer_id],
+                        mask_list[spatial_transformer_id],
+                    )
+
+                x = layer(x, context, mask=mask)
+                spatial_transformer_id += 1
+            else:
+                x = layer(x)
+        return x
+
+
+class Upsample(nn.Module):
+    """
+    An upsampling layer with an optional convolution.
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 upsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        if use_conv:
+            self.conv = conv_nd(
+                dims, self.channels, self.out_channels, 3, padding=padding
+            )
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        if self.dims == 3:
+            x = F.interpolate(
+                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+            )
+        else:
+            x = F.interpolate(x, scale_factor=2, mode="nearest")
+        if self.use_conv:
+            x = self.conv(x)
+        return x
+
+
+class TransposedUpsample(nn.Module):
+    "Learned 2x upsampling without padding"
+
+    def __init__(self, channels, out_channels=None, ks=5):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+
+        self.up = nn.ConvTranspose2d(
+            self.channels, self.out_channels, kernel_size=ks, stride=2
+        )
+
+    def forward(self, x):
+        return self.up(x)
+
+
+class Downsample(nn.Module):
+    """
+    A downsampling layer with an optional convolution.
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 downsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        stride = 2 if dims != 3 else (1, 2, 2)
+        if use_conv:
+            self.op = conv_nd(
+                dims,
+                self.channels,
+                self.out_channels,
+                3,
+                stride=stride,
+                padding=padding,
+            )
+        else:
+            assert self.channels == self.out_channels
+            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+    """
+    A residual block that can optionally change the number of channels.
+    :param channels: the number of input channels.
+    :param emb_channels: the number of timestep embedding channels.
+    :param dropout: the rate of dropout.
+    :param out_channels: if specified, the number of out channels.
+    :param use_conv: if True and out_channels is specified, use a spatial
+        convolution instead of a smaller 1x1 convolution to change the
+        channels in the skip connection.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param use_checkpoint: if True, use gradient checkpointing on this module.
+    :param up: if True, use this block for upsampling.
+    :param down: if True, use this block for downsampling.
+    """
+
+    def __init__(
+        self,
+        channels,
+        emb_channels,
+        dropout,
+        out_channels=None,
+        use_conv=False,
+        use_scale_shift_norm=False,
+        dims=2,
+        use_checkpoint=False,
+        up=False,
+        down=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.emb_channels = emb_channels
+        self.dropout = dropout
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.use_checkpoint = use_checkpoint
+        self.use_scale_shift_norm = use_scale_shift_norm
+
+        self.in_layers = nn.Sequential(
+            normalization(channels),
+            nn.SiLU(),
+            conv_nd(dims, channels, self.out_channels, 3, padding=1),
+        )
+
+        self.updown = up or down
+
+        if up:
+            self.h_upd = Upsample(channels, False, dims)
+            self.x_upd = Upsample(channels, False, dims)
+        elif down:
+            self.h_upd = Downsample(channels, False, dims)
+            self.x_upd = Downsample(channels, False, dims)
+        else:
+            self.h_upd = self.x_upd = nn.Identity()
+
+        self.emb_layers = nn.Sequential(
+            nn.SiLU(),
+            linear(
+                emb_channels,
+                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+            ),
+        )
+        self.out_layers = nn.Sequential(
+            normalization(self.out_channels),
+            nn.SiLU(),
+            nn.Dropout(p=dropout),
+            zero_module(
+                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+            ),
+        )
+
+        if self.out_channels == channels:
+            self.skip_connection = nn.Identity()
+        elif use_conv:
+            self.skip_connection = conv_nd(
+                dims, channels, self.out_channels, 3, padding=1
+            )
+        else:
+            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+    def forward(self, x, emb):
+        """
+        Apply the block to a Tensor, conditioned on a timestep embedding.
+        :param x: an [N x C x ...] Tensor of features.
+        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        return checkpoint(
+            self._forward, (x, emb), self.parameters(), self.use_checkpoint
+        )
+
+    def _forward(self, x, emb):
+        if self.updown:
+            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+            h = in_rest(x)
+            h = self.h_upd(h)
+            x = self.x_upd(x)
+            h = in_conv(h)
+        else:
+            h = self.in_layers(x)
+        emb_out = self.emb_layers(emb).type(h.dtype)
+        while len(emb_out.shape) < len(h.shape):
+            emb_out = emb_out[..., None]
+        if self.use_scale_shift_norm:
+            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+            scale, shift = th.chunk(emb_out, 2, dim=1)
+            h = out_norm(h) * (1 + scale) + shift
+            h = out_rest(h)
+        else:
+            h = h + emb_out
+            h = self.out_layers(h)
+        return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+    """
+    An attention block that allows spatial positions to attend to each other.
+    Originally ported from here, but adapted to the N-d case.
+    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+    """
+
+    def __init__(
+        self,
+        channels,
+        num_heads=1,
+        num_head_channels=-1,
+        use_checkpoint=False,
+        use_new_attention_order=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        if num_head_channels == -1:
+            self.num_heads = num_heads
+        else:
+            assert (
+                channels % num_head_channels == 0
+            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+            self.num_heads = channels // num_head_channels
+        self.use_checkpoint = use_checkpoint
+        self.norm = normalization(channels)
+        self.qkv = conv_nd(1, channels, channels * 3, 1)
+        if use_new_attention_order:
+            # split qkv before split heads
+            self.attention = QKVAttention(self.num_heads)
+        else:
+            # split heads before split qkv
+            self.attention = QKVAttentionLegacy(self.num_heads)
+
+        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+    def forward(self, x):
+        return checkpoint(
+            self._forward, (x,), self.parameters(), True
+        )  # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+        # return pt_checkpoint(self._forward, x)  # pytorch
+
+    def _forward(self, x):
+        b, c, *spatial = x.shape
+        x = x.reshape(b, c, -1).contiguous()
+        qkv = self.qkv(self.norm(x)).contiguous()
+        h = self.attention(qkv).contiguous()
+        h = self.proj_out(h).contiguous()
+        return (x + h).reshape(b, c, *spatial).contiguous()
+
+
+def count_flops_attn(model, _x, y):
+    """
+    A counter for the `thop` package to count the operations in an
+    attention operation.
+    Meant to be used like:
+        macs, params = thop.profile(
+            model,
+            inputs=(inputs, timestamps),
+            custom_ops={QKVAttention: QKVAttention.count_flops},
+        )
+    """
+    b, c, *spatial = y[0].shape
+    num_spatial = int(np.prod(spatial))
+    # We perform two matmuls with the same number of ops.
+    # The first computes the weight matrix, the second computes
+    # the combination of the value vectors.
+    matmul_ops = 2 * b * (num_spatial**2) * c
+    model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+    """
+    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = (
+            qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1)
+        )
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts", q * scale, k * scale
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v)
+        return a.reshape(bs, -1, length).contiguous()
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+    """
+    A module which performs QKV attention and splits in a different order.
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.chunk(3, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts",
+            (q * scale).view(bs * self.n_heads, ch, length),
+            (k * scale).view(bs * self.n_heads, ch, length),
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum(
+            "bts,bcs->bct",
+            weight,
+            v.reshape(bs * self.n_heads, ch, length).contiguous(),
+        )
+        return a.reshape(bs, -1, length).contiguous()
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+    """
+    The full UNet model with attention and timestep embedding.
+    :param in_channels: channels in the input Tensor.
+    :param model_channels: base channel count for the model.
+    :param out_channels: channels in the output Tensor.
+    :param num_res_blocks: number of residual blocks per downsample.
+    :param attention_resolutions: a collection of downsample rates at which
+        attention will take place. May be a set, list, or tuple.
+        For example, if this contains 4, then at 4x downsampling, attention
+        will be used.
+    :param dropout: the dropout probability.
+    :param channel_mult: channel multiplier for each level of the UNet.
+    :param conv_resample: if True, use learned convolutions for upsampling and
+        downsampling.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param num_classes: if specified (as an int), then this model will be
+        class-conditional with `num_classes` classes.
+    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+    :param num_heads: the number of attention heads in each attention layer.
+    :param num_heads_channels: if specified, ignore num_heads and instead use
+                               a fixed channel width per attention head.
+    :param num_heads_upsample: works with num_heads to set a different number
+                               of heads for upsampling. Deprecated.
+    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+    :param resblock_updown: use residual blocks for up/downsampling.
+    :param use_new_attention_order: use a different attention pattern for potentially
+                                    increased efficiency.
+    """
+
+    def __init__(
+        self,
+        image_size,
+        in_channels,
+        model_channels,
+        out_channels,
+        num_res_blocks,
+        attention_resolutions,
+        dropout=0,
+        channel_mult=(1, 2, 4, 8),
+        conv_resample=True,
+        dims=2,
+        extra_sa_layer=True,
+        num_classes=None,
+        extra_film_condition_dim=None,
+        use_checkpoint=False,
+        use_fp16=False,
+        num_heads=-1,
+        num_head_channels=-1,
+        num_heads_upsample=-1,
+        use_scale_shift_norm=False,
+        resblock_updown=False,
+        use_new_attention_order=False,
+        use_spatial_transformer=True,  # custom transformer support
+        transformer_depth=1,  # custom transformer support
+        context_dim=None,  # custom transformer support
+        n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
+        legacy=True,
+    ):
+        super().__init__()
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        if num_heads == -1:
+            assert (
+                num_head_channels != -1
+            ), "Either num_heads or num_head_channels has to be set"
+
+        if num_head_channels == -1:
+            assert (
+                num_heads != -1
+            ), "Either num_heads or num_head_channels has to be set"
+
+        self.image_size = image_size
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.out_channels = out_channels
+        self.num_res_blocks = num_res_blocks
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.num_classes = num_classes
+        self.extra_film_condition_dim = extra_film_condition_dim
+        self.use_checkpoint = use_checkpoint
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+        self.predict_codebook_ids = n_embed is not None
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        # assert not (
+        #     self.num_classes is not None and self.extra_film_condition_dim is not None
+        # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim."
+
+        if self.num_classes is not None:
+            self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+        self.use_extra_film_by_concat = self.extra_film_condition_dim is not None
+
+        if self.extra_film_condition_dim is not None:
+            self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim)
+            print(
+                "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. "
+                % self.extra_film_condition_dim
+            )
+
+        if context_dim is not None and not use_spatial_transformer:
+            assert (
+                use_spatial_transformer
+            ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+
+        if context_dim is not None and not isinstance(context_dim, list):
+            context_dim = [context_dim]
+        elif context_dim is None:
+            context_dim = [None]  # At least use one spatial transformer
+
+        self.input_blocks = nn.ModuleList(
+            [
+                TimestepEmbedSequential(
+                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
+                )
+            ]
+        )
+        self._feature_size = model_channels
+        input_block_chans = [model_channels]
+        ch = model_channels
+        ds = 1
+        for level, mult in enumerate(channel_mult):
+            for _ in range(num_res_blocks):
+                layers = [
+                    ResBlock(
+                        ch,
+                        time_embed_dim
+                        if (not self.use_extra_film_by_concat)
+                        else time_embed_dim * 2,
+                        dropout,
+                        out_channels=mult * model_channels,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = mult * model_channels
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        dim_head = (
+                            ch // num_heads
+                            if use_spatial_transformer
+                            else num_head_channels
+                        )
+                    if extra_sa_layer:
+                        layers.append(
+                            SpatialTransformer(
+                                ch,
+                                num_heads,
+                                dim_head,
+                                depth=transformer_depth,
+                                context_dim=None,
+                            )
+                        )
+                    for context_dim_id in range(len(context_dim)):
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            )
+                            if not use_spatial_transformer
+                            else SpatialTransformer(
+                                ch,
+                                num_heads,
+                                dim_head,
+                                depth=transformer_depth,
+                                context_dim=context_dim[context_dim_id],
+                            )
+                        )
+                self.input_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    TimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim
+                            if (not self.use_extra_film_by_concat)
+                            else time_embed_dim * 2,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                ds *= 2
+                self._feature_size += ch
+
+        if num_head_channels == -1:
+            dim_head = ch // num_heads
+        else:
+            num_heads = ch // num_head_channels
+            dim_head = num_head_channels
+        if legacy:
+            # num_heads = 1
+            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+        middle_layers = [
+            ResBlock(
+                ch,
+                time_embed_dim
+                if (not self.use_extra_film_by_concat)
+                else time_embed_dim * 2,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            )
+        ]
+        if extra_sa_layer:
+            middle_layers.append(
+                SpatialTransformer(
+                    ch, num_heads, dim_head, depth=transformer_depth, context_dim=None
+                )
+            )
+        for context_dim_id in range(len(context_dim)):
+            middle_layers.append(
+                AttentionBlock(
+                    ch,
+                    use_checkpoint=use_checkpoint,
+                    num_heads=num_heads,
+                    num_head_channels=dim_head,
+                    use_new_attention_order=use_new_attention_order,
+                )
+                if not use_spatial_transformer
+                else SpatialTransformer(
+                    ch,
+                    num_heads,
+                    dim_head,
+                    depth=transformer_depth,
+                    context_dim=context_dim[context_dim_id],
+                )
+            )
+        middle_layers.append(
+            ResBlock(
+                ch,
+                time_embed_dim
+                if (not self.use_extra_film_by_concat)
+                else time_embed_dim * 2,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            )
+        )
+        self.middle_block = TimestepEmbedSequential(*middle_layers)
+
+        self._feature_size += ch
+
+        self.output_blocks = nn.ModuleList([])
+        for level, mult in list(enumerate(channel_mult))[::-1]:
+            for i in range(num_res_blocks + 1):
+                ich = input_block_chans.pop()
+                layers = [
+                    ResBlock(
+                        ch + ich,
+                        time_embed_dim
+                        if (not self.use_extra_film_by_concat)
+                        else time_embed_dim * 2,
+                        dropout,
+                        out_channels=model_channels * mult,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = model_channels * mult
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        # num_heads = 1
+                        dim_head = (
+                            ch // num_heads
+                            if use_spatial_transformer
+                            else num_head_channels
+                        )
+                    if extra_sa_layer:
+                        layers.append(
+                            SpatialTransformer(
+                                ch,
+                                num_heads,
+                                dim_head,
+                                depth=transformer_depth,
+                                context_dim=None,
+                            )
+                        )
+                    for context_dim_id in range(len(context_dim)):
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads_upsample,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            )
+                            if not use_spatial_transformer
+                            else SpatialTransformer(
+                                ch,
+                                num_heads,
+                                dim_head,
+                                depth=transformer_depth,
+                                context_dim=context_dim[context_dim_id],
+                            )
+                        )
+                if level and i == num_res_blocks:
+                    out_ch = ch
+                    layers.append(
+                        ResBlock(
+                            ch,
+                            time_embed_dim
+                            if (not self.use_extra_film_by_concat)
+                            else time_embed_dim * 2,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            up=True,
+                        )
+                        if resblock_updown
+                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+                    )
+                    ds //= 2
+                self.output_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+
+        self.out = nn.Sequential(
+            normalization(ch),
+            nn.SiLU(),
+            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+        )
+        if self.predict_codebook_ids:
+            self.id_predictor = nn.Sequential(
+                normalization(ch),
+                conv_nd(dims, model_channels, n_embed, 1),
+                # nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
+            )
+
+        self.shape_reported = False
+
+    def convert_to_fp16(self):
+        """
+        Convert the torso of the model to float16.
+        """
+        self.input_blocks.apply(convert_module_to_f16)
+        self.middle_block.apply(convert_module_to_f16)
+        self.output_blocks.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self):
+        """
+        Convert the torso of the model to float32.
+        """
+        self.input_blocks.apply(convert_module_to_f32)
+        self.middle_block.apply(convert_module_to_f32)
+        self.output_blocks.apply(convert_module_to_f32)
+
+    def forward(
+        self,
+        x,
+        timesteps=None,
+        y=None,
+        context_list=None,
+        context_attn_mask_list=None,
+        **kwargs,
+    ):
+        """
+        Apply the model to an input batch.
+        :param x: an [N x C x ...] Tensor of inputs.
+        :param timesteps: a 1-D batch of timesteps.
+        :param context: conditioning plugged in via crossattn
+        :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        if not self.shape_reported:
+            # print("The shape of UNet input is", x.size())
+            self.shape_reported = True
+
+        assert (y is not None) == (
+            self.num_classes is not None or self.extra_film_condition_dim is not None
+        ), "must specify y if and only if the model is class-conditional or film embedding conditional"
+        hs = []
+        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+        emb = self.time_embed(t_emb)
+
+        # if self.num_classes is not None:
+        #     assert y.shape == (x.shape[0],)
+        #     emb = emb + self.label_emb(y)
+
+        if self.use_extra_film_by_concat:
+            emb = th.cat([emb, self.film_emb(y)], dim=-1)
+
+        h = x.type(self.dtype)
+        for module in self.input_blocks:
+            h = module(h, emb, context_list, context_attn_mask_list)
+            hs.append(h)
+        h = self.middle_block(h, emb, context_list, context_attn_mask_list)
+        for module in self.output_blocks:
+            concate_tensor = hs.pop()
+            h = th.cat([h, concate_tensor], dim=1)
+            h = module(h, emb, context_list, context_attn_mask_list)
+        h = h.type(x.dtype)
+        if self.predict_codebook_ids:
+            return self.id_predictor(h)
+        else:
+            return self.out(h)
+
+
+class EncoderUNetModel(nn.Module):
+    """
+    The half UNet model with attention and timestep embedding.
+    For usage, see UNet.
+    """
+
+    def __init__(
+        self,
+        image_size,
+        in_channels,
+        model_channels,
+        out_channels,
+        num_res_blocks,
+        attention_resolutions,
+        dropout=0,
+        channel_mult=(1, 2, 4, 8),
+        conv_resample=True,
+        dims=2,
+        use_checkpoint=False,
+        use_fp16=False,
+        num_heads=1,
+        num_head_channels=-1,
+        num_heads_upsample=-1,
+        use_scale_shift_norm=False,
+        resblock_updown=False,
+        use_new_attention_order=False,
+        pool="adaptive",
+        *args,
+        **kwargs,
+    ):
+        super().__init__()
+
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.out_channels = out_channels
+        self.num_res_blocks = num_res_blocks
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.use_checkpoint = use_checkpoint
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        self.input_blocks = nn.ModuleList(
+            [
+                TimestepEmbedSequential(
+                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
+                )
+            ]
+        )
+        self._feature_size = model_channels
+        input_block_chans = [model_channels]
+        ch = model_channels
+        ds = 1
+        for level, mult in enumerate(channel_mult):
+            for _ in range(num_res_blocks):
+                layers = [
+                    ResBlock(
+                        ch,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=mult * model_channels,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = mult * model_channels
+                if ds in attention_resolutions:
+                    layers.append(
+                        AttentionBlock(
+                            ch,
+                            use_checkpoint=use_checkpoint,
+                            num_heads=num_heads,
+                            num_head_channels=num_head_channels,
+                            use_new_attention_order=use_new_attention_order,
+                        )
+                    )
+                self.input_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    TimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                ds *= 2
+                self._feature_size += ch
+
+        self.middle_block = TimestepEmbedSequential(
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+            AttentionBlock(
+                ch,
+                use_checkpoint=use_checkpoint,
+                num_heads=num_heads,
+                num_head_channels=num_head_channels,
+                use_new_attention_order=use_new_attention_order,
+            ),
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+        )
+        self._feature_size += ch
+        self.pool = pool
+        if pool == "adaptive":
+            self.out = nn.Sequential(
+                normalization(ch),
+                nn.SiLU(),
+                nn.AdaptiveAvgPool2d((1, 1)),
+                zero_module(conv_nd(dims, ch, out_channels, 1)),
+                nn.Flatten(),
+            )
+        elif pool == "attention":
+            assert num_head_channels != -1
+            self.out = nn.Sequential(
+                normalization(ch),
+                nn.SiLU(),
+                AttentionPool2d(
+                    (image_size // ds), ch, num_head_channels, out_channels
+                ),
+            )
+        elif pool == "spatial":
+            self.out = nn.Sequential(
+                nn.Linear(self._feature_size, 2048),
+                nn.ReLU(),
+                nn.Linear(2048, self.out_channels),
+            )
+        elif pool == "spatial_v2":
+            self.out = nn.Sequential(
+                nn.Linear(self._feature_size, 2048),
+                normalization(2048),
+                nn.SiLU(),
+                nn.Linear(2048, self.out_channels),
+            )
+        else:
+            raise NotImplementedError(f"Unexpected {pool} pooling")
+
+    def convert_to_fp16(self):
+        """
+        Convert the torso of the model to float16.
+        """
+        self.input_blocks.apply(convert_module_to_f16)
+        self.middle_block.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self):
+        """
+        Convert the torso of the model to float32.
+        """
+        self.input_blocks.apply(convert_module_to_f32)
+        self.middle_block.apply(convert_module_to_f32)
+
+    def forward(self, x, timesteps):
+        """
+        Apply the model to an input batch.
+        :param x: an [N x C x ...] Tensor of inputs.
+        :param timesteps: a 1-D batch of timesteps.
+        :return: an [N x K] Tensor of outputs.
+        """
+        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+        results = []
+        h = x.type(self.dtype)
+        for module in self.input_blocks:
+            h = module(h, emb)
+            if self.pool.startswith("spatial"):
+                results.append(h.type(x.dtype).mean(dim=(2, 3)))
+        h = self.middle_block(h, emb)
+        if self.pool.startswith("spatial"):
+            results.append(h.type(x.dtype).mean(dim=(2, 3)))
+            h = th.cat(results, axis=-1)
+            return self.out(h)
+        else:
+            h = h.type(x.dtype)
+            return self.out(h)
diff --git a/audioldm2/latent_diffusion/modules/diffusionmodules/util.py b/audioldm2/latent_diffusion/modules/diffusionmodules/util.py
new file mode 100755
index 0000000000000000000000000000000000000000..0d486f919a7ccc0586bc40225dac0ffb33aed01c
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/diffusionmodules/util.py
@@ -0,0 +1,294 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from audioldm2.latent_diffusion.util import instantiate_from_config
+
+
+def make_beta_schedule(
+    schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
+):
+    if schedule == "linear":
+        betas = (
+            torch.linspace(
+                linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
+            )
+            ** 2
+        )
+
+    elif schedule == "cosine":
+        timesteps = (
+            torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+        )
+        alphas = timesteps / (1 + cosine_s) * np.pi / 2
+        alphas = torch.cos(alphas).pow(2)
+        alphas = alphas / alphas[0]
+        betas = 1 - alphas[1:] / alphas[:-1]
+        betas = np.clip(betas, a_min=0, a_max=0.999)
+
+    elif schedule == "sqrt_linear":
+        betas = torch.linspace(
+            linear_start, linear_end, n_timestep, dtype=torch.float64
+        )
+    elif schedule == "sqrt":
+        betas = (
+            torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+            ** 0.5
+        )
+    else:
+        raise ValueError(f"schedule '{schedule}' unknown.")
+    return betas.numpy()
+
+
+def make_ddim_timesteps(
+    ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
+):
+    if ddim_discr_method == "uniform":
+        c = num_ddpm_timesteps // num_ddim_timesteps
+        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+    elif ddim_discr_method == "quad":
+        ddim_timesteps = (
+            (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
+        ).astype(int)
+    else:
+        raise NotImplementedError(
+            f'There is no ddim discretization method called "{ddim_discr_method}"'
+        )
+
+    # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+    # add one to get the final alpha values right (the ones from first scale to data during sampling)
+    steps_out = ddim_timesteps + 1
+    if verbose:
+        print(f"Selected timesteps for ddim sampler: {steps_out}")
+    return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+    # select alphas for computing the variance schedule
+    alphas = alphacums[ddim_timesteps]
+    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+    # according the the formula provided in https://arxiv.org/abs/2010.02502
+    sigmas = eta * np.sqrt(
+        (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
+    )
+    if verbose:
+        print(
+            f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
+        )
+        print(
+            f"For the chosen value of eta, which is {eta}, "
+            f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
+        )
+    return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+    """
+    Create a beta schedule that discretizes the given alpha_t_bar function,
+    which defines the cumulative product of (1-beta) over time from t = [0,1].
+    :param num_diffusion_timesteps: the number of betas to produce.
+    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+                      produces the cumulative product of (1-beta) up to that
+                      part of the diffusion process.
+    :param max_beta: the maximum beta to use; use values lower than 1 to
+                     prevent singularities.
+    """
+    betas = []
+    for i in range(num_diffusion_timesteps):
+        t1 = i / num_diffusion_timesteps
+        t2 = (i + 1) / num_diffusion_timesteps
+        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+    return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+    b, *_ = t.shape
+    out = a.gather(-1, t).contiguous()
+    return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
+
+
+def checkpoint(func, inputs, params, flag):
+    """
+    Evaluate a function without caching intermediate activations, allowing for
+    reduced memory at the expense of extra compute in the backward pass.
+    :param func: the function to evaluate.
+    :param inputs: the argument sequence to pass to `func`.
+    :param params: a sequence of parameters `func` depends on but does not
+                   explicitly take as arguments.
+    :param flag: if False, disable gradient checkpointing.
+    """
+    if flag:
+        args = tuple(inputs) + tuple(params)
+        return CheckpointFunction.apply(func, len(inputs), *args)
+    else:
+        return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, run_function, length, *args):
+        ctx.run_function = run_function
+        ctx.input_tensors = list(args[:length])
+        ctx.input_params = list(args[length:])
+
+        with torch.no_grad():
+            output_tensors = ctx.run_function(*ctx.input_tensors)
+        return output_tensors
+
+    @staticmethod
+    def backward(ctx, *output_grads):
+        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+        with torch.enable_grad():
+            # Fixes a bug where the first op in run_function modifies the
+            # Tensor storage in place, which is not allowed for detach()'d
+            # Tensors.
+            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+            output_tensors = ctx.run_function(*shallow_copies)
+        input_grads = torch.autograd.grad(
+            output_tensors,
+            ctx.input_tensors + ctx.input_params,
+            output_grads,
+            allow_unused=True,
+        )
+        del ctx.input_tensors
+        del ctx.input_params
+        del output_tensors
+        return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+    """
+    Create sinusoidal timestep embeddings.
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an [N x dim] Tensor of positional embeddings.
+    """
+    if not repeat_only:
+        half = dim // 2
+        freqs = torch.exp(
+            -math.log(max_period)
+            * torch.arange(start=0, end=half, dtype=torch.float32)
+            / half
+        ).to(device=timesteps.device)
+        args = timesteps[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat(
+                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+            )
+    else:
+        embedding = repeat(timesteps, "b -> b d", d=dim)
+    return embedding
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def scale_module(module, scale):
+    """
+    Scale the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().mul_(scale)
+    return module
+
+
+def mean_flat(tensor):
+    """
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+    """
+    Make a standard normalization layer.
+    :param channels: number of input channels.
+    :return: an nn.Module for normalization.
+    """
+    return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+    def forward(self, x):
+        return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D convolution module.
+    """
+    if dims == 1:
+        return nn.Conv1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.Conv2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.Conv3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+    """
+    Create a linear module.
+    """
+    return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D average pooling module.
+    """
+    if dims == 1:
+        return nn.AvgPool1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.AvgPool2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.AvgPool3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+    def __init__(self, c_concat_config, c_crossattn_config):
+        super().__init__()
+        self.concat_conditioner = instantiate_from_config(c_concat_config)
+        self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+    def forward(self, c_concat, c_crossattn):
+        c_concat = self.concat_conditioner(c_concat)
+        c_crossattn = self.crossattn_conditioner(c_crossattn)
+        return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
+        shape[0], *((1,) * (len(shape) - 1))
+    )
+    noise = lambda: torch.randn(shape, device=device)
+    return repeat_noise() if repeat else noise()
diff --git a/audioldm2/latent_diffusion/modules/ema.py b/audioldm2/latent_diffusion/modules/ema.py
new file mode 100755
index 0000000000000000000000000000000000000000..880ca3d205d9b4d7450e146930a93f2e63c58b70
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/ema.py
@@ -0,0 +1,82 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+    def __init__(self, model, decay=0.9999, use_num_upates=True):
+        super().__init__()
+        if decay < 0.0 or decay > 1.0:
+            raise ValueError("Decay must be between 0 and 1")
+
+        self.m_name2s_name = {}
+        self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+        self.register_buffer(
+            "num_updates",
+            torch.tensor(0, dtype=torch.int)
+            if use_num_upates
+            else torch.tensor(-1, dtype=torch.int),
+        )
+
+        for name, p in model.named_parameters():
+            if p.requires_grad:
+                # remove as '.'-character is not allowed in buffers
+                s_name = name.replace(".", "")
+                self.m_name2s_name.update({name: s_name})
+                self.register_buffer(s_name, p.clone().detach().data)
+
+        self.collected_params = []
+
+    def forward(self, model):
+        decay = self.decay
+
+        if self.num_updates >= 0:
+            self.num_updates += 1
+            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+        one_minus_decay = 1.0 - decay
+
+        with torch.no_grad():
+            m_param = dict(model.named_parameters())
+            shadow_params = dict(self.named_buffers())
+
+            for key in m_param:
+                if m_param[key].requires_grad:
+                    sname = self.m_name2s_name[key]
+                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+                    shadow_params[sname].sub_(
+                        one_minus_decay * (shadow_params[sname] - m_param[key])
+                    )
+                else:
+                    assert not key in self.m_name2s_name
+
+    def copy_to(self, model):
+        m_param = dict(model.named_parameters())
+        shadow_params = dict(self.named_buffers())
+        for key in m_param:
+            if m_param[key].requires_grad:
+                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+            else:
+                assert not key in self.m_name2s_name
+
+    def store(self, parameters):
+        """
+        Save the current parameters for restoring later.
+        Args:
+          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+            temporarily stored.
+        """
+        self.collected_params = [param.clone() for param in parameters]
+
+    def restore(self, parameters):
+        """
+        Restore the parameters stored with the `store` method.
+        Useful to validate the model with EMA parameters without affecting the
+        original optimization process. Store the parameters before the
+        `copy_to` method. After validation (or model saving), use this to
+        restore the former parameters.
+        Args:
+          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+            updated with the stored parameters.
+        """
+        for c_param, param in zip(self.collected_params, parameters):
+            param.data.copy_(c_param.data)
diff --git a/audioldm2/latent_diffusion/modules/encoders/__init__.py b/audioldm2/latent_diffusion/modules/encoders/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/modules/encoders/modules.py b/audioldm2/latent_diffusion/modules/encoders/modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..7a72339840c0c3b667e907ea07ee7cb755eb66fd
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/encoders/modules.py
@@ -0,0 +1,736 @@
+import torch
+import logging
+import torch.nn as nn
+from audioldm2.clap.open_clip import create_model
+from audioldm2.clap.training.data import get_audio_features
+import torchaudio
+from transformers import RobertaTokenizer, AutoTokenizer, T5EncoderModel
+import torch.nn.functional as F
+from audioldm2.latent_diffusion.modules.audiomae.AudioMAE import Vanilla_AudioMAE
+from audioldm2.latent_diffusion.modules.phoneme_encoder.encoder import TextEncoder
+
+from transformers import AutoTokenizer, T5Config
+
+from audioldm2.audiomae_gen.sequence_input import Sequence2AudioMAE
+import numpy as np
+
+"""
+The model forward function can return three types of data:
+1. tensor: used directly as conditioning signal
+2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc.
+3. list: the length is 2, in which the first element is tensor, the second element is attntion mask.
+
+The output shape for the cross attention condition should be:
+x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len]
+
+All the returned data, in which will be used as diffusion input, will need to be in float type
+"""
+
+
+class PhonemeEncoder(nn.Module):
+    def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None):
+        super().__init__()
+        """
+            encoder = PhonemeEncoder(40)
+            data = torch.randint(0, 39, (2, 250))
+            output = encoder(data)
+            import ipdb;ipdb.set_trace()
+        """
+        assert pad_token_id is not None
+
+        self.device = None
+        self.PAD_LENGTH = int(pad_length)
+        self.pad_token_id = pad_token_id
+        self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH)
+
+        self.text_encoder = TextEncoder(
+            n_vocab=vocabs_size,
+            out_channels=192,
+            hidden_channels=192,
+            filter_channels=768,
+            n_heads=2,
+            n_layers=6,
+            kernel_size=3,
+            p_dropout=0.1,
+        )
+
+        self.learnable_positional_embedding = torch.nn.Parameter(
+            torch.zeros((1, 192, self.PAD_LENGTH))
+        )  # [batchsize, seqlen, padlen]
+        self.learnable_positional_embedding.requires_grad = True
+
+    # Required
+    def get_unconditional_condition(self, batchsize):
+        unconditional_tokens = self.pad_token_sequence.expand(
+            batchsize, self.PAD_LENGTH
+        )
+        return self(unconditional_tokens)  # Need to return float type
+
+    # def get_unconditional_condition(self, batchsize):
+
+    #     hidden_state = torch.zeros((batchsize, self.PAD_LENGTH, 192)).to(self.device)
+    #     attention_mask = torch.ones((batchsize, self.PAD_LENGTH)).to(self.device)
+    #     return [hidden_state, attention_mask] # Need to return float type
+
+    def _get_src_mask(self, phoneme):
+        src_mask = phoneme != self.pad_token_id
+        return src_mask
+
+    def _get_src_length(self, phoneme):
+        src_mask = self._get_src_mask(phoneme)
+        length = torch.sum(src_mask, dim=-1)
+        return length
+
+    # def make_empty_condition_unconditional(self, src_length, text_emb, attention_mask):
+    #     # src_length: [bs]
+    #     # text_emb: [bs, 192, pad_length]
+    #     # attention_mask: [bs, pad_length]
+    #     mask = src_length[..., None, None] > 1
+    #     text_emb = text_emb * mask
+
+    #     attention_mask[src_length < 1] = attention_mask[src_length < 1] * 0.0 + 1.0
+    #     return text_emb, attention_mask
+
+    def forward(self, phoneme_idx):
+        if self.device is None:
+            self.device = self.learnable_positional_embedding.device
+            self.pad_token_sequence = self.pad_token_sequence.to(self.device)
+
+        src_length = self._get_src_length(phoneme_idx)
+        text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length)
+        text_emb = text_emb + self.learnable_positional_embedding
+
+        # text_emb, text_emb_mask = self.make_empty_condition_unconditional(src_length, text_emb, text_emb_mask)
+
+        return [
+            text_emb.permute(0, 2, 1),
+            text_emb_mask.squeeze(1),
+        ]  # [2, 250, 192], [2, 250]
+
+
+class FlanT5HiddenState(nn.Module):
+    """
+    llama = FlanT5HiddenState()
+    data = ["","this is not an empty sentence"]
+    encoder_hidden_states = llama(data)
+    import ipdb;ipdb.set_trace()
+    """
+
+    def __init__(
+        self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True
+    ):
+        super().__init__()
+        self.freeze_text_encoder = freeze_text_encoder
+        self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name)
+        self.model = T5EncoderModel(T5Config.from_pretrained(text_encoder_name))
+        if freeze_text_encoder:
+            self.model.eval()
+            for p in self.model.parameters():
+                p.requires_grad = False
+        else:
+            print("=> The text encoder is learnable")
+
+        self.empty_hidden_state_cfg = None
+        self.device = None
+
+    # Required
+    def get_unconditional_condition(self, batchsize):
+        param = next(self.model.parameters())
+        if self.freeze_text_encoder:
+            assert param.requires_grad == False
+
+        # device = param.device
+        if self.empty_hidden_state_cfg is None:
+            self.empty_hidden_state_cfg, _ = self([""])
+
+        hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float()
+        attention_mask = (
+            torch.ones((batchsize, hidden_state.size(1)))
+            .to(hidden_state.device)
+            .float()
+        )
+        return [hidden_state, attention_mask]  # Need to return float type
+
+    def forward(self, batch):
+        param = next(self.model.parameters())
+        if self.freeze_text_encoder:
+            assert param.requires_grad == False
+
+        if self.device is None:
+            self.device = param.device
+
+        # print("Manually change text")
+        # for i in range(len(batch)):
+        #     batch[i] = "dog barking"
+        try:
+            return self.encode_text(batch)
+        except Exception as e:
+            print(e, batch)
+            logging.exception("An error occurred: %s", str(e))
+
+    def encode_text(self, prompt):
+        device = self.model.device
+        batch = self.tokenizer(
+            prompt,
+            max_length=128,  # self.tokenizer.model_max_length
+            padding=True,
+            truncation=True,
+            return_tensors="pt",
+        )
+        input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
+            device
+        )
+        # Get text encoding
+        if self.freeze_text_encoder:
+            with torch.no_grad():
+                encoder_hidden_states = self.model(
+                    input_ids=input_ids, attention_mask=attention_mask
+                )[0]
+        else:
+            encoder_hidden_states = self.model(
+                input_ids=input_ids, attention_mask=attention_mask
+            )[0]
+        return [
+            encoder_hidden_states.detach(),
+            attention_mask.float(),
+        ]  # Attention mask == 1 means usable token
+
+
+class SequenceGenAudioMAECond(Sequence2AudioMAE):
+    def __init__(
+        self,
+        cond_stage_config,
+        base_learning_rate,
+        sequence_gen_length,
+        sequence_input_key,
+        sequence_input_embed_dim,
+        batchsize,
+        always_output_audiomae_gt=False,
+        pretrained_path=None,
+        force_reload_pretrain_avoid_overwrite=False,
+        learnable=True,
+        use_warmup=True,
+        device=None,
+        use_gt_mae_output=None,  # False: does not use AudioMAE GT, True: Use AudioMAE GT
+        use_gt_mae_prob=None,
+    ):  # The prob of using AudioMAE GT
+        if use_warmup:
+            use_warmup = False
+
+        super().__init__(
+            base_learning_rate=base_learning_rate,
+            cond_stage_config=cond_stage_config,
+            sequence_gen_length=sequence_gen_length,
+            sequence_input_key=sequence_input_key,
+            use_warmup=use_warmup,
+            sequence_input_embed_dim=sequence_input_embed_dim,
+            batchsize=batchsize,
+        )
+
+        assert use_gt_mae_output is not None and use_gt_mae_prob is not None
+        self.always_output_audiomae_gt = always_output_audiomae_gt
+        self.force_reload_pretrain_avoid_overwrite = (
+            force_reload_pretrain_avoid_overwrite
+        )
+        self.pretrained_path = pretrained_path
+        self.device = device
+        if self.force_reload_pretrain_avoid_overwrite:
+            self.is_reload = False
+        else:
+            self.is_reload = True
+
+        self.load_pretrain_model()
+
+        self.use_gt_mae_output = use_gt_mae_output
+        self.use_gt_mae_prob = use_gt_mae_prob
+        self.learnable = learnable
+
+        if not learnable:
+            # Only optimize the GPT2 model
+            for p in self.model.parameters():
+                p.requires_grad = False
+            self.eval()
+
+    def load_pretrain_model(self):
+        if self.pretrained_path is not None:
+            print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path)
+            state_dict = torch.load(self.pretrained_path)["state_dict"]
+            self.load_state_dict(state_dict)
+
+    # Required
+    def get_unconditional_condition(self, batchsize):
+        return_dict = self.cfg_uncond(batchsize)
+        return_dict["crossattn_audiomae_generated"] = [
+            return_dict["crossattn_audiomae_pooled"][0],
+            torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(),
+        ]
+        return return_dict
+
+    def forward(self, batch):
+        # The conditional module can return both tensor or dictionaries
+        # The returned tensor will be corresponding to the cond_stage_key
+        # The returned dict will have keys that correspond to the cond_stage_key
+        ret_dict = {}
+
+        if self.force_reload_pretrain_avoid_overwrite and not self.is_reload:
+            self.load_pretrain_model()
+            self.is_reload = True
+
+        # if(self.always_output_audiomae_gt or (self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob)):
+        #     cond_dict = self.get_input(batch)
+        #     ret_dict["crossattn_audiomae_generated"] = [cond_dict["crossattn_audiomae_pooled"][0], torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float()] # Input sequence and mask
+        # else:
+        input_embeds, cond_dict = self.generate(batch)
+        input_embeds_mask = (
+            torch.ones((input_embeds.size(0), input_embeds.size(1)))
+            .to(input_embeds.device)
+            .float()
+        )
+        ret_dict["crossattn_audiomae_generated"] = [
+            input_embeds,
+            input_embeds_mask,
+        ]  # Input sequence and mask
+
+        # If the following two keys are not in cond_stage_key, then they will not be used as condition
+        for key in cond_dict.keys():
+            ret_dict[key] = cond_dict[key]
+
+        return ret_dict
+
+
+class AudioMAEConditionCTPoolRandTFSeparated(nn.Module):
+    """
+    audiomae = AudioMAEConditionCTPool2x2()
+    data = torch.randn((4, 1024, 128))
+    output = audiomae(data)
+    import ipdb;ipdb.set_trace()
+    exit(0)
+    """
+
+    def __init__(
+        self,
+        time_pooling_factors=[1, 2, 4, 8],
+        freq_pooling_factors=[1, 2, 4, 8],
+        eval_time_pooling=None,
+        eval_freq_pooling=None,
+        mask_ratio=0.0,
+        regularization=False,
+        no_audiomae_mask=True,
+        no_audiomae_average=False,
+    ):
+        super().__init__()
+        self.device = None
+        self.time_pooling_factors = time_pooling_factors
+        self.freq_pooling_factors = freq_pooling_factors
+        self.no_audiomae_mask = no_audiomae_mask
+        self.no_audiomae_average = no_audiomae_average
+
+        self.eval_freq_pooling = eval_freq_pooling
+        self.eval_time_pooling = eval_time_pooling
+        self.mask_ratio = mask_ratio
+        self.use_reg = regularization
+
+        self.audiomae = Vanilla_AudioMAE()
+        self.audiomae.eval()
+        for p in self.audiomae.parameters():
+            p.requires_grad = False
+
+    # Required
+    def get_unconditional_condition(self, batchsize):
+        param = next(self.audiomae.parameters())
+        assert param.requires_grad == False
+        device = param.device
+        # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
+        time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
+            self.eval_freq_pooling, 8
+        )
+        # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
+        # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
+        token_num = int(512 / (time_pool * freq_pool))
+        return [
+            torch.zeros((batchsize, token_num, 768)).to(device).float(),
+            torch.ones((batchsize, token_num)).to(device).float(),
+        ]
+
+    def pool(self, representation, time_pool=None, freq_pool=None):
+        assert representation.size(-1) == 768
+        representation = representation[:, 1:, :].transpose(1, 2)
+        bs, embedding_dim, token_num = representation.size()
+        representation = representation.reshape(bs, embedding_dim, 64, 8)
+
+        if self.training:
+            if time_pool is None and freq_pool is None:
+                time_pool = min(
+                    64,
+                    self.time_pooling_factors[
+                        np.random.choice(list(range(len(self.time_pooling_factors))))
+                    ],
+                )
+                freq_pool = min(
+                    8,
+                    self.freq_pooling_factors[
+                        np.random.choice(list(range(len(self.freq_pooling_factors))))
+                    ],
+                )
+                # freq_pool = min(8, time_pool) # TODO here I make some modification.
+        else:
+            time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
+                self.eval_freq_pooling, 8
+            )
+
+        self.avgpooling = nn.AvgPool2d(
+            kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
+        )
+        self.maxpooling = nn.MaxPool2d(
+            kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
+        )
+
+        pooled = (
+            self.avgpooling(representation) + self.maxpooling(representation)
+        ) / 2  # [bs, embedding_dim, time_token_num, freq_token_num]
+        pooled = pooled.flatten(2).transpose(1, 2)
+        return pooled  # [bs, token_num, embedding_dim]
+
+    def regularization(self, x):
+        assert x.size(-1) == 768
+        x = F.normalize(x, p=2, dim=-1)
+        return x
+
+    # Required
+    def forward(self, batch, time_pool=None, freq_pool=None):
+        assert batch.size(-2) == 1024 and batch.size(-1) == 128
+
+        if self.device is None:
+            self.device = batch.device
+
+        batch = batch.unsqueeze(1)
+        with torch.no_grad():
+            representation = self.audiomae(
+                batch,
+                mask_ratio=self.mask_ratio,
+                no_mask=self.no_audiomae_mask,
+                no_average=self.no_audiomae_average,
+            )
+            representation = self.pool(representation, time_pool, freq_pool)
+            if self.use_reg:
+                representation = self.regularization(representation)
+            return [
+                representation,
+                torch.ones((representation.size(0), representation.size(1)))
+                .to(representation.device)
+                .float(),
+            ]
+
+
+class AudioMAEConditionCTPoolRand(nn.Module):
+    """
+    audiomae = AudioMAEConditionCTPool2x2()
+    data = torch.randn((4, 1024, 128))
+    output = audiomae(data)
+    import ipdb;ipdb.set_trace()
+    exit(0)
+    """
+
+    def __init__(
+        self,
+        time_pooling_factors=[1, 2, 4, 8],
+        freq_pooling_factors=[1, 2, 4, 8],
+        eval_time_pooling=None,
+        eval_freq_pooling=None,
+        mask_ratio=0.0,
+        regularization=False,
+        no_audiomae_mask=True,
+        no_audiomae_average=False,
+    ):
+        super().__init__()
+        self.device = None
+        self.time_pooling_factors = time_pooling_factors
+        self.freq_pooling_factors = freq_pooling_factors
+        self.no_audiomae_mask = no_audiomae_mask
+        self.no_audiomae_average = no_audiomae_average
+
+        self.eval_freq_pooling = eval_freq_pooling
+        self.eval_time_pooling = eval_time_pooling
+        self.mask_ratio = mask_ratio
+        self.use_reg = regularization
+
+        self.audiomae = Vanilla_AudioMAE()
+        self.audiomae.eval()
+        for p in self.audiomae.parameters():
+            p.requires_grad = False
+
+    # Required
+    def get_unconditional_condition(self, batchsize):
+        param = next(self.audiomae.parameters())
+        assert param.requires_grad == False
+        device = param.device
+        # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
+        time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
+            self.eval_freq_pooling, 8
+        )
+        # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
+        # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
+        token_num = int(512 / (time_pool * freq_pool))
+        return [
+            torch.zeros((batchsize, token_num, 768)).to(device).float(),
+            torch.ones((batchsize, token_num)).to(device).float(),
+        ]
+
+    def pool(self, representation, time_pool=None, freq_pool=None):
+        assert representation.size(-1) == 768
+        representation = representation[:, 1:, :].transpose(1, 2)
+        bs, embedding_dim, token_num = representation.size()
+        representation = representation.reshape(bs, embedding_dim, 64, 8)
+
+        if self.training:
+            if time_pool is None and freq_pool is None:
+                time_pool = min(
+                    64,
+                    self.time_pooling_factors[
+                        np.random.choice(list(range(len(self.time_pooling_factors))))
+                    ],
+                )
+                # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
+                freq_pool = min(8, time_pool)  # TODO here I make some modification.
+        else:
+            time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
+                self.eval_freq_pooling, 8
+            )
+
+        self.avgpooling = nn.AvgPool2d(
+            kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
+        )
+        self.maxpooling = nn.MaxPool2d(
+            kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
+        )
+
+        pooled = (
+            self.avgpooling(representation) + self.maxpooling(representation)
+        ) / 2  # [bs, embedding_dim, time_token_num, freq_token_num]
+        pooled = pooled.flatten(2).transpose(1, 2)
+        return pooled  # [bs, token_num, embedding_dim]
+
+    def regularization(self, x):
+        assert x.size(-1) == 768
+        x = F.normalize(x, p=2, dim=-1)
+        return x
+
+    # Required
+    def forward(self, batch, time_pool=None, freq_pool=None):
+        assert batch.size(-2) == 1024 and batch.size(-1) == 128
+
+        if self.device is None:
+            self.device = batch.device
+
+        batch = batch.unsqueeze(1)
+        with torch.no_grad():
+            representation = self.audiomae(
+                batch,
+                mask_ratio=self.mask_ratio,
+                no_mask=self.no_audiomae_mask,
+                no_average=self.no_audiomae_average,
+            )
+            representation = self.pool(representation, time_pool, freq_pool)
+            if self.use_reg:
+                representation = self.regularization(representation)
+            return [
+                representation,
+                torch.ones((representation.size(0), representation.size(1)))
+                .to(representation.device)
+                .float(),
+            ]
+
+
+class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
+    def __init__(
+        self,
+        pretrained_path="",
+        sampling_rate=16000,
+        embed_mode="audio",
+        amodel="HTSAT-base",
+        unconditional_prob=0.1,
+        random_mute=False,
+        max_random_mute_portion=0.5,
+        training_mode=True,
+    ):
+        super().__init__()
+        self.device = "cpu"
+        self.precision = "fp32"
+        self.amodel = amodel  # or 'PANN-14'
+        self.tmodel = "roberta"  # the best text encoder in our training
+        self.enable_fusion = False  # False if you do not want to use the fusion model
+        self.fusion_type = "aff_2d"
+        self.pretrained = pretrained_path
+        self.embed_mode = embed_mode
+        self.embed_mode_orig = embed_mode
+        self.sampling_rate = sampling_rate
+        self.unconditional_prob = unconditional_prob
+        self.random_mute = random_mute
+        self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
+        self.max_random_mute_portion = max_random_mute_portion
+        self.training_mode = training_mode
+        self.model, self.model_cfg = create_model(
+            self.amodel,
+            self.tmodel,
+            self.pretrained,
+            precision=self.precision,
+            device=self.device,
+            enable_fusion=self.enable_fusion,
+            fusion_type=self.fusion_type,
+        )
+        audio_cfg = self.model_cfg["audio_cfg"]
+        self.mel_transform = torchaudio.transforms.MelSpectrogram(
+            sample_rate=audio_cfg["sample_rate"],
+            n_fft=audio_cfg["window_size"],
+            win_length=audio_cfg["window_size"],
+            hop_length=audio_cfg["hop_size"],
+            center=True,
+            pad_mode="reflect",
+            power=2.0,
+            norm=None,
+            onesided=True,
+            n_mels=64,
+            f_min=audio_cfg["fmin"],
+            f_max=audio_cfg["fmax"],
+        )
+        for p in self.model.parameters():
+            p.requires_grad = False
+        self.unconditional_token = None
+        self.model.eval()
+
+    def get_unconditional_condition(self, batchsize):
+        self.unconditional_token = self.model.get_text_embedding(
+            self.tokenizer(["", ""])
+        )[0:1]
+        return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
+
+    def batch_to_list(self, batch):
+        ret = []
+        for i in range(batch.size(0)):
+            ret.append(batch[i])
+        return ret
+
+    def make_decision(self, probability):
+        if float(torch.rand(1)) < probability:
+            return True
+        else:
+            return False
+
+    def random_uniform(self, start, end):
+        val = torch.rand(1).item()
+        return start + (end - start) * val
+
+    def _random_mute(self, waveform):
+        # waveform: [bs, t-steps]
+        t_steps = waveform.size(-1)
+        for i in range(waveform.size(0)):
+            mute_size = int(
+                self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
+            )
+            mute_start = int(self.random_uniform(0, t_steps - mute_size))
+            waveform[i, mute_start : mute_start + mute_size] = 0
+        return waveform
+
+    def cos_similarity(self, waveform, text):
+        # waveform: [bs, t_steps]
+        original_embed_mode = self.embed_mode
+        with torch.no_grad():
+            self.embed_mode = "audio"
+            audio_emb = self(waveform.cuda())
+            self.embed_mode = "text"
+            text_emb = self(text)
+            similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
+        self.embed_mode = original_embed_mode
+        return similarity.squeeze()
+
+    def build_unconditional_emb(self):
+        self.unconditional_token = self.model.get_text_embedding(
+            self.tokenizer(["", ""])
+        )[0:1]
+
+    def forward(self, batch):
+        # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
+        # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
+        if self.model.training == True and not self.training_mode:
+            print(
+                "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
+            )
+            self.model, self.model_cfg = create_model(
+                self.amodel,
+                self.tmodel,
+                self.pretrained,
+                precision=self.precision,
+                device="cuda",
+                enable_fusion=self.enable_fusion,
+                fusion_type=self.fusion_type,
+            )
+            for p in self.model.parameters():
+                p.requires_grad = False
+            self.model.eval()
+
+        if self.unconditional_token is None:
+            self.build_unconditional_emb()
+
+        # if(self.training_mode):
+        #     assert self.model.training == True
+        # else:
+        #     assert self.model.training == False
+
+        # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
+        if self.embed_mode == "audio":
+            if not self.training:
+                print("INFO: clap model calculate the audio embedding as condition")
+            with torch.no_grad():
+                # assert (
+                #     self.sampling_rate == 16000
+                # ), "We only support 16000 sampling rate"
+
+                # if self.random_mute:
+                #     batch = self._random_mute(batch)
+                # batch: [bs, 1, t-samples]
+                if self.sampling_rate != 48000:
+                    batch = torchaudio.functional.resample(
+                        batch, orig_freq=self.sampling_rate, new_freq=48000
+                    )
+
+                audio_data = batch.squeeze(1)
+                mel = self.mel_transform(audio_data)
+                audio_dict = get_audio_features(
+                    audio_data,
+                    mel,
+                    480000,
+                    data_truncating="fusion",
+                    data_filling="repeatpad",
+                    audio_cfg=self.model_cfg["audio_cfg"],
+                )
+                # [bs, 512]
+                embed = self.model.get_audio_embedding(audio_dict)
+        elif self.embed_mode == "text":
+            with torch.no_grad():
+                # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
+                text_data = self.tokenizer(batch)
+
+                if isinstance(batch, str) or (
+                    isinstance(batch, list) and len(batch) == 1
+                ):
+                    for key in text_data.keys():
+                        text_data[key] = text_data[key].unsqueeze(0)
+
+                embed = self.model.get_text_embedding(text_data)
+
+        embed = embed.unsqueeze(1)
+        for i in range(embed.size(0)):
+            if self.make_decision(self.unconditional_prob):
+                embed[i] = self.unconditional_token
+        # embed = torch.randn((batch.size(0), 1, 512)).type_as(batch)
+        return embed.detach()
+
+    def tokenizer(self, text):
+        result = self.tokenize(
+            text,
+            padding="max_length",
+            truncation=True,
+            max_length=512,
+            return_tensors="pt",
+        )
+        return {k: v.squeeze(0) for k, v in result.items()}
diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/__init__.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py
new file mode 100755
index 0000000000000000000000000000000000000000..3553a688d41b07a45a7ced25f740a55dbc0b6d94
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/phoneme_encoder/attentions.py
@@ -0,0 +1,430 @@
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+import audioldm2.latent_diffusion.modules.phoneme_encoder.commons as commons
+
+LRELU_SLOPE = 0.1
+
+
+class LayerNorm(nn.Module):
+    def __init__(self, channels, eps=1e-5):
+        super().__init__()
+        self.channels = channels
+        self.eps = eps
+
+        self.gamma = nn.Parameter(torch.ones(channels))
+        self.beta = nn.Parameter(torch.zeros(channels))
+
+    def forward(self, x):
+        x = x.transpose(1, -1)
+        x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+        return x.transpose(1, -1)
+
+
+class Encoder(nn.Module):
+    def __init__(
+        self,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        kernel_size=1,
+        p_dropout=0.0,
+        window_size=4,
+        **kwargs
+    ):
+        super().__init__()
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+
+        self.drop = nn.Dropout(p_dropout)
+        self.attn_layers = nn.ModuleList()
+        self.norm_layers_1 = nn.ModuleList()
+        self.ffn_layers = nn.ModuleList()
+        self.norm_layers_2 = nn.ModuleList()
+        for i in range(self.n_layers):
+            self.attn_layers.append(
+                MultiHeadAttention(
+                    hidden_channels,
+                    hidden_channels,
+                    n_heads,
+                    p_dropout=p_dropout,
+                    window_size=window_size,
+                )
+            )
+            self.norm_layers_1.append(LayerNorm(hidden_channels))
+            self.ffn_layers.append(
+                FFN(
+                    hidden_channels,
+                    hidden_channels,
+                    filter_channels,
+                    kernel_size,
+                    p_dropout=p_dropout,
+                )
+            )
+            self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+    def forward(self, x, x_mask):
+        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+        x = x * x_mask
+        for i in range(self.n_layers):
+            y = self.attn_layers[i](x, x, attn_mask)
+            y = self.drop(y)
+            x = self.norm_layers_1[i](x + y)
+
+            y = self.ffn_layers[i](x, x_mask)
+            y = self.drop(y)
+            x = self.norm_layers_2[i](x + y)
+        x = x * x_mask
+        return x
+
+
+class Decoder(nn.Module):
+    def __init__(
+        self,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        kernel_size=1,
+        p_dropout=0.0,
+        proximal_bias=False,
+        proximal_init=True,
+        **kwargs
+    ):
+        super().__init__()
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.proximal_bias = proximal_bias
+        self.proximal_init = proximal_init
+
+        self.drop = nn.Dropout(p_dropout)
+        self.self_attn_layers = nn.ModuleList()
+        self.norm_layers_0 = nn.ModuleList()
+        self.encdec_attn_layers = nn.ModuleList()
+        self.norm_layers_1 = nn.ModuleList()
+        self.ffn_layers = nn.ModuleList()
+        self.norm_layers_2 = nn.ModuleList()
+        for i in range(self.n_layers):
+            self.self_attn_layers.append(
+                MultiHeadAttention(
+                    hidden_channels,
+                    hidden_channels,
+                    n_heads,
+                    p_dropout=p_dropout,
+                    proximal_bias=proximal_bias,
+                    proximal_init=proximal_init,
+                )
+            )
+            self.norm_layers_0.append(LayerNorm(hidden_channels))
+            self.encdec_attn_layers.append(
+                MultiHeadAttention(
+                    hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
+                )
+            )
+            self.norm_layers_1.append(LayerNorm(hidden_channels))
+            self.ffn_layers.append(
+                FFN(
+                    hidden_channels,
+                    hidden_channels,
+                    filter_channels,
+                    kernel_size,
+                    p_dropout=p_dropout,
+                    causal=True,
+                )
+            )
+            self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+    def forward(self, x, x_mask, h, h_mask):
+        """
+        x: decoder input
+        h: encoder output
+        """
+        self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
+            device=x.device, dtype=x.dtype
+        )
+        encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+        x = x * x_mask
+        for i in range(self.n_layers):
+            y = self.self_attn_layers[i](x, x, self_attn_mask)
+            y = self.drop(y)
+            x = self.norm_layers_0[i](x + y)
+
+            y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
+            y = self.drop(y)
+            x = self.norm_layers_1[i](x + y)
+
+            y = self.ffn_layers[i](x, x_mask)
+            y = self.drop(y)
+            x = self.norm_layers_2[i](x + y)
+        x = x * x_mask
+        return x
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(
+        self,
+        channels,
+        out_channels,
+        n_heads,
+        p_dropout=0.0,
+        window_size=None,
+        heads_share=True,
+        block_length=None,
+        proximal_bias=False,
+        proximal_init=False,
+    ):
+        super().__init__()
+        assert channels % n_heads == 0
+
+        self.channels = channels
+        self.out_channels = out_channels
+        self.n_heads = n_heads
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+        self.heads_share = heads_share
+        self.block_length = block_length
+        self.proximal_bias = proximal_bias
+        self.proximal_init = proximal_init
+        self.attn = None
+
+        self.k_channels = channels // n_heads
+        self.conv_q = nn.Conv1d(channels, channels, 1)
+        self.conv_k = nn.Conv1d(channels, channels, 1)
+        self.conv_v = nn.Conv1d(channels, channels, 1)
+        self.conv_o = nn.Conv1d(channels, out_channels, 1)
+        self.drop = nn.Dropout(p_dropout)
+
+        if window_size is not None:
+            n_heads_rel = 1 if heads_share else n_heads
+            rel_stddev = self.k_channels**-0.5
+            self.emb_rel_k = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+            self.emb_rel_v = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+
+        nn.init.xavier_uniform_(self.conv_q.weight)
+        nn.init.xavier_uniform_(self.conv_k.weight)
+        nn.init.xavier_uniform_(self.conv_v.weight)
+        if proximal_init:
+            with torch.no_grad():
+                self.conv_k.weight.copy_(self.conv_q.weight)
+                self.conv_k.bias.copy_(self.conv_q.bias)
+
+    def forward(self, x, c, attn_mask=None):
+        q = self.conv_q(x)
+        k = self.conv_k(c)
+        v = self.conv_v(c)
+
+        x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+        x = self.conv_o(x)
+        return x
+
+    def attention(self, query, key, value, mask=None):
+        # reshape [b, d, t] -> [b, n_h, t, d_k]
+        b, d, t_s, t_t = (*key.size(), query.size(2))
+        query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+        key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+        value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+        scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+        if self.window_size is not None:
+            assert (
+                t_s == t_t
+            ), "Relative attention is only available for self-attention."
+            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+            rel_logits = self._matmul_with_relative_keys(
+                query / math.sqrt(self.k_channels), key_relative_embeddings
+            )
+            scores_local = self._relative_position_to_absolute_position(rel_logits)
+            scores = scores + scores_local
+        if self.proximal_bias:
+            assert t_s == t_t, "Proximal bias is only available for self-attention."
+            scores = scores + self._attention_bias_proximal(t_s).to(
+                device=scores.device, dtype=scores.dtype
+            )
+        if mask is not None:
+            scores = scores.masked_fill(mask == 0, -1e4)
+            if self.block_length is not None:
+                assert (
+                    t_s == t_t
+                ), "Local attention is only available for self-attention."
+                block_mask = (
+                    torch.ones_like(scores)
+                    .triu(-self.block_length)
+                    .tril(self.block_length)
+                )
+                scores = scores.masked_fill(block_mask == 0, -1e4)
+        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
+        p_attn = self.drop(p_attn)
+        output = torch.matmul(p_attn, value)
+        if self.window_size is not None:
+            relative_weights = self._absolute_position_to_relative_position(p_attn)
+            value_relative_embeddings = self._get_relative_embeddings(
+                self.emb_rel_v, t_s
+            )
+            output = output + self._matmul_with_relative_values(
+                relative_weights, value_relative_embeddings
+            )
+        output = (
+            output.transpose(2, 3).contiguous().view(b, d, t_t)
+        )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
+        return output, p_attn
+
+    def _matmul_with_relative_values(self, x, y):
+        """
+        x: [b, h, l, m]
+        y: [h or 1, m, d]
+        ret: [b, h, l, d]
+        """
+        ret = torch.matmul(x, y.unsqueeze(0))
+        return ret
+
+    def _matmul_with_relative_keys(self, x, y):
+        """
+        x: [b, h, l, d]
+        y: [h or 1, m, d]
+        ret: [b, h, l, m]
+        """
+        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+        return ret
+
+    def _get_relative_embeddings(self, relative_embeddings, length):
+        2 * self.window_size + 1
+        # Pad first before slice to avoid using cond ops.
+        pad_length = max(length - (self.window_size + 1), 0)
+        slice_start_position = max((self.window_size + 1) - length, 0)
+        slice_end_position = slice_start_position + 2 * length - 1
+        if pad_length > 0:
+            padded_relative_embeddings = F.pad(
+                relative_embeddings,
+                commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+            )
+        else:
+            padded_relative_embeddings = relative_embeddings
+        used_relative_embeddings = padded_relative_embeddings[
+            :, slice_start_position:slice_end_position
+        ]
+        return used_relative_embeddings
+
+    def _relative_position_to_absolute_position(self, x):
+        """
+        x: [b, h, l, 2*l-1]
+        ret: [b, h, l, l]
+        """
+        batch, heads, length, _ = x.size()
+        # Concat columns of pad to shift from relative to absolute indexing.
+        x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+        # Concat extra elements so to add up to shape (len+1, 2*len-1).
+        x_flat = x.view([batch, heads, length * 2 * length])
+        x_flat = F.pad(
+            x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
+        )
+
+        # Reshape and slice out the padded elements.
+        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+            :, :, :length, length - 1 :
+        ]
+        return x_final
+
+    def _absolute_position_to_relative_position(self, x):
+        """
+        x: [b, h, l, l]
+        ret: [b, h, l, 2*l-1]
+        """
+        batch, heads, length, _ = x.size()
+        # padd along column
+        x = F.pad(
+            x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
+        )
+        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+        # add 0's in the beginning that will skew the elements after reshape
+        x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+        return x_final
+
+    def _attention_bias_proximal(self, length):
+        """Bias for self-attention to encourage attention to close positions.
+        Args:
+          length: an integer scalar.
+        Returns:
+          a Tensor with shape [1, 1, length, length]
+        """
+        r = torch.arange(length, dtype=torch.float32)
+        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        filter_channels,
+        kernel_size,
+        p_dropout=0.0,
+        activation=None,
+        causal=False,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.filter_channels = filter_channels
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.activation = activation
+        self.causal = causal
+
+        if causal:
+            self.padding = self._causal_padding
+        else:
+            self.padding = self._same_padding
+
+        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+        self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+        self.drop = nn.Dropout(p_dropout)
+
+    def forward(self, x, x_mask):
+        x = self.conv_1(self.padding(x * x_mask))
+        if self.activation == "gelu":
+            x = x * torch.sigmoid(1.702 * x)
+        else:
+            x = torch.relu(x)
+        x = self.drop(x)
+        x = self.conv_2(self.padding(x * x_mask))
+        return x * x_mask
+
+    def _causal_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = self.kernel_size - 1
+        pad_r = 0
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, commons.convert_pad_shape(padding))
+        return x
+
+    def _same_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = (self.kernel_size - 1) // 2
+        pad_r = self.kernel_size // 2
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, commons.convert_pad_shape(padding))
+        return x
diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py
new file mode 100755
index 0000000000000000000000000000000000000000..9515724c12ab2f856b9a2ec14e38cc63df9b85d6
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/phoneme_encoder/commons.py
@@ -0,0 +1,161 @@
+import math
+import torch
+from torch.nn import functional as F
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def intersperse(lst, item):
+    result = [item] * (len(lst) * 2 + 1)
+    result[1::2] = lst
+    return result
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+    """KL(P||Q)"""
+    kl = (logs_q - logs_p) - 0.5
+    kl += (
+        0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
+    )
+    return kl
+
+
+def rand_gumbel(shape):
+    """Sample from the Gumbel distribution, protect from overflows."""
+    uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+    return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+    g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+    return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+    ret = torch.zeros_like(x[:, :, :segment_size])
+    for i in range(x.size(0)):
+        idx_str = ids_str[i]
+        idx_end = idx_str + segment_size
+        ret[i] = x[i, :, idx_str:idx_end]
+    return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+    b, d, t = x.size()
+    if x_lengths is None:
+        x_lengths = t
+    ids_str_max = x_lengths - segment_size + 1
+    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+    ret = slice_segments(x, ids_str, segment_size)
+    return ret, ids_str
+
+
+def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
+    position = torch.arange(length, dtype=torch.float)
+    num_timescales = channels // 2
+    log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
+        num_timescales - 1
+    )
+    inv_timescales = min_timescale * torch.exp(
+        torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
+    )
+    scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+    signal = F.pad(signal, [0, 0, 0, channels % 2])
+    signal = signal.view(1, channels, length)
+    return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+    b, channels, length = x.size()
+    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+    return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+    b, channels, length = x.size()
+    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+    return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+    mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+    return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+    n_channels_int = n_channels[0]
+    in_act = input_a + input_b
+    t_act = torch.tanh(in_act[:, :n_channels_int, :])
+    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+    acts = t_act * s_act
+    return acts
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def shift_1d(x):
+    x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+    return x
+
+
+def sequence_mask(length, max_length=None):
+    if max_length is None:
+        max_length = length.max()
+    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+    return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def generate_path(duration, mask):
+    """
+    duration: [b, 1, t_x]
+    mask: [b, 1, t_y, t_x]
+    """
+    duration.device
+
+    b, _, t_y, t_x = mask.shape
+    cum_duration = torch.cumsum(duration, -1)
+
+    cum_duration_flat = cum_duration.view(b * t_x)
+    path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+    path = path.view(b, t_x, t_y)
+    path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+    path = path.unsqueeze(1).transpose(2, 3) * mask
+    return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = list(filter(lambda p: p.grad is not None, parameters))
+    norm_type = float(norm_type)
+    if clip_value is not None:
+        clip_value = float(clip_value)
+
+    total_norm = 0
+    for p in parameters:
+        param_norm = p.grad.data.norm(norm_type)
+        total_norm += param_norm.item() ** norm_type
+        if clip_value is not None:
+            p.grad.data.clamp_(min=-clip_value, max=clip_value)
+    total_norm = total_norm ** (1.0 / norm_type)
+    return total_norm
diff --git a/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py b/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..b39bf583b5ea88a4771181e491c8deb92b2d7559
--- /dev/null
+++ b/audioldm2/latent_diffusion/modules/phoneme_encoder/encoder.py
@@ -0,0 +1,50 @@
+import math
+import torch
+from torch import nn
+
+import audioldm2.latent_diffusion.modules.phoneme_encoder.commons as commons
+import audioldm2.latent_diffusion.modules.phoneme_encoder.attentions as attentions
+
+
+class TextEncoder(nn.Module):
+    def __init__(
+        self,
+        n_vocab,
+        out_channels=192,
+        hidden_channels=192,
+        filter_channels=768,
+        n_heads=2,
+        n_layers=6,
+        kernel_size=3,
+        p_dropout=0.1,
+    ):
+        super().__init__()
+        self.n_vocab = n_vocab
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+
+        self.emb = nn.Embedding(n_vocab, hidden_channels)
+        nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
+
+        self.encoder = attentions.Encoder(
+            hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
+        )
+        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+    def forward(self, x, x_lengths):
+        x = self.emb(x) * math.sqrt(self.hidden_channels)  # [b, t, h]
+        x = torch.transpose(x, 1, -1)  # [b, h, t]
+        x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
+            x.dtype
+        )
+
+        x = self.encoder(x * x_mask, x_mask)
+        stats = self.proj(x) * x_mask
+
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        return x, m, logs, x_mask
diff --git a/audioldm2/latent_diffusion/util.py b/audioldm2/latent_diffusion/util.py
new file mode 100755
index 0000000000000000000000000000000000000000..3dd301b1c0a39a5b905aa23f4b98d224df7d87d9
--- /dev/null
+++ b/audioldm2/latent_diffusion/util.py
@@ -0,0 +1,217 @@
+import importlib
+
+import torch
+import numpy as np
+from collections import abc
+
+import multiprocessing as mp
+from threading import Thread
+from queue import Queue
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+    # wh a tuple of (width, height)
+    # xc a list of captions to plot
+    b = len(xc)
+    txts = list()
+    for bi in range(b):
+        txt = Image.new("RGB", wh, color="white")
+        draw = ImageDraw.Draw(txt)
+        font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+        nc = int(40 * (wh[0] / 256))
+        lines = "\n".join(
+            xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
+        )
+
+        try:
+            draw.text((0, 0), lines, fill="black", font=font)
+        except UnicodeEncodeError:
+            print("Cant encode string for logging. Skipping.")
+
+        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+        txts.append(txt)
+    txts = np.stack(txts)
+    txts = torch.tensor(txts)
+    return txts
+
+
+def ismap(x):
+    if not isinstance(x, torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+    if not isinstance(x, torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def int16_to_float32(x):
+    return (x / 32767.0).astype(np.float32)
+
+
+def float32_to_int16(x):
+    x = np.clip(x, a_min=-1.0, a_max=1.0)
+    return (x * 32767.0).astype(np.int16)
+
+
+def exists(x):
+    return x is not None
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+    """
+    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+    total_params = sum(p.numel() for p in model.parameters())
+    if verbose:
+        print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+    return total_params
+
+
+def instantiate_from_config(config):
+    if not "target" in config:
+        if config == "__is_first_stage__":
+            return None
+        elif config == "__is_unconditional__":
+            return None
+        raise KeyError("Expected key `target` to instantiate.")
+    return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+    module, cls = string.rsplit(".", 1)
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
+
+
+def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
+    # create dummy dataset instance
+
+    # run prefetching
+    if idx_to_fn:
+        res = func(data, worker_id=idx)
+    else:
+        res = func(data)
+    Q.put([idx, res])
+    Q.put("Done")
+
+
+def parallel_data_prefetch(
+    func: callable,
+    data,
+    n_proc,
+    target_data_type="ndarray",
+    cpu_intensive=True,
+    use_worker_id=False,
+):
+    # if target_data_type not in ["ndarray", "list"]:
+    #     raise ValueError(
+    #         "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
+    #     )
+    if isinstance(data, np.ndarray) and target_data_type == "list":
+        raise ValueError("list expected but function got ndarray.")
+    elif isinstance(data, abc.Iterable):
+        if isinstance(data, dict):
+            print(
+                f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
+            )
+            data = list(data.values())
+        if target_data_type == "ndarray":
+            data = np.asarray(data)
+        else:
+            data = list(data)
+    else:
+        raise TypeError(
+            f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
+        )
+
+    if cpu_intensive:
+        Q = mp.Queue(1000)
+        proc = mp.Process
+    else:
+        Q = Queue(1000)
+        proc = Thread
+    # spawn processes
+    if target_data_type == "ndarray":
+        arguments = [
+            [func, Q, part, i, use_worker_id]
+            for i, part in enumerate(np.array_split(data, n_proc))
+        ]
+    else:
+        step = (
+            int(len(data) / n_proc + 1)
+            if len(data) % n_proc != 0
+            else int(len(data) / n_proc)
+        )
+        arguments = [
+            [func, Q, part, i, use_worker_id]
+            for i, part in enumerate(
+                [data[i : i + step] for i in range(0, len(data), step)]
+            )
+        ]
+    processes = []
+    for i in range(n_proc):
+        p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
+        processes += [p]
+
+    # start processes
+    print(f"Start prefetching...")
+    import time
+
+    start = time.time()
+    gather_res = [[] for _ in range(n_proc)]
+    try:
+        for p in processes:
+            p.start()
+
+        k = 0
+        while k < n_proc:
+            # get result
+            res = Q.get()
+            if res == "Done":
+                k += 1
+            else:
+                gather_res[res[0]] = res[1]
+
+    except Exception as e:
+        print("Exception: ", e)
+        for p in processes:
+            p.terminate()
+
+        raise e
+    finally:
+        for p in processes:
+            p.join()
+        print(f"Prefetching complete. [{time.time() - start} sec.]")
+
+    if target_data_type == "ndarray":
+        if not isinstance(gather_res[0], np.ndarray):
+            return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
+
+        # order outputs
+        return np.concatenate(gather_res, axis=0)
+    elif target_data_type == "list":
+        out = []
+        for r in gather_res:
+            out.extend(r)
+        return out
+    else:
+        return gather_res
diff --git a/audioldm2/latent_encoder/__init__.py b/audioldm2/latent_encoder/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm2/latent_encoder/autoencoder.py b/audioldm2/latent_encoder/autoencoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..f07075bb76a34edd8568797961752e4957129f92
--- /dev/null
+++ b/audioldm2/latent_encoder/autoencoder.py
@@ -0,0 +1,326 @@
+import torch
+import os
+
+import torch.nn.functional as F
+import numpy as np
+from audioldm2.latent_diffusion.modules.ema import *
+
+from audioldm2.latent_diffusion.modules.diffusionmodules.model import Encoder, Decoder
+from audioldm2.latent_diffusion.modules.distributions.distributions import (
+    DiagonalGaussianDistribution,
+)
+import soundfile as sf
+
+from audioldm2.utilities.model import get_vocoder
+from audioldm2.utilities.tools import synth_one_sample
+
+
+class AutoencoderKL(nn.Module):
+    def __init__(
+        self,
+        ddconfig=None,
+        lossconfig=None,
+        batchsize=None,
+        embed_dim=None,
+        time_shuffle=1,
+        subband=1,
+        sampling_rate=16000,
+        ckpt_path=None,
+        reload_from_ckpt=None,
+        ignore_keys=[],
+        image_key="fbank",
+        colorize_nlabels=None,
+        monitor=None,
+        base_learning_rate=1e-5,
+    ):
+        super().__init__()
+        self.automatic_optimization = False
+        assert (
+            "mel_bins" in ddconfig.keys()
+        ), "mel_bins is not specified in the Autoencoder config"
+        num_mel = ddconfig["mel_bins"]
+        self.image_key = image_key
+        self.sampling_rate = sampling_rate
+        self.encoder = Encoder(**ddconfig)
+        self.decoder = Decoder(**ddconfig)
+
+        self.loss = None
+        self.subband = int(subband)
+
+        if self.subband > 1:
+            print("Use subband decomposition %s" % self.subband)
+
+        assert ddconfig["double_z"]
+        self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
+        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+
+        if self.image_key == "fbank":
+            self.vocoder = get_vocoder(None, "cpu", num_mel)
+        self.embed_dim = embed_dim
+        if colorize_nlabels is not None:
+            assert type(colorize_nlabels) == int
+            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+        if monitor is not None:
+            self.monitor = monitor
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+        self.learning_rate = float(base_learning_rate)
+        # print("Initial learning rate %s" % self.learning_rate)
+
+        self.time_shuffle = time_shuffle
+        self.reload_from_ckpt = reload_from_ckpt
+        self.reloaded = False
+        self.mean, self.std = None, None
+
+        self.feature_cache = None
+        self.flag_first_run = True
+        self.train_step = 0
+
+        self.logger_save_dir = None
+        self.logger_exp_name = None
+
+    def get_log_dir(self):
+        if self.logger_save_dir is None and self.logger_exp_name is None:
+            return os.path.join(self.logger.save_dir, self.logger._project)
+        else:
+            return os.path.join(self.logger_save_dir, self.logger_exp_name)
+
+    def set_log_dir(self, save_dir, exp_name):
+        self.logger_save_dir = save_dir
+        self.logger_exp_name = exp_name
+
+    def init_from_ckpt(self, path, ignore_keys=list()):
+        sd = torch.load(path, map_location="cpu")["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        self.load_state_dict(sd, strict=False)
+        print(f"Restored from {path}")
+
+    def encode(self, x):
+        # x = self.time_shuffle_operation(x)
+        # x = self.freq_split_subband(x)
+        h = self.encoder(x)
+        moments = self.quant_conv(h)
+        posterior = DiagonalGaussianDistribution(moments)
+        return posterior
+
+    def decode(self, z):
+        z = self.post_quant_conv(z)
+        dec = self.decoder(z)
+        # bs, ch, shuffled_timesteps, fbins = dec.size()
+        # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)
+        # dec = self.freq_merge_subband(dec)
+        return dec
+
+    def decode_to_waveform(self, dec):
+        from audioldm2.utilities.model import vocoder_infer
+
+        if self.image_key == "fbank":
+            dec = dec.squeeze(1).permute(0, 2, 1)
+            wav_reconstruction = vocoder_infer(dec, self.vocoder)
+        elif self.image_key == "stft":
+            dec = dec.squeeze(1).permute(0, 2, 1)
+            wav_reconstruction = self.wave_decoder(dec)
+        return wav_reconstruction
+
+    def visualize_latent(self, input):
+        import matplotlib.pyplot as plt
+
+        # for i in range(10):
+        #     zero_input = torch.zeros_like(input) - 11.59
+        #     zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59
+
+        #     posterior = self.encode(zero_input)
+        #     latent = posterior.sample()
+        #     avg_latent = torch.mean(latent, dim=1)[0]
+        #     plt.imshow(avg_latent.cpu().detach().numpy().T)
+        #     plt.savefig("%s.png" % i)
+        #     plt.close()
+
+        np.save("input.npy", input.cpu().detach().numpy())
+        # zero_input = torch.zeros_like(input) - 11.59
+        time_input = input.clone()
+        time_input[:, :, :, :32] *= 0
+        time_input[:, :, :, :32] -= 11.59
+
+        np.save("time_input.npy", time_input.cpu().detach().numpy())
+
+        posterior = self.encode(time_input)
+        latent = posterior.sample()
+        np.save("time_latent.npy", latent.cpu().detach().numpy())
+        avg_latent = torch.mean(latent, dim=1)
+        for i in range(avg_latent.size(0)):
+            plt.imshow(avg_latent[i].cpu().detach().numpy().T)
+            plt.savefig("freq_%s.png" % i)
+            plt.close()
+
+        freq_input = input.clone()
+        freq_input[:, :, :512, :] *= 0
+        freq_input[:, :, :512, :] -= 11.59
+
+        np.save("freq_input.npy", freq_input.cpu().detach().numpy())
+
+        posterior = self.encode(freq_input)
+        latent = posterior.sample()
+        np.save("freq_latent.npy", latent.cpu().detach().numpy())
+        avg_latent = torch.mean(latent, dim=1)
+        for i in range(avg_latent.size(0)):
+            plt.imshow(avg_latent[i].cpu().detach().numpy().T)
+            plt.savefig("time_%s.png" % i)
+            plt.close()
+
+    def get_input(self, batch):
+        fname, text, label_indices, waveform, stft, fbank = (
+            batch["fname"],
+            batch["text"],
+            batch["label_vector"],
+            batch["waveform"],
+            batch["stft"],
+            batch["log_mel_spec"],
+        )
+        # if(self.time_shuffle != 1):
+        #     if(fbank.size(1) % self.time_shuffle != 0):
+        #         pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle)
+        #         fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len))
+
+        ret = {}
+
+        ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = (
+            fbank.unsqueeze(1),
+            stft.unsqueeze(1),
+            fname,
+            waveform.unsqueeze(1),
+        )
+
+        return ret
+
+    def save_wave(self, batch_wav, fname, save_dir):
+        os.makedirs(save_dir, exist_ok=True)
+
+        for wav, name in zip(batch_wav, fname):
+            name = os.path.basename(name)
+
+            sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)
+
+    def get_last_layer(self):
+        return self.decoder.conv_out.weight
+
+    @torch.no_grad()
+    def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):
+        log = dict()
+        x = batch.to(self.device)
+        if not only_inputs:
+            xrec, posterior = self(x)
+            log["samples"] = self.decode(posterior.sample())
+            log["reconstructions"] = xrec
+
+        log["inputs"] = x
+        wavs = self._log_img(log, train=train, index=0, waveform=waveform)
+        return wavs
+
+    def _log_img(self, log, train=True, index=0, waveform=None):
+        images_input = self.tensor2numpy(log["inputs"][index, 0]).T
+        images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T
+        images_samples = self.tensor2numpy(log["samples"][index, 0]).T
+
+        if train:
+            name = "train"
+        else:
+            name = "val"
+
+        if self.logger is not None:
+            self.logger.log_image(
+                "img_%s" % name,
+                [images_input, images_reconstruct, images_samples],
+                caption=["input", "reconstruct", "samples"],
+            )
+
+        inputs, reconstructions, samples = (
+            log["inputs"],
+            log["reconstructions"],
+            log["samples"],
+        )
+
+        if self.image_key == "fbank":
+            wav_original, wav_prediction = synth_one_sample(
+                inputs[index],
+                reconstructions[index],
+                labels="validation",
+                vocoder=self.vocoder,
+            )
+            wav_original, wav_samples = synth_one_sample(
+                inputs[index], samples[index], labels="validation", vocoder=self.vocoder
+            )
+            wav_original, wav_samples, wav_prediction = (
+                wav_original[0],
+                wav_samples[0],
+                wav_prediction[0],
+            )
+        elif self.image_key == "stft":
+            wav_prediction = (
+                self.decode_to_waveform(reconstructions)[index, 0]
+                .cpu()
+                .detach()
+                .numpy()
+            )
+            wav_samples = (
+                self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy()
+            )
+            wav_original = waveform[index, 0].cpu().detach().numpy()
+
+        if self.logger is not None:
+            self.logger.experiment.log(
+                {
+                    "original_%s"
+                    % name: wandb.Audio(
+                        wav_original, caption="original", sample_rate=self.sampling_rate
+                    ),
+                    "reconstruct_%s"
+                    % name: wandb.Audio(
+                        wav_prediction,
+                        caption="reconstruct",
+                        sample_rate=self.sampling_rate,
+                    ),
+                    "samples_%s"
+                    % name: wandb.Audio(
+                        wav_samples, caption="samples", sample_rate=self.sampling_rate
+                    ),
+                }
+            )
+
+        return wav_original, wav_prediction, wav_samples
+
+    def tensor2numpy(self, tensor):
+        return tensor.cpu().detach().numpy()
+
+    def to_rgb(self, x):
+        assert self.image_key == "segmentation"
+        if not hasattr(self, "colorize"):
+            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+        x = F.conv2d(x, weight=self.colorize)
+        x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
+        return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+    def __init__(self, *args, vq_interface=False, **kwargs):
+        self.vq_interface = vq_interface  # TODO: Should be true by default but check to not break older stuff
+        super().__init__()
+
+    def encode(self, x, *args, **kwargs):
+        return x
+
+    def decode(self, x, *args, **kwargs):
+        return x
+
+    def quantize(self, x, *args, **kwargs):
+        if self.vq_interface:
+            return x, None, [None, None, None]
+        return x
+
+    def forward(self, x, *args, **kwargs):
+        return x
diff --git a/audioldm2/pipeline.py b/audioldm2/pipeline.py
new file mode 100755
index 0000000000000000000000000000000000000000..1eec55b0198049f8baf263c3b80a7a8a0584ebeb
--- /dev/null
+++ b/audioldm2/pipeline.py
@@ -0,0 +1,201 @@
+import os
+
+import yaml
+import torch
+import torchaudio
+
+from audioldm2.latent_diffusion.models.ddpm import LatentDiffusion
+from audioldm2.utils import default_audioldm_config, get_metadata, download_checkpoint
+from audioldm2.utilities.audio import read_wav_file
+import os
+
+CACHE_DIR = os.getenv(
+    "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2")
+)
+
+
+def seed_everything(seed):
+    import random, os
+    import numpy as np
+    import torch
+
+    random.seed(seed)
+    os.environ["PYTHONHASHSEED"] = str(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = True
+
+
+def text_to_filename(text):
+    return text.replace(" ", "_").replace("'", "_").replace('"', "_")
+
+
+def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
+    norm_mean = -4.2677393
+    norm_std = 4.5689974
+
+    if sampling_rate != 16000:
+        waveform_16k = torchaudio.functional.resample(
+            waveform, orig_freq=sampling_rate, new_freq=16000
+        )
+    else:
+        waveform_16k = waveform
+
+    waveform_16k = waveform_16k - waveform_16k.mean()
+    fbank = torchaudio.compliance.kaldi.fbank(
+        waveform_16k,
+        htk_compat=True,
+        sample_frequency=16000,
+        use_energy=False,
+        window_type="hanning",
+        num_mel_bins=128,
+        dither=0.0,
+        frame_shift=10,
+    )
+
+    TARGET_LEN = log_mel_spec.size(0)
+
+    # cut and pad
+    n_frames = fbank.shape[0]
+    p = TARGET_LEN - n_frames
+    if p > 0:
+        m = torch.nn.ZeroPad2d((0, 0, 0, p))
+        fbank = m(fbank)
+    elif p < 0:
+        fbank = fbank[:TARGET_LEN, :]
+
+    fbank = (fbank - norm_mean) / (norm_std * 2)
+
+    return {"ta_kaldi_fbank": fbank}  # [1024, 128]
+
+
+def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1):
+    text = [text] * batchsize
+    if batchsize < 1:
+        print("Warning: Batchsize must be at least 1. Batchsize is set to .")
+
+    if fbank is None:
+        fbank = torch.zeros(
+            (batchsize, 1024, 64)
+        )  # Not used, here to keep the code format
+    else:
+        fbank = torch.FloatTensor(fbank)
+        fbank = fbank.expand(batchsize, 1024, 64)
+        assert fbank.size(0) == batchsize
+
+    stft = torch.zeros((batchsize, 1024, 512))  # Not used
+
+    if waveform is None:
+        waveform = torch.zeros((batchsize, 160000))  # Not used
+        ta_kaldi_fbank = torch.zeros((batchsize, 1024, 128))
+    else:
+        waveform = torch.FloatTensor(waveform)
+        waveform = waveform.expand(batchsize, -1)
+        assert waveform.size(0) == batchsize
+        ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, 16000, fbank)
+
+    batch = {
+        "text": text,  # list
+        "fname": [text_to_filename(t) for t in text],  # list
+        "waveform": waveform,
+        "stft": stft,
+        "log_mel_spec": fbank,
+        "ta_kaldi_fbank": ta_kaldi_fbank,
+    }
+
+    return batch
+
+
+def round_up_duration(duration):
+    return int(round(duration / 2.5) + 1) * 2.5
+
+
+def split_clap_weight_to_pth(checkpoint):
+    if os.path.exists(os.path.join(CACHE_DIR, "clap.pth")):
+        return
+    print("Constructing the weight for the CLAP model.")
+    include_keys = "cond_stage_models.0.cond_stage_models.0.model."
+    new_state_dict = {}
+    for each in checkpoint["state_dict"].keys():
+        if include_keys in each:
+            new_state_dict[each.replace(include_keys, "module.")] = checkpoint[
+                "state_dict"
+            ][each]
+    torch.save({"state_dict": new_state_dict}, os.path.join(CACHE_DIR, "clap.pth"))
+
+
+def build_model(ckpt_path=None, config=None, model_name="audioldm2-full"):
+    print("Loading AudioLDM-2: %s" % model_name)
+
+    if ckpt_path is None:
+        ckpt_path = get_metadata()[model_name]["path"]
+
+    if not os.path.exists(ckpt_path):
+        download_checkpoint(model_name)
+
+    if torch.cuda.is_available():
+        device = torch.device("cuda:0")
+    else:
+        device = torch.device("cpu")
+
+    if config is not None:
+        assert type(config) is str
+        config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
+    else: 
+        config = default_audioldm_config(model_name)
+
+    # # Use text as condition instead of using waveform during training
+    config["model"]["params"]["device"] = device
+    # config["model"]["params"]["cond_stage_key"] = "text"
+
+    # No normalization here
+    latent_diffusion = LatentDiffusion(**config["model"]["params"])
+
+    resume_from_checkpoint = ckpt_path
+
+    checkpoint = torch.load(resume_from_checkpoint, map_location=device)
+
+    latent_diffusion.load_state_dict(checkpoint["state_dict"])
+    
+    latent_diffusion.eval()
+    latent_diffusion = latent_diffusion.to(device)
+    
+    return latent_diffusion
+
+def duration_to_latent_t_size(duration):
+    return int(duration * 25.6)
+
+def text_to_audio(
+    latent_diffusion,
+    text,
+    seed=42,
+    ddim_steps=200,
+    duration=10,
+    batchsize=1,
+    guidance_scale=3.5,
+    n_candidate_gen_per_text=3,
+    config=None,
+):
+    assert (
+        duration == 10
+    ), "Error: Currently we only support 10 seconds of generation. Generating longer files requires some extra coding, which would be a part of the future work."
+
+    seed_everything(int(seed))
+    waveform = None
+
+    batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize)
+
+    latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
+
+    with torch.no_grad():
+        waveform = latent_diffusion.generate_batch(
+            batch,
+            unconditional_guidance_scale=guidance_scale,
+            ddim_steps=ddim_steps,
+            n_gen=n_candidate_gen_per_text,
+            duration=duration,
+        )
+
+    return waveform
diff --git a/audioldm2/utilities/__init__.py b/audioldm2/utilities/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..495e8fe675337df0afacd3a31d06d0241b6b0e63
--- /dev/null
+++ b/audioldm2/utilities/__init__.py
@@ -0,0 +1,3 @@
+from .tools import *
+from .data import *
+from .model import *
diff --git a/audioldm2/utilities/audio/__init__.py b/audioldm2/utilities/audio/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..c39f9243d2d7b4fc5dea18f56b153b0f5c5bbd4c
--- /dev/null
+++ b/audioldm2/utilities/audio/__init__.py
@@ -0,0 +1,3 @@
+from .audio_processing import *
+from .stft import *
+from .tools import *
diff --git a/audioldm2/utilities/audio/audio_processing.py b/audioldm2/utilities/audio/audio_processing.py
new file mode 100755
index 0000000000000000000000000000000000000000..77a4057aa82f226f68474f4c2a19eba84510d663
--- /dev/null
+++ b/audioldm2/utilities/audio/audio_processing.py
@@ -0,0 +1,100 @@
+import torch
+import numpy as np
+import librosa.util as librosa_util
+from scipy.signal import get_window
+
+
+def window_sumsquare(
+    window,
+    n_frames,
+    hop_length,
+    win_length,
+    n_fft,
+    dtype=np.float32,
+    norm=None,
+):
+    """
+    # from librosa 0.6
+    Compute the sum-square envelope of a window function at a given hop length.
+
+    This is used to estimate modulation effects induced by windowing
+    observations in short-time fourier transforms.
+
+    Parameters
+    ----------
+    window : string, tuple, number, callable, or list-like
+        Window specification, as in `get_window`
+
+    n_frames : int > 0
+        The number of analysis frames
+
+    hop_length : int > 0
+        The number of samples to advance between frames
+
+    win_length : [optional]
+        The length of the window function.  By default, this matches `n_fft`.
+
+    n_fft : int > 0
+        The length of each analysis frame.
+
+    dtype : np.dtype
+        The data type of the output
+
+    Returns
+    -------
+    wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
+        The sum-squared envelope of the window function
+    """
+    if win_length is None:
+        win_length = n_fft
+
+    n = n_fft + hop_length * (n_frames - 1)
+    x = np.zeros(n, dtype=dtype)
+
+    # Compute the squared window at the desired length
+    win_sq = get_window(window, win_length, fftbins=True)
+    win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
+    win_sq = librosa_util.pad_center(win_sq, n_fft)
+
+    # Fill the envelope
+    for i in range(n_frames):
+        sample = i * hop_length
+        x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
+    return x
+
+
+def griffin_lim(magnitudes, stft_fn, n_iters=30):
+    """
+    PARAMS
+    ------
+    magnitudes: spectrogram magnitudes
+    stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
+    """
+
+    angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
+    angles = angles.astype(np.float32)
+    angles = torch.autograd.Variable(torch.from_numpy(angles))
+    signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+
+    for i in range(n_iters):
+        _, angles = stft_fn.transform(signal)
+        signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+    return signal
+
+
+def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
+    """
+    PARAMS
+    ------
+    C: compression factor
+    """
+    return normalize_fun(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+    """
+    PARAMS
+    ------
+    C: compression factor used to compress
+    """
+    return torch.exp(x) / C
diff --git a/audioldm2/utilities/audio/stft.py b/audioldm2/utilities/audio/stft.py
new file mode 100755
index 0000000000000000000000000000000000000000..508f33674e6dd8a5557205c8e77e07955df13a87
--- /dev/null
+++ b/audioldm2/utilities/audio/stft.py
@@ -0,0 +1,178 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.signal import get_window
+from librosa.util import pad_center, tiny
+from librosa.filters import mel as librosa_mel_fn
+
+from audioldm2.utilities.audio.audio_processing import (
+    dynamic_range_compression,
+    dynamic_range_decompression,
+    window_sumsquare,
+)
+
+
+class STFT(torch.nn.Module):
+    """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
+
+    def __init__(self, filter_length, hop_length, win_length, window="hann"):
+        super(STFT, self).__init__()
+        self.filter_length = filter_length
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.window = window
+        self.forward_transform = None
+        scale = self.filter_length / self.hop_length
+        fourier_basis = np.fft.fft(np.eye(self.filter_length))
+
+        cutoff = int((self.filter_length / 2 + 1))
+        fourier_basis = np.vstack(
+            [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
+        )
+
+        forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
+        inverse_basis = torch.FloatTensor(
+            np.linalg.pinv(scale * fourier_basis).T[:, None, :]
+        )
+
+        if window is not None:
+            assert filter_length >= win_length
+            # get window and zero center pad it to filter_length
+            fft_window = get_window(window, win_length, fftbins=True)
+            fft_window = pad_center(fft_window, filter_length)
+            fft_window = torch.from_numpy(fft_window).float()
+
+            # window the bases
+            forward_basis *= fft_window
+            inverse_basis *= fft_window
+
+        self.register_buffer("forward_basis", forward_basis.float())
+        self.register_buffer("inverse_basis", inverse_basis.float())
+
+    def transform(self, input_data):
+        num_batches = input_data.size(0)
+        num_samples = input_data.size(1)
+
+        self.num_samples = num_samples
+
+        # similar to librosa, reflect-pad the input
+        input_data = input_data.view(num_batches, 1, num_samples)
+        input_data = F.pad(
+            input_data.unsqueeze(1),
+            (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
+            mode="reflect",
+        )
+        input_data = input_data.squeeze(1)
+
+        forward_transform = F.conv1d(
+            input_data,
+            torch.autograd.Variable(self.forward_basis, requires_grad=False),
+            stride=self.hop_length,
+            padding=0,
+        ).cpu()
+
+        cutoff = int((self.filter_length / 2) + 1)
+        real_part = forward_transform[:, :cutoff, :]
+        imag_part = forward_transform[:, cutoff:, :]
+
+        magnitude = torch.sqrt(real_part**2 + imag_part**2)
+        phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
+
+        return magnitude, phase
+
+    def inverse(self, magnitude, phase):
+        recombine_magnitude_phase = torch.cat(
+            [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
+        )
+
+        inverse_transform = F.conv_transpose1d(
+            recombine_magnitude_phase,
+            torch.autograd.Variable(self.inverse_basis, requires_grad=False),
+            stride=self.hop_length,
+            padding=0,
+        )
+
+        if self.window is not None:
+            window_sum = window_sumsquare(
+                self.window,
+                magnitude.size(-1),
+                hop_length=self.hop_length,
+                win_length=self.win_length,
+                n_fft=self.filter_length,
+                dtype=np.float32,
+            )
+            # remove modulation effects
+            approx_nonzero_indices = torch.from_numpy(
+                np.where(window_sum > tiny(window_sum))[0]
+            )
+            window_sum = torch.autograd.Variable(
+                torch.from_numpy(window_sum), requires_grad=False
+            )
+            window_sum = window_sum
+            inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
+                approx_nonzero_indices
+            ]
+
+            # scale by hop ratio
+            inverse_transform *= float(self.filter_length) / self.hop_length
+
+        inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
+        inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
+
+        return inverse_transform
+
+    def forward(self, input_data):
+        self.magnitude, self.phase = self.transform(input_data)
+        reconstruction = self.inverse(self.magnitude, self.phase)
+        return reconstruction
+
+
+class TacotronSTFT(torch.nn.Module):
+    def __init__(
+        self,
+        filter_length,
+        hop_length,
+        win_length,
+        n_mel_channels,
+        sampling_rate,
+        mel_fmin,
+        mel_fmax,
+    ):
+        super(TacotronSTFT, self).__init__()
+        self.n_mel_channels = n_mel_channels
+        self.sampling_rate = sampling_rate
+        self.stft_fn = STFT(filter_length, hop_length, win_length)
+        mel_basis = librosa_mel_fn(
+            sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
+        )
+        mel_basis = torch.from_numpy(mel_basis).float()
+        self.register_buffer("mel_basis", mel_basis)
+
+    def spectral_normalize(self, magnitudes, normalize_fun):
+        output = dynamic_range_compression(magnitudes, normalize_fun)
+        return output
+
+    def spectral_de_normalize(self, magnitudes):
+        output = dynamic_range_decompression(magnitudes)
+        return output
+
+    def mel_spectrogram(self, y, normalize_fun=torch.log):
+        """Computes mel-spectrograms from a batch of waves
+        PARAMS
+        ------
+        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
+
+        RETURNS
+        -------
+        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
+        """
+        assert torch.min(y.data) >= -1, torch.min(y.data)
+        assert torch.max(y.data) <= 1, torch.max(y.data)
+
+        magnitudes, phases = self.stft_fn.transform(y)
+        magnitudes = magnitudes.data
+        mel_output = torch.matmul(self.mel_basis, magnitudes)
+        mel_output = self.spectral_normalize(mel_output, normalize_fun)
+        energy = torch.norm(magnitudes, dim=1)
+
+        return mel_output, magnitudes, phases, energy
diff --git a/audioldm2/utilities/audio/tools.py b/audioldm2/utilities/audio/tools.py
new file mode 100755
index 0000000000000000000000000000000000000000..8c666a7c67e0ae93edbad666520fd2e98fd29d18
--- /dev/null
+++ b/audioldm2/utilities/audio/tools.py
@@ -0,0 +1,69 @@
+import torch
+import numpy as np
+from scipy.io.wavfile import write
+import torchaudio
+
+from audioldm2.utilities.audio.audio_processing import griffin_lim
+
+
+def pad_wav(waveform, segment_length):
+    waveform_length = waveform.shape[-1]
+    assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
+    if segment_length is None or waveform_length == segment_length:
+        return waveform
+    elif waveform_length > segment_length:
+        return waveform[:segment_length]
+    elif waveform_length < segment_length:
+        temp_wav = np.zeros((1, segment_length))
+        temp_wav[:, :waveform_length] = waveform
+    return temp_wav
+
+
+def normalize_wav(waveform):
+    waveform = waveform - np.mean(waveform)
+    waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
+    return waveform * 0.5
+
+
+def read_wav_file(filename, segment_length):
+    # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
+    waveform, sr = torchaudio.load(filename)  # Faster!!!
+    waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
+    waveform = waveform.numpy()[0, ...]
+    waveform = normalize_wav(waveform)
+    waveform = waveform[None, ...]
+    waveform = pad_wav(waveform, segment_length)
+
+    waveform = waveform / np.max(np.abs(waveform))
+    waveform = 0.5 * waveform
+
+    return waveform
+
+
+def get_mel_from_wav(audio, _stft):
+    audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
+    audio = torch.autograd.Variable(audio, requires_grad=False)
+    melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio)
+    melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
+    magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32)
+    energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
+    return melspec, magnitudes, energy
+
+
+def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
+    mel = torch.stack([mel])
+    mel_decompress = _stft.spectral_de_normalize(mel)
+    mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
+    spec_from_mel_scaling = 1000
+    spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
+    spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
+    spec_from_mel = spec_from_mel * spec_from_mel_scaling
+
+    audio = griffin_lim(
+        torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
+    )
+
+    audio = audio.squeeze()
+    audio = audio.cpu().numpy()
+    audio_path = out_filename
+    write(audio_path, _stft.sampling_rate, audio)
diff --git a/audioldm2/utilities/data/__init__.py b/audioldm2/utilities/data/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..13a9804e72b88e3b9078940aee87db73788c1fb5
--- /dev/null
+++ b/audioldm2/utilities/data/__init__.py
@@ -0,0 +1 @@
+from .dataset import Dataset
diff --git a/audioldm2/utilities/data/add_on.py b/audioldm2/utilities/data/add_on.py
new file mode 100755
index 0000000000000000000000000000000000000000..4cfc6297e2f66759077c1540fc04b19560f3659c
--- /dev/null
+++ b/audioldm2/utilities/data/add_on.py
@@ -0,0 +1,508 @@
+import os
+import torch
+import numpy as np
+import torchaudio
+import matplotlib.pyplot as plt
+
+CACHE = {
+    "get_vits_phoneme_ids": {
+        "PAD_LENGTH": 310,
+        "_pad": "_",
+        "_punctuation": ';:,.!?¡¿—…"«»“” ',
+        "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
+        "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
+        "_special": "♪☎☒☝⚠",
+    }
+}
+
+CACHE["get_vits_phoneme_ids"]["symbols"] = (
+    [CACHE["get_vits_phoneme_ids"]["_pad"]]
+    + list(CACHE["get_vits_phoneme_ids"]["_punctuation"])
+    + list(CACHE["get_vits_phoneme_ids"]["_letters"])
+    + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"])
+    + list(CACHE["get_vits_phoneme_ids"]["_special"])
+)
+CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {
+    s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])
+}
+
+
+def get_vits_phoneme_ids(config, dl_output, metadata):
+    pad_token_id = 0
+    pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
+    _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
+
+    assert (
+        "phonemes" in metadata.keys()
+    ), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
+    clean_text = metadata["phonemes"]
+    sequence = []
+
+    for symbol in clean_text:
+        symbol_id = _symbol_to_id[symbol]
+        sequence += [symbol_id]
+
+    inserted_zero_sequence = [0] * (len(sequence) * 2)
+    inserted_zero_sequence[1::2] = sequence
+    inserted_zero_sequence = inserted_zero_sequence + [0]
+
+    def _pad_phonemes(phonemes_list):
+        return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
+
+    return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))}
+
+
+def get_vits_phoneme_ids_no_padding(config, dl_output, metadata):
+    pad_token_id = 0
+    pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
+    _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
+
+    assert (
+        "phonemes" in metadata.keys()
+    ), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
+    clean_text = metadata["phonemes"] + "⚠"
+    sequence = []
+
+    for symbol in clean_text:
+        if symbol not in _symbol_to_id.keys():
+            print("%s is not in the vocabulary. %s" % (symbol, clean_text))
+            symbol = "_"
+        symbol_id = _symbol_to_id[symbol]
+        sequence += [symbol_id]
+
+    def _pad_phonemes(phonemes_list):
+        return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
+
+    sequence = sequence[:pad_length]
+
+    return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))}
+
+
+def calculate_relative_bandwidth(config, dl_output, metadata):
+    assert "stft" in dl_output.keys()
+
+    # The last dimension of the stft feature is the frequency dimension
+    freq_dimensions = dl_output["stft"].size(-1)
+
+    freq_energy_dist = torch.sum(dl_output["stft"], dim=0)
+    freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
+    total_energy = freq_energy_dist[-1]
+
+    percentile_5th = total_energy * 0.05
+    percentile_95th = total_energy * 0.95
+
+    lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
+    higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
+
+    lower_idx = int((lower_idx / freq_dimensions) * 1000)
+    higher_idx = int((higher_idx / freq_dimensions) * 1000)
+
+    return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])}
+
+
+def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata):
+    assert "stft" in dl_output.keys()
+    linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10))
+
+    # The last dimension of the stft feature is the frequency dimension
+    freq_dimensions = linear_mel_spec.size(-1)
+    freq_energy_dist = torch.sum(linear_mel_spec, dim=0)
+    freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
+    total_energy = freq_energy_dist[-1]
+
+    percentile_5th = total_energy * 0.05
+    percentile_95th = total_energy * 0.95
+
+    lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
+    higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
+
+    latent_t_size = config["model"]["params"]["latent_t_size"]
+    latent_f_size = config["model"]["params"]["latent_f_size"]
+
+    lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions)))
+    higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions)))
+
+    bandwidth_condition = torch.zeros((latent_t_size, latent_f_size))
+    bandwidth_condition[:, lower_idx:higher_idx] += 1.0
+
+    return {
+        "mel_spec_bandwidth_cond_extra_channel": bandwidth_condition,
+        "freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]),
+    }
+
+
+def waveform_rs_48k(config, dl_output, metadata):
+    waveform = dl_output["waveform"]  # [1, samples]
+    sampling_rate = dl_output["sampling_rate"]
+
+    if sampling_rate != 48000:
+        waveform_48k = torchaudio.functional.resample(
+            waveform, orig_freq=sampling_rate, new_freq=48000
+        )
+    else:
+        waveform_48k = waveform
+
+    return {"waveform_48k": waveform_48k}
+
+
+def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata):
+    assert (
+        "phoneme" not in metadata.keys()
+    ), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json"
+
+    if "phonemes" in metadata.keys():
+        new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata)
+        new_item["text"] = ""  # We assume TTS data does not have text description
+    else:
+        fake_metadata = {"phonemes": ""}  # Add empty phoneme sequence
+        new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata)
+
+    return new_item
+
+
+def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata):
+    if "phoneme" in metadata.keys():
+        new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata)
+        new_item["text"] = ""
+    else:
+        fake_metadata = {"phoneme": []}
+        new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata)
+    return new_item
+
+
+def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata):
+    PAD_LENGTH = 135
+
+    phonemes_lookup_dict = {
+        "K": 0,
+        "IH2": 1,
+        "NG": 2,
+        "OW2": 3,
+        "AH2": 4,
+        "F": 5,
+        "AE0": 6,
+        "IY0": 7,
+        "SH": 8,
+        "G": 9,
+        "W": 10,
+        "UW1": 11,
+        "AO2": 12,
+        "AW2": 13,
+        "UW0": 14,
+        "EY2": 15,
+        "UW2": 16,
+        "AE2": 17,
+        "IH0": 18,
+        "P": 19,
+        "D": 20,
+        "ER1": 21,
+        "AA1": 22,
+        "EH0": 23,
+        "UH1": 24,
+        "N": 25,
+        "V": 26,
+        "AY1": 27,
+        "EY1": 28,
+        "UH2": 29,
+        "EH1": 30,
+        "L": 31,
+        "AA2": 32,
+        "R": 33,
+        "OY1": 34,
+        "Y": 35,
+        "ER2": 36,
+        "S": 37,
+        "AE1": 38,
+        "AH1": 39,
+        "JH": 40,
+        "ER0": 41,
+        "EH2": 42,
+        "IY2": 43,
+        "OY2": 44,
+        "AW1": 45,
+        "IH1": 46,
+        "IY1": 47,
+        "OW0": 48,
+        "AO0": 49,
+        "AY0": 50,
+        "EY0": 51,
+        "AY2": 52,
+        "UH0": 53,
+        "M": 54,
+        "TH": 55,
+        "T": 56,
+        "OY0": 57,
+        "AW0": 58,
+        "DH": 59,
+        "Z": 60,
+        "spn": 61,
+        "AH0": 62,
+        "sp": 63,
+        "AO1": 64,
+        "OW1": 65,
+        "ZH": 66,
+        "B": 67,
+        "AA0": 68,
+        "CH": 69,
+        "HH": 70,
+    }
+    pad_token_id = len(phonemes_lookup_dict.keys())
+
+    assert (
+        "phoneme" in metadata.keys()
+    ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
+
+    phonemes = [
+        phonemes_lookup_dict[x]
+        for x in metadata["phoneme"]
+        if (x in phonemes_lookup_dict.keys())
+    ]
+
+    if (len(phonemes) / PAD_LENGTH) > 5:
+        print(
+            "Warning: Phonemes length is too long and is truncated too much! %s"
+            % metadata
+        )
+
+    phonemes = phonemes[:PAD_LENGTH]
+
+    def _pad_phonemes(phonemes_list):
+        return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
+
+    return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
+
+
+def extract_phoneme_g2p_en_feature(config, dl_output, metadata):
+    PAD_LENGTH = 250
+
+    phonemes_lookup_dict = {
+        " ": 0,
+        "AA": 1,
+        "AE": 2,
+        "AH": 3,
+        "AO": 4,
+        "AW": 5,
+        "AY": 6,
+        "B": 7,
+        "CH": 8,
+        "D": 9,
+        "DH": 10,
+        "EH": 11,
+        "ER": 12,
+        "EY": 13,
+        "F": 14,
+        "G": 15,
+        "HH": 16,
+        "IH": 17,
+        "IY": 18,
+        "JH": 19,
+        "K": 20,
+        "L": 21,
+        "M": 22,
+        "N": 23,
+        "NG": 24,
+        "OW": 25,
+        "OY": 26,
+        "P": 27,
+        "R": 28,
+        "S": 29,
+        "SH": 30,
+        "T": 31,
+        "TH": 32,
+        "UH": 33,
+        "UW": 34,
+        "V": 35,
+        "W": 36,
+        "Y": 37,
+        "Z": 38,
+        "ZH": 39,
+    }
+    pad_token_id = len(phonemes_lookup_dict.keys())
+
+    assert (
+        "phoneme" in metadata.keys()
+    ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
+    phonemes = [
+        phonemes_lookup_dict[x]
+        for x in metadata["phoneme"]
+        if (x in phonemes_lookup_dict.keys())
+    ]
+
+    if (len(phonemes) / PAD_LENGTH) > 5:
+        print(
+            "Warning: Phonemes length is too long and is truncated too much! %s"
+            % metadata
+        )
+
+    phonemes = phonemes[:PAD_LENGTH]
+
+    def _pad_phonemes(phonemes_list):
+        return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
+
+    return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
+
+
+def extract_kaldi_fbank_feature(config, dl_output, metadata):
+    norm_mean = -4.2677393
+    norm_std = 4.5689974
+
+    waveform = dl_output["waveform"]  # [1, samples]
+    sampling_rate = dl_output["sampling_rate"]
+    log_mel_spec_hifigan = dl_output["log_mel_spec"]
+
+    if sampling_rate != 16000:
+        waveform_16k = torchaudio.functional.resample(
+            waveform, orig_freq=sampling_rate, new_freq=16000
+        )
+    else:
+        waveform_16k = waveform
+
+    waveform_16k = waveform_16k - waveform_16k.mean()
+    fbank = torchaudio.compliance.kaldi.fbank(
+        waveform_16k,
+        htk_compat=True,
+        sample_frequency=16000,
+        use_energy=False,
+        window_type="hanning",
+        num_mel_bins=128,
+        dither=0.0,
+        frame_shift=10,
+    )
+
+    TARGET_LEN = log_mel_spec_hifigan.size(0)
+
+    # cut and pad
+    n_frames = fbank.shape[0]
+    p = TARGET_LEN - n_frames
+    if p > 0:
+        m = torch.nn.ZeroPad2d((0, 0, 0, p))
+        fbank = m(fbank)
+    elif p < 0:
+        fbank = fbank[:TARGET_LEN, :]
+
+    fbank = (fbank - norm_mean) / (norm_std * 2)
+
+    return {"ta_kaldi_fbank": fbank}  # [1024, 128]
+
+
+def extract_kaldi_fbank_feature_32k(config, dl_output, metadata):
+    norm_mean = -4.2677393
+    norm_std = 4.5689974
+
+    waveform = dl_output["waveform"]  # [1, samples]
+    sampling_rate = dl_output["sampling_rate"]
+    log_mel_spec_hifigan = dl_output["log_mel_spec"]
+
+    if sampling_rate != 32000:
+        waveform_32k = torchaudio.functional.resample(
+            waveform, orig_freq=sampling_rate, new_freq=32000
+        )
+    else:
+        waveform_32k = waveform
+
+    waveform_32k = waveform_32k - waveform_32k.mean()
+    fbank = torchaudio.compliance.kaldi.fbank(
+        waveform_32k,
+        htk_compat=True,
+        sample_frequency=32000,
+        use_energy=False,
+        window_type="hanning",
+        num_mel_bins=128,
+        dither=0.0,
+        frame_shift=10,
+    )
+
+    TARGET_LEN = log_mel_spec_hifigan.size(0)
+
+    # cut and pad
+    n_frames = fbank.shape[0]
+    p = TARGET_LEN - n_frames
+    if p > 0:
+        m = torch.nn.ZeroPad2d((0, 0, 0, p))
+        fbank = m(fbank)
+    elif p < 0:
+        fbank = fbank[:TARGET_LEN, :]
+
+    fbank = (fbank - norm_mean) / (norm_std * 2)
+
+    return {"ta_kaldi_fbank": fbank}  # [1024, 128]
+
+
+# Use the beat and downbeat information as music conditions
+def extract_drum_beat(config, dl_output, metadata):
+    def visualization(conditional_signal, mel_spectrogram, filename):
+        import soundfile as sf
+
+        sf.write(
+            os.path.basename(dl_output["fname"]),
+            np.array(dl_output["waveform"])[0],
+            dl_output["sampling_rate"],
+        )
+        plt.figure(figsize=(10, 10))
+
+        plt.subplot(211)
+        plt.imshow(np.array(conditional_signal).T, aspect="auto")
+        plt.title("Conditional Signal")
+
+        plt.subplot(212)
+        plt.imshow(np.array(mel_spectrogram).T, aspect="auto")
+        plt.title("Mel Spectrogram")
+
+        plt.savefig(filename)
+        plt.close()
+
+    assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata
+
+    sampling_rate = metadata["sample_rate"]
+    duration = dl_output["duration"]
+    # The dataloader segment length before performing torch resampling
+    original_segment_length_before_resample = int(sampling_rate * duration)
+
+    random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"])
+
+    # The sample idx for beat and downbeat, relatively to the segmented audio
+    beat = [
+        x - random_start_sample
+        for x in metadata["beat"]
+        if (
+            x - random_start_sample >= 0
+            and x - random_start_sample <= original_segment_length_before_resample
+        )
+    ]
+    downbeat = [
+        x - random_start_sample
+        for x in metadata["downbeat"]
+        if (
+            x - random_start_sample >= 0
+            and x - random_start_sample <= original_segment_length_before_resample
+        )
+    ]
+
+    latent_shape = (
+        config["model"]["params"]["latent_t_size"],
+        config["model"]["params"]["latent_f_size"],
+    )
+    conditional_signal = torch.zeros(latent_shape)
+
+    # beat: -0.5
+    # downbeat: +1.0
+    # 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat
+    for each in beat:
+        beat_index = int(
+            (each / original_segment_length_before_resample) * latent_shape[0]
+        )
+        beat_index = min(beat_index, conditional_signal.size(0) - 1)
+
+        conditional_signal[beat_index, :] -= 0.5
+
+    for each in downbeat:
+        beat_index = int(
+            (each / original_segment_length_before_resample) * latent_shape[0]
+        )
+        beat_index = min(beat_index, conditional_signal.size(0) - 1)
+
+        conditional_signal[beat_index, :] += 1.0
+
+    # visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png")
+
+    return {"cond_beat_downbeat": conditional_signal}
diff --git a/audioldm2/utilities/data/dataset.py b/audioldm2/utilities/data/dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..f0bfbb7388ca6473beb4574ac4e29dcf0b7c0571
--- /dev/null
+++ b/audioldm2/utilities/data/dataset.py
@@ -0,0 +1,518 @@
+import os
+import pandas as pd
+
+import audioldm2.utilities.audio as Audio
+from audioldm2.utilities.tools import load_json
+
+import random
+from torch.utils.data import Dataset
+import torch.nn.functional
+import torch
+import numpy as np
+import torchaudio
+
+
+class AudioDataset(Dataset):
+    def __init__(
+        self,
+        config=None,
+        split="train",
+        waveform_only=False,
+        add_ons=[],
+        dataset_json_path=None,  #
+    ):
+        """
+        Dataset that manages audio recordings
+        :param audio_conf: Dictionary containing the audio loading and preprocessing settings
+        :param dataset_json_file
+        """
+        self.config = config
+        self.split = split
+        self.pad_wav_start_sample = 0  # If none, random choose
+        self.trim_wav = False
+        self.waveform_only = waveform_only
+        self.add_ons = [eval(x) for x in add_ons]
+        print("Add-ons:", self.add_ons)
+
+        self.build_setting_parameters()
+
+        # For an external dataset
+        if dataset_json_path is not None:
+            assert type(dataset_json_path) == str
+            print("Load metadata from %s" % dataset_json_path)
+            self.data = load_json(dataset_json_path)["data"]
+            self.id2label, self.index_dict, self.num2label = {}, {}, {}
+        else:
+            self.metadata_root = load_json(self.config["metadata_root"])
+            self.dataset_name = self.config["data"][self.split]
+            assert split in self.config["data"].keys(), (
+                "The dataset split %s you specified is not present in the config. You can choose from %s"
+                % (split, self.config["data"].keys())
+            )
+            self.build_dataset()
+            self.build_id_to_label()
+
+        self.build_dsp()
+        self.label_num = len(self.index_dict)
+        print("Dataset initialize finished")
+
+    def __getitem__(self, index):
+        (
+            fname,
+            waveform,
+            stft,
+            log_mel_spec,
+            label_vector,  # the one-hot representation of the audio class
+            # the metadata of the sampled audio file and the mixup audio file (if exist)
+            (datum, mix_datum),
+            random_start,
+        ) = self.feature_extraction(index)
+        text = self.get_sample_text_caption(datum, mix_datum, label_vector)
+
+        data = {
+            "text": text,  # list
+            "fname": self.text_to_filename(text)
+            if (len(fname) == 0)
+            else fname,  # list
+            # tensor, [batchsize, class_num]
+            "label_vector": "" if (label_vector is None) else label_vector.float(),
+            # tensor, [batchsize, 1, samples_num]
+            "waveform": "" if (waveform is None) else waveform.float(),
+            # tensor, [batchsize, t-steps, f-bins]
+            "stft": "" if (stft is None) else stft.float(),
+            # tensor, [batchsize, t-steps, mel-bins]
+            "log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(),
+            "duration": self.duration,
+            "sampling_rate": self.sampling_rate,
+            "random_start_sample_in_original_audio_file": random_start,
+        }
+
+        for add_on in self.add_ons:
+            data.update(add_on(self.config, data, self.data[index]))
+
+        if data["text"] is None:
+            print("Warning: The model return None on key text", fname)
+            data["text"] = ""
+
+        return data
+
+    def text_to_filename(self, text):
+        return text.replace(" ", "_").replace("'", "_").replace('"', "_")
+
+    def get_dataset_root_path(self, dataset):
+        assert dataset in self.metadata_root.keys()
+        return self.metadata_root[dataset]
+
+    def get_dataset_metadata_path(self, dataset, key):
+        # key: train, test, val, class_label_indices
+        try:
+            if dataset in self.metadata_root["metadata"]["path"].keys():
+                return self.metadata_root["metadata"]["path"][dataset][key]
+        except:
+            raise ValueError(
+                'Dataset %s does not metadata "%s" specified' % (dataset, key)
+            )
+            # return None
+
+    def __len__(self):
+        return len(self.data)
+
+    def feature_extraction(self, index):
+        if index > len(self.data) - 1:
+            print(
+                "The index of the dataloader is out of range: %s/%s"
+                % (index, len(self.data))
+            )
+            index = random.randint(0, len(self.data) - 1)
+
+        # Read wave file and extract feature
+        while True:
+            try:
+                label_indices = np.zeros(self.label_num, dtype=np.float32)
+                datum = self.data[index]
+                (
+                    log_mel_spec,
+                    stft,
+                    mix_lambda,
+                    waveform,
+                    random_start,
+                ) = self.read_audio_file(datum["wav"])
+                mix_datum = None
+                if self.label_num > 0 and "labels" in datum.keys():
+                    for label_str in datum["labels"].split(","):
+                        label_indices[int(self.index_dict[label_str])] = 1.0
+
+                # If the key "label" is not in the metadata, return all zero vector
+                label_indices = torch.FloatTensor(label_indices)
+                break
+            except Exception as e:
+                index = (index + 1) % len(self.data)
+                print(
+                    "Error encounter during audio feature extraction: ", e, datum["wav"]
+                )
+                continue
+
+        # The filename of the wav file
+        fname = datum["wav"]
+        # t_step = log_mel_spec.size(0)
+        # waveform = torch.FloatTensor(waveform[..., : int(self.hopsize * t_step)])
+        waveform = torch.FloatTensor(waveform)
+
+        return (
+            fname,
+            waveform,
+            stft,
+            log_mel_spec,
+            label_indices,
+            (datum, mix_datum),
+            random_start,
+        )
+
+    # def augmentation(self, log_mel_spec):
+    #     assert torch.min(log_mel_spec) < 0
+    #     log_mel_spec = log_mel_spec.exp()
+
+    #     log_mel_spec = torch.transpose(log_mel_spec, 0, 1)
+    #     # this is just to satisfy new torchaudio version.
+    #     log_mel_spec = log_mel_spec.unsqueeze(0)
+    #     if self.freqm != 0:
+    #         log_mel_spec = self.frequency_masking(log_mel_spec, self.freqm)
+    #     if self.timem != 0:
+    #         log_mel_spec = self.time_masking(
+    #             log_mel_spec, self.timem)  # self.timem=0
+
+    #     log_mel_spec = (log_mel_spec + 1e-7).log()
+    #     # squeeze back
+    #     log_mel_spec = log_mel_spec.squeeze(0)
+    #     log_mel_spec = torch.transpose(log_mel_spec, 0, 1)
+    #     return log_mel_spec
+
+    def build_setting_parameters(self):
+        # Read from the json config
+        self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"]
+        # self.freqm = self.config["preprocessing"]["mel"]["freqm"]
+        # self.timem = self.config["preprocessing"]["mel"]["timem"]
+        self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"]
+        self.hopsize = self.config["preprocessing"]["stft"]["hop_length"]
+        self.duration = self.config["preprocessing"]["audio"]["duration"]
+        self.target_length = int(self.duration * self.sampling_rate / self.hopsize)
+
+        self.mixup = self.config["augmentation"]["mixup"]
+
+        # Calculate parameter derivations
+        # self.waveform_sample_length = int(self.target_length * self.hopsize)
+
+        # if (self.config["balance_sampling_weight"]):
+        #     self.samples_weight = np.loadtxt(
+        #         self.config["balance_sampling_weight"], delimiter=","
+        #     )
+
+        if "train" not in self.split:
+            self.mixup = 0.0
+            # self.freqm = 0
+            # self.timem = 0
+
+    def _relative_path_to_absolute_path(self, metadata, dataset_name):
+        root_path = self.get_dataset_root_path(dataset_name)
+        for i in range(len(metadata["data"])):
+            assert "wav" in metadata["data"][i].keys(), metadata["data"][i]
+            assert metadata["data"][i]["wav"][0] != "/", (
+                "The dataset metadata should only contain relative path to the audio file: "
+                + str(metadata["data"][i]["wav"])
+            )
+            metadata["data"][i]["wav"] = os.path.join(
+                root_path, metadata["data"][i]["wav"]
+            )
+        return metadata
+
+    def build_dataset(self):
+        self.data = []
+        print("Build dataset split %s from %s" % (self.split, self.dataset_name))
+        if type(self.dataset_name) is str:
+            data_json = load_json(
+                self.get_dataset_metadata_path(self.dataset_name, key=self.split)
+            )
+            data_json = self._relative_path_to_absolute_path(
+                data_json, self.dataset_name
+            )
+            self.data = data_json["data"]
+        elif type(self.dataset_name) is list:
+            for dataset_name in self.dataset_name:
+                data_json = load_json(
+                    self.get_dataset_metadata_path(dataset_name, key=self.split)
+                )
+                data_json = self._relative_path_to_absolute_path(
+                    data_json, dataset_name
+                )
+                self.data += data_json["data"]
+        else:
+            raise Exception("Invalid data format")
+        print("Data size: {}".format(len(self.data)))
+
+    def build_dsp(self):
+        self.STFT = Audio.stft.TacotronSTFT(
+            self.config["preprocessing"]["stft"]["filter_length"],
+            self.config["preprocessing"]["stft"]["hop_length"],
+            self.config["preprocessing"]["stft"]["win_length"],
+            self.config["preprocessing"]["mel"]["n_mel_channels"],
+            self.config["preprocessing"]["audio"]["sampling_rate"],
+            self.config["preprocessing"]["mel"]["mel_fmin"],
+            self.config["preprocessing"]["mel"]["mel_fmax"],
+        )
+        # self.stft_transform = torchaudio.transforms.Spectrogram(
+        #     n_fft=1024, hop_length=160
+        # )
+        # self.melscale_transform = torchaudio.transforms.MelScale(
+        #     sample_rate=16000, n_stft=1024 // 2 + 1, n_mels=64
+        # )
+
+    def build_id_to_label(self):
+        id2label = {}
+        id2num = {}
+        num2label = {}
+        class_label_indices_path = self.get_dataset_metadata_path(
+            dataset=self.config["data"]["class_label_indices"],
+            key="class_label_indices",
+        )
+        if class_label_indices_path is not None:
+            df = pd.read_csv(class_label_indices_path)
+            for _, row in df.iterrows():
+                index, mid, display_name = row["index"], row["mid"], row["display_name"]
+                id2label[mid] = display_name
+                id2num[mid] = index
+                num2label[index] = display_name
+            self.id2label, self.index_dict, self.num2label = id2label, id2num, num2label
+        else:
+            self.id2label, self.index_dict, self.num2label = {}, {}, {}
+
+    def resample(self, waveform, sr):
+        waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate)
+        # waveform = librosa.resample(waveform, sr, self.sampling_rate)
+        return waveform
+
+        # if sr == 16000:
+        #     return waveform
+        # if sr == 32000 and self.sampling_rate == 16000:
+        #     waveform = waveform[::2]
+        #     return waveform
+        # if sr == 48000 and self.sampling_rate == 16000:
+        #     waveform = waveform[::3]
+        #     return waveform
+        # else:
+        #     raise ValueError(
+        #         "We currently only support 16k audio generation. You need to resample you audio file to 16k, 32k, or 48k: %s, %s"
+        #         % (sr, self.sampling_rate)
+        #     )
+
+    def normalize_wav(self, waveform):
+        waveform = waveform - np.mean(waveform)
+        waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
+        return waveform * 0.5  # Manually limit the maximum amplitude into 0.5
+
+    def random_segment_wav(self, waveform, target_length):
+        waveform_length = waveform.shape[-1]
+        assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
+
+        # Too short
+        if (waveform_length - target_length) <= 0:
+            return waveform, 0
+
+        random_start = int(self.random_uniform(0, waveform_length - target_length))
+        return waveform[:, random_start : random_start + target_length], random_start
+
+    def pad_wav(self, waveform, target_length):
+        waveform_length = waveform.shape[-1]
+        assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
+
+        if waveform_length == target_length:
+            return waveform
+
+        # Pad
+        temp_wav = np.zeros((1, target_length), dtype=np.float32)
+        if self.pad_wav_start_sample is None:
+            rand_start = int(self.random_uniform(0, target_length - waveform_length))
+        else:
+            rand_start = 0
+
+        temp_wav[:, rand_start : rand_start + waveform_length] = waveform
+        return temp_wav
+
+    def trim_wav(self, waveform):
+        if np.max(np.abs(waveform)) < 0.0001:
+            return waveform
+
+        def detect_leading_silence(waveform, threshold=0.0001):
+            chunk_size = 1000
+            waveform_length = waveform.shape[0]
+            start = 0
+            while start + chunk_size < waveform_length:
+                if np.max(np.abs(waveform[start : start + chunk_size])) < threshold:
+                    start += chunk_size
+                else:
+                    break
+            return start
+
+        def detect_ending_silence(waveform, threshold=0.0001):
+            chunk_size = 1000
+            waveform_length = waveform.shape[0]
+            start = waveform_length
+            while start - chunk_size > 0:
+                if np.max(np.abs(waveform[start - chunk_size : start])) < threshold:
+                    start -= chunk_size
+                else:
+                    break
+            if start == waveform_length:
+                return start
+            else:
+                return start + chunk_size
+
+        start = detect_leading_silence(waveform)
+        end = detect_ending_silence(waveform)
+
+        return waveform[start:end]
+
+    def read_wav_file(self, filename):
+        # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
+        waveform, sr = torchaudio.load(filename)
+
+        waveform, random_start = self.random_segment_wav(
+            waveform, target_length=int(sr * self.duration)
+        )
+
+        waveform = self.resample(waveform, sr)
+        # random_start = int(random_start * (self.sampling_rate / sr))
+
+        waveform = waveform.numpy()[0, ...]
+
+        waveform = self.normalize_wav(waveform)
+
+        if self.trim_wav:
+            waveform = self.trim_wav(waveform)
+
+        waveform = waveform[None, ...]
+        waveform = self.pad_wav(
+            waveform, target_length=int(self.sampling_rate * self.duration)
+        )
+        return waveform, random_start
+
+    def mix_two_waveforms(self, waveform1, waveform2):
+        mix_lambda = np.random.beta(5, 5)
+        mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2
+        return self.normalize_wav(mix_waveform), mix_lambda
+
+    def read_audio_file(self, filename, filename2=None):
+        if os.path.exists(filename):
+            waveform, random_start = self.read_wav_file(filename)
+        else:
+            print(
+                'Warning [dataset.py]: The wav path "',
+                filename,
+                '" is not find in the metadata. Use empty waveform instead.',
+            )
+            target_length = int(self.sampling_rate * self.duration)
+            waveform = torch.zeros((1, target_length))
+            random_start = 0
+
+        mix_lambda = 0.0
+        # log_mel_spec, stft = self.wav_feature_extraction_torchaudio(waveform) # this line is faster, but this implementation is not aligned with HiFi-GAN
+        if not self.waveform_only:
+            log_mel_spec, stft = self.wav_feature_extraction(waveform)
+        else:
+            # Load waveform data only
+            # Use zero array to keep the format unified
+            log_mel_spec, stft = None, None
+
+        return log_mel_spec, stft, mix_lambda, waveform, random_start
+
+    def get_sample_text_caption(self, datum, mix_datum, label_indices):
+        text = self.label_indices_to_text(datum, label_indices)
+        if mix_datum is not None:
+            text += " " + self.label_indices_to_text(mix_datum, label_indices)
+        return text
+
+    # This one is significantly slower than "wav_feature_extraction_torchaudio" if num_worker > 1
+    def wav_feature_extraction(self, waveform):
+        waveform = waveform[0, ...]
+        waveform = torch.FloatTensor(waveform)
+
+        log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT)
+
+        log_mel_spec = torch.FloatTensor(log_mel_spec.T)
+        stft = torch.FloatTensor(stft.T)
+
+        log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
+        return log_mel_spec, stft
+
+    # @profile
+    # def wav_feature_extraction_torchaudio(self, waveform):
+    #     waveform = waveform[0, ...]
+    #     waveform = torch.FloatTensor(waveform)
+
+    #     stft = self.stft_transform(waveform)
+    #     mel_spec = self.melscale_transform(stft)
+    #     log_mel_spec = torch.log(mel_spec + 1e-7)
+
+    #     log_mel_spec = torch.FloatTensor(log_mel_spec.T)
+    #     stft = torch.FloatTensor(stft.T)
+
+    #     log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
+    #     return log_mel_spec, stft
+
+    def pad_spec(self, log_mel_spec):
+        n_frames = log_mel_spec.shape[0]
+        p = self.target_length - n_frames
+        # cut and pad
+        if p > 0:
+            m = torch.nn.ZeroPad2d((0, 0, 0, p))
+            log_mel_spec = m(log_mel_spec)
+        elif p < 0:
+            log_mel_spec = log_mel_spec[0 : self.target_length, :]
+
+        if log_mel_spec.size(-1) % 2 != 0:
+            log_mel_spec = log_mel_spec[..., :-1]
+
+        return log_mel_spec
+
+    def _read_datum_caption(self, datum):
+        caption_keys = [x for x in datum.keys() if ("caption" in x)]
+        random_index = torch.randint(0, len(caption_keys), (1,))[0].item()
+        return datum[caption_keys[random_index]]
+
+    def _is_contain_caption(self, datum):
+        caption_keys = [x for x in datum.keys() if ("caption" in x)]
+        return len(caption_keys) > 0
+
+    def label_indices_to_text(self, datum, label_indices):
+        if self._is_contain_caption(datum):
+            return self._read_datum_caption(datum)
+        elif "label" in datum.keys():
+            name_indices = torch.where(label_indices > 0.1)[0]
+            # description_header = "This audio contains the sound of "
+            description_header = ""
+            labels = ""
+            for id, each in enumerate(name_indices):
+                if id == len(name_indices) - 1:
+                    labels += "%s." % self.num2label[int(each)]
+                else:
+                    labels += "%s, " % self.num2label[int(each)]
+            return description_header + labels
+        else:
+            return ""  # TODO, if both label and caption are not provided, return empty string
+
+    def random_uniform(self, start, end):
+        val = torch.rand(1).item()
+        return start + (end - start) * val
+
+    def frequency_masking(self, log_mel_spec, freqm):
+        bs, freq, tsteps = log_mel_spec.size()
+        mask_len = int(self.random_uniform(freqm // 8, freqm))
+        mask_start = int(self.random_uniform(start=0, end=freq - mask_len))
+        log_mel_spec[:, mask_start : mask_start + mask_len, :] *= 0.0
+        return log_mel_spec
+
+    def time_masking(self, log_mel_spec, timem):
+        bs, freq, tsteps = log_mel_spec.size()
+        mask_len = int(self.random_uniform(timem // 8, timem))
+        mask_start = int(self.random_uniform(start=0, end=tsteps - mask_len))
+        log_mel_spec[:, :, mask_start : mask_start + mask_len] *= 0.0
+        return log_mel_spec
diff --git a/audioldm2/utilities/model.py b/audioldm2/utilities/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..ffefac1212b85bfb8c4992371dbdf6d500a969e3
--- /dev/null
+++ b/audioldm2/utilities/model.py
@@ -0,0 +1,121 @@
+import torch
+
+import audioldm2.hifigan as hifigan
+
+
+def get_vocoder_config():
+    return {
+        "resblock": "1",
+        "num_gpus": 6,
+        "batch_size": 16,
+        "learning_rate": 0.0002,
+        "adam_b1": 0.8,
+        "adam_b2": 0.99,
+        "lr_decay": 0.999,
+        "seed": 1234,
+        "upsample_rates": [5, 4, 2, 2, 2],
+        "upsample_kernel_sizes": [16, 16, 8, 4, 4],
+        "upsample_initial_channel": 1024,
+        "resblock_kernel_sizes": [3, 7, 11],
+        "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        "segment_size": 8192,
+        "num_mels": 64,
+        "num_freq": 1025,
+        "n_fft": 1024,
+        "hop_size": 160,
+        "win_size": 1024,
+        "sampling_rate": 16000,
+        "fmin": 0,
+        "fmax": 8000,
+        "fmax_for_loss": None,
+        "num_workers": 4,
+        "dist_config": {
+            "dist_backend": "nccl",
+            "dist_url": "tcp://localhost:54321",
+            "world_size": 1,
+        },
+    }
+
+
+def get_available_checkpoint_keys(model, ckpt):
+    state_dict = torch.load(ckpt)["state_dict"]
+    current_state_dict = model.state_dict()
+    new_state_dict = {}
+    for k in state_dict.keys():
+        if (
+            k in current_state_dict.keys()
+            and current_state_dict[k].size() == state_dict[k].size()
+        ):
+            new_state_dict[k] = state_dict[k]
+        else:
+            print("==> WARNING: Skipping %s" % k)
+    print(
+        "%s out of %s keys are matched"
+        % (len(new_state_dict.keys()), len(state_dict.keys()))
+    )
+    return new_state_dict
+
+
+def get_param_num(model):
+    num_param = sum(param.numel() for param in model.parameters())
+    return num_param
+
+
+def torch_version_orig_mod_remove(state_dict):
+    new_state_dict = {}
+    new_state_dict["generator"] = {}
+    for key in state_dict["generator"].keys():
+        if "_orig_mod." in key:
+            new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[
+                "generator"
+            ][key]
+        else:
+            new_state_dict["generator"][key] = state_dict["generator"][key]
+    return new_state_dict
+
+
+def get_vocoder(config, device, mel_bins):
+    name = "HiFi-GAN"
+    speaker = ""
+    if name == "MelGAN":
+        if speaker == "LJSpeech":
+            vocoder = torch.hub.load(
+                "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
+            )
+        elif speaker == "universal":
+            vocoder = torch.hub.load(
+                "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
+            )
+        vocoder.mel2wav.eval()
+        vocoder.mel2wav.to(device)
+    elif name == "HiFi-GAN":
+        config = get_vocoder_config()
+        config = hifigan.AttrDict(config)
+        vocoder = hifigan.Generator_old(config)
+        # print("Load hifigan/g_01080000")
+        # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
+        # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
+        # ckpt = torch_version_orig_mod_remove(ckpt)
+        # vocoder.load_state_dict(ckpt["generator"])
+        vocoder.eval()
+        vocoder.remove_weight_norm()
+        vocoder.to(device)
+    return vocoder
+
+
+def vocoder_infer(mels, vocoder, lengths=None):
+    with torch.no_grad():
+        wavs = vocoder(mels).squeeze(1)
+
+    wavs = (wavs.cpu().numpy() * 32768).astype("int16")
+
+    if lengths is not None:
+        wavs = wavs[:, :lengths]
+
+    # wavs = [wav for wav in wavs]
+
+    # for i in range(len(mels)):
+    #     if lengths is not None:
+    #         wavs[i] = wavs[i][: lengths[i]]
+
+    return wavs
diff --git a/audioldm2/utilities/sampler.py b/audioldm2/utilities/sampler.py
new file mode 100755
index 0000000000000000000000000000000000000000..cdaf4882715f53f39ead8bf71fb3dccc29cd8b94
--- /dev/null
+++ b/audioldm2/utilities/sampler.py
@@ -0,0 +1,588 @@
+from typing import Iterator, List, Optional, Union
+from collections import Counter
+import logging
+from operator import itemgetter
+import random
+
+import numpy as np
+
+from torch.utils.data import DistributedSampler
+from torch.utils.data.sampler import Sampler
+
+LOGGER = logging.getLogger(__name__)
+
+from torch.utils.data import Dataset, Sampler
+
+
+class DatasetFromSampler(Dataset):
+    """Dataset to create indexes from `Sampler`.
+    Args:
+        sampler: PyTorch sampler
+    """
+
+    def __init__(self, sampler: Sampler):
+        """Initialisation for DatasetFromSampler."""
+        self.sampler = sampler
+        self.sampler_list = None
+
+    def __getitem__(self, index: int):
+        """Gets element of the dataset.
+        Args:
+            index: index of the element in the dataset
+        Returns:
+            Single element by index
+        """
+        if self.sampler_list is None:
+            self.sampler_list = list(self.sampler)
+        return self.sampler_list[index]
+
+    def __len__(self) -> int:
+        """
+        Returns:
+            int: length of the dataset
+        """
+        return len(self.sampler)
+
+
+class BalanceClassSampler(Sampler):
+    """Allows you to create stratified sample on unbalanced classes.
+
+    Args:
+        labels: list of class label for each elem in the dataset
+        mode: Strategy to balance classes.
+            Must be one of [downsampling, upsampling]
+
+    Python API examples:
+
+    .. code-block:: python
+
+        import os
+        from torch import nn, optim
+        from torch.utils.data import DataLoader
+        from catalyst import dl
+        from catalyst.data import ToTensor, BalanceClassSampler
+        from catalyst.contrib.datasets import MNIST
+
+        train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
+        train_labels = train_data.targets.cpu().numpy().tolist()
+        train_sampler = BalanceClassSampler(train_labels, mode=5000)
+        valid_data = MNIST(os.getcwd(), train=False)
+
+        loaders = {
+            "train": DataLoader(train_data, sampler=train_sampler, batch_size=32),
+            "valid": DataLoader(valid_data, batch_size=32),
+        }
+
+        model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
+        criterion = nn.CrossEntropyLoss()
+        optimizer = optim.Adam(model.parameters(), lr=0.02)
+
+        runner = dl.SupervisedRunner()
+        # model training
+        runner.train(
+            model=model,
+            criterion=criterion,
+            optimizer=optimizer,
+            loaders=loaders,
+            num_epochs=1,
+            logdir="./logs",
+            valid_loader="valid",
+            valid_metric="loss",
+            minimize_valid_metric=True,
+            verbose=True,
+        )
+    """
+
+    def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"):
+        """Sampler initialisation."""
+        super().__init__(labels)
+
+        labels = np.array(labels)
+        samples_per_class = {label: (labels == label).sum() for label in set(labels)}
+
+        self.lbl2idx = {
+            label: np.arange(len(labels))[labels == label].tolist()
+            for label in set(labels)
+        }
+
+        if isinstance(mode, str):
+            assert mode in ["downsampling", "upsampling"]
+
+        if isinstance(mode, int) or mode == "upsampling":
+            samples_per_class = (
+                mode if isinstance(mode, int) else max(samples_per_class.values())
+            )
+        else:
+            samples_per_class = min(samples_per_class.values())
+
+        self.labels = labels
+        self.samples_per_class = samples_per_class
+        self.length = self.samples_per_class * len(set(labels))
+
+    def __iter__(self) -> Iterator[int]:
+        """
+        Returns:
+            iterator of indices of stratified sample
+        """
+        indices = []
+        for key in sorted(self.lbl2idx):
+            replace_flag = self.samples_per_class > len(self.lbl2idx[key])
+            indices += np.random.choice(
+                self.lbl2idx[key], self.samples_per_class, replace=replace_flag
+            ).tolist()
+        assert len(indices) == self.length
+        np.random.shuffle(indices)
+
+        return iter(indices)
+
+    def __len__(self) -> int:
+        """
+        Returns:
+             length of result sample
+        """
+        return self.length
+
+
+class BatchBalanceClassSampler(Sampler):
+    """
+    This kind of sampler can be used for both metric learning and classification task.
+
+    BatchSampler with the given strategy for the C unique classes dataset:
+    - Selection `num_classes` of C classes for each batch
+    - Selection `num_samples` instances for each class in the batch
+    The epoch ends after `num_batches`.
+    So, the batch sise is `num_classes` * `num_samples`.
+
+    One of the purposes of this sampler is to be used for
+    forming triplets and pos/neg pairs inside the batch.
+    To guarante existance of these pairs in the batch,
+    `num_classes` and `num_samples` should be > 1. (1)
+
+    This type of sampling can be found in the classical paper of Person Re-Id,
+    where P (`num_classes`) equals 32 and K (`num_samples`) equals 4:
+    `In Defense of the Triplet Loss for Person Re-Identification`_.
+
+    Args:
+        labels: list of classes labeles for each elem in the dataset
+        num_classes: number of classes in a batch, should be > 1
+        num_samples: number of instances of each class in a batch, should be > 1
+        num_batches: number of batches in epoch
+            (default = len(labels) // (num_classes * num_samples))
+
+    .. _In Defense of the Triplet Loss for Person Re-Identification:
+        https://arxiv.org/abs/1703.07737
+
+    Python API examples:
+
+    .. code-block:: python
+
+        import os
+        from torch import nn, optim
+        from torch.utils.data import DataLoader
+        from catalyst import dl
+        from catalyst.data import ToTensor, BatchBalanceClassSampler
+        from catalyst.contrib.datasets import MNIST
+
+        train_data = MNIST(os.getcwd(), train=True, download=True)
+        train_labels = train_data.targets.cpu().numpy().tolist()
+        train_sampler = BatchBalanceClassSampler(
+            train_labels, num_classes=10, num_samples=4)
+        valid_data = MNIST(os.getcwd(), train=False)
+
+        loaders = {
+            "train": DataLoader(train_data, batch_sampler=train_sampler),
+            "valid": DataLoader(valid_data, batch_size=32),
+        }
+
+        model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
+        criterion = nn.CrossEntropyLoss()
+        optimizer = optim.Adam(model.parameters(), lr=0.02)
+
+        runner = dl.SupervisedRunner()
+        # model training
+        runner.train(
+            model=model,
+            criterion=criterion,
+            optimizer=optimizer,
+            loaders=loaders,
+            num_epochs=1,
+            logdir="./logs",
+            valid_loader="valid",
+            valid_metric="loss",
+            minimize_valid_metric=True,
+            verbose=True,
+        )
+    """
+
+    def __init__(
+        self,
+        labels: Union[List[int], np.ndarray],
+        num_classes: int,
+        num_samples: int,
+        num_batches: int = None,
+    ):
+        """Sampler initialisation."""
+        super().__init__(labels)
+        classes = set(labels)
+
+        assert isinstance(num_classes, int) and isinstance(num_samples, int)
+        assert (1 < num_classes <= len(classes)) and (1 < num_samples)
+        assert all(
+            n > 1 for n in Counter(labels).values()
+        ), "Each class shoud contain at least 2 instances to fit (1)"
+
+        labels = np.array(labels)
+        self._labels = list(set(labels.tolist()))
+        self._num_classes = num_classes
+        self._num_samples = num_samples
+        self._batch_size = self._num_classes * self._num_samples
+        self._num_batches = num_batches or len(labels) // self._batch_size
+        self.lbl2idx = {
+            label: np.arange(len(labels))[labels == label].tolist()
+            for label in set(labels)
+        }
+
+    @property
+    def batch_size(self) -> int:
+        """
+        Returns:
+            this value should be used in DataLoader as batch size
+        """
+        return self._batch_size
+
+    @property
+    def batches_in_epoch(self) -> int:
+        """
+        Returns:
+            number of batches in an epoch
+        """
+        return self._num_batches
+
+    def __len__(self) -> int:
+        """
+        Returns:
+            number of samples in an epoch
+        """
+        return self._num_batches  # * self._batch_size
+
+    def __iter__(self) -> Iterator[int]:
+        """
+        Returns:
+            indeces for sampling dataset elems during an epoch
+        """
+        indices = []
+        for _ in range(self._num_batches):
+            batch_indices = []
+            classes_for_batch = random.sample(self._labels, self._num_classes)
+            while self._num_classes != len(set(classes_for_batch)):
+                classes_for_batch = random.sample(self._labels, self._num_classes)
+            for cls_id in classes_for_batch:
+                replace_flag = self._num_samples > len(self.lbl2idx[cls_id])
+                batch_indices += np.random.choice(
+                    self.lbl2idx[cls_id], self._num_samples, replace=replace_flag
+                ).tolist()
+            indices.append(batch_indices)
+        return iter(indices)
+
+
+class DynamicBalanceClassSampler(Sampler):
+    """
+    This kind of sampler can be used for classification tasks with significant
+    class imbalance.
+
+    The idea of this sampler that we start with the original class distribution
+    and gradually move to uniform class distribution like with downsampling.
+
+    Let's define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class
+    i and :math: #C_min is a size of the rarest class, so :math: D_i define
+    class distribution. Also define :math: g(n_epoch) is a exponential
+    scheduler. On each epoch current :math: D_i  calculated as
+    :math: current D_i  = D_i ^ g(n_epoch),
+    after this data samples according this distribution.
+
+    Notes:
+         In the end of the training, epochs will contain only
+         min_size_class * n_classes examples. So, possible it will not
+         necessary to do validation on each epoch. For this reason use
+         ControlFlowCallback.
+
+    Examples:
+
+        >>> import torch
+        >>> import numpy as np
+
+        >>> from catalyst.data import DynamicBalanceClassSampler
+        >>> from torch.utils import data
+
+        >>> features = torch.Tensor(np.random.random((200, 100)))
+        >>> labels = np.random.randint(0, 4, size=(200,))
+        >>> sampler = DynamicBalanceClassSampler(labels)
+        >>> labels = torch.LongTensor(labels)
+        >>> dataset = data.TensorDataset(features, labels)
+        >>> loader = data.dataloader.DataLoader(dataset, batch_size=8)
+
+        >>> for batch in loader:
+        >>>     b_features, b_labels = batch
+
+    Sampler was inspired by https://arxiv.org/abs/1901.06783
+    """
+
+    def __init__(
+        self,
+        labels: List[Union[int, str]],
+        exp_lambda: float = 0.9,
+        start_epoch: int = 0,
+        max_d: Optional[int] = None,
+        mode: Union[str, int] = "downsampling",
+        ignore_warning: bool = False,
+    ):
+        """
+        Args:
+            labels: list of labels for each elem in the dataset
+            exp_lambda: exponent figure for schedule
+            start_epoch: start epoch number, can be useful for multi-stage
+            experiments
+            max_d: if not None, limit on the difference between the most
+            frequent and the rarest classes, heuristic
+            mode: number of samples per class in the end of training. Must be
+            "downsampling" or number. Before change it, make sure that you
+            understand how does it work
+            ignore_warning: ignore warning about min class size
+        """
+        assert isinstance(start_epoch, int)
+        assert 0 < exp_lambda < 1, "exp_lambda must be in (0, 1)"
+        super().__init__(labels)
+        self.exp_lambda = exp_lambda
+        if max_d is None:
+            max_d = np.inf
+        self.max_d = max_d
+        self.epoch = start_epoch
+        labels = np.array(labels)
+        samples_per_class = Counter(labels)
+        self.min_class_size = min(samples_per_class.values())
+
+        if self.min_class_size < 100 and not ignore_warning:
+            LOGGER.warning(
+                f"the smallest class contains only"
+                f" {self.min_class_size} examples. At the end of"
+                f" training, epochs will contain only"
+                f" {self.min_class_size * len(samples_per_class)}"
+                f" examples"
+            )
+
+        self.original_d = {
+            key: value / self.min_class_size for key, value in samples_per_class.items()
+        }
+        self.label2idxes = {
+            label: np.arange(len(labels))[labels == label].tolist()
+            for label in set(labels)
+        }
+
+        if isinstance(mode, int):
+            self.min_class_size = mode
+        else:
+            assert mode == "downsampling"
+
+        self.labels = labels
+        self._update()
+
+    def _update(self) -> None:
+        """Update d coefficients."""
+        current_d = {
+            key: min(value ** self._exp_scheduler(), self.max_d)
+            for key, value in self.original_d.items()
+        }
+        samples_per_classes = {
+            key: int(value * self.min_class_size) for key, value in current_d.items()
+        }
+        self.samples_per_classes = samples_per_classes
+        self.length = np.sum(list(samples_per_classes.values()))
+        self.epoch += 1
+
+    def _exp_scheduler(self) -> float:
+        return self.exp_lambda**self.epoch
+
+    def __iter__(self) -> Iterator[int]:
+        """
+        Returns:
+            iterator of indices of stratified sample
+        """
+        indices = []
+        for key in sorted(self.label2idxes):
+            samples_per_class = self.samples_per_classes[key]
+            replace_flag = samples_per_class > len(self.label2idxes[key])
+            indices += np.random.choice(
+                self.label2idxes[key], samples_per_class, replace=replace_flag
+            ).tolist()
+        assert len(indices) == self.length
+        np.random.shuffle(indices)
+        self._update()
+        return iter(indices)
+
+    def __len__(self) -> int:
+        """
+        Returns:
+             length of result sample
+        """
+        return self.length
+
+
+class MiniEpochSampler(Sampler):
+    """
+    Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``.
+
+    Args:
+        data_len: Size of the dataset
+        mini_epoch_len: Num samples from the dataset used in one
+          mini epoch.
+        drop_last: If ``True``, sampler will drop the last batches
+          if its size would be less than ``batches_per_epoch``
+        shuffle: one of  ``"always"``, ``"real_epoch"``, or `None``.
+          The sampler will shuffle indices
+          > "per_mini_epoch" - every mini epoch (every ``__iter__`` call)
+          > "per_epoch" -- every real epoch
+          > None -- don't shuffle
+
+    Example:
+        >>> MiniEpochSampler(len(dataset), mini_epoch_len=100)
+        >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, drop_last=True)
+        >>> MiniEpochSampler(len(dataset), mini_epoch_len=100,
+        >>>     shuffle="per_epoch")
+    """
+
+    def __init__(
+        self,
+        data_len: int,
+        mini_epoch_len: int,
+        drop_last: bool = False,
+        shuffle: str = None,
+    ):
+        """Sampler initialisation."""
+        super().__init__(None)
+
+        self.data_len = int(data_len)
+        self.mini_epoch_len = int(mini_epoch_len)
+
+        self.steps = int(data_len / self.mini_epoch_len)
+        self.state_i = 0
+
+        has_reminder = data_len - self.steps * mini_epoch_len > 0
+        if self.steps == 0:
+            self.divider = 1
+        elif has_reminder and not drop_last:
+            self.divider = self.steps + 1
+        else:
+            self.divider = self.steps
+
+        self._indices = np.arange(self.data_len)
+        self.indices = self._indices
+        self.end_pointer = max(self.data_len, self.mini_epoch_len)
+
+        if not (shuffle is None or shuffle in ["per_mini_epoch", "per_epoch"]):
+            raise ValueError(
+                "Shuffle must be one of ['per_mini_epoch', 'per_epoch']. "
+                + f"Got {shuffle}"
+            )
+        self.shuffle_type = shuffle
+
+    def shuffle(self) -> None:
+        """Shuffle sampler indices."""
+        if self.shuffle_type == "per_mini_epoch" or (
+            self.shuffle_type == "per_epoch" and self.state_i == 0
+        ):
+            if self.data_len >= self.mini_epoch_len:
+                self.indices = self._indices
+                np.random.shuffle(self.indices)
+            else:
+                self.indices = np.random.choice(
+                    self._indices, self.mini_epoch_len, replace=True
+                )
+
+    def __iter__(self) -> Iterator[int]:
+        """Iterate over sampler.
+
+        Returns:
+            python iterator
+        """
+        self.state_i = self.state_i % self.divider
+        self.shuffle()
+
+        start = self.state_i * self.mini_epoch_len
+        stop = (
+            self.end_pointer
+            if (self.state_i == self.steps)
+            else (self.state_i + 1) * self.mini_epoch_len
+        )
+        indices = self.indices[start:stop].tolist()
+
+        self.state_i += 1
+        return iter(indices)
+
+    def __len__(self) -> int:
+        """
+        Returns:
+            int: length of the mini-epoch
+        """
+        return self.mini_epoch_len
+
+
+class DistributedSamplerWrapper(DistributedSampler):
+    """
+    Wrapper over `Sampler` for distributed training.
+    Allows you to use any sampler in distributed mode.
+
+    It is especially useful in conjunction with
+    `torch.nn.parallel.DistributedDataParallel`. In such case, each
+    process can pass a DistributedSamplerWrapper instance as a DataLoader
+    sampler, and load a subset of subsampled data of the original dataset
+    that is exclusive to it.
+
+    .. note::
+        Sampler is assumed to be of constant size.
+    """
+
+    def __init__(
+        self,
+        sampler,
+        num_replicas: Optional[int] = None,
+        rank: Optional[int] = None,
+        shuffle: bool = True,
+    ):
+        """
+
+        Args:
+            sampler: Sampler used for subsampling
+            num_replicas (int, optional): Number of processes participating in
+                distributed training
+            rank (int, optional): Rank of the current process
+                within ``num_replicas``
+            shuffle (bool, optional): If true (default),
+                sampler will shuffle the indices
+        """
+        super(DistributedSamplerWrapper, self).__init__(
+            DatasetFromSampler(sampler),
+            num_replicas=num_replicas,
+            rank=rank,
+            shuffle=shuffle,
+        )
+        self.sampler = sampler
+
+    def __iter__(self) -> Iterator[int]:
+        """Iterate over sampler.
+
+        Returns:
+            python iterator
+        """
+        self.dataset = DatasetFromSampler(self.sampler)
+        indexes_of_indexes = super().__iter__()
+        subsampler_indexes = self.dataset
+        return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
+
+
+__all__ = [
+    "BalanceClassSampler",
+    "BatchBalanceClassSampler",
+    "DistributedSamplerWrapper",
+    "DynamicBalanceClassSampler",
+    "MiniEpochSampler",
+]
diff --git a/audioldm2/utilities/tools.py b/audioldm2/utilities/tools.py
new file mode 100755
index 0000000000000000000000000000000000000000..a647a272cdf076b2ae9785bc83724ebd7a897642
--- /dev/null
+++ b/audioldm2/utilities/tools.py
@@ -0,0 +1,541 @@
+# Author: Haohe Liu
+# Email: haoheliu@gmail.com
+# Date: 11 Feb 2023
+
+import os
+import json
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import matplotlib
+from scipy.io import wavfile
+from matplotlib import pyplot as plt
+
+
+matplotlib.use("Agg")
+
+import hashlib
+import os
+
+import requests
+from tqdm import tqdm
+
+URL_MAP = {
+    "vggishish_lpaps": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt",
+    "vggishish_mean_std_melspec_10s_22050hz": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt",
+    "melception": "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt",
+}
+
+CKPT_MAP = {
+    "vggishish_lpaps": "vggishish16.pt",
+    "vggishish_mean_std_melspec_10s_22050hz": "train_means_stds_melspec_10s_22050hz.txt",
+    "melception": "melception-21-05-10T09-28-40.pt",
+}
+
+MD5_MAP = {
+    "vggishish_lpaps": "197040c524a07ccacf7715d7080a80bd",
+    "vggishish_mean_std_melspec_10s_22050hz": "f449c6fd0e248936c16f6d22492bb625",
+    "melception": "a71a41041e945b457c7d3d814bbcf72d",
+}
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def load_json(fname):
+    with open(fname, "r") as f:
+        data = json.load(f)
+        return data
+
+
+def read_json(dataset_json_file):
+    with open(dataset_json_file, "r") as fp:
+        data_json = json.load(fp)
+    return data_json["data"]
+
+
+def copy_test_subset_data(metadata, testset_copy_target_path):
+    # metadata = read_json(testset_metadata)
+    os.makedirs(testset_copy_target_path, exist_ok=True)
+    if len(os.listdir(testset_copy_target_path)) == len(metadata):
+        return
+    else:
+        # delete files in folder testset_copy_target_path
+        for file in os.listdir(testset_copy_target_path):
+            try:
+                os.remove(os.path.join(testset_copy_target_path, file))
+            except Exception as e:
+                print(e)
+
+    print("Copying test subset data to {}".format(testset_copy_target_path))
+    for each in tqdm(metadata):
+        cmd = "cp {} {}".format(each["wav"], os.path.join(testset_copy_target_path))
+        os.system(cmd)
+
+
+def listdir_nohidden(path):
+    for f in os.listdir(path):
+        if not f.startswith("."):
+            yield f
+
+
+def get_restore_step(path):
+    checkpoints = os.listdir(path)
+    if os.path.exists(os.path.join(path, "final.ckpt")):
+        return "final.ckpt", 0
+    elif not os.path.exists(os.path.join(path, "last.ckpt")):
+        steps = [int(x.split(".ckpt")[0].split("step=")[1]) for x in checkpoints]
+        return checkpoints[np.argmax(steps)], np.max(steps)
+    else:
+        steps = []
+        for x in checkpoints:
+            if "last" in x:
+                if "-v" not in x:
+                    fname = "last.ckpt"
+                else:
+                    this_version = int(x.split(".ckpt")[0].split("-v")[1])
+                    steps.append(this_version)
+                    if len(steps) == 0 or this_version > np.max(steps):
+                        fname = "last-v%s.ckpt" % this_version
+        return fname, 0
+
+
+def download(url, local_path, chunk_size=1024):
+    os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+    with requests.get(url, stream=True) as r:
+        total_size = int(r.headers.get("content-length", 0))
+        with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+            with open(local_path, "wb") as f:
+                for data in r.iter_content(chunk_size=chunk_size):
+                    if data:
+                        f.write(data)
+                        pbar.update(chunk_size)
+
+
+def md5_hash(path):
+    with open(path, "rb") as f:
+        content = f.read()
+    return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+    assert name in URL_MAP
+    path = os.path.join(root, CKPT_MAP[name])
+    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+        print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+        download(URL_MAP[name], path)
+        md5 = md5_hash(path)
+        assert md5 == MD5_MAP[name], md5
+    return path
+
+
+class KeyNotFoundError(Exception):
+    def __init__(self, cause, keys=None, visited=None):
+        self.cause = cause
+        self.keys = keys
+        self.visited = visited
+        messages = list()
+        if keys is not None:
+            messages.append("Key not found: {}".format(keys))
+        if visited is not None:
+            messages.append("Visited: {}".format(visited))
+        messages.append("Cause:\n{}".format(cause))
+        message = "\n".join(messages)
+        super().__init__(message)
+
+
+def retrieve(
+    list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
+):
+    """Given a nested list or dict return the desired value at key expanding
+    callable nodes if necessary and :attr:`expand` is ``True``. The expansion
+    is done in-place.
+
+    Parameters
+    ----------
+        list_or_dict : list or dict
+            Possibly nested list or dictionary.
+        key : str
+            key/to/value, path like string describing all keys necessary to
+            consider to get to the desired value. List indices can also be
+            passed here.
+        splitval : str
+            String that defines the delimiter between keys of the
+            different depth levels in `key`.
+        default : obj
+            Value returned if :attr:`key` is not found.
+        expand : bool
+            Whether to expand callable nodes on the path or not.
+
+    Returns
+    -------
+        The desired value or if :attr:`default` is not ``None`` and the
+        :attr:`key` is not found returns ``default``.
+
+    Raises
+    ------
+        Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
+        ``None``.
+    """
+
+    keys = key.split(splitval)
+
+    success = True
+    try:
+        visited = []
+        parent = None
+        last_key = None
+        for key in keys:
+            if callable(list_or_dict):
+                if not expand:
+                    raise KeyNotFoundError(
+                        ValueError(
+                            "Trying to get past callable node with expand=False."
+                        ),
+                        keys=keys,
+                        visited=visited,
+                    )
+                list_or_dict = list_or_dict()
+                parent[last_key] = list_or_dict
+
+            last_key = key
+            parent = list_or_dict
+
+            try:
+                if isinstance(list_or_dict, dict):
+                    list_or_dict = list_or_dict[key]
+                else:
+                    list_or_dict = list_or_dict[int(key)]
+            except (KeyError, IndexError, ValueError) as e:
+                raise KeyNotFoundError(e, keys=keys, visited=visited)
+
+            visited += [key]
+        # final expansion of retrieved value
+        if expand and callable(list_or_dict):
+            list_or_dict = list_or_dict()
+            parent[last_key] = list_or_dict
+    except KeyNotFoundError as e:
+        if default is None:
+            raise e
+        else:
+            list_or_dict = default
+            success = False
+
+    if not pass_success:
+        return list_or_dict
+    else:
+        return list_or_dict, success
+
+
+def to_device(data, device):
+    if len(data) == 12:
+        (
+            ids,
+            raw_texts,
+            speakers,
+            texts,
+            src_lens,
+            max_src_len,
+            mels,
+            mel_lens,
+            max_mel_len,
+            pitches,
+            energies,
+            durations,
+        ) = data
+
+        speakers = torch.from_numpy(speakers).long().to(device)
+        texts = torch.from_numpy(texts).long().to(device)
+        src_lens = torch.from_numpy(src_lens).to(device)
+        mels = torch.from_numpy(mels).float().to(device)
+        mel_lens = torch.from_numpy(mel_lens).to(device)
+        pitches = torch.from_numpy(pitches).float().to(device)
+        energies = torch.from_numpy(energies).to(device)
+        durations = torch.from_numpy(durations).long().to(device)
+
+        return (
+            ids,
+            raw_texts,
+            speakers,
+            texts,
+            src_lens,
+            max_src_len,
+            mels,
+            mel_lens,
+            max_mel_len,
+            pitches,
+            energies,
+            durations,
+        )
+
+    if len(data) == 6:
+        (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
+
+        speakers = torch.from_numpy(speakers).long().to(device)
+        texts = torch.from_numpy(texts).long().to(device)
+        src_lens = torch.from_numpy(src_lens).to(device)
+
+        return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
+
+
+def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=""):
+    # if losses is not None:
+    #     logger.add_scalar("Loss/total_loss", losses[0], step)
+    #     logger.add_scalar("Loss/mel_loss", losses[1], step)
+    #     logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
+    #     logger.add_scalar("Loss/pitch_loss", losses[3], step)
+    #     logger.add_scalar("Loss/energy_loss", losses[4], step)
+    #     logger.add_scalar("Loss/duration_loss", losses[5], step)
+    #     if(len(losses) > 6):
+    #         logger.add_scalar("Loss/disc_loss", losses[6], step)
+    #         logger.add_scalar("Loss/fmap_loss", losses[7], step)
+    #         logger.add_scalar("Loss/r_loss", losses[8], step)
+    #         logger.add_scalar("Loss/g_loss", losses[9], step)
+    #         logger.add_scalar("Loss/gen_loss", losses[10], step)
+    #         logger.add_scalar("Loss/diff_loss", losses[11], step)
+
+    if fig is not None:
+        logger.add_figure(tag, fig)
+
+    if audio is not None:
+        audio = audio / (max(abs(audio)) * 1.1)
+        logger.add_audio(
+            tag,
+            audio,
+            sample_rate=sampling_rate,
+        )
+
+
+def get_mask_from_lengths(lengths, max_len=None):
+    batch_size = lengths.shape[0]
+    if max_len is None:
+        max_len = torch.max(lengths).item()
+
+    ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
+    mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
+
+    return mask
+
+
+def expand(values, durations):
+    out = list()
+    for value, d in zip(values, durations):
+        out += [value] * max(0, int(d))
+    return np.array(out)
+
+
+def synth_one_sample_val(
+    targets, predictions, vocoder, model_config, preprocess_config
+):
+    index = np.random.choice(list(np.arange(targets[6].size(0))))
+
+    basename = targets[0][index]
+    src_len = predictions[8][index].item()
+    mel_len = predictions[9][index].item()
+    mel_target = targets[6][index, :mel_len].detach().transpose(0, 1)
+
+    mel_prediction = predictions[0][index, :mel_len].detach().transpose(0, 1)
+    postnet_mel_prediction = predictions[1][index, :mel_len].detach().transpose(0, 1)
+    duration = targets[11][index, :src_len].detach().cpu().numpy()
+
+    if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
+        pitch = predictions[2][index, :src_len].detach().cpu().numpy()
+        pitch = expand(pitch, duration)
+    else:
+        pitch = predictions[2][index, :mel_len].detach().cpu().numpy()
+
+    if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
+        energy = predictions[3][index, :src_len].detach().cpu().numpy()
+        energy = expand(energy, duration)
+    else:
+        energy = predictions[3][index, :mel_len].detach().cpu().numpy()
+
+    with open(
+        os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
+    ) as f:
+        stats = json.load(f)
+        stats = stats["pitch"] + stats["energy"][:2]
+
+    # from datetime import datetime
+    # now = datetime.now()
+    # current_time = now.strftime("%D:%H:%M:%S")
+    # np.save(("mel_pred_%s.npy" % current_time).replace("/","-"), mel_prediction.cpu().numpy())
+    # np.save(("postnet_mel_prediction_%s.npy" % current_time).replace("/","-"), postnet_mel_prediction.cpu().numpy())
+    # np.save(("mel_target_%s.npy" % current_time).replace("/","-"), mel_target.cpu().numpy())
+
+    fig = plot_mel(
+        [
+            (mel_prediction.cpu().numpy(), pitch, energy),
+            (postnet_mel_prediction.cpu().numpy(), pitch, energy),
+            (mel_target.cpu().numpy(), pitch, energy),
+        ],
+        stats,
+        [
+            "Raw mel spectrogram prediction",
+            "Postnet mel prediction",
+            "Ground-Truth Spectrogram",
+        ],
+    )
+
+    if vocoder is not None:
+        from .model import vocoder_infer
+
+        wav_reconstruction = vocoder_infer(
+            mel_target.unsqueeze(0),
+            vocoder,
+            model_config,
+            preprocess_config,
+        )[0]
+        wav_prediction = vocoder_infer(
+            postnet_mel_prediction.unsqueeze(0),
+            vocoder,
+            model_config,
+            preprocess_config,
+        )[0]
+    else:
+        wav_reconstruction = wav_prediction = None
+
+    return fig, wav_reconstruction, wav_prediction, basename
+
+
+def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
+    if vocoder is not None:
+        from .model import vocoder_infer
+
+        wav_reconstruction = vocoder_infer(
+            mel_input.permute(0, 2, 1),
+            vocoder,
+        )
+        wav_prediction = vocoder_infer(
+            mel_prediction.permute(0, 2, 1),
+            vocoder,
+        )
+    else:
+        wav_reconstruction = wav_prediction = None
+
+    return wav_reconstruction, wav_prediction
+
+
+def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
+    # (diff_output, diff_loss, latent_loss) = diffusion
+
+    basenames = targets[0]
+
+    for i in range(len(predictions[1])):
+        basename = basenames[i]
+        src_len = predictions[8][i].item()
+        mel_len = predictions[9][i].item()
+        mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
+        # diff_output = diff_output[i, :mel_len].detach().transpose(0, 1)
+        # duration = predictions[5][i, :src_len].detach().cpu().numpy()
+        if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
+            pitch = predictions[2][i, :src_len].detach().cpu().numpy()
+            # pitch = expand(pitch, duration)
+        else:
+            pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
+        if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
+            energy = predictions[3][i, :src_len].detach().cpu().numpy()
+            # energy = expand(energy, duration)
+        else:
+            energy = predictions[3][i, :mel_len].detach().cpu().numpy()
+        # import ipdb; ipdb.set_trace()
+        with open(
+            os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
+        ) as f:
+            stats = json.load(f)
+            stats = stats["pitch"] + stats["energy"][:2]
+
+        fig = plot_mel(
+            [
+                (mel_prediction.cpu().numpy(), pitch, energy),
+            ],
+            stats,
+            ["Synthetized Spectrogram by PostNet"],
+        )
+        # np.save("{}_postnet.npy".format(basename), mel_prediction.cpu().numpy())
+        plt.savefig(os.path.join(path, "{}_postnet_2.png".format(basename)))
+        plt.close()
+
+    from .model import vocoder_infer
+
+    mel_predictions = predictions[1].transpose(1, 2)
+    lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
+    wav_predictions = vocoder_infer(
+        mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
+    )
+
+    sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
+    for wav, basename in zip(wav_predictions, basenames):
+        wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
+
+
+def plot_mel(data, titles=None):
+    fig, axes = plt.subplots(len(data), 1, squeeze=False)
+    if titles is None:
+        titles = [None for i in range(len(data))]
+
+    for i in range(len(data)):
+        mel = data[i]
+        axes[i][0].imshow(mel, origin="lower", aspect="auto")
+        axes[i][0].set_aspect(2.5, adjustable="box")
+        axes[i][0].set_ylim(0, mel.shape[0])
+        axes[i][0].set_title(titles[i], fontsize="medium")
+        axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
+        axes[i][0].set_anchor("W")
+
+    return fig
+
+
+def pad_1D(inputs, PAD=0):
+    def pad_data(x, length, PAD):
+        x_padded = np.pad(
+            x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
+        )
+        return x_padded
+
+    max_len = max((len(x) for x in inputs))
+    padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
+
+    return padded
+
+
+def pad_2D(inputs, maxlen=None):
+    def pad(x, max_len):
+        PAD = 0
+        if np.shape(x)[0] > max_len:
+            raise ValueError("not max_len")
+
+        s = np.shape(x)[1]
+        x_padded = np.pad(
+            x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
+        )
+        return x_padded[:, :s]
+
+    if maxlen:
+        output = np.stack([pad(x, maxlen) for x in inputs])
+    else:
+        max_len = max(np.shape(x)[0] for x in inputs)
+        output = np.stack([pad(x, max_len) for x in inputs])
+
+    return output
+
+
+def pad(input_ele, mel_max_length=None):
+    if mel_max_length:
+        max_len = mel_max_length
+    else:
+        max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
+
+    out_list = list()
+    for i, batch in enumerate(input_ele):
+        if len(batch.shape) == 1:
+            one_batch_padded = F.pad(
+                batch, (0, max_len - batch.size(0)), "constant", 0.0
+            )
+        elif len(batch.shape) == 2:
+            one_batch_padded = F.pad(
+                batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
+            )
+        out_list.append(one_batch_padded)
+    out_padded = torch.stack(out_list)
+    return out_padded
diff --git a/audioldm2/utils.py b/audioldm2/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..c098e25b99cf78b0c0befc71fb7ba7688e79c899
--- /dev/null
+++ b/audioldm2/utils.py
@@ -0,0 +1,352 @@
+import contextlib
+import importlib
+from huggingface_hub import hf_hub_download
+
+from inspect import isfunction
+import os
+import soundfile as sf
+import time
+import wave
+
+import progressbar
+
+CACHE_DIR = os.getenv(
+    "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2")
+)
+
+def read_list(fname):
+    result = []
+    with open(fname, "r", encoding="utf-8") as f:
+        for each in f.readlines():
+            each = each.strip('\n')
+            result.append(each)
+    return result
+
+def get_duration(fname):
+    with contextlib.closing(wave.open(fname, "r")) as f:
+        frames = f.getnframes()
+        rate = f.getframerate()
+        return frames / float(rate)
+
+
+def get_bit_depth(fname):
+    with contextlib.closing(wave.open(fname, "r")) as f:
+        bit_depth = f.getsampwidth() * 8
+        return bit_depth
+
+
+def get_time():
+    t = time.localtime()
+    return time.strftime("%d_%m_%Y_%H_%M_%S", t)
+
+
+def seed_everything(seed):
+    import random, os
+    import numpy as np
+    import torch
+
+    random.seed(seed)
+    os.environ["PYTHONHASHSEED"] = str(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = True
+
+
+def save_wave(waveform, savepath, name="outwav"):
+    if type(name) is not list:
+        name = [name] * waveform.shape[0]
+
+    for i in range(waveform.shape[0]):
+        path = os.path.join(
+            savepath,
+            "%s_%s.wav"
+            % (
+                os.path.basename(name[i])
+                if (not ".wav" in name[i])
+                else os.path.basename(name[i]).split(".")[0],
+                i,
+            ),
+        )
+        print("Save audio to %s" % path)
+        sf.write(path, waveform[i, 0], samplerate=16000)
+
+
+def exists(x):
+    return x is not None
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def count_params(model, verbose=False):
+    total_params = sum(p.numel() for p in model.parameters())
+    if verbose:
+        print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+    return total_params
+
+
+def get_obj_from_str(string, reload=False):
+    module, cls = string.rsplit(".", 1)
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
+
+
+def instantiate_from_config(config):
+    if not "target" in config:
+        if config == "__is_first_stage__":
+            return None
+        elif config == "__is_unconditional__":
+            return None
+        raise KeyError("Expected key `target` to instantiate.")
+    try:
+        return get_obj_from_str(config["target"])(**config.get("params", dict()))
+    except:
+        import ipdb
+
+        ipdb.set_trace()
+
+
+def default_audioldm_config(model_name="audioldm2-full"):
+    basic_config = {
+        "metadata_root": "/mnt/bn/lqhaoheliu/metadata/processed/dataset_root.json",
+        "log_directory": "./log/audiomae_pred",
+        "precision": "high",
+        "data": {
+            "train": [
+                "audiocaps",
+                "audioset",
+                "wavcaps",
+                "audiostock_music_250k",
+                "free_to_use_sounds",
+                "epidemic_sound_effects",
+                "vggsound",
+                "million_song_dataset",
+            ],
+            "val": "audiocaps",
+            "test": "audiocaps",
+            "class_label_indices": "audioset",
+            "dataloader_add_ons": [
+                "extract_kaldi_fbank_feature",
+                "extract_vits_phoneme_and_flant5_text",
+                "waveform_rs_48k",
+            ],
+        },
+        "variables": {
+            "sampling_rate": 16000,
+            "mel_bins": 64,
+            "latent_embed_dim": 8,
+            "latent_t_size": 256,
+            "latent_f_size": 16,
+            "in_channels": 8,
+            "optimize_ddpm_parameter": True,
+            "warmup_steps": 5000,
+        },
+        "step": {
+            "validation_every_n_epochs": 1,
+            "save_checkpoint_every_n_steps": 5000,
+            "limit_val_batches": 10,
+            "max_steps": 1500000,
+            "save_top_k": 2,
+        },
+        "preprocessing": {
+            "audio": {
+                "sampling_rate": 16000,
+                "max_wav_value": 32768,
+                "duration": 10.24,
+            },
+            "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
+            "mel": {"n_mel_channels": 64, "mel_fmin": 0, "mel_fmax": 8000},
+        },
+        "augmentation": {"mixup": 0},
+        "model": {
+            "target": "audioldm2.latent_diffusion.models.ddpm.LatentDiffusion",
+            "params": {
+                "first_stage_config": {
+                    "base_learning_rate": 0.000008,
+                    "target": "audioldm2.latent_encoder.autoencoder.AutoencoderKL",
+                    "params": {
+                        "sampling_rate": 16000,
+                        "batchsize": 4,
+                        "monitor": "val/rec_loss",
+                        "image_key": "fbank",
+                        "subband": 1,
+                        "embed_dim": 8,
+                        "time_shuffle": 1,
+                        "lossconfig": {
+                            "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator",
+                            "params": {
+                                "disc_start": 50001,
+                                "kl_weight": 1000,
+                                "disc_weight": 0.5,
+                                "disc_in_channels": 1,
+                            },
+                        },
+                        "ddconfig": {
+                            "double_z": True,
+                            "mel_bins": 64,
+                            "z_channels": 8,
+                            "resolution": 256,
+                            "downsample_time": False,
+                            "in_channels": 1,
+                            "out_ch": 1,
+                            "ch": 128,
+                            "ch_mult": [1, 2, 4],
+                            "num_res_blocks": 2,
+                            "attn_resolutions": [],
+                            "dropout": 0,
+                        },
+                    },
+                },
+                "base_learning_rate": 0.0001,
+                "warmup_steps": 5000,
+                "optimize_ddpm_parameter": True,
+                "sampling_rate": 16000,
+                "batchsize": 16,
+                "linear_start": 0.0015,
+                "linear_end": 0.0195,
+                "num_timesteps_cond": 1,
+                "log_every_t": 200,
+                "timesteps": 1000,
+                "unconditional_prob_cfg": 0.1,
+                "parameterization": "eps",
+                "first_stage_key": "fbank",
+                "latent_t_size": 256,
+                "latent_f_size": 16,
+                "channels": 8,
+                "monitor": "val/loss_simple_ema",
+                "scale_by_std": True,
+                "unet_config": {
+                    "target": "audioldm2.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel",
+                    "params": {
+                        "image_size": 64,
+                        "context_dim": [768, 1024],
+                        "in_channels": 8,
+                        "out_channels": 8,
+                        "model_channels": 128,
+                        "attention_resolutions": [8, 4, 2],
+                        "num_res_blocks": 2,
+                        "channel_mult": [1, 2, 3, 5],
+                        "num_head_channels": 32,
+                        "use_spatial_transformer": True,
+                        "transformer_depth": 1,
+                    },
+                },
+                "evaluation_params": {
+                    "unconditional_guidance_scale": 3.5,
+                    "ddim_sampling_steps": 200,
+                    "n_candidates_per_samples": 3,
+                },
+                "cond_stage_config": {
+                    "crossattn_audiomae_generated": {
+                        "cond_stage_key": "all",
+                        "conditioning_key": "crossattn",
+                        "target": "audioldm2.latent_diffusion.modules.encoders.modules.SequenceGenAudioMAECond",
+                        "params": {
+                            "always_output_audiomae_gt": False,
+                            "learnable": True,
+                            "device": "cuda",
+                            "use_gt_mae_output": True,
+                            "use_gt_mae_prob": 0.25,
+                            "base_learning_rate": 0.0002,
+                            "sequence_gen_length": 8,
+                            "use_warmup": True,
+                            "sequence_input_key": [
+                                "film_clap_cond1",
+                                "crossattn_flan_t5",
+                            ],
+                            "sequence_input_embed_dim": [512, 1024],
+                            "batchsize": 16,
+                            "cond_stage_config": {
+                                "film_clap_cond1": {
+                                    "cond_stage_key": "text",
+                                    "conditioning_key": "film",
+                                    "target": "audioldm2.latent_diffusion.modules.encoders.modules.CLAPAudioEmbeddingClassifierFreev2",
+                                    "params": {
+                                        "sampling_rate": 48000,
+                                        "embed_mode": "text",
+                                        "amodel": "HTSAT-base",
+                                    },
+                                },
+                                "crossattn_flan_t5": {
+                                    "cond_stage_key": "text",
+                                    "conditioning_key": "crossattn",
+                                    "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState",
+                                },
+                                "crossattn_audiomae_pooled": {
+                                    "cond_stage_key": "ta_kaldi_fbank",
+                                    "conditioning_key": "crossattn",
+                                    "target": "audioldm2.latent_diffusion.modules.encoders.modules.AudioMAEConditionCTPoolRand",
+                                    "params": {
+                                        "regularization": False,
+                                        "no_audiomae_mask": True,
+                                        "time_pooling_factors": [8],
+                                        "freq_pooling_factors": [8],
+                                        "eval_time_pooling": 8,
+                                        "eval_freq_pooling": 8,
+                                        "mask_ratio": 0,
+                                    },
+                                },
+                            },
+                        },
+                    },
+                    "crossattn_flan_t5": {
+                        "cond_stage_key": "text",
+                        "conditioning_key": "crossattn",
+                        "target": "audioldm2.latent_diffusion.modules.encoders.modules.FlanT5HiddenState",
+                    },
+                },
+            },
+        },
+    }
+    return basic_config
+
+
+def get_metadata():
+    return {
+        "audioldm2-full": {
+            "path": os.path.join(
+                CACHE_DIR,
+                "audioldm2-full.pth",
+            ),
+            "url": "https://huggingface.co/haoheliu/audioldm2-full/resolve/main/audioldm2-full.pth",
+        },
+    }
+
+
+class MyProgressBar:
+    def __init__(self):
+        self.pbar = None
+
+    def __call__(self, block_num, block_size, total_size):
+        if not self.pbar:
+            self.pbar = progressbar.ProgressBar(maxval=total_size)
+            self.pbar.start()
+
+        downloaded = block_num * block_size
+        if downloaded < total_size:
+            self.pbar.update(downloaded)
+        else:
+            self.pbar.finish()
+
+
+def download_checkpoint(checkpoint_name="audioldm2-full"):
+    meta = get_metadata()
+    if checkpoint_name not in meta.keys():
+        print(
+            "The model name you provided is not supported. Please use one of the following: ",
+            meta.keys(),
+        )
+
+    model_id = "haoheliu/%s" % checkpoint_name
+    hf_hub_download(
+        repo_id=model_id,
+        filename=os.path.basename(meta[checkpoint_name]["path"]),
+        local_dir=os.path.dirname(meta[checkpoint_name]["path"]),
+    )
diff --git a/batch.lst b/batch.lst
new file mode 100644
index 0000000000000000000000000000000000000000..c52c7a52523775ad729bfbb350f9cd70ddfbf3e4
--- /dev/null
+++ b/batch.lst
@@ -0,0 +1,4 @@
+A forest of wind chimes singing a soothing melody in the breeze.
+A violin playing a heartfelt melody.
+A saxophone playing a soulful melody.
+Musical constellations twinkling in the night sky, forming a cosmic melody.
\ No newline at end of file
diff --git a/bg.png b/bg.png
new file mode 100644
index 0000000000000000000000000000000000000000..2811a3593a7492c5af5754ab6949a6e60f2635bd
Binary files /dev/null and b/bg.png differ
diff --git a/bin/audioldm2 b/bin/audioldm2
new file mode 100755
index 0000000000000000000000000000000000000000..2ff95674ad63326a4b0d7f7b633c2f87949502f3
--- /dev/null
+++ b/bin/audioldm2
@@ -0,0 +1,131 @@
+#!/usr/bin/python3
+import os
+import torch
+import logging
+from audioldm2 import text_to_audio, build_model, save_wave, get_time, read_list
+import argparse
+
+os.environ["TOKENIZERS_PARALLELISM"] = "true"
+matplotlib_logger = logging.getLogger('matplotlib')
+matplotlib_logger.setLevel(logging.WARNING)
+
+
+CACHE_DIR = os.getenv(
+    "AUDIOLDM_CACHE_DIR",
+    os.path.join(os.path.expanduser("~"), ".cache/audioldm2"))
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument(
+    "-t",
+    "--text",
+    type=str,
+    required=False,
+    default="",
+    help="Text prompt to the model for audio generation",
+)
+
+parser.add_argument(
+    "-tl",
+    "--text_list",
+    type=str,
+    required=False,
+    default="",
+    help="A file that contains text prompt to the model for audio generation",
+)
+
+parser.add_argument(
+    "-s",
+    "--save_path",
+    type=str,
+    required=False,
+    help="The path to save model output",
+    default="./output",
+)
+
+parser.add_argument(
+    "--model_name",
+    type=str,
+    required=False,
+    help="The checkpoint you gonna use",
+    default="audioldm2-full",
+    choices=["audioldm2-full"]
+)
+
+parser.add_argument(
+    "-b",
+    "--batchsize",
+    type=int,
+    required=False,
+    default=1,
+    help="Generate how many samples at the same time",
+)
+
+parser.add_argument(
+    "--ddim_steps",
+    type=int,
+    required=False,
+    default=200,
+    help="The sampling step for DDIM",
+)
+
+parser.add_argument(
+    "-gs",
+    "--guidance_scale",
+    type=float,
+    required=False,
+    default=3.5,
+    help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
+)
+
+parser.add_argument(
+    "-n",
+    "--n_candidate_gen_per_text",
+    type=int,
+    required=False,
+    default=3,
+    help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
+)
+
+parser.add_argument(
+    "--seed",
+    type=int,
+    required=False,
+    default=0,
+    help="Change this value (any integer number) will lead to a different generation result.",
+)
+
+args = parser.parse_args()
+
+torch.set_float32_matmul_precision("high")
+        
+save_path = os.path.join(args.save_path, get_time())
+
+text = args.text
+random_seed = args.seed
+duration = 10
+guidance_scale = args.guidance_scale
+n_candidate_gen_per_text = args.n_candidate_gen_per_text
+
+os.makedirs(save_path, exist_ok=True)
+audioldm2 = build_model(model_name=args.model_name)
+
+if(args.text_list):
+    print("Generate audio based on the text prompts in %s" % args.text_list)
+    prompt_todo = read_list(args.text_list)
+else: 
+    prompt_todo = [text]
+    
+for text in prompt_todo:
+    waveform = text_to_audio(
+        audioldm2,
+        text,
+        seed=random_seed,
+        duration=duration,
+        guidance_scale=guidance_scale,
+        ddim_steps=args.ddim_steps,
+        n_candidate_gen_per_text=n_candidate_gen_per_text,
+        batchsize=args.batchsize,
+    )
+        
+    save_wave(waveform, save_path, name=text)
diff --git a/bin/audioldm2.cmd b/bin/audioldm2.cmd
new file mode 100755
index 0000000000000000000000000000000000000000..c164fbfb6a194858b6d9019c8e29df3e57b3172a
--- /dev/null
+++ b/bin/audioldm2.cmd
@@ -0,0 +1,2 @@
+@echo OFF
+python -m audioldm2 %*
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ea2a8af90d37d8a1994cc87a7731b03ea065cd8f
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,6 @@
+git+https://github.com/huggingface/diffusers.git
+git+https://github.com/huggingface/transformers.git
+--extra-index-url https://download.pytorch.org/whl/cu113
+torch >= 2.0
+huggingface_hub
+transformers==4.30.2
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f32a48bc29a8e293efbbf3fb05080e940de108
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,158 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# python3 setup.py sdist bdist_wheel
+"""
+@File    :   setup.py.py    
+@Contact :   haoheliu@gmail.com
+@License :   (C)Copyright 2020-2100
+
+@Modify Time      @Author    @Version    @Desciption
+------------      -------    --------    -----------
+9/6/21 5:16 PM   Haohe Liu      1.0         None
+"""
+
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Note: To use the 'upload' functionality of this file, you must:
+#   $ pipenv install twine --dev
+
+import io
+import os
+import sys
+from shutil import rmtree
+
+from setuptools import find_packages, setup, Command
+
+# Package meta-data.
+NAME = "audioldm2"
+DESCRIPTION = "This package is written for text-to-audio/music generation."
+URL = "https://github.com/haoheliu/audioldm2"
+EMAIL = "haoheliu@gmail.com"
+AUTHOR = "Haohe Liu"
+REQUIRES_PYTHON = ">=3.7.0"
+VERSION = "0.0.2"
+
+# What packages are required for this module to be executed?
+REQUIRED = [
+    "torch>=1.13.0",
+    "torchaudio>=0.13.0",
+    "torchvision>=0.14.0",
+    "tqdm",
+    "gradio",
+    "pyyaml",
+    "einops",
+    "chardet",
+    "numpy<=1.23.5",
+    "soundfile",
+    "librosa==0.9.2",
+    "scipy",
+    "pandas",
+    "torchlibrosa==0.0.9",
+    "transformers",
+    "progressbar",
+    "ftfy",
+]
+
+# What packages are optional?
+EXTRAS = {}
+
+# The rest you shouldn't have to touch too much :)
+# ------------------------------------------------
+# Except, perhaps the License and Trove Classifiers!
+# If you do change the License, remember to change the Trove Classifier for that!
+
+here = os.path.abspath(os.path.dirname(__file__))
+
+# Import the README and use it as the long-description.
+# Note: this will only work if 'README.md' is present in your MANIFEST.in file!
+try:
+    with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f:
+        long_description = "\n" + f.read()
+except FileNotFoundError:
+    long_description = DESCRIPTION
+
+# Load the package's __version__.py module as a dictionary.
+about = {}
+if not VERSION:
+    project_slug = NAME.lower().replace("-", "_").replace(" ", "_")
+    with open(os.path.join(here, project_slug, "__version__.py")) as f:
+        exec(f.read(), about)
+else:
+    about["__version__"] = VERSION
+
+
+class UploadCommand(Command):
+    """Support setup.py upload."""
+
+    description = "Build and publish the package."
+    user_options = []
+
+    @staticmethod
+    def status(s):
+        """Prints things in bold."""
+        print("\033[1m{0}\033[0m".format(s))
+
+    def initialize_options(self):
+        pass
+
+    def finalize_options(self):
+        pass
+
+    def run(self):
+        try:
+            self.status("Removing previous builds…")
+            rmtree(os.path.join(here, "dist"))
+        except OSError:
+            pass
+
+        self.status("Building Source and Wheel (universal) distribution…")
+        os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable))
+
+        self.status("Uploading the package to PyPI via Twine…")
+        os.system("twine upload dist/*")
+
+        self.status("Pushing git tags…")
+        os.system("git tag v{0}".format(about["__version__"]))
+        os.system("git push --tags")
+
+        sys.exit()
+
+
+# Where the magic happens:
+setup(
+    name=NAME,
+    version=about["__version__"],
+    description=DESCRIPTION,
+    long_description=long_description,
+    long_description_content_type="text/markdown",
+    author=AUTHOR,
+    author_email=EMAIL,
+    python_requires=REQUIRES_PYTHON,
+    url=URL,
+    # packages=find_packages(exclude=[]),
+    # If your package is a single module, use this instead of 'packages':
+    # entry_points={
+    #     'console_scripts': ['mycli=mymodule:cli'],
+    # },
+    install_requires=REQUIRED,
+    extras_require=EXTRAS,
+    packages=find_packages(),
+    include_package_data=True,
+    license="MIT",
+    classifiers=[
+        # Trove classifiers
+        # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
+        "License :: OSI Approved :: MIT License",
+        "Programming Language :: Python",
+        "Programming Language :: Python :: 3",
+        "Programming Language :: Python :: 3.7",
+        "Programming Language :: Python :: Implementation :: CPython",
+        "Programming Language :: Python :: Implementation :: PyPy",
+    ],
+    # $ setup.py publish support.
+    cmdclass={
+        "upload": UploadCommand,
+    },
+    scripts=["bin/audioldm2.cmd", "bin/audioldm2"],
+)
diff --git a/share_btn.py b/share_btn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0378607680fa5468e9034d230f546f5f0913ae0
--- /dev/null
+++ b/share_btn.py
@@ -0,0 +1,74 @@
+community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
+    <path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
+    <path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
+</svg>"""
+
+loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
+   style="color: #ffffff; 
+"
+   xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
+
+share_js = """async () => {
+	async function uploadFile(file){
+		const UPLOAD_URL = 'https://huggingface.co/uploads';
+		const response = await fetch(UPLOAD_URL, {
+			method: 'POST',
+			headers: {
+				'Content-Type': file.type,
+				'X-Requested-With': 'XMLHttpRequest',
+			},
+			body: file, /// <- File inherits from Blob
+		});
+		const url = await response.text();
+		return url;
+	}
+    async function getInputVideoFile(videoEl){
+        const res = await fetch(videoEl.src);
+        const blob = await res.blob();
+        const videoId = Date.now() % 200;
+        const fileName = `sd-perception-${{videoId}}.mp4`;
+        return new File([blob], fileName, { type: 'video/mp4' }); 
+	}
+    
+    async function audioToBase64(audioFile) {
+        return new Promise((resolve, reject) => {
+            let reader = new FileReader();
+            reader.readAsDataURL(audioFile);
+            reader.onload = () => resolve(reader.result);
+            reader.onerror = error => reject(error);
+          
+        });
+      }
+    const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
+    const inputPromptEl = gradioEl.querySelector('#prompt-in input').value;
+    const outputVideoEl = gradioEl.querySelector('#output-video video');
+    
+    let titleTxt = `Text-to-Audio: ${inputPromptEl}`;
+
+    const shareBtnEl = gradioEl.querySelector('#share-btn');
+    const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
+    const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
+    if(!outputVideoEl){
+        return;
+    };
+    shareBtnEl.style.pointerEvents = 'none';
+    shareIconEl.style.display = 'none';
+    loadingIconEl.style.removeProperty('display');
+    const outputVideo = await getInputVideoFile(outputVideoEl);
+    const urlOutputVideo = await uploadFile(outputVideo);
+    
+    const descriptionMd = `
+##### ${inputPromptEl}
+    
+${urlOutputVideo}
+`;
+    const params = new URLSearchParams({
+        title: titleTxt,
+        description: descriptionMd,
+    });
+	const paramsStr = params.toString();
+	window.open(`https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/discussions/new?${paramsStr}`, '_blank');
+    shareBtnEl.style.removeProperty('pointer-events');
+    shareIconEl.style.removeProperty('display');
+    loadingIconEl.style.display = 'none';
+}"""
diff --git a/tests/code_coverage.py b/tests/code_coverage.py
new file mode 100644
index 0000000000000000000000000000000000000000..deb035e9fedffacd8bf3a9c37d5566fa8fd4e819
--- /dev/null
+++ b/tests/code_coverage.py
@@ -0,0 +1,3 @@
+import os
+
+os.system('python3 bin/audioldm2 -t "A toilet flushing and water trickling"')
diff --git a/tests/code_coverage.sh b/tests/code_coverage.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0a5920c645c262c80436ff586e7ad4825e9e5622
--- /dev/null
+++ b/tests/code_coverage.sh
@@ -0,0 +1 @@
+pytest --cov=src tests/*
\ No newline at end of file