Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload folder using huggingface_hub
Browse files- .gitignore +173 -0
- LICENSE +21 -0
- README.md +8 -7
- README_REPO.md +196 -0
- app.py +824 -0
- inference-cli.py +378 -0
- inference-cli.toml +8 -0
- model/__init__.py +7 -0
- model/backbones/README.md +20 -0
- model/backbones/dit.py +158 -0
- model/backbones/mmdit.py +136 -0
- model/backbones/unett.py +201 -0
- model/cfm.py +279 -0
- model/dataset.py +242 -0
- model/ecapa_tdnn.py +268 -0
- model/modules.py +575 -0
- model/trainer.py +250 -0
- model/utils.py +574 -0
- requirements.txt +29 -0
- scripts/count_max_epoch.py +32 -0
- scripts/count_params_gflops.py +35 -0
- scripts/eval_infer_batch.py +199 -0
- scripts/eval_infer_batch.sh +13 -0
- scripts/eval_librispeech_test_clean.py +67 -0
- scripts/eval_seedtts_testset.py +69 -0
- scripts/prepare_emilia.py +143 -0
- scripts/prepare_wenetspeech4tts.py +116 -0
- speech_edit.py +182 -0
- train.py +91 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,173 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Customed
         | 
| 2 | 
            +
            .vscode/
         | 
| 3 | 
            +
            tests/
         | 
| 4 | 
            +
            runs/
         | 
| 5 | 
            +
            data/
         | 
| 6 | 
            +
            ckpts/
         | 
| 7 | 
            +
            wandb/
         | 
| 8 | 
            +
            results/
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 13 | 
            +
            __pycache__/
         | 
| 14 | 
            +
            *.py[cod]
         | 
| 15 | 
            +
            *$py.class
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # C extensions
         | 
| 18 | 
            +
            *.so
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # Distribution / packaging
         | 
| 21 | 
            +
            .Python
         | 
| 22 | 
            +
            build/
         | 
| 23 | 
            +
            develop-eggs/
         | 
| 24 | 
            +
            dist/
         | 
| 25 | 
            +
            downloads/
         | 
| 26 | 
            +
            eggs/
         | 
| 27 | 
            +
            .eggs/
         | 
| 28 | 
            +
            lib/
         | 
| 29 | 
            +
            lib64/
         | 
| 30 | 
            +
            parts/
         | 
| 31 | 
            +
            sdist/
         | 
| 32 | 
            +
            var/
         | 
| 33 | 
            +
            wheels/
         | 
| 34 | 
            +
            share/python-wheels/
         | 
| 35 | 
            +
            *.egg-info/
         | 
| 36 | 
            +
            .installed.cfg
         | 
| 37 | 
            +
            *.egg
         | 
| 38 | 
            +
            MANIFEST
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            # PyInstaller
         | 
| 41 | 
            +
            #  Usually these files are written by a python script from a template
         | 
| 42 | 
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         | 
| 43 | 
            +
            *.manifest
         | 
| 44 | 
            +
            *.spec
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            # Installer logs
         | 
| 47 | 
            +
            pip-log.txt
         | 
| 48 | 
            +
            pip-delete-this-directory.txt
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            # Unit test / coverage reports
         | 
| 51 | 
            +
            htmlcov/
         | 
| 52 | 
            +
            .tox/
         | 
| 53 | 
            +
            .nox/
         | 
| 54 | 
            +
            .coverage
         | 
| 55 | 
            +
            .coverage.*
         | 
| 56 | 
            +
            .cache
         | 
| 57 | 
            +
            nosetests.xml
         | 
| 58 | 
            +
            coverage.xml
         | 
| 59 | 
            +
            *.cover
         | 
| 60 | 
            +
            *.py,cover
         | 
| 61 | 
            +
            .hypothesis/
         | 
| 62 | 
            +
            .pytest_cache/
         | 
| 63 | 
            +
            cover/
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            # Translations
         | 
| 66 | 
            +
            *.mo
         | 
| 67 | 
            +
            *.pot
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            # Django stuff:
         | 
| 70 | 
            +
            *.log
         | 
| 71 | 
            +
            local_settings.py
         | 
| 72 | 
            +
            db.sqlite3
         | 
| 73 | 
            +
            db.sqlite3-journal
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            # Flask stuff:
         | 
| 76 | 
            +
            instance/
         | 
| 77 | 
            +
            .webassets-cache
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            # Scrapy stuff:
         | 
| 80 | 
            +
            .scrapy
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            # Sphinx documentation
         | 
| 83 | 
            +
            docs/_build/
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            # PyBuilder
         | 
| 86 | 
            +
            .pybuilder/
         | 
| 87 | 
            +
            target/
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            # Jupyter Notebook
         | 
| 90 | 
            +
            .ipynb_checkpoints
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            # IPython
         | 
| 93 | 
            +
            profile_default/
         | 
| 94 | 
            +
            ipython_config.py
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            # pyenv
         | 
| 97 | 
            +
            #   For a library or package, you might want to ignore these files since the code is
         | 
| 98 | 
            +
            #   intended to run in multiple environments; otherwise, check them in:
         | 
| 99 | 
            +
            # .python-version
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            # pipenv
         | 
| 102 | 
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         | 
| 103 | 
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         | 
| 104 | 
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         | 
| 105 | 
            +
            #   install all needed dependencies.
         | 
| 106 | 
            +
            #Pipfile.lock
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            # poetry
         | 
| 109 | 
            +
            #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
         | 
| 110 | 
            +
            #   This is especially recommended for binary packages to ensure reproducibility, and is more
         | 
| 111 | 
            +
            #   commonly ignored for libraries.
         | 
| 112 | 
            +
            #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
         | 
| 113 | 
            +
            #poetry.lock
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            # pdm
         | 
| 116 | 
            +
            #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
         | 
| 117 | 
            +
            #pdm.lock
         | 
| 118 | 
            +
            #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
         | 
| 119 | 
            +
            #   in version control.
         | 
| 120 | 
            +
            #   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
         | 
| 121 | 
            +
            .pdm.toml
         | 
| 122 | 
            +
            .pdm-python
         | 
| 123 | 
            +
            .pdm-build/
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
         | 
| 126 | 
            +
            __pypackages__/
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            # Celery stuff
         | 
| 129 | 
            +
            celerybeat-schedule
         | 
| 130 | 
            +
            celerybeat.pid
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            # SageMath parsed files
         | 
| 133 | 
            +
            *.sage.py
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            # Environments
         | 
| 136 | 
            +
            .env
         | 
| 137 | 
            +
            .venv
         | 
| 138 | 
            +
            env/
         | 
| 139 | 
            +
            venv/
         | 
| 140 | 
            +
            ENV/
         | 
| 141 | 
            +
            env.bak/
         | 
| 142 | 
            +
            venv.bak/
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            # Spyder project settings
         | 
| 145 | 
            +
            .spyderproject
         | 
| 146 | 
            +
            .spyproject
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            # Rope project settings
         | 
| 149 | 
            +
            .ropeproject
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            # mkdocs documentation
         | 
| 152 | 
            +
            /site
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            # mypy
         | 
| 155 | 
            +
            .mypy_cache/
         | 
| 156 | 
            +
            .dmypy.json
         | 
| 157 | 
            +
            dmypy.json
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            # Pyre type checker
         | 
| 160 | 
            +
            .pyre/
         | 
| 161 | 
            +
             | 
| 162 | 
            +
            # pytype static type analyzer
         | 
| 163 | 
            +
            .pytype/
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            # Cython debug symbols
         | 
| 166 | 
            +
            cython_debug/
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            # PyCharm
         | 
| 169 | 
            +
            #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
         | 
| 170 | 
            +
            #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
         | 
| 171 | 
            +
            #  and can be added to the global gitignore or merged into this file.  For a more nuclear
         | 
| 172 | 
            +
            #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
         | 
| 173 | 
            +
            #.idea/
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2024 Yushen CHEN
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,12 +1,13 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title: F5 | 
| 3 | 
            -
            emoji:  | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version: 5.1.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
            -
            pinned:  | 
|  | |
|  | |
| 10 | 
             
            ---
         | 
| 11 |  | 
| 12 | 
            -
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: F5-TTS
         | 
| 3 | 
            +
            emoji: 🗣️
         | 
| 4 | 
            +
            colorFrom: green
         | 
| 5 | 
            +
            colorTo: green
         | 
| 6 | 
             
            sdk: gradio
         | 
|  | |
| 7 | 
             
            app_file: app.py
         | 
| 8 | 
            +
            pinned: true
         | 
| 9 | 
            +
            short_description: 'F5-TTS & E2-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
         | 
| 10 | 
            +
            sdk_version: 5.1.0
         | 
| 11 | 
             
            ---
         | 
