kevinwang676 commited on
Commit
8ad4e11
·
verified ·
1 Parent(s): 784a8e5

Add files using upload-large-folder tool

Browse files
Files changed (43) hide show
  1. .github/ISSUE_TEMPLATE/bug_report.md +38 -0
  2. .github/ISSUE_TEMPLATE/feature_request.md +20 -0
  3. .github/workflows/lint.yml +56 -0
  4. .github/workflows/stale-issues.yml +22 -0
  5. .gitignore +52 -0
  6. .gitmodules +3 -0
  7. CODE_OF_CONDUCT.md +76 -0
  8. FAQ.md +16 -0
  9. LICENSE +201 -0
  10. README.md +237 -3
  11. asset/dingding.png +0 -0
  12. cosyvoice/__init__.py +0 -0
  13. cosyvoice/bin/export_jit.py +91 -0
  14. cosyvoice/bin/export_onnx.py +116 -0
  15. cosyvoice/bin/export_trt.sh +10 -0
  16. cosyvoice/bin/inference.py +115 -0
  17. cosyvoice/bin/train.py +170 -0
  18. cosyvoice/cli/cosyvoice.py +173 -0
  19. cosyvoice/cli/model.py +411 -0
  20. cosyvoice/dataset/__init__.py +0 -0
  21. cosyvoice/dataset/dataset.py +164 -0
  22. cosyvoice/flow/decoder.py +301 -0
  23. cosyvoice/flow/flow.py +239 -0
  24. cosyvoice/flow/flow_matching.py +217 -0
  25. cosyvoice/flow/length_regulator.py +69 -0
  26. cosyvoice/hifigan/discriminator.py +140 -0
  27. cosyvoice/hifigan/f0_predictor.py +55 -0
  28. cosyvoice/hifigan/generator.py +411 -0
  29. cosyvoice/llm/llm.py +434 -0
  30. cosyvoice/transformer/__init__.py +0 -0
  31. cosyvoice/transformer/attention.py +330 -0
  32. cosyvoice/transformer/convolution.py +145 -0
  33. cosyvoice/transformer/decoder.py +396 -0
  34. cosyvoice/transformer/decoder_layer.py +132 -0
  35. cosyvoice/utils/__init__.py +0 -0
  36. cosyvoice/utils/class_utils.py +83 -0
  37. cosyvoice/utils/scheduler.py +738 -0
  38. examples/libritts/cosyvoice/local/prepare_data.py +53 -0
  39. examples/libritts/cosyvoice/path.sh +3 -0
  40. requirements.txt +38 -0
  41. runtime/python/Dockerfile +13 -0
  42. runtime/python/grpc/cosyvoice.proto +43 -0
  43. webui.py +200 -0
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **To Reproduce**
14
+ Steps to reproduce the behavior:
15
+ 1. Go to '...'
16
+ 2. Click on '....'
17
+ 3. Scroll down to '....'
18
+ 4. See error
19
+
20
+ **Expected behavior**
21
+ A clear and concise description of what you expected to happen.
22
+
23
+ **Screenshots**
24
+ If applicable, add screenshots to help explain your problem.
25
+
26
+ **Desktop (please complete the following information):**
27
+ - OS: [e.g. iOS]
28
+ - Browser [e.g. chrome, safari]
29
+ - Version [e.g. 22]
30
+
31
+ **Smartphone (please complete the following information):**
32
+ - Device: [e.g. iPhone6]
33
+ - OS: [e.g. iOS8.1]
34
+ - Browser [e.g. stock browser, safari]
35
+ - Version [e.g. 22]
36
+
37
+ **Additional context**
38
+ Add any other context about the problem here.
.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Is your feature request related to a problem? Please describe.**
11
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12
+
13
+ **Describe the solution you'd like**
14
+ A clear and concise description of what you want to happen.
15
+
16
+ **Describe alternatives you've considered**
17
+ A clear and concise description of any alternative solutions or features you've considered.
18
+
19
+ **Additional context**
20
+ Add any other context or screenshots about the feature request here.
.github/workflows/lint.yml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint
2
+
3
+ on:
4
+ pull_request:
5
+ push:
6
+
7
+ jobs:
8
+ quick-checks:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - name: Fetch CosyVoice
12
+ uses: actions/checkout@v1
13
+ - name: Checkout PR tip
14
+ run: |
15
+ set -eux
16
+ if [[ "${{ github.event_name }}" == "pull_request" ]]; then
17
+ # We are on a PR, so actions/checkout leaves us on a merge commit.
18
+ # Check out the actual tip of the branch.
19
+ git checkout ${{ github.event.pull_request.head.sha }}
20
+ fi
21
+ echo ::set-output name=commit_sha::$(git rev-parse HEAD)
22
+ id: get_pr_tip
23
+ - name: Ensure no tabs
24
+ run: |
25
+ (! git grep -I -l $'\t' -- . ':(exclude)*.txt' ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have tabs; please convert them to spaces"; false))
26
+ - name: Ensure no trailing whitespace
27
+ run: |
28
+ (! git grep -I -n $' $' -- . ':(exclude)*.txt' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have trailing whitespace; please remove them"; false))
29
+
30
+ flake8-py3:
31
+ runs-on: ubuntu-latest
32
+ steps:
33
+ - name: Setup Python
34
+ uses: actions/setup-python@v1
35
+ with:
36
+ python-version: 3.9
37
+ architecture: x64
38
+ - name: Fetch CosyVoice
39
+ uses: actions/checkout@v1
40
+ - name: Checkout PR tip
41
+ run: |
42
+ set -eux
43
+ if [[ "${{ github.event_name }}" == "pull_request" ]]; then
44
+ # We are on a PR, so actions/checkout leaves us on a merge commit.
45
+ # Check out the actual tip of the branch.
46
+ git checkout ${{ github.event.pull_request.head.sha }}
47
+ fi
48
+ echo ::set-output name=commit_sha::$(git rev-parse HEAD)
49
+ id: get_pr_tip
50
+ - name: Run flake8
51
+ run: |
52
+ set -eux
53
+ pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
54
+ flake8 --version
55
+ flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
56
+ if [ $? != 0 ]; then exit 1; fi
.github/workflows/stale-issues.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Close inactive issues
2
+ on:
3
+ schedule:
4
+ - cron: "30 1 * * *"
5
+
6
+ jobs:
7
+ close-issues:
8
+ runs-on: ubuntu-latest
9
+ permissions:
10
+ issues: write
11
+ pull-requests: write
12
+ steps:
13
+ - uses: actions/stale@v5
14
+ with:
15
+ days-before-issue-stale: 30
16
+ days-before-issue-close: 14
17
+ stale-issue-label: "stale"
18
+ stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
19
+ close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
20
+ days-before-pr-stale: -1
21
+ days-before-pr-close: -1
22
+ repo-token: ${{ secrets.GITHUB_TOKEN }}
.gitignore ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Visual Studio Code files
7
+ .vscode
8
+ .vs
9
+
10
+ # PyCharm files
11
+ .idea
12
+
13
+ # Eclipse Project settings
14
+ *.*project
15
+ .settings
16
+
17
+ # Sublime Text settings
18
+ *.sublime-workspace
19
+ *.sublime-project
20
+
21
+ # Editor temporaries
22
+ *.swn
23
+ *.swo
24
+ *.swp
25
+ *.swm
26
+ *~
27
+
28
+ # IPython notebook checkpoints
29
+ .ipynb_checkpoints
30
+
31
+ # macOS dir files
32
+ .DS_Store
33
+
34
+ exp
35
+ data
36
+ raw_wav
37
+ tensorboard
38
+ **/*build*
39
+
40
+ # Clangd files
41
+ .cache
42
+ compile_commands.json
43
+
44
+ # train/inference files
45
+ *.wav
46
+ *.m4a
47
+ *.aac
48
+ *.pt
49
+ pretrained_models/*
50
+ *_pb2_grpc.py
51
+ *_pb2.py
52
+ *.tar
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "third_party/Matcha-TTS"]
2
+ path = third_party/Matcha-TTS
3
+ url = https://github.com/shivammehta25/Matcha-TTS.git
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to making participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies both within project spaces and in public spaces
49
+ when an individual is representing the project or its community. Examples of
50
+ representing a project or community include using an official project e-mail
51
+ address, posting via an official social media account, or acting as an appointed
52
+ representative at an online or offline event. Representation of a project may be
53
+ further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at [email protected]. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
FAQ.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## ModuleNotFoundError: No module named 'matcha'
2
+
3
+ Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`.
4
+
5
+ run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script.
6
+
7
+ ## cannot find resource.zip or cannot unzip resource.zip
8
+
9
+ Please make sure you have git-lfs installed. Execute
10
+
11
+ ```sh
12
+ git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
13
+ cd pretrained_models/CosyVoice-ttsfrd/
14
+ unzip resource.zip -d .
15
+ pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
16
+ ```
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,237 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)](https://github.com/Akshay090/svg-banners)
2
+
3
+ ## 👉🏻 CosyVoice 👈🏻
4
+ **CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
5
+
6
+ **CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
7
+
8
+ ## Highlight🔥
9
+
10
+ **CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities.
11
+ ### Multilingual
12
+ - **Supported Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.)
13
+ - **Crosslingual & Mixlingual**:Support zero-shot voice cloning for cross-lingual and code-switching scenarios.
14
+ ### Ultra-Low Latency
15
+ - **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies.
16
+ - **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output.
17
+ ### High Accuracy
18
+ - **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0.
19
+ - **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set.
20
+ ### Strong Stability
21
+ - **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis.
22
+ - **Cross-language Synthesis**: Marked improvements compared to version 1.0.
23
+ ### Natural Experience
24
+ - **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53.
25
+ - **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments.
26
+
27
+ ## Roadmap
28
+
29
+ - [x] 2024/12
30
+
31
+ - [x] 25hz cosyvoice 2.0 released
32
+
33
+ - [x] 2024/09
34
+
35
+ - [x] 25hz cosyvoice base model
36
+ - [x] 25hz cosyvoice voice conversion model
37
+
38
+ - [x] 2024/08
39
+
40
+ - [x] Repetition Aware Sampling(RAS) inference for llm stability
41
+ - [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
42
+
43
+ - [x] 2024/07
44
+
45
+ - [x] Flow matching training support
46
+ - [x] WeTextProcessing support when ttsfrd is not available
47
+ - [x] Fastapi server and client
48
+
49
+
50
+ ## Install
51
+
52
+ **Clone and install**
53
+
54
+ - Clone the repo
55
+ ``` sh
56
+ git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
57
+ # If you failed to clone submodule due to network failures, please run following command until success
58
+ cd CosyVoice
59
+ git submodule update --init --recursive
60
+ ```
61
+
62
+ - Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
63
+ - Create Conda env:
64
+
65
+ ``` sh
66
+ conda create -n cosyvoice -y python=3.10
67
+ conda activate cosyvoice
68
+ # pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
69
+ conda install -y -c conda-forge pynini==2.1.5
70
+ pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
71
+
72
+ # If you encounter sox compatibility issues
73
+ # ubuntu
74
+ sudo apt-get install sox libsox-dev
75
+ # centos
76
+ sudo yum install sox sox-devel
77
+ ```
78
+
79
+ **Model download**
80
+
81
+ We strongly recommend that you download our pretrained `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
82
+
83
+ ``` python
84
+ # SDK模型下载
85
+ from modelscope import snapshot_download
86
+ snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
87
+ snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
88
+ snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
89
+ snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
90
+ snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
91
+ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
92
+ ```
93
+
94
+ ``` sh
95
+ # git模型下载,请确保已安装git lfs
96
+ mkdir -p pretrained_models
97
+ git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
98
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
99
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
100
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
101
+ git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
102
+ git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
103
+ ```
104
+
105
+ Optionally, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
106
+
107
+ Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
108
+
109
+ ``` sh
110
+ cd pretrained_models/CosyVoice-ttsfrd/
111
+ unzip resource.zip -d .
112
+ pip install ttsfrd_dependency-0.1-py3-none-any.whl
113
+ pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
114
+ ```
115
+
116
+ **Basic Usage**
117
+
118
+ We strongly recommend using `CosyVoice2-0.5B` for better performance.
119
+ Follow code below for detailed usage of each model.
120
+
121
+ ``` python
122
+ import sys
123
+ sys.path.append('third_party/Matcha-TTS')
124
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
125
+ from cosyvoice.utils.file_utils import load_wav
126
+ import torchaudio
127
+ ```
128
+
129
+ **CosyVoice2 Usage**
130
+ ```python
131
+ cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
132
+
133
+ # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
134
+ # zero_shot usage
135
+ prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
136
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
137
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
138
+
139
+ # fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
140
+ for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
141
+ torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
142
+
143
+ # instruct usage
144
+ for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
145
+ torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
146
+
147
+ # bistream usage, you can use generator as input, this is useful when using text llm model as input
148
+ # NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
149
+ def text_generator():
150
+ yield '收到好友从远方寄来的生日礼物,'
151
+ yield '那份意外的惊喜与深深的祝福'
152
+ yield '让我心中充满了甜蜜的快乐,'
153
+ yield '笑容如花儿般绽放。'
154
+ for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
155
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
156
+ ```
157
+
158
+ **CosyVoice Usage**
159
+ ```python
160
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
161
+ # sft usage
162
+ print(cosyvoice.list_available_spks())
163
+ # change stream=True for chunk stream inference
164
+ for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
165
+ torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
166
+
167
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M') # or change to pretrained_models/CosyVoice-300M-25Hz for 25Hz inference
168
+ # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
169
+ prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
170
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
171
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
172
+ # cross_lingual usage
173
+ prompt_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
174
+ for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
175
+ torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
176
+ # vc usage
177
+ prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
178
+ source_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
179
+ for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
180
+ torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
181
+
182
+ cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
183
+ # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
184
+ for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
185
+ torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
186
+ ```
187
+
188
+ **Start web demo**
189
+
190
+ You can use our web demo page to get familiar with CosyVoice quickly.
191
+
192
+ Please see the demo website for details.
193
+
194
+ ``` python
195
+ # change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
196
+ python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
197
+ ```
198
+
199
+ **Advanced Usage**
200
+
201
+ For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
202
+
203
+ **Build for deployment**
204
+
205
+ Optionally, if you want service deployment,
206
+ you can run following steps.
207
+
208
+ ``` sh
209
+ cd runtime/python
210
+ docker build -t cosyvoice:v1.0 .
211
+ # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
212
+ # for grpc usage
213
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
214
+ cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
215
+ # for fastapi usage
216
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
217
+ cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
218
+ ```
219
+
220
+ ## Discussion & Communication
221
+
222
+ You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
223
+
224
+ You can also scan the QR code to join our official Dingding chat group.
225
+
226
+ <img src="./asset/dingding.png" width="250px">
227
+
228
+ ## Acknowledge
229
+
230
+ 1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
231
+ 2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
232
+ 3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
233
+ 4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
234
+ 5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
235
+
236
+ ## Disclaimer
237
+ The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
asset/dingding.png ADDED
cosyvoice/__init__.py ADDED
File without changes
cosyvoice/bin/export_jit.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import sys
22
+ import torch
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/../..'.format(ROOT_DIR))
25
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
27
+
28
+
29
+ def get_args():
30
+ parser = argparse.ArgumentParser(description='export your model for deployment')
31
+ parser.add_argument('--model_dir',
32
+ type=str,
33
+ default='pretrained_models/CosyVoice-300M',
34
+ help='local path')
35
+ args = parser.parse_args()
36
+ print(args)
37
+ return args
38
+
39
+
40
+ def get_optimized_script(model, preserved_attrs=[]):
41
+ script = torch.jit.script(model)
42
+ if preserved_attrs != []:
43
+ script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
44
+ else:
45
+ script = torch.jit.freeze(script)
46
+ script = torch.jit.optimize_for_inference(script)
47
+ return script
48
+
49
+
50
+ def main():
51
+ args = get_args()
52
+ logging.basicConfig(level=logging.DEBUG,
53
+ format='%(asctime)s %(levelname)s %(message)s')
54
+
55
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
56
+ torch._C._jit_set_profiling_mode(False)
57
+ torch._C._jit_set_profiling_executor(False)
58
+
59
+ try:
60
+ model = CosyVoice(args.model_dir)
61
+ except Exception:
62
+ try:
63
+ model = CosyVoice2(args.model_dir)
64
+ except Exception:
65
+ raise TypeError('no valid model_type!')
66
+
67
+ if not isinstance(model, CosyVoice2):
68
+ # 1. export llm text_encoder
69
+ llm_text_encoder = model.model.llm.text_encoder
70
+ script = get_optimized_script(llm_text_encoder)
71
+ script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
72
+ script = get_optimized_script(llm_text_encoder.half())
73
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
74
+
75
+ # 2. export llm llm
76
+ llm_llm = model.model.llm.llm
77
+ script = get_optimized_script(llm_llm, ['forward_chunk'])
78
+ script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
79
+ script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
80
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
81
+
82
+ # 3. export flow encoder
83
+ flow_encoder = model.model.flow.encoder
84
+ script = get_optimized_script(flow_encoder)
85
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
86
+ script = get_optimized_script(flow_encoder.half())
87
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
88
+
89
+
90
+ if __name__ == '__main__':
91
+ main()
cosyvoice/bin/export_onnx.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ import onnxruntime
24
+ import random
25
+ import torch
26
+ from tqdm import tqdm
27
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.append('{}/../..'.format(ROOT_DIR))
29
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
31
+
32
+
33
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
34
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
35
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
36
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
37
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
38
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
39
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
40
+ return x, mask, mu, t, spks, cond
41
+
42
+
43
+ def get_args():
44
+ parser = argparse.ArgumentParser(description='export your model for deployment')
45
+ parser.add_argument('--model_dir',
46
+ type=str,
47
+ default='pretrained_models/CosyVoice-300M',
48
+ help='local path')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+
59
+ try:
60
+ model = CosyVoice(args.model_dir)
61
+ except Exception:
62
+ try:
63
+ model = CosyVoice2(args.model_dir)
64
+ except Exception:
65
+ raise TypeError('no valid model_type!')
66
+
67
+ # 1. export flow decoder estimator
68
+ estimator = model.model.flow.decoder.estimator
69
+
70
+ device = model.model.device
71
+ batch_size, seq_len = 2, 256
72
+ out_channels = model.model.flow.decoder.estimator.out_channels
73
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
74
+ torch.onnx.export(
75
+ estimator,
76
+ (x, mask, mu, t, spks, cond),
77
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
78
+ export_params=True,
79
+ opset_version=18,
80
+ do_constant_folding=True,
81
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
82
+ output_names=['estimator_out'],
83
+ dynamic_axes={
84
+ 'x': {2: 'seq_len'},
85
+ 'mask': {2: 'seq_len'},
86
+ 'mu': {2: 'seq_len'},
87
+ 'cond': {2: 'seq_len'},
88
+ 'estimator_out': {2: 'seq_len'},
89
+ }
90
+ )
91
+
92
+ # 2. test computation consistency
93
+ option = onnxruntime.SessionOptions()
94
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
95
+ option.intra_op_num_threads = 1
96
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
97
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
98
+ sess_options=option, providers=providers)
99
+
100
+ for _ in tqdm(range(10)):
101
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
102
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
103
+ ort_inputs = {
104
+ 'x': x.cpu().numpy(),
105
+ 'mask': mask.cpu().numpy(),
106
+ 'mu': mu.cpu().numpy(),
107
+ 't': t.cpu().numpy(),
108
+ 'spks': spks.cpu().numpy(),
109
+ 'cond': cond.cpu().numpy()
110
+ }
111
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
112
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
cosyvoice/bin/export_trt.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright 2024 Alibaba Inc. All Rights Reserved.
3
+ # download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
4
+ # for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
5
+ TRT_DIR=<YOUR_TRT_DIR>
6
+ MODEL_DIR=<COSYVOICE2_MODEL_DIR>
7
+
8
+ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
9
+ $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
10
+ $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
cosyvoice/bin/inference.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+ import torch
22
+ from torch.utils.data import DataLoader
23
+ import torchaudio
24
+ from hyperpyyaml import load_hyperpyyaml
25
+ from tqdm import tqdm
26
+ from cosyvoice.cli.model import CosyVoiceModel
27
+ from cosyvoice.dataset.dataset import Dataset
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser(description='inference with your model')
32
+ parser.add_argument('--config', required=True, help='config file')
33
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
34
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
35
+ parser.add_argument('--tts_text', required=True, help='tts input file')
36
+ parser.add_argument('--llm_model', required=True, help='llm model file')
37
+ parser.add_argument('--flow_model', required=True, help='flow model file')
38
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
39
+ parser.add_argument('--gpu',
40
+ type=int,
41
+ default=-1,
42
+ help='gpu id for this rank, -1 for cpu')
43
+ parser.add_argument('--mode',
44
+ default='sft',
45
+ choices=['sft', 'zero_shot'],
46
+ help='inference mode')
47
+ parser.add_argument('--result_dir', required=True, help='asr result file')
48
+ args = parser.parse_args()
49
+ print(args)
50
+ return args
51
+
52
+
53
+ def main():
54
+ args = get_args()
55
+ logging.basicConfig(level=logging.DEBUG,
56
+ format='%(asctime)s %(levelname)s %(message)s')
57
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
58
+
59
+ # Init cosyvoice models from configs
60
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
61
+ device = torch.device('cuda' if use_cuda else 'cpu')
62
+ with open(args.config, 'r') as f:
63
+ configs = load_hyperpyyaml(f)
64
+
65
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
66
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
67
+
68
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
69
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
+
72
+ del configs
73
+ os.makedirs(args.result_dir, exist_ok=True)
74
+ fn = os.path.join(args.result_dir, 'wav.scp')
75
+ f = open(fn, 'w')
76
+ with torch.no_grad():
77
+ for _, batch in tqdm(enumerate(test_data_loader)):
78
+ utts = batch["utts"]
79
+ assert len(utts) == 1, "inference mode only support batchsize 1"
80
+ text_token = batch["text_token"].to(device)
81
+ text_token_len = batch["text_token_len"].to(device)
82
+ tts_index = batch["tts_index"]
83
+ tts_text_token = batch["tts_text_token"].to(device)
84
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
85
+ speech_token = batch["speech_token"].to(device)
86
+ speech_token_len = batch["speech_token_len"].to(device)
87
+ speech_feat = batch["speech_feat"].to(device)
88
+ speech_feat_len = batch["speech_feat_len"].to(device)
89
+ utt_embedding = batch["utt_embedding"].to(device)
90
+ spk_embedding = batch["spk_embedding"].to(device)
91
+ if args.mode == 'sft':
92
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
93
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
94
+ else:
95
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
96
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
97
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
98
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
99
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
100
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
101
+ tts_speeches = []
102
+ for model_output in model.tts(**model_input):
103
+ tts_speeches.append(model_output['tts_speech'])
104
+ tts_speeches = torch.concat(tts_speeches, dim=1)
105
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
106
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
107
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
108
+ f.write('{} {}\n'.format(tts_key, tts_fn))
109
+ f.flush()
110
+ f.close()
111
+ logging.info('Result wav.scp saved in {}'.format(fn))
112
+
113
+
114
+ if __name__ == '__main__':
115
+ main()
cosyvoice/bin/train.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import os
22
+ import torch
23
+ import torch.distributed as dist
24
+ import deepspeed
25
+
26
+ from hyperpyyaml import load_hyperpyyaml
27
+
28
+ from torch.distributed.elastic.multiprocessing.errors import record
29
+
30
+ from cosyvoice.utils.executor import Executor
31
+ from cosyvoice.utils.train_utils import (
32
+ init_distributed,
33
+ init_dataset_and_dataloader,
34
+ init_optimizer_and_scheduler,
35
+ init_summarywriter, save_model,
36
+ wrap_cuda_model, check_modify_and_save_config)
37
+
38
+
39
+ def get_args():
40
+ parser = argparse.ArgumentParser(description='training your network')
41
+ parser.add_argument('--train_engine',
42
+ default='torch_ddp',
43
+ choices=['torch_ddp', 'deepspeed'],
44
+ help='Engine for paralleled training')
45
+ parser.add_argument('--model', required=True, help='model which will be trained')
46
+ parser.add_argument('--config', required=True, help='config file')
47
+ parser.add_argument('--train_data', required=True, help='train data file')
48
+ parser.add_argument('--cv_data', required=True, help='cv data file')
49
+ parser.add_argument('--checkpoint', help='checkpoint model')
50
+ parser.add_argument('--model_dir', required=True, help='save model dir')
51
+ parser.add_argument('--tensorboard_dir',
52
+ default='tensorboard',
53
+ help='tensorboard log dir')
54
+ parser.add_argument('--ddp.dist_backend',
55
+ dest='dist_backend',
56
+ default='nccl',
57
+ choices=['nccl', 'gloo'],
58
+ help='distributed backend')
59
+ parser.add_argument('--num_workers',
60
+ default=0,
61
+ type=int,
62
+ help='num of subprocess workers for reading')
63
+ parser.add_argument('--prefetch',
64
+ default=100,
65
+ type=int,
66
+ help='prefetch number')
67
+ parser.add_argument('--pin_memory',
68
+ action='store_true',
69
+ default=False,
70
+ help='Use pinned memory buffers used for reading')
71
+ parser.add_argument('--use_amp',
72
+ action='store_true',
73
+ default=False,
74
+ help='Use automatic mixed precision training')
75
+ parser.add_argument('--deepspeed.save_states',
76
+ dest='save_states',
77
+ default='model_only',
78
+ choices=['model_only', 'model+optimizer'],
79
+ help='save model/optimizer states')
80
+ parser.add_argument('--timeout',
81
+ default=60,
82
+ type=int,
83
+ help='timeout (in seconds) of cosyvoice_join.')
84
+ parser = deepspeed.add_config_arguments(parser)
85
+ args = parser.parse_args()
86
+ return args
87
+
88
+
89
+ @record
90
+ def main():
91
+ args = get_args()
92
+ logging.basicConfig(level=logging.DEBUG,
93
+ format='%(asctime)s %(levelname)s %(message)s')
94
+ # gan train has some special initialization logic
95
+ gan = True if args.model == 'hifigan' else False
96
+
97
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
98
+ if gan is True:
99
+ override_dict.pop('hift')
100
+ with open(args.config, 'r') as f:
101
+ configs = load_hyperpyyaml(f, overrides=override_dict)
102
+ if gan is True:
103
+ configs['train_conf'] = configs['train_conf_gan']
104
+ configs['train_conf'].update(vars(args))
105
+
106
+ # Init env for ddp
107
+ init_distributed(args)
108
+
109
+ # Get dataset & dataloader
110
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
111
+ init_dataset_and_dataloader(args, configs, gan)
112
+
113
+ # Do some sanity checks and save config to arsg.model_dir
114
+ configs = check_modify_and_save_config(args, configs)
115
+
116
+ # Tensorboard summary
117
+ writer = init_summarywriter(args)
118
+
119
+ # load checkpoint
120
+ model = configs[args.model]
121
+ start_step, start_epoch = 0, -1
122
+ if args.checkpoint is not None:
123
+ if os.path.exists(args.checkpoint):
124
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
125
+ model.load_state_dict(state_dict, strict=False)
126
+ if 'step' in state_dict:
127
+ start_step = state_dict['step']
128
+ if 'epoch' in state_dict:
129
+ start_epoch = state_dict['epoch']
130
+ else:
131
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
132
+
133
+ # Dispatch model from cpu to gpu
134
+ model = wrap_cuda_model(args, model)
135
+
136
+ # Get optimizer & scheduler
137
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
138
+ scheduler.set_step(start_step)
139
+ if scheduler_d is not None:
140
+ scheduler_d.set_step(start_step)
141
+
142
+ # Save init checkpoints
143
+ info_dict = deepcopy(configs['train_conf'])
144
+ info_dict['step'] = start_step
145
+ info_dict['epoch'] = start_epoch
146
+ save_model(model, 'init', info_dict)
147
+
148
+ # Get executor
149
+ executor = Executor(gan=gan)
150
+ executor.step = start_step
151
+
152
+ # Init scaler, used for pytorch amp mixed precision training
153
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
154
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
155
+ # Start training loop
156
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
157
+ executor.epoch = epoch
158
+ train_dataset.set_epoch(epoch)
159
+ dist.barrier()
160
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
161
+ if gan is True:
162
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
163
+ writer, info_dict, scaler, group_join)
164
+ else:
165
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
166
+ dist.destroy_process_group(group_join)
167
+
168
+
169
+ if __name__ == '__main__':
170
+ main()
cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import time
16
+ from typing import Generator
17
+ from tqdm import tqdm
18
+ from hyperpyyaml import load_hyperpyyaml
19
+ from modelscope import snapshot_download
20
+ import torch
21
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
22
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
23
+ from cosyvoice.utils.file_utils import logging
24
+ from cosyvoice.utils.class_utils import get_model_type
25
+
26
+
27
+ class CosyVoice:
28
+
29
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
30
+ self.instruct = True if '-Instruct' in model_dir else False
31
+ self.model_dir = model_dir
32
+ self.fp16 = fp16
33
+ if not os.path.exists(model_dir):
34
+ model_dir = snapshot_download(model_dir)
35
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
36
+ configs = load_hyperpyyaml(f)
37
+ assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
38
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
39
+ configs['feat_extractor'],
40
+ '{}/campplus.onnx'.format(model_dir),
41
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
42
+ '{}/spk2info.pt'.format(model_dir),
43
+ configs['allowed_special'])
44
+ self.sample_rate = configs['sample_rate']
45
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
46
+ load_jit, load_trt, fp16 = False, False, False
47
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
48
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
49
+ self.model.load('{}/llm.pt'.format(model_dir),
50
+ '{}/flow.pt'.format(model_dir),
51
+ '{}/hift.pt'.format(model_dir))
52
+ if load_jit:
53
+ self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
54
+ '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
55
+ '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
56
+ if load_trt:
57
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
58
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
59
+ self.fp16)
60
+ del configs
61
+
62
+ def list_available_spks(self):
63
+ spks = list(self.frontend.spk2info.keys())
64
+ return spks
65
+
66
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
67
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
68
+ model_input = self.frontend.frontend_sft(i, spk_id)
69
+ start_time = time.time()
70
+ logging.info('synthesis text {}'.format(i))
71
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
72
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
73
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
74
+ yield model_output
75
+ start_time = time.time()
76
+
77
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
78
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
79
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
80
+ if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
81
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
82
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
83
+ start_time = time.time()
84
+ logging.info('synthesis text {}'.format(i))
85
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
86
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
87
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
88
+ yield model_output
89
+ start_time = time.time()
90
+
91
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
92
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
93
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
94
+ start_time = time.time()
95
+ logging.info('synthesis text {}'.format(i))
96
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
97
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
98
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
99
+ yield model_output
100
+ start_time = time.time()
101
+
102
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
103
+ assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
104
+ if self.instruct is False:
105
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
106
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
107
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
108
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
109
+ start_time = time.time()
110
+ logging.info('synthesis text {}'.format(i))
111
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
112
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
113
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
114
+ yield model_output
115
+ start_time = time.time()
116
+
117
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
118
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
119
+ start_time = time.time()
120
+ for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
121
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
122
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
123
+ yield model_output
124
+ start_time = time.time()
125
+
126
+
127
+ class CosyVoice2(CosyVoice):
128
+
129
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
130
+ self.instruct = True if '-Instruct' in model_dir else False
131
+ self.model_dir = model_dir
132
+ self.fp16 = fp16
133
+ if not os.path.exists(model_dir):
134
+ model_dir = snapshot_download(model_dir)
135
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
136
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
137
+ assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
138
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
139
+ configs['feat_extractor'],
140
+ '{}/campplus.onnx'.format(model_dir),
141
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
142
+ '{}/spk2info.pt'.format(model_dir),
143
+ configs['allowed_special'])
144
+ self.sample_rate = configs['sample_rate']
145
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
146
+ load_jit, load_trt, fp16 = False, False, False
147
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
148
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
149
+ self.model.load('{}/llm.pt'.format(model_dir),
150
+ '{}/flow.pt'.format(model_dir),
151
+ '{}/hift.pt'.format(model_dir))
152
+ if load_jit:
153
+ self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
154
+ if load_trt:
155
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
156
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
157
+ self.fp16)
158
+ del configs
159
+
160
+ def inference_instruct(self, *args, **kwargs):
161
+ raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
162
+
163
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
164
+ assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
165
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
166
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
167
+ start_time = time.time()
168
+ logging.info('synthesis text {}'.format(i))
169
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
170
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
171
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
172
+ yield model_output
173
+ start_time = time.time()
cosyvoice/cli/model.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Generator
16
+ import torch
17
+ import numpy as np
18
+ import threading
19
+ import time
20
+ from torch.nn import functional as F
21
+ from contextlib import nullcontext
22
+ import uuid
23
+ from cosyvoice.utils.common import fade_in_out
24
+ from cosyvoice.utils.file_utils import convert_onnx_to_trt
25
+
26
+
27
+ class CosyVoiceModel:
28
+
29
+ def __init__(self,
30
+ llm: torch.nn.Module,
31
+ flow: torch.nn.Module,
32
+ hift: torch.nn.Module,
33
+ fp16: bool):
34
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+ self.llm = llm
36
+ self.flow = flow
37
+ self.hift = hift
38
+ self.fp16 = fp16
39
+ self.llm.fp16 = fp16
40
+ self.flow.fp16 = fp16
41
+ if self.fp16 is True:
42
+ self.llm.half()
43
+ self.flow.half()
44
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
45
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
46
+ self.token_overlap_len = 20
47
+ # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
48
+ self.flow.decoder.estimator.static_chunk_size = 0
49
+ # mel fade in out
50
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
51
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
52
+ # hift cache
53
+ self.mel_cache_len = 20
54
+ self.source_cache_len = int(self.mel_cache_len * 256)
55
+ # speech fade in out
56
+ self.speech_window = np.hamming(2 * self.source_cache_len)
57
+ # rtf and decoding related
58
+ self.stream_scale_factor = 1
59
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
60
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
61
+ self.lock = threading.Lock()
62
+ # dict used to store session related variable
63
+ self.tts_speech_token_dict = {}
64
+ self.llm_end_dict = {}
65
+ self.mel_overlap_dict = {}
66
+ self.flow_cache_dict = {}
67
+ self.hift_cache_dict = {}
68
+
69
+ def load(self, llm_model, flow_model, hift_model):
70
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
71
+ self.llm.to(self.device).eval()
72
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
73
+ self.flow.to(self.device).eval()
74
+ # in case hift_model is a hifigan model
75
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
76
+ self.hift.load_state_dict(hift_state_dict, strict=True)
77
+ self.hift.to(self.device).eval()
78
+
79
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
80
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
81
+ self.llm.text_encoder = llm_text_encoder
82
+ llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
83
+ self.llm.llm = llm_llm
84
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
85
+ self.flow.encoder = flow_encoder
86
+
87
+ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
88
+ assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
89
+ if not os.path.exists(flow_decoder_estimator_model):
90
+ convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
91
+ if os.path.getsize(flow_decoder_estimator_model) == 0:
92
+ raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
93
+ del self.flow.decoder.estimator
94
+ import tensorrt as trt
95
+ with open(flow_decoder_estimator_model, 'rb') as f:
96
+ self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
97
+ if self.flow.decoder.estimator_engine is None:
98
+ raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
99
+ self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
100
+
101
+ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
102
+ with self.llm_context:
103
+ if isinstance(text, Generator):
104
+ assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
105
+ for i in self.llm.inference_bistream(text=text,
106
+ prompt_text=prompt_text.to(self.device),
107
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
108
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
109
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
110
+ embedding=llm_embedding.to(self.device)):
111
+ self.tts_speech_token_dict[uuid].append(i)
112
+ else:
113
+ for i in self.llm.inference(text=text.to(self.device),
114
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
115
+ prompt_text=prompt_text.to(self.device),
116
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
117
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
118
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
119
+ embedding=llm_embedding.to(self.device)):
120
+ self.tts_speech_token_dict[uuid].append(i)
121
+ self.llm_end_dict[uuid] = True
122
+
123
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
124
+ tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
125
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
126
+ prompt_token=prompt_token.to(self.device),
127
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
128
+ prompt_feat=prompt_feat.to(self.device),
129
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
130
+ embedding=embedding.to(self.device),
131
+ flow_cache=self.flow_cache_dict[uuid])
132
+ self.flow_cache_dict[uuid] = flow_cache
133
+
134
+ # mel overlap fade in out
135
+ if self.mel_overlap_dict[uuid].shape[2] != 0:
136
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
137
+ # append hift cache
138
+ if self.hift_cache_dict[uuid] is not None:
139
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
140
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
141
+ else:
142
+ hift_cache_source = torch.zeros(1, 1, 0)
143
+ # keep overlap mel and hift cache
144
+ if finalize is False:
145
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
146
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
147
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
148
+ if self.hift_cache_dict[uuid] is not None:
149
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
150
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
151
+ 'source': tts_source[:, :, -self.source_cache_len:],
152
+ 'speech': tts_speech[:, -self.source_cache_len:]}
153
+ tts_speech = tts_speech[:, :-self.source_cache_len]
154
+ else:
155
+ if speed != 1.0:
156
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
157
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
158
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
159
+ if self.hift_cache_dict[uuid] is not None:
160
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
161
+ return tts_speech
162
+
163
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
164
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
165
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
166
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
167
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
168
+ # this_uuid is used to track variables related to this inference thread
169
+ this_uuid = str(uuid.uuid1())
170
+ with self.lock:
171
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
172
+ self.hift_cache_dict[this_uuid] = None
173
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
174
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
175
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
176
+ p.start()
177
+ if stream is True:
178
+ token_hop_len = self.token_min_hop_len
179
+ while True:
180
+ time.sleep(0.1)
181
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
182
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
183
+ .unsqueeze(dim=0)
184
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
185
+ prompt_token=flow_prompt_speech_token,
186
+ prompt_feat=prompt_speech_feat,
187
+ embedding=flow_embedding,
188
+ uuid=this_uuid,
189
+ finalize=False)
190
+ yield {'tts_speech': this_tts_speech.cpu()}
191
+ with self.lock:
192
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
193
+ # increase token_hop_len for better speech quality
194
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
195
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
196
+ break
197
+ p.join()
198
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
199
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
200
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
201
+ prompt_token=flow_prompt_speech_token,
202
+ prompt_feat=prompt_speech_feat,
203
+ embedding=flow_embedding,
204
+ uuid=this_uuid,
205
+ finalize=True)
206
+ yield {'tts_speech': this_tts_speech.cpu()}
207
+ else:
208
+ # deal with all tokens
209
+ p.join()
210
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
211
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
212
+ prompt_token=flow_prompt_speech_token,
213
+ prompt_feat=prompt_speech_feat,
214
+ embedding=flow_embedding,
215
+ uuid=this_uuid,
216
+ finalize=True,
217
+ speed=speed)
218
+ yield {'tts_speech': this_tts_speech.cpu()}
219
+ with self.lock:
220
+ self.tts_speech_token_dict.pop(this_uuid)
221
+ self.llm_end_dict.pop(this_uuid)
222
+ self.mel_overlap_dict.pop(this_uuid)
223
+ self.hift_cache_dict.pop(this_uuid)
224
+ self.flow_cache_dict.pop(this_uuid)
225
+ torch.cuda.empty_cache()
226
+
227
+ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
228
+ # this_uuid is used to track variables related to this inference thread
229
+ this_uuid = str(uuid.uuid1())
230
+ with self.lock:
231
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
232
+ self.hift_cache_dict[this_uuid] = None
233
+ self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
234
+ self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
235
+ if stream is True:
236
+ token_hop_len = self.token_min_hop_len
237
+ while True:
238
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
239
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
240
+ .unsqueeze(dim=0)
241
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
242
+ prompt_token=flow_prompt_speech_token,
243
+ prompt_feat=prompt_speech_feat,
244
+ embedding=flow_embedding,
245
+ uuid=this_uuid,
246
+ finalize=False)
247
+ yield {'tts_speech': this_tts_speech.cpu()}
248
+ with self.lock:
249
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
250
+ # increase token_hop_len for better speech quality
251
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
252
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
253
+ break
254
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
255
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
256
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
257
+ prompt_token=flow_prompt_speech_token,
258
+ prompt_feat=prompt_speech_feat,
259
+ embedding=flow_embedding,
260
+ uuid=this_uuid,
261
+ finalize=True)
262
+ yield {'tts_speech': this_tts_speech.cpu()}
263
+ else:
264
+ # deal with all tokens
265
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
266
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
267
+ prompt_token=flow_prompt_speech_token,
268
+ prompt_feat=prompt_speech_feat,
269
+ embedding=flow_embedding,
270
+ uuid=this_uuid,
271
+ finalize=True,
272
+ speed=speed)
273
+ yield {'tts_speech': this_tts_speech.cpu()}
274
+ with self.lock:
275
+ self.tts_speech_token_dict.pop(this_uuid)
276
+ self.llm_end_dict.pop(this_uuid)
277
+ self.mel_overlap_dict.pop(this_uuid)
278
+ self.hift_cache_dict.pop(this_uuid)
279
+ torch.cuda.empty_cache()
280
+
281
+
282
+ class CosyVoice2Model(CosyVoiceModel):
283
+
284
+ def __init__(self,
285
+ llm: torch.nn.Module,
286
+ flow: torch.nn.Module,
287
+ hift: torch.nn.Module,
288
+ fp16: bool):
289
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
290
+ self.llm = llm
291
+ self.flow = flow
292
+ self.hift = hift
293
+ self.fp16 = fp16
294
+ self.llm.fp16 = fp16
295
+ self.flow.fp16 = fp16
296
+ if self.fp16 is True:
297
+ self.llm.half()
298
+ self.flow.half()
299
+ self.token_hop_len = 2 * self.flow.input_frame_rate
300
+ # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
301
+ self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
302
+ self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
303
+ # hift cache
304
+ self.mel_cache_len = 8
305
+ self.source_cache_len = int(self.mel_cache_len * 480)
306
+ # speech fade in out
307
+ self.speech_window = np.hamming(2 * self.source_cache_len)
308
+ # rtf and decoding related
309
+ self.stream_scale_factor = 1
310
+ self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
311
+ self.lock = threading.Lock()
312
+ # dict used to store session related variable
313
+ self.tts_speech_token_dict = {}
314
+ self.llm_end_dict = {}
315
+ self.hift_cache_dict = {}
316
+
317
+ def load_jit(self, flow_encoder_model):
318
+ flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
319
+ self.flow.encoder = flow_encoder
320
+
321
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
322
+ tts_mel, _ = self.flow.inference(token=token.to(self.device),
323
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
324
+ prompt_token=prompt_token.to(self.device),
325
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
326
+ prompt_feat=prompt_feat.to(self.device),
327
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
328
+ embedding=embedding.to(self.device),
329
+ finalize=finalize)
330
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
331
+ # append hift cache
332
+ if self.hift_cache_dict[uuid] is not None:
333
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
334
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
335
+ else:
336
+ hift_cache_source = torch.zeros(1, 1, 0)
337
+ # keep overlap mel and hift cache
338
+ if finalize is False:
339
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
340
+ if self.hift_cache_dict[uuid] is not None:
341
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
342
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
343
+ 'source': tts_source[:, :, -self.source_cache_len:],
344
+ 'speech': tts_speech[:, -self.source_cache_len:]}
345
+ tts_speech = tts_speech[:, :-self.source_cache_len]
346
+ else:
347
+ if speed != 1.0:
348
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
349
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
350
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
351
+ if self.hift_cache_dict[uuid] is not None:
352
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
353
+ return tts_speech
354
+
355
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
356
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
357
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
358
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
359
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
360
+ # this_uuid is used to track variables related to this inference thread
361
+ this_uuid = str(uuid.uuid1())
362
+ with self.lock:
363
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
364
+ self.hift_cache_dict[this_uuid] = None
365
+ p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
366
+ p.start()
367
+ if stream is True:
368
+ token_offset = 0
369
+ while True:
370
+ time.sleep(0.1)
371
+ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
372
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
373
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
374
+ prompt_token=flow_prompt_speech_token,
375
+ prompt_feat=prompt_speech_feat,
376
+ embedding=flow_embedding,
377
+ uuid=this_uuid,
378
+ token_offset=token_offset,
379
+ finalize=False)
380
+ token_offset += self.token_hop_len
381
+ yield {'tts_speech': this_tts_speech.cpu()}
382
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
383
+ break
384
+ p.join()
385
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
386
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
387
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
388
+ prompt_token=flow_prompt_speech_token,
389
+ prompt_feat=prompt_speech_feat,
390
+ embedding=flow_embedding,
391
+ uuid=this_uuid,
392
+ token_offset=token_offset,
393
+ finalize=True)
394
+ yield {'tts_speech': this_tts_speech.cpu()}
395
+ else:
396
+ # deal with all tokens
397
+ p.join()
398
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
399
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
400
+ prompt_token=flow_prompt_speech_token,
401
+ prompt_feat=prompt_speech_feat,
402
+ embedding=flow_embedding,
403
+ uuid=this_uuid,
404
+ token_offset=0,
405
+ finalize=True,
406
+ speed=speed)
407
+ yield {'tts_speech': this_tts_speech.cpu()}
408
+ with self.lock:
409
+ self.tts_speech_token_dict.pop(this_uuid)
410
+ self.llm_end_dict.pop(this_uuid)
411
+ torch.cuda.empty_cache()
cosyvoice/dataset/__init__.py ADDED
File without changes
cosyvoice/dataset/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ gan=False,
130
+ shuffle=True,
131
+ partition=True,
132
+ tts_file='',
133
+ prompt_utt2data=''):
134
+ """ Construct dataset from arguments
135
+
136
+ We have two shuffle stage in the Dataset. The first is global
137
+ shuffle at shards tar/raw file level. The second is global shuffle
138
+ at training samples level.
139
+
140
+ Args:
141
+ data_type(str): raw/shard
142
+ tokenizer (BaseTokenizer): tokenizer to tokenize
143
+ partition(bool): whether to do data partition in terms of rank
144
+ """
145
+ assert mode in ['train', 'inference']
146
+ lists = read_lists(data_list_file)
147
+ if mode == 'inference':
148
+ with open(tts_file) as f:
149
+ tts_data = json.load(f)
150
+ utt2lists = read_json_lists(prompt_utt2data)
151
+ # filter unnecessary file in inference mode
152
+ lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
153
+ dataset = DataList(lists,
154
+ shuffle=shuffle,
155
+ partition=partition)
156
+ if mode == 'inference':
157
+ # map partial arg to parquet_opener func in inference mode
158
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
159
+ if gan is True:
160
+ # map partial arg to padding func in gan mode
161
+ data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
162
+ for func in data_pipeline:
163
+ dataset = Processor(dataset, func, mode=mode)
164
+ return dataset
cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import pack, rearrange, repeat
18
+ from cosyvoice.utils.common import mask_to_bias
19
+ from cosyvoice.utils.mask import add_optional_chunk_mask
20
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
21
+ from matcha.models.components.transformer import BasicTransformerBlock
22
+
23
+
24
+ class Transpose(torch.nn.Module):
25
+ def __init__(self, dim0: int, dim1: int):
26
+ super().__init__()
27
+ self.dim0 = dim0
28
+ self.dim1 = dim1
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = torch.transpose(x, self.dim0, self.dim1)
32
+ return x
33
+
34
+
35
+ class CausalBlock1D(Block1D):
36
+ def __init__(self, dim: int, dim_out: int):
37
+ super(CausalBlock1D, self).__init__(dim, dim_out)
38
+ self.block = torch.nn.Sequential(
39
+ CausalConv1d(dim, dim_out, 3),
40
+ Transpose(1, 2),
41
+ nn.LayerNorm(dim_out),
42
+ Transpose(1, 2),
43
+ nn.Mish(),
44
+ )
45
+
46
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
47
+ output = self.block(x * mask)
48
+ return output * mask
49
+
50
+
51
+ class CausalResnetBlock1D(ResnetBlock1D):
52
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
53
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
54
+ self.block1 = CausalBlock1D(dim, dim_out)
55
+ self.block2 = CausalBlock1D(dim_out, dim_out)
56
+
57
+
58
+ class CausalConv1d(torch.nn.Conv1d):
59
+ def __init__(
60
+ self,
61
+ in_channels: int,
62
+ out_channels: int,
63
+ kernel_size: int,
64
+ stride: int = 1,
65
+ dilation: int = 1,
66
+ groups: int = 1,
67
+ bias: bool = True,
68
+ padding_mode: str = 'zeros',
69
+ device=None,
70
+ dtype=None
71
+ ) -> None:
72
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
73
+ kernel_size, stride,
74
+ padding=0, dilation=dilation,
75
+ groups=groups, bias=bias,
76
+ padding_mode=padding_mode,
77
+ device=device, dtype=dtype)
78
+ assert stride == 1
79
+ self.causal_padding = (kernel_size - 1, 0)
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ x = F.pad(x, self.causal_padding)
83
+ x = super(CausalConv1d, self).forward(x)
84
+ return x
85
+
86
+
87
+ class ConditionalDecoder(nn.Module):
88
+ def __init__(
89
+ self,
90
+ in_channels,
91
+ out_channels,
92
+ causal=False,
93
+ channels=(256, 256),
94
+ dropout=0.05,
95
+ attention_head_dim=64,
96
+ n_blocks=1,
97
+ num_mid_blocks=2,
98
+ num_heads=4,
99
+ act_fn="snake",
100
+ ):
101
+ """
102
+ This decoder requires an input with the same shape of the target. So, if your text content
103
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
104
+ """
105
+ super().__init__()
106
+ channels = tuple(channels)
107
+ self.in_channels = in_channels
108
+ self.out_channels = out_channels
109
+ self.causal = causal
110
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
111
+ time_embed_dim = channels[0] * 4
112
+ self.time_mlp = TimestepEmbedding(
113
+ in_channels=in_channels,
114
+ time_embed_dim=time_embed_dim,
115
+ act_fn="silu",
116
+ )
117
+ self.down_blocks = nn.ModuleList([])
118
+ self.mid_blocks = nn.ModuleList([])
119
+ self.up_blocks = nn.ModuleList([])
120
+
121
+ output_channel = in_channels
122
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
123
+ input_channel = output_channel
124
+ output_channel = channels[i]
125
+ is_last = i == len(channels) - 1
126
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
127
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
128
+ transformer_blocks = nn.ModuleList(
129
+ [
130
+ BasicTransformerBlock(
131
+ dim=output_channel,
132
+ num_attention_heads=num_heads,
133
+ attention_head_dim=attention_head_dim,
134
+ dropout=dropout,
135
+ activation_fn=act_fn,
136
+ )
137
+ for _ in range(n_blocks)
138
+ ]
139
+ )
140
+ downsample = (
141
+ Downsample1D(output_channel) if not is_last else
142
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
143
+ )
144
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
145
+
146
+ for _ in range(num_mid_blocks):
147
+ input_channel = channels[-1]
148
+ out_channels = channels[-1]
149
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
150
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
151
+
152
+ transformer_blocks = nn.ModuleList(
153
+ [
154
+ BasicTransformerBlock(
155
+ dim=output_channel,
156
+ num_attention_heads=num_heads,
157
+ attention_head_dim=attention_head_dim,
158
+ dropout=dropout,
159
+ activation_fn=act_fn,
160
+ )
161
+ for _ in range(n_blocks)
162
+ ]
163
+ )
164
+
165
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
166
+
167
+ channels = channels[::-1] + (channels[0],)
168
+ for i in range(len(channels) - 1):
169
+ input_channel = channels[i] * 2
170
+ output_channel = channels[i + 1]
171
+ is_last = i == len(channels) - 2
172
+ resnet = CausalResnetBlock1D(
173
+ dim=input_channel,
174
+ dim_out=output_channel,
175
+ time_emb_dim=time_embed_dim,
176
+ ) if self.causal else ResnetBlock1D(
177
+ dim=input_channel,
178
+ dim_out=output_channel,
179
+ time_emb_dim=time_embed_dim,
180
+ )
181
+ transformer_blocks = nn.ModuleList(
182
+ [
183
+ BasicTransformerBlock(
184
+ dim=output_channel,
185
+ num_attention_heads=num_heads,
186
+ attention_head_dim=attention_head_dim,
187
+ dropout=dropout,
188
+ activation_fn=act_fn,
189
+ )
190
+ for _ in range(n_blocks)
191
+ ]
192
+ )
193
+ upsample = (
194
+ Upsample1D(output_channel, use_conv_transpose=True)
195
+ if not is_last
196
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
197
+ )
198
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
199
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
200
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
201
+ self.initialize_weights()
202
+
203
+ def initialize_weights(self):
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv1d):
206
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
207
+ if m.bias is not None:
208
+ nn.init.constant_(m.bias, 0)
209
+ elif isinstance(m, nn.GroupNorm):
210
+ nn.init.constant_(m.weight, 1)
211
+ nn.init.constant_(m.bias, 0)
212
+ elif isinstance(m, nn.Linear):
213
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
214
+ if m.bias is not None:
215
+ nn.init.constant_(m.bias, 0)
216
+
217
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
218
+ """Forward pass of the UNet1DConditional model.
219
+
220
+ Args:
221
+ x (torch.Tensor): shape (batch_size, in_channels, time)
222
+ mask (_type_): shape (batch_size, 1, time)
223
+ t (_type_): shape (batch_size)
224
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
225
+ cond (_type_, optional): placeholder for future use. Defaults to None.
226
+
227
+ Raises:
228
+ ValueError: _description_
229
+ ValueError: _description_
230
+
231
+ Returns:
232
+ _type_: _description_
233
+ """
234
+
235
+ t = self.time_embeddings(t).to(t.dtype)
236
+ t = self.time_mlp(t)
237
+
238
+ x = pack([x, mu], "b * t")[0]
239
+
240
+ if spks is not None:
241
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
242
+ x = pack([x, spks], "b * t")[0]
243
+ if cond is not None:
244
+ x = pack([x, cond], "b * t")[0]
245
+
246
+ hiddens = []
247
+ masks = [mask]
248
+ for resnet, transformer_blocks, downsample in self.down_blocks:
249
+ mask_down = masks[-1]
250
+ x = resnet(x, mask_down, t)
251
+ x = rearrange(x, "b c t -> b t c").contiguous()
252
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
253
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
254
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
255
+ for transformer_block in transformer_blocks:
256
+ x = transformer_block(
257
+ hidden_states=x,
258
+ attention_mask=attn_mask,
259
+ timestep=t,
260
+ )
261
+ x = rearrange(x, "b t c -> b c t").contiguous()
262
+ hiddens.append(x) # Save hidden states for skip connections
263
+ x = downsample(x * mask_down)
264
+ masks.append(mask_down[:, :, ::2])
265
+ masks = masks[:-1]
266
+ mask_mid = masks[-1]
267
+
268
+ for resnet, transformer_blocks in self.mid_blocks:
269
+ x = resnet(x, mask_mid, t)
270
+ x = rearrange(x, "b c t -> b t c").contiguous()
271
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
272
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
273
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
274
+ for transformer_block in transformer_blocks:
275
+ x = transformer_block(
276
+ hidden_states=x,
277
+ attention_mask=attn_mask,
278
+ timestep=t,
279
+ )
280
+ x = rearrange(x, "b t c -> b c t").contiguous()
281
+
282
+ for resnet, transformer_blocks, upsample in self.up_blocks:
283
+ mask_up = masks.pop()
284
+ skip = hiddens.pop()
285
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
286
+ x = resnet(x, mask_up, t)
287
+ x = rearrange(x, "b c t -> b t c").contiguous()
288
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
289
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
290
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
291
+ for transformer_block in transformer_blocks:
292
+ x = transformer_block(
293
+ hidden_states=x,
294
+ attention_mask=attn_mask,
295
+ timestep=t,
296
+ )
297
+ x = rearrange(x, "b t c -> b c t").contiguous()
298
+ x = upsample(x * mask_up)
299
+ x = self.final_block(x, mask_up)
300
+ output = self.final_proj(x * mask_up)
301
+ return output * mask
cosyvoice/flow/flow.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from cosyvoice.utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+ self.output_size = output_size
46
+ self.decoder_conf = decoder_conf
47
+ self.mel_feat_conf = mel_feat_conf
48
+ self.vocab_size = vocab_size
49
+ self.output_type = output_type
50
+ self.input_frame_rate = input_frame_rate
51
+ logging.info(f"input frame rate={self.input_frame_rate}")
52
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
53
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
54
+ self.encoder = encoder
55
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
56
+ self.decoder = decoder
57
+ self.length_regulator = length_regulator
58
+ self.only_mask_loss = only_mask_loss
59
+
60
+ def forward(
61
+ self,
62
+ batch: dict,
63
+ device: torch.device,
64
+ ) -> Dict[str, Optional[torch.Tensor]]:
65
+ token = batch['speech_token'].to(device)
66
+ token_len = batch['speech_token_len'].to(device)
67
+ feat = batch['speech_feat'].to(device)
68
+ feat_len = batch['speech_feat_len'].to(device)
69
+ embedding = batch['embedding'].to(device)
70
+
71
+ # xvec projection
72
+ embedding = F.normalize(embedding, dim=1)
73
+ embedding = self.spk_embed_affine_layer(embedding)
74
+
75
+ # concat text and prompt_text
76
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
78
+
79
+ # text encode
80
+ h, h_lengths = self.encoder(token, token_len)
81
+ h = self.encoder_proj(h)
82
+ h, h_lengths = self.length_regulator(h, feat_len)
83
+
84
+ # get conditions
85
+ conds = torch.zeros(feat.shape, device=token.device)
86
+ for i, j in enumerate(feat_len):
87
+ if random.random() < 0.5:
88
+ continue
89
+ index = random.randint(0, int(0.3 * j))
90
+ conds[i, :index] = feat[i, :index]
91
+ conds = conds.transpose(1, 2)
92
+
93
+ mask = (~make_pad_mask(feat_len)).to(h)
94
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
95
+ loss, _ = self.decoder.compute_loss(
96
+ feat.transpose(1, 2).contiguous(),
97
+ mask.unsqueeze(1),
98
+ h.transpose(1, 2).contiguous(),
99
+ embedding,
100
+ cond=conds
101
+ )
102
+ return {'loss': loss}
103
+
104
+ @torch.inference_mode()
105
+ def inference(self,
106
+ token,
107
+ token_len,
108
+ prompt_token,
109
+ prompt_token_len,
110
+ prompt_feat,
111
+ prompt_feat_len,
112
+ embedding,
113
+ flow_cache):
114
+ if self.fp16 is True:
115
+ prompt_feat = prompt_feat.half()
116
+ embedding = embedding.half()
117
+
118
+ assert token.shape[0] == 1
119
+ # xvec projection
120
+ embedding = F.normalize(embedding, dim=1)
121
+ embedding = self.spk_embed_affine_layer(embedding)
122
+
123
+ # concat text and prompt_text
124
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
125
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
126
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
127
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
128
+
129
+ # text encode
130
+ h, h_lengths = self.encoder(token, token_len)
131
+ h = self.encoder_proj(h)
132
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
133
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
134
+
135
+ # get conditions
136
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
137
+ conds[:, :mel_len1] = prompt_feat
138
+ conds = conds.transpose(1, 2)
139
+
140
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
141
+ feat, flow_cache = self.decoder(
142
+ mu=h.transpose(1, 2).contiguous(),
143
+ mask=mask.unsqueeze(1),
144
+ spks=embedding,
145
+ cond=conds,
146
+ n_timesteps=10,
147
+ prompt_len=mel_len1,
148
+ flow_cache=flow_cache
149
+ )
150
+ feat = feat[:, :, mel_len1:]
151
+ assert feat.shape[2] == mel_len2
152
+ return feat.float(), flow_cache
153
+
154
+
155
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
156
+ def __init__(self,
157
+ input_size: int = 512,
158
+ output_size: int = 80,
159
+ spk_embed_dim: int = 192,
160
+ output_type: str = "mel",
161
+ vocab_size: int = 4096,
162
+ input_frame_rate: int = 50,
163
+ only_mask_loss: bool = True,
164
+ token_mel_ratio: int = 2,
165
+ pre_lookahead_len: int = 3,
166
+ encoder: torch.nn.Module = None,
167
+ decoder: torch.nn.Module = None,
168
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
169
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
170
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
171
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
172
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
173
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
174
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
175
+ super().__init__()
176
+ self.input_size = input_size
177
+ self.output_size = output_size
178
+ self.decoder_conf = decoder_conf
179
+ self.mel_feat_conf = mel_feat_conf
180
+ self.vocab_size = vocab_size
181
+ self.output_type = output_type
182
+ self.input_frame_rate = input_frame_rate
183
+ logging.info(f"input frame rate={self.input_frame_rate}")
184
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
185
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
186
+ self.encoder = encoder
187
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
188
+ self.decoder = decoder
189
+ self.only_mask_loss = only_mask_loss
190
+ self.token_mel_ratio = token_mel_ratio
191
+ self.pre_lookahead_len = pre_lookahead_len
192
+
193
+ @torch.inference_mode()
194
+ def inference(self,
195
+ token,
196
+ token_len,
197
+ prompt_token,
198
+ prompt_token_len,
199
+ prompt_feat,
200
+ prompt_feat_len,
201
+ embedding,
202
+ finalize):
203
+ if self.fp16 is True:
204
+ prompt_feat = prompt_feat.half()
205
+ embedding = embedding.half()
206
+
207
+ assert token.shape[0] == 1
208
+ # xvec projection
209
+ embedding = F.normalize(embedding, dim=1)
210
+ embedding = self.spk_embed_affine_layer(embedding)
211
+
212
+ # concat text and prompt_text
213
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
214
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
215
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
216
+
217
+ # text encode
218
+ h, h_lengths = self.encoder(token, token_len)
219
+ if finalize is False:
220
+ h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
221
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
222
+ h = self.encoder_proj(h)
223
+
224
+ # get conditions
225
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
226
+ conds[:, :mel_len1] = prompt_feat
227
+ conds = conds.transpose(1, 2)
228
+
229
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
230
+ feat, _ = self.decoder(
231
+ mu=h.transpose(1, 2).contiguous(),
232
+ mask=mask.unsqueeze(1),
233
+ spks=embedding,
234
+ cond=conds,
235
+ n_timesteps=10
236
+ )
237
+ feat = feat[:, :, mel_len1:]
238
+ assert feat.shape[2] == mel_len2
239
+ return feat.float(), None
cosyvoice/flow/flow_matching.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import threading
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from matcha.models.components.flow_matching import BASECFM
18
+
19
+
20
+ class ConditionalCFM(BASECFM):
21
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
22
+ super().__init__(
23
+ n_feats=in_channels,
24
+ cfm_params=cfm_params,
25
+ n_spks=n_spks,
26
+ spk_emb_dim=spk_emb_dim,
27
+ )
28
+ self.t_scheduler = cfm_params.t_scheduler
29
+ self.training_cfg_rate = cfm_params.training_cfg_rate
30
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
31
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
32
+ # Just change the architecture of the estimator here
33
+ self.estimator = estimator
34
+ self.lock = threading.Lock()
35
+
36
+ @torch.inference_mode()
37
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
38
+ """Forward diffusion
39
+
40
+ Args:
41
+ mu (torch.Tensor): output of encoder
42
+ shape: (batch_size, n_feats, mel_timesteps)
43
+ mask (torch.Tensor): output_mask
44
+ shape: (batch_size, 1, mel_timesteps)
45
+ n_timesteps (int): number of diffusion steps
46
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
47
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
48
+ shape: (batch_size, spk_emb_dim)
49
+ cond: Not used but kept for future purposes
50
+
51
+ Returns:
52
+ sample: generated mel-spectrogram
53
+ shape: (batch_size, n_feats, mel_timesteps)
54
+ """
55
+
56
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
57
+ cache_size = flow_cache.shape[2]
58
+ # fix prompt and overlap part mu and z
59
+ if cache_size != 0:
60
+ z[:, :, :cache_size] = flow_cache[:, :, :, 0]
61
+ mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
62
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
63
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
64
+ flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
65
+
66
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
67
+ if self.t_scheduler == 'cosine':
68
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
69
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
70
+
71
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
72
+ """
73
+ Fixed euler solver for ODEs.
74
+ Args:
75
+ x (torch.Tensor): random noise
76
+ t_span (torch.Tensor): n_timesteps interpolated
77
+ shape: (n_timesteps + 1,)
78
+ mu (torch.Tensor): output of encoder
79
+ shape: (batch_size, n_feats, mel_timesteps)
80
+ mask (torch.Tensor): output_mask
81
+ shape: (batch_size, 1, mel_timesteps)
82
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
83
+ shape: (batch_size, spk_emb_dim)
84
+ cond: Not used but kept for future purposes
85
+ """
86
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
87
+ t = t.unsqueeze(dim=0)
88
+
89
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
90
+ # Or in future might add like a return_all_steps flag
91
+ sol = []
92
+
93
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
94
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
95
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
96
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
97
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
98
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
99
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
100
+ for step in range(1, len(t_span)):
101
+ # Classifier-Free Guidance inference introduced in VoiceBox
102
+ x_in[:] = x
103
+ mask_in[:] = mask
104
+ mu_in[0] = mu
105
+ t_in[:] = t.unsqueeze(0)
106
+ spks_in[0] = spks
107
+ cond_in[0] = cond
108
+ dphi_dt = self.forward_estimator(
109
+ x_in, mask_in,
110
+ mu_in, t_in,
111
+ spks_in,
112
+ cond_in
113
+ )
114
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
115
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
116
+ x = x + dt * dphi_dt
117
+ t = t + dt
118
+ sol.append(x)
119
+ if step < len(t_span) - 1:
120
+ dt = t_span[step + 1] - t
121
+
122
+ return sol[-1].float()
123
+
124
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
125
+ if isinstance(self.estimator, torch.nn.Module):
126
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
127
+ else:
128
+ with self.lock:
129
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
130
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
131
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
132
+ self.estimator.set_input_shape('t', (2,))
133
+ self.estimator.set_input_shape('spks', (2, 80))
134
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
135
+ # run trt engine
136
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
137
+ mask.contiguous().data_ptr(),
138
+ mu.contiguous().data_ptr(),
139
+ t.contiguous().data_ptr(),
140
+ spks.contiguous().data_ptr(),
141
+ cond.contiguous().data_ptr(),
142
+ x.data_ptr()])
143
+ return x
144
+
145
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
146
+ """Computes diffusion loss
147
+
148
+ Args:
149
+ x1 (torch.Tensor): Target
150
+ shape: (batch_size, n_feats, mel_timesteps)
151
+ mask (torch.Tensor): target mask
152
+ shape: (batch_size, 1, mel_timesteps)
153
+ mu (torch.Tensor): output of encoder
154
+ shape: (batch_size, n_feats, mel_timesteps)
155
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
156
+ shape: (batch_size, spk_emb_dim)
157
+
158
+ Returns:
159
+ loss: conditional flow matching loss
160
+ y: conditional flow
161
+ shape: (batch_size, n_feats, mel_timesteps)
162
+ """
163
+ b, _, t = mu.shape
164
+
165
+ # random timestep
166
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
167
+ if self.t_scheduler == 'cosine':
168
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
169
+ # sample noise p(x_0)
170
+ z = torch.randn_like(x1)
171
+
172
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
173
+ u = x1 - (1 - self.sigma_min) * z
174
+
175
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
176
+ if self.training_cfg_rate > 0:
177
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
178
+ mu = mu * cfg_mask.view(-1, 1, 1)
179
+ spks = spks * cfg_mask.view(-1, 1)
180
+ cond = cond * cfg_mask.view(-1, 1, 1)
181
+
182
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
183
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
184
+ return loss, y
185
+
186
+
187
+ class CausalConditionalCFM(ConditionalCFM):
188
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
189
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
190
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
191
+
192
+ @torch.inference_mode()
193
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
194
+ """Forward diffusion
195
+
196
+ Args:
197
+ mu (torch.Tensor): output of encoder
198
+ shape: (batch_size, n_feats, mel_timesteps)
199
+ mask (torch.Tensor): output_mask
200
+ shape: (batch_size, 1, mel_timesteps)
201
+ n_timesteps (int): number of diffusion steps
202
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
203
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
204
+ shape: (batch_size, spk_emb_dim)
205
+ cond: Not used but kept for future purposes
206
+
207
+ Returns:
208
+ sample: generated mel-spectrogram
209
+ shape: (batch_size, n_feats, mel_timesteps)
210
+ """
211
+
212
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
213
+ # fix prompt and overlap part mu and z
214
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
215
+ if self.t_scheduler == 'cosine':
216
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
217
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
cosyvoice/flow/length_regulator.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Tuple
15
+ import torch.nn as nn
16
+ import torch
17
+ from torch.nn import functional as F
18
+ from cosyvoice.utils.mask import make_pad_mask
19
+
20
+
21
+ class InterpolateRegulator(nn.Module):
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ sampling_ratios: Tuple,
26
+ out_channels: int = None,
27
+ groups: int = 1,
28
+ ):
29
+ super().__init__()
30
+ self.sampling_ratios = sampling_ratios
31
+ out_channels = out_channels or channels
32
+ model = nn.ModuleList([])
33
+ if len(sampling_ratios) > 0:
34
+ for _ in sampling_ratios:
35
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
36
+ norm = nn.GroupNorm(groups, channels)
37
+ act = nn.Mish()
38
+ model.extend([module, norm, act])
39
+ model.append(
40
+ nn.Conv1d(channels, out_channels, 1, 1)
41
+ )
42
+ self.model = nn.Sequential(*model)
43
+
44
+ def forward(self, x, ylens=None):
45
+ # x in (B, T, D)
46
+ mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48
+ out = self.model(x).transpose(1, 2).contiguous()
49
+ olens = ylens
50
+ return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # x in (B, T, D)
55
+ if x2.shape[1] > 40:
56
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
57
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
58
+ mode='linear')
59
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
60
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
61
+ else:
62
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
63
+ if x1.shape[1] != 0:
64
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
65
+ x = torch.concat([x1, x2], dim=2)
66
+ else:
67
+ x = x2
68
+ out = self.model(x).transpose(1, 2).contiguous()
69
+ return out, mel_len1 + mel_len2
cosyvoice/hifigan/discriminator.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils.parametrizations import weight_norm
4
+ from typing import List, Optional, Tuple
5
+ from einops import rearrange
6
+ from torchaudio.transforms import Spectrogram
7
+
8
+
9
+ class MultipleDiscriminator(nn.Module):
10
+ def __init__(
11
+ self, mpd: nn.Module, mrd: nn.Module
12
+ ):
13
+ super().__init__()
14
+ self.mpd = mpd
15
+ self.mrd = mrd
16
+
17
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
18
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
19
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
20
+ y_d_rs += this_y_d_rs
21
+ y_d_gs += this_y_d_gs
22
+ fmap_rs += this_fmap_rs
23
+ fmap_gs += this_fmap_gs
24
+ this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
25
+ y_d_rs += this_y_d_rs
26
+ y_d_gs += this_y_d_gs
27
+ fmap_rs += this_fmap_rs
28
+ fmap_gs += this_fmap_gs
29
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
30
+
31
+
32
+ class MultiResolutionDiscriminator(nn.Module):
33
+ def __init__(
34
+ self,
35
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
36
+ num_embeddings: Optional[int] = None,
37
+ ):
38
+ """
39
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
40
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
41
+
42
+ Args:
43
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
44
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
45
+ Defaults to None.
46
+ """
47
+
48
+ super().__init__()
49
+ self.discriminators = nn.ModuleList(
50
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
51
+ )
52
+
53
+ def forward(
54
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
55
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
56
+ y_d_rs = []
57
+ y_d_gs = []
58
+ fmap_rs = []
59
+ fmap_gs = []
60
+
61
+ for d in self.discriminators:
62
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
63
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
64
+ y_d_rs.append(y_d_r)
65
+ fmap_rs.append(fmap_r)
66
+ y_d_gs.append(y_d_g)
67
+ fmap_gs.append(fmap_g)
68
+
69
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
70
+
71
+
72
+ class DiscriminatorR(nn.Module):
73
+ def __init__(
74
+ self,
75
+ window_length: int,
76
+ num_embeddings: Optional[int] = None,
77
+ channels: int = 32,
78
+ hop_factor: float = 0.25,
79
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
80
+ ):
81
+ super().__init__()
82
+ self.window_length = window_length
83
+ self.hop_factor = hop_factor
84
+ self.spec_fn = Spectrogram(
85
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
86
+ )
87
+ n_fft = window_length // 2 + 1
88
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
89
+ self.bands = bands
90
+ convs = lambda: nn.ModuleList(
91
+ [
92
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
93
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
94
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
95
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
96
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
97
+ ]
98
+ )
99
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
100
+
101
+ if num_embeddings is not None:
102
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
103
+ torch.nn.init.zeros_(self.emb.weight)
104
+
105
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
106
+
107
+ def spectrogram(self, x):
108
+ # Remove DC offset
109
+ x = x - x.mean(dim=-1, keepdims=True)
110
+ # Peak normalize the volume of input audio
111
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
112
+ x = self.spec_fn(x)
113
+ x = torch.view_as_real(x)
114
+ x = rearrange(x, "b f t c -> b c t f")
115
+ # Split into bands
116
+ x_bands = [x[..., b[0]: b[1]] for b in self.bands]
117
+ return x_bands
118
+
119
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
120
+ x_bands = self.spectrogram(x)
121
+ fmap = []
122
+ x = []
123
+ for band, stack in zip(x_bands, self.band_convs):
124
+ for i, layer in enumerate(stack):
125
+ band = layer(band)
126
+ band = torch.nn.functional.leaky_relu(band, 0.1)
127
+ if i > 0:
128
+ fmap.append(band)
129
+ x.append(band)
130
+ x = torch.cat(x, dim=-1)
131
+ if cond_embedding_id is not None:
132
+ emb = self.emb(cond_embedding_id)
133
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
134
+ else:
135
+ h = 0
136
+ x = self.conv_post(x)
137
+ fmap.append(x)
138
+ x += h
139
+
140
+ return x, fmap
cosyvoice/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.utils.parametrizations import weight_norm
17
+
18
+
19
+ class ConvRNNF0Predictor(nn.Module):
20
+ def __init__(self,
21
+ num_class: int = 1,
22
+ in_channels: int = 80,
23
+ cond_channels: int = 512
24
+ ):
25
+ super().__init__()
26
+
27
+ self.num_class = num_class
28
+ self.condnet = nn.Sequential(
29
+ weight_norm(
30
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31
+ ),
32
+ nn.ELU(),
33
+ weight_norm(
34
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35
+ ),
36
+ nn.ELU(),
37
+ weight_norm(
38
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
+ ),
40
+ nn.ELU(),
41
+ weight_norm(
42
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
+ ),
44
+ nn.ELU(),
45
+ weight_norm(
46
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
+ ),
48
+ nn.ELU(),
49
+ )
50
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.condnet(x)
54
+ x = x.transpose(1, 2)
55
+ return torch.abs(self.classifier(x).squeeze(-1))
cosyvoice/hifigan/generator.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ from typing import Dict, Optional, List
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ from torch.nn.utils.parametrizations import weight_norm
27
+ from torch.distributions.uniform import Uniform
28
+
29
+ from cosyvoice.transformer.activation import Snake
30
+ from cosyvoice.utils.common import get_padding
31
+ from cosyvoice.utils.common import init_weights
32
+
33
+
34
+ """hifigan based generator implementation.
35
+
36
+ This code is modified from https://github.com/jik876/hifi-gan
37
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
38
+ https://github.com/NVIDIA/BigVGAN
39
+
40
+ """
41
+
42
+
43
+ class ResBlock(torch.nn.Module):
44
+ """Residual block module in HiFiGAN/BigVGAN."""
45
+ def __init__(
46
+ self,
47
+ channels: int = 512,
48
+ kernel_size: int = 3,
49
+ dilations: List[int] = [1, 3, 5],
50
+ ):
51
+ super(ResBlock, self).__init__()
52
+ self.convs1 = nn.ModuleList()
53
+ self.convs2 = nn.ModuleList()
54
+
55
+ for dilation in dilations:
56
+ self.convs1.append(
57
+ weight_norm(
58
+ Conv1d(
59
+ channels,
60
+ channels,
61
+ kernel_size,
62
+ 1,
63
+ dilation=dilation,
64
+ padding=get_padding(kernel_size, dilation)
65
+ )
66
+ )
67
+ )
68
+ self.convs2.append(
69
+ weight_norm(
70
+ Conv1d(
71
+ channels,
72
+ channels,
73
+ kernel_size,
74
+ 1,
75
+ dilation=1,
76
+ padding=get_padding(kernel_size, 1)
77
+ )
78
+ )
79
+ )
80
+ self.convs1.apply(init_weights)
81
+ self.convs2.apply(init_weights)
82
+ self.activations1 = nn.ModuleList([
83
+ Snake(channels, alpha_logscale=False)
84
+ for _ in range(len(self.convs1))
85
+ ])
86
+ self.activations2 = nn.ModuleList([
87
+ Snake(channels, alpha_logscale=False)
88
+ for _ in range(len(self.convs2))
89
+ ])
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ for idx in range(len(self.convs1)):
93
+ xt = self.activations1[idx](x)
94
+ xt = self.convs1[idx](xt)
95
+ xt = self.activations2[idx](xt)
96
+ xt = self.convs2[idx](xt)
97
+ x = xt + x
98
+ return x
99
+
100
+ def remove_weight_norm(self):
101
+ for idx in range(len(self.convs1)):
102
+ remove_weight_norm(self.convs1[idx])
103
+ remove_weight_norm(self.convs2[idx])
104
+
105
+
106
+ class SineGen(torch.nn.Module):
107
+ """ Definition of sine generator
108
+ SineGen(samp_rate, harmonic_num = 0,
109
+ sine_amp = 0.1, noise_std = 0.003,
110
+ voiced_threshold = 0,
111
+ flag_for_pulse=False)
112
+ samp_rate: sampling rate in Hz
113
+ harmonic_num: number of harmonic overtones (default 0)
114
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
115
+ noise_std: std of Gaussian noise (default 0.003)
116
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
117
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
118
+ Note: when flag_for_pulse is True, the first time step of a voiced
119
+ segment is always sin(np.pi) or cos(0)
120
+ """
121
+
122
+ def __init__(self, samp_rate, harmonic_num=0,
123
+ sine_amp=0.1, noise_std=0.003,
124
+ voiced_threshold=0):
125
+ super(SineGen, self).__init__()
126
+ self.sine_amp = sine_amp
127
+ self.noise_std = noise_std
128
+ self.harmonic_num = harmonic_num
129
+ self.sampling_rate = samp_rate
130
+ self.voiced_threshold = voiced_threshold
131
+
132
+ def _f02uv(self, f0):
133
+ # generate uv signal
134
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
135
+ return uv
136
+
137
+ @torch.no_grad()
138
+ def forward(self, f0):
139
+ """
140
+ :param f0: [B, 1, sample_len], Hz
141
+ :return: [B, 1, sample_len]
142
+ """
143
+
144
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
145
+ for i in range(self.harmonic_num + 1):
146
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
147
+
148
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
149
+ u_dist = Uniform(low=-np.pi, high=np.pi)
150
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
151
+ phase_vec[:, 0, :] = 0
152
+
153
+ # generate sine waveforms
154
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
155
+
156
+ # generate uv signal
157
+ uv = self._f02uv(f0)
158
+
159
+ # noise: for unvoiced should be similar to sine_amp
160
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
161
+ # . for voiced regions is self.noise_std
162
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
163
+ noise = noise_amp * torch.randn_like(sine_waves)
164
+
165
+ # first: set the unvoiced part to 0 by uv
166
+ # then: additive noise
167
+ sine_waves = sine_waves * uv + noise
168
+ return sine_waves, uv, noise
169
+
170
+
171
+ class SourceModuleHnNSF(torch.nn.Module):
172
+ """ SourceModule for hn-nsf
173
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
174
+ add_noise_std=0.003, voiced_threshod=0)
175
+ sampling_rate: sampling_rate in Hz
176
+ harmonic_num: number of harmonic above F0 (default: 0)
177
+ sine_amp: amplitude of sine source signal (default: 0.1)
178
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
179
+ note that amplitude of noise in unvoiced is decided
180
+ by sine_amp
181
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
182
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
183
+ F0_sampled (batchsize, length, 1)
184
+ Sine_source (batchsize, length, 1)
185
+ noise_source (batchsize, length 1)
186
+ uv (batchsize, length, 1)
187
+ """
188
+
189
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
190
+ add_noise_std=0.003, voiced_threshod=0):
191
+ super(SourceModuleHnNSF, self).__init__()
192
+
193
+ self.sine_amp = sine_amp
194
+ self.noise_std = add_noise_std
195
+
196
+ # to produce sine waveforms
197
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
198
+ sine_amp, add_noise_std, voiced_threshod)
199
+
200
+ # to merge source harmonics into a single excitation
201
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
202
+ self.l_tanh = torch.nn.Tanh()
203
+
204
+ def forward(self, x):
205
+ """
206
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
207
+ F0_sampled (batchsize, length, 1)
208
+ Sine_source (batchsize, length, 1)
209
+ noise_source (batchsize, length 1)
210
+ """
211
+ # source for harmonic branch
212
+ with torch.no_grad():
213
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
214
+ sine_wavs = sine_wavs.transpose(1, 2)
215
+ uv = uv.transpose(1, 2)
216
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
217
+
218
+ # source for noise branch, in the same shape as uv
219
+ noise = torch.randn_like(uv) * self.sine_amp / 3
220
+ return sine_merge, noise, uv
221
+
222
+
223
+ class HiFTGenerator(nn.Module):
224
+ """
225
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
226
+ https://arxiv.org/abs/2309.09493
227
+ """
228
+ def __init__(
229
+ self,
230
+ in_channels: int = 80,
231
+ base_channels: int = 512,
232
+ nb_harmonics: int = 8,
233
+ sampling_rate: int = 22050,
234
+ nsf_alpha: float = 0.1,
235
+ nsf_sigma: float = 0.003,
236
+ nsf_voiced_threshold: float = 10,
237
+ upsample_rates: List[int] = [8, 8],
238
+ upsample_kernel_sizes: List[int] = [16, 16],
239
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
240
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
241
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
242
+ source_resblock_kernel_sizes: List[int] = [7, 11],
243
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
244
+ lrelu_slope: float = 0.1,
245
+ audio_limit: float = 0.99,
246
+ f0_predictor: torch.nn.Module = None,
247
+ ):
248
+ super(HiFTGenerator, self).__init__()
249
+
250
+ self.out_channels = 1
251
+ self.nb_harmonics = nb_harmonics
252
+ self.sampling_rate = sampling_rate
253
+ self.istft_params = istft_params
254
+ self.lrelu_slope = lrelu_slope
255
+ self.audio_limit = audio_limit
256
+
257
+ self.num_kernels = len(resblock_kernel_sizes)
258
+ self.num_upsamples = len(upsample_rates)
259
+ self.m_source = SourceModuleHnNSF(
260
+ sampling_rate=sampling_rate,
261
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
262
+ harmonic_num=nb_harmonics,
263
+ sine_amp=nsf_alpha,
264
+ add_noise_std=nsf_sigma,
265
+ voiced_threshod=nsf_voiced_threshold)
266
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
267
+
268
+ self.conv_pre = weight_norm(
269
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
270
+ )
271
+
272
+ # Up
273
+ self.ups = nn.ModuleList()
274
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
275
+ self.ups.append(
276
+ weight_norm(
277
+ ConvTranspose1d(
278
+ base_channels // (2**i),
279
+ base_channels // (2**(i + 1)),
280
+ k,
281
+ u,
282
+ padding=(k - u) // 2,
283
+ )
284
+ )
285
+ )
286
+
287
+ # Down
288
+ self.source_downs = nn.ModuleList()
289
+ self.source_resblocks = nn.ModuleList()
290
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
291
+ downsample_cum_rates = np.cumprod(downsample_rates)
292
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
293
+ if u == 1:
294
+ self.source_downs.append(
295
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
296
+ )
297
+ else:
298
+ self.source_downs.append(
299
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
300
+ )
301
+
302
+ self.source_resblocks.append(
303
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
304
+ )
305
+
306
+ self.resblocks = nn.ModuleList()
307
+ for i in range(len(self.ups)):
308
+ ch = base_channels // (2**(i + 1))
309
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
310
+ self.resblocks.append(ResBlock(ch, k, d))
311
+
312
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
313
+ self.ups.apply(init_weights)
314
+ self.conv_post.apply(init_weights)
315
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
316
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
317
+ self.f0_predictor = f0_predictor
318
+
319
+ def remove_weight_norm(self):
320
+ print('Removing weight norm...')
321
+ for l in self.ups:
322
+ remove_weight_norm(l)
323
+ for l in self.resblocks:
324
+ l.remove_weight_norm()
325
+ remove_weight_norm(self.conv_pre)
326
+ remove_weight_norm(self.conv_post)
327
+ self.m_source.remove_weight_norm()
328
+ for l in self.source_downs:
329
+ remove_weight_norm(l)
330
+ for l in self.source_resblocks:
331
+ l.remove_weight_norm()
332
+
333
+ def _stft(self, x):
334
+ spec = torch.stft(
335
+ x,
336
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
337
+ return_complex=True)
338
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
339
+ return spec[..., 0], spec[..., 1]
340
+
341
+ def _istft(self, magnitude, phase):
342
+ magnitude = torch.clip(magnitude, max=1e2)
343
+ real = magnitude * torch.cos(phase)
344
+ img = magnitude * torch.sin(phase)
345
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
346
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
347
+ return inverse_transform
348
+
349
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
350
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
351
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
352
+
353
+ x = self.conv_pre(x)
354
+ for i in range(self.num_upsamples):
355
+ x = F.leaky_relu(x, self.lrelu_slope)
356
+ x = self.ups[i](x)
357
+
358
+ if i == self.num_upsamples - 1:
359
+ x = self.reflection_pad(x)
360
+
361
+ # fusion
362
+ si = self.source_downs[i](s_stft)
363
+ si = self.source_resblocks[i](si)
364
+ x = x + si
365
+
366
+ xs = None
367
+ for j in range(self.num_kernels):
368
+ if xs is None:
369
+ xs = self.resblocks[i * self.num_kernels + j](x)
370
+ else:
371
+ xs += self.resblocks[i * self.num_kernels + j](x)
372
+ x = xs / self.num_kernels
373
+
374
+ x = F.leaky_relu(x)
375
+ x = self.conv_post(x)
376
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
377
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
378
+
379
+ x = self._istft(magnitude, phase)
380
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
381
+ return x
382
+
383
+ def forward(
384
+ self,
385
+ batch: dict,
386
+ device: torch.device,
387
+ ) -> Dict[str, Optional[torch.Tensor]]:
388
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
389
+ # mel->f0
390
+ f0 = self.f0_predictor(speech_feat)
391
+ # f0->source
392
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
393
+ s, _, _ = self.m_source(s)
394
+ s = s.transpose(1, 2)
395
+ # mel+source->speech
396
+ generated_speech = self.decode(x=speech_feat, s=s)
397
+ return generated_speech, f0
398
+
399
+ @torch.inference_mode()
400
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
401
+ # mel->f0
402
+ f0 = self.f0_predictor(speech_feat)
403
+ # f0->source
404
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
405
+ s, _, _ = self.m_source(s)
406
+ s = s.transpose(1, 2)
407
+ # use cache_source to avoid glitch
408
+ if cache_source.shape[2] != 0:
409
+ s[:, :, :cache_source.shape[2]] = cache_source
410
+ generated_speech = self.decode(x=speech_feat, s=s)
411
+ return generated_speech, s
cosyvoice/llm/llm.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Callable, List, Generator
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ from transformers import Qwen2ForCausalLM
19
+ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
20
+ from cosyvoice.utils.common import IGNORE_ID
21
+ from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
22
+ from cosyvoice.utils.common import th_accuracy
23
+ from cosyvoice.utils.file_utils import logging
24
+
25
+
26
+ class TransformerLM(torch.nn.Module):
27
+ def __init__(
28
+ self,
29
+ text_encoder_input_size: int,
30
+ llm_input_size: int,
31
+ llm_output_size: int,
32
+ text_token_size: int,
33
+ speech_token_size: int,
34
+ text_encoder: torch.nn.Module,
35
+ llm: torch.nn.Module,
36
+ sampling: Callable,
37
+ length_normalized_loss: bool = True,
38
+ lsm_weight: float = 0.0,
39
+ spk_embed_dim: int = 192,
40
+ ):
41
+ super().__init__()
42
+ self.llm_input_size = llm_input_size
43
+ self.speech_token_size = speech_token_size
44
+ # 1. build text token inputs related modules
45
+ self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
46
+ self.text_encoder = text_encoder
47
+ self.text_encoder_affine_layer = nn.Linear(
48
+ self.text_encoder.output_size(),
49
+ llm_input_size
50
+ )
51
+
52
+ # 2. build speech token language model related modules
53
+ self.sos_eos = 0
54
+ self.task_id = 1
55
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
56
+ self.llm = llm
57
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
58
+ self.criterion_ce = LabelSmoothingLoss(
59
+ size=speech_token_size + 1,
60
+ padding_idx=IGNORE_ID,
61
+ smoothing=lsm_weight,
62
+ normalize_length=length_normalized_loss,
63
+ )
64
+
65
+ # 3. [Optional] build speech token related modules
66
+ self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
67
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
68
+
69
+ # 4. sampling method
70
+ self.sampling = sampling
71
+
72
+ def encode(
73
+ self,
74
+ text: torch.Tensor,
75
+ text_lengths: torch.Tensor,
76
+ ):
77
+ encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
78
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
79
+ encoder_out = self.text_encoder_affine_layer(encoder_out)
80
+ return encoder_out, encoder_out_lens
81
+
82
+ def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
83
+ text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
84
+ speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
85
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
86
+ for i in range(len(text_token))]
87
+ lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
88
+ lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
89
+ return lm_input, lm_input_len
90
+
91
+ def forward(
92
+ self,
93
+ batch: dict,
94
+ device: torch.device,
95
+ ) -> Dict[str, Optional[torch.Tensor]]:
96
+ """
97
+ Args:
98
+ text: (B, L, D)
99
+ text_lengths: (B,)
100
+ audio: (B, T, N) or (B, T)
101
+ audio_lengths: (B,)
102
+ """
103
+ text_token = batch['text_token'].to(device)
104
+ text_token_len = batch['text_token_len'].to(device)
105
+ speech_token = batch['speech_token'].to(device)
106
+ speech_token_len = batch['speech_token_len'].to(device)
107
+ embedding = batch['embedding'].to(device)
108
+
109
+ # 1. prepare llm_target
110
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
111
+ [self.speech_token_size]) for i in range(text_token.size(0))]
112
+ lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
113
+
114
+ # 1. encode text_token
115
+ text_token = self.text_embedding(text_token)
116
+ text_token, text_token_len = self.encode(text_token, text_token_len)
117
+
118
+ # 2. embedding projection
119
+ embedding = F.normalize(embedding, dim=1)
120
+ embedding = self.spk_embed_affine_layer(embedding)
121
+ embedding = embedding.unsqueeze(1)
122
+
123
+ # 3. eos and task_id
124
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
125
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
126
+
127
+ # 4. encode speech_token
128
+ speech_token = self.speech_embedding(speech_token)
129
+
130
+ # 5. unpad and pad
131
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
132
+ task_id_emb, speech_token, speech_token_len)
133
+
134
+ # 6. run lm forward
135
+ lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
136
+ logits = self.llm_decoder(lm_output)
137
+ loss = self.criterion_ce(logits, lm_target)
138
+ acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
139
+ return {'loss': loss, 'acc': acc}
140
+
141
+ def sampling_ids(
142
+ self,
143
+ weighted_scores: torch.Tensor,
144
+ decoded_tokens: List,
145
+ sampling: int,
146
+ ignore_eos: bool = True,
147
+ ):
148
+ num_trials, max_trials = 0, 100
149
+ while True:
150
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
151
+ if (not ignore_eos) or (self.speech_token_size not in top_ids):
152
+ break
153
+ num_trials += 1
154
+ if num_trials > max_trials:
155
+ raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
156
+ return top_ids
157
+
158
+ @torch.inference_mode()
159
+ def inference(
160
+ self,
161
+ text: torch.Tensor,
162
+ text_len: torch.Tensor,
163
+ prompt_text: torch.Tensor,
164
+ prompt_text_len: torch.Tensor,
165
+ prompt_speech_token: torch.Tensor,
166
+ prompt_speech_token_len: torch.Tensor,
167
+ embedding: torch.Tensor,
168
+ sampling: int = 25,
169
+ max_token_text_ratio: float = 20,
170
+ min_token_text_ratio: float = 2,
171
+ ) -> Generator[torch.Tensor, None, None]:
172
+ if self.fp16 is True:
173
+ embedding = embedding.half()
174
+
175
+ device = text.device
176
+ text = torch.concat([prompt_text, text], dim=1)
177
+ text_len += prompt_text_len
178
+ text = self.text_embedding(text)
179
+
180
+ # 1. encode text
181
+ text, text_len = self.encode(text, text_len)
182
+
183
+ # 2. encode embedding
184
+ if embedding.shape[0] != 0:
185
+ embedding = F.normalize(embedding, dim=1)
186
+ embedding = self.spk_embed_affine_layer(embedding)
187
+ embedding = embedding.unsqueeze(dim=1)
188
+ else:
189
+ embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
190
+
191
+ # 3. concat llm_input
192
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
193
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
194
+ if prompt_speech_token_len != 0:
195
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
196
+ else:
197
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
198
+ lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
199
+
200
+ # 4. cal min/max_length
201
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
202
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
203
+
204
+ # 5. step by step decode
205
+ out_tokens = []
206
+ offset = 0
207
+ att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
208
+ for i in range(max_len):
209
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
210
+ att_cache=att_cache, cnn_cache=cnn_cache,
211
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
212
+ device=lm_input.device)).to(torch.bool))
213
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
214
+ # force continue decode first token
215
+ if i == 0:
216
+ logp[:, self.speech_token_size] = -float('inf')
217
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
218
+ if top_ids == self.speech_token_size:
219
+ break
220
+ # in stream mode, yield token one by one
221
+ yield top_ids
222
+ out_tokens.append(top_ids)
223
+ offset += lm_input.size(1)
224
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
225
+
226
+
227
+ class Qwen2Encoder(torch.nn.Module):
228
+ def __init__(self, pretrain_path):
229
+ super().__init__()
230
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
231
+
232
+ def forward_one_step(self, xs, masks, cache=None):
233
+ input_masks = masks[:, -1, :]
234
+ outs = self.model(
235
+ inputs_embeds=xs,
236
+ attention_mask=input_masks,
237
+ output_hidden_states=True,
238
+ return_dict=True,
239
+ use_cache=True,
240
+ past_key_values=cache,
241
+ )
242
+ xs = outs.hidden_states[-1]
243
+ new_cache = outs.past_key_values
244
+ return xs, new_cache
245
+
246
+
247
+ class Qwen2LM(TransformerLM):
248
+ def __init__(
249
+ self,
250
+ llm_input_size: int,
251
+ llm_output_size: int,
252
+ speech_token_size: int,
253
+ llm: torch.nn.Module,
254
+ sampling: Callable,
255
+ length_normalized_loss: bool = True,
256
+ lsm_weight: float = 0.0,
257
+ mix_ratio: List[int] = [5, 15],
258
+ ):
259
+ torch.nn.Module.__init__(self)
260
+ self.llm_input_size = llm_input_size
261
+ self.llm_output_size = llm_output_size
262
+ self.speech_token_size = speech_token_size
263
+
264
+ # 2. build speech token language model related modules
265
+ self.sos_eos = 0
266
+ self.task_id = 1
267
+ self.fill_token = 2
268
+
269
+ self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
270
+ self.llm = llm
271
+ self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
272
+ self.criterion_ce = LabelSmoothingLoss(
273
+ size=speech_token_size + 3,
274
+ padding_idx=IGNORE_ID,
275
+ smoothing=lsm_weight,
276
+ normalize_length=length_normalized_loss,
277
+ )
278
+
279
+ # 3. [Optional] build speech token related modules
280
+ self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
281
+
282
+ # 4. sampling method
283
+ self.sampling = sampling
284
+ self.mix_ratio = mix_ratio
285
+
286
+ @torch.inference_mode()
287
+ def inference(
288
+ self,
289
+ text: torch.Tensor,
290
+ text_len: torch.Tensor,
291
+ prompt_text: torch.Tensor,
292
+ prompt_text_len: torch.Tensor,
293
+ prompt_speech_token: torch.Tensor,
294
+ prompt_speech_token_len: torch.Tensor,
295
+ embedding: torch.Tensor,
296
+ sampling: int = 25,
297
+ max_token_text_ratio: float = 20,
298
+ min_token_text_ratio: float = 2,
299
+ ) -> Generator[torch.Tensor, None, None]:
300
+ device = text.device
301
+ text = torch.concat([prompt_text, text], dim=1)
302
+ text_len += prompt_text_len
303
+ text = self.llm.model.model.embed_tokens(text)
304
+
305
+ # 3. concat llm_input
306
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
307
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
308
+ if prompt_speech_token_len != 0:
309
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
310
+ else:
311
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
312
+ lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
313
+
314
+ # 4. cal min/max_length
315
+ min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
316
+ max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
317
+
318
+ # 5. step by step decode
319
+ out_tokens = []
320
+ cache = None
321
+ for i in range(max_len):
322
+ y_pred, cache = self.llm.forward_one_step(lm_input,
323
+ masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
324
+ cache=cache)
325
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
326
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
327
+ if top_ids == self.speech_token_size:
328
+ break
329
+ if top_ids > self.speech_token_size:
330
+ continue
331
+ # in stream mode, yield token one by one
332
+ yield top_ids
333
+ out_tokens.append(top_ids)
334
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
335
+
336
+ @torch.inference_mode()
337
+ def inference_bistream(
338
+ self,
339
+ text: Generator,
340
+ prompt_text: torch.Tensor,
341
+ prompt_text_len: torch.Tensor,
342
+ prompt_speech_token: torch.Tensor,
343
+ prompt_speech_token_len: torch.Tensor,
344
+ embedding: torch.Tensor,
345
+ sampling: int = 25,
346
+ max_token_text_ratio: float = 20,
347
+ min_token_text_ratio: float = 2,
348
+ ) -> Generator[torch.Tensor, None, None]:
349
+
350
+ device = prompt_text.device
351
+ # 1. prepare input
352
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
353
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
354
+ if prompt_speech_token_len != 0:
355
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
356
+ else:
357
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
358
+ lm_input = torch.concat([sos_eos_emb], dim=1)
359
+
360
+ # 2. iterate text
361
+ out_tokens = []
362
+ cache = None
363
+ # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
364
+ text_cache = self.llm.model.model.embed_tokens(prompt_text)
365
+ next_fill_index = -1
366
+ for this_text in text:
367
+ text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
368
+ # prompt_speech_token_emb not empty, try append to lm_input
369
+ while prompt_speech_token_emb.size(1) != 0:
370
+ if text_cache.size(1) >= self.mix_ratio[0]:
371
+ lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
372
+ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
373
+ lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
374
+ text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
375
+ else:
376
+ logging.info('not enough text token to decode, wait for more')
377
+ break
378
+ # no prompt_speech_token_emb remain, can decode some speech token
379
+ if prompt_speech_token_emb.size(1) == 0:
380
+ if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
381
+ logging.info('get fill token, need to append more text token')
382
+ if text_cache.size(1) >= self.mix_ratio[0]:
383
+ lm_input_text = text_cache[:, :self.mix_ratio[0]]
384
+ logging.info('append {} text token'.format(lm_input_text.size(1)))
385
+ if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
386
+ lm_input = lm_input_text
387
+ else:
388
+ lm_input = torch.concat([lm_input, lm_input_text], dim=1)
389
+ text_cache = text_cache[:, self.mix_ratio[0]:]
390
+ else:
391
+ logging.info('not enough text token to decode, wait for more')
392
+ continue
393
+ while True:
394
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
395
+ y_pred, cache = self.llm.forward_one_step(lm_input,
396
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
397
+ cache=cache)
398
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
399
+ if next_fill_index != -1 and len(out_tokens) == next_fill_index:
400
+ top_ids = self.speech_token_size + 2
401
+ next_fill_index += (self.mix_ratio[1] + 1)
402
+ else:
403
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
404
+ if top_ids == self.speech_token_size + 2:
405
+ next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
406
+ logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
407
+ out_tokens.append(top_ids)
408
+ if top_ids >= self.speech_token_size:
409
+ if top_ids == self.speech_token_size + 2:
410
+ break
411
+ else:
412
+ raise ValueError('should not get token {}'.format(top_ids))
413
+ yield top_ids
414
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
415
+
416
+ # 3. final decode
417
+ lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
418
+ logging.info('no more text token, decode until met eos')
419
+ while True:
420
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
421
+ y_pred, cache = self.llm.forward_one_step(lm_input,
422
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
423
+ cache=cache)
424
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
425
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
426
+ out_tokens.append(top_ids)
427
+ if top_ids >= self.speech_token_size:
428
+ if top_ids == self.speech_token_size:
429
+ break
430
+ else:
431
+ raise ValueError('should not get token {}'.format(top_ids))
432
+ # in stream mode, yield token one by one
433
+ yield top_ids
434
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
cosyvoice/transformer/__init__.py ADDED
File without changes
cosyvoice/transformer/attention.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song ([email protected])
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
239
+ x_padded = torch.cat([zero_pad, x], dim=-1)
240
+
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
244
+ x = x_padded[:, :, 1:].view_as(x)[
245
+ :, :, :, : x.size(-1) // 2 + 1
246
+ ] # only keep the positions from 0 to time2
247
+ return x
248
+
249
+ def forward(
250
+ self,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ value: torch.Tensor,
254
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
255
+ pos_emb: torch.Tensor = torch.empty(0),
256
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
259
+ Args:
260
+ query (torch.Tensor): Query tensor (#batch, time1, size).
261
+ key (torch.Tensor): Key tensor (#batch, time2, size).
262
+ value (torch.Tensor): Value tensor (#batch, time2, size).
263
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
264
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
265
+ pos_emb (torch.Tensor): Positional embedding tensor
266
+ (#batch, time2, size).
267
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
268
+ where `cache_t == chunk_size * num_decoding_left_chunks`
269
+ and `head * d_k == size`
270
+ Returns:
271
+ torch.Tensor: Output tensor (#batch, time1, d_model).
272
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
273
+ where `cache_t == chunk_size * num_decoding_left_chunks`
274
+ and `head * d_k == size`
275
+ """
276
+ q, k, v = self.forward_qkv(query, key, value)
277
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
278
+
279
+ # NOTE(xcsong):
280
+ # when export onnx model, for 1st chunk, we feed
281
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
282
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
283
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
284
+ # and we will always do splitting and
285
+ # concatnation(this will simplify onnx export). Note that
286
+ # it's OK to concat & split zero-shaped tensors(see code below).
287
+ # when export jit model, for 1st chunk, we always feed
288
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
289
+ # >>> a = torch.ones((1, 2, 0, 4))
290
+ # >>> b = torch.ones((1, 2, 3, 4))
291
+ # >>> c = torch.cat((a, b), dim=2)
292
+ # >>> torch.equal(b, c) # True
293
+ # >>> d = torch.split(a, 2, dim=-1)
294
+ # >>> torch.equal(d[0], d[1]) # True
295
+ if cache.size(0) > 0:
296
+ key_cache, value_cache = torch.split(cache,
297
+ cache.size(-1) // 2,
298
+ dim=-1)
299
+ k = torch.cat([key_cache, k], dim=2)
300
+ v = torch.cat([value_cache, v], dim=2)
301
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
302
+ # non-trivial to calculate `next_cache_start` here.
303
+ new_cache = torch.cat((k, v), dim=-1)
304
+
305
+ n_batch_pos = pos_emb.size(0)
306
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
307
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
308
+
309
+ # (batch, head, time1, d_k)
310
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
311
+ # (batch, head, time1, d_k)
312
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
313
+
314
+ # compute attention score
315
+ # first compute matrix a and matrix c
316
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
317
+ # (batch, head, time1, time2)
318
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
319
+
320
+ # compute matrix b and matrix d
321
+ # (batch, head, time1, time2)
322
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
323
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
324
+ if matrix_ac.shape != matrix_bd.shape:
325
+ matrix_bd = self.rel_shift(matrix_bd)
326
+
327
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
328
+ self.d_k) # (batch, head, time1, time2)
329
+
330
+ return self.forward_attention(v, scores, mask), new_cache
cosyvoice/transformer/convolution.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConvolutionModule(nn.Module):
25
+ """ConvolutionModule in Conformer model."""
26
+
27
+ def __init__(self,
28
+ channels: int,
29
+ kernel_size: int = 15,
30
+ activation: nn.Module = nn.ReLU(),
31
+ norm: str = "batch_norm",
32
+ causal: bool = False,
33
+ bias: bool = True):
34
+ """Construct an ConvolutionModule object.
35
+ Args:
36
+ channels (int): The number of channels of conv layers.
37
+ kernel_size (int): Kernel size of conv layers.
38
+ causal (int): Whether use causal convolution or not
39
+ """
40
+ super().__init__()
41
+
42
+ self.pointwise_conv1 = nn.Conv1d(
43
+ channels,
44
+ 2 * channels,
45
+ kernel_size=1,
46
+ stride=1,
47
+ padding=0,
48
+ bias=bias,
49
+ )
50
+ # self.lorder is used to distinguish if it's a causal convolution,
51
+ # if self.lorder > 0: it's a causal convolution, the input will be
52
+ # padded with self.lorder frames on the left in forward.
53
+ # else: it's a symmetrical convolution
54
+ if causal:
55
+ padding = 0
56
+ self.lorder = kernel_size - 1
57
+ else:
58
+ # kernel_size should be an odd number for none causal convolution
59
+ assert (kernel_size - 1) % 2 == 0
60
+ padding = (kernel_size - 1) // 2
61
+ self.lorder = 0
62
+ self.depthwise_conv = nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ stride=1,
67
+ padding=padding,
68
+ groups=channels,
69
+ bias=bias,
70
+ )
71
+
72
+ assert norm in ['batch_norm', 'layer_norm']
73
+ if norm == "batch_norm":
74
+ self.use_layer_norm = False
75
+ self.norm = nn.BatchNorm1d(channels)
76
+ else:
77
+ self.use_layer_norm = True
78
+ self.norm = nn.LayerNorm(channels)
79
+
80
+ self.pointwise_conv2 = nn.Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size=1,
84
+ stride=1,
85
+ padding=0,
86
+ bias=bias,
87
+ )
88
+ self.activation = activation
89
+
90
+ def forward(
91
+ self,
92
+ x: torch.Tensor,
93
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ """Compute convolution module.
97
+ Args:
98
+ x (torch.Tensor): Input tensor (#batch, time, channels).
99
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100
+ (0, 0, 0) means fake mask.
101
+ cache (torch.Tensor): left context cache, it is only
102
+ used in causal convolution (#batch, channels, cache_t),
103
+ (0, 0, 0) meas fake cache.
104
+ Returns:
105
+ torch.Tensor: Output tensor (#batch, time, channels).
106
+ """
107
+ # exchange the temporal dimension and the feature dimension
108
+ x = x.transpose(1, 2) # (#batch, channels, time)
109
+
110
+ # mask batch padding
111
+ if mask_pad.size(2) > 0: # time > 0
112
+ x.masked_fill_(~mask_pad, 0.0)
113
+
114
+ if self.lorder > 0:
115
+ if cache.size(2) == 0: # cache_t == 0
116
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117
+ else:
118
+ assert cache.size(0) == x.size(0) # equal batch
119
+ assert cache.size(1) == x.size(1) # equal channel
120
+ x = torch.cat((cache, x), dim=2)
121
+ assert (x.size(2) > self.lorder)
122
+ new_cache = x[:, :, -self.lorder:]
123
+ else:
124
+ # It's better we just return None if no cache is required,
125
+ # However, for JIT export, here we just fake one tensor instead of
126
+ # None.
127
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128
+
129
+ # GLU mechanism
130
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132
+
133
+ # 1D Depthwise Conv
134
+ x = self.depthwise_conv(x)
135
+ if self.use_layer_norm:
136
+ x = x.transpose(1, 2)
137
+ x = self.activation(self.norm(x))
138
+ if self.use_layer_norm:
139
+ x = x.transpose(1, 2)
140
+ x = self.pointwise_conv2(x)
141
+ # mask batch padding
142
+ if mask_pad.size(2) > 0: # time > 0
143
+ x.masked_fill_(~mask_pad, 0.0)
144
+
145
+ return x.transpose(1, 2), new_cache
cosyvoice/transformer/decoder.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Decoder definition."""
17
+ from typing import Tuple, List, Optional
18
+
19
+ import torch
20
+ import torch.utils.checkpoint as ckpt
21
+ import logging
22
+
23
+ from cosyvoice.transformer.decoder_layer import DecoderLayer
24
+ from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
25
+ from cosyvoice.utils.class_utils import (
26
+ COSYVOICE_EMB_CLASSES,
27
+ COSYVOICE_ATTENTION_CLASSES,
28
+ COSYVOICE_ACTIVATION_CLASSES,
29
+ )
30
+ from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
31
+
32
+
33
+ class TransformerDecoder(torch.nn.Module):
34
+ """Base class of Transfomer decoder module.
35
+ Args:
36
+ vocab_size: output dim
37
+ encoder_output_size: dimension of attention
38
+ attention_heads: the number of heads of multi head attention
39
+ linear_units: the hidden units number of position-wise feedforward
40
+ num_blocks: the number of decoder blocks
41
+ dropout_rate: dropout rate
42
+ self_attention_dropout_rate: dropout rate for attention
43
+ input_layer: input layer type
44
+ use_output_layer: whether to use output layer
45
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
46
+ normalize_before:
47
+ True: use layer_norm before each sub-block of a layer.
48
+ False: use layer_norm after each sub-block of a layer.
49
+ src_attention: if false, encoder-decoder cross attention is not
50
+ applied, such as CIF model
51
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
52
+ gradient_checkpointing: rerunning a forward-pass segment for each
53
+ checkpointed segment during backward.
54
+ tie_word_embedding: Tie or clone module weights depending of whether we are
55
+ using TorchScript or not
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ vocab_size: int,
61
+ encoder_output_size: int,
62
+ attention_heads: int = 4,
63
+ linear_units: int = 2048,
64
+ num_blocks: int = 6,
65
+ dropout_rate: float = 0.1,
66
+ positional_dropout_rate: float = 0.1,
67
+ self_attention_dropout_rate: float = 0.0,
68
+ src_attention_dropout_rate: float = 0.0,
69
+ input_layer: str = "embed",
70
+ use_output_layer: bool = True,
71
+ normalize_before: bool = True,
72
+ src_attention: bool = True,
73
+ key_bias: bool = True,
74
+ activation_type: str = "relu",
75
+ gradient_checkpointing: bool = False,
76
+ tie_word_embedding: bool = False,
77
+ ):
78
+ super().__init__()
79
+ attention_dim = encoder_output_size
80
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
81
+
82
+ self.embed = torch.nn.Sequential(
83
+ torch.nn.Identity() if input_layer == "no_pos" else
84
+ torch.nn.Embedding(vocab_size, attention_dim),
85
+ COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
86
+ positional_dropout_rate),
87
+ )
88
+
89
+ self.normalize_before = normalize_before
90
+ self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
91
+ self.use_output_layer = use_output_layer
92
+ if use_output_layer:
93
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
94
+ else:
95
+ self.output_layer = torch.nn.Identity()
96
+ self.num_blocks = num_blocks
97
+ self.decoders = torch.nn.ModuleList([
98
+ DecoderLayer(
99
+ attention_dim,
100
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
101
+ attention_heads, attention_dim,
102
+ self_attention_dropout_rate, key_bias),
103
+ COSYVOICE_ATTENTION_CLASSES["selfattn"](
104
+ attention_heads, attention_dim, src_attention_dropout_rate,
105
+ key_bias) if src_attention else None,
106
+ PositionwiseFeedForward(attention_dim, linear_units,
107
+ dropout_rate, activation),
108
+ dropout_rate,
109
+ normalize_before,
110
+ ) for _ in range(self.num_blocks)
111
+ ])
112
+
113
+ self.gradient_checkpointing = gradient_checkpointing
114
+ self.tie_word_embedding = tie_word_embedding
115
+
116
+ def forward(
117
+ self,
118
+ memory: torch.Tensor,
119
+ memory_mask: torch.Tensor,
120
+ ys_in_pad: torch.Tensor,
121
+ ys_in_lens: torch.Tensor,
122
+ r_ys_in_pad: torch.Tensor = torch.empty(0),
123
+ reverse_weight: float = 0.0,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125
+ """Forward decoder.
126
+ Args:
127
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
128
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
129
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
130
+ ys_in_lens: input lengths of this batch (batch)
131
+ r_ys_in_pad: not used in transformer decoder, in order to unify api
132
+ with bidirectional decoder
133
+ reverse_weight: not used in transformer decoder, in order to unify
134
+ api with bidirectional decode
135
+ Returns:
136
+ (tuple): tuple containing:
137
+ x: decoded token score before softmax (batch, maxlen_out,
138
+ vocab_size) if use_output_layer is True,
139
+ torch.tensor(0.0), in order to unify api with bidirectional decoder
140
+ olens: (batch, )
141
+ NOTE(xcsong):
142
+ We pass the `__call__` method of the modules instead of `forward` to the
143
+ checkpointing API because `__call__` attaches all the hooks of the module.
144
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
145
+ """
146
+ tgt = ys_in_pad
147
+ maxlen = tgt.size(1)
148
+ # tgt_mask: (B, 1, L)
149
+ tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
150
+ tgt_mask = tgt_mask.to(tgt.device)
151
+ # m: (1, L, L)
152
+ m = subsequent_mask(tgt_mask.size(-1),
153
+ device=tgt_mask.device).unsqueeze(0)
154
+ # tgt_mask: (B, L, L)
155
+ tgt_mask = tgt_mask & m
156
+ x, _ = self.embed(tgt)
157
+ if self.gradient_checkpointing and self.training:
158
+ x = self.forward_layers_checkpointed(x, tgt_mask, memory,
159
+ memory_mask)
160
+ else:
161
+ x = self.forward_layers(x, tgt_mask, memory, memory_mask)
162
+ if self.normalize_before:
163
+ x = self.after_norm(x)
164
+ if self.use_output_layer:
165
+ x = self.output_layer(x)
166
+ olens = tgt_mask.sum(1)
167
+ return x, torch.tensor(0.0), olens
168
+
169
+ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
170
+ memory: torch.Tensor,
171
+ memory_mask: torch.Tensor) -> torch.Tensor:
172
+ for layer in self.decoders:
173
+ x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
174
+ memory_mask)
175
+ return x
176
+
177
+ @torch.jit.unused
178
+ def forward_layers_checkpointed(self, x: torch.Tensor,
179
+ tgt_mask: torch.Tensor,
180
+ memory: torch.Tensor,
181
+ memory_mask: torch.Tensor) -> torch.Tensor:
182
+ for layer in self.decoders:
183
+ x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
184
+ layer.__call__, x, tgt_mask, memory, memory_mask)
185
+ return x
186
+
187
+ def forward_one_step(
188
+ self,
189
+ memory: torch.Tensor,
190
+ memory_mask: torch.Tensor,
191
+ tgt: torch.Tensor,
192
+ tgt_mask: torch.Tensor,
193
+ cache: Optional[List[torch.Tensor]] = None,
194
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
195
+ """Forward one step.
196
+ This is only used for decoding.
197
+ Args:
198
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
199
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
200
+ tgt: input token ids, int64 (batch, maxlen_out)
201
+ tgt_mask: input token mask, (batch, maxlen_out)
202
+ dtype=torch.uint8 in PyTorch 1.2-
203
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
204
+ cache: cached output list of (batch, max_time_out-1, size)
205
+ Returns:
206
+ y, cache: NN output value and cache per `self.decoders`.
207
+ y.shape` is (batch, maxlen_out, token)
208
+ """
209
+ x, _ = self.embed(tgt)
210
+ new_cache = []
211
+ for i, decoder in enumerate(self.decoders):
212
+ if cache is None:
213
+ c = None
214
+ else:
215
+ c = cache[i]
216
+ x, tgt_mask, memory, memory_mask = decoder(x,
217
+ tgt_mask,
218
+ memory,
219
+ memory_mask,
220
+ cache=c)
221
+ new_cache.append(x)
222
+ if self.normalize_before:
223
+ y = self.after_norm(x[:, -1])
224
+ else:
225
+ y = x[:, -1]
226
+ if self.use_output_layer:
227
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
228
+ return y, new_cache
229
+
230
+ def tie_or_clone_weights(self, jit_mode: bool = True):
231
+ """Tie or clone module weights (between word_emb and output_layer)
232
+ depending of whether we are using TorchScript or not"""
233
+ if not self.use_output_layer:
234
+ return
235
+ if jit_mode:
236
+ logging.info("clone emb.weight to output.weight")
237
+ self.output_layer.weight = torch.nn.Parameter(
238
+ self.embed[0].weight.clone())
239
+ else:
240
+ logging.info("tie emb.weight with output.weight")
241
+ self.output_layer.weight = self.embed[0].weight
242
+
243
+ if getattr(self.output_layer, "bias", None) is not None:
244
+ self.output_layer.bias.data = torch.nn.functional.pad(
245
+ self.output_layer.bias.data,
246
+ (
247
+ 0,
248
+ self.output_layer.weight.shape[0] -
249
+ self.output_layer.bias.shape[0],
250
+ ),
251
+ "constant",
252
+ 0,
253
+ )
254
+
255
+
256
+ class BiTransformerDecoder(torch.nn.Module):
257
+ """Base class of Transfomer decoder module.
258
+ Args:
259
+ vocab_size: output dim
260
+ encoder_output_size: dimension of attention
261
+ attention_heads: the number of heads of multi head attention
262
+ linear_units: the hidden units number of position-wise feedforward
263
+ num_blocks: the number of decoder blocks
264
+ r_num_blocks: the number of right to left decoder blocks
265
+ dropout_rate: dropout rate
266
+ self_attention_dropout_rate: dropout rate for attention
267
+ input_layer: input layer type
268
+ use_output_layer: whether to use output layer
269
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
270
+ normalize_before:
271
+ True: use layer_norm before each sub-block of a layer.
272
+ False: use layer_norm after each sub-block of a layer.
273
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ vocab_size: int,
279
+ encoder_output_size: int,
280
+ attention_heads: int = 4,
281
+ linear_units: int = 2048,
282
+ num_blocks: int = 6,
283
+ r_num_blocks: int = 0,
284
+ dropout_rate: float = 0.1,
285
+ positional_dropout_rate: float = 0.1,
286
+ self_attention_dropout_rate: float = 0.0,
287
+ src_attention_dropout_rate: float = 0.0,
288
+ input_layer: str = "embed",
289
+ use_output_layer: bool = True,
290
+ normalize_before: bool = True,
291
+ key_bias: bool = True,
292
+ gradient_checkpointing: bool = False,
293
+ tie_word_embedding: bool = False,
294
+ ):
295
+
296
+ super().__init__()
297
+ self.tie_word_embedding = tie_word_embedding
298
+ self.left_decoder = TransformerDecoder(
299
+ vocab_size,
300
+ encoder_output_size,
301
+ attention_heads,
302
+ linear_units,
303
+ num_blocks,
304
+ dropout_rate,
305
+ positional_dropout_rate,
306
+ self_attention_dropout_rate,
307
+ src_attention_dropout_rate,
308
+ input_layer,
309
+ use_output_layer,
310
+ normalize_before,
311
+ key_bias=key_bias,
312
+ gradient_checkpointing=gradient_checkpointing,
313
+ tie_word_embedding=tie_word_embedding)
314
+
315
+ self.right_decoder = TransformerDecoder(
316
+ vocab_size,
317
+ encoder_output_size,
318
+ attention_heads,
319
+ linear_units,
320
+ r_num_blocks,
321
+ dropout_rate,
322
+ positional_dropout_rate,
323
+ self_attention_dropout_rate,
324
+ src_attention_dropout_rate,
325
+ input_layer,
326
+ use_output_layer,
327
+ normalize_before,
328
+ key_bias=key_bias,
329
+ gradient_checkpointing=gradient_checkpointing,
330
+ tie_word_embedding=tie_word_embedding)
331
+
332
+ def forward(
333
+ self,
334
+ memory: torch.Tensor,
335
+ memory_mask: torch.Tensor,
336
+ ys_in_pad: torch.Tensor,
337
+ ys_in_lens: torch.Tensor,
338
+ r_ys_in_pad: torch.Tensor,
339
+ reverse_weight: float = 0.0,
340
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
341
+ """Forward decoder.
342
+ Args:
343
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
344
+ memory_mask: encoder memory mask, (batch, 1, maxlen_in)
345
+ ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
346
+ ys_in_lens: input lengths of this batch (batch)
347
+ r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
348
+ used for right to left decoder
349
+ reverse_weight: used for right to left decoder
350
+ Returns:
351
+ (tuple): tuple containing:
352
+ x: decoded token score before softmax (batch, maxlen_out,
353
+ vocab_size) if use_output_layer is True,
354
+ r_x: x: decoded token score (right to left decoder)
355
+ before softmax (batch, maxlen_out, vocab_size)
356
+ if use_output_layer is True,
357
+ olens: (batch, )
358
+ """
359
+ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
360
+ ys_in_lens)
361
+ r_x = torch.tensor(0.0)
362
+ if reverse_weight > 0.0:
363
+ r_x, _, olens = self.right_decoder(memory, memory_mask,
364
+ r_ys_in_pad, ys_in_lens)
365
+ return l_x, r_x, olens
366
+
367
+ def forward_one_step(
368
+ self,
369
+ memory: torch.Tensor,
370
+ memory_mask: torch.Tensor,
371
+ tgt: torch.Tensor,
372
+ tgt_mask: torch.Tensor,
373
+ cache: Optional[List[torch.Tensor]] = None,
374
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
375
+ """Forward one step.
376
+ This is only used for decoding.
377
+ Args:
378
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
379
+ memory_mask: encoded memory mask, (batch, 1, maxlen_in)
380
+ tgt: input token ids, int64 (batch, maxlen_out)
381
+ tgt_mask: input token mask, (batch, maxlen_out)
382
+ dtype=torch.uint8 in PyTorch 1.2-
383
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
384
+ cache: cached output list of (batch, max_time_out-1, size)
385
+ Returns:
386
+ y, cache: NN output value and cache per `self.decoders`.
387
+ y.shape` is (batch, maxlen_out, token)
388
+ """
389
+ return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
390
+ tgt_mask, cache)
391
+
392
+ def tie_or_clone_weights(self, jit_mode: bool = True):
393
+ """Tie or clone module weights (between word_emb and output_layer)
394
+ depending of whether we are using TorchScript or not"""
395
+ self.left_decoder.tie_or_clone_weights(jit_mode)
396
+ self.right_decoder.tie_or_clone_weights(jit_mode)
cosyvoice/transformer/decoder_layer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Decoder self-attention layer definition."""
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ class DecoderLayer(nn.Module):
23
+ """Single decoder layer module.
24
+
25
+ Args:
26
+ size (int): Input dimension.
27
+ self_attn (torch.nn.Module): Self-attention module instance.
28
+ `MultiHeadedAttention` instance can be used as the argument.
29
+ src_attn (torch.nn.Module): Inter-attention module instance.
30
+ `MultiHeadedAttention` instance can be used as the argument.
31
+ If `None` is passed, Inter-attention is not used, such as
32
+ CIF, GPT, and other decoder only model.
33
+ feed_forward (torch.nn.Module): Feed-forward module instance.
34
+ `PositionwiseFeedForward` instance can be used as the argument.
35
+ dropout_rate (float): Dropout rate.
36
+ normalize_before (bool):
37
+ True: use layer_norm before each sub-block.
38
+ False: to use layer_norm after each sub-block.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ size: int,
44
+ self_attn: nn.Module,
45
+ src_attn: Optional[nn.Module],
46
+ feed_forward: nn.Module,
47
+ dropout_rate: float,
48
+ normalize_before: bool = True,
49
+ ):
50
+ """Construct an DecoderLayer object."""
51
+ super().__init__()
52
+ self.size = size
53
+ self.self_attn = self_attn
54
+ self.src_attn = src_attn
55
+ self.feed_forward = feed_forward
56
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
57
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
58
+ self.norm3 = nn.LayerNorm(size, eps=1e-5)
59
+ self.dropout = nn.Dropout(dropout_rate)
60
+ self.normalize_before = normalize_before
61
+
62
+ def forward(
63
+ self,
64
+ tgt: torch.Tensor,
65
+ tgt_mask: torch.Tensor,
66
+ memory: torch.Tensor,
67
+ memory_mask: torch.Tensor,
68
+ cache: Optional[torch.Tensor] = None
69
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
70
+ """Compute decoded features.
71
+
72
+ Args:
73
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
74
+ tgt_mask (torch.Tensor): Mask for input tensor
75
+ (#batch, maxlen_out).
76
+ memory (torch.Tensor): Encoded memory
77
+ (#batch, maxlen_in, size).
78
+ memory_mask (torch.Tensor): Encoded memory mask
79
+ (#batch, maxlen_in).
80
+ cache (torch.Tensor): cached tensors.
81
+ (#batch, maxlen_out - 1, size).
82
+
83
+ Returns:
84
+ torch.Tensor: Output tensor (#batch, maxlen_out, size).
85
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
86
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
87
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
88
+
89
+ """
90
+ residual = tgt
91
+ if self.normalize_before:
92
+ tgt = self.norm1(tgt)
93
+
94
+ if cache is None:
95
+ tgt_q = tgt
96
+ tgt_q_mask = tgt_mask
97
+ else:
98
+ # compute only the last frame query keeping dim: max_time_out -> 1
99
+ assert cache.shape == (
100
+ tgt.shape[0],
101
+ tgt.shape[1] - 1,
102
+ self.size,
103
+ ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
104
+ tgt_q = tgt[:, -1:, :]
105
+ residual = residual[:, -1:, :]
106
+ tgt_q_mask = tgt_mask[:, -1:, :]
107
+
108
+ x = residual + self.dropout(
109
+ self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
110
+ if not self.normalize_before:
111
+ x = self.norm1(x)
112
+
113
+ if self.src_attn is not None:
114
+ residual = x
115
+ if self.normalize_before:
116
+ x = self.norm2(x)
117
+ x = residual + self.dropout(
118
+ self.src_attn(x, memory, memory, memory_mask)[0])
119
+ if not self.normalize_before:
120
+ x = self.norm2(x)
121
+
122
+ residual = x
123
+ if self.normalize_before:
124
+ x = self.norm3(x)
125
+ x = residual + self.dropout(self.feed_forward(x))
126
+ if not self.normalize_before:
127
+ x = self.norm3(x)
128
+
129
+ if cache is not None:
130
+ x = torch.cat([cache, x], dim=1)
131
+
132
+ return x, tgt_mask, memory, memory_mask
cosyvoice/utils/__init__.py ADDED
File without changes
cosyvoice/utils/class_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright [2023-11-28] <[email protected], Xingchen Song>
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from cosyvoice.transformer.activation import Swish
18
+ from cosyvoice.transformer.subsampling import (
19
+ LinearNoSubsampling,
20
+ EmbedinigNoSubsampling,
21
+ Conv1dSubsampling2,
22
+ Conv2dSubsampling4,
23
+ Conv2dSubsampling6,
24
+ Conv2dSubsampling8,
25
+ )
26
+ from cosyvoice.transformer.embedding import (PositionalEncoding,
27
+ RelPositionalEncoding,
28
+ WhisperPositionalEncoding,
29
+ LearnablePositionalEncoding,
30
+ NoPositionalEncoding)
31
+ from cosyvoice.transformer.attention import (MultiHeadedAttention,
32
+ RelPositionMultiHeadedAttention)
33
+ from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
34
+ from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
35
+ from cosyvoice.llm.llm import TransformerLM, Qwen2LM
36
+ from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
37
+ from cosyvoice.hifigan.generator import HiFTGenerator
38
+ from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
39
+
40
+
41
+ COSYVOICE_ACTIVATION_CLASSES = {
42
+ "hardtanh": torch.nn.Hardtanh,
43
+ "tanh": torch.nn.Tanh,
44
+ "relu": torch.nn.ReLU,
45
+ "selu": torch.nn.SELU,
46
+ "swish": getattr(torch.nn, "SiLU", Swish),
47
+ "gelu": torch.nn.GELU,
48
+ }
49
+
50
+ COSYVOICE_SUBSAMPLE_CLASSES = {
51
+ "linear": LinearNoSubsampling,
52
+ "linear_legacy": LegacyLinearNoSubsampling,
53
+ "embed": EmbedinigNoSubsampling,
54
+ "conv1d2": Conv1dSubsampling2,
55
+ "conv2d": Conv2dSubsampling4,
56
+ "conv2d6": Conv2dSubsampling6,
57
+ "conv2d8": Conv2dSubsampling8,
58
+ 'paraformer_dummy': torch.nn.Identity
59
+ }
60
+
61
+ COSYVOICE_EMB_CLASSES = {
62
+ "embed": PositionalEncoding,
63
+ "abs_pos": PositionalEncoding,
64
+ "rel_pos": RelPositionalEncoding,
65
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
66
+ "no_pos": NoPositionalEncoding,
67
+ "abs_pos_whisper": WhisperPositionalEncoding,
68
+ "embed_learnable_pe": LearnablePositionalEncoding,
69
+ }
70
+
71
+ COSYVOICE_ATTENTION_CLASSES = {
72
+ "selfattn": MultiHeadedAttention,
73
+ "rel_selfattn": RelPositionMultiHeadedAttention,
74
+ }
75
+
76
+
77
+ def get_model_type(configs):
78
+ # NOTE CosyVoice2Model inherits CosyVoiceModel
79
+ if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
80
+ return CosyVoiceModel
81
+ if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
82
+ return CosyVoice2Model
83
+ raise TypeError('No valid model type found!')
cosyvoice/utils/scheduler.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2022 Ximalaya Inc (Yuguang Yang)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ # NeMo(https://github.com/NVIDIA/NeMo)
18
+
19
+ from typing import Union
20
+
21
+ import math
22
+ import warnings
23
+ import torch
24
+ from torch.optim.lr_scheduler import _LRScheduler
25
+
26
+
27
+ class WarmupLR(_LRScheduler):
28
+ """The WarmupLR scheduler
29
+
30
+ This scheduler is almost same as NoamLR Scheduler except for following
31
+ difference:
32
+
33
+ NoamLR:
34
+ lr = optimizer.lr * model_size ** -0.5
35
+ * min(step ** -0.5, step * warmup_step ** -1.5)
36
+ WarmupLR:
37
+ lr = optimizer.lr * warmup_step ** 0.5
38
+ * min(step ** -0.5, step * warmup_step ** -1.5)
39
+
40
+ Note that the maximum lr equals to optimizer.lr in this scheduler.
41
+
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ optimizer: torch.optim.Optimizer,
47
+ warmup_steps: Union[int, float] = 25000,
48
+ last_epoch: int = -1,
49
+ ):
50
+ self.warmup_steps = warmup_steps
51
+
52
+ # __init__() must be invoked before setting field
53
+ # because step() is also invoked in __init__()
54
+ super().__init__(optimizer, last_epoch)
55
+
56
+ def __repr__(self):
57
+ return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
58
+
59
+ def get_lr(self):
60
+ step_num = self.last_epoch + 1
61
+ if self.warmup_steps == 0:
62
+ return [lr * step_num**-0.5 for lr in self.base_lrs]
63
+ else:
64
+ return [
65
+ lr * self.warmup_steps**0.5 *
66
+ min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
67
+ for lr in self.base_lrs
68
+ ]
69
+
70
+ def set_step(self, step: int):
71
+ self.last_epoch = step
72
+
73
+
74
+ class WarmupPolicy(_LRScheduler):
75
+ """Adds warmup kwargs and warmup logic to lr policy.
76
+ All arguments should be passed as kwargs for clarity,
77
+ Args:
78
+ warmup_steps: Number of training steps in warmup stage
79
+ warmup_ratio: Ratio of warmup steps to total steps
80
+ max_steps: Total number of steps while training or `None` for
81
+ infinite training
82
+ """
83
+
84
+ def __init__(self,
85
+ optimizer,
86
+ *,
87
+ warmup_steps=None,
88
+ warmup_ratio=None,
89
+ max_steps=None,
90
+ min_lr=0.0,
91
+ last_epoch=-1):
92
+ assert not (warmup_steps is not None and warmup_ratio is not None),\
93
+ "Either use particular number of step or ratio"
94
+ assert warmup_ratio is None or max_steps is not None, \
95
+ "If there is a ratio, there should be a total steps"
96
+
97
+ # It is necessary to assign all attributes *before* __init__,
98
+ # as class is wrapped by an inner class.
99
+ self.max_steps = max_steps
100
+ if warmup_steps is not None:
101
+ self.warmup_steps = warmup_steps
102
+ elif warmup_ratio is not None:
103
+ self.warmup_steps = int(warmup_ratio * max_steps)
104
+ else:
105
+ self.warmup_steps = 0
106
+
107
+ self.min_lr = min_lr
108
+ super().__init__(optimizer, last_epoch)
109
+
110
+ def get_lr(self):
111
+ if not self._get_lr_called_within_step:
112
+ warnings.warn(
113
+ "To get the last learning rate computed "
114
+ "by the scheduler, please use `get_last_lr()`.",
115
+ UserWarning,
116
+ stacklevel=2)
117
+
118
+ step = self.last_epoch
119
+
120
+ if step <= self.warmup_steps and self.warmup_steps > 0:
121
+ return self._get_warmup_lr(step)
122
+
123
+ if step > self.max_steps:
124
+ return [self.min_lr for _ in self.base_lrs]
125
+
126
+ return self._get_lr(step)
127
+
128
+ def _get_warmup_lr(self, step):
129
+ lr_val = (step + 1) / (self.warmup_steps + 1)
130
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
131
+
132
+ def _get_lr(self, step):
133
+ """Simple const lr policy"""
134
+ return self.base_lrs
135
+
136
+
137
+ class SquareRootConstantPolicy(_LRScheduler):
138
+ """Adds warmup kwargs and warmup logic to lr policy.
139
+ All arguments should be passed as kwargs for clarity,
140
+ Args:
141
+ warmup_steps: Number of training steps in warmup stage
142
+ warmup_ratio: Ratio of warmup steps to total steps
143
+ max_steps: Total number of steps while training or `None` for
144
+ infinite training
145
+ """
146
+
147
+ def __init__(self,
148
+ optimizer,
149
+ *,
150
+ constant_steps=None,
151
+ constant_ratio=None,
152
+ max_steps=None,
153
+ min_lr=0.0,
154
+ last_epoch=-1):
155
+ assert not (constant_steps is not None
156
+ and constant_ratio is not None), \
157
+ "Either use particular number of step or ratio"
158
+ assert constant_ratio is None or max_steps is not None, \
159
+ "If there is a ratio, there should be a total steps"
160
+
161
+ # It is necessary to assign all attributes *before* __init__,
162
+ # as class is wrapped by an inner class.
163
+ self.max_steps = max_steps
164
+ if constant_steps is not None:
165
+ self.constant_steps = constant_steps
166
+ elif constant_ratio is not None:
167
+ self.constant_steps = int(constant_ratio * max_steps)
168
+ else:
169
+ self.constant_steps = 0
170
+
171
+ self.constant_lr = 1 / (constant_steps**0.5)
172
+ self.min_lr = min_lr
173
+ super().__init__(optimizer, last_epoch)
174
+
175
+ def get_lr(self):
176
+ if not self._get_lr_called_within_step:
177
+ warnings.warn(
178
+ "To get the last learning rate computed "
179
+ "by the scheduler, please use `get_last_lr()`.",
180
+ UserWarning,
181
+ stacklevel=2)
182
+
183
+ step = self.last_epoch
184
+
185
+ if step <= self.constant_steps:
186
+ return [self.constant_lr for _ in self.base_lrs]
187
+
188
+ if step > self.max_steps:
189
+ return [self.min_lr for _ in self.base_lrs]
190
+
191
+ return self._get_lr(step)
192
+
193
+ def _get_lr(self, step):
194
+ """Simple const lr policy"""
195
+ return self.base_lrs
196
+
197
+
198
+ class WarmupHoldPolicy(WarmupPolicy):
199
+ """Variant of WarmupPolicy which maintains high
200
+ learning rate for a defined number of steps.
201
+ All arguments should be passed as kwargs for clarity,
202
+ Args:
203
+ warmup_steps: Number of training steps in warmup stage
204
+ warmup_ratio: Ratio of warmup steps to total steps
205
+ hold_steps: Number of training steps to
206
+ hold the learning rate after warm up
207
+ hold_ratio: Ratio of hold steps to total steps
208
+ max_steps: Total number of steps while training or `None` for
209
+ infinite training
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ optimizer,
215
+ *,
216
+ warmup_steps=None,
217
+ warmup_ratio=None,
218
+ hold_steps=None,
219
+ hold_ratio=None,
220
+ max_steps=None,
221
+ min_lr=0.0,
222
+ last_epoch=-1,
223
+ ):
224
+ assert not (hold_steps is not None and hold_ratio is not None), \
225
+ "Either use particular number of step or ratio"
226
+ assert hold_ratio is None or max_steps is not None, \
227
+ "If there is a ratio, there should be a total steps"
228
+
229
+ self.min_lr = min_lr
230
+ self._last_warmup_lr = 0.0
231
+
232
+ # Necessary to duplicate as class attributes are hidden in inner class
233
+ self.max_steps = max_steps
234
+ if warmup_steps is not None:
235
+ self.warmup_steps = warmup_steps
236
+ elif warmup_ratio is not None:
237
+ self.warmup_steps = int(warmup_ratio * max_steps)
238
+ else:
239
+ self.warmup_steps = 0
240
+
241
+ if hold_steps is not None:
242
+ self.hold_steps = hold_steps + self.warmup_steps
243
+ elif hold_ratio is not None:
244
+ self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
245
+ else:
246
+ self.hold_steps = 0
247
+
248
+ super().__init__(
249
+ optimizer,
250
+ warmup_steps=warmup_steps,
251
+ warmup_ratio=warmup_ratio,
252
+ max_steps=max_steps,
253
+ last_epoch=last_epoch,
254
+ min_lr=min_lr,
255
+ )
256
+
257
+ def get_lr(self):
258
+ if not self._get_lr_called_within_step:
259
+ warnings.warn(
260
+ "To get the last learning rate computed by the scheduler,"
261
+ " "
262
+ "please use `get_last_lr()`.",
263
+ UserWarning,
264
+ stacklevel=2)
265
+
266
+ step = self.last_epoch
267
+
268
+ # Warmup phase
269
+ if step <= self.warmup_steps and self.warmup_steps > 0:
270
+ return self._get_warmup_lr(step)
271
+
272
+ # Hold phase
273
+ if (step >= self.warmup_steps) and (step < self.hold_steps):
274
+ return self.base_lrs
275
+
276
+ if step > self.max_steps:
277
+ return [self.min_lr for _ in self.base_lrs]
278
+
279
+ return self._get_lr(step)
280
+
281
+
282
+ class WarmupAnnealHoldPolicy(_LRScheduler):
283
+ """Adds warmup kwargs and warmup logic to lr policy.
284
+ All arguments should be passed as kwargs for clarity,
285
+ Args:
286
+ warmup_steps: Number of training steps in warmup stage
287
+ warmup_ratio: Ratio of warmup steps to total steps
288
+ max_steps: Total number of steps while training or `None` for
289
+ infinite training
290
+ min_lr: Minimum lr to hold the learning rate after decay at.
291
+ constant_steps: Number of steps to keep lr constant at.
292
+ constant_ratio: Ratio of steps to keep lr constant.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ optimizer,
298
+ *,
299
+ warmup_steps=None,
300
+ warmup_ratio=None,
301
+ constant_steps=None,
302
+ constant_ratio=None,
303
+ max_steps=None,
304
+ min_lr=0.0,
305
+ last_epoch=-1,
306
+ ):
307
+ assert not (warmup_steps is not None
308
+ and warmup_ratio is not None), \
309
+ "Either use particular number of step or ratio"
310
+ assert not (constant_steps is not None
311
+ and constant_ratio is not None), \
312
+ "Either use constant_steps or constant_ratio"
313
+ assert warmup_ratio is None or max_steps is not None, \
314
+ "If there is a ratio, there should be a total steps"
315
+
316
+ # It is necessary to assign all attributes *before* __init__,
317
+ # as class is wrapped by an inner class.
318
+ self.max_steps = max_steps
319
+
320
+ if warmup_steps is not None:
321
+ self.warmup_steps = warmup_steps
322
+ elif warmup_ratio is not None:
323
+ self.warmup_steps = int(warmup_ratio * max_steps)
324
+ else:
325
+ self.warmup_steps = 0
326
+
327
+ if constant_steps is not None:
328
+ self.constant_steps = constant_steps
329
+ elif constant_ratio is not None:
330
+ self.constant_steps = int(constant_ratio * max_steps)
331
+ else:
332
+ self.constant_steps = 0
333
+
334
+ self.decay_steps = max_steps - (self.constant_steps +
335
+ self.warmup_steps)
336
+
337
+ self.min_lr = min_lr
338
+ super().__init__(optimizer, last_epoch)
339
+
340
+ def get_lr(self):
341
+ if not self._get_lr_called_within_step:
342
+ warnings.warn(
343
+ "To get the last learning rate computed "
344
+ "by the scheduler, please use `get_last_lr()`.",
345
+ UserWarning,
346
+ stacklevel=2)
347
+
348
+ step = self.last_epoch
349
+
350
+ # Warmup steps
351
+ if self.warmup_steps > 0 and step <= self.warmup_steps:
352
+ return self._get_warmup_lr(step)
353
+
354
+ # Constant steps after warmup and decay
355
+ if self.constant_steps > 0 and (
356
+ self.warmup_steps + self.decay_steps) < step <= self.max_steps:
357
+ return self._get_constant_lr(step)
358
+
359
+ # Min lr after max steps of updates
360
+ if step > self.max_steps:
361
+ return [self.min_lr for _ in self.base_lrs]
362
+
363
+ return self._get_lr(step)
364
+
365
+ def _get_warmup_lr(self, step):
366
+ lr_val = (step + 1) / (self.warmup_steps + 1)
367
+ return [initial_lr * lr_val for initial_lr in self.base_lrs]
368
+
369
+ def _get_constant_lr(self, step):
370
+ return [self.min_lr for _ in self.base_lrs]
371
+
372
+ def _get_lr(self, step):
373
+ """Simple const lr policy"""
374
+ return self.base_lrs
375
+
376
+
377
+ def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
378
+ mult = ((max_steps - step) / max_steps)**0.5
379
+ out_lr = initial_lr * mult
380
+ out_lr = max(out_lr, min_lr)
381
+ return out_lr
382
+
383
+
384
+ def _square_annealing(initial_lr, step, max_steps, min_lr):
385
+ mult = ((max_steps - step) / max_steps)**2
386
+ out_lr = initial_lr * mult
387
+ out_lr = max(out_lr, min_lr)
388
+ return out_lr
389
+
390
+
391
+ def _cosine_annealing(initial_lr, step, max_steps, min_lr):
392
+ mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
393
+ out_lr = (initial_lr - min_lr) * mult + min_lr
394
+ return out_lr
395
+
396
+
397
+ def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step,
398
+ decay_steps, min_lr):
399
+ assert max_lr > min_lr
400
+ # Use linear warmup for the initial part.
401
+ if warmup_steps > 0 and step <= warmup_steps:
402
+ return max_lr * float(step) / float(warmup_steps)
403
+
404
+ # For any steps larger than `decay_steps`, use `min_lr`.
405
+ if step > warmup_steps + decay_steps:
406
+ return min_lr
407
+
408
+ # If we are done with the warmup period, use the decay style.
409
+ num_steps_ = step - warmup_steps
410
+ decay_steps_ = decay_steps
411
+ decay_ratio = float(num_steps_) / float(decay_steps_)
412
+ assert decay_ratio >= 0.0
413
+ assert decay_ratio <= 1.0
414
+ delta_lr = max_lr - min_lr
415
+
416
+ coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
417
+
418
+ return min_lr + coeff * delta_lr
419
+
420
+
421
+ def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
422
+ if cycle:
423
+ multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
424
+ decay_steps *= multiplier
425
+ else:
426
+ step = min(step, decay_steps)
427
+ p = step / decay_steps
428
+ lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
429
+ lr += min_lr
430
+ return lr
431
+
432
+
433
+ def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps,
434
+ decay_rate, min_lr):
435
+ # hold_steps = total number of steps
436
+ # to hold the LR, not the warmup + hold steps.
437
+ T_warmup_decay = max(1, warmup_steps**decay_rate)
438
+ T_hold_decay = max(1, (step - hold_steps)**decay_rate)
439
+ lr = (initial_lr * T_warmup_decay) / T_hold_decay
440
+ lr = max(lr, min_lr)
441
+ return lr
442
+
443
+
444
+ class SquareAnnealing(WarmupPolicy):
445
+
446
+ def __init__(self,
447
+ optimizer,
448
+ *,
449
+ max_steps,
450
+ min_lr=1e-5,
451
+ last_epoch=-1,
452
+ **kwargs):
453
+ super().__init__(optimizer=optimizer,
454
+ max_steps=max_steps,
455
+ last_epoch=last_epoch,
456
+ min_lr=min_lr,
457
+ **kwargs)
458
+
459
+ def _get_lr(self, step):
460
+ new_lrs = [
461
+ _square_annealing(
462
+ initial_lr=initial_lr,
463
+ step=step - self.warmup_steps,
464
+ max_steps=self.max_steps - self.warmup_steps,
465
+ min_lr=self.min_lr,
466
+ ) for initial_lr in self.base_lrs
467
+ ]
468
+ return new_lrs
469
+
470
+
471
+ class SquareRootAnnealing(WarmupPolicy):
472
+
473
+ def __init__(self,
474
+ optimizer,
475
+ *,
476
+ max_steps,
477
+ min_lr=0,
478
+ last_epoch=-1,
479
+ **kwargs):
480
+ super().__init__(optimizer=optimizer,
481
+ max_steps=max_steps,
482
+ last_epoch=last_epoch,
483
+ min_lr=min_lr,
484
+ **kwargs)
485
+
486
+ def _get_lr(self, step):
487
+ new_lrs = [
488
+ _squareroot_annealing(initial_lr=initial_lr,
489
+ step=step,
490
+ max_steps=self.max_steps,
491
+ min_lr=self.min_lr)
492
+ for initial_lr in self.base_lrs
493
+ ]
494
+ return new_lrs
495
+
496
+
497
+ class CosineAnnealing(WarmupAnnealHoldPolicy):
498
+
499
+ def __init__(self,
500
+ optimizer,
501
+ *,
502
+ max_steps,
503
+ min_lr=0,
504
+ last_epoch=-1,
505
+ **kwargs):
506
+ super().__init__(optimizer=optimizer,
507
+ max_steps=max_steps,
508
+ last_epoch=last_epoch,
509
+ min_lr=min_lr,
510
+ **kwargs)
511
+
512
+ def _get_lr(self, step):
513
+ for initial_lr in self.base_lrs:
514
+ if initial_lr < self.min_lr:
515
+ raise ValueError(
516
+ f"{self} received an initial learning rate "
517
+ f"that was lower than the minimum learning rate.")
518
+
519
+ if self.constant_steps is None or self.constant_steps == 0:
520
+ new_lrs = [
521
+ _cosine_annealing(
522
+ initial_lr=initial_lr,
523
+ step=step - self.warmup_steps,
524
+ max_steps=self.max_steps - self.warmup_steps,
525
+ min_lr=self.min_lr,
526
+ ) for initial_lr in self.base_lrs
527
+ ]
528
+ else:
529
+ new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
530
+ return new_lrs
531
+
532
+ def _get_warmup_lr(self, step):
533
+ if self.constant_steps is None or self.constant_steps == 0:
534
+ return super()._get_warmup_lr(step)
535
+ else:
536
+ # Use linear warmup for the initial part.
537
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
538
+
539
+ def _get_constant_lr(self, step):
540
+ # Only called when `constant_steps` > 0.
541
+ return self._get_linear_warmup_with_cosine_annealing_lr(step)
542
+
543
+ def _get_linear_warmup_with_cosine_annealing_lr(self, step):
544
+ # Cosine Schedule for Megatron LM,
545
+ # slightly different warmup schedule + constant LR at the end.
546
+ new_lrs = [
547
+ _linear_warmup_with_cosine_annealing(
548
+ max_lr=self.base_lrs[0],
549
+ warmup_steps=self.warmup_steps,
550
+ step=step,
551
+ decay_steps=self.decay_steps,
552
+ min_lr=self.min_lr,
553
+ ) for _ in self.base_lrs
554
+ ]
555
+ return new_lrs
556
+
557
+
558
+ class NoamAnnealing(_LRScheduler):
559
+
560
+ def __init__(self,
561
+ optimizer,
562
+ *,
563
+ d_model,
564
+ warmup_steps=None,
565
+ warmup_ratio=None,
566
+ max_steps=None,
567
+ min_lr=0.0,
568
+ last_epoch=-1):
569
+ self._normalize = d_model**(-0.5)
570
+ assert not (warmup_steps is not None and warmup_ratio is not None), \
571
+ "Either use particular number of step or ratio"
572
+ assert warmup_ratio is None or max_steps is not None, \
573
+ "If there is a ratio, there should be a total steps"
574
+
575
+ # It is necessary to assign all attributes *before* __init__,
576
+ # as class is wrapped by an inner class.
577
+ self.max_steps = max_steps
578
+ if warmup_steps is not None:
579
+ self.warmup_steps = warmup_steps
580
+ elif warmup_ratio is not None:
581
+ self.warmup_steps = int(warmup_ratio * max_steps)
582
+ else:
583
+ self.warmup_steps = 0
584
+
585
+ self.min_lr = min_lr
586
+ super().__init__(optimizer, last_epoch)
587
+
588
+ def get_lr(self):
589
+ if not self._get_lr_called_within_step:
590
+ warnings.warn(
591
+ "To get the last learning rate computed "
592
+ "by the scheduler, please use `get_last_lr()`.",
593
+ UserWarning,
594
+ stacklevel=2)
595
+
596
+ step = max(1, self.last_epoch)
597
+
598
+ for initial_lr in self.base_lrs:
599
+ if initial_lr < self.min_lr:
600
+ raise ValueError(
601
+ f"{self} received an initial learning rate "
602
+ f"that was lower than the minimum learning rate.")
603
+
604
+ new_lrs = [
605
+ self._noam_annealing(initial_lr=initial_lr, step=step)
606
+ for initial_lr in self.base_lrs
607
+ ]
608
+ return new_lrs
609
+
610
+ def _noam_annealing(self, initial_lr, step):
611
+ if self.warmup_steps > 0:
612
+ mult = self._normalize * min(step**(-0.5),
613
+ step * (self.warmup_steps**(-1.5)))
614
+ else:
615
+ mult = self._normalize * step**(-0.5)
616
+
617
+ out_lr = initial_lr * mult
618
+ if step > self.warmup_steps:
619
+ out_lr = max(out_lr, self.min_lr)
620
+ return out_lr
621
+
622
+
623
+ class NoamHoldAnnealing(WarmupHoldPolicy):
624
+
625
+ def __init__(self,
626
+ optimizer,
627
+ *,
628
+ max_steps,
629
+ decay_rate=0.5,
630
+ min_lr=0.0,
631
+ last_epoch=-1,
632
+ **kwargs):
633
+ """
634
+ From Nemo:
635
+ Implementation of the Noam Hold Annealing policy
636
+ from the SqueezeFormer paper.
637
+
638
+ Unlike NoamAnnealing, the peak learning rate
639
+ can be explicitly set for this scheduler.
640
+ The schedule first performs linear warmup,
641
+ then holds the peak LR, then decays with some schedule for
642
+ the remainder of the steps.
643
+ Therefore the min-lr is still dependent
644
+ on the hyper parameters selected.
645
+
646
+ It's schedule is determined by three factors-
647
+
648
+ Warmup Steps: Initial stage, where linear warmup
649
+ occurs uptil the peak LR is reached. Unlike NoamAnnealing,
650
+ the peak LR is explicitly stated here instead of a scaling factor.
651
+
652
+ Hold Steps: Intermediate stage, where the peak LR
653
+ is maintained for some number of steps. In this region,
654
+ the high peak LR allows the model to converge faster
655
+ if training is stable. However the high LR
656
+ may also cause instability during training.
657
+ Should usually be a significant fraction of training
658
+ steps (around 30-40% of the entire training steps).
659
+
660
+ Decay Steps: Final stage, where the LR rapidly decays
661
+ with some scaling rate (set by decay rate).
662
+ To attain Noam decay, use 0.5,
663
+ for Squeezeformer recommended decay, use 1.0.
664
+ The fast decay after prolonged high LR during
665
+ hold phase allows for rapid convergence.
666
+
667
+ References:
668
+ - [Squeezeformer:
669
+ An Efficient Transformer for Automatic Speech Recognition]
670
+ (https://arxiv.org/abs/2206.00888)
671
+
672
+ Args:
673
+ optimizer: Pytorch compatible Optimizer object.
674
+ warmup_steps: Number of training steps in warmup stage
675
+ warmup_ratio: Ratio of warmup steps to total steps
676
+ hold_steps: Number of training steps to
677
+ hold the learning rate after warm up
678
+ hold_ratio: Ratio of hold steps to total steps
679
+ max_steps: Total number of steps while training or `None` for
680
+ infinite training
681
+ decay_rate: Float value describing the polynomial decay
682
+ after the hold period. Default value
683
+ of 0.5 corresponds to Noam decay.
684
+ min_lr: Minimum learning rate.
685
+ """
686
+ self.decay_rate = decay_rate
687
+ super().__init__(optimizer=optimizer,
688
+ max_steps=max_steps,
689
+ last_epoch=last_epoch,
690
+ min_lr=min_lr,
691
+ **kwargs)
692
+
693
+ def _get_lr(self, step):
694
+ if self.warmup_steps is None or self.warmup_steps == 0:
695
+ raise ValueError(
696
+ "Noam scheduler cannot be used without warmup steps")
697
+
698
+ if self.hold_steps > 0:
699
+ hold_steps = self.hold_steps - self.warmup_steps
700
+ else:
701
+ hold_steps = 0
702
+
703
+ new_lrs = [
704
+ _noam_hold_annealing(
705
+ initial_lr,
706
+ step=step,
707
+ warmup_steps=self.warmup_steps,
708
+ hold_steps=hold_steps,
709
+ decay_rate=self.decay_rate,
710
+ min_lr=self.min_lr,
711
+ ) for initial_lr in self.base_lrs
712
+ ]
713
+ return new_lrs
714
+
715
+ def set_step(self, step: int):
716
+ self.last_epoch = step
717
+
718
+
719
+ class ConstantLR(_LRScheduler):
720
+ """The ConstantLR scheduler
721
+
722
+ This scheduler keeps a constant lr
723
+
724
+ """
725
+
726
+ def __init__(
727
+ self,
728
+ optimizer: torch.optim.Optimizer,
729
+ ):
730
+ # __init__() must be invoked before setting field
731
+ # because step() is also invoked in __init__()
732
+ super().__init__(optimizer)
733
+
734
+ def get_lr(self):
735
+ return self.base_lrs
736
+
737
+ def set_step(self, step: int):
738
+ self.last_epoch = step
examples/libritts/cosyvoice/local/prepare_data.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import glob
4
+ import os
5
+ from tqdm import tqdm
6
+
7
+
8
+ logger = logging.getLogger()
9
+
10
+
11
+ def main():
12
+ wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
13
+
14
+ utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
15
+ for wav in tqdm(wavs):
16
+ txt = wav.replace('.wav', '.normalized.txt')
17
+ if not os.path.exists(txt):
18
+ logger.warning('{} do not exsist'.format(txt))
19
+ continue
20
+ with open(txt) as f:
21
+ content = ''.join(l.replace('\n', '') for l in f.readline())
22
+ utt = os.path.basename(wav).replace('.wav', '')
23
+ spk = utt.split('_')[0]
24
+ utt2wav[utt] = wav
25
+ utt2text[utt] = content
26
+ utt2spk[utt] = spk
27
+ if spk not in spk2utt:
28
+ spk2utt[spk] = []
29
+ spk2utt[spk].append(utt)
30
+
31
+ with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
32
+ for k, v in utt2wav.items():
33
+ f.write('{} {}\n'.format(k, v))
34
+ with open('{}/text'.format(args.des_dir), 'w') as f:
35
+ for k, v in utt2text.items():
36
+ f.write('{} {}\n'.format(k, v))
37
+ with open('{}/utt2spk'.format(args.des_dir), 'w') as f:
38
+ for k, v in utt2spk.items():
39
+ f.write('{} {}\n'.format(k, v))
40
+ with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
41
+ for k, v in spk2utt.items():
42
+ f.write('{} {}\n'.format(k, ' '.join(v)))
43
+ return
44
+
45
+
46
+ if __name__ == "__main__":
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument('--src_dir',
49
+ type=str)
50
+ parser.add_argument('--des_dir',
51
+ type=str)
52
+ args = parser.parse_args()
53
+ main()
examples/libritts/cosyvoice/path.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
2
+ export PYTHONIOENCODING=UTF-8
3
+ export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH
requirements.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684
3
+ conformer==0.3.2
4
+ diffusers==0.29.0
5
+ gdown==5.1.0
6
+ gradio==5.4.0
7
+ grpcio==1.57.0
8
+ grpcio-tools==1.57.0
9
+ hydra-core==1.3.2
10
+ HyperPyYAML==1.2.2
11
+ inflect==7.3.1
12
+ librosa==0.10.2
13
+ lightning==2.2.4
14
+ matplotlib==3.7.5
15
+ modelscope==1.15.0
16
+ networkx==3.1
17
+ omegaconf==2.3.0
18
+ onnx==1.16.0
19
+ onnxruntime-gpu==1.18.0; sys_platform == 'linux'
20
+ onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'windows'
21
+ openai-whisper==20231117
22
+ protobuf==4.25
23
+ pydantic==2.7.0
24
+ pyworld==0.3.4
25
+ rich==13.7.1
26
+ soundfile==0.12.1
27
+ tensorboard==2.14.0
28
+ tensorrt-cu12==10.0.1; sys_platform == 'linux'
29
+ tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
30
+ tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
31
+ torch==2.3.1
32
+ torchaudio==2.3.1
33
+ transformers==4.40.1
34
+ uvicorn==0.30.0
35
+ wget==3.2
36
+ fastapi==0.115.6
37
+ fastapi-cli==0.0.4
38
+ WeTextProcessing==1.0.3
runtime/python/Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+
4
+ WORKDIR /opt/CosyVoice
5
+
6
+ RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list
7
+ RUN apt-get update -y
8
+ RUN apt-get -y install git unzip git-lfs
9
+ RUN git lfs install
10
+ RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
11
+ # here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
12
+ RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
13
+ RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
runtime/python/grpc/cosyvoice.proto ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package cosyvoice;
4
+ option go_package = "protos/";
5
+
6
+ service CosyVoice{
7
+ rpc Inference(Request) returns (stream Response) {}
8
+ }
9
+
10
+ message Request{
11
+ oneof RequestPayload {
12
+ sftRequest sft_request = 1;
13
+ zeroshotRequest zero_shot_request = 2;
14
+ crosslingualRequest cross_lingual_request = 3;
15
+ instructRequest instruct_request = 4;
16
+ }
17
+ }
18
+
19
+ message sftRequest{
20
+ string spk_id = 1;
21
+ string tts_text = 2;
22
+ }
23
+
24
+ message zeroshotRequest{
25
+ string tts_text = 1;
26
+ string prompt_text = 2;
27
+ bytes prompt_audio = 3;
28
+ }
29
+
30
+ message crosslingualRequest{
31
+ string tts_text = 1;
32
+ bytes prompt_audio = 2;
33
+ }
34
+
35
+ message instructRequest{
36
+ string tts_text = 1;
37
+ string spk_id = 2;
38
+ string instruct_text = 3;
39
+ }
40
+
41
+ message Response{
42
+ bytes tts_audio = 1;
43
+ }
webui.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ import argparse
17
+ import gradio as gr
18
+ import numpy as np
19
+ import torch
20
+ import torchaudio
21
+ import random
22
+ import librosa
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
25
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
26
+ from cosyvoice.utils.file_utils import load_wav, logging
27
+ from cosyvoice.utils.common import set_all_random_seed
28
+
29
+ inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
30
+ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
31
+ '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
32
+ '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
33
+ '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
34
+ stream_mode_list = [('否', False), ('是', True)]
35
+ max_val = 0.8
36
+
37
+
38
+ def generate_seed():
39
+ seed = random.randint(1, 100000000)
40
+ return {
41
+ "__type__": "update",
42
+ "value": seed
43
+ }
44
+
45
+
46
+ def postprocess(speech, top_db=60, hop_length=220, win_length=440):
47
+ speech, _ = librosa.effects.trim(
48
+ speech, top_db=top_db,
49
+ frame_length=win_length,
50
+ hop_length=hop_length
51
+ )
52
+ if speech.abs().max() > max_val:
53
+ speech = speech / speech.abs().max() * max_val
54
+ speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1)
55
+ return speech
56
+
57
+
58
+ def change_instruction(mode_checkbox_group):
59
+ return instruct_dict[mode_checkbox_group]
60
+
61
+
62
+ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
63
+ seed, stream, speed):
64
+ if prompt_wav_upload is not None:
65
+ prompt_wav = prompt_wav_upload
66
+ elif prompt_wav_record is not None:
67
+ prompt_wav = prompt_wav_record
68
+ else:
69
+ prompt_wav = None
70
+ # if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode
71
+ if mode_checkbox_group in ['自然语言控制']:
72
+ if cosyvoice.instruct is False:
73
+ gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
74
+ yield (cosyvoice.sample_rate, default_data)
75
+ if instruct_text == '':
76
+ gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
77
+ yield (cosyvoice.sample_rate, default_data)
78
+ if prompt_wav is not None or prompt_text != '':
79
+ gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
80
+ # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
81
+ if mode_checkbox_group in ['跨语种复刻']:
82
+ if cosyvoice.instruct is True:
83
+ gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
84
+ yield (cosyvoice.sample_rate, default_data)
85
+ if instruct_text != '':
86
+ gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
87
+ if prompt_wav is None:
88
+ gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
89
+ yield (cosyvoice.sample_rate, default_data)
90
+ gr.Info('您正在使用跨语种复刻模式, 请确保合成文本和prompt文本为不同语言')
91
+ # if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
92
+ if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
93
+ if prompt_wav is None:
94
+ gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
95
+ yield (cosyvoice.sample_rate, default_data)
96
+ if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
97
+ gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
98
+ yield (cosyvoice.sample_rate, default_data)
99
+ # sft mode only use sft_dropdown
100
+ if mode_checkbox_group in ['预训练音色']:
101
+ if instruct_text != '' or prompt_wav is not None or prompt_text != '':
102
+ gr.Info('您正在使用预训练音色模式,prompt文本/prompt音频/instruct文本会被忽略!')
103
+ if sft_dropdown == '':
104
+ gr.Warning('没有可用的预训练音色!')
105
+ yield (cosyvoice.sample_rate, default_data)
106
+ # zero_shot mode only use prompt_wav prompt text
107
+ if mode_checkbox_group in ['3s极速复刻']:
108
+ if prompt_text == '':
109
+ gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
110
+ yield (cosyvoice.sample_rate, default_data)
111
+ if instruct_text != '':
112
+ gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
113
+
114
+ if mode_checkbox_group == '预训练音色':
115
+ logging.info('get sft inference request')
116
+ set_all_random_seed(seed)
117
+ for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed):
118
+ yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
119
+ elif mode_checkbox_group == '3s极速复刻':
120
+ logging.info('get zero_shot inference request')
121
+ prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
122
+ set_all_random_seed(seed)
123
+ for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
124
+ yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
125
+ elif mode_checkbox_group == '跨语种复刻':
126
+ logging.info('get cross_lingual inference request')
127
+ prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
128
+ set_all_random_seed(seed)
129
+ for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
130
+ yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
131
+ else:
132
+ logging.info('get instruct inference request')
133
+ set_all_random_seed(seed)
134
+ for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
135
+ yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
136
+
137
+
138
+ def main():
139
+ with gr.Blocks() as demo:
140
+ gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \
141
+ 预训练模型 [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \
142
+ [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \
143
+ [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
144
+ gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
145
+
146
+ tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
147
+ with gr.Row():
148
+ mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
149
+ instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
150
+ sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
151
+ stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
152
+ speed = gr.Number(value=1, label="速度调节(仅支持非流式推理)", minimum=0.5, maximum=2.0, step=0.1)
153
+ with gr.Column(scale=0.25):
154
+ seed_button = gr.Button(value="\U0001F3B2")
155
+ seed = gr.Number(value=0, label="随机推理种子")
156
+
157
+ with gr.Row():
158
+ prompt_wav_upload = gr.Audio(sources='upload', type='filepath', label='选择prompt音频文件,注意采样率不低于16khz')
159
+ prompt_wav_record = gr.Audio(sources='microphone', type='filepath', label='录制prompt音频文件')
160
+ prompt_text = gr.Textbox(label="输入prompt文本", lines=1, placeholder="请输入prompt文本,需与prompt音频内容一致,暂时不支持自动识别...", value='')
161
+ instruct_text = gr.Textbox(label="输入instruct文本", lines=1, placeholder="请输入instruct文本.", value='')
162
+
163
+ generate_button = gr.Button("生成音频")
164
+
165
+ audio_output = gr.Audio(label="合成音频", autoplay=True, streaming=True)
166
+
167
+ seed_button.click(generate_seed, inputs=[], outputs=seed)
168
+ generate_button.click(generate_audio,
169
+ inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
170
+ seed, stream, speed],
171
+ outputs=[audio_output])
172
+ mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
173
+ demo.queue(max_size=4, default_concurrency_limit=2)
174
+ demo.launch(share=True,server_name='0.0.0.0', server_port=args.port)
175
+
176
+
177
+ if __name__ == '__main__':
178
+ parser = argparse.ArgumentParser()
179
+ parser.add_argument('--port',
180
+ type=int,
181
+ default=8000)
182
+ parser.add_argument('--model_dir',
183
+ type=str,
184
+ default='pretrained_models/CosyVoice2-0.5B',
185
+ help='local path or modelscope repo id')
186
+ args = parser.parse_args()
187
+ try:
188
+ cosyvoice = CosyVoice(args.model_dir)
189
+ except Exception:
190
+ try:
191
+ cosyvoice = CosyVoice2(args.model_dir)
192
+ except Exception:
193
+ raise TypeError('no valid model_type!')
194
+
195
+ sft_spk = cosyvoice.list_available_spks()
196
+ if len(sft_spk) == 0:
197
+ sft_spk = ['']
198
+ prompt_sr = 16000
199
+ default_data = np.zeros(cosyvoice.sample_rate)
200
+ main()