| 12 |  | 
| 13 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        README_REPO.md
    ADDED
    
    | @@ -0,0 +1,196 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            [](https://github.com/SWivid/F5-TTS)
         | 
| 4 | 
            +
            [](https://arxiv.org/abs/2410.06885)
         | 
| 5 | 
            +
            [](https://swivid.github.io/F5-TTS/)
         | 
| 6 | 
            +
            [](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            **E2 TTS**: Flat-UNet Transformer, closest reproduction.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            ## Installation
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            Clone the repository:
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            ```bash
         | 
| 19 | 
            +
            git clone https://github.com/SWivid/F5-TTS.git
         | 
| 20 | 
            +
            cd F5-TTS
         | 
| 21 | 
            +
            ```
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            Install torch with your CUDA version, e.g. :
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            ```bash
         | 
| 26 | 
            +
            pip install torch==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
         | 
| 27 | 
            +
            pip install torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
         | 
| 28 | 
            +
            ```
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            Install other packages:
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            ```bash
         | 
| 33 | 
            +
            pip install -r requirements.txt
         | 
| 34 | 
            +
            ```
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            ## Prepare Dataset
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            ```bash
         | 
| 41 | 
            +
            # prepare custom dataset up to your need
         | 
| 42 | 
            +
            # download corresponding dataset first, and fill in the path in scripts
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            # Prepare the Emilia dataset
         | 
| 45 | 
            +
            python scripts/prepare_emilia.py
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            # Prepare the Wenetspeech4TTS dataset
         | 
| 48 | 
            +
            python scripts/prepare_wenetspeech4tts.py
         | 
| 49 | 
            +
            ```
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            ## Training
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            Once your datasets are prepared, you can start the training process.
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            ```bash
         | 
| 56 | 
            +
            # setup accelerate config, e.g. use multi-gpu ddp, fp16
         | 
| 57 | 
            +
            # will be to: ~/.cache/huggingface/accelerate/default_config.yaml     
         | 
| 58 | 
            +
            accelerate config
         | 
| 59 | 
            +
            accelerate launch train.py
         | 
| 60 | 
            +
            ```
         | 
| 61 | 
            +
            An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            ## Inference
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), or automatically downloaded with `inference-cli` and `gradio_app`.
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`. 
         | 
| 68 | 
            +
            - To avoid possible inference failures, make sure you have seen through the following instructions.
         | 
| 69 | 
            +
            - A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider using a prompt audio <15s.
         | 
| 70 | 
            +
            - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words. 
         | 
| 71 | 
            +
            - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            ### CLI Inference
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            ```bash
         | 
| 78 | 
            +
            python inference-cli.py \
         | 
| 79 | 
            +
            --model "F5-TTS" \
         | 
| 80 | 
            +
            --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
         | 
| 81 | 
            +
            --ref_text "Some call me nature, others call me mother nature." \
         | 
| 82 | 
            +
            --gen_text "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            python inference-cli.py \
         | 
| 85 | 
            +
            --model "E2-TTS" \
         | 
| 86 | 
            +
            --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
         | 
| 87 | 
            +
            --ref_text "对,这就是我,万人敬仰的太乙真人。" \
         | 
| 88 | 
            +
            --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
         | 
| 89 | 
            +
            ```
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            ### Gradio App
         | 
| 92 | 
            +
            Currently supported features:
         | 
| 93 | 
            +
            - Chunk inference
         | 
| 94 | 
            +
            - Podcast Generation
         | 
| 95 | 
            +
            - Multiple Speech-Type Generation
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            ```bash
         | 
| 100 | 
            +
            python gradio_app.py
         | 
| 101 | 
            +
            ```
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            You can specify the port/host:
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            ```bash
         | 
| 106 | 
            +
            python gradio_app.py --port 7860 --host 0.0.0.0
         | 
| 107 | 
            +
            ```
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            Or launch a share link:
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            ```bash
         | 
| 112 | 
            +
            python gradio_app.py --share
         | 
| 113 | 
            +
            ```
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            ### Speech Editing
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            To test speech editing capabilities, use the following command.
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            ```bash
         | 
| 120 | 
            +
            python speech_edit.py
         | 
| 121 | 
            +
            ```
         | 
| 122 | 
            +
             | 
| 123 | 
            +
            ## Evaluation
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            ### Prepare Test Datasets
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
         | 
| 128 | 
            +
            2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
         | 
| 129 | 
            +
            3. Unzip the downloaded datasets and place them in the data/ directory.
         | 
| 130 | 
            +
            4. Update the path for the test-clean data in `scripts/eval_infer_batch.py`
         | 
| 131 | 
            +
            5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            ### Batch Inference for Test Set
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            To run batch inference for evaluations, execute the following commands:
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            ```bash
         | 
| 138 | 
            +
            # batch inference for evaluations
         | 
| 139 | 
            +
            accelerate config  # if not set before
         | 
| 140 | 
            +
            bash scripts/eval_infer_batch.sh
         | 
| 141 | 
            +
            ```
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            ### Download Evaluation Model Checkpoints
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
         | 
| 146 | 
            +
            2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
         | 
| 147 | 
            +
            3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            ### Objective Evaluation
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            **Some Notes**
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            For faster-whisper with CUDA 11:
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            ```bash
         | 
| 156 | 
            +
            pip install --force-reinstall ctranslate2==3.24.0
         | 
| 157 | 
            +
            ```
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            (Recommended) To avoid possible ASR failures, such as abnormal repetitions in output:
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            ```bash
         | 
| 162 | 
            +
            pip install faster-whisper==0.10.1
         | 
| 163 | 
            +
            ```
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
         | 
| 166 | 
            +
            ```bash
         | 
| 167 | 
            +
            # Evaluation for Seed-TTS test set
         | 
| 168 | 
            +
            python scripts/eval_seedtts_testset.py
         | 
| 169 | 
            +
             | 
| 170 | 
            +
            # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
         | 
| 171 | 
            +
            python scripts/eval_librispeech_test_clean.py
         | 
| 172 | 
            +
            ```
         | 
| 173 | 
            +
             | 
| 174 | 
            +
            ## Acknowledgements
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
         | 
| 177 | 
            +
            - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
         | 
| 178 | 
            +
            - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
         | 
| 179 | 
            +
            - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
         | 
| 180 | 
            +
            - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
         | 
| 181 | 
            +
            - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
         | 
| 182 | 
            +
            - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
         | 
| 183 | 
            +
            - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            ## Citation
         | 
| 186 | 
            +
            ```
         | 
| 187 | 
            +
            @article{chen-etal-2024-f5tts,
         | 
| 188 | 
            +
                  title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching}, 
         | 
| 189 | 
            +
                  author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
         | 
| 190 | 
            +
                  journal={arXiv preprint arXiv:2410.06885},
         | 
| 191 | 
            +
                  year={2024},
         | 
| 192 | 
            +
            }
         | 
| 193 | 
            +
            ```
         | 
| 194 | 
            +
            ## License
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            Our code is released under MIT License.
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,824 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torchaudio
         | 
| 5 | 
            +
            import gradio as gr
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import tempfile
         | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
            from vocos import Vocos
         | 
| 10 | 
            +
            from pydub import AudioSegment, silence
         | 
| 11 | 
            +
            from model import CFM, UNetT, DiT, MMDiT
         | 
| 12 | 
            +
            from cached_path import cached_path
         | 
| 13 | 
            +
            from model.utils import (
         | 
| 14 | 
            +
                load_checkpoint,
         | 
| 15 | 
            +
                get_tokenizer,
         | 
| 16 | 
            +
                convert_char_to_pinyin,
         | 
| 17 | 
            +
                save_spectrogram,
         | 
| 18 | 
            +
            )
         | 
| 19 | 
            +
            from transformers import pipeline
         | 
| 20 | 
            +
            import librosa
         | 
| 21 | 
            +
            import click
         | 
| 22 | 
            +
            import soundfile as sf
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            try:
         | 
| 25 | 
            +
                import spaces
         | 
| 26 | 
            +
                USING_SPACES = True
         | 
| 27 | 
            +
            except ImportError:
         | 
| 28 | 
            +
                USING_SPACES = False
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            def gpu_decorator(func):
         | 
| 31 | 
            +
                if USING_SPACES:
         | 
| 32 | 
            +
                    return spaces.GPU(func)
         | 
| 33 | 
            +
                else:
         | 
| 34 | 
            +
                    return func
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            SPLIT_WORDS = [
         | 
| 39 | 
            +
                "but", "however", "nevertheless", "yet", "still",
         | 
| 40 | 
            +
                "therefore", "thus", "hence", "consequently",
         | 
| 41 | 
            +
                "moreover", "furthermore", "additionally",
         | 
| 42 | 
            +
                "meanwhile", "alternatively", "otherwise",
         | 
| 43 | 
            +
                "namely", "specifically", "for example", "such as",
         | 
| 44 | 
            +
                "in fact", "indeed", "notably",
         | 
| 45 | 
            +
                "in contrast", "on the other hand", "conversely",
         | 
| 46 | 
            +
                "in conclusion", "to summarize", "finally"
         | 
| 47 | 
            +
            ]
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            device = (
         | 
| 50 | 
            +
                "cuda"
         | 
| 51 | 
            +
                if torch.cuda.is_available()
         | 
| 52 | 
            +
                else "mps" if torch.backends.mps.is_available() else "cpu"
         | 
| 53 | 
            +
            )
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            print(f"Using {device} device")
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            pipe = pipeline(
         | 
| 58 | 
            +
                "automatic-speech-recognition",
         | 
| 59 | 
            +
                model="openai/whisper-large-v3-turbo",
         | 
| 60 | 
            +
                torch_dtype=torch.float16,
         | 
| 61 | 
            +
                device=device,
         | 
| 62 | 
            +
            )
         | 
| 63 | 
            +
            vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            # --------------------- Settings -------------------- #
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            target_sample_rate = 24000
         | 
| 68 | 
            +
            n_mel_channels = 100
         | 
| 69 | 
            +
            hop_length = 256
         | 
| 70 | 
            +
            target_rms = 0.1
         | 
| 71 | 
            +
            nfe_step = 32  # 16, 32
         | 
| 72 | 
            +
            cfg_strength = 2.0
         | 
| 73 | 
            +
            ode_method = "euler"
         | 
| 74 | 
            +
            sway_sampling_coef = -1.0
         | 
| 75 | 
            +
            speed = 1.0
         | 
| 76 | 
            +
            # fix_duration = 27  # None or float (duration in seconds)
         | 
| 77 | 
            +
            fix_duration = None
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
         | 
| 81 | 
            +
                ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
         | 
| 82 | 
            +
                # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"  # .pt | .safetensors
         | 
| 83 | 
            +
                vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
         | 
| 84 | 
            +
                model = CFM(
         | 
| 85 | 
            +
                    transformer=model_cls(
         | 
| 86 | 
            +
                        **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
         | 
| 87 | 
            +
                    ),
         | 
| 88 | 
            +
                    mel_spec_kwargs=dict(
         | 
| 89 | 
            +
                        target_sample_rate=target_sample_rate,
         | 
| 90 | 
            +
                        n_mel_channels=n_mel_channels,
         | 
| 91 | 
            +
                        hop_length=hop_length,
         | 
| 92 | 
            +
                    ),
         | 
| 93 | 
            +
                    odeint_kwargs=dict(
         | 
| 94 | 
            +
                        method=ode_method,
         | 
| 95 | 
            +
                    ),
         | 
| 96 | 
            +
                    vocab_char_map=vocab_char_map,
         | 
| 97 | 
            +
                ).to(device)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                model = load_checkpoint(model, ckpt_path, device, use_ema = True)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                return model
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            # load models
         | 
| 105 | 
            +
            F5TTS_model_cfg = dict(
         | 
| 106 | 
            +
                dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
         | 
| 107 | 
            +
            )
         | 
| 108 | 
            +
            E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
            F5TTS_ema_model = load_model(
         | 
| 111 | 
            +
                "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
         | 
| 112 | 
            +
            )
         | 
| 113 | 
            +
            E2TTS_ema_model = load_model(
         | 
| 114 | 
            +
                "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
         | 
| 115 | 
            +
            )
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
         | 
| 118 | 
            +
                if len(text.encode('utf-8')) <= max_chars:
         | 
| 119 | 
            +
                    return [text]
         | 
| 120 | 
            +
                if text[-1] not in ['。', '.', '!', '!', '?', '?']:
         | 
| 121 | 
            +
                    text += '.'
         | 
| 122 | 
            +
                    
         | 
| 123 | 
            +
                sentences = re.split('([。.!?!?])', text)
         | 
| 124 | 
            +
                sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
                batches = []
         | 
| 127 | 
            +
                current_batch = ""
         | 
| 128 | 
            +
                
         | 
| 129 | 
            +
                def split_by_words(text):
         | 
| 130 | 
            +
                    words = text.split()
         | 
| 131 | 
            +
                    current_word_part = ""
         | 
| 132 | 
            +
                    word_batches = []
         | 
| 133 | 
            +
                    for word in words:
         | 
| 134 | 
            +
                        if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
         | 
| 135 | 
            +
                            current_word_part += word + ' '
         | 
| 136 | 
            +
                        else:
         | 
| 137 | 
            +
                            if current_word_part:
         | 
| 138 | 
            +
                                # Try to find a suitable split word
         | 
| 139 | 
            +
                                for split_word in split_words:
         | 
| 140 | 
            +
                                    split_index = current_word_part.rfind(' ' + split_word + ' ')
         | 
| 141 | 
            +
                                    if split_index != -1:
         | 
| 142 | 
            +
                                        word_batches.append(current_word_part[:split_index].strip())
         | 
| 143 | 
            +
                                        current_word_part = current_word_part[split_index:].strip() + ' '
         | 
| 144 | 
            +
                                        break
         | 
| 145 | 
            +
                                else:
         | 
| 146 | 
            +
                                    # If no suitable split word found, just append the current part
         | 
| 147 | 
            +
                                    word_batches.append(current_word_part.strip())
         | 
| 148 | 
            +
                                    current_word_part = ""
         | 
| 149 | 
            +
                            current_word_part += word + ' '
         | 
| 150 | 
            +
                    if current_word_part:
         | 
| 151 | 
            +
                        word_batches.append(current_word_part.strip())
         | 
| 152 | 
            +
                    return word_batches
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                for sentence in sentences:
         | 
| 155 | 
            +
                    if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
         | 
| 156 | 
            +
                        current_batch += sentence
         | 
| 157 | 
            +
                    else:
         | 
| 158 | 
            +
                        # If adding this sentence would exceed the limit
         | 
| 159 | 
            +
                        if current_batch:
         | 
| 160 | 
            +
                            batches.append(current_batch)
         | 
| 161 | 
            +
                            current_batch = ""
         | 
| 162 | 
            +
                        
         | 
| 163 | 
            +
                        # If the sentence itself is longer than max_chars, split it
         | 
| 164 | 
            +
                        if len(sentence.encode('utf-8')) > max_chars:
         | 
| 165 | 
            +
                            # First, try to split by colon
         | 
| 166 | 
            +
                            colon_parts = sentence.split(':')
         | 
| 167 | 
            +
                            if len(colon_parts) > 1:
         | 
| 168 | 
            +
                                for part in colon_parts:
         | 
| 169 | 
            +
                                    if len(part.encode('utf-8')) <= max_chars:
         | 
| 170 | 
            +
                                        batches.append(part)
         | 
| 171 | 
            +
                                    else:
         | 
| 172 | 
            +
                                        # If colon part is still too long, split by comma
         | 
| 173 | 
            +
                                        comma_parts = re.split('[,,]', part)
         | 
| 174 | 
            +
                                        if len(comma_parts) > 1:
         | 
| 175 | 
            +
                                            current_comma_part = ""
         | 
| 176 | 
            +
                                            for comma_part in comma_parts:
         | 
| 177 | 
            +
                                                if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
         | 
| 178 | 
            +
                                                    current_comma_part += comma_part + ','
         | 
| 179 | 
            +
                                                else:
         | 
| 180 | 
            +
                                                    if current_comma_part:
         | 
| 181 | 
            +
                                                        batches.append(current_comma_part.rstrip(','))
         | 
| 182 | 
            +
                                                    current_comma_part = comma_part + ','
         | 
| 183 | 
            +
                                            if current_comma_part:
         | 
| 184 | 
            +
                                                batches.append(current_comma_part.rstrip(','))
         | 
| 185 | 
            +
                                        else:
         | 
| 186 | 
            +
                                            # If no comma, split by words
         | 
| 187 | 
            +
                                            batches.extend(split_by_words(part))
         | 
| 188 | 
            +
                            else:
         | 
| 189 | 
            +
                                # If no colon, split by comma
         | 
| 190 | 
            +
                                comma_parts = re.split('[,,]', sentence)
         | 
| 191 | 
            +
                                if len(comma_parts) > 1:
         | 
| 192 | 
            +
                                    current_comma_part = ""
         | 
| 193 | 
            +
                                    for comma_part in comma_parts:
         | 
| 194 | 
            +
                                        if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
         | 
| 195 | 
            +
                                            current_comma_part += comma_part + ','
         | 
| 196 | 
            +
                                        else:
         | 
| 197 | 
            +
                                            if current_comma_part:
         | 
| 198 | 
            +
                                                batches.append(current_comma_part.rstrip(','))
         | 
| 199 | 
            +
                                            current_comma_part = comma_part + ','
         | 
| 200 | 
            +
                                    if current_comma_part:
         | 
| 201 | 
            +
                                        batches.append(current_comma_part.rstrip(','))
         | 
| 202 | 
            +
                                else:
         | 
| 203 | 
            +
                                    # If no comma, split by words
         | 
| 204 | 
            +
                                    batches.extend(split_by_words(sentence))
         | 
| 205 | 
            +
                        else:
         | 
| 206 | 
            +
                            current_batch = sentence
         | 
| 207 | 
            +
                
         | 
| 208 | 
            +
                if current_batch:
         | 
| 209 | 
            +
                    batches.append(current_batch)
         | 
| 210 | 
            +
                
         | 
| 211 | 
            +
                return batches
         | 
| 212 | 
            +
             | 
| 213 | 
            +
            def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
         | 
| 214 | 
            +
                if exp_name == "F5-TTS":
         | 
| 215 | 
            +
                    ema_model = F5TTS_ema_model
         | 
| 216 | 
            +
                elif exp_name == "E2-TTS":
         | 
| 217 | 
            +
                    ema_model = E2TTS_ema_model
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                audio, sr = ref_audio
         | 
| 220 | 
            +
                if audio.shape[0] > 1:
         | 
| 221 | 
            +
                    audio = torch.mean(audio, dim=0, keepdim=True)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                rms = torch.sqrt(torch.mean(torch.square(audio)))
         | 
| 224 | 
            +
                if rms < target_rms:
         | 
| 225 | 
            +
                    audio = audio * target_rms / rms
         | 
| 226 | 
            +
                if sr != target_sample_rate:
         | 
| 227 | 
            +
                    resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
         | 
| 228 | 
            +
                    audio = resampler(audio)
         | 
| 229 | 
            +
                audio = audio.to(device)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                generated_waves = []
         | 
| 232 | 
            +
                spectrograms = []
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
         | 
| 235 | 
            +
                    # Prepare the text
         | 
| 236 | 
            +
                    if len(ref_text[-1].encode('utf-8')) == 1:
         | 
| 237 | 
            +
                        ref_text = ref_text + " "
         | 
| 238 | 
            +
                    text_list = [ref_text + gen_text]
         | 
| 239 | 
            +
                    final_text_list = convert_char_to_pinyin(text_list)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    # Calculate duration
         | 
| 242 | 
            +
                    ref_audio_len = audio.shape[-1] // hop_length
         | 
| 243 | 
            +
                    zh_pause_punc = r"。,、;:?!"
         | 
| 244 | 
            +
                    ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
         | 
| 245 | 
            +
                    gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
         | 
| 246 | 
            +
                    duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # inference
         | 
| 249 | 
            +
                    with torch.inference_mode():
         | 
| 250 | 
            +
                        generated, _ = ema_model.sample(
         | 
| 251 | 
            +
                            cond=audio,
         | 
| 252 | 
            +
                            text=final_text_list,
         | 
| 253 | 
            +
                            duration=duration,
         | 
| 254 | 
            +
                            steps=nfe_step,
         | 
| 255 | 
            +
                            cfg_strength=cfg_strength,
         | 
| 256 | 
            +
                            sway_sampling_coef=sway_sampling_coef,
         | 
| 257 | 
            +
                        )
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    generated = generated[:, ref_audio_len:, :]
         | 
| 260 | 
            +
                    generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
         | 
| 261 | 
            +
                    generated_wave = vocos.decode(generated_mel_spec.cpu())
         | 
| 262 | 
            +
                    if rms < target_rms:
         | 
| 263 | 
            +
                        generated_wave = generated_wave * rms / target_rms
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    # wav -> numpy
         | 
| 266 | 
            +
                    generated_wave = generated_wave.squeeze().cpu().numpy()
         | 
| 267 | 
            +
                    
         | 
| 268 | 
            +
                    generated_waves.append(generated_wave)
         | 
| 269 | 
            +
                    spectrograms.append(generated_mel_spec[0].cpu().numpy())
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                # Combine all generated waves
         | 
| 272 | 
            +
                final_wave = np.concatenate(generated_waves)
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                # Remove silence
         | 
| 275 | 
            +
                if remove_silence:
         | 
| 276 | 
            +
                    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
         | 
| 277 | 
            +
                        sf.write(f.name, final_wave, target_sample_rate)
         | 
| 278 | 
            +
                        aseg = AudioSegment.from_file(f.name)
         | 
| 279 | 
            +
                        non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
         | 
| 280 | 
            +
                        non_silent_wave = AudioSegment.silent(duration=0)
         | 
| 281 | 
            +
                        for non_silent_seg in non_silent_segs:
         | 
| 282 | 
            +
                            non_silent_wave += non_silent_seg
         | 
| 283 | 
            +
                        aseg = non_silent_wave
         | 
| 284 | 
            +
                        aseg.export(f.name, format="wav")
         | 
| 285 | 
            +
                        final_wave, _ = torchaudio.load(f.name)
         | 
| 286 | 
            +
                    final_wave = final_wave.squeeze().cpu().numpy()
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                # Create a combined spectrogram
         | 
| 289 | 
            +
                combined_spectrogram = np.concatenate(spectrograms, axis=1)
         | 
| 290 | 
            +
                
         | 
| 291 | 
            +
                with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
         | 
| 292 | 
            +
                    spectrogram_path = tmp_spectrogram.name
         | 
| 293 | 
            +
                    save_spectrogram(combined_spectrogram, spectrogram_path)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                return (target_sample_rate, final_wave), spectrogram_path
         | 
| 296 | 
            +
             | 
| 297 | 
            +
            def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
         | 
| 298 | 
            +
                if not custom_split_words.strip():
         | 
| 299 | 
            +
                    custom_words = [word.strip() for word in custom_split_words.split(',')]
         | 
| 300 | 
            +
                    global SPLIT_WORDS
         | 
| 301 | 
            +
                    SPLIT_WORDS = custom_words
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                print(gen_text)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                gr.Info("Converting audio...")
         | 
| 306 | 
            +
                with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
         | 
| 307 | 
            +
                    aseg = AudioSegment.from_file(ref_audio_orig)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
         | 
| 310 | 
            +
                    non_silent_wave = AudioSegment.silent(duration=0)
         | 
| 311 | 
            +
                    for non_silent_seg in non_silent_segs:
         | 
| 312 | 
            +
                        non_silent_wave += non_silent_seg
         | 
| 313 | 
            +
                    aseg = non_silent_wave
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    audio_duration = len(aseg)
         | 
| 316 | 
            +
                    if audio_duration > 15000:
         | 
| 317 | 
            +
                        gr.Warning("Audio is over 15s, clipping to only first 15s.")
         | 
| 318 | 
            +
                        aseg = aseg[:15000]
         | 
| 319 | 
            +
                    aseg.export(f.name, format="wav")
         | 
| 320 | 
            +
                    ref_audio = f.name
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                if not ref_text.strip():
         | 
| 323 | 
            +
                    gr.Info("No reference text provided, transcribing reference audio...")
         | 
| 324 | 
            +
                    ref_text = pipe(
         | 
| 325 | 
            +
                        ref_audio,
         | 
| 326 | 
            +
                        chunk_length_s=30,
         | 
| 327 | 
            +
                        batch_size=128,
         | 
| 328 | 
            +
                        generate_kwargs={"task": "transcribe"},
         | 
| 329 | 
            +
                        return_timestamps=False,
         | 
| 330 | 
            +
                    )["text"].strip()
         | 
| 331 | 
            +
                    gr.Info("Finished transcription")
         | 
| 332 | 
            +
                else:
         | 
| 333 | 
            +
                    gr.Info("Using custom reference text...")
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                # Split the input text into batches
         | 
| 336 | 
            +
                audio, sr = torchaudio.load(ref_audio)
         | 
| 337 | 
            +
                max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
         | 
| 338 | 
            +
                gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
         | 
| 339 | 
            +
                print('ref_text', ref_text)
         | 
| 340 | 
            +
                for i, gen_text in enumerate(gen_text_batches):
         | 
| 341 | 
            +
                    print(f'gen_text {i}', gen_text)
         | 
| 342 | 
            +
                
         | 
| 343 | 
            +
                gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
         | 
| 344 | 
            +
                return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
         | 
| 345 | 
            +
                
         | 
| 346 | 
            +
            def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
         | 
| 347 | 
            +
                # Split the script into speaker blocks
         | 
| 348 | 
            +
                speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
         | 
| 349 | 
            +
                speaker_blocks = speaker_pattern.split(script)[1:]  # Skip the first empty element
         | 
| 350 | 
            +
                
         | 
| 351 | 
            +
                generated_audio_segments = []
         | 
| 352 | 
            +
                
         | 
| 353 | 
            +
                for i in range(0, len(speaker_blocks), 2):
         | 
| 354 | 
            +
                    speaker = speaker_blocks[i]
         | 
| 355 | 
            +
                    text = speaker_blocks[i+1].strip()
         | 
| 356 | 
            +
                    
         | 
| 357 | 
            +
                    # Determine which speaker is talking
         | 
| 358 | 
            +
                    if speaker == speaker1_name:
         | 
| 359 | 
            +
                        ref_audio = ref_audio1
         | 
| 360 | 
            +
                        ref_text = ref_text1
         | 
| 361 | 
            +
                    elif speaker == speaker2_name:
         | 
| 362 | 
            +
                        ref_audio = ref_audio2
         | 
| 363 | 
            +
                        ref_text = ref_text2
         | 
| 364 | 
            +
                    else:
         | 
| 365 | 
            +
                        continue  # Skip if the speaker is neither speaker1 nor speaker2
         | 
| 366 | 
            +
                    
         | 
| 367 | 
            +
                    # Generate audio for this block
         | 
| 368 | 
            +
                    audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
         | 
| 369 | 
            +
                    
         | 
| 370 | 
            +
                    # Convert the generated audio to a numpy array
         | 
| 371 | 
            +
                    sr, audio_data = audio
         | 
| 372 | 
            +
                    
         | 
| 373 | 
            +
                    # Save the audio data as a WAV file
         | 
| 374 | 
            +
                    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
         | 
| 375 | 
            +
                        sf.write(temp_file.name, audio_data, sr)
         | 
| 376 | 
            +
                        audio_segment = AudioSegment.from_wav(temp_file.name)
         | 
| 377 | 
            +
                    
         | 
| 378 | 
            +
                    generated_audio_segments.append(audio_segment)
         | 
| 379 | 
            +
                    
         | 
| 380 | 
            +
                    # Add a short pause between speakers
         | 
| 381 | 
            +
                    pause = AudioSegment.silent(duration=500)  # 500ms pause
         | 
| 382 | 
            +
                    generated_audio_segments.append(pause)
         | 
| 383 | 
            +
                
         | 
| 384 | 
            +
                # Concatenate all audio segments
         | 
| 385 | 
            +
                final_podcast = sum(generated_audio_segments)
         | 
| 386 | 
            +
                
         | 
| 387 | 
            +
                # Export the final podcast
         | 
| 388 | 
            +
                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
         | 
| 389 | 
            +
                    podcast_path = temp_file.name
         | 
| 390 | 
            +
                    final_podcast.export(podcast_path, format="wav")
         | 
| 391 | 
            +
                
         | 
| 392 | 
            +
                return podcast_path
         | 
| 393 | 
            +
             | 
| 394 | 
            +
            def parse_speechtypes_text(gen_text):
         | 
| 395 | 
            +
                # Pattern to find (Emotion)
         | 
| 396 | 
            +
                pattern = r'\((.*?)\)'
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                # Split the text by the pattern
         | 
| 399 | 
            +
                tokens = re.split(pattern, gen_text)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                segments = []
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                current_emotion = 'Regular'
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                for i in range(len(tokens)):
         | 
| 406 | 
            +
                    if i % 2 == 0:
         | 
| 407 | 
            +
                        # This is text
         | 
| 408 | 
            +
                        text = tokens[i].strip()
         | 
| 409 | 
            +
                        if text:
         | 
| 410 | 
            +
                            segments.append({'emotion': current_emotion, 'text': text})
         | 
| 411 | 
            +
                    else:
         | 
| 412 | 
            +
                        # This is emotion
         | 
| 413 | 
            +
                        emotion = tokens[i].strip()
         | 
| 414 | 
            +
                        current_emotion = emotion
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                return segments
         | 
| 417 | 
            +
             | 
| 418 | 
            +
            def update_speed(new_speed):
         | 
| 419 | 
            +
                global speed
         | 
| 420 | 
            +
                speed = new_speed
         | 
| 421 | 
            +
                return f"Speed set to: {speed}"
         | 
| 422 | 
            +
             | 
| 423 | 
            +
            with gr.Blocks() as app_credits:
         | 
| 424 | 
            +
                gr.Markdown("""
         | 
| 425 | 
            +
            # Credits
         | 
| 426 | 
            +
             | 
| 427 | 
            +
            * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
         | 
| 428 | 
            +
            * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
         | 
| 429 | 
            +
            """)
         | 
| 430 | 
            +
            with gr.Blocks() as app_tts:
         | 
| 431 | 
            +
                gr.Markdown("# Batched TTS")
         | 
| 432 | 
            +
                ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
         | 
| 433 | 
            +
                gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
         | 
| 434 | 
            +
                model_choice = gr.Radio(
         | 
| 435 | 
            +
                    choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
         | 
| 436 | 
            +
                )
         | 
| 437 | 
            +
                generate_btn = gr.Button("Synthesize", variant="primary")
         | 
| 438 | 
            +
                with gr.Accordion("Advanced Settings", open=False):
         | 
| 439 | 
            +
                    ref_text_input = gr.Textbox(
         | 
| 440 | 
            +
                        label="Reference Text",
         | 
| 441 | 
            +
                        info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
         | 
| 442 | 
            +
                        lines=2,
         | 
| 443 | 
            +
                    )
         | 
| 444 | 
            +
                    remove_silence = gr.Checkbox(
         | 
| 445 | 
            +
                        label="Remove Silences",
         | 
| 446 | 
            +
                        info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
         | 
| 447 | 
            +
                        value=True,
         | 
| 448 | 
            +
                    )
         | 
| 449 | 
            +
                    split_words_input = gr.Textbox(
         | 
| 450 | 
            +
                        label="Custom Split Words",
         | 
| 451 | 
            +
                        info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
         | 
| 452 | 
            +
                        lines=2,
         | 
| 453 | 
            +
                    )
         | 
| 454 | 
            +
                    speed_slider = gr.Slider(
         | 
| 455 | 
            +
                        label="Speed",
         | 
| 456 | 
            +
                        minimum=0.3,
         | 
| 457 | 
            +
                        maximum=2.0,
         | 
| 458 | 
            +
                        value=speed,
         | 
| 459 | 
            +
                        step=0.1,
         | 
| 460 | 
            +
                        info="Adjust the speed of the audio.",
         | 
| 461 | 
            +
                    )
         | 
| 462 | 
            +
                speed_slider.change(update_speed, inputs=speed_slider)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                audio_output = gr.Audio(label="Synthesized Audio")
         | 
| 465 | 
            +
                spectrogram_output = gr.Image(label="Spectrogram")
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                generate_btn.click(
         | 
| 468 | 
            +
                    infer,
         | 
| 469 | 
            +
                    inputs=[
         | 
| 470 | 
            +
                        ref_audio_input,
         | 
| 471 | 
            +
                        ref_text_input,
         | 
| 472 | 
            +
                        gen_text_input,
         | 
| 473 | 
            +
                        model_choice,
         | 
| 474 | 
            +
                        remove_silence,
         | 
| 475 | 
            +
                        split_words_input,
         | 
| 476 | 
            +
                    ],
         | 
| 477 | 
            +
                    outputs=[audio_output, spectrogram_output],
         | 
| 478 | 
            +
                )
         | 
| 479 | 
            +
                
         | 
| 480 | 
            +
            with gr.Blocks() as app_podcast:
         | 
| 481 | 
            +
                gr.Markdown("# Podcast Generation")
         | 
| 482 | 
            +
                speaker1_name = gr.Textbox(label="Speaker 1 Name")
         | 
| 483 | 
            +
                ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
         | 
| 484 | 
            +
                ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
         | 
| 485 | 
            +
                
         | 
| 486 | 
            +
                speaker2_name = gr.Textbox(label="Speaker 2 Name")
         | 
| 487 | 
            +
                ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
         | 
| 488 | 
            +
                ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
         | 
| 489 | 
            +
                
         | 
| 490 | 
            +
                script_input = gr.Textbox(label="Podcast Script", lines=10, 
         | 
| 491 | 
            +
                                            placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
         | 
| 492 | 
            +
                
         | 
| 493 | 
            +
                podcast_model_choice = gr.Radio(
         | 
| 494 | 
            +
                    choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
         | 
| 495 | 
            +
                )
         | 
| 496 | 
            +
                podcast_remove_silence = gr.Checkbox(
         | 
| 497 | 
            +
                    label="Remove Silences",
         | 
| 498 | 
            +
                    value=True,
         | 
| 499 | 
            +
                )
         | 
| 500 | 
            +
                generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
         | 
| 501 | 
            +
                podcast_output = gr.Audio(label="Generated Podcast")
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
         | 
| 504 | 
            +
                    return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                generate_podcast_btn.click(
         | 
| 507 | 
            +
                    podcast_generation,
         | 
| 508 | 
            +
                    inputs=[
         | 
| 509 | 
            +
                        script_input,
         | 
| 510 | 
            +
                        speaker1_name,
         | 
| 511 | 
            +
                        ref_audio_input1,
         | 
| 512 | 
            +
                        ref_text_input1,
         | 
| 513 | 
            +
                        speaker2_name,
         | 
| 514 | 
            +
                        ref_audio_input2,
         | 
| 515 | 
            +
                        ref_text_input2,
         | 
| 516 | 
            +
                        podcast_model_choice,
         | 
| 517 | 
            +
                        podcast_remove_silence,
         | 
| 518 | 
            +
                    ],
         | 
| 519 | 
            +
                    outputs=podcast_output,
         | 
| 520 | 
            +
                )
         | 
| 521 | 
            +
             | 
| 522 | 
            +
            def parse_emotional_text(gen_text):
         | 
| 523 | 
            +
                # Pattern to find (Emotion)
         | 
| 524 | 
            +
                pattern = r'\((.*?)\)'
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                # Split the text by the pattern
         | 
| 527 | 
            +
                tokens = re.split(pattern, gen_text)
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                segments = []
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                current_emotion = 'Regular'
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                for i in range(len(tokens)):
         | 
| 534 | 
            +
                    if i % 2 == 0:
         | 
| 535 | 
            +
                        # This is text
         | 
| 536 | 
            +
                        text = tokens[i].strip()
         | 
| 537 | 
            +
                        if text:
         | 
| 538 | 
            +
                            segments.append({'emotion': current_emotion, 'text': text})
         | 
| 539 | 
            +
                    else:
         | 
| 540 | 
            +
                        # This is emotion
         | 
| 541 | 
            +
                        emotion = tokens[i].strip()
         | 
| 542 | 
            +
                        current_emotion = emotion
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                return segments
         | 
| 545 | 
            +
             | 
| 546 | 
            +
            with gr.Blocks() as app_emotional:
         | 
| 547 | 
            +
                # New section for emotional generation
         | 
| 548 | 
            +
                gr.Markdown(
         | 
| 549 | 
            +
                    """
         | 
| 550 | 
            +
                # Multiple Speech-Type Generation
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                **Example Input:**
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
         | 
| 557 | 
            +
                """
         | 
| 558 | 
            +
                )
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                # Regular speech type (mandatory)
         | 
| 563 | 
            +
                with gr.Row():
         | 
| 564 | 
            +
                    regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
         | 
| 565 | 
            +
                    regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
         | 
| 566 | 
            +
                    regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                # Additional speech types (up to 9 more)
         | 
| 569 | 
            +
                max_speech_types = 10
         | 
| 570 | 
            +
                speech_type_names = []
         | 
| 571 | 
            +
                speech_type_audios = []
         | 
| 572 | 
            +
                speech_type_ref_texts = []
         | 
| 573 | 
            +
                speech_type_delete_btns = []
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                for i in range(max_speech_types - 1):
         | 
| 576 | 
            +
                    with gr.Row():
         | 
| 577 | 
            +
                        name_input = gr.Textbox(label='Speech Type Name', visible=False)
         | 
| 578 | 
            +
                        audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
         | 
| 579 | 
            +
                        ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
         | 
| 580 | 
            +
                        delete_btn = gr.Button("Delete", variant="secondary", visible=False)
         | 
| 581 | 
            +
                    speech_type_names.append(name_input)
         | 
| 582 | 
            +
                    speech_type_audios.append(audio_input)
         | 
| 583 | 
            +
                    speech_type_ref_texts.append(ref_text_input)
         | 
| 584 | 
            +
                    speech_type_delete_btns.append(delete_btn)
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                # Button to add speech type
         | 
| 587 | 
            +
                add_speech_type_btn = gr.Button("Add Speech Type")
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                # Keep track of current number of speech types
         | 
| 590 | 
            +
                speech_type_count = gr.State(value=0)
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                # Function to add a speech type
         | 
| 593 | 
            +
                def add_speech_type_fn(speech_type_count):
         | 
| 594 | 
            +
                    if speech_type_count < max_speech_types - 1:
         | 
| 595 | 
            +
                        speech_type_count += 1
         | 
| 596 | 
            +
                        # Prepare updates for the components
         | 
| 597 | 
            +
                        name_updates = []
         | 
| 598 | 
            +
                        audio_updates = []
         | 
| 599 | 
            +
                        ref_text_updates = []
         | 
| 600 | 
            +
                        delete_btn_updates = []
         | 
| 601 | 
            +
                        for i in range(max_speech_types - 1):
         | 
| 602 | 
            +
                            if i < speech_type_count:
         | 
| 603 | 
            +
                                name_updates.append(gr.update(visible=True))
         | 
| 604 | 
            +
                                audio_updates.append(gr.update(visible=True))
         | 
| 605 | 
            +
                                ref_text_updates.append(gr.update(visible=True))
         | 
| 606 | 
            +
                                delete_btn_updates.append(gr.update(visible=True))
         | 
| 607 | 
            +
                            else:
         | 
| 608 | 
            +
                                name_updates.append(gr.update())
         | 
| 609 | 
            +
                                audio_updates.append(gr.update())
         | 
| 610 | 
            +
                                ref_text_updates.append(gr.update())
         | 
| 611 | 
            +
                                delete_btn_updates.append(gr.update())
         | 
| 612 | 
            +
                    else:
         | 
| 613 | 
            +
                        # Optionally, show a warning
         | 
| 614 | 
            +
                        # gr.Warning("Maximum number of speech types reached.")
         | 
| 615 | 
            +
                        name_updates = [gr.update() for _ in range(max_speech_types - 1)]
         | 
| 616 | 
            +
                        audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
         | 
| 617 | 
            +
                        ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
         | 
| 618 | 
            +
                        delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
         | 
| 619 | 
            +
                    return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                add_speech_type_btn.click(
         | 
| 622 | 
            +
                    add_speech_type_fn,
         | 
| 623 | 
            +
                    inputs=speech_type_count,
         | 
| 624 | 
            +
                    outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
         | 
| 625 | 
            +
                )
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                # Function to delete a speech type
         | 
| 628 | 
            +
                def make_delete_speech_type_fn(index):
         | 
| 629 | 
            +
                    def delete_speech_type_fn(speech_type_count):
         | 
| 630 | 
            +
                        # Prepare updates
         | 
| 631 | 
            +
                        name_updates = []
         | 
| 632 | 
            +
                        audio_updates = []
         | 
| 633 | 
            +
                        ref_text_updates = []
         | 
| 634 | 
            +
                        delete_btn_updates = []
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                        for i in range(max_speech_types - 1):
         | 
| 637 | 
            +
                            if i == index:
         | 
| 638 | 
            +
                                name_updates.append(gr.update(visible=False, value=''))
         | 
| 639 | 
            +
                                audio_updates.append(gr.update(visible=False, value=None))
         | 
| 640 | 
            +
                                ref_text_updates.append(gr.update(visible=False, value=''))
         | 
| 641 | 
            +
                                delete_btn_updates.append(gr.update(visible=False))
         | 
| 642 | 
            +
                            else:
         | 
| 643 | 
            +
                                name_updates.append(gr.update())
         | 
| 644 | 
            +
                                audio_updates.append(gr.update())
         | 
| 645 | 
            +
                                ref_text_updates.append(gr.update())
         | 
| 646 | 
            +
                                delete_btn_updates.append(gr.update())
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                        speech_type_count = max(0, speech_type_count - 1)
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                        return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                    return delete_speech_type_fn
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                for i, delete_btn in enumerate(speech_type_delete_btns):
         | 
| 655 | 
            +
                    delete_fn = make_delete_speech_type_fn(i)
         | 
| 656 | 
            +
                    delete_btn.click(
         | 
| 657 | 
            +
                        delete_fn,
         | 
| 658 | 
            +
                        inputs=speech_type_count,
         | 
| 659 | 
            +
                        outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
         | 
| 660 | 
            +
                    )
         | 
| 661 | 
            +
             | 
| 662 | 
            +
                # Text input for the prompt
         | 
| 663 | 
            +
                gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
         | 
| 664 | 
            +
             | 
| 665 | 
            +
                # Model choice
         | 
| 666 | 
            +
                model_choice_emotional = gr.Radio(
         | 
| 667 | 
            +
                    choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
         | 
| 668 | 
            +
                )
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                with gr.Accordion("Advanced Settings", open=False):
         | 
| 671 | 
            +
                    remove_silence_emotional = gr.Checkbox(
         | 
| 672 | 
            +
                        label="Remove Silences",
         | 
| 673 | 
            +
                        value=True,
         | 
| 674 | 
            +
                    )
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                # Generate button
         | 
| 677 | 
            +
                generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                # Output audio
         | 
| 680 | 
            +
                audio_output_emotional = gr.Audio(label="Synthesized Audio")
         | 
| 681 | 
            +
             | 
| 682 | 
            +
                def generate_emotional_speech(
         | 
| 683 | 
            +
                    regular_audio,
         | 
| 684 | 
            +
                    regular_ref_text,
         | 
| 685 | 
            +
                    gen_text,
         | 
| 686 | 
            +
                    *args,
         | 
| 687 | 
            +
                ):
         | 
| 688 | 
            +
                    num_additional_speech_types = max_speech_types - 1
         | 
| 689 | 
            +
                    speech_type_names_list = args[:num_additional_speech_types]
         | 
| 690 | 
            +
                    speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
         | 
| 691 | 
            +
                    speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
         | 
| 692 | 
            +
                    model_choice = args[3 * num_additional_speech_types]
         | 
| 693 | 
            +
                    remove_silence = args[3 * num_additional_speech_types + 1]
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                    # Collect the speech types and their audios into a dict
         | 
| 696 | 
            +
                    speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                    for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
         | 
| 699 | 
            +
                        if name_input and audio_input:
         | 
| 700 | 
            +
                            speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
         | 
| 701 | 
            +
             | 
| 702 | 
            +
                    # Parse the gen_text into segments
         | 
| 703 | 
            +
                    segments = parse_speechtypes_text(gen_text)
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                    # For each segment, generate speech
         | 
| 706 | 
            +
                    generated_audio_segments = []
         | 
| 707 | 
            +
                    current_emotion = 'Regular'
         | 
| 708 | 
            +
             | 
| 709 | 
            +
                    for segment in segments:
         | 
| 710 | 
            +
                        emotion = segment['emotion']
         | 
| 711 | 
            +
                        text = segment['text']
         | 
| 712 | 
            +
             | 
| 713 | 
            +
                        if emotion in speech_types:
         | 
| 714 | 
            +
                            current_emotion = emotion
         | 
| 715 | 
            +
                        else:
         | 
| 716 | 
            +
                            # If emotion not available, default to Regular
         | 
| 717 | 
            +
                            current_emotion = 'Regular'
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                        ref_audio = speech_types[current_emotion]['audio']
         | 
| 720 | 
            +
                        ref_text = speech_types[current_emotion].get('ref_text', '')
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                        # Generate speech for this segment
         | 
| 723 | 
            +
                        audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
         | 
| 724 | 
            +
                        sr, audio_data = audio
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                        generated_audio_segments.append(audio_data)
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                    # Concatenate all audio segments
         | 
| 729 | 
            +
                    if generated_audio_segments:
         | 
| 730 | 
            +
                        final_audio_data = np.concatenate(generated_audio_segments)
         | 
| 731 | 
            +
                        return (sr, final_audio_data)
         | 
| 732 | 
            +
                    else:
         | 
| 733 | 
            +
                        gr.Warning("No audio generated.")
         | 
| 734 | 
            +
                        return None
         | 
| 735 | 
            +
             | 
| 736 | 
            +
                generate_emotional_btn.click(
         | 
| 737 | 
            +
                    generate_emotional_speech,
         | 
| 738 | 
            +
                    inputs=[
         | 
| 739 | 
            +
                        regular_audio,
         | 
| 740 | 
            +
                        regular_ref_text,
         | 
| 741 | 
            +
                        gen_text_input_emotional,
         | 
| 742 | 
            +
                    ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
         | 
| 743 | 
            +
                        model_choice_emotional,
         | 
| 744 | 
            +
                        remove_silence_emotional,
         | 
| 745 | 
            +
                    ],
         | 
| 746 | 
            +
                    outputs=audio_output_emotional,
         | 
| 747 | 
            +
                )
         | 
| 748 | 
            +
             | 
| 749 | 
            +
                # Validation function to disable Generate button if speech types are missing
         | 
| 750 | 
            +
                def validate_speech_types(
         | 
| 751 | 
            +
                    gen_text,
         | 
| 752 | 
            +
                    regular_name,
         | 
| 753 | 
            +
                    *args
         | 
| 754 | 
            +
                ):
         | 
| 755 | 
            +
                    num_additional_speech_types = max_speech_types - 1
         | 
| 756 | 
            +
                    speech_type_names_list = args[:num_additional_speech_types]
         | 
| 757 | 
            +
             | 
| 758 | 
            +
                    # Collect the speech types names
         | 
| 759 | 
            +
                    speech_types_available = set()
         | 
| 760 | 
            +
                    if regular_name:
         | 
| 761 | 
            +
                        speech_types_available.add(regular_name)
         | 
| 762 | 
            +
                    for name_input in speech_type_names_list:
         | 
| 763 | 
            +
                        if name_input:
         | 
| 764 | 
            +
                            speech_types_available.add(name_input)
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                    # Parse the gen_text to get the speech types used
         | 
| 767 | 
            +
                    segments = parse_emotional_text(gen_text)
         | 
| 768 | 
            +
                    speech_types_in_text = set(segment['emotion'] for segment in segments)
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                    # Check if all speech types in text are available
         | 
| 771 | 
            +
                    missing_speech_types = speech_types_in_text - speech_types_available
         | 
| 772 | 
            +
             | 
| 773 | 
            +
                    if missing_speech_types:
         | 
| 774 | 
            +
                        # Disable the generate button
         | 
| 775 | 
            +
                        return gr.update(interactive=False)
         | 
| 776 | 
            +
                    else:
         | 
| 777 | 
            +
                        # Enable the generate button
         | 
| 778 | 
            +
                        return gr.update(interactive=True)
         | 
| 779 | 
            +
             | 
| 780 | 
            +
                gen_text_input_emotional.change(
         | 
| 781 | 
            +
                    validate_speech_types,
         | 
| 782 | 
            +
                    inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
         | 
| 783 | 
            +
                    outputs=generate_emotional_btn
         | 
| 784 | 
            +
                )
         | 
| 785 | 
            +
            with gr.Blocks() as app:
         | 
| 786 | 
            +
                gr.Markdown(
         | 
| 787 | 
            +
                    """
         | 
| 788 | 
            +
            # E2/F5 TTS
         | 
| 789 | 
            +
             | 
| 790 | 
            +
            This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
         | 
| 791 | 
            +
             | 
| 792 | 
            +
            * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
         | 
| 793 | 
            +
            * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
         | 
| 794 | 
            +
             | 
| 795 | 
            +
            The checkpoints support English and Chinese.
         | 
| 796 | 
            +
             | 
| 797 | 
            +
            If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
         | 
| 798 | 
            +
             | 
| 799 | 
            +
            **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
         | 
| 800 | 
            +
            """
         | 
| 801 | 
            +
                )
         | 
| 802 | 
            +
                gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
         | 
| 803 | 
            +
             | 
| 804 | 
            +
            @click.command()
         | 
| 805 | 
            +
            @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
         | 
| 806 | 
            +
            @click.option("--host", "-H", default=None, help="Host to run the app on")
         | 
| 807 | 
            +
            @click.option(
         | 
| 808 | 
            +
                "--share",
         | 
| 809 | 
            +
                "-s",
         | 
| 810 | 
            +
                default=False,
         | 
| 811 | 
            +
                is_flag=True,
         | 
| 812 | 
            +
                help="Share the app via Gradio share link",
         | 
| 813 | 
            +
            )
         | 
| 814 | 
            +
            @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
         | 
| 815 | 
            +
            def main(port, host, share, api):
         | 
| 816 | 
            +
                global app
         | 
| 817 | 
            +
                print(f"Starting app...")
         | 
| 818 | 
            +
                app.queue(api_open=api).launch(
         | 
| 819 | 
            +
                    server_name=host, server_port=port, share=share, show_api=api
         | 
| 820 | 
            +
                )
         | 
| 821 | 
            +
             | 
| 822 | 
            +
             | 
| 823 | 
            +
            if __name__ == "__main__":
         | 
| 824 | 
            +
                main()
         | 
    	
        inference-cli.py
    ADDED
    
    | @@ -0,0 +1,378 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torchaudio
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import tempfile
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
            from vocos import Vocos
         | 
| 8 | 
            +
            from pydub import AudioSegment, silence
         | 
| 9 | 
            +
            from model import CFM, UNetT, DiT, MMDiT
         | 
| 10 | 
            +
            from cached_path import cached_path
         | 
| 11 | 
            +
            from model.utils import (
         | 
| 12 | 
            +
                load_checkpoint,
         | 
| 13 | 
            +
                get_tokenizer,
         | 
| 14 | 
            +
                convert_char_to_pinyin,
         | 
| 15 | 
            +
                save_spectrogram,
         | 
| 16 | 
            +
            )
         | 
| 17 | 
            +
            from transformers import pipeline
         | 
| 18 | 
            +
            import soundfile as sf
         | 
| 19 | 
            +
            import tomli
         | 
| 20 | 
            +
            import argparse
         | 
| 21 | 
            +
            import tqdm
         | 
| 22 | 
            +
            from pathlib import Path
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            parser = argparse.ArgumentParser(
         | 
| 25 | 
            +
                prog="python3 inference-cli.py",
         | 
| 26 | 
            +
                description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
         | 
| 27 | 
            +
                epilog="Specify  options above  to override  one or more settings from config.",
         | 
| 28 | 
            +
            )
         | 
| 29 | 
            +
            parser.add_argument(
         | 
| 30 | 
            +
                "-c",
         | 
| 31 | 
            +
                "--config",
         | 
| 32 | 
            +
                help="Configuration file. Default=cli-config.toml",
         | 
| 33 | 
            +
                default="inference-cli.toml",
         | 
| 34 | 
            +
            )
         | 
| 35 | 
            +
            parser.add_argument(
         | 
| 36 | 
            +
                "-m",
         | 
| 37 | 
            +
                "--model",
         | 
| 38 | 
            +
                help="F5-TTS | E2-TTS",
         | 
| 39 | 
            +
            )
         | 
| 40 | 
            +
            parser.add_argument(
         | 
| 41 | 
            +
                "-r",
         | 
| 42 | 
            +
                "--ref_audio",
         | 
| 43 | 
            +
                type=str,
         | 
| 44 | 
            +
                help="Reference audio file < 15 seconds."
         | 
| 45 | 
            +
            )
         | 
| 46 | 
            +
            parser.add_argument(
         | 
| 47 | 
            +
                "-s",
         | 
| 48 | 
            +
                "--ref_text",
         | 
| 49 | 
            +
                type=str,
         | 
| 50 | 
            +
                default="666",
         | 
| 51 | 
            +
                help="Subtitle for the reference audio."
         | 
| 52 | 
            +
            )
         | 
| 53 | 
            +
            parser.add_argument(
         | 
| 54 | 
            +
                "-t",
         | 
| 55 | 
            +
                "--gen_text",
         | 
| 56 | 
            +
                type=str,
         | 
| 57 | 
            +
                help="Text to generate.",
         | 
| 58 | 
            +
            )
         | 
| 59 | 
            +
            parser.add_argument(
         | 
| 60 | 
            +
                "-o",
         | 
| 61 | 
            +
                "--output_dir",
         | 
| 62 | 
            +
                type=str,
         | 
| 63 | 
            +
                help="Path to output folder..",
         | 
| 64 | 
            +
            )
         | 
| 65 | 
            +
            parser.add_argument(
         | 
| 66 | 
            +
                "--remove_silence",
         | 
| 67 | 
            +
                help="Remove silence.",
         | 
| 68 | 
            +
            )
         | 
| 69 | 
            +
            args = parser.parse_args()
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            config = tomli.load(open(args.config, "rb"))
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
         | 
| 74 | 
            +
            ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
         | 
| 75 | 
            +
            gen_text = args.gen_text if args.gen_text else config["gen_text"]
         | 
| 76 | 
            +
            output_dir = args.output_dir if args.output_dir else config["output_dir"]
         | 
| 77 | 
            +
            model = args.model if args.model else config["model"]
         | 
| 78 | 
            +
            remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
         | 
| 79 | 
            +
            wave_path = Path(output_dir)/"out.wav"
         | 
| 80 | 
            +
            spectrogram_path = Path(output_dir)/"out.png"
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            SPLIT_WORDS = [
         | 
| 83 | 
            +
                "but", "however", "nevertheless", "yet", "still",
         | 
| 84 | 
            +
                "therefore", "thus", "hence", "consequently",
         | 
| 85 | 
            +
                "moreover", "furthermore", "additionally",
         | 
| 86 | 
            +
                "meanwhile", "alternatively", "otherwise",
         | 
| 87 | 
            +
                "namely", "specifically", "for example", "such as",
         | 
| 88 | 
            +
                "in fact", "indeed", "notably",
         | 
| 89 | 
            +
                "in contrast", "on the other hand", "conversely",
         | 
| 90 | 
            +
                "in conclusion", "to summarize", "finally"
         | 
| 91 | 
            +
            ]
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            device = (
         | 
| 94 | 
            +
                "cuda"
         | 
| 95 | 
            +
                if torch.cuda.is_available()
         | 
| 96 | 
            +
                else "mps" if torch.backends.mps.is_available() else "cpu"
         | 
| 97 | 
            +
            )
         | 
| 98 | 
            +
            vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            print(f"Using {device} device")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            # --------------------- Settings -------------------- #
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            target_sample_rate = 24000
         | 
| 105 | 
            +
            n_mel_channels = 100
         | 
| 106 | 
            +
            hop_length = 256
         | 
| 107 | 
            +
            target_rms = 0.1
         | 
| 108 | 
            +
            nfe_step = 32  # 16, 32
         | 
| 109 | 
            +
            cfg_strength = 2.0
         | 
| 110 | 
            +
            ode_method = "euler"
         | 
| 111 | 
            +
            sway_sampling_coef = -1.0
         | 
| 112 | 
            +
            speed = 1.0
         | 
| 113 | 
            +
            # fix_duration = 27  # None or float (duration in seconds)
         | 
| 114 | 
            +
            fix_duration = None
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
         | 
| 117 | 
            +
                ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
         | 
| 118 | 
            +
                # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"  # .pt | .safetensors
         | 
| 119 | 
            +
                vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
         | 
| 120 | 
            +
                model = CFM(
         | 
| 121 | 
            +
                    transformer=model_cls(
         | 
| 122 | 
            +
                        **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
         | 
| 123 | 
            +
                    ),
         | 
| 124 | 
            +
                    mel_spec_kwargs=dict(
         | 
| 125 | 
            +
                        target_sample_rate=target_sample_rate,
         | 
| 126 | 
            +
                        n_mel_channels=n_mel_channels,
         | 
| 127 | 
            +
                        hop_length=hop_length,
         | 
| 128 | 
            +
                    ),
         | 
| 129 | 
            +
                    odeint_kwargs=dict(
         | 
| 130 | 
            +
                        method=ode_method,
         | 
| 131 | 
            +
                    ),
         | 
| 132 | 
            +
                    vocab_char_map=vocab_char_map,
         | 
| 133 | 
            +
                ).to(device)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                model = load_checkpoint(model, ckpt_path, device, use_ema = True)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                return model
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            # load models
         | 
| 141 | 
            +
            F5TTS_model_cfg = dict(
         | 
| 142 | 
            +
                dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
         | 
| 143 | 
            +
            )
         | 
| 144 | 
            +
            E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
         | 
| 147 | 
            +
                if len(text.encode('utf-8')) <= max_chars:
         | 
| 148 | 
            +
                    return [text]
         | 
| 149 | 
            +
                if text[-1] not in ['。', '.', '!', '!', '?', '?']:
         | 
| 150 | 
            +
                    text += '.'
         | 
| 151 | 
            +
                    
         | 
| 152 | 
            +
                sentences = re.split('([。.!?!?])', text)
         | 
| 153 | 
            +
                sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
         | 
| 154 | 
            +
                
         | 
| 155 | 
            +
                batches = []
         | 
| 156 | 
            +
                current_batch = ""
         | 
| 157 | 
            +
                
         | 
| 158 | 
            +
                def split_by_words(text):
         | 
| 159 | 
            +
                    words = text.split()
         | 
| 160 | 
            +
                    current_word_part = ""
         | 
| 161 | 
            +
                    word_batches = []
         | 
| 162 | 
            +
                    for word in words:
         | 
| 163 | 
            +
                        if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
         | 
| 164 | 
            +
                            current_word_part += word + ' '
         | 
| 165 | 
            +
                        else:
         | 
| 166 | 
            +
                            if current_word_part:
         | 
| 167 | 
            +
                                # Try to find a suitable split word
         | 
| 168 | 
            +
                                for split_word in split_words:
         | 
| 169 | 
            +
                                    split_index = current_word_part.rfind(' ' + split_word + ' ')
         | 
| 170 | 
            +
                                    if split_index != -1:
         | 
| 171 | 
            +
                                        word_batches.append(current_word_part[:split_index].strip())
         | 
| 172 | 
            +
                                        current_word_part = current_word_part[split_index:].strip() + ' '
         | 
| 173 | 
            +
                                        break
         | 
| 174 | 
            +
                                else:
         | 
| 175 | 
            +
                                    # If no suitable split word found, just append the current part
         | 
| 176 | 
            +
                                    word_batches.append(current_word_part.strip())
         | 
| 177 | 
            +
                                    current_word_part = ""
         | 
| 178 | 
            +
                            current_word_part += word + ' '
         | 
| 179 | 
            +
                    if current_word_part:
         | 
| 180 | 
            +
                        word_batches.append(current_word_part.strip())
         | 
| 181 | 
            +
                    return word_batches
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                for sentence in sentences:
         | 
| 184 | 
            +
                    if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
         | 
| 185 | 
            +
                        current_batch += sentence
         | 
| 186 | 
            +
                    else:
         | 
| 187 | 
            +
                        # If adding this sentence would exceed the limit
         | 
| 188 | 
            +
                        if current_batch:
         | 
| 189 | 
            +
                            batches.append(current_batch)
         | 
| 190 | 
            +
                            current_batch = ""
         | 
| 191 | 
            +
                        
         | 
| 192 | 
            +
                        # If the sentence itself is longer than max_chars, split it
         | 
| 193 | 
            +
                        if len(sentence.encode('utf-8')) > max_chars:
         | 
| 194 | 
            +
                            # First, try to split by colon
         | 
| 195 | 
            +
                            colon_parts = sentence.split(':')
         | 
| 196 | 
            +
                            if len(colon_parts) > 1:
         | 
| 197 | 
            +
                                for part in colon_parts:
         | 
| 198 | 
            +
                                    if len(part.encode('utf-8')) <= max_chars:
         | 
| 199 | 
            +
                                        batches.append(part)
         | 
| 200 | 
            +
                                    else:
         | 
| 201 | 
            +
                                        # If colon part is still too long, split by comma
         | 
| 202 | 
            +
                                        comma_parts = re.split('[,,]', part)
         | 
| 203 | 
            +
                                        if len(comma_parts) > 1:
         | 
| 204 | 
            +
                                            current_comma_part = ""
         | 
| 205 | 
            +
                                            for comma_part in comma_parts:
         | 
| 206 | 
            +
                                                if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
         | 
| 207 | 
            +
                                                    current_comma_part += comma_part + ','
         | 
| 208 | 
            +
                                                else:
         | 
| 209 | 
            +
                                                    if current_comma_part:
         | 
| 210 | 
            +
                                                        batches.append(current_comma_part.rstrip(','))
         | 
| 211 | 
            +
                                                    current_comma_part = comma_part + ','
         | 
| 212 | 
            +
                                            if current_comma_part:
         | 
| 213 | 
            +
                                                batches.append(current_comma_part.rstrip(','))
         | 
| 214 | 
            +
                                        else:
         | 
| 215 | 
            +
                                            # If no comma, split by words
         | 
| 216 | 
            +
                                            batches.extend(split_by_words(part))
         | 
| 217 | 
            +
                            else:
         | 
| 218 | 
            +
                                # If no colon, split by comma
         | 
| 219 | 
            +
                                comma_parts = re.split('[,,]', sentence)
         | 
| 220 | 
            +
                                if len(comma_parts) > 1:
         | 
| 221 | 
            +
                                    current_comma_part = ""
         | 
| 222 | 
            +
                                    for comma_part in comma_parts:
         | 
| 223 | 
            +
                                        if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
         | 
| 224 | 
            +
                                            current_comma_part += comma_part + ','
         | 
| 225 | 
            +
                                        else:
         | 
| 226 | 
            +
                                            if current_comma_part:
         | 
| 227 | 
            +
                                                batches.append(current_comma_part.rstrip(','))
         | 
| 228 | 
            +
                                            current_comma_part = comma_part + ','
         | 
| 229 | 
            +
                                    if current_comma_part:
         | 
| 230 | 
            +
                                        batches.append(current_comma_part.rstrip(','))
         | 
| 231 | 
            +
                                else:
         | 
| 232 | 
            +
                                    # If no comma, split by words
         | 
| 233 | 
            +
                                    batches.extend(split_by_words(sentence))
         | 
| 234 | 
            +
                        else:
         | 
| 235 | 
            +
                            current_batch = sentence
         | 
| 236 | 
            +
                
         | 
| 237 | 
            +
                if current_batch:
         | 
| 238 | 
            +
                    batches.append(current_batch)
         | 
| 239 | 
            +
                
         | 
| 240 | 
            +
                return batches
         | 
| 241 | 
            +
             | 
| 242 | 
            +
            def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
         | 
| 243 | 
            +
                if model == "F5-TTS":
         | 
| 244 | 
            +
                    ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
         | 
| 245 | 
            +
                elif model == "E2-TTS":
         | 
| 246 | 
            +
                    ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                audio, sr = ref_audio
         | 
| 249 | 
            +
                if audio.shape[0] > 1:
         | 
| 250 | 
            +
                    audio = torch.mean(audio, dim=0, keepdim=True)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                rms = torch.sqrt(torch.mean(torch.square(audio)))
         | 
| 253 | 
            +
                if rms < target_rms:
         | 
| 254 | 
            +
                    audio = audio * target_rms / rms
         | 
| 255 | 
            +
                if sr != target_sample_rate:
         | 
| 256 | 
            +
                    resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
         | 
| 257 | 
            +
                    audio = resampler(audio)
         | 
| 258 | 
            +
                audio = audio.to(device)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                generated_waves = []
         | 
| 261 | 
            +
                spectrograms = []
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)):
         | 
| 264 | 
            +
                    # Prepare the text
         | 
| 265 | 
            +
                    if len(ref_text[-1].encode('utf-8')) == 1:
         | 
| 266 | 
            +
                        ref_text = ref_text + " "
         | 
| 267 | 
            +
                    text_list = [ref_text + gen_text]
         | 
| 268 | 
            +
                    final_text_list = convert_char_to_pinyin(text_list)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    # Calculate duration
         | 
| 271 | 
            +
                    ref_audio_len = audio.shape[-1] // hop_length
         | 
| 272 | 
            +
                    zh_pause_punc = r"。,、;:?!"
         | 
| 273 | 
            +
                    ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
         | 
| 274 | 
            +
                    gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
         | 
| 275 | 
            +
                    duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    # inference
         | 
| 278 | 
            +
                    with torch.inference_mode():
         | 
| 279 | 
            +
                        generated, _ = ema_model.sample(
         | 
| 280 | 
            +
                            cond=audio,
         | 
| 281 | 
            +
                            text=final_text_list,
         | 
| 282 | 
            +
                            duration=duration,
         | 
| 283 | 
            +
                            steps=nfe_step,
         | 
| 284 | 
            +
                            cfg_strength=cfg_strength,
         | 
| 285 | 
            +
                            sway_sampling_coef=sway_sampling_coef,
         | 
| 286 | 
            +
                        )
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    generated = generated[:, ref_audio_len:, :]
         | 
| 289 | 
            +
                    generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
         | 
| 290 | 
            +
                    generated_wave = vocos.decode(generated_mel_spec.cpu())
         | 
| 291 | 
            +
                    if rms < target_rms:
         | 
| 292 | 
            +
                        generated_wave = generated_wave * rms / target_rms
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    # wav -> numpy
         | 
| 295 | 
            +
                    generated_wave = generated_wave.squeeze().cpu().numpy()
         | 
| 296 | 
            +
                    
         | 
| 297 | 
            +
                    generated_waves.append(generated_wave)
         | 
| 298 | 
            +
                    spectrograms.append(generated_mel_spec[0].cpu().numpy())
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                # Combine all generated waves
         | 
| 301 | 
            +
                final_wave = np.concatenate(generated_waves)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                with open(wave_path, "wb") as f:
         | 
| 304 | 
            +
                    sf.write(f.name, final_wave, target_sample_rate)
         | 
| 305 | 
            +
                    # Remove silence
         | 
| 306 | 
            +
                    if remove_silence:
         | 
| 307 | 
            +
                        aseg = AudioSegment.from_file(f.name)
         | 
| 308 | 
            +
                        non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
         | 
| 309 | 
            +
                        non_silent_wave = AudioSegment.silent(duration=0)
         | 
| 310 | 
            +
                        for non_silent_seg in non_silent_segs:
         | 
| 311 | 
            +
                            non_silent_wave += non_silent_seg
         | 
| 312 | 
            +
                        aseg = non_silent_wave
         | 
| 313 | 
            +
                        aseg.export(f.name, format="wav")
         | 
| 314 | 
            +
                    print(f.name)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                # Create a combined spectrogram
         | 
| 317 | 
            +
                combined_spectrogram = np.concatenate(spectrograms, axis=1)
         | 
| 318 | 
            +
                save_spectrogram(combined_spectrogram, spectrogram_path)
         | 
| 319 | 
            +
                print(spectrogram_path)
         | 
| 320 | 
            +
             | 
| 321 | 
            +
             | 
| 322 | 
            +
            def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
         | 
| 323 | 
            +
                if not custom_split_words.strip():
         | 
| 324 | 
            +
                    custom_words = [word.strip() for word in custom_split_words.split(',')]
         | 
| 325 | 
            +
                    global SPLIT_WORDS
         | 
| 326 | 
            +
                    SPLIT_WORDS = custom_words
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                print(gen_text)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                print("Converting audio...")
         | 
| 331 | 
            +
                with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
         | 
| 332 | 
            +
                    aseg = AudioSegment.from_file(ref_audio_orig)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
         | 
| 335 | 
            +
                    non_silent_wave = AudioSegment.silent(duration=0)
         | 
| 336 | 
            +
                    for non_silent_seg in non_silent_segs:
         | 
| 337 | 
            +
                        non_silent_wave += non_silent_seg
         | 
| 338 | 
            +
                    aseg = non_silent_wave
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    audio_duration = len(aseg)
         | 
| 341 | 
            +
                    if audio_duration > 15000:
         | 
| 342 | 
            +
                        print("Audio is over 15s, clipping to only first 15s.")
         | 
| 343 | 
            +
                        aseg = aseg[:15000]
         | 
| 344 | 
            +
                    aseg.export(f.name, format="wav")
         | 
| 345 | 
            +
                    ref_audio = f.name
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                if not ref_text.strip():
         | 
| 348 | 
            +
                    print("No reference text provided, transcribing reference audio...")
         | 
| 349 | 
            +
                    pipe = pipeline(
         | 
| 350 | 
            +
                        "automatic-speech-recognition",
         | 
| 351 | 
            +
                        model="openai/whisper-large-v3-turbo",
         | 
| 352 | 
            +
                        torch_dtype=torch.float16,
         | 
| 353 | 
            +
                        device=device,
         | 
| 354 | 
            +
                    )
         | 
| 355 | 
            +
                    ref_text = pipe(
         | 
| 356 | 
            +
                        ref_audio,
         | 
| 357 | 
            +
                        chunk_length_s=30,
         | 
| 358 | 
            +
                        batch_size=128,
         | 
| 359 | 
            +
                        generate_kwargs={"task": "transcribe"},
         | 
| 360 | 
            +
                        return_timestamps=False,
         | 
| 361 | 
            +
                    )["text"].strip()
         | 
| 362 | 
            +
                    print("Finished transcription")
         | 
| 363 | 
            +
                else:
         | 
| 364 | 
            +
                    print("Using custom reference text...")
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                # Split the input text into batches
         | 
| 367 | 
            +
                audio, sr = torchaudio.load(ref_audio)
         | 
| 368 | 
            +
                max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
         | 
| 369 | 
            +
                gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
         | 
| 370 | 
            +
                print('ref_text', ref_text)
         | 
| 371 | 
            +
                for i, gen_text in enumerate(gen_text_batches):
         | 
| 372 | 
            +
                    print(f'gen_text {i}', gen_text)
         | 
| 373 | 
            +
                
         | 
| 374 | 
            +
                print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
         | 
| 375 | 
            +
                return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
         | 
| 376 | 
            +
                
         | 
| 377 | 
            +
             | 
| 378 | 
            +
            infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
         | 
    	
        inference-cli.toml
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # F5-TTS | E2-TTS
         | 
| 2 | 
            +
            model = "F5-TTS"
         | 
| 3 | 
            +
            ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
         | 
| 4 | 
            +
            # If an empty "", transcribes the reference audio automatically.
         | 
| 5 | 
            +
            ref_text = "Some call me nature, others call me mother nature."
         | 
| 6 | 
            +
            gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
         | 
| 7 | 
            +
            remove_silence = true
         | 
| 8 | 
            +
            output_dir = "tests"
         | 
    	
        model/__init__.py
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from model.cfm import CFM
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from model.backbones.unett import UNetT
         | 
| 4 | 
            +
            from model.backbones.dit import DiT
         | 
| 5 | 
            +
            from model.backbones.mmdit import MMDiT
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from model.trainer import Trainer
         | 
    	
        model/backbones/README.md
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ## Backbones quick introduction
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            ### unett.py
         | 
| 5 | 
            +
            - flat unet transformer
         | 
| 6 | 
            +
            - structure same as in e2-tts & voicebox paper except using rotary pos emb
         | 
| 7 | 
            +
            - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            ### dit.py
         | 
| 10 | 
            +
            - adaln-zero dit
         | 
| 11 | 
            +
            - embedded timestep as condition
         | 
| 12 | 
            +
            - concatted noised_input + masked_cond + embedded_text, linear proj in
         | 
| 13 | 
            +
            - possible abs pos emb & convnextv2 blocks for embedded text before concat
         | 
| 14 | 
            +
            - possible long skip connection (first layer to last layer)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            ### mmdit.py
         | 
| 17 | 
            +
            - sd3 structure
         | 
| 18 | 
            +
            - timestep as condition
         | 
| 19 | 
            +
            - left stream: text embedded and applied a abs pos emb
         | 
| 20 | 
            +
            - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
         | 
    	
        model/backbones/dit.py
    ADDED
    
    | @@ -0,0 +1,158 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            ein notation:
         | 
| 3 | 
            +
            b - batch
         | 
| 4 | 
            +
            n - sequence
         | 
| 5 | 
            +
            nt - text sequence
         | 
| 6 | 
            +
            nw - raw wave length
         | 
| 7 | 
            +
            d - dimension
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from __future__ import annotations
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from torch import nn
         | 
| 14 | 
            +
            import torch.nn.functional as F
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from einops import repeat
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from x_transformers.x_transformers import RotaryEmbedding
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from model.modules import (
         | 
| 21 | 
            +
                TimestepEmbedding,
         | 
| 22 | 
            +
                ConvNeXtV2Block,
         | 
| 23 | 
            +
                ConvPositionEmbedding,
         | 
| 24 | 
            +
                DiTBlock,
         | 
| 25 | 
            +
                AdaLayerNormZero_Final,
         | 
| 26 | 
            +
                precompute_freqs_cis, get_pos_embed_indices,
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            # Text embedding
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            class TextEmbedding(nn.Module):
         | 
| 33 | 
            +
                def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
         | 
| 34 | 
            +
                    super().__init__()
         | 
| 35 | 
            +
                    self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim)  # use 0 as filler token
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    if conv_layers > 0:
         | 
| 38 | 
            +
                        self.extra_modeling = True
         | 
| 39 | 
            +
                        self.precompute_max_pos = 4096  # ~44s of 24khz audio
         | 
| 40 | 
            +
                        self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
         | 
| 41 | 
            +
                        self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
         | 
| 42 | 
            +
                    else:
         | 
| 43 | 
            +
                        self.extra_modeling = False
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def forward(self, text: int['b nt'], seq_len, drop_text = False):
         | 
| 46 | 
            +
                    batch, text_len = text.shape[0], text.shape[1]
         | 
| 47 | 
            +
                    text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
         | 
| 48 | 
            +
                    text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
         | 
| 49 | 
            +
                    text = F.pad(text, (0, seq_len - text_len), value = 0)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    if drop_text:  # cfg for text
         | 
| 52 | 
            +
                        text = torch.zeros_like(text)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    text = self.text_embed(text) # b n -> b n d
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    # possible extra modeling
         | 
| 57 | 
            +
                    if self.extra_modeling:
         | 
| 58 | 
            +
                        # sinus pos emb
         | 
| 59 | 
            +
                        batch_start = torch.zeros((batch,), dtype=torch.long)
         | 
| 60 | 
            +
                        pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
         | 
| 61 | 
            +
                        text_pos_embed = self.freqs_cis[pos_idx]
         | 
| 62 | 
            +
                        text = text + text_pos_embed
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                        # convnextv2 blocks
         | 
| 65 | 
            +
                        text = self.text_blocks(text)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    return text
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            # noised input audio and context mixing embedding
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            class InputEmbedding(nn.Module):
         | 
| 73 | 
            +
                def __init__(self, mel_dim, text_dim, out_dim):
         | 
| 74 | 
            +
                    super().__init__()
         | 
| 75 | 
            +
                    self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
         | 
| 76 | 
            +
                    self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
         | 
| 79 | 
            +
                    if drop_audio_cond:  # cfg for cond audio
         | 
| 80 | 
            +
                        cond = torch.zeros_like(cond)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
         | 
| 83 | 
            +
                    x = self.conv_pos_embed(x) + x
         | 
| 84 | 
            +
                    return x
         | 
| 85 | 
            +
                
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            # Transformer backbone using DiT blocks
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            class DiT(nn.Module):
         | 
| 90 | 
            +
                def __init__(self, *, 
         | 
| 91 | 
            +
                             dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
         | 
| 92 | 
            +
                             mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
         | 
| 93 | 
            +
                             long_skip_connection = False,
         | 
| 94 | 
            +
                ):
         | 
| 95 | 
            +
                    super().__init__()
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.time_embed = TimestepEmbedding(dim)
         | 
| 98 | 
            +
                    if text_dim is None:
         | 
| 99 | 
            +
                        text_dim = mel_dim
         | 
| 100 | 
            +
                    self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
         | 
| 101 | 
            +
                    self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    self.rotary_embed = RotaryEmbedding(dim_head)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    self.dim = dim
         | 
| 106 | 
            +
                    self.depth = depth
         | 
| 107 | 
            +
                    
         | 
| 108 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 109 | 
            +
                        [
         | 
| 110 | 
            +
                            DiTBlock(
         | 
| 111 | 
            +
                                dim = dim,
         | 
| 112 | 
            +
                                heads = heads,
         | 
| 113 | 
            +
                                dim_head = dim_head,
         | 
| 114 | 
            +
                                ff_mult = ff_mult,
         | 
| 115 | 
            +
                                dropout = dropout
         | 
| 116 | 
            +
                            )
         | 
| 117 | 
            +
                            for _ in range(depth)
         | 
| 118 | 
            +
                        ]
         | 
| 119 | 
            +
                    )
         | 
| 120 | 
            +
                    self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
         | 
| 121 | 
            +
                    
         | 
| 122 | 
            +
                    self.norm_out = AdaLayerNormZero_Final(dim)  # final modulation
         | 
| 123 | 
            +
                    self.proj_out = nn.Linear(dim, mel_dim)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def forward(
         | 
| 126 | 
            +
                    self,
         | 
| 127 | 
            +
                    x: float['b n d'],     # nosied input audio
         | 
| 128 | 
            +
                    cond: float['b n d'],  # masked cond audio
         | 
| 129 | 
            +
                    text: int['b nt'],     # text
         | 
| 130 | 
            +
                    time: float['b'] | float[''],  # time step
         | 
| 131 | 
            +
                    drop_audio_cond,  # cfg for cond audio
         | 
| 132 | 
            +
                    drop_text,        # cfg for text
         | 
| 133 | 
            +
                    mask: bool['b n'] | None = None,
         | 
| 134 | 
            +
                ):
         | 
| 135 | 
            +
                    batch, seq_len = x.shape[0], x.shape[1]
         | 
| 136 | 
            +
                    if time.ndim == 0:
         | 
| 137 | 
            +
                        time = repeat(time, ' -> b', b = batch)
         | 
| 138 | 
            +
                    
         | 
| 139 | 
            +
                    # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
         | 
| 140 | 
            +
                    t = self.time_embed(time)
         | 
| 141 | 
            +
                    text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
         | 
| 142 | 
            +
                    x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
         | 
| 143 | 
            +
                    
         | 
| 144 | 
            +
                    rope = self.rotary_embed.forward_from_seq_len(seq_len)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    if self.long_skip_connection is not None:
         | 
| 147 | 
            +
                        residual = x
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    for block in self.transformer_blocks:
         | 
| 150 | 
            +
                        x = block(x, t, mask = mask, rope = rope)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    if self.long_skip_connection is not None:
         | 
| 153 | 
            +
                        x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    x = self.norm_out(x, t)
         | 
| 156 | 
            +
                    output = self.proj_out(x)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    return output
         | 
    	
        model/backbones/mmdit.py
    ADDED
    
    | @@ -0,0 +1,136 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            ein notation:
         | 
| 3 | 
            +
            b - batch
         | 
| 4 | 
            +
            n - sequence
         | 
| 5 | 
            +
            nt - text sequence
         | 
| 6 | 
            +
            nw - raw wave length
         | 
| 7 | 
            +
            d - dimension
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from __future__ import annotations
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from torch import nn
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from einops import repeat
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from x_transformers.x_transformers import RotaryEmbedding
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from model.modules import (
         | 
| 20 | 
            +
                TimestepEmbedding,
         | 
| 21 | 
            +
                ConvPositionEmbedding,
         | 
| 22 | 
            +
                MMDiTBlock,
         | 
| 23 | 
            +
                AdaLayerNormZero_Final,
         | 
| 24 | 
            +
                precompute_freqs_cis, get_pos_embed_indices,
         | 
| 25 | 
            +
            )
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            # text embedding
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            class TextEmbedding(nn.Module):
         | 
| 31 | 
            +
                def __init__(self, out_dim, text_num_embeds):
         | 
| 32 | 
            +
                    super().__init__()
         | 
| 33 | 
            +
                    self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim)  # will use 0 as filler token
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    self.precompute_max_pos = 1024
         | 
| 36 | 
            +
                    self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
         | 
| 39 | 
            +
                    text = text + 1
         | 
| 40 | 
            +
                    if drop_text:
         | 
| 41 | 
            +
                        text = torch.zeros_like(text)
         | 
| 42 | 
            +
                    text = self.text_embed(text)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # sinus pos emb
         | 
| 45 | 
            +
                    batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
         | 
| 46 | 
            +
                    batch_text_len = text.shape[1]
         | 
| 47 | 
            +
                    pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
         | 
| 48 | 
            +
                    text_pos_embed = self.freqs_cis[pos_idx]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    text = text + text_pos_embed
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    return text
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            # noised input & masked cond audio embedding
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            class AudioEmbedding(nn.Module):
         | 
| 58 | 
            +
                def __init__(self, in_dim, out_dim):
         | 
| 59 | 
            +
                    super().__init__()
         | 
| 60 | 
            +
                    self.linear = nn.Linear(2 * in_dim, out_dim)
         | 
| 61 | 
            +
                    self.conv_pos_embed = ConvPositionEmbedding(out_dim)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
         | 
| 64 | 
            +
                    if drop_audio_cond:
         | 
| 65 | 
            +
                        cond = torch.zeros_like(cond)
         | 
| 66 | 
            +
                    x = torch.cat((x, cond), dim = -1)
         | 
| 67 | 
            +
                    x = self.linear(x)
         | 
| 68 | 
            +
                    x = self.conv_pos_embed(x) + x
         | 
| 69 | 
            +
                    return x
         | 
| 70 | 
            +
                
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            # Transformer backbone using MM-DiT blocks
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            class MMDiT(nn.Module):
         | 
| 75 | 
            +
                def __init__(self, *, 
         | 
| 76 | 
            +
                             dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
         | 
| 77 | 
            +
                             text_num_embeds = 256, mel_dim = 100,
         | 
| 78 | 
            +
                ):
         | 
| 79 | 
            +
                    super().__init__()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    self.time_embed = TimestepEmbedding(dim)
         | 
| 82 | 
            +
                    self.text_embed = TextEmbedding(dim, text_num_embeds)
         | 
| 83 | 
            +
                    self.audio_embed = AudioEmbedding(mel_dim, dim)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    self.rotary_embed = RotaryEmbedding(dim_head)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    self.dim = dim
         | 
| 88 | 
            +
                    self.depth = depth
         | 
| 89 | 
            +
                    
         | 
| 90 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 91 | 
            +
                        [
         | 
| 92 | 
            +
                            MMDiTBlock(
         | 
| 93 | 
            +
                                dim = dim,
         | 
| 94 | 
            +
                                heads = heads,
         | 
| 95 | 
            +
                                dim_head = dim_head,
         | 
| 96 | 
            +
                                dropout = dropout,
         | 
| 97 | 
            +
                                ff_mult = ff_mult,
         | 
| 98 | 
            +
                                context_pre_only = i == depth - 1,
         | 
| 99 | 
            +
                            )
         | 
| 100 | 
            +
                            for i in range(depth)
         | 
| 101 | 
            +
                        ]
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
                    self.norm_out = AdaLayerNormZero_Final(dim)  # final modulation
         | 
| 104 | 
            +
                    self.proj_out = nn.Linear(dim, mel_dim)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def forward(
         | 
| 107 | 
            +
                    self,
         | 
| 108 | 
            +
                    x: float['b n d'],     # nosied input audio
         | 
| 109 | 
            +
                    cond: float['b n d'],  # masked cond audio
         | 
| 110 | 
            +
                    text: int['b nt'],     # text
         | 
| 111 | 
            +
                    time: float['b'] | float[''],  # time step
         | 
| 112 | 
            +
                    drop_audio_cond,  # cfg for cond audio
         | 
| 113 | 
            +
                    drop_text,        # cfg for text
         | 
| 114 | 
            +
                    mask: bool['b n'] | None = None,
         | 
| 115 | 
            +
                ):
         | 
| 116 | 
            +
                    batch = x.shape[0]
         | 
| 117 | 
            +
                    if time.ndim == 0:
         | 
| 118 | 
            +
                        time = repeat(time, ' -> b', b = batch)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
         | 
| 121 | 
            +
                    t = self.time_embed(time)
         | 
| 122 | 
            +
                    c = self.text_embed(text, drop_text = drop_text)
         | 
| 123 | 
            +
                    x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    seq_len = x.shape[1]
         | 
| 126 | 
            +
                    text_len = text.shape[1]
         | 
| 127 | 
            +
                    rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
         | 
| 128 | 
            +
                    rope_text = self.rotary_embed.forward_from_seq_len(text_len)
         | 
| 129 | 
            +
                    
         | 
| 130 | 
            +
                    for block in self.transformer_blocks:
         | 
| 131 | 
            +
                        c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    x = self.norm_out(x, t)
         | 
| 134 | 
            +
                    output = self.proj_out(x)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    return output
         | 
    	
        model/backbones/unett.py
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            ein notation:
         | 
| 3 | 
            +
            b - batch
         | 
| 4 | 
            +
            n - sequence
         | 
| 5 | 
            +
            nt - text sequence
         | 
| 6 | 
            +
            nw - raw wave length
         | 
| 7 | 
            +
            d - dimension
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from __future__ import annotations
         | 
| 11 | 
            +
            from typing import Literal
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            from torch import nn
         | 
| 15 | 
            +
            import torch.nn.functional as F
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from einops import repeat, pack, unpack
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from x_transformers import RMSNorm
         | 
| 20 | 
            +
            from x_transformers.x_transformers import RotaryEmbedding
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from model.modules import (
         | 
| 23 | 
            +
                TimestepEmbedding,
         | 
| 24 | 
            +
                ConvNeXtV2Block,
         | 
| 25 | 
            +
                ConvPositionEmbedding,
         | 
| 26 | 
            +
                Attention,
         | 
| 27 | 
            +
                AttnProcessor,
         | 
| 28 | 
            +
                FeedForward,
         | 
| 29 | 
            +
                precompute_freqs_cis, get_pos_embed_indices,
         | 
| 30 | 
            +
            )
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            # Text embedding
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            class TextEmbedding(nn.Module):
         | 
| 36 | 
            +
                def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
         | 
| 37 | 
            +
                    super().__init__()
         | 
| 38 | 
            +
                    self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim)  # use 0 as filler token
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    if conv_layers > 0:
         | 
| 41 | 
            +
                        self.extra_modeling = True
         | 
| 42 | 
            +
                        self.precompute_max_pos = 4096  # ~44s of 24khz audio
         | 
| 43 | 
            +
                        self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
         | 
| 44 | 
            +
                        self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
         | 
| 45 | 
            +
                    else:
         | 
| 46 | 
            +
                        self.extra_modeling = False
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def forward(self, text: int['b nt'], seq_len, drop_text = False):
         | 
| 49 | 
            +
                    batch, text_len = text.shape[0], text.shape[1]
         | 
| 50 | 
            +
                    text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
         | 
| 51 | 
            +
                    text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
         | 
| 52 | 
            +
                    text = F.pad(text, (0, seq_len - text_len), value = 0)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    if drop_text:  # cfg for text
         | 
| 55 | 
            +
                        text = torch.zeros_like(text)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    text = self.text_embed(text) # b n -> b n d
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # possible extra modeling
         | 
| 60 | 
            +
                    if self.extra_modeling:
         | 
| 61 | 
            +
                        # sinus pos emb
         | 
| 62 | 
            +
                        batch_start = torch.zeros((batch,), dtype=torch.long)
         | 
| 63 | 
            +
                        pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
         | 
| 64 | 
            +
                        text_pos_embed = self.freqs_cis[pos_idx]
         | 
| 65 | 
            +
                        text = text + text_pos_embed
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                        # convnextv2 blocks
         | 
| 68 | 
            +
                        text = self.text_blocks(text)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    return text
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            # noised input audio and context mixing embedding
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            class InputEmbedding(nn.Module):
         | 
| 76 | 
            +
                def __init__(self, mel_dim, text_dim, out_dim):
         | 
| 77 | 
            +
                    super().__init__()
         | 
| 78 | 
            +
                    self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
         | 
| 79 | 
            +
                    self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
         | 
| 82 | 
            +
                    if drop_audio_cond:  # cfg for cond audio
         | 
| 83 | 
            +
                        cond = torch.zeros_like(cond)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
         | 
| 86 | 
            +
                    x = self.conv_pos_embed(x) + x
         | 
| 87 | 
            +
                    return x
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            # Flat UNet Transformer backbone
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            class UNetT(nn.Module):
         | 
| 93 | 
            +
                def __init__(self, *,
         | 
| 94 | 
            +
                             dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
         | 
| 95 | 
            +
                             mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
         | 
| 96 | 
            +
                             skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
         | 
| 97 | 
            +
                ):
         | 
| 98 | 
            +
                    super().__init__()
         | 
| 99 | 
            +
                    assert depth % 2 == 0, "UNet-Transformer's depth should be even."
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    self.time_embed = TimestepEmbedding(dim)
         | 
| 102 | 
            +
                    if text_dim is None:
         | 
| 103 | 
            +
                        text_dim = mel_dim
         | 
| 104 | 
            +
                    self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
         | 
| 105 | 
            +
                    self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    self.rotary_embed = RotaryEmbedding(dim_head)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # transformer layers & skip connections
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    self.dim = dim
         | 
| 112 | 
            +
                    self.skip_connect_type = skip_connect_type
         | 
| 113 | 
            +
                    needs_skip_proj = skip_connect_type == 'concat'
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    self.depth = depth
         | 
| 116 | 
            +
                    self.layers = nn.ModuleList([])
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    for idx in range(depth):
         | 
| 119 | 
            +
                        is_later_half = idx >= (depth // 2)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                        attn_norm = RMSNorm(dim)
         | 
| 122 | 
            +
                        attn = Attention(
         | 
| 123 | 
            +
                            processor = AttnProcessor(),
         | 
| 124 | 
            +
                            dim = dim,
         | 
| 125 | 
            +
                            heads = heads,
         | 
| 126 | 
            +
                            dim_head = dim_head,
         | 
| 127 | 
            +
                            dropout = dropout,
         | 
| 128 | 
            +
                            )
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        ff_norm = RMSNorm(dim)
         | 
| 131 | 
            +
                        ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                        skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                        self.layers.append(nn.ModuleList([
         | 
| 136 | 
            +
                            skip_proj,
         | 
| 137 | 
            +
                            attn_norm,
         | 
| 138 | 
            +
                            attn,
         | 
| 139 | 
            +
                            ff_norm,
         | 
| 140 | 
            +
                            ff,
         | 
| 141 | 
            +
                        ]))
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    self.norm_out = RMSNorm(dim)
         | 
| 144 | 
            +
                    self.proj_out = nn.Linear(dim, mel_dim)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def forward(
         | 
| 147 | 
            +
                    self,
         | 
| 148 | 
            +
                    x: float['b n d'],     # nosied input audio
         | 
| 149 | 
            +
                    cond: float['b n d'],  # masked cond audio
         | 
| 150 | 
            +
                    text: int['b nt'],     # text
         | 
| 151 | 
            +
                    time: float['b'] | float[''],  # time step
         | 
| 152 | 
            +
                    drop_audio_cond,  # cfg for cond audio
         | 
| 153 | 
            +
                    drop_text,        # cfg for text
         | 
| 154 | 
            +
                    mask: bool['b n'] | None = None,
         | 
| 155 | 
            +
                ):
         | 
| 156 | 
            +
                    batch, seq_len = x.shape[0], x.shape[1]
         | 
| 157 | 
            +
                    if time.ndim == 0:
         | 
| 158 | 
            +
                        time = repeat(time, ' -> b', b = batch)
         | 
| 159 | 
            +
                    
         | 
| 160 | 
            +
                    # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
         | 
| 161 | 
            +
                    t = self.time_embed(time)
         | 
| 162 | 
            +
                    text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
         | 
| 163 | 
            +
                    x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    # postfix time t to input x, [b n d] -> [b n+1 d]
         | 
| 166 | 
            +
                    x, ps = pack((t, x), 'b * d')
         | 
| 167 | 
            +
                    if mask is not None:
         | 
| 168 | 
            +
                        mask = F.pad(mask, (1, 0), value=1)
         | 
| 169 | 
            +
                    
         | 
| 170 | 
            +
                    rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    # flat unet transformer
         | 
| 173 | 
            +
                    skip_connect_type = self.skip_connect_type
         | 
| 174 | 
            +
                    skips = []
         | 
| 175 | 
            +
                    for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
         | 
| 176 | 
            +
                        layer = idx + 1
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                        # skip connection logic
         | 
| 179 | 
            +
                        is_first_half = layer <= (self.depth // 2)
         | 
| 180 | 
            +
                        is_later_half = not is_first_half
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                        if is_first_half:
         | 
| 183 | 
            +
                            skips.append(x)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                        if is_later_half:
         | 
| 186 | 
            +
                            skip = skips.pop()
         | 
| 187 | 
            +
                            if skip_connect_type == 'concat':
         | 
| 188 | 
            +
                                x = torch.cat((x, skip), dim = -1)
         | 
| 189 | 
            +
                                x = maybe_skip_proj(x)
         | 
| 190 | 
            +
                            elif skip_connect_type == 'add':
         | 
| 191 | 
            +
                                x = x + skip
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        # attention and feedforward blocks
         | 
| 194 | 
            +
                        x = attn(attn_norm(x), rope = rope, mask = mask) + x
         | 
| 195 | 
            +
                        x = ff(ff_norm(x)) + x
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    assert len(skips) == 0
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    _, x = unpack(self.norm_out(x), ps, 'b * d')
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    return self.proj_out(x)
         | 
    	
        model/cfm.py
    ADDED
    
    | @@ -0,0 +1,279 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            ein notation:
         | 
| 3 | 
            +
            b - batch
         | 
| 4 | 
            +
            n - sequence
         | 
| 5 | 
            +
            nt - text sequence
         | 
| 6 | 
            +
            nw - raw wave length
         | 
| 7 | 
            +
            d - dimension
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from __future__ import annotations
         | 
| 11 | 
            +
            from typing import Callable
         | 
| 12 | 
            +
            from random import random
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            from torch import nn
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            from torch.nn.utils.rnn import pad_sequence
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from torchdiffeq import odeint
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from einops import rearrange
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from model.modules import MelSpec
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from model.utils import (
         | 
| 26 | 
            +
                default, exists, 
         | 
| 27 | 
            +
                list_str_to_idx, list_str_to_tensor, 
         | 
| 28 | 
            +
                lens_to_mask, mask_from_frac_lengths,
         | 
| 29 | 
            +
            ) 
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class CFM(nn.Module):
         | 
| 33 | 
            +
                def __init__(
         | 
| 34 | 
            +
                    self,
         | 
| 35 | 
            +
                    transformer: nn.Module,
         | 
| 36 | 
            +
                    sigma = 0.,
         | 
| 37 | 
            +
                    odeint_kwargs: dict = dict(
         | 
| 38 | 
            +
                        # atol = 1e-5,
         | 
| 39 | 
            +
                        # rtol = 1e-5,
         | 
| 40 | 
            +
                        method = 'euler'  # 'midpoint'
         | 
| 41 | 
            +
                    ),
         | 
| 42 | 
            +
                    audio_drop_prob = 0.3,
         | 
| 43 | 
            +
                    cond_drop_prob = 0.2,
         | 
| 44 | 
            +
                    num_channels = None,
         | 
| 45 | 
            +
                    mel_spec_module: nn.Module | None = None,
         | 
| 46 | 
            +
                    mel_spec_kwargs: dict = dict(),
         | 
| 47 | 
            +
                    frac_lengths_mask: tuple[float, float] = (0.7, 1.),
         | 
| 48 | 
            +
                    vocab_char_map: dict[str: int] | None = None
         | 
| 49 | 
            +
                ):
         | 
| 50 | 
            +
                    super().__init__()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    self.frac_lengths_mask = frac_lengths_mask
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # mel spec
         | 
| 55 | 
            +
                    self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
         | 
| 56 | 
            +
                    num_channels = default(num_channels, self.mel_spec.n_mel_channels)
         | 
| 57 | 
            +
                    self.num_channels = num_channels
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # classifier-free guidance
         | 
| 60 | 
            +
                    self.audio_drop_prob = audio_drop_prob
         | 
| 61 | 
            +
                    self.cond_drop_prob = cond_drop_prob
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    # transformer
         | 
| 64 | 
            +
                    self.transformer = transformer
         | 
| 65 | 
            +
                    dim = transformer.dim
         | 
| 66 | 
            +
                    self.dim = dim
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # conditional flow related
         | 
| 69 | 
            +
                    self.sigma = sigma
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    # sampling related
         | 
| 72 | 
            +
                    self.odeint_kwargs = odeint_kwargs
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    # vocab map for tokenization
         | 
| 75 | 
            +
                    self.vocab_char_map = vocab_char_map
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                @property
         | 
| 78 | 
            +
                def device(self):
         | 
| 79 | 
            +
                    return next(self.parameters()).device
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                @torch.no_grad()
         | 
| 82 | 
            +
                def sample(
         | 
| 83 | 
            +
                    self,
         | 
| 84 | 
            +
                    cond: float['b n d'] | float['b nw'],
         | 
| 85 | 
            +
                    text: int['b nt'] | list[str],
         | 
| 86 | 
            +
                    duration: int | int['b'],
         | 
| 87 | 
            +
                    *,
         | 
| 88 | 
            +
                    lens: int['b'] | None = None,
         | 
| 89 | 
            +
                    steps = 32,
         | 
| 90 | 
            +
                    cfg_strength = 1., 
         | 
| 91 | 
            +
                    sway_sampling_coef = None,
         | 
| 92 | 
            +
                    seed: int | None = None,
         | 
| 93 | 
            +
                    max_duration = 4096, 
         | 
| 94 | 
            +
                    vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
         | 
| 95 | 
            +
                    no_ref_audio = False,
         | 
| 96 | 
            +
                    duplicate_test = False,
         | 
| 97 | 
            +
                    t_inter = 0.1,
         | 
| 98 | 
            +
                    edit_mask = None,
         | 
| 99 | 
            +
                ):
         | 
| 100 | 
            +
                    self.eval()
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # raw wave
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    if cond.ndim == 2:
         | 
| 105 | 
            +
                        cond = self.mel_spec(cond)
         | 
| 106 | 
            +
                        cond = rearrange(cond, 'b d n -> b n d')
         | 
| 107 | 
            +
                        assert cond.shape[-1] == self.num_channels
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    batch, cond_seq_len, device = *cond.shape[:2], cond.device
         | 
| 110 | 
            +
                    if not exists(lens):
         | 
| 111 | 
            +
                        lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    # text
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    if isinstance(text, list):
         | 
| 116 | 
            +
                        if exists(self.vocab_char_map):
         | 
| 117 | 
            +
                            text = list_str_to_idx(text, self.vocab_char_map).to(device)
         | 
| 118 | 
            +
                        else:
         | 
| 119 | 
            +
                            text = list_str_to_tensor(text).to(device)
         | 
| 120 | 
            +
                        assert text.shape[0] == batch
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    if exists(text):
         | 
| 123 | 
            +
                        text_lens = (text != -1).sum(dim = -1)
         | 
| 124 | 
            +
                        lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    # duration
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    cond_mask = lens_to_mask(lens)
         | 
| 129 | 
            +
                    if edit_mask is not None:
         | 
| 130 | 
            +
                        cond_mask = cond_mask & edit_mask
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    if isinstance(duration, int):
         | 
| 133 | 
            +
                        duration = torch.full((batch,), duration, device = device, dtype = torch.long)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
         | 
| 136 | 
            +
                    duration = duration.clamp(max = max_duration)
         | 
| 137 | 
            +
                    max_duration = duration.amax()
         | 
| 138 | 
            +
                    
         | 
| 139 | 
            +
                    # duplicate test corner for inner time step oberservation
         | 
| 140 | 
            +
                    if duplicate_test:
         | 
| 141 | 
            +
                        test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
         | 
| 142 | 
            +
                        
         | 
| 143 | 
            +
                    cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
         | 
| 144 | 
            +
                    cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
         | 
| 145 | 
            +
                    cond_mask = rearrange(cond_mask, '... -> ... 1')
         | 
| 146 | 
            +
                    step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))  # allow direct control (cut cond audio) with lens passed in
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    if batch > 1:
         | 
| 149 | 
            +
                        mask = lens_to_mask(duration)
         | 
| 150 | 
            +
                    else:  # save memory and speed up, as single inference need no mask currently
         | 
| 151 | 
            +
                        mask = None
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    # test for no ref audio
         | 
| 154 | 
            +
                    if no_ref_audio:
         | 
| 155 | 
            +
                        cond = torch.zeros_like(cond)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # neural ode
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    def fn(t, x):
         | 
| 160 | 
            +
                        # at each step, conditioning is fixed
         | 
| 161 | 
            +
                        # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                        # predict flow
         | 
| 164 | 
            +
                        pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
         | 
| 165 | 
            +
                        if cfg_strength < 1e-5:
         | 
| 166 | 
            +
                            return pred
         | 
| 167 | 
            +
                        
         | 
| 168 | 
            +
                        null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
         | 
| 169 | 
            +
                        return pred + (pred - null_pred) * cfg_strength
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    # noise input
         | 
| 172 | 
            +
                    # to make sure batch inference result is same with different batch size, and for sure single inference
         | 
| 173 | 
            +
                    # still some difference maybe due to convolutional layers
         | 
| 174 | 
            +
                    y0 = []
         | 
| 175 | 
            +
                    for dur in duration:
         | 
| 176 | 
            +
                        if exists(seed):
         | 
| 177 | 
            +
                            torch.manual_seed(seed)
         | 
| 178 | 
            +
                        y0.append(torch.randn(dur, self.num_channels, device = self.device))
         | 
| 179 | 
            +
                    y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    t_start = 0
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # duplicate test corner for inner time step oberservation
         | 
| 184 | 
            +
                    if duplicate_test:
         | 
| 185 | 
            +
                        t_start = t_inter
         | 
| 186 | 
            +
                        y0 = (1 - t_start) * y0 + t_start * test_cond
         | 
| 187 | 
            +
                        steps = int(steps * (1 - t_start))
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    t = torch.linspace(t_start, 1, steps, device = self.device)
         | 
| 190 | 
            +
                    if sway_sampling_coef is not None:
         | 
| 191 | 
            +
                        t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
         | 
| 194 | 
            +
                    
         | 
| 195 | 
            +
                    sampled = trajectory[-1]
         | 
| 196 | 
            +
                    out = sampled
         | 
| 197 | 
            +
                    out = torch.where(cond_mask, cond, out)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    if exists(vocoder):
         | 
| 200 | 
            +
                        out = rearrange(out, 'b n d -> b d n')
         | 
| 201 | 
            +
                        out = vocoder(out)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    return out, trajectory
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def forward(
         | 
| 206 | 
            +
                    self,
         | 
| 207 | 
            +
                    inp: float['b n d'] | float['b nw'], # mel or raw wave
         | 
| 208 | 
            +
                    text: int['b nt'] | list[str],
         | 
| 209 | 
            +
                    *,
         | 
| 210 | 
            +
                    lens: int['b'] | None = None,
         | 
| 211 | 
            +
                    noise_scheduler: str | None = None,
         | 
| 212 | 
            +
                ):
         | 
| 213 | 
            +
                    # handle raw wave
         | 
| 214 | 
            +
                    if inp.ndim == 2:
         | 
| 215 | 
            +
                        inp = self.mel_spec(inp)
         | 
| 216 | 
            +
                        inp = rearrange(inp, 'b d n -> b n d')
         | 
| 217 | 
            +
                        assert inp.shape[-1] == self.num_channels
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    # handle text as string
         | 
| 222 | 
            +
                    if isinstance(text, list):
         | 
| 223 | 
            +
                        if exists(self.vocab_char_map):
         | 
| 224 | 
            +
                            text = list_str_to_idx(text, self.vocab_char_map).to(device)
         | 
| 225 | 
            +
                        else:
         | 
| 226 | 
            +
                            text = list_str_to_tensor(text).to(device)
         | 
| 227 | 
            +
                        assert text.shape[0] == batch
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    # lens and mask
         | 
| 230 | 
            +
                    if not exists(lens):
         | 
| 231 | 
            +
                        lens = torch.full((batch,), seq_len, device = device)
         | 
| 232 | 
            +
                    
         | 
| 233 | 
            +
                    mask = lens_to_mask(lens, length = seq_len)  # useless here, as collate_fn will pad to max length in batch
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    # get a random span to mask out for training conditionally
         | 
| 236 | 
            +
                    frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
         | 
| 237 | 
            +
                    rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    if exists(mask):
         | 
| 240 | 
            +
                        rand_span_mask &= mask
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    # mel is x1
         | 
| 243 | 
            +
                    x1 = inp
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    # x0 is gaussian noise
         | 
| 246 | 
            +
                    x0 = torch.randn_like(x1)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # time step
         | 
| 249 | 
            +
                    time = torch.rand((batch,), dtype = dtype, device = self.device)
         | 
| 250 | 
            +
                    # TODO. noise_scheduler
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    # sample xt (φ_t(x) in the paper)
         | 
| 253 | 
            +
                    t = rearrange(time, 'b -> b 1 1')
         | 
| 254 | 
            +
                    φ = (1 - t) * x0 + t * x1
         | 
| 255 | 
            +
                    flow = x1 - x0
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    # only predict what is within the random mask span for infilling
         | 
| 258 | 
            +
                    cond = torch.where(
         | 
| 259 | 
            +
                        rand_span_mask[..., None],
         | 
| 260 | 
            +
                        torch.zeros_like(x1), x1
         | 
| 261 | 
            +
                    )
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    # transformer and cfg training with a drop rate
         | 
| 264 | 
            +
                    drop_audio_cond = random() < self.audio_drop_prob  # p_drop in voicebox paper
         | 
| 265 | 
            +
                    if random() < self.cond_drop_prob:  # p_uncond in voicebox paper
         | 
| 266 | 
            +
                        drop_audio_cond = True
         | 
| 267 | 
            +
                        drop_text = True
         | 
| 268 | 
            +
                    else:
         | 
| 269 | 
            +
                        drop_text = False
         | 
| 270 | 
            +
                        
         | 
| 271 | 
            +
                    # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
         | 
| 272 | 
            +
                    # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
         | 
| 273 | 
            +
                    pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    # flow matching loss
         | 
| 276 | 
            +
                    loss = F.mse_loss(pred, flow, reduction = 'none')
         | 
| 277 | 
            +
                    loss = loss[rand_span_mask]
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    return loss.mean(), cond, pred
         | 
    	
        model/dataset.py
    ADDED
    
    | @@ -0,0 +1,242 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            from tqdm import tqdm
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            from torch.utils.data import Dataset, Sampler
         | 
| 8 | 
            +
            import torchaudio
         | 
| 9 | 
            +
            from datasets import load_dataset, load_from_disk
         | 
| 10 | 
            +
            from datasets import Dataset as Dataset_
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from einops import rearrange
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from model.modules import MelSpec
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class HFDataset(Dataset):
         | 
| 18 | 
            +
                def __init__(
         | 
| 19 | 
            +
                    self,
         | 
| 20 | 
            +
                    hf_dataset: Dataset,
         | 
| 21 | 
            +
                    target_sample_rate = 24_000,
         | 
| 22 | 
            +
                    n_mel_channels = 100,
         | 
| 23 | 
            +
                    hop_length = 256,
         | 
| 24 | 
            +
                ):
         | 
| 25 | 
            +
                    self.data = hf_dataset
         | 
| 26 | 
            +
                    self.target_sample_rate = target_sample_rate
         | 
| 27 | 
            +
                    self.hop_length = hop_length
         | 
| 28 | 
            +
                    self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
         | 
| 29 | 
            +
                    
         | 
| 30 | 
            +
                def get_frame_len(self, index):
         | 
| 31 | 
            +
                    row = self.data[index]
         | 
| 32 | 
            +
                    audio = row['audio']['array']
         | 
| 33 | 
            +
                    sample_rate = row['audio']['sampling_rate']
         | 
| 34 | 
            +
                    return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def __len__(self):
         | 
| 37 | 
            +
                    return len(self.data)
         | 
| 38 | 
            +
                
         | 
| 39 | 
            +
                def __getitem__(self, index):
         | 
| 40 | 
            +
                    row = self.data[index]
         | 
| 41 | 
            +
                    audio = row['audio']['array']
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    # logger.info(f"Audio shape: {audio.shape}")
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    sample_rate = row['audio']['sampling_rate']
         | 
| 46 | 
            +
                    duration = audio.shape[-1] / sample_rate
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    if duration > 30 or duration < 0.3:
         | 
| 49 | 
            +
                        return self.__getitem__((index + 1) % len(self.data))
         | 
| 50 | 
            +
                    
         | 
| 51 | 
            +
                    audio_tensor = torch.from_numpy(audio).float()
         | 
| 52 | 
            +
                    
         | 
| 53 | 
            +
                    if sample_rate != self.target_sample_rate:
         | 
| 54 | 
            +
                        resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
         | 
| 55 | 
            +
                        audio_tensor = resampler(audio_tensor)
         | 
| 56 | 
            +
                    
         | 
| 57 | 
            +
                    audio_tensor = rearrange(audio_tensor, 't -> 1 t')
         | 
| 58 | 
            +
                    
         | 
| 59 | 
            +
                    mel_spec = self.mel_spectrogram(audio_tensor)
         | 
| 60 | 
            +
                    
         | 
| 61 | 
            +
                    mel_spec = rearrange(mel_spec, '1 d t -> d t')
         | 
| 62 | 
            +
                    
         | 
| 63 | 
            +
                    text = row['text']
         | 
| 64 | 
            +
                    
         | 
| 65 | 
            +
                    return dict(
         | 
| 66 | 
            +
                        mel_spec = mel_spec,
         | 
| 67 | 
            +
                        text = text,
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            class CustomDataset(Dataset):
         | 
| 72 | 
            +
                def __init__(
         | 
| 73 | 
            +
                    self,
         | 
| 74 | 
            +
                    custom_dataset: Dataset,
         | 
| 75 | 
            +
                    durations = None,
         | 
| 76 | 
            +
                    target_sample_rate = 24_000,
         | 
| 77 | 
            +
                    hop_length = 256,
         | 
| 78 | 
            +
                    n_mel_channels = 100,
         | 
| 79 | 
            +
                    preprocessed_mel = False,
         | 
| 80 | 
            +
                ):
         | 
| 81 | 
            +
                    self.data = custom_dataset
         | 
| 82 | 
            +
                    self.durations = durations
         | 
| 83 | 
            +
                    self.target_sample_rate = target_sample_rate
         | 
| 84 | 
            +
                    self.hop_length = hop_length
         | 
| 85 | 
            +
                    self.preprocessed_mel = preprocessed_mel
         | 
| 86 | 
            +
                    if not preprocessed_mel:
         | 
| 87 | 
            +
                        self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def get_frame_len(self, index):
         | 
| 90 | 
            +
                    if self.durations is not None:  # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
         | 
| 91 | 
            +
                        return self.durations[index] * self.target_sample_rate / self.hop_length
         | 
| 92 | 
            +
                    return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                def __len__(self):
         | 
| 95 | 
            +
                    return len(self.data)
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                def __getitem__(self, index):
         | 
| 98 | 
            +
                    row = self.data[index]
         | 
| 99 | 
            +
                    audio_path = row["audio_path"]
         | 
| 100 | 
            +
                    text = row["text"]
         | 
| 101 | 
            +
                    duration = row["duration"]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    if self.preprocessed_mel:
         | 
| 104 | 
            +
                        mel_spec = torch.tensor(row["mel_spec"])
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    else:
         | 
| 107 | 
            +
                        audio, source_sample_rate = torchaudio.load(audio_path)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                        if duration > 30 or duration < 0.3:
         | 
| 110 | 
            +
                            return self.__getitem__((index + 1) % len(self.data))
         | 
| 111 | 
            +
                        
         | 
| 112 | 
            +
                        if source_sample_rate != self.target_sample_rate:
         | 
| 113 | 
            +
                            resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
         | 
| 114 | 
            +
                            audio = resampler(audio)
         | 
| 115 | 
            +
                        
         | 
| 116 | 
            +
                        mel_spec = self.mel_spectrogram(audio)
         | 
| 117 | 
            +
                        mel_spec = rearrange(mel_spec, '1 d t -> d t')
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    return dict(
         | 
| 120 | 
            +
                        mel_spec = mel_spec,
         | 
| 121 | 
            +
                        text = text,
         | 
| 122 | 
            +
                    )
         | 
| 123 | 
            +
                
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            # Dynamic Batch Sampler
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            class DynamicBatchSampler(Sampler[list[int]]):
         | 
| 128 | 
            +
                """ Extension of Sampler that will do the following:
         | 
| 129 | 
            +
                    1.  Change the batch size (essentially number of sequences)
         | 
| 130 | 
            +
                        in a batch to ensure that the total number of frames are less
         | 
| 131 | 
            +
                        than a certain threshold.
         | 
| 132 | 
            +
                    2.  Make sure the padding efficiency in the batch is high.
         | 
| 133 | 
            +
                """
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
         | 
| 136 | 
            +
                    self.sampler = sampler
         | 
| 137 | 
            +
                    self.frames_threshold = frames_threshold
         | 
| 138 | 
            +
                    self.max_samples = max_samples
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    indices, batches = [], []
         | 
| 141 | 
            +
                    data_source = self.sampler.data_source
         | 
| 142 | 
            +
                    
         | 
| 143 | 
            +
                    for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
         | 
| 144 | 
            +
                        indices.append((idx, data_source.get_frame_len(idx)))
         | 
| 145 | 
            +
                    indices.sort(key=lambda elem : elem[1])
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    batch = []
         | 
| 148 | 
            +
                    batch_frames = 0
         | 
| 149 | 
            +
                    for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
         | 
| 150 | 
            +
                        if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
         | 
| 151 | 
            +
                            batch.append(idx)
         | 
| 152 | 
            +
                            batch_frames += frame_len
         | 
| 153 | 
            +
                        else:
         | 
| 154 | 
            +
                            if len(batch) > 0:
         | 
| 155 | 
            +
                                batches.append(batch)
         | 
| 156 | 
            +
                            if frame_len <= self.frames_threshold:
         | 
| 157 | 
            +
                                batch = [idx]
         | 
| 158 | 
            +
                                batch_frames = frame_len
         | 
| 159 | 
            +
                            else:
         | 
| 160 | 
            +
                                batch = []
         | 
| 161 | 
            +
                                batch_frames = 0
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    if not drop_last and len(batch) > 0:
         | 
| 164 | 
            +
                        batches.append(batch)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    del indices
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # if want to have different batches between epochs, may just set a seed and log it in ckpt
         | 
| 169 | 
            +
                    # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
         | 
| 170 | 
            +
                    # e.g. for epoch n, use (random_seed + n)
         | 
| 171 | 
            +
                    random.seed(random_seed)
         | 
| 172 | 
            +
                    random.shuffle(batches)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    self.batches = batches
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def __iter__(self):
         | 
| 177 | 
            +
                    return iter(self.batches)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def __len__(self):
         | 
| 180 | 
            +
                    return len(self.batches)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            # Load dataset
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            def load_dataset(
         | 
| 186 | 
            +
                    dataset_name: str,
         | 
| 187 | 
            +
                    tokenizer: str,
         | 
| 188 | 
            +
                    dataset_type: str = "CustomDataset", 
         | 
| 189 | 
            +
                    audio_type: str = "raw", 
         | 
| 190 | 
            +
                    mel_spec_kwargs: dict = dict()
         | 
| 191 | 
            +
                    ) -> CustomDataset:
         | 
| 192 | 
            +
                
         | 
| 193 | 
            +
                print("Loading dataset ...")
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                if dataset_type == "CustomDataset":
         | 
| 196 | 
            +
                    if audio_type == "raw":
         | 
| 197 | 
            +
                        try:
         | 
| 198 | 
            +
                            train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
         | 
| 199 | 
            +
                        except:
         | 
| 200 | 
            +
                            train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
         | 
| 201 | 
            +
                        preprocessed_mel = False
         | 
| 202 | 
            +
                    elif audio_type == "mel":
         | 
| 203 | 
            +
                        train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
         | 
| 204 | 
            +
                        preprocessed_mel = True
         | 
| 205 | 
            +
                    with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
         | 
| 206 | 
            +
                        data_dict = json.load(f)
         | 
| 207 | 
            +
                    durations = data_dict["duration"]
         | 
| 208 | 
            +
                    train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                elif dataset_type == "HFDataset":
         | 
| 211 | 
            +
                    print("Should manually modify the path of huggingface dataset to your need.\n" +
         | 
| 212 | 
            +
                          "May also the corresponding script cuz different dataset may have different format.")
         | 
| 213 | 
            +
                    pre, post = dataset_name.split("_")
         | 
| 214 | 
            +
                    train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                return train_dataset
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
            # collation
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            def collate_fn(batch):
         | 
| 222 | 
            +
                mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
         | 
| 223 | 
            +
                mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
         | 
| 224 | 
            +
                max_mel_length = mel_lengths.amax()
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                padded_mel_specs = []
         | 
| 227 | 
            +
                for spec in mel_specs:  # TODO. maybe records mask for attention here
         | 
| 228 | 
            +
                    padding = (0, max_mel_length - spec.size(-1))
         | 
| 229 | 
            +
                    padded_spec = F.pad(spec, padding, value = 0)
         | 
| 230 | 
            +
                    padded_mel_specs.append(padded_spec)
         | 
| 231 | 
            +
                
         | 
| 232 | 
            +
                mel_specs = torch.stack(padded_mel_specs)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                text = [item['text'] for item in batch]
         | 
| 235 | 
            +
                text_lengths = torch.LongTensor([len(item) for item in text])
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                return dict(
         | 
| 238 | 
            +
                    mel = mel_specs,
         | 
| 239 | 
            +
                    mel_lengths = mel_lengths,
         | 
| 240 | 
            +
                    text = text,
         | 
| 241 | 
            +
                    text_lengths = text_lengths,
         | 
| 242 | 
            +
                )
         | 
    	
        model/ecapa_tdnn.py
    ADDED
    
    | @@ -0,0 +1,268 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # just for speaker similarity evaluation, third-party code
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
         | 
| 4 | 
            +
            # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn as nn
         | 
| 9 | 
            +
            import torch.nn.functional as F
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            ''' Res2Conv1d + BatchNorm1d + ReLU
         | 
| 13 | 
            +
            '''
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            class Res2Conv1dReluBn(nn.Module):
         | 
| 16 | 
            +
                '''
         | 
| 17 | 
            +
                in_channels == out_channels == channels
         | 
| 18 | 
            +
                '''
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
         | 
| 21 | 
            +
                    super().__init__()
         | 
| 22 | 
            +
                    assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
         | 
| 23 | 
            +
                    self.scale = scale
         | 
| 24 | 
            +
                    self.width = channels // scale
         | 
| 25 | 
            +
                    self.nums = scale if scale == 1 else scale - 1
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    self.convs = []
         | 
| 28 | 
            +
                    self.bns = []
         | 
| 29 | 
            +
                    for i in range(self.nums):
         | 
| 30 | 
            +
                        self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
         | 
| 31 | 
            +
                        self.bns.append(nn.BatchNorm1d(self.width))
         | 
| 32 | 
            +
                    self.convs = nn.ModuleList(self.convs)
         | 
| 33 | 
            +
                    self.bns = nn.ModuleList(self.bns)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def forward(self, x):
         | 
| 36 | 
            +
                    out = []
         | 
| 37 | 
            +
                    spx = torch.split(x, self.width, 1)
         | 
| 38 | 
            +
                    for i in range(self.nums):
         | 
| 39 | 
            +
                        if i == 0:
         | 
| 40 | 
            +
                            sp = spx[i]
         | 
| 41 | 
            +
                        else:
         | 
| 42 | 
            +
                            sp = sp + spx[i]
         | 
| 43 | 
            +
                        # Order: conv -> relu -> bn
         | 
| 44 | 
            +
                        sp = self.convs[i](sp)
         | 
| 45 | 
            +
                        sp = self.bns[i](F.relu(sp))
         | 
| 46 | 
            +
                        out.append(sp)
         | 
| 47 | 
            +
                    if self.scale != 1:
         | 
| 48 | 
            +
                        out.append(spx[self.nums])
         | 
| 49 | 
            +
                    out = torch.cat(out, dim=1)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    return out
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            ''' Conv1d + BatchNorm1d + ReLU
         | 
| 55 | 
            +
            '''
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            class Conv1dReluBn(nn.Module):
         | 
| 58 | 
            +
                def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
         | 
| 59 | 
            +
                    super().__init__()
         | 
| 60 | 
            +
                    self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
         | 
| 61 | 
            +
                    self.bn = nn.BatchNorm1d(out_channels)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def forward(self, x):
         | 
| 64 | 
            +
                    return self.bn(F.relu(self.conv(x)))
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            ''' The SE connection of 1D case.
         | 
| 68 | 
            +
            '''
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            class SE_Connect(nn.Module):
         | 
| 71 | 
            +
                def __init__(self, channels, se_bottleneck_dim=128):
         | 
| 72 | 
            +
                    super().__init__()
         | 
| 73 | 
            +
                    self.linear1 = nn.Linear(channels, se_bottleneck_dim)
         | 
| 74 | 
            +
                    self.linear2 = nn.Linear(se_bottleneck_dim, channels)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def forward(self, x):
         | 
| 77 | 
            +
                    out = x.mean(dim=2)
         | 
| 78 | 
            +
                    out = F.relu(self.linear1(out))
         | 
| 79 | 
            +
                    out = torch.sigmoid(self.linear2(out))
         | 
| 80 | 
            +
                    out = x * out.unsqueeze(2)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    return out
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            ''' SE-Res2Block of the ECAPA-TDNN architecture.
         | 
| 86 | 
            +
            '''
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
         | 
| 89 | 
            +
            #     return nn.Sequential(
         | 
| 90 | 
            +
            #         Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
         | 
| 91 | 
            +
            #         Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
         | 
| 92 | 
            +
            #         Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
         | 
| 93 | 
            +
            #         SE_Connect(channels)
         | 
| 94 | 
            +
            #     )
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            class SE_Res2Block(nn.Module):
         | 
| 97 | 
            +
                def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
         | 
| 98 | 
            +
                    super().__init__()
         | 
| 99 | 
            +
                    self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
         | 
| 100 | 
            +
                    self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
         | 
| 101 | 
            +
                    self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
         | 
| 102 | 
            +
                    self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self.shortcut = None
         | 
| 105 | 
            +
                    if in_channels != out_channels:
         | 
| 106 | 
            +
                        self.shortcut = nn.Conv1d(
         | 
| 107 | 
            +
                            in_channels=in_channels,
         | 
| 108 | 
            +
                            out_channels=out_channels,
         | 
| 109 | 
            +
                            kernel_size=1,
         | 
| 110 | 
            +
                        )
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def forward(self, x):
         | 
| 113 | 
            +
                    residual = x
         | 
| 114 | 
            +
                    if self.shortcut:
         | 
| 115 | 
            +
                        residual = self.shortcut(x)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    x = self.Conv1dReluBn1(x)
         | 
| 118 | 
            +
                    x = self.Res2Conv1dReluBn(x)
         | 
| 119 | 
            +
                    x = self.Conv1dReluBn2(x)
         | 
| 120 | 
            +
                    x = self.SE_Connect(x)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    return x + residual
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            ''' Attentive weighted mean and standard deviation pooling.
         | 
| 126 | 
            +
            '''
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            class AttentiveStatsPool(nn.Module):
         | 
| 129 | 
            +
                def __init__(self, in_dim, attention_channels=128, global_context_att=False):
         | 
| 130 | 
            +
                    super().__init__()
         | 
| 131 | 
            +
                    self.global_context_att = global_context_att
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
         | 
| 134 | 
            +
                    if global_context_att:
         | 
| 135 | 
            +
                        self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1)  # equals W and b in the paper
         | 
| 136 | 
            +
                    else:
         | 
| 137 | 
            +
                        self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1)  # equals W and b in the paper
         | 
| 138 | 
            +
                    self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1)  # equals V and k in the paper
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                def forward(self, x):
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    if self.global_context_att:
         | 
| 143 | 
            +
                        context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
         | 
| 144 | 
            +
                        context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
         | 
| 145 | 
            +
                        x_in = torch.cat((x, context_mean, context_std), dim=1)
         | 
| 146 | 
            +
                    else:
         | 
| 147 | 
            +
                        x_in = x
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
         | 
| 150 | 
            +
                    alpha = torch.tanh(self.linear1(x_in))
         | 
| 151 | 
            +
                    # alpha = F.relu(self.linear1(x_in))
         | 
| 152 | 
            +
                    alpha = torch.softmax(self.linear2(alpha), dim=2)
         | 
| 153 | 
            +
                    mean = torch.sum(alpha * x, dim=2)
         | 
| 154 | 
            +
                    residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
         | 
| 155 | 
            +
                    std = torch.sqrt(residuals.clamp(min=1e-9))
         | 
| 156 | 
            +
                    return torch.cat([mean, std], dim=1)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            class ECAPA_TDNN(nn.Module):
         | 
| 160 | 
            +
                def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
         | 
| 161 | 
            +
                             feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
         | 
| 162 | 
            +
                    super().__init__()
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    self.feat_type = feat_type
         | 
| 165 | 
            +
                    self.feature_selection = feature_selection
         | 
| 166 | 
            +
                    self.update_extract = update_extract
         | 
| 167 | 
            +
                    self.sr = sr
         | 
| 168 | 
            +
                    
         | 
| 169 | 
            +
                    torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
         | 
| 170 | 
            +
                    try:
         | 
| 171 | 
            +
                        local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
         | 
| 172 | 
            +
                        self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
         | 
| 173 | 
            +
                    except:
         | 
| 174 | 
            +
                        self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
         | 
| 177 | 
            +
                        self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
         | 
| 178 | 
            +
                    if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
         | 
| 179 | 
            +
                        self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    self.feat_num = self.get_feat_num()
         | 
| 182 | 
            +
                    self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    if feat_type != 'fbank' and feat_type != 'mfcc':
         | 
| 185 | 
            +
                        freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
         | 
| 186 | 
            +
                        for name, param in self.feature_extract.named_parameters():
         | 
| 187 | 
            +
                            for freeze_val in freeze_list:
         | 
| 188 | 
            +
                                if freeze_val in name:
         | 
| 189 | 
            +
                                    param.requires_grad = False
         | 
| 190 | 
            +
                                    break
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    if not self.update_extract:
         | 
| 193 | 
            +
                        for param in self.feature_extract.parameters():
         | 
| 194 | 
            +
                            param.requires_grad = False
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    self.instance_norm = nn.InstanceNorm1d(feat_dim)
         | 
| 197 | 
            +
                    # self.channels = [channels] * 4 + [channels * 3]
         | 
| 198 | 
            +
                    self.channels = [channels] * 4 + [1536]
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
         | 
| 201 | 
            +
                    self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
         | 
| 202 | 
            +
                    self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
         | 
| 203 | 
            +
                    self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
         | 
| 206 | 
            +
                    cat_channels = channels * 3
         | 
| 207 | 
            +
                    self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
         | 
| 208 | 
            +
                    self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
         | 
| 209 | 
            +
                    self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
         | 
| 210 | 
            +
                    self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
                def get_feat_num(self):
         | 
| 214 | 
            +
                    self.feature_extract.eval()
         | 
| 215 | 
            +
                    wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
         | 
| 216 | 
            +
                    with torch.no_grad():
         | 
| 217 | 
            +
                        features = self.feature_extract(wav)
         | 
| 218 | 
            +
                    select_feature = features[self.feature_selection]
         | 
| 219 | 
            +
                    if isinstance(select_feature, (list, tuple)):
         | 
| 220 | 
            +
                        return len(select_feature)
         | 
| 221 | 
            +
                    else:
         | 
| 222 | 
            +
                        return 1
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                def get_feat(self, x):
         | 
| 225 | 
            +
                    if self.update_extract:
         | 
| 226 | 
            +
                        x = self.feature_extract([sample for sample in x])
         | 
| 227 | 
            +
                    else:
         | 
| 228 | 
            +
                        with torch.no_grad():
         | 
| 229 | 
            +
                            if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
         | 
| 230 | 
            +
                                x = self.feature_extract(x) + 1e-6  # B x feat_dim x time_len
         | 
| 231 | 
            +
                            else:
         | 
| 232 | 
            +
                                x = self.feature_extract([sample for sample in x])
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    if self.feat_type == 'fbank':
         | 
| 235 | 
            +
                        x = x.log()
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    if self.feat_type != "fbank" and self.feat_type != "mfcc":
         | 
| 238 | 
            +
                        x = x[self.feature_selection]
         | 
| 239 | 
            +
                        if isinstance(x, (list, tuple)):
         | 
| 240 | 
            +
                            x = torch.stack(x, dim=0)
         | 
| 241 | 
            +
                        else:
         | 
| 242 | 
            +
                            x = x.unsqueeze(0)
         | 
| 243 | 
            +
                        norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
         | 
| 244 | 
            +
                        x = (norm_weights * x).sum(dim=0)
         | 
| 245 | 
            +
                        x = torch.transpose(x, 1, 2) + 1e-6
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    x = self.instance_norm(x)
         | 
| 248 | 
            +
                    return x
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                def forward(self, x):
         | 
| 251 | 
            +
                    x = self.get_feat(x)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    out1 = self.layer1(x)
         | 
| 254 | 
            +
                    out2 = self.layer2(out1)
         | 
| 255 | 
            +
                    out3 = self.layer3(out2)
         | 
| 256 | 
            +
                    out4 = self.layer4(out3)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    out = torch.cat([out2, out3, out4], dim=1)
         | 
| 259 | 
            +
                    out = F.relu(self.conv(out))
         | 
| 260 | 
            +
                    out = self.bn(self.pooling(out))
         | 
| 261 | 
            +
                    out = self.linear(out)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    return out
         | 
| 264 | 
            +
             | 
| 265 | 
            +
             | 
| 266 | 
            +
            def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
         | 
| 267 | 
            +
                return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
         | 
| 268 | 
            +
                                  feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
         | 
    	
        model/modules.py
    ADDED
    
    | @@ -0,0 +1,575 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            ein notation:
         | 
| 3 | 
            +
            b - batch
         | 
| 4 | 
            +
            n - sequence
         | 
| 5 | 
            +
            nt - text sequence
         | 
| 6 | 
            +
            nw - raw wave length
         | 
| 7 | 
            +
            d - dimension
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from __future__ import annotations
         | 
| 11 | 
            +
            from typing import Optional
         | 
| 12 | 
            +
            import math
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            from torch import nn
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            import torchaudio
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from einops import rearrange
         | 
| 20 | 
            +
            from x_transformers.x_transformers import apply_rotary_pos_emb
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            # raw wav to mel spec
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            class MelSpec(nn.Module):
         | 
| 26 | 
            +
                def __init__(
         | 
| 27 | 
            +
                    self,
         | 
| 28 | 
            +
                    filter_length = 1024,
         | 
| 29 | 
            +
                    hop_length = 256,
         | 
| 30 | 
            +
                    win_length = 1024,
         | 
| 31 | 
            +
                    n_mel_channels = 100,
         | 
| 32 | 
            +
                    target_sample_rate = 24_000,
         | 
| 33 | 
            +
                    normalize = False,
         | 
| 34 | 
            +
                    power = 1,
         | 
| 35 | 
            +
                    norm = None,
         | 
| 36 | 
            +
                    center = True,
         | 
| 37 | 
            +
                ):
         | 
| 38 | 
            +
                    super().__init__()
         | 
| 39 | 
            +
                    self.n_mel_channels = n_mel_channels
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    self.mel_stft = torchaudio.transforms.MelSpectrogram(
         | 
| 42 | 
            +
                        sample_rate = target_sample_rate,
         | 
| 43 | 
            +
                        n_fft = filter_length,
         | 
| 44 | 
            +
                        win_length = win_length,
         | 
| 45 | 
            +
                        hop_length = hop_length,
         | 
| 46 | 
            +
                        n_mels = n_mel_channels,
         | 
| 47 | 
            +
                        power = power,
         | 
| 48 | 
            +
                        center = center,
         | 
| 49 | 
            +
                        normalized = normalize,
         | 
| 50 | 
            +
                        norm = norm,
         | 
| 51 | 
            +
                    )
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    self.register_buffer('dummy', torch.tensor(0), persistent = False)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def forward(self, inp):
         | 
| 56 | 
            +
                    if len(inp.shape) == 3:
         | 
| 57 | 
            +
                        inp = rearrange(inp, 'b 1 nw -> b nw')
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    assert len(inp.shape) == 2
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    if self.dummy.device != inp.device:
         | 
| 62 | 
            +
                        self.to(inp.device)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    mel = self.mel_stft(inp)
         | 
| 65 | 
            +
                    mel = mel.clamp(min = 1e-5).log()
         | 
| 66 | 
            +
                    return mel
         | 
| 67 | 
            +
                
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            # sinusoidal position embedding
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            class SinusPositionEmbedding(nn.Module):
         | 
| 72 | 
            +
                def __init__(self, dim):
         | 
| 73 | 
            +
                    super().__init__()
         | 
| 74 | 
            +
                    self.dim = dim
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def forward(self, x, scale=1000):
         | 
| 77 | 
            +
                    device = x.device
         | 
| 78 | 
            +
                    half_dim = self.dim // 2
         | 
| 79 | 
            +
                    emb = math.log(10000) / (half_dim - 1)
         | 
| 80 | 
            +
                    emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
         | 
| 81 | 
            +
                    emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
         | 
| 82 | 
            +
                    emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
         | 
| 83 | 
            +
                    return emb
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            # convolutional position embedding
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            class ConvPositionEmbedding(nn.Module):
         | 
| 89 | 
            +
                def __init__(self, dim, kernel_size = 31, groups = 16):
         | 
| 90 | 
            +
                    super().__init__()
         | 
| 91 | 
            +
                    assert kernel_size % 2 != 0
         | 
| 92 | 
            +
                    self.conv1d = nn.Sequential(
         | 
| 93 | 
            +
                        nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
         | 
| 94 | 
            +
                        nn.Mish(),
         | 
| 95 | 
            +
                        nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
         | 
| 96 | 
            +
                        nn.Mish(),
         | 
| 97 | 
            +
                    )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def forward(self, x: float['b n d'], mask: bool['b n'] | None  = None):
         | 
| 100 | 
            +
                    if mask is not None:
         | 
| 101 | 
            +
                        mask = mask[..., None]
         | 
| 102 | 
            +
                        x = x.masked_fill(~mask, 0.)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    x = rearrange(x, 'b n d -> b d n')
         | 
| 105 | 
            +
                    x = self.conv1d(x)
         | 
| 106 | 
            +
                    out = rearrange(x, 'b d n -> b n d')
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    if mask is not None:
         | 
| 109 | 
            +
                        out = out.masked_fill(~mask, 0.)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    return out
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            # rotary positional embedding related
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
         | 
| 117 | 
            +
                # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
         | 
| 118 | 
            +
                # has some connection to NTK literature
         | 
| 119 | 
            +
                # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
         | 
| 120 | 
            +
                # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
         | 
| 121 | 
            +
                theta *= theta_rescale_factor ** (dim / (dim - 2))
         | 
| 122 | 
            +
                freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
         | 
| 123 | 
            +
                t = torch.arange(end, device=freqs.device)  # type: ignore
         | 
| 124 | 
            +
                freqs = torch.outer(t, freqs).float()  # type: ignore
         | 
| 125 | 
            +
                freqs_cos = torch.cos(freqs)  # real part
         | 
| 126 | 
            +
                freqs_sin = torch.sin(freqs)  # imaginary part
         | 
| 127 | 
            +
                return torch.cat([freqs_cos, freqs_sin], dim=-1)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
            def get_pos_embed_indices(start, length, max_pos, scale=1.):
         | 
| 130 | 
            +
                # length = length if isinstance(length, int) else length.max()
         | 
| 131 | 
            +
                scale = scale * torch.ones_like(start, dtype=torch.float32)  # in case scale is a scalar
         | 
| 132 | 
            +
                pos = start.unsqueeze(1) + (
         | 
| 133 | 
            +
                        torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
         | 
| 134 | 
            +
                        scale.unsqueeze(1)).long()
         | 
| 135 | 
            +
                # avoid extra long error.
         | 
| 136 | 
            +
                pos = torch.where(pos < max_pos, pos, max_pos - 1)
         | 
| 137 | 
            +
                return pos
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            # Global Response Normalization layer (Instance Normalization ?)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            class GRN(nn.Module):
         | 
| 143 | 
            +
                def __init__(self, dim):
         | 
| 144 | 
            +
                    super().__init__()
         | 
| 145 | 
            +
                    self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
         | 
| 146 | 
            +
                    self.beta = nn.Parameter(torch.zeros(1, 1, dim))
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def forward(self, x):
         | 
| 149 | 
            +
                    Gx = torch.norm(x, p=2, dim=1, keepdim=True)
         | 
| 150 | 
            +
                    Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
         | 
| 151 | 
            +
                    return self.gamma * (x * Nx) + self.beta + x
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
         | 
| 155 | 
            +
            # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            class ConvNeXtV2Block(nn.Module):
         | 
| 158 | 
            +
                def __init__(
         | 
| 159 | 
            +
                    self,
         | 
| 160 | 
            +
                    dim: int,
         | 
| 161 | 
            +
                    intermediate_dim: int,
         | 
| 162 | 
            +
                    dilation: int = 1,
         | 
| 163 | 
            +
                ):
         | 
| 164 | 
            +
                    super().__init__()
         | 
| 165 | 
            +
                    padding = (dilation * (7 - 1)) // 2
         | 
| 166 | 
            +
                    self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation)  # depthwise conv
         | 
| 167 | 
            +
                    self.norm = nn.LayerNorm(dim, eps=1e-6)
         | 
| 168 | 
            +
                    self.pwconv1 = nn.Linear(dim, intermediate_dim)  # pointwise/1x1 convs, implemented with linear layers
         | 
| 169 | 
            +
                    self.act = nn.GELU()
         | 
| 170 | 
            +
                    self.grn = GRN(intermediate_dim)
         | 
| 171 | 
            +
                    self.pwconv2 = nn.Linear(intermediate_dim, dim)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 174 | 
            +
                    residual = x
         | 
| 175 | 
            +
                    x = x.transpose(1, 2)  # b n d -> b d n
         | 
| 176 | 
            +
                    x = self.dwconv(x)
         | 
| 177 | 
            +
                    x = x.transpose(1, 2)  # b d n -> b n d
         | 
| 178 | 
            +
                    x = self.norm(x)
         | 
| 179 | 
            +
                    x = self.pwconv1(x)
         | 
| 180 | 
            +
                    x = self.act(x)
         | 
| 181 | 
            +
                    x = self.grn(x)
         | 
| 182 | 
            +
                    x = self.pwconv2(x)
         | 
| 183 | 
            +
                    return residual + x
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            # AdaLayerNormZero
         | 
| 187 | 
            +
            # return with modulated x for attn input, and params for later mlp modulation
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            class AdaLayerNormZero(nn.Module):
         | 
| 190 | 
            +
                def __init__(self, dim):
         | 
| 191 | 
            +
                    super().__init__()
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    self.silu = nn.SiLU()
         | 
| 194 | 
            +
                    self.linear = nn.Linear(dim, dim * 6)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def forward(self, x, emb = None):
         | 
| 199 | 
            +
                    emb = self.linear(self.silu(emb))
         | 
| 200 | 
            +
                    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
         | 
| 203 | 
            +
                    return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            # AdaLayerNormZero for final layer
         | 
| 207 | 
            +
            # return only with modulated x for attn input, cuz no more mlp modulation
         | 
| 208 | 
            +
             | 
| 209 | 
            +
            class AdaLayerNormZero_Final(nn.Module):
         | 
| 210 | 
            +
                def __init__(self, dim):
         | 
| 211 | 
            +
                    super().__init__()
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    self.silu = nn.SiLU()
         | 
| 214 | 
            +
                    self.linear = nn.Linear(dim, dim * 2)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                def forward(self, x, emb):
         | 
| 219 | 
            +
                    emb = self.linear(self.silu(emb))
         | 
| 220 | 
            +
                    scale, shift = torch.chunk(emb, 2, dim=1)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
         | 
| 223 | 
            +
                    return x
         | 
| 224 | 
            +
             | 
| 225 | 
            +
             | 
| 226 | 
            +
            # FeedForward
         | 
| 227 | 
            +
             | 
| 228 | 
            +
            class FeedForward(nn.Module):
         | 
| 229 | 
            +
                def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
         | 
| 230 | 
            +
                    super().__init__()
         | 
| 231 | 
            +
                    inner_dim = int(dim * mult)
         | 
| 232 | 
            +
                    dim_out = dim_out if dim_out is not None else dim
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    activation = nn.GELU(approximate=approximate)
         | 
| 235 | 
            +
                    project_in = nn.Sequential(
         | 
| 236 | 
            +
                        nn.Linear(dim, inner_dim),
         | 
| 237 | 
            +
                        activation
         | 
| 238 | 
            +
                    )
         | 
| 239 | 
            +
                    self.ff = nn.Sequential(
         | 
| 240 | 
            +
                        project_in,
         | 
| 241 | 
            +
                        nn.Dropout(dropout),
         | 
| 242 | 
            +
                        nn.Linear(inner_dim, dim_out)
         | 
| 243 | 
            +
                    )
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                def forward(self, x):
         | 
| 246 | 
            +
                    return self.ff(x)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
            # Attention with possible joint part
         | 
| 250 | 
            +
            # modified from diffusers/src/diffusers/models/attention_processor.py
         | 
| 251 | 
            +
             | 
| 252 | 
            +
            class Attention(nn.Module):
         | 
| 253 | 
            +
                def __init__(
         | 
| 254 | 
            +
                    self,
         | 
| 255 | 
            +
                    processor: JointAttnProcessor | AttnProcessor,
         | 
| 256 | 
            +
                    dim: int,
         | 
| 257 | 
            +
                    heads: int = 8,
         | 
| 258 | 
            +
                    dim_head: int = 64,
         | 
| 259 | 
            +
                    dropout: float = 0.0,
         | 
| 260 | 
            +
                    context_dim: Optional[int] = None, # if not None -> joint attention
         | 
| 261 | 
            +
                    context_pre_only = None,
         | 
| 262 | 
            +
                ):
         | 
| 263 | 
            +
                    super().__init__()
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    if not hasattr(F, "scaled_dot_product_attention"):
         | 
| 266 | 
            +
                        raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    self.processor = processor
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    self.dim = dim
         | 
| 271 | 
            +
                    self.heads = heads
         | 
| 272 | 
            +
                    self.inner_dim = dim_head * heads
         | 
| 273 | 
            +
                    self.dropout = dropout
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    self.context_dim = context_dim
         | 
| 276 | 
            +
                    self.context_pre_only = context_pre_only
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    self.to_q = nn.Linear(dim, self.inner_dim)
         | 
| 279 | 
            +
                    self.to_k = nn.Linear(dim, self.inner_dim)
         | 
| 280 | 
            +
                    self.to_v = nn.Linear(dim, self.inner_dim)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    if self.context_dim is not None:
         | 
| 283 | 
            +
                        self.to_k_c = nn.Linear(context_dim, self.inner_dim)
         | 
| 284 | 
            +
                        self.to_v_c = nn.Linear(context_dim, self.inner_dim)
         | 
| 285 | 
            +
                        if self.context_pre_only is not None:
         | 
| 286 | 
            +
                            self.to_q_c = nn.Linear(context_dim, self.inner_dim)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    self.to_out = nn.ModuleList([])
         | 
| 289 | 
            +
                    self.to_out.append(nn.Linear(self.inner_dim, dim))
         | 
| 290 | 
            +
                    self.to_out.append(nn.Dropout(dropout))
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    if self.context_pre_only is not None and not self.context_pre_only:
         | 
| 293 | 
            +
                        self.to_out_c = nn.Linear(self.inner_dim, dim)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def forward(
         | 
| 296 | 
            +
                    self,
         | 
| 297 | 
            +
                    x: float['b n d'], # noised input x
         | 
| 298 | 
            +
                    c: float['b n d'] = None,  # context c
         | 
| 299 | 
            +
                    mask: bool['b n'] | None = None,
         | 
| 300 | 
            +
                    rope = None,  # rotary position embedding for x
         | 
| 301 | 
            +
                    c_rope = None,  # rotary position embedding for c
         | 
| 302 | 
            +
                ) -> torch.Tensor:
         | 
| 303 | 
            +
                    if c is not None:
         | 
| 304 | 
            +
                        return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
         | 
| 305 | 
            +
                    else:
         | 
| 306 | 
            +
                        return self.processor(self, x, mask = mask, rope = rope)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
             | 
| 309 | 
            +
            # Attention processor
         | 
| 310 | 
            +
             | 
| 311 | 
            +
            class AttnProcessor:
         | 
| 312 | 
            +
                def __init__(self):
         | 
| 313 | 
            +
                    pass
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                def __call__(
         | 
| 316 | 
            +
                    self,
         | 
| 317 | 
            +
                    attn: Attention,
         | 
| 318 | 
            +
                    x: float['b n d'], # noised input x
         | 
| 319 | 
            +
                    mask: bool['b n'] | None = None,
         | 
| 320 | 
            +
                    rope = None,  # rotary position embedding
         | 
| 321 | 
            +
                ) -> torch.FloatTensor:
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    batch_size = x.shape[0]
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    # `sample` projections.
         | 
| 326 | 
            +
                    query = attn.to_q(x)
         | 
| 327 | 
            +
                    key = attn.to_k(x)
         | 
| 328 | 
            +
                    value = attn.to_v(x)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    # apply rotary position embedding
         | 
| 331 | 
            +
                    if rope is not None:
         | 
| 332 | 
            +
                        freqs, xpos_scale = rope
         | 
| 333 | 
            +
                        q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                        query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
         | 
| 336 | 
            +
                        key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    # attention
         | 
| 339 | 
            +
                    inner_dim = key.shape[-1]
         | 
| 340 | 
            +
                    head_dim = inner_dim // attn.heads
         | 
| 341 | 
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 342 | 
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 343 | 
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    # mask. e.g. inference got a batch with different target durations, mask out the padding
         | 
| 346 | 
            +
                    if mask is not None:
         | 
| 347 | 
            +
                        attn_mask = mask
         | 
| 348 | 
            +
                        attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
         | 
| 349 | 
            +
                        attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
         | 
| 350 | 
            +
                    else:
         | 
| 351 | 
            +
                        attn_mask = None
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
         | 
| 354 | 
            +
                    x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         | 
| 355 | 
            +
                    x = x.to(query.dtype)
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                    # linear proj
         | 
| 358 | 
            +
                    x = attn.to_out[0](x)
         | 
| 359 | 
            +
                    # dropout
         | 
| 360 | 
            +
                    x = attn.to_out[1](x)
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    if mask is not None:
         | 
| 363 | 
            +
                        mask = rearrange(mask, 'b n -> b n 1')
         | 
| 364 | 
            +
                        x = x.masked_fill(~mask, 0.)
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    return x
         | 
| 367 | 
            +
                
         | 
| 368 | 
            +
             | 
| 369 | 
            +
            # Joint Attention processor for MM-DiT
         | 
| 370 | 
            +
            # modified from diffusers/src/diffusers/models/attention_processor.py
         | 
| 371 | 
            +
             | 
| 372 | 
            +
            class JointAttnProcessor:
         | 
| 373 | 
            +
                def __init__(self):
         | 
| 374 | 
            +
                    pass
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                def __call__(
         | 
| 377 | 
            +
                    self,
         | 
| 378 | 
            +
                    attn: Attention,
         | 
| 379 | 
            +
                    x: float['b n d'], # noised input x
         | 
| 380 | 
            +
                    c: float['b nt d'] = None,  # context c, here text
         | 
| 381 | 
            +
                    mask: bool['b n'] | None = None,
         | 
| 382 | 
            +
                    rope = None,  # rotary position embedding for x
         | 
| 383 | 
            +
                    c_rope = None,  # rotary position embedding for c
         | 
| 384 | 
            +
                ) -> torch.FloatTensor:
         | 
| 385 | 
            +
                    residual = x
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    batch_size = c.shape[0]
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    # `sample` projections.
         | 
| 390 | 
            +
                    query = attn.to_q(x)
         | 
| 391 | 
            +
                    key = attn.to_k(x)
         | 
| 392 | 
            +
                    value = attn.to_v(x)
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    # `context` projections.
         | 
| 395 | 
            +
                    c_query = attn.to_q_c(c)
         | 
| 396 | 
            +
                    c_key = attn.to_k_c(c)
         | 
| 397 | 
            +
                    c_value = attn.to_v_c(c)
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    # apply rope for context and noised input independently
         | 
| 400 | 
            +
                    if rope is not None:
         | 
| 401 | 
            +
                        freqs, xpos_scale = rope
         | 
| 402 | 
            +
                        q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
         | 
| 403 | 
            +
                        query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
         | 
| 404 | 
            +
                        key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
         | 
| 405 | 
            +
                    if c_rope is not None:
         | 
| 406 | 
            +
                        freqs, xpos_scale = c_rope
         | 
| 407 | 
            +
                        q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
         | 
| 408 | 
            +
                        c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
         | 
| 409 | 
            +
                        c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    # attention
         | 
| 412 | 
            +
                    query = torch.cat([query, c_query], dim=1)
         | 
| 413 | 
            +
                    key = torch.cat([key, c_key], dim=1)
         | 
| 414 | 
            +
                    value = torch.cat([value, c_value], dim=1)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                    inner_dim = key.shape[-1]
         | 
| 417 | 
            +
                    head_dim = inner_dim // attn.heads
         | 
| 418 | 
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 419 | 
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 420 | 
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    # mask. e.g. inference got a batch with different target durations, mask out the padding
         | 
| 423 | 
            +
                    if mask is not None:
         | 
| 424 | 
            +
                        attn_mask = F.pad(mask, (0, c.shape[1]), value = True)  # no mask for c (text)
         | 
| 425 | 
            +
                        attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
         | 
| 426 | 
            +
                        attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
         | 
| 427 | 
            +
                    else:
         | 
| 428 | 
            +
                        attn_mask = None
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                    x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
         | 
| 431 | 
            +
                    x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
         | 
| 432 | 
            +
                    x = x.to(query.dtype)
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    # Split the attention outputs.
         | 
| 435 | 
            +
                    x, c = (
         | 
| 436 | 
            +
                        x[:, :residual.shape[1]],
         | 
| 437 | 
            +
                        x[:, residual.shape[1]:],
         | 
| 438 | 
            +
                    )
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    # linear proj
         | 
| 441 | 
            +
                    x = attn.to_out[0](x)
         | 
| 442 | 
            +
                    # dropout
         | 
| 443 | 
            +
                    x = attn.to_out[1](x)
         | 
| 444 | 
            +
                    if not attn.context_pre_only:
         | 
| 445 | 
            +
                        c = attn.to_out_c(c)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    if mask is not None:
         | 
| 448 | 
            +
                        mask = rearrange(mask, 'b n -> b n 1')
         | 
| 449 | 
            +
                        x = x.masked_fill(~mask, 0.)
         | 
| 450 | 
            +
                        # c = c.masked_fill(~mask, 0.)  # no mask for c (text)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                    return x, c
         | 
| 453 | 
            +
             | 
| 454 | 
            +
             | 
| 455 | 
            +
            # DiT Block
         | 
| 456 | 
            +
             | 
| 457 | 
            +
            class DiTBlock(nn.Module):
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
         | 
| 460 | 
            +
                    super().__init__()
         | 
| 461 | 
            +
                    
         | 
| 462 | 
            +
                    self.attn_norm = AdaLayerNormZero(dim)
         | 
| 463 | 
            +
                    self.attn = Attention(
         | 
| 464 | 
            +
                        processor = AttnProcessor(),
         | 
| 465 | 
            +
                        dim = dim,
         | 
| 466 | 
            +
                        heads = heads,
         | 
| 467 | 
            +
                        dim_head = dim_head,
         | 
| 468 | 
            +
                        dropout = dropout,
         | 
| 469 | 
            +
                        )
         | 
| 470 | 
            +
                    
         | 
| 471 | 
            +
                    self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         | 
| 472 | 
            +
                    self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
         | 
| 475 | 
            +
                    # pre-norm & modulation for attention input
         | 
| 476 | 
            +
                    norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                    # attention
         | 
| 479 | 
            +
                    attn_output = self.attn(x=norm, mask=mask, rope=rope)
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                    # process attention output for input x
         | 
| 482 | 
            +
                    x = x + gate_msa.unsqueeze(1) * attn_output
         | 
| 483 | 
            +
                    
         | 
| 484 | 
            +
                    norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
         | 
| 485 | 
            +
                    ff_output = self.ff(norm)
         | 
| 486 | 
            +
                    x = x + gate_mlp.unsqueeze(1) * ff_output
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    return x
         | 
| 489 | 
            +
             | 
| 490 | 
            +
             | 
| 491 | 
            +
            # MMDiT Block https://arxiv.org/abs/2403.03206
         | 
| 492 | 
            +
             | 
| 493 | 
            +
            class MMDiTBlock(nn.Module):
         | 
| 494 | 
            +
                r""" 
         | 
| 495 | 
            +
                modified from diffusers/src/diffusers/models/attention.py
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                notes.
         | 
| 498 | 
            +
                _c: context related. text, cond, etc. (left part in sd3 fig2.b)
         | 
| 499 | 
            +
                _x: noised input related. (right part)
         | 
| 500 | 
            +
                context_pre_only: last layer only do prenorm + modulation cuz no more ffn
         | 
| 501 | 
            +
                """
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
         | 
| 504 | 
            +
                    super().__init__()
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                    self.context_pre_only = context_pre_only
         | 
| 507 | 
            +
                    
         | 
| 508 | 
            +
                    self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
         | 
| 509 | 
            +
                    self.attn_norm_x = AdaLayerNormZero(dim)
         | 
| 510 | 
            +
                    self.attn = Attention(
         | 
| 511 | 
            +
                        processor = JointAttnProcessor(),
         | 
| 512 | 
            +
                        dim = dim,
         | 
| 513 | 
            +
                        heads = heads,
         | 
| 514 | 
            +
                        dim_head = dim_head,
         | 
| 515 | 
            +
                        dropout = dropout,
         | 
| 516 | 
            +
                        context_dim = dim,
         | 
| 517 | 
            +
                        context_pre_only = context_pre_only,
         | 
| 518 | 
            +
                        )
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    if not context_pre_only:
         | 
| 521 | 
            +
                        self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         | 
| 522 | 
            +
                        self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
         | 
| 523 | 
            +
                    else:
         | 
| 524 | 
            +
                        self.ff_norm_c = None
         | 
| 525 | 
            +
                        self.ff_c = None
         | 
| 526 | 
            +
                    self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
         | 
| 527 | 
            +
                    self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
         | 
| 530 | 
            +
                    # pre-norm & modulation for attention input
         | 
| 531 | 
            +
                    if self.context_pre_only:
         | 
| 532 | 
            +
                        norm_c = self.attn_norm_c(c, t)
         | 
| 533 | 
            +
                    else:
         | 
| 534 | 
            +
                        norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
         | 
| 535 | 
            +
                    norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    # attention
         | 
| 538 | 
            +
                    x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    # process attention output for context c
         | 
| 541 | 
            +
                    if self.context_pre_only:
         | 
| 542 | 
            +
                        c = None
         | 
| 543 | 
            +
                    else: # if not last layer
         | 
| 544 | 
            +
                        c = c + c_gate_msa.unsqueeze(1) * c_attn_output
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                        norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
         | 
| 547 | 
            +
                        c_ff_output = self.ff_c(norm_c)
         | 
| 548 | 
            +
                        c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    # process attention output for input x
         | 
| 551 | 
            +
                    x = x + x_gate_msa.unsqueeze(1) * x_attn_output
         | 
| 552 | 
            +
                    
         | 
| 553 | 
            +
                    norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
         | 
| 554 | 
            +
                    x_ff_output = self.ff_x(norm_x)
         | 
| 555 | 
            +
                    x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                    return c, x
         | 
| 558 | 
            +
             | 
| 559 | 
            +
             | 
| 560 | 
            +
            # time step conditioning embedding
         | 
| 561 | 
            +
             | 
| 562 | 
            +
            class TimestepEmbedding(nn.Module):
         | 
| 563 | 
            +
                def __init__(self, dim, freq_embed_dim=256):
         | 
| 564 | 
            +
                    super().__init__()
         | 
| 565 | 
            +
                    self.time_embed = SinusPositionEmbedding(freq_embed_dim)
         | 
| 566 | 
            +
                    self.time_mlp = nn.Sequential(
         | 
| 567 | 
            +
                        nn.Linear(freq_embed_dim, dim),
         | 
| 568 | 
            +
                        nn.SiLU(),
         | 
| 569 | 
            +
                        nn.Linear(dim, dim)
         | 
| 570 | 
            +
                    )
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                def forward(self, timestep: float['b']):
         | 
| 573 | 
            +
                    time_hidden = self.time_embed(timestep)
         | 
| 574 | 
            +
                    time = self.time_mlp(time_hidden)  # b d
         | 
| 575 | 
            +
                    return time
         | 
    	
        model/trainer.py
    ADDED
    
    | @@ -0,0 +1,250 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import annotations
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import gc
         | 
| 5 | 
            +
            from tqdm import tqdm
         | 
| 6 | 
            +
            import wandb
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            from torch.optim import AdamW
         | 
| 10 | 
            +
            from torch.utils.data import DataLoader, Dataset, SequentialSampler
         | 
| 11 | 
            +
            from torch.optim.lr_scheduler import LinearLR, SequentialLR
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from einops import rearrange
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from accelerate import Accelerator
         | 
| 16 | 
            +
            from accelerate.utils import DistributedDataParallelKwargs
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from ema_pytorch import EMA
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from model import CFM
         | 
| 21 | 
            +
            from model.utils import exists, default
         | 
| 22 | 
            +
            from model.dataset import DynamicBatchSampler, collate_fn
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            # trainer
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            class Trainer:
         | 
| 28 | 
            +
                def __init__(
         | 
| 29 | 
            +
                    self,
         | 
| 30 | 
            +
                    model: CFM,
         | 
| 31 | 
            +
                    epochs,
         | 
| 32 | 
            +
                    learning_rate,
         | 
| 33 | 
            +
                    num_warmup_updates = 20000,
         | 
| 34 | 
            +
                    save_per_updates = 1000, 
         | 
| 35 | 
            +
                    checkpoint_path = None,
         | 
| 36 | 
            +
                    batch_size = 32, 
         | 
| 37 | 
            +
                    batch_size_type: str = "sample",
         | 
| 38 | 
            +
                    max_samples = 32,
         | 
| 39 | 
            +
                    grad_accumulation_steps = 1,
         | 
| 40 | 
            +
                    max_grad_norm = 1.0,
         | 
| 41 | 
            +
                    noise_scheduler: str | None = None,
         | 
| 42 | 
            +
                    duration_predictor: torch.nn.Module | None = None,
         | 
| 43 | 
            +
                    wandb_project = "test_e2-tts",
         | 
| 44 | 
            +
                    wandb_run_name = "test_run",
         | 
| 45 | 
            +
                    wandb_resume_id: str = None,
         | 
| 46 | 
            +
                    last_per_steps = None,
         | 
| 47 | 
            +
                    accelerate_kwargs: dict = dict(),
         | 
| 48 | 
            +
                    ema_kwargs: dict = dict()
         | 
| 49 | 
            +
                ):
         | 
| 50 | 
            +
                    
         | 
| 51 | 
            +
                    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    self.accelerator = Accelerator(
         | 
| 54 | 
            +
                        log_with = "wandb",
         | 
| 55 | 
            +
                        kwargs_handlers = [ddp_kwargs],
         | 
| 56 | 
            +
                        gradient_accumulation_steps = grad_accumulation_steps,
         | 
| 57 | 
            +
                        **accelerate_kwargs
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    if exists(wandb_resume_id):
         | 
| 61 | 
            +
                        init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
         | 
| 62 | 
            +
                    else:
         | 
| 63 | 
            +
                        init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
         | 
| 64 | 
            +
                    self.accelerator.init_trackers(
         | 
| 65 | 
            +
                        project_name = wandb_project, 
         | 
| 66 | 
            +
                        init_kwargs=init_kwargs,
         | 
| 67 | 
            +
                        config={"epochs": epochs,
         | 
| 68 | 
            +
                                "learning_rate": learning_rate,
         | 
| 69 | 
            +
                                "num_warmup_updates": num_warmup_updates, 
         | 
| 70 | 
            +
                                "batch_size": batch_size,
         | 
| 71 | 
            +
                                "batch_size_type": batch_size_type,
         | 
| 72 | 
            +
                                "max_samples": max_samples,
         | 
| 73 | 
            +
                                "grad_accumulation_steps": grad_accumulation_steps,
         | 
| 74 | 
            +
                                "max_grad_norm": max_grad_norm,
         | 
| 75 | 
            +
                                "gpus": self.accelerator.num_processes,
         | 
| 76 | 
            +
                                "noise_scheduler": noise_scheduler}
         | 
| 77 | 
            +
                        )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    self.model = model
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    if self.is_main:
         | 
| 82 | 
            +
                        self.ema_model = EMA(
         | 
| 83 | 
            +
                            model,
         | 
| 84 | 
            +
                            include_online_model = False,
         | 
| 85 | 
            +
                            **ema_kwargs
         | 
| 86 | 
            +
                        )
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                        self.ema_model.to(self.accelerator.device)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    self.epochs = epochs
         | 
| 91 | 
            +
                    self.num_warmup_updates = num_warmup_updates
         | 
| 92 | 
            +
                    self.save_per_updates = save_per_updates
         | 
| 93 | 
            +
                    self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
         | 
| 94 | 
            +
                    self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    self.batch_size = batch_size
         | 
| 97 | 
            +
                    self.batch_size_type = batch_size_type
         | 
| 98 | 
            +
                    self.max_samples = max_samples
         | 
| 99 | 
            +
                    self.grad_accumulation_steps = grad_accumulation_steps
         | 
| 100 | 
            +
                    self.max_grad_norm = max_grad_norm
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    self.noise_scheduler = noise_scheduler
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self.duration_predictor = duration_predictor
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    self.optimizer = AdamW(model.parameters(), lr=learning_rate)
         | 
| 107 | 
            +
                    self.model, self.optimizer = self.accelerator.prepare(
         | 
| 108 | 
            +
                        self.model, self.optimizer
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                @property
         | 
| 112 | 
            +
                def is_main(self):
         | 
| 113 | 
            +
                    return self.accelerator.is_main_process
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def save_checkpoint(self, step, last=False):
         | 
| 116 | 
            +
                    self.accelerator.wait_for_everyone()
         | 
| 117 | 
            +
                    if self.is_main:
         | 
| 118 | 
            +
                        checkpoint = dict(
         | 
| 119 | 
            +
                            model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
         | 
| 120 | 
            +
                            optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
         | 
| 121 | 
            +
                            ema_model_state_dict = self.ema_model.state_dict(),
         | 
| 122 | 
            +
                            scheduler_state_dict = self.scheduler.state_dict(),
         | 
| 123 | 
            +
                            step = step
         | 
| 124 | 
            +
                        )
         | 
| 125 | 
            +
                        if not os.path.exists(self.checkpoint_path):
         | 
| 126 | 
            +
                            os.makedirs(self.checkpoint_path)
         | 
| 127 | 
            +
                        if last == True:
         | 
| 128 | 
            +
                            self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
         | 
| 129 | 
            +
                            print(f"Saved last checkpoint at step {step}")
         | 
| 130 | 
            +
                        else:
         | 
| 131 | 
            +
                            self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def load_checkpoint(self):
         | 
| 134 | 
            +
                    if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
         | 
| 135 | 
            +
                        return 0
         | 
| 136 | 
            +
                    
         | 
| 137 | 
            +
                    self.accelerator.wait_for_everyone()
         | 
| 138 | 
            +
                    if "model_last.pt" in os.listdir(self.checkpoint_path):
         | 
| 139 | 
            +
                        latest_checkpoint = "model_last.pt"
         | 
| 140 | 
            +
                    else:
         | 
| 141 | 
            +
                        latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
         | 
| 142 | 
            +
                    # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device)  # rather use accelerator.load_state ಥ_ಥ
         | 
| 143 | 
            +
                    checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    if self.is_main:
         | 
| 146 | 
            +
                        self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    if 'step' in checkpoint:
         | 
| 149 | 
            +
                        self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
         | 
| 150 | 
            +
                        self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
         | 
| 151 | 
            +
                        if self.scheduler:
         | 
| 152 | 
            +
                            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
         | 
| 153 | 
            +
                        step = checkpoint['step']
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
         | 
| 156 | 
            +
                        self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
         | 
| 157 | 
            +
                        step = 0
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    del checkpoint; gc.collect()
         | 
| 160 | 
            +
                    return step
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
         | 
| 163 | 
            +
                    
         | 
| 164 | 
            +
                    if exists(resumable_with_seed):
         | 
| 165 | 
            +
                        generator = torch.Generator()
         | 
| 166 | 
            +
                        generator.manual_seed(resumable_with_seed)
         | 
| 167 | 
            +
                    else: 
         | 
| 168 | 
            +
                        generator = None
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    if self.batch_size_type == "sample":
         | 
| 171 | 
            +
                        train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
         | 
| 172 | 
            +
                                                      batch_size=self.batch_size, shuffle=True, generator=generator)
         | 
| 173 | 
            +
                    elif self.batch_size_type == "frame":
         | 
| 174 | 
            +
                        self.accelerator.even_batches = False
         | 
| 175 | 
            +
                        sampler = SequentialSampler(train_dataset)
         | 
| 176 | 
            +
                        batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
         | 
| 177 | 
            +
                        train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
         | 
| 178 | 
            +
                                                      batch_sampler=batch_sampler)
         | 
| 179 | 
            +
                    else:
         | 
| 180 | 
            +
                        raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
         | 
| 181 | 
            +
                    
         | 
| 182 | 
            +
                    #  accelerator.prepare() dispatches batches to devices;
         | 
| 183 | 
            +
                    #  which means the length of dataloader calculated before, should consider the number of devices
         | 
| 184 | 
            +
                    warmup_steps = self.num_warmup_updates * self.accelerator.num_processes  # consider a fixed warmup steps while using accelerate multi-gpu ddp
         | 
| 185 | 
            +
                                                                                             # otherwise by default with split_batches=False, warmup steps change with num_processes
         | 
| 186 | 
            +
                    total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
         | 
| 187 | 
            +
                    decay_steps = total_steps - warmup_steps
         | 
| 188 | 
            +
                    warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
         | 
| 189 | 
            +
                    decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
         | 
| 190 | 
            +
                    self.scheduler = SequentialLR(self.optimizer, 
         | 
| 191 | 
            +
                                                  schedulers=[warmup_scheduler, decay_scheduler],
         | 
| 192 | 
            +
                                                  milestones=[warmup_steps])
         | 
| 193 | 
            +
                    train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler)  # actual steps = 1 gpu steps / gpus
         | 
| 194 | 
            +
                    start_step = self.load_checkpoint()
         | 
| 195 | 
            +
                    global_step = start_step
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    if exists(resumable_with_seed):
         | 
| 198 | 
            +
                        orig_epoch_step = len(train_dataloader)
         | 
| 199 | 
            +
                        skipped_epoch = int(start_step // orig_epoch_step)
         | 
| 200 | 
            +
                        skipped_batch = start_step % orig_epoch_step
         | 
| 201 | 
            +
                        skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
         | 
| 202 | 
            +
                    else:
         | 
| 203 | 
            +
                        skipped_epoch = 0
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    for epoch in range(skipped_epoch, self.epochs):
         | 
| 206 | 
            +
                        self.model.train()
         | 
| 207 | 
            +
                        if exists(resumable_with_seed) and epoch == skipped_epoch:
         | 
| 208 | 
            +
                            progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, 
         | 
| 209 | 
            +
                                                initial=skipped_batch, total=orig_epoch_step)
         | 
| 210 | 
            +
                        else:
         | 
| 211 | 
            +
                            progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                        for batch in progress_bar:
         | 
| 214 | 
            +
                            with self.accelerator.accumulate(self.model):
         | 
| 215 | 
            +
                                text_inputs = batch['text']
         | 
| 216 | 
            +
                                mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
         | 
| 217 | 
            +
                                mel_lengths = batch["mel_lengths"]
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                                # TODO. add duration predictor training
         | 
| 220 | 
            +
                                if self.duration_predictor is not None and self.accelerator.is_local_main_process:
         | 
| 221 | 
            +
                                    dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
         | 
| 222 | 
            +
                                    self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                                loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
         | 
| 225 | 
            +
                                self.accelerator.backward(loss)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                                if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
         | 
| 228 | 
            +
                                    self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                                self.optimizer.step()
         | 
| 231 | 
            +
                                self.scheduler.step()
         | 
| 232 | 
            +
                                self.optimizer.zero_grad()
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                            if self.is_main:
         | 
| 235 | 
            +
                                self.ema_model.update()
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                            global_step += 1
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                            if self.accelerator.is_local_main_process:
         | 
| 240 | 
            +
                                self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
         | 
| 241 | 
            +
                            
         | 
| 242 | 
            +
                            progress_bar.set_postfix(step=str(global_step), loss=loss.item())
         | 
| 243 | 
            +
                            
         | 
| 244 | 
            +
                            if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
         | 
| 245 | 
            +
                                self.save_checkpoint(global_step)
         | 
| 246 | 
            +
                            
         | 
| 247 | 
            +
                            if global_step % self.last_per_steps == 0:
         | 
| 248 | 
            +
                                self.save_checkpoint(global_step, last=True)
         | 
| 249 | 
            +
                    
         | 
| 250 | 
            +
                    self.accelerator.end_training()
         | 
    	
        model/utils.py
    ADDED
    
    | @@ -0,0 +1,574 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from __future__ import annotations
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import re
         | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
            import string
         | 
| 8 | 
            +
            from tqdm import tqdm
         | 
| 9 | 
            +
            from collections import defaultdict
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import matplotlib
         | 
| 12 | 
            +
            matplotlib.use("Agg")
         | 
| 13 | 
            +
            import matplotlib.pylab as plt
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            from torch.nn.utils.rnn import pad_sequence
         | 
| 18 | 
            +
            import torchaudio
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import einx
         | 
| 21 | 
            +
            from einops import rearrange, reduce
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            import jieba
         | 
| 24 | 
            +
            from pypinyin import lazy_pinyin, Style
         | 
| 25 | 
            +
            import zhconv
         | 
| 26 | 
            +
            from zhon.hanzi import punctuation
         | 
| 27 | 
            +
            from jiwer import compute_measures
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            from funasr import AutoModel
         | 
| 30 | 
            +
            from faster_whisper import WhisperModel
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            from model.ecapa_tdnn import ECAPA_TDNN_SMALL
         | 
| 33 | 
            +
            from model.modules import MelSpec
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            # seed everything
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            def seed_everything(seed = 0):
         | 
| 39 | 
            +
                random.seed(seed)
         | 
| 40 | 
            +
                os.environ['PYTHONHASHSEED'] = str(seed)
         | 
| 41 | 
            +
                torch.manual_seed(seed)
         | 
| 42 | 
            +
                torch.cuda.manual_seed(seed)
         | 
| 43 | 
            +
                torch.cuda.manual_seed_all(seed)
         | 
| 44 | 
            +
                torch.backends.cudnn.deterministic = True
         | 
| 45 | 
            +
                torch.backends.cudnn.benchmark = False
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            # helpers
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            def exists(v):
         | 
| 50 | 
            +
                return v is not None
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            def default(v, d):
         | 
| 53 | 
            +
                return v if exists(v) else d
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            # tensor helpers
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def lens_to_mask(
         | 
| 58 | 
            +
                t: int['b'],
         | 
| 59 | 
            +
                length: int | None = None
         | 
| 60 | 
            +
            ) -> bool['b n']:
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                if not exists(length):
         | 
| 63 | 
            +
                    length = t.amax()
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                seq = torch.arange(length, device = t.device)
         | 
| 66 | 
            +
                return einx.less('n, b -> b n', seq, t)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            def mask_from_start_end_indices(
         | 
| 69 | 
            +
                seq_len: int['b'],
         | 
| 70 | 
            +
                start: int['b'],
         | 
| 71 | 
            +
                end: int['b']
         | 
| 72 | 
            +
            ):
         | 
| 73 | 
            +
                max_seq_len = seq_len.max().item()  
         | 
| 74 | 
            +
                seq = torch.arange(max_seq_len, device = start.device).long()
         | 
| 75 | 
            +
                return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            def mask_from_frac_lengths(
         | 
| 78 | 
            +
                seq_len: int['b'],
         | 
| 79 | 
            +
                frac_lengths: float['b']
         | 
| 80 | 
            +
            ):
         | 
| 81 | 
            +
                lengths = (frac_lengths * seq_len).long()
         | 
| 82 | 
            +
                max_start = seq_len - lengths
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                rand = torch.rand_like(frac_lengths)
         | 
| 85 | 
            +
                start = (max_start * rand).long().clamp(min = 0)
         | 
| 86 | 
            +
                end = start + lengths
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                return mask_from_start_end_indices(seq_len, start, end)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            def maybe_masked_mean(
         | 
| 91 | 
            +
                t: float['b n d'],
         | 
| 92 | 
            +
                mask: bool['b n'] = None
         | 
| 93 | 
            +
            ) -> float['b d']:
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                if not exists(mask):
         | 
| 96 | 
            +
                    return t.mean(dim = 1)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
         | 
| 99 | 
            +
                num = reduce(t, 'b n d -> b d', 'sum')
         | 
| 100 | 
            +
                den = reduce(mask.float(), 'b n -> b', 'sum')
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            # simple utf-8 tokenizer, since paper went character based
         | 
| 106 | 
            +
            def list_str_to_tensor(
         | 
| 107 | 
            +
                text: list[str],
         | 
| 108 | 
            +
                padding_value = -1
         | 
| 109 | 
            +
            ) -> int['b nt']:
         | 
| 110 | 
            +
                list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text]  # ByT5 style
         | 
| 111 | 
            +
                text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
         | 
| 112 | 
            +
                return text
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            # char tokenizer, based on custom dataset's extracted .txt file
         | 
| 115 | 
            +
            def list_str_to_idx(
         | 
| 116 | 
            +
                text: list[str] | list[list[str]],
         | 
| 117 | 
            +
                vocab_char_map: dict[str, int],  # {char: idx}
         | 
| 118 | 
            +
                padding_value = -1
         | 
| 119 | 
            +
            ) -> int['b nt']:
         | 
| 120 | 
            +
                list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text]  # pinyin or char style
         | 
| 121 | 
            +
                text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
         | 
| 122 | 
            +
                return text
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            # Get tokenizer
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
         | 
| 128 | 
            +
                ''' 
         | 
| 129 | 
            +
                tokenizer   - "pinyin" do g2p for only chinese characters, need .txt vocab_file
         | 
| 130 | 
            +
                            - "char" for char-wise tokenizer, need .txt vocab_file
         | 
| 131 | 
            +
                            - "byte" for utf-8 tokenizer
         | 
| 132 | 
            +
                vocab_size  - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
         | 
| 133 | 
            +
                            - if use "char", derived from unfiltered character & symbol counts of custom dataset
         | 
| 134 | 
            +
                            - if use "byte", set to 256 (unicode byte range) 
         | 
| 135 | 
            +
                ''' 
         | 
| 136 | 
            +
                if tokenizer in ["pinyin", "char"]:
         | 
| 137 | 
            +
                    with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
         | 
| 138 | 
            +
                        vocab_char_map = {}
         | 
| 139 | 
            +
                        for i, char in enumerate(f):
         | 
| 140 | 
            +
                            vocab_char_map[char[:-1]] = i
         | 
| 141 | 
            +
                    vocab_size = len(vocab_char_map)
         | 
| 142 | 
            +
                    assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                elif tokenizer == "byte":
         | 
| 145 | 
            +
                    vocab_char_map = None
         | 
| 146 | 
            +
                    vocab_size = 256
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                return vocab_char_map, vocab_size
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            # convert char to pinyin
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            def convert_char_to_pinyin(text_list, polyphone = True):
         | 
| 154 | 
            +
                final_text_list = []
         | 
| 155 | 
            +
                god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"})  # in case librispeech (orig no-pc) test-clean
         | 
| 156 | 
            +
                custom_trans = str.maketrans({';': ','})  # add custom trans here, to address oov
         | 
| 157 | 
            +
                for text in text_list:
         | 
| 158 | 
            +
                    char_list = []
         | 
| 159 | 
            +
                    text = text.translate(god_knows_why_en_testset_contains_zh_quote)
         | 
| 160 | 
            +
                    text = text.translate(custom_trans)
         | 
| 161 | 
            +
                    for seg in jieba.cut(text):
         | 
| 162 | 
            +
                        seg_byte_len = len(bytes(seg, 'UTF-8'))
         | 
| 163 | 
            +
                        if seg_byte_len == len(seg):  # if pure alphabets and symbols
         | 
| 164 | 
            +
                            if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
         | 
| 165 | 
            +
                                char_list.append(" ")
         | 
| 166 | 
            +
                            char_list.extend(seg)
         | 
| 167 | 
            +
                        elif polyphone and seg_byte_len == 3 * len(seg):  # if pure chinese characters
         | 
| 168 | 
            +
                            seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
         | 
| 169 | 
            +
                            for c in seg:
         | 
| 170 | 
            +
                                if c not in "。,、;:?!《》【】—…":
         | 
| 171 | 
            +
                                    char_list.append(" ")
         | 
| 172 | 
            +
                                char_list.append(c)
         | 
| 173 | 
            +
                        else:  # if mixed chinese characters, alphabets and symbols
         | 
| 174 | 
            +
                            for c in seg:
         | 
| 175 | 
            +
                                if ord(c) < 256:
         | 
| 176 | 
            +
                                    char_list.extend(c)
         | 
| 177 | 
            +
                                else:
         | 
| 178 | 
            +
                                    if c not in "。,、;:?!《》【】—…":
         | 
| 179 | 
            +
                                        char_list.append(" ")
         | 
| 180 | 
            +
                                        char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
         | 
| 181 | 
            +
                                    else:  # if is zh punc
         | 
| 182 | 
            +
                                        char_list.append(c)
         | 
| 183 | 
            +
                    final_text_list.append(char_list)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                return final_text_list
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            # save spectrogram
         | 
| 189 | 
            +
            def save_spectrogram(spectrogram, path):
         | 
| 190 | 
            +
                plt.figure(figsize=(12, 4))
         | 
| 191 | 
            +
                plt.imshow(spectrogram, origin='lower', aspect='auto')
         | 
| 192 | 
            +
                plt.colorbar()
         | 
| 193 | 
            +
                plt.savefig(path)
         | 
| 194 | 
            +
                plt.close()
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
         | 
| 198 | 
            +
            def get_seedtts_testset_metainfo(metalst):
         | 
| 199 | 
            +
                f = open(metalst); lines = f.readlines(); f.close()
         | 
| 200 | 
            +
                metainfo = []
         | 
| 201 | 
            +
                for line in lines:
         | 
| 202 | 
            +
                    if len(line.strip().split('|')) == 5:
         | 
| 203 | 
            +
                        utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
         | 
| 204 | 
            +
                    elif len(line.strip().split('|')) == 4:
         | 
| 205 | 
            +
                        utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
         | 
| 206 | 
            +
                        gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
         | 
| 207 | 
            +
                    if not os.path.isabs(prompt_wav):
         | 
| 208 | 
            +
                        prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
         | 
| 209 | 
            +
                    metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
         | 
| 210 | 
            +
                return metainfo
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
         | 
| 214 | 
            +
            def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
         | 
| 215 | 
            +
                f = open(metalst); lines = f.readlines(); f.close()
         | 
| 216 | 
            +
                metainfo = []
         | 
| 217 | 
            +
                for line in lines:
         | 
| 218 | 
            +
                    ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.'  # if use librispeech test-clean (no-pc)
         | 
| 221 | 
            +
                    ref_spk_id, ref_chaptr_id, _ =  ref_utt.split('-')
         | 
| 222 | 
            +
                    ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.'  # if use librispeech test-clean (no-pc)
         | 
| 225 | 
            +
                    gen_spk_id, gen_chaptr_id, _ =  gen_utt.split('-')
         | 
| 226 | 
            +
                    gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                return metainfo
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            # padded to max length mel batch
         | 
| 234 | 
            +
            def padded_mel_batch(ref_mels):
         | 
| 235 | 
            +
                max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
         | 
| 236 | 
            +
                padded_ref_mels = []
         | 
| 237 | 
            +
                for mel in ref_mels:
         | 
| 238 | 
            +
                    padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
         | 
| 239 | 
            +
                    padded_ref_mels.append(padded_ref_mel)
         | 
| 240 | 
            +
                padded_ref_mels = torch.stack(padded_ref_mels)
         | 
| 241 | 
            +
                padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
         | 
| 242 | 
            +
                return padded_ref_mels
         | 
| 243 | 
            +
             | 
| 244 | 
            +
             | 
| 245 | 
            +
            # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
         | 
| 246 | 
            +
             | 
| 247 | 
            +
            def get_inference_prompt(
         | 
| 248 | 
            +
                metainfo, 
         | 
| 249 | 
            +
                speed = 1., tokenizer = "pinyin", polyphone = True, 
         | 
| 250 | 
            +
                target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
         | 
| 251 | 
            +
                use_truth_duration = False,
         | 
| 252 | 
            +
                infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
         | 
| 253 | 
            +
            ):
         | 
| 254 | 
            +
                prompts_all = []
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                min_tokens = min_secs * target_sample_rate // hop_length
         | 
| 257 | 
            +
                max_tokens = max_secs * target_sample_rate // hop_length
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                batch_accum = [0] * num_buckets
         | 
| 260 | 
            +
                utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
         | 
| 261 | 
            +
                    ([[] for _ in range(num_buckets)] for _ in range(6))
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    # Audio
         | 
| 268 | 
            +
                    ref_audio, ref_sr = torchaudio.load(prompt_wav)
         | 
| 269 | 
            +
                    ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
         | 
| 270 | 
            +
                    if ref_rms < target_rms:
         | 
| 271 | 
            +
                        ref_audio = ref_audio * target_rms / ref_rms
         | 
| 272 | 
            +
                    assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
         | 
| 273 | 
            +
                    if ref_sr != target_sample_rate:
         | 
| 274 | 
            +
                        resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
         | 
| 275 | 
            +
                        ref_audio = resampler(ref_audio)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    # Text
         | 
| 278 | 
            +
                    if len(prompt_text[-1].encode('utf-8')) == 1:
         | 
| 279 | 
            +
                        prompt_text = prompt_text + " "
         | 
| 280 | 
            +
                    text = [prompt_text + gt_text]
         | 
| 281 | 
            +
                    if tokenizer == "pinyin":
         | 
| 282 | 
            +
                        text_list = convert_char_to_pinyin(text, polyphone = polyphone)
         | 
| 283 | 
            +
                    else:
         | 
| 284 | 
            +
                        text_list = text
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # Duration, mel frame length
         | 
| 287 | 
            +
                    ref_mel_len = ref_audio.shape[-1] // hop_length
         | 
| 288 | 
            +
                    if use_truth_duration:
         | 
| 289 | 
            +
                        gt_audio, gt_sr = torchaudio.load(gt_wav)
         | 
| 290 | 
            +
                        if gt_sr != target_sample_rate:
         | 
| 291 | 
            +
                            resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
         | 
| 292 | 
            +
                            gt_audio = resampler(gt_audio)
         | 
| 293 | 
            +
                        total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                        # # test vocoder resynthesis
         | 
| 296 | 
            +
                        # ref_audio = gt_audio
         | 
| 297 | 
            +
                    else:
         | 
| 298 | 
            +
                        zh_pause_punc = r"。,、;:?!"
         | 
| 299 | 
            +
                        ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
         | 
| 300 | 
            +
                        gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
         | 
| 301 | 
            +
                        total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    # to mel spectrogram
         | 
| 304 | 
            +
                    ref_mel = mel_spectrogram(ref_audio)
         | 
| 305 | 
            +
                    ref_mel = rearrange(ref_mel, '1 d n -> d n')
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    # deal with batch
         | 
| 308 | 
            +
                    assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
         | 
| 309 | 
            +
                    assert min_tokens <= total_mel_len <= max_tokens, \
         | 
| 310 | 
            +
                        f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
         | 
| 311 | 
            +
                    bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    utts[bucket_i].append(utt)
         | 
| 314 | 
            +
                    ref_rms_list[bucket_i].append(ref_rms)
         | 
| 315 | 
            +
                    ref_mels[bucket_i].append(ref_mel)
         | 
| 316 | 
            +
                    ref_mel_lens[bucket_i].append(ref_mel_len)
         | 
| 317 | 
            +
                    total_mel_lens[bucket_i].append(total_mel_len)
         | 
| 318 | 
            +
                    final_text_list[bucket_i].extend(text_list)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    batch_accum[bucket_i] += total_mel_len
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    if batch_accum[bucket_i] >= infer_batch_size:
         | 
| 323 | 
            +
                        # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
         | 
| 324 | 
            +
                        prompts_all.append((
         | 
| 325 | 
            +
                            utts[bucket_i], 
         | 
| 326 | 
            +
                            ref_rms_list[bucket_i], 
         | 
| 327 | 
            +
                            padded_mel_batch(ref_mels[bucket_i]), 
         | 
| 328 | 
            +
                            ref_mel_lens[bucket_i], 
         | 
| 329 | 
            +
                            total_mel_lens[bucket_i], 
         | 
| 330 | 
            +
                            final_text_list[bucket_i]
         | 
| 331 | 
            +
                        ))
         | 
| 332 | 
            +
                        batch_accum[bucket_i] = 0
         | 
| 333 | 
            +
                        utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                # add residual
         | 
| 336 | 
            +
                for bucket_i, bucket_frames in enumerate(batch_accum):
         | 
| 337 | 
            +
                    if bucket_frames > 0:
         | 
| 338 | 
            +
                        prompts_all.append((
         | 
| 339 | 
            +
                            utts[bucket_i], 
         | 
| 340 | 
            +
                            ref_rms_list[bucket_i], 
         | 
| 341 | 
            +
                            padded_mel_batch(ref_mels[bucket_i]), 
         | 
| 342 | 
            +
                            ref_mel_lens[bucket_i], 
         | 
| 343 | 
            +
                            total_mel_lens[bucket_i], 
         | 
| 344 | 
            +
                            final_text_list[bucket_i]
         | 
| 345 | 
            +
                        ))
         | 
| 346 | 
            +
                # not only leave easy work for last workers
         | 
| 347 | 
            +
                random.seed(666)
         | 
| 348 | 
            +
                random.shuffle(prompts_all)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                return prompts_all
         | 
| 351 | 
            +
             | 
| 352 | 
            +
             | 
| 353 | 
            +
            # get wav_res_ref_text of seed-tts test metalst
         | 
| 354 | 
            +
            # https://github.com/BytedanceSpeech/seed-tts-eval
         | 
| 355 | 
            +
             | 
| 356 | 
            +
            def get_seed_tts_test(metalst, gen_wav_dir, gpus):
         | 
| 357 | 
            +
                f = open(metalst)
         | 
| 358 | 
            +
                lines = f.readlines()
         | 
| 359 | 
            +
                f.close()
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                test_set_ = []
         | 
| 362 | 
            +
                for line in tqdm(lines):
         | 
| 363 | 
            +
                    if len(line.strip().split('|')) == 5:
         | 
| 364 | 
            +
                        utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
         | 
| 365 | 
            +
                    elif len(line.strip().split('|')) == 4:
         | 
| 366 | 
            +
                        utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
         | 
| 369 | 
            +
                        continue
         | 
| 370 | 
            +
                    gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
         | 
| 371 | 
            +
                    if not os.path.isabs(prompt_wav):
         | 
| 372 | 
            +
                        prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    test_set_.append((gen_wav, prompt_wav, gt_text))
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                num_jobs = len(gpus)
         | 
| 377 | 
            +
                if num_jobs == 1:
         | 
| 378 | 
            +
                    return [(gpus[0], test_set_)]
         | 
| 379 | 
            +
                
         | 
| 380 | 
            +
                wav_per_job = len(test_set_) // num_jobs + 1
         | 
| 381 | 
            +
                test_set = []
         | 
| 382 | 
            +
                for i in range(num_jobs):
         | 
| 383 | 
            +
                    test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                return test_set
         | 
| 386 | 
            +
             | 
| 387 | 
            +
             | 
| 388 | 
            +
            # get librispeech test-clean cross sentence test
         | 
| 389 | 
            +
             | 
| 390 | 
            +
            def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
         | 
| 391 | 
            +
                f = open(metalst)
         | 
| 392 | 
            +
                lines = f.readlines()
         | 
| 393 | 
            +
                f.close()
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                test_set_ = []
         | 
| 396 | 
            +
                for line in tqdm(lines):
         | 
| 397 | 
            +
                    ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    if eval_ground_truth:
         | 
| 400 | 
            +
                        gen_spk_id, gen_chaptr_id, _ =  gen_utt.split('-')
         | 
| 401 | 
            +
                        gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
         | 
| 402 | 
            +
                    else:
         | 
| 403 | 
            +
                        if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
         | 
| 404 | 
            +
                            raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
         | 
| 405 | 
            +
                        gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                    ref_spk_id, ref_chaptr_id, _ =  ref_utt.split('-')
         | 
| 408 | 
            +
                    ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    test_set_.append((gen_wav, ref_wav, gen_txt))
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                num_jobs = len(gpus)
         | 
| 413 | 
            +
                if num_jobs == 1:
         | 
| 414 | 
            +
                    return [(gpus[0], test_set_)]
         | 
| 415 | 
            +
                
         | 
| 416 | 
            +
                wav_per_job = len(test_set_) // num_jobs + 1
         | 
| 417 | 
            +
                test_set = []
         | 
| 418 | 
            +
                for i in range(num_jobs):
         | 
| 419 | 
            +
                    test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                return test_set
         | 
| 422 | 
            +
             | 
| 423 | 
            +
             | 
| 424 | 
            +
            # load asr model
         | 
| 425 | 
            +
             | 
| 426 | 
            +
            def load_asr_model(lang, ckpt_dir = ""):
         | 
| 427 | 
            +
                if lang == "zh":
         | 
| 428 | 
            +
                    model = AutoModel(
         | 
| 429 | 
            +
                        model = os.path.join(ckpt_dir, "paraformer-zh"), 
         | 
| 430 | 
            +
                        # vad_model = os.path.join(ckpt_dir, "fsmn-vad"), 
         | 
| 431 | 
            +
                        # punc_model = os.path.join(ckpt_dir, "ct-punc"),
         | 
| 432 | 
            +
                        # spk_model = os.path.join(ckpt_dir, "cam++"), 
         | 
| 433 | 
            +
                        disable_update=True,
         | 
| 434 | 
            +
                        )  # following seed-tts setting
         | 
| 435 | 
            +
                elif lang == "en":
         | 
| 436 | 
            +
                    model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
         | 
| 437 | 
            +
                    model = WhisperModel(model_size, device="cuda", compute_type="float16")
         | 
| 438 | 
            +
                return model
         | 
| 439 | 
            +
             | 
| 440 | 
            +
             | 
| 441 | 
            +
            # WER Evaluation, the way Seed-TTS does
         | 
| 442 | 
            +
             | 
| 443 | 
            +
            def run_asr_wer(args):
         | 
| 444 | 
            +
                rank, lang, test_set, ckpt_dir = args
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                if lang == "zh":
         | 
| 447 | 
            +
                    torch.cuda.set_device(rank)
         | 
| 448 | 
            +
                elif lang == "en":
         | 
| 449 | 
            +
                    os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
         | 
| 450 | 
            +
                else:
         | 
| 451 | 
            +
                    raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                punctuation_all = punctuation + string.punctuation
         | 
| 456 | 
            +
                wers = []
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                for gen_wav, prompt_wav, truth in tqdm(test_set):
         | 
| 459 | 
            +
                    if lang == "zh":
         | 
| 460 | 
            +
                        res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
         | 
| 461 | 
            +
                        hypo = res[0]["text"]
         | 
| 462 | 
            +
                        hypo = zhconv.convert(hypo, 'zh-cn')
         | 
| 463 | 
            +
                    elif lang == "en":
         | 
| 464 | 
            +
                        segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
         | 
| 465 | 
            +
                        hypo = ''
         | 
| 466 | 
            +
                        for segment in segments:
         | 
| 467 | 
            +
                            hypo = hypo + ' ' + segment.text
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                    # raw_truth = truth
         | 
| 470 | 
            +
                    # raw_hypo = hypo
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                    for x in punctuation_all:
         | 
| 473 | 
            +
                        truth = truth.replace(x, '')
         | 
| 474 | 
            +
                        hypo = hypo.replace(x, '')
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    truth = truth.replace('  ', ' ')
         | 
| 477 | 
            +
                    hypo = hypo.replace('  ', ' ')
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                    if lang == "zh":
         | 
| 480 | 
            +
                        truth = " ".join([x for x in truth])
         | 
| 481 | 
            +
                        hypo = " ".join([x for x in hypo])
         | 
| 482 | 
            +
                    elif lang == "en":
         | 
| 483 | 
            +
                        truth = truth.lower()
         | 
| 484 | 
            +
                        hypo = hypo.lower()
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    measures = compute_measures(truth, hypo)
         | 
| 487 | 
            +
                    wer = measures["wer"]
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    # ref_list = truth.split(" ")
         | 
| 490 | 
            +
                    # subs = measures["substitutions"] / len(ref_list)
         | 
| 491 | 
            +
                    # dele = measures["deletions"] / len(ref_list)
         | 
| 492 | 
            +
                    # inse = measures["insertions"] / len(ref_list)
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    wers.append(wer)
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                return wers
         | 
| 497 | 
            +
             | 
| 498 | 
            +
             | 
| 499 | 
            +
            # SIM Evaluation
         | 
| 500 | 
            +
             | 
| 501 | 
            +
            def run_sim(args):
         | 
| 502 | 
            +
                rank, test_set, ckpt_dir = args
         | 
| 503 | 
            +
                device = f"cuda:{rank}"
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
         | 
| 506 | 
            +
                state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
         | 
| 507 | 
            +
                model.load_state_dict(state_dict['model'], strict=False)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                use_gpu=True if torch.cuda.is_available() else False
         | 
| 510 | 
            +
                if use_gpu:
         | 
| 511 | 
            +
                    model = model.cuda(device)
         | 
| 512 | 
            +
                model.eval()
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                sim_list = []
         | 
| 515 | 
            +
                for wav1, wav2, truth in tqdm(test_set):
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                    wav1, sr1 = torchaudio.load(wav1)
         | 
| 518 | 
            +
                    wav2, sr2 = torchaudio.load(wav2)
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
         | 
| 521 | 
            +
                    resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
         | 
| 522 | 
            +
                    wav1 = resample1(wav1)
         | 
| 523 | 
            +
                    wav2 = resample2(wav2)
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    if use_gpu:
         | 
| 526 | 
            +
                        wav1 = wav1.cuda(device)
         | 
| 527 | 
            +
                        wav2 = wav2.cuda(device)
         | 
| 528 | 
            +
                    with torch.no_grad():
         | 
| 529 | 
            +
                        emb1 = model(wav1)
         | 
| 530 | 
            +
                        emb2 = model(wav2)
         | 
| 531 | 
            +
                    
         | 
| 532 | 
            +
                    sim = F.cosine_similarity(emb1, emb2)[0].item()
         | 
| 533 | 
            +
                    # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
         | 
| 534 | 
            +
                    sim_list.append(sim)
         | 
| 535 | 
            +
                
         | 
| 536 | 
            +
                return sim_list
         | 
| 537 | 
            +
             | 
| 538 | 
            +
             | 
| 539 | 
            +
            # filter func for dirty data with many repetitions
         | 
| 540 | 
            +
             | 
| 541 | 
            +
            def repetition_found(text, length = 2, tolerance = 10):
         | 
| 542 | 
            +
                pattern_count = defaultdict(int)
         | 
| 543 | 
            +
                for i in range(len(text) - length + 1):
         | 
| 544 | 
            +
                    pattern = text[i:i + length]
         | 
| 545 | 
            +
                    pattern_count[pattern] += 1
         | 
| 546 | 
            +
                for pattern, count in pattern_count.items():
         | 
| 547 | 
            +
                    if count > tolerance:
         | 
| 548 | 
            +
                        return True
         | 
| 549 | 
            +
                return False
         | 
| 550 | 
            +
             | 
| 551 | 
            +
             | 
| 552 | 
            +
            # load model checkpoint for inference
         | 
| 553 | 
            +
             | 
| 554 | 
            +
            def load_checkpoint(model, ckpt_path, device, use_ema = True):
         | 
| 555 | 
            +
                from ema_pytorch import EMA
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                ckpt_type = ckpt_path.split(".")[-1]
         | 
| 558 | 
            +
                if ckpt_type == "safetensors":
         | 
| 559 | 
            +
                    from safetensors.torch import load_file
         | 
| 560 | 
            +
                    checkpoint = load_file(ckpt_path, device=device)
         | 
| 561 | 
            +
                else:
         | 
| 562 | 
            +
                    checkpoint = torch.load(ckpt_path, map_location=device)
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                if use_ema == True:
         | 
| 565 | 
            +
                    ema_model = EMA(model, include_online_model = False).to(device)
         | 
| 566 | 
            +
                    if ckpt_type == "safetensors":
         | 
| 567 | 
            +
                        ema_model.load_state_dict(checkpoint)
         | 
| 568 | 
            +
                    else:
         | 
| 569 | 
            +
                        ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
         | 
| 570 | 
            +
                    ema_model.copy_params_from_ema_to_model()
         | 
| 571 | 
            +
                else:
         | 
| 572 | 
            +
                    model.load_state_dict(checkpoint['model_state_dict'])
         | 
| 573 | 
            +
                    
         | 
| 574 | 
            +
                return model
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            accelerate>=0.33.0
         | 
| 2 | 
            +
            cached_path
         | 
| 3 | 
            +
            click
         | 
| 4 | 
            +
            datasets
         | 
| 5 | 
            +
            einops>=0.8.0
         | 
| 6 | 
            +
            einx>=0.3.0
         | 
| 7 | 
            +
            ema_pytorch>=0.5.2
         | 
| 8 | 
            +
            faster_whisper
         | 
| 9 | 
            +
            funasr
         | 
| 10 | 
            +
            gradio
         | 
| 11 | 
            +
            jieba
         | 
| 12 | 
            +
            jiwer
         | 
| 13 | 
            +
            librosa
         | 
| 14 | 
            +
            matplotlib
         | 
| 15 | 
            +
            numpy==1.23.5
         | 
| 16 | 
            +
            pydub
         | 
| 17 | 
            +
            pypinyin
         | 
| 18 | 
            +
            safetensors
         | 
| 19 | 
            +
            soundfile
         | 
| 20 | 
            +
            # torch>=2.0
         | 
| 21 | 
            +
            # torchaudio>=2.3.0
         | 
| 22 | 
            +
            torchdiffeq
         | 
| 23 | 
            +
            tqdm>=4.65.0
         | 
| 24 | 
            +
            transformers
         | 
| 25 | 
            +
            vocos
         | 
| 26 | 
            +
            wandb
         | 
| 27 | 
            +
            x_transformers>=1.31.14
         | 
| 28 | 
            +
            zhconv
         | 
| 29 | 
            +
            zhon
         | 
    	
        scripts/count_max_epoch.py
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            '''ADAPTIVE BATCH SIZE'''
         | 
| 2 | 
            +
            print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
         | 
| 3 | 
            +
            print('  -> least padding, gather wavs with accumulated frames in a batch\n')
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # data
         | 
| 6 | 
            +
            total_hours = 95282
         | 
| 7 | 
            +
            mel_hop_length = 256
         | 
| 8 | 
            +
            mel_sampling_rate = 24000
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # target
         | 
| 11 | 
            +
            wanted_max_updates = 1000000
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # train params
         | 
| 14 | 
            +
            gpus = 8
         | 
| 15 | 
            +
            frames_per_gpu = 38400  # 8 * 38400 = 307200
         | 
| 16 | 
            +
            grad_accum = 1
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # intermediate
         | 
| 19 | 
            +
            mini_batch_frames = frames_per_gpu * grad_accum * gpus
         | 
| 20 | 
            +
            mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
         | 
| 21 | 
            +
            updates_per_epoch = total_hours / mini_batch_hours
         | 
| 22 | 
            +
            steps_per_epoch = updates_per_epoch * grad_accum
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # result
         | 
| 25 | 
            +
            epochs = wanted_max_updates / updates_per_epoch
         | 
| 26 | 
            +
            print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
         | 
| 27 | 
            +
            print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
         | 
| 28 | 
            +
            print(f"                      or approx. 0/{steps_per_epoch:.0f} steps")
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # others
         | 
| 31 | 
            +
            print(f"total {total_hours:.0f} hours")
         | 
| 32 | 
            +
            print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
         | 
    	
        scripts/count_params_gflops.py
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys, os
         | 
| 2 | 
            +
            sys.path.append(os.getcwd())
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from model import M2_TTS, UNetT, DiT, MMDiT
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import thop
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            ''' ~155M '''
         | 
| 11 | 
            +
            # transformer =     UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
         | 
| 12 | 
            +
            # transformer =     UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
         | 
| 13 | 
            +
            # transformer =       DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
         | 
| 14 | 
            +
            # transformer =       DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
         | 
| 15 | 
            +
            # transformer =       DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
         | 
| 16 | 
            +
            # transformer =     MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            ''' ~335M '''
         | 
| 19 | 
            +
            # FLOPs: 622.1 G, Params: 333.2 M
         | 
| 20 | 
            +
            # transformer =     UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
         | 
| 21 | 
            +
            # FLOPs: 363.4 G, Params: 335.8 M
         | 
| 22 | 
            +
            transformer =       DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            model = M2_TTS(transformer=transformer)
         | 
| 26 | 
            +
            target_sample_rate = 24000
         | 
| 27 | 
            +
            n_mel_channels = 100
         | 
| 28 | 
            +
            hop_length = 256
         | 
| 29 | 
            +
            duration = 20
         | 
| 30 | 
            +
            frame_length = int(duration * target_sample_rate / hop_length)
         | 
| 31 | 
            +
            text_length = 150
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
         | 
| 34 | 
            +
            print(f"FLOPs: {flops / 1e9} G")
         | 
| 35 | 
            +
            print(f"Params: {params / 1e6} M")
         | 
    	
        scripts/eval_infer_batch.py
    ADDED
    
    | @@ -0,0 +1,199 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys, os
         | 
| 2 | 
            +
            sys.path.append(os.getcwd())
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import time
         | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
            from tqdm import tqdm
         | 
| 7 | 
            +
            import argparse
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torchaudio
         | 
| 11 | 
            +
            from accelerate import Accelerator
         | 
| 12 | 
            +
            from einops import rearrange
         | 
| 13 | 
            +
            from vocos import Vocos
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from model import CFM, UNetT, DiT
         | 
| 16 | 
            +
            from model.utils import (
         | 
| 17 | 
            +
                load_checkpoint,
         | 
| 18 | 
            +
                get_tokenizer, 
         | 
| 19 | 
            +
                get_seedtts_testset_metainfo, 
         | 
| 20 | 
            +
                get_librispeech_test_clean_metainfo, 
         | 
| 21 | 
            +
                get_inference_prompt,
         | 
| 22 | 
            +
            )
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            accelerator = Accelerator()
         | 
| 25 | 
            +
            device = f"cuda:{accelerator.process_index}"
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            # --------------------- Dataset Settings -------------------- #
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            target_sample_rate = 24000
         | 
| 31 | 
            +
            n_mel_channels = 100
         | 
| 32 | 
            +
            hop_length = 256
         | 
| 33 | 
            +
            target_rms = 0.1
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            tokenizer = "pinyin"
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            # ---------------------- infer setting ---------------------- #
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            parser = argparse.ArgumentParser(description="batch inference")
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            parser.add_argument('-s', '--seed', default=None, type=int)
         | 
| 43 | 
            +
            parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
         | 
| 44 | 
            +
            parser.add_argument('-n', '--expname', required=True)
         | 
| 45 | 
            +
            parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            parser.add_argument('-nfe', '--nfestep', default=32, type=int)
         | 
| 48 | 
            +
            parser.add_argument('-o', '--odemethod', default="euler")
         | 
| 49 | 
            +
            parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            parser.add_argument('-t', '--testset', required=True)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            args = parser.parse_args()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            seed = args.seed
         | 
| 57 | 
            +
            dataset_name = args.dataset
         | 
| 58 | 
            +
            exp_name = args.expname
         | 
| 59 | 
            +
            ckpt_step = args.ckptstep
         | 
| 60 | 
            +
            ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            nfe_step = args.nfestep
         | 
| 63 | 
            +
            ode_method = args.odemethod
         | 
| 64 | 
            +
            sway_sampling_coef = args.swaysampling
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            testset = args.testset
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            infer_batch_size = 1  # max frames. 1 for ddp single inference (recommended)
         | 
| 70 | 
            +
            cfg_strength = 2.
         | 
| 71 | 
            +
            speed = 1.
         | 
| 72 | 
            +
            use_truth_duration = False
         | 
| 73 | 
            +
            no_ref_audio = False
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            if exp_name == "F5TTS_Base":
         | 
| 77 | 
            +
                model_cls = DiT
         | 
| 78 | 
            +
                model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            elif exp_name == "E2TTS_Base":
         | 
| 81 | 
            +
                model_cls = UNetT
         | 
| 82 | 
            +
                model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            if testset == "ls_pc_test_clean":
         | 
| 86 | 
            +
                metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
         | 
| 87 | 
            +
                librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean"  # test-clean path
         | 
| 88 | 
            +
                metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
            elif testset == "seedtts_test_zh":
         | 
| 91 | 
            +
                metalst = "data/seedtts_testset/zh/meta.lst"
         | 
| 92 | 
            +
                metainfo = get_seedtts_testset_metainfo(metalst)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            elif testset == "seedtts_test_en":
         | 
| 95 | 
            +
                metalst = "data/seedtts_testset/en/meta.lst"
         | 
| 96 | 
            +
                metainfo = get_seedtts_testset_metainfo(metalst)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            # path to save genereted wavs
         | 
| 100 | 
            +
            if seed is None: seed = random.randint(-10000, 10000)
         | 
| 101 | 
            +
            output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
         | 
| 102 | 
            +
                f"seed{seed}_{ode_method}_nfe{nfe_step}" \
         | 
| 103 | 
            +
                f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
         | 
| 104 | 
            +
                f"_cfg{cfg_strength}_speed{speed}" \
         | 
| 105 | 
            +
                f"{'_gt-dur' if use_truth_duration else ''}" \
         | 
| 106 | 
            +
                f"{'_no-ref-audio' if no_ref_audio else ''}"
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
            # -------------------------------------------------#
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            use_ema = True
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            prompts_all = get_inference_prompt(
         | 
| 114 | 
            +
                metainfo, 
         | 
| 115 | 
            +
                speed = speed, 
         | 
| 116 | 
            +
                tokenizer = tokenizer, 
         | 
| 117 | 
            +
                target_sample_rate = target_sample_rate, 
         | 
| 118 | 
            +
                n_mel_channels = n_mel_channels,
         | 
| 119 | 
            +
                hop_length = hop_length,
         | 
| 120 | 
            +
                target_rms = target_rms,
         | 
| 121 | 
            +
                use_truth_duration = use_truth_duration,
         | 
| 122 | 
            +
                infer_batch_size = infer_batch_size,
         | 
| 123 | 
            +
            )
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            # Vocoder model
         | 
| 126 | 
            +
            local = False
         | 
| 127 | 
            +
            if local:
         | 
| 128 | 
            +
                vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
         | 
| 129 | 
            +
                vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
         | 
| 130 | 
            +
                state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
         | 
| 131 | 
            +
                vocos.load_state_dict(state_dict)
         | 
| 132 | 
            +
                vocos.eval()
         | 
| 133 | 
            +
            else:
         | 
| 134 | 
            +
                vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            # Tokenizer
         | 
| 137 | 
            +
            vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            # Model
         | 
| 140 | 
            +
            model = CFM(
         | 
| 141 | 
            +
                transformer = model_cls(
         | 
| 142 | 
            +
                    **model_cfg,
         | 
| 143 | 
            +
                    text_num_embeds = vocab_size, 
         | 
| 144 | 
            +
                    mel_dim = n_mel_channels
         | 
| 145 | 
            +
                ),
         | 
| 146 | 
            +
                mel_spec_kwargs = dict(
         | 
| 147 | 
            +
                    target_sample_rate = target_sample_rate, 
         | 
| 148 | 
            +
                    n_mel_channels = n_mel_channels,
         | 
| 149 | 
            +
                    hop_length = hop_length,
         | 
| 150 | 
            +
                ),
         | 
| 151 | 
            +
                odeint_kwargs = dict(
         | 
| 152 | 
            +
                    method = ode_method,
         | 
| 153 | 
            +
                ),
         | 
| 154 | 
            +
                vocab_char_map = vocab_char_map,
         | 
| 155 | 
            +
            ).to(device)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            if not os.path.exists(output_dir) and accelerator.is_main_process:
         | 
| 160 | 
            +
                os.makedirs(output_dir)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
            # start batch inference
         | 
| 163 | 
            +
            accelerator.wait_for_everyone()
         | 
| 164 | 
            +
            start = time.time()
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            with accelerator.split_between_processes(prompts_all) as prompts:
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
         | 
| 169 | 
            +
                    utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
         | 
| 170 | 
            +
                    ref_mels = ref_mels.to(device)
         | 
| 171 | 
            +
                    ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
         | 
| 172 | 
            +
                    total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
         | 
| 173 | 
            +
                    
         | 
| 174 | 
            +
                    # Inference
         | 
| 175 | 
            +
                    with torch.inference_mode():
         | 
| 176 | 
            +
                        generated, _ = model.sample(
         | 
| 177 | 
            +
                            cond = ref_mels,
         | 
| 178 | 
            +
                            text = final_text_list,
         | 
| 179 | 
            +
                            duration = total_mel_lens,
         | 
| 180 | 
            +
                            lens = ref_mel_lens,
         | 
| 181 | 
            +
                            steps = nfe_step,
         | 
| 182 | 
            +
                            cfg_strength = cfg_strength,
         | 
| 183 | 
            +
                            sway_sampling_coef = sway_sampling_coef,
         | 
| 184 | 
            +
                            no_ref_audio = no_ref_audio,
         | 
| 185 | 
            +
                            seed = seed,
         | 
| 186 | 
            +
                        )
         | 
| 187 | 
            +
                    # Final result
         | 
| 188 | 
            +
                    for i, gen in enumerate(generated):
         | 
| 189 | 
            +
                        gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
         | 
| 190 | 
            +
                        gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
         | 
| 191 | 
            +
                        generated_wave = vocos.decode(gen_mel_spec.cpu())
         | 
| 192 | 
            +
                        if ref_rms_list[i] < target_rms:
         | 
| 193 | 
            +
                            generated_wave = generated_wave * ref_rms_list[i] / target_rms
         | 
| 194 | 
            +
                        torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            accelerator.wait_for_everyone()
         | 
| 197 | 
            +
            if accelerator.is_main_process:
         | 
| 198 | 
            +
                timediff = time.time() - start
         | 
| 199 | 
            +
                print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
         | 
    	
        scripts/eval_infer_batch.sh
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/bin/bash
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # e.g. F5-TTS, 16 NFE
         | 
| 4 | 
            +
            accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
         | 
| 5 | 
            +
            accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
         | 
| 6 | 
            +
            accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # e.g. Vanilla E2 TTS, 32 NFE
         | 
| 9 | 
            +
            accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
         | 
| 10 | 
            +
            accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
         | 
| 11 | 
            +
            accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # etc.
         | 
    	
        scripts/eval_librispeech_test_clean.py
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import sys, os
         | 
| 4 | 
            +
            sys.path.append(os.getcwd())
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import multiprocessing as mp
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from model.utils import (
         | 
| 10 | 
            +
                get_librispeech_test,
         | 
| 11 | 
            +
                run_asr_wer,
         | 
| 12 | 
            +
                run_sim,
         | 
| 13 | 
            +
            )
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            eval_task = "wer"  # sim | wer
         | 
| 17 | 
            +
            lang = "en"
         | 
| 18 | 
            +
            metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
         | 
| 19 | 
            +
            librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean"  # test-clean path
         | 
| 20 | 
            +
            gen_wav_dir = "PATH_TO_GENERATED"  # generated wavs
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            gpus = [0,1,2,3,4,5,6,7]
         | 
| 23 | 
            +
            test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
         | 
| 26 | 
            +
            ## leading to a low similarity for the ground truth in some cases.
         | 
| 27 | 
            +
            # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True)  # eval ground truth
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            local = False
         | 
| 30 | 
            +
            if local:  # use local custom checkpoint dir
         | 
| 31 | 
            +
                asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
         | 
| 32 | 
            +
            else:
         | 
| 33 | 
            +
                asr_ckpt_dir = ""  # auto download to cache dir
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            # --------------------------- WER ---------------------------
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            if eval_task == "wer":
         | 
| 41 | 
            +
                wers = []
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                with mp.Pool(processes=len(gpus)) as pool:
         | 
| 44 | 
            +
                    args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
         | 
| 45 | 
            +
                    results = pool.map(run_asr_wer, args)
         | 
| 46 | 
            +
                    for wers_ in results:
         | 
| 47 | 
            +
                        wers.extend(wers_)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                wer = round(np.mean(wers)*100, 3)
         | 
| 50 | 
            +
                print(f"\nTotal {len(wers)} samples")
         | 
| 51 | 
            +
                print(f"WER      : {wer}%")
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            # --------------------------- SIM ---------------------------
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            if eval_task == "sim":
         | 
| 57 | 
            +
                sim_list = []
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                with mp.Pool(processes=len(gpus)) as pool:
         | 
| 60 | 
            +
                    args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
         | 
| 61 | 
            +
                    results = pool.map(run_sim, args)
         | 
| 62 | 
            +
                    for sim_ in results:
         | 
| 63 | 
            +
                        sim_list.extend(sim_)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                sim = round(sum(sim_list)/len(sim_list), 3)
         | 
| 66 | 
            +
                print(f"\nTotal {len(sim_list)} samples")
         | 
| 67 | 
            +
                print(f"SIM      : {sim}")
         | 
    	
        scripts/eval_seedtts_testset.py
    ADDED
    
    | @@ -0,0 +1,69 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Evaluate with Seed-TTS testset
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import sys, os
         | 
| 4 | 
            +
            sys.path.append(os.getcwd())
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import multiprocessing as mp
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from model.utils import (
         | 
| 10 | 
            +
                get_seed_tts_test,
         | 
| 11 | 
            +
                run_asr_wer,
         | 
| 12 | 
            +
                run_sim,
         | 
| 13 | 
            +
            )
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            eval_task = "wer"  # sim | wer
         | 
| 17 | 
            +
            lang = "zh"        # zh | en
         | 
| 18 | 
            +
            metalst = f"data/seedtts_testset/{lang}/meta.lst"  # seed-tts testset
         | 
| 19 | 
            +
            # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs"  # ground truth wavs
         | 
| 20 | 
            +
            gen_wav_dir = f"PATH_TO_GENERATED"  # generated wavs
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
         | 
| 24 | 
            +
            #       zh 1.254 seems a result of 4 workers wer_seed_tts 
         | 
| 25 | 
            +
            gpus = [0,1,2,3,4,5,6,7]
         | 
| 26 | 
            +
            test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            local = False
         | 
| 29 | 
            +
            if local:  # use local custom checkpoint dir
         | 
| 30 | 
            +
                if lang == "zh":
         | 
| 31 | 
            +
                    asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
         | 
| 32 | 
            +
                elif lang == "en":
         | 
| 33 | 
            +
                    asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
         | 
| 34 | 
            +
            else:
         | 
| 35 | 
            +
                asr_ckpt_dir = ""  # auto download to cache dir
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            # --------------------------- WER ---------------------------
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            if eval_task == "wer":
         | 
| 43 | 
            +
                wers = []
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                with mp.Pool(processes=len(gpus)) as pool:
         | 
| 46 | 
            +
                    args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
         | 
| 47 | 
            +
                    results = pool.map(run_asr_wer, args)
         | 
| 48 | 
            +
                    for wers_ in results:
         | 
| 49 | 
            +
                        wers.extend(wers_)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                wer = round(np.mean(wers)*100, 3)
         | 
| 52 | 
            +
                print(f"\nTotal {len(wers)} samples")
         | 
| 53 | 
            +
                print(f"WER      : {wer}%")
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            # --------------------------- SIM ---------------------------
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            if eval_task == "sim":
         | 
| 59 | 
            +
                sim_list = []
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                with mp.Pool(processes=len(gpus)) as pool:
         | 
| 62 | 
            +
                    args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
         | 
| 63 | 
            +
                    results = pool.map(run_sim, args)
         | 
| 64 | 
            +
                    for sim_ in results:
         | 
| 65 | 
            +
                        sim_list.extend(sim_)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                sim = round(sum(sim_list)/len(sim_list), 3)
         | 
| 68 | 
            +
                print(f"\nTotal {len(sim_list)} samples")
         | 
| 69 | 
            +
                print(f"SIM      : {sim}")
         | 
    	
        scripts/prepare_emilia.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
         | 
| 2 | 
            +
            # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # generate audio text map for Emilia ZH & EN
         | 
| 5 | 
            +
            # evaluate for vocab size
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import sys, os
         | 
| 8 | 
            +
            sys.path.append(os.getcwd())
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from pathlib import Path
         | 
| 11 | 
            +
            import json
         | 
| 12 | 
            +
            from tqdm import tqdm
         | 
| 13 | 
            +
            from concurrent.futures import ProcessPoolExecutor
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from datasets import Dataset
         | 
| 16 | 
            +
            from datasets.arrow_writer import ArrowWriter
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from model.utils import (
         | 
| 19 | 
            +
                repetition_found,
         | 
| 20 | 
            +
                convert_char_to_pinyin,
         | 
| 21 | 
            +
            )
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
         | 
| 25 | 
            +
            zh_filters = ["い", "て"]
         | 
| 26 | 
            +
            # seems synthesized audios, or heavily code-switched
         | 
| 27 | 
            +
            out_en = {
         | 
| 28 | 
            +
                "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
         | 
| 31 | 
            +
            }
         | 
| 32 | 
            +
            en_filters = ["ا", "い", "て"]
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def deal_with_audio_dir(audio_dir):
         | 
| 36 | 
            +
                audio_jsonl = audio_dir.with_suffix(".jsonl")
         | 
| 37 | 
            +
                sub_result, durations = [], []
         | 
| 38 | 
            +
                vocab_set = set()
         | 
| 39 | 
            +
                bad_case_zh = 0
         | 
| 40 | 
            +
                bad_case_en = 0
         | 
| 41 | 
            +
                with open(audio_jsonl, "r") as f:
         | 
| 42 | 
            +
                    lines = f.readlines()
         | 
| 43 | 
            +
                    for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
         | 
| 44 | 
            +
                        obj = json.loads(line)
         | 
| 45 | 
            +
                        text = obj["text"]
         | 
| 46 | 
            +
                        if obj['language'] == "zh":
         | 
| 47 | 
            +
                            if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
         | 
| 48 | 
            +
                                bad_case_zh += 1
         | 
| 49 | 
            +
                                continue
         | 
| 50 | 
            +
                            else:
         | 
| 51 | 
            +
                                text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'}))  # not "。" cuz much code-switched
         | 
| 52 | 
            +
                        if obj['language'] == "en":
         | 
| 53 | 
            +
                            if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
         | 
| 54 | 
            +
                                bad_case_en += 1
         | 
| 55 | 
            +
                                continue
         | 
| 56 | 
            +
                        if tokenizer == "pinyin":
         | 
| 57 | 
            +
                            text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
         | 
| 58 | 
            +
                        duration = obj["duration"]
         | 
| 59 | 
            +
                        sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
         | 
| 60 | 
            +
                        durations.append(duration)
         | 
| 61 | 
            +
                        vocab_set.update(list(text))
         | 
| 62 | 
            +
                return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def main():
         | 
| 66 | 
            +
                assert tokenizer in ["pinyin", "char"]
         | 
| 67 | 
            +
                result = []
         | 
| 68 | 
            +
                duration_list = []
         | 
| 69 | 
            +
                text_vocab_set = set()
         | 
| 70 | 
            +
                total_bad_case_zh = 0
         | 
| 71 | 
            +
                total_bad_case_en = 0
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                # process raw data
         | 
| 74 | 
            +
                executor = ProcessPoolExecutor(max_workers=max_workers)
         | 
| 75 | 
            +
                futures = []
         | 
| 76 | 
            +
                for lang in langs:
         | 
| 77 | 
            +
                    dataset_path = Path(os.path.join(dataset_dir, lang))
         | 
| 78 | 
            +
                    [
         | 
| 79 | 
            +
                        futures.append(executor.submit(deal_with_audio_dir, audio_dir))
         | 
| 80 | 
            +
                        for audio_dir in dataset_path.iterdir()
         | 
| 81 | 
            +
                        if audio_dir.is_dir()
         | 
| 82 | 
            +
                    ]
         | 
| 83 | 
            +
                for futures in tqdm(futures, total=len(futures)):
         | 
| 84 | 
            +
                    sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
         | 
| 85 | 
            +
                    result.extend(sub_result)
         | 
| 86 | 
            +
                    duration_list.extend(durations)
         | 
| 87 | 
            +
                    text_vocab_set.update(vocab_set)
         | 
| 88 | 
            +
                    total_bad_case_zh += bad_case_zh
         | 
| 89 | 
            +
                    total_bad_case_en += bad_case_en
         | 
| 90 | 
            +
                executor.shutdown()
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                # save preprocessed dataset to disk
         | 
| 93 | 
            +
                if not os.path.exists(f"data/{dataset_name}"):
         | 
| 94 | 
            +
                    os.makedirs(f"data/{dataset_name}")
         | 
| 95 | 
            +
                print(f"\nSaving to data/{dataset_name} ...")
         | 
| 96 | 
            +
                # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})  # oom
         | 
| 97 | 
            +
                # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
         | 
| 98 | 
            +
                with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
         | 
| 99 | 
            +
                    for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
         | 
| 100 | 
            +
                        writer.write(line)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                # dup a json separately saving duration in case for DynamicBatchSampler ease
         | 
| 103 | 
            +
                with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
         | 
| 104 | 
            +
                    json.dump({"duration": duration_list}, f, ensure_ascii=False)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                # vocab map, i.e. tokenizer
         | 
| 107 | 
            +
                # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
         | 
| 108 | 
            +
                # if tokenizer == "pinyin":
         | 
| 109 | 
            +
                #     text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
         | 
| 110 | 
            +
                with open(f"data/{dataset_name}/vocab.txt", "w") as f:
         | 
| 111 | 
            +
                    for vocab in sorted(text_vocab_set):
         | 
| 112 | 
            +
                        f.write(vocab + "\n")
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                print(f"\nFor {dataset_name}, sample count: {len(result)}")
         | 
| 115 | 
            +
                print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
         | 
| 116 | 
            +
                print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
         | 
| 117 | 
            +
                if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
         | 
| 118 | 
            +
                if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            if __name__ == "__main__":
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                max_workers = 32
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                tokenizer = "pinyin"  # "pinyin" | "char"
         | 
| 126 | 
            +
                polyphone = True
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                langs = ["ZH", "EN"]
         | 
| 129 | 
            +
                dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
         | 
| 130 | 
            +
                dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
         | 
| 131 | 
            +
                print(f"\nPrepare for {dataset_name}\n")
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                main()
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                # Emilia               ZH & EN
         | 
| 136 | 
            +
                # samples count       37837916   (after removal)
         | 
| 137 | 
            +
                # pinyin vocab size       2543   (polyphone)
         | 
| 138 | 
            +
                # total duration      95281.87   (hours)
         | 
| 139 | 
            +
                # bad zh asr cnt        230435   (samples)
         | 
| 140 | 
            +
                # bad eh asr cnt         37217   (samples)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
         | 
| 143 | 
            +
                # please be careful if using pretrained model, make sure the vocab.txt is same
         | 
    	
        scripts/prepare_wenetspeech4tts.py
    ADDED
    
    | @@ -0,0 +1,116 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # generate audio text map for WenetSpeech4TTS
         | 
| 2 | 
            +
            # evaluate for vocab size
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import sys, os
         | 
| 5 | 
            +
            sys.path.append(os.getcwd())
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
            from tqdm import tqdm
         | 
| 9 | 
            +
            from concurrent.futures import ProcessPoolExecutor
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import torchaudio
         | 
| 12 | 
            +
            from datasets import Dataset
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from model.utils import convert_char_to_pinyin
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def deal_with_sub_path_files(dataset_path, sub_path):
         | 
| 18 | 
            +
                print(f"Dealing with: {sub_path}")
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                text_dir = os.path.join(dataset_path, sub_path, "txts")
         | 
| 21 | 
            +
                audio_dir = os.path.join(dataset_path, sub_path, "wavs")
         | 
| 22 | 
            +
                text_files = os.listdir(text_dir)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                audio_paths, texts, durations = [], [], []
         | 
| 25 | 
            +
                for text_file in tqdm(text_files):
         | 
| 26 | 
            +
                    with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
         | 
| 27 | 
            +
                        first_line = file.readline().split("\t")
         | 
| 28 | 
            +
                    audio_nm = first_line[0]
         | 
| 29 | 
            +
                    audio_path = os.path.join(audio_dir, audio_nm + ".wav")
         | 
| 30 | 
            +
                    text = first_line[1].strip()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    audio_paths.append(audio_path)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if tokenizer == "pinyin":
         | 
| 35 | 
            +
                        texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
         | 
| 36 | 
            +
                    elif tokenizer == "char":
         | 
| 37 | 
            +
                        texts.append(text)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    audio, sample_rate = torchaudio.load(audio_path)
         | 
| 40 | 
            +
                    durations.append(audio.shape[-1] / sample_rate)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                return audio_paths, texts, durations
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def main():
         | 
| 46 | 
            +
                assert tokenizer in ["pinyin", "char"]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                audio_path_list, text_list, duration_list = [], [], []
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                executor = ProcessPoolExecutor(max_workers=max_workers)
         | 
| 51 | 
            +
                futures = []
         | 
| 52 | 
            +
                for dataset_path in dataset_paths:
         | 
| 53 | 
            +
                    sub_items = os.listdir(dataset_path)
         | 
| 54 | 
            +
                    sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
         | 
| 55 | 
            +
                    for sub_path in sub_paths:
         | 
| 56 | 
            +
                        futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
         | 
| 57 | 
            +
                for future in tqdm(futures, total=len(futures)):
         | 
| 58 | 
            +
                    audio_paths, texts, durations = future.result()
         | 
| 59 | 
            +
                    audio_path_list.extend(audio_paths)
         | 
| 60 | 
            +
                    text_list.extend(texts)
         | 
| 61 | 
            +
                    duration_list.extend(durations)
         | 
| 62 | 
            +
                executor.shutdown()
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                if not os.path.exists("data"):
         | 
| 65 | 
            +
                    os.makedirs("data")
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                print(f"\nSaving to data/{dataset_name}_{tokenizer} ...")
         | 
| 68 | 
            +
                dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
         | 
| 69 | 
            +
                dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB")  # arrow format
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
         | 
| 72 | 
            +
                    json.dump({"duration": duration_list}, f, ensure_ascii=False)  # dup a json separately saving duration in case for DynamicBatchSampler ease
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
         | 
| 75 | 
            +
                text_vocab_set = set()
         | 
| 76 | 
            +
                for text in tqdm(text_list):
         | 
| 77 | 
            +
                    text_vocab_set.update(list(text))
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
         | 
| 80 | 
            +
                if tokenizer == "pinyin":
         | 
| 81 | 
            +
                    text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f:
         | 
| 84 | 
            +
                    for vocab in sorted(text_vocab_set):
         | 
| 85 | 
            +
                        f.write(vocab + "\n")
         | 
| 86 | 
            +
                print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
         | 
| 87 | 
            +
                print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
         | 
| 88 | 
            +
                
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            if __name__ == "__main__":
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                max_workers = 32
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                tokenizer = "pinyin"  # "pinyin" | "char"
         | 
| 95 | 
            +
                polyphone = True
         | 
| 96 | 
            +
                dataset_choice = 1  # 1: Premium, 2: Standard, 3: Basic
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
         | 
| 99 | 
            +
                dataset_paths = [
         | 
| 100 | 
            +
                    "<SOME_PATH>/WenetSpeech4TTS/Basic",
         | 
| 101 | 
            +
                    "<SOME_PATH>/WenetSpeech4TTS/Standard",
         | 
| 102 | 
            +
                    "<SOME_PATH>/WenetSpeech4TTS/Premium",
         | 
| 103 | 
            +
                    ][-dataset_choice:]
         | 
| 104 | 
            +
                print(f"\nChoose Dataset: {dataset_name}\n")
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                main()
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                # Results (if adding alphabets with accents and symbols):
         | 
| 109 | 
            +
                # WenetSpeech4TTS       Basic     Standard     Premium
         | 
| 110 | 
            +
                # samples count       3932473      1941220      407494
         | 
| 111 | 
            +
                # pinyin vocab size      1349         1348        1344   (no polyphone)
         | 
| 112 | 
            +
                #                           -            -        1459   (polyphone) 
         | 
| 113 | 
            +
                # char   vocab size      5264         5219        5042
         | 
| 114 | 
            +
                
         | 
| 115 | 
            +
                # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
         | 
| 116 | 
            +
                # please be careful if using pretrained model, make sure the vocab.txt is same
         | 
    	
        speech_edit.py
    ADDED
    
    | @@ -0,0 +1,182 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            import torchaudio
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
            from vocos import Vocos
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from model import CFM, UNetT, DiT, MMDiT
         | 
| 10 | 
            +
            from model.utils import (
         | 
| 11 | 
            +
                load_checkpoint,
         | 
| 12 | 
            +
                get_tokenizer, 
         | 
| 13 | 
            +
                convert_char_to_pinyin, 
         | 
| 14 | 
            +
                save_spectrogram,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            # --------------------- Dataset Settings -------------------- #
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            target_sample_rate = 24000
         | 
| 23 | 
            +
            n_mel_channels = 100
         | 
| 24 | 
            +
            hop_length = 256
         | 
| 25 | 
            +
            target_rms = 0.1
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            tokenizer = "pinyin"
         | 
| 28 | 
            +
            dataset_name = "Emilia_ZH_EN"
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            # ---------------------- infer setting ---------------------- #
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            seed = None  # int | None
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            exp_name = "F5TTS_Base"  # F5TTS_Base | E2TTS_Base
         | 
| 36 | 
            +
            ckpt_step = 1200000
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            nfe_step = 32  # 16, 32
         | 
| 39 | 
            +
            cfg_strength = 2.
         | 
| 40 | 
            +
            ode_method = 'euler'  # euler | midpoint
         | 
| 41 | 
            +
            sway_sampling_coef = -1.
         | 
| 42 | 
            +
            speed = 1.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            if exp_name == "F5TTS_Base":
         | 
| 45 | 
            +
                model_cls = DiT
         | 
| 46 | 
            +
                model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            elif exp_name == "E2TTS_Base":
         | 
| 49 | 
            +
                model_cls = UNetT
         | 
| 50 | 
            +
                model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
         | 
| 53 | 
            +
            output_dir = "tests"
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
         | 
| 56 | 
            +
            # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
         | 
| 57 | 
            +
            # [write the origin_text into a file, e.g. tests/test_edit.txt]
         | 
| 58 | 
            +
            # ctc-forced-aligner --audio_path "tests/ref_audio/test_en_1_ref_short.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
         | 
| 59 | 
            +
            # [result will be saved at same path of audio file]
         | 
| 60 | 
            +
            # [--language "zho" for Chinese, "eng" for English]
         | 
| 61 | 
            +
            # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
         | 
| 64 | 
            +
            origin_text = "Some call me nature, others call me mother nature."
         | 
| 65 | 
            +
            target_text = "Some call me optimist, others call me realist."
         | 
| 66 | 
            +
            parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ]  # stard_ends of "nature" & "mother nature", in seconds
         | 
| 67 | 
            +
            fix_duration = [1.2, 1, ]  # fix duration for "optimist" & "realist", in seconds
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
         | 
| 70 | 
            +
            # origin_text = "对,这就是我,万人敬仰的太乙真人。"
         | 
| 71 | 
            +
            # target_text = "对,那就是你,万人敬仰的太白金星。"
         | 
| 72 | 
            +
            # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
         | 
| 73 | 
            +
            # fix_duration = None  # use origin text duration
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            # -------------------------------------------------#
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            use_ema = True
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            if not os.path.exists(output_dir):
         | 
| 81 | 
            +
                os.makedirs(output_dir)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            # Vocoder model
         | 
| 84 | 
            +
            local = False
         | 
| 85 | 
            +
            if local:
         | 
| 86 | 
            +
                vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
         | 
| 87 | 
            +
                vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
         | 
| 88 | 
            +
                state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
         | 
| 89 | 
            +
                vocos.load_state_dict(state_dict)
         | 
| 90 | 
            +
                vocos.eval()
         | 
| 91 | 
            +
            else:
         | 
| 92 | 
            +
                vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            # Tokenizer
         | 
| 95 | 
            +
            vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            # Model
         | 
| 98 | 
            +
            model = CFM(
         | 
| 99 | 
            +
                transformer = model_cls(
         | 
| 100 | 
            +
                    **model_cfg,
         | 
| 101 | 
            +
                    text_num_embeds = vocab_size, 
         | 
| 102 | 
            +
                    mel_dim = n_mel_channels
         | 
| 103 | 
            +
                ),
         | 
| 104 | 
            +
                mel_spec_kwargs = dict(
         | 
| 105 | 
            +
                    target_sample_rate = target_sample_rate, 
         | 
| 106 | 
            +
                    n_mel_channels = n_mel_channels,
         | 
| 107 | 
            +
                    hop_length = hop_length,
         | 
| 108 | 
            +
                ),
         | 
| 109 | 
            +
                odeint_kwargs = dict(
         | 
| 110 | 
            +
                    method = ode_method,
         | 
| 111 | 
            +
                ),
         | 
| 112 | 
            +
                vocab_char_map = vocab_char_map,
         | 
| 113 | 
            +
            ).to(device)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            # Audio
         | 
| 118 | 
            +
            audio, sr = torchaudio.load(audio_to_edit)
         | 
| 119 | 
            +
            if audio.shape[0] > 1:
         | 
| 120 | 
            +
                audio = torch.mean(audio, dim=0, keepdim=True)
         | 
| 121 | 
            +
            rms = torch.sqrt(torch.mean(torch.square(audio)))
         | 
| 122 | 
            +
            if rms < target_rms:
         | 
| 123 | 
            +
                audio = audio * target_rms / rms
         | 
| 124 | 
            +
            if sr != target_sample_rate:
         | 
| 125 | 
            +
                resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
         | 
| 126 | 
            +
                audio = resampler(audio)
         | 
| 127 | 
            +
            offset = 0
         | 
| 128 | 
            +
            audio_ = torch.zeros(1, 0)
         | 
| 129 | 
            +
            edit_mask = torch.zeros(1, 0, dtype=torch.bool)
         | 
| 130 | 
            +
            for part in parts_to_edit:
         | 
| 131 | 
            +
                start, end = part
         | 
| 132 | 
            +
                part_dur = end - start if fix_duration is None else fix_duration.pop(0)
         | 
| 133 | 
            +
                part_dur = part_dur * target_sample_rate
         | 
| 134 | 
            +
                start = start * target_sample_rate
         | 
| 135 | 
            +
                audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
         | 
| 136 | 
            +
                edit_mask = torch.cat((edit_mask, 
         | 
| 137 | 
            +
                                       torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool), 
         | 
| 138 | 
            +
                                       torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
         | 
| 139 | 
            +
                                       ), dim = -1)
         | 
| 140 | 
            +
                offset = end * target_sample_rate
         | 
| 141 | 
            +
            # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
         | 
| 142 | 
            +
            edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
         | 
| 143 | 
            +
            audio = audio.to(device)
         | 
| 144 | 
            +
            edit_mask = edit_mask.to(device)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            # Text
         | 
| 147 | 
            +
            text_list = [target_text]
         | 
| 148 | 
            +
            if tokenizer == "pinyin":
         | 
| 149 | 
            +
                final_text_list = convert_char_to_pinyin(text_list)
         | 
| 150 | 
            +
            else:
         | 
| 151 | 
            +
                final_text_list = [text_list]
         | 
| 152 | 
            +
            print(f"text  : {text_list}")
         | 
| 153 | 
            +
            print(f"pinyin: {final_text_list}")
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            # Duration
         | 
| 156 | 
            +
            ref_audio_len = 0
         | 
| 157 | 
            +
            duration = audio.shape[-1] // hop_length
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            # Inference
         | 
| 160 | 
            +
            with torch.inference_mode():
         | 
| 161 | 
            +
                generated, trajectory = model.sample(
         | 
| 162 | 
            +
                    cond = audio,
         | 
| 163 | 
            +
                    text = final_text_list,
         | 
| 164 | 
            +
                    duration = duration,
         | 
| 165 | 
            +
                    steps = nfe_step,
         | 
| 166 | 
            +
                    cfg_strength = cfg_strength,
         | 
| 167 | 
            +
                    sway_sampling_coef = sway_sampling_coef,
         | 
| 168 | 
            +
                    seed = seed,
         | 
| 169 | 
            +
                    edit_mask = edit_mask,
         | 
| 170 | 
            +
                )
         | 
| 171 | 
            +
            print(f"Generated mel: {generated.shape}")
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            # Final result
         | 
| 174 | 
            +
            generated = generated[:, ref_audio_len:, :]
         | 
| 175 | 
            +
            generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
         | 
| 176 | 
            +
            generated_wave = vocos.decode(generated_mel_spec.cpu())
         | 
| 177 | 
            +
            if rms < target_rms:
         | 
| 178 | 
            +
                generated_wave = generated_wave * rms / target_rms
         | 
| 179 | 
            +
             | 
| 180 | 
            +
            save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png")
         | 
| 181 | 
            +
            torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate)
         | 
| 182 | 
            +
            print(f"Generated wav: {generated_wave.shape}")
         | 
    	
        train.py
    ADDED
    
    | @@ -0,0 +1,91 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from model import CFM, UNetT, DiT, MMDiT, Trainer
         | 
| 2 | 
            +
            from model.utils import get_tokenizer
         | 
| 3 | 
            +
            from model.dataset import load_dataset
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            # -------------------------- Dataset Settings --------------------------- #
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            target_sample_rate = 24000
         | 
| 9 | 
            +
            n_mel_channels = 100
         | 
| 10 | 
            +
            hop_length = 256
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            tokenizer = "pinyin"
         | 
| 13 | 
            +
            dataset_name = "Emilia_ZH_EN"
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            # -------------------------- Training Settings -------------------------- #
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            exp_name = "F5TTS_Base"  # F5TTS_Base | E2TTS_Base
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            learning_rate = 7.5e-5
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            batch_size_per_gpu = 38400  # 8 GPUs, 8 * 38400 = 307200
         | 
| 23 | 
            +
            batch_size_type = "frame"  # "frame" or "sample"
         | 
| 24 | 
            +
            max_samples = 64  # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
         | 
| 25 | 
            +
            grad_accumulation_steps = 1  # note: updates = steps / grad_accumulation_steps
         | 
| 26 | 
            +
            max_grad_norm = 1.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            epochs = 11  # use linear decay, thus epochs control the slope
         | 
| 29 | 
            +
            num_warmup_updates = 20000  # warmup steps
         | 
| 30 | 
            +
            save_per_updates = 50000  # save checkpoint per steps
         | 
| 31 | 
            +
            last_per_steps = 5000  # save last checkpoint per steps
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            # model params
         | 
| 34 | 
            +
            if exp_name == "F5TTS_Base":
         | 
| 35 | 
            +
                wandb_resume_id = None
         | 
| 36 | 
            +
                model_cls = DiT
         | 
| 37 | 
            +
                model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
         | 
| 38 | 
            +
            elif exp_name == "E2TTS_Base":
         | 
| 39 | 
            +
                wandb_resume_id = None
         | 
| 40 | 
            +
                model_cls = UNetT
         | 
| 41 | 
            +
                model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            # ----------------------------------------------------------------------- #
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            def main():
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                mel_spec_kwargs = dict(
         | 
| 51 | 
            +
                        target_sample_rate = target_sample_rate, 
         | 
| 52 | 
            +
                        n_mel_channels = n_mel_channels,
         | 
| 53 | 
            +
                        hop_length = hop_length,
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                
         | 
| 56 | 
            +
                e2tts = CFM(
         | 
| 57 | 
            +
                    transformer = model_cls(
         | 
| 58 | 
            +
                        **model_cfg,
         | 
| 59 | 
            +
                        text_num_embeds = vocab_size, 
         | 
| 60 | 
            +
                        mel_dim = n_mel_channels
         | 
| 61 | 
            +
                    ),
         | 
| 62 | 
            +
                    mel_spec_kwargs = mel_spec_kwargs,
         | 
| 63 | 
            +
                    vocab_char_map = vocab_char_map,
         | 
| 64 | 
            +
                )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                trainer = Trainer(
         | 
| 67 | 
            +
                    e2tts,
         | 
| 68 | 
            +
                    epochs, 
         | 
| 69 | 
            +
                    learning_rate,
         | 
| 70 | 
            +
                    num_warmup_updates = num_warmup_updates,
         | 
| 71 | 
            +
                    save_per_updates = save_per_updates, 
         | 
| 72 | 
            +
                    checkpoint_path = f'ckpts/{exp_name}',
         | 
| 73 | 
            +
                    batch_size = batch_size_per_gpu, 
         | 
| 74 | 
            +
                    batch_size_type = batch_size_type,
         | 
| 75 | 
            +
                    max_samples = max_samples,
         | 
| 76 | 
            +
                    grad_accumulation_steps = grad_accumulation_steps,
         | 
| 77 | 
            +
                    max_grad_norm = max_grad_norm,
         | 
| 78 | 
            +
                    wandb_project = "CFM-TTS",
         | 
| 79 | 
            +
                    wandb_run_name = exp_name,
         | 
| 80 | 
            +
                    wandb_resume_id = wandb_resume_id,
         | 
| 81 | 
            +
                    last_per_steps = last_per_steps,
         | 
| 82 | 
            +
                )
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
         | 
| 85 | 
            +
                trainer.train(train_dataset, 
         | 
| 86 | 
            +
                              resumable_with_seed = 666 # seed for shuffling dataset
         | 
| 87 | 
            +
                              )
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            if __name__ == '__main__':
         | 
| 91 | 
            +
                main()
         | 
