Spaces:
Running
Running
start deploy
Browse files- .gitignore +6 -0
- LICENSE +201 -0
- app.py +278 -0
- cli/SparkTTS.py +236 -0
- cli/inference.py +116 -0
- datasets/.gitkeep +0 -0
- pretrained_models/.gitkeep +0 -0
- requirements.txt +14 -0
- runtime/triton_trtllm/Dockerfile.server +5 -0
- runtime/triton_trtllm/README.md +94 -0
- runtime/triton_trtllm/client_grpc.py +831 -0
- runtime/triton_trtllm/client_http.py +165 -0
- runtime/triton_trtllm/docker-compose.yml +20 -0
- runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py +137 -0
- runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt +58 -0
- runtime/triton_trtllm/model_repo/spark_tts/1/model.py +404 -0
- runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt +86 -0
- runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep +0 -0
- runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt +857 -0
- runtime/triton_trtllm/model_repo/vocoder/1/model.py +106 -0
- runtime/triton_trtllm/model_repo/vocoder/config.pbtxt +53 -0
- runtime/triton_trtllm/run.sh +109 -0
- runtime/triton_trtllm/scripts/convert_checkpoint.py +335 -0
- runtime/triton_trtllm/scripts/fill_template.py +70 -0
- sparktts/models/audio_tokenizer.py +163 -0
- sparktts/models/bicodec.py +247 -0
- sparktts/modules/blocks/layers.py +73 -0
- sparktts/modules/blocks/samper.py +115 -0
- sparktts/modules/blocks/vocos.py +373 -0
- sparktts/modules/encoder_decoder/feat_decoder.py +115 -0
- sparktts/modules/encoder_decoder/feat_encoder.py +105 -0
- sparktts/modules/encoder_decoder/wave_generator.py +88 -0
- sparktts/modules/fsq/finite_scalar_quantization.py +251 -0
- sparktts/modules/fsq/residual_fsq.py +355 -0
- sparktts/modules/speaker/ecapa_tdnn.py +267 -0
- sparktts/modules/speaker/perceiver_encoder.py +360 -0
- sparktts/modules/speaker/pooling_layers.py +298 -0
- sparktts/modules/speaker/speaker_encoder.py +136 -0
- sparktts/modules/vq/factorized_vector_quantize.py +187 -0
- sparktts/utils/__init__.py +0 -0
- sparktts/utils/audio.py +271 -0
- sparktts/utils/file.py +221 -0
- sparktts/utils/parse_options.sh +97 -0
- sparktts/utils/token_parser.py +187 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
example/
|
2 |
+
src/
|
3 |
+
pretrained_models/*
|
4 |
+
!pretrained_models/.gitkeep
|
5 |
+
datasets/*
|
6 |
+
!datasets/.gitkeep
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import torch
|
18 |
+
import soundfile as sf
|
19 |
+
import logging
|
20 |
+
import argparse
|
21 |
+
import gradio as gr
|
22 |
+
import platform
|
23 |
+
|
24 |
+
from datetime import datetime
|
25 |
+
from cli.SparkTTS import SparkTTS
|
26 |
+
from sparktts.utils.token_parser import LEVELS_MAP_UI
|
27 |
+
from huggingface_hub import snapshot_download
|
28 |
+
|
29 |
+
|
30 |
+
def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0):
|
31 |
+
"""Load the model once at the beginning."""
|
32 |
+
logging.info(f"Loading model from: {model_dir}")
|
33 |
+
|
34 |
+
# Determine appropriate device based on platform and availability
|
35 |
+
if platform.system() == "Darwin":
|
36 |
+
# macOS with MPS support (Apple Silicon)
|
37 |
+
device = torch.device(f"mps:{device}")
|
38 |
+
logging.info(f"Using MPS device: {device}")
|
39 |
+
elif torch.cuda.is_available():
|
40 |
+
# System with CUDA support
|
41 |
+
device = torch.device(f"cuda:{device}")
|
42 |
+
logging.info(f"Using CUDA device: {device}")
|
43 |
+
else:
|
44 |
+
# Fall back to CPU
|
45 |
+
device = torch.device("cpu")
|
46 |
+
logging.info("GPU acceleration not available, using CPU")
|
47 |
+
|
48 |
+
model = SparkTTS(model_dir, device)
|
49 |
+
return model
|
50 |
+
|
51 |
+
|
52 |
+
def run_tts(
|
53 |
+
text,
|
54 |
+
model,
|
55 |
+
prompt_text=None,
|
56 |
+
prompt_speech=None,
|
57 |
+
gender=None,
|
58 |
+
pitch=None,
|
59 |
+
speed=None,
|
60 |
+
save_dir="example/results",
|
61 |
+
):
|
62 |
+
"""Perform TTS inference and save the generated audio."""
|
63 |
+
logging.info(f"Saving audio to: {save_dir}")
|
64 |
+
|
65 |
+
if prompt_text is not None:
|
66 |
+
prompt_text = None if len(prompt_text) <= 1 else prompt_text
|
67 |
+
|
68 |
+
# Ensure the save directory exists
|
69 |
+
os.makedirs(save_dir, exist_ok=True)
|
70 |
+
|
71 |
+
# Generate unique filename using timestamp
|
72 |
+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
73 |
+
save_path = os.path.join(save_dir, f"{timestamp}.wav")
|
74 |
+
|
75 |
+
logging.info("Starting inference...")
|
76 |
+
|
77 |
+
# Perform inference and save the output audio
|
78 |
+
with torch.no_grad():
|
79 |
+
wav = model.inference(
|
80 |
+
text,
|
81 |
+
prompt_speech,
|
82 |
+
prompt_text,
|
83 |
+
gender,
|
84 |
+
pitch,
|
85 |
+
speed,
|
86 |
+
)
|
87 |
+
|
88 |
+
sf.write(save_path, wav, samplerate=16000)
|
89 |
+
|
90 |
+
logging.info(f"Audio saved at: {save_path}")
|
91 |
+
|
92 |
+
return save_path
|
93 |
+
|
94 |
+
|
95 |
+
def build_ui(model_dir, device=0):
|
96 |
+
|
97 |
+
# Initialize model
|
98 |
+
model = initialize_model(model_dir, device=device)
|
99 |
+
|
100 |
+
# Define callback function for voice cloning
|
101 |
+
def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
|
102 |
+
"""
|
103 |
+
Gradio callback to clone voice using text and optional prompt speech.
|
104 |
+
- text: The input text to be synthesised.
|
105 |
+
- prompt_text: Additional textual info for the prompt (optional).
|
106 |
+
- prompt_wav_upload/prompt_wav_record: Audio files used as reference.
|
107 |
+
"""
|
108 |
+
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
|
109 |
+
prompt_text_clean = None if len(prompt_text) < 2 else prompt_text
|
110 |
+
|
111 |
+
audio_output_path = run_tts(
|
112 |
+
text,
|
113 |
+
model,
|
114 |
+
prompt_text=prompt_text_clean,
|
115 |
+
prompt_speech=prompt_speech
|
116 |
+
)
|
117 |
+
return audio_output_path
|
118 |
+
|
119 |
+
# Define callback function for creating new voices
|
120 |
+
def voice_creation(text, gender, pitch, speed):
|
121 |
+
"""
|
122 |
+
Gradio callback to create a synthetic voice with adjustable parameters.
|
123 |
+
- text: The input text for synthesis.
|
124 |
+
- gender: 'male' or 'female'.
|
125 |
+
- pitch/speed: Ranges mapped by LEVELS_MAP_UI.
|
126 |
+
"""
|
127 |
+
pitch_val = LEVELS_MAP_UI[int(pitch)]
|
128 |
+
speed_val = LEVELS_MAP_UI[int(speed)]
|
129 |
+
audio_output_path = run_tts(
|
130 |
+
text,
|
131 |
+
model,
|
132 |
+
gender=gender,
|
133 |
+
pitch=pitch_val,
|
134 |
+
speed=speed_val
|
135 |
+
)
|
136 |
+
return audio_output_path
|
137 |
+
|
138 |
+
with gr.Blocks() as demo:
|
139 |
+
# Use HTML for centered title
|
140 |
+
gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>')
|
141 |
+
with gr.Tabs():
|
142 |
+
# Voice Clone Tab
|
143 |
+
with gr.TabItem("Voice Clone"):
|
144 |
+
gr.Markdown(
|
145 |
+
"### Upload reference audio or recording (上传参考音频或者录音)"
|
146 |
+
)
|
147 |
+
|
148 |
+
with gr.Row():
|
149 |
+
prompt_wav_upload = gr.Audio(
|
150 |
+
sources="upload",
|
151 |
+
type="filepath",
|
152 |
+
label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.",
|
153 |
+
)
|
154 |
+
prompt_wav_record = gr.Audio(
|
155 |
+
sources="microphone",
|
156 |
+
type="filepath",
|
157 |
+
label="Record the prompt audio file.",
|
158 |
+
)
|
159 |
+
|
160 |
+
with gr.Row():
|
161 |
+
text_input = gr.Textbox(
|
162 |
+
label="Text", lines=3, placeholder="Enter text here"
|
163 |
+
)
|
164 |
+
prompt_text_input = gr.Textbox(
|
165 |
+
label="Text of prompt speech (Optional; recommended for cloning in the same language.)",
|
166 |
+
lines=3,
|
167 |
+
placeholder="Enter text of the prompt speech.",
|
168 |
+
)
|
169 |
+
|
170 |
+
audio_output = gr.Audio(
|
171 |
+
label="Generated Audio", autoplay=True, streaming=True
|
172 |
+
)
|
173 |
+
|
174 |
+
generate_buttom_clone = gr.Button("Generate")
|
175 |
+
|
176 |
+
generate_buttom_clone.click(
|
177 |
+
voice_clone,
|
178 |
+
inputs=[
|
179 |
+
text_input,
|
180 |
+
prompt_text_input,
|
181 |
+
prompt_wav_upload,
|
182 |
+
prompt_wav_record,
|
183 |
+
],
|
184 |
+
outputs=[audio_output],
|
185 |
+
)
|
186 |
+
|
187 |
+
# Voice Creation Tab
|
188 |
+
with gr.TabItem("Voice Creation"):
|
189 |
+
gr.Markdown(
|
190 |
+
"### Create your own voice based on the following parameters"
|
191 |
+
)
|
192 |
+
|
193 |
+
with gr.Row():
|
194 |
+
with gr.Column():
|
195 |
+
gender = gr.Radio(
|
196 |
+
choices=["male", "female"], value="male", label="Gender"
|
197 |
+
)
|
198 |
+
pitch = gr.Slider(
|
199 |
+
minimum=1, maximum=5, step=1, value=3, label="Pitch"
|
200 |
+
)
|
201 |
+
speed = gr.Slider(
|
202 |
+
minimum=1, maximum=5, step=1, value=3, label="Speed"
|
203 |
+
)
|
204 |
+
with gr.Column():
|
205 |
+
text_input_creation = gr.Textbox(
|
206 |
+
label="Input Text",
|
207 |
+
lines=3,
|
208 |
+
placeholder="Enter text here",
|
209 |
+
value="You can generate a customized voice by adjusting parameters such as pitch and speed.",
|
210 |
+
)
|
211 |
+
create_button = gr.Button("Create Voice")
|
212 |
+
|
213 |
+
audio_output = gr.Audio(
|
214 |
+
label="Generated Audio", autoplay=True, streaming=True
|
215 |
+
)
|
216 |
+
create_button.click(
|
217 |
+
voice_creation,
|
218 |
+
inputs=[text_input_creation, gender, pitch, speed],
|
219 |
+
outputs=[audio_output],
|
220 |
+
)
|
221 |
+
|
222 |
+
return demo
|
223 |
+
|
224 |
+
|
225 |
+
def parse_arguments():
|
226 |
+
"""
|
227 |
+
Parse command-line arguments such as model directory and device ID.
|
228 |
+
"""
|
229 |
+
parser = argparse.ArgumentParser(description="Spark TTS Gradio server.")
|
230 |
+
parser.add_argument(
|
231 |
+
"--model_dir",
|
232 |
+
type=str,
|
233 |
+
default="pretrained_models/Spark-TTS-0.5B",
|
234 |
+
help="Path to the model directory."
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--device",
|
238 |
+
type=int,
|
239 |
+
default=0,
|
240 |
+
help="ID of the GPU device to use (e.g., 0 for cuda:0)."
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--server_name",
|
244 |
+
type=str,
|
245 |
+
default="0.0.0.0",
|
246 |
+
help="Server host/IP for Gradio app."
|
247 |
+
)
|
248 |
+
parser.add_argument(
|
249 |
+
"--server_port",
|
250 |
+
type=int,
|
251 |
+
default=7860,
|
252 |
+
help="Server port for Gradio app."
|
253 |
+
)
|
254 |
+
return parser.parse_args()
|
255 |
+
|
256 |
+
if __name__ == "__main__":
|
257 |
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
258 |
+
torch.backends.mps.is_available = lambda: False
|
259 |
+
|
260 |
+
|
261 |
+
## if model not downloaded, download it
|
262 |
+
if not os.path.exists("pretrained_models/Spark-TTS-0.5B"):
|
263 |
+
snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
|
264 |
+
|
265 |
+
# Parse command-line arguments
|
266 |
+
args = parse_arguments()
|
267 |
+
|
268 |
+
# Build the Gradio demo by specifying the model directory and GPU device
|
269 |
+
demo = build_ui(
|
270 |
+
model_dir=args.model_dir,
|
271 |
+
device=args.device
|
272 |
+
)
|
273 |
+
|
274 |
+
# Launch Gradio with the specified server name and port
|
275 |
+
demo.launch(
|
276 |
+
server_name=args.server_name,
|
277 |
+
server_port=args.server_port
|
278 |
+
)
|
cli/SparkTTS.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import re
|
17 |
+
import torch
|
18 |
+
from typing import Tuple
|
19 |
+
from pathlib import Path
|
20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
21 |
+
|
22 |
+
from sparktts.utils.file import load_config
|
23 |
+
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
24 |
+
from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
|
25 |
+
|
26 |
+
|
27 |
+
class SparkTTS:
|
28 |
+
"""
|
29 |
+
Spark-TTS for text-to-speech generation.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")):
|
33 |
+
"""
|
34 |
+
Initializes the SparkTTS model with the provided configurations and device.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
model_dir (Path): Directory containing the model and config files.
|
38 |
+
device (torch.device): The device (CPU/GPU) to run the model on.
|
39 |
+
"""
|
40 |
+
self.device = "cpu"
|
41 |
+
self.model_dir = model_dir
|
42 |
+
self.configs = load_config(f"{model_dir}/config.yaml")
|
43 |
+
self.sample_rate = self.configs["sample_rate"]
|
44 |
+
self._initialize_inference()
|
45 |
+
|
46 |
+
def _initialize_inference(self):
|
47 |
+
"""Initializes the tokenizer, model, and audio tokenizer for inference."""
|
48 |
+
self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM")
|
49 |
+
self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
|
50 |
+
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
|
51 |
+
self.model.to(self.device)
|
52 |
+
|
53 |
+
def process_prompt(
|
54 |
+
self,
|
55 |
+
text: str,
|
56 |
+
prompt_speech_path: Path,
|
57 |
+
prompt_text: str = None,
|
58 |
+
) -> Tuple[str, torch.Tensor]:
|
59 |
+
"""
|
60 |
+
Process input for voice cloning.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
text (str): The text input to be converted to speech.
|
64 |
+
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
65 |
+
prompt_text (str, optional): Transcript of the prompt audio.
|
66 |
+
|
67 |
+
Return:
|
68 |
+
Tuple[str, torch.Tensor]: Input prompt; global tokens
|
69 |
+
"""
|
70 |
+
|
71 |
+
global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
|
72 |
+
prompt_speech_path
|
73 |
+
)
|
74 |
+
global_tokens = "".join(
|
75 |
+
[f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
|
76 |
+
)
|
77 |
+
|
78 |
+
# Prepare the input tokens for the model
|
79 |
+
if prompt_text is not None:
|
80 |
+
semantic_tokens = "".join(
|
81 |
+
[f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
|
82 |
+
)
|
83 |
+
inputs = [
|
84 |
+
TASK_TOKEN_MAP["tts"],
|
85 |
+
"<|start_content|>",
|
86 |
+
prompt_text,
|
87 |
+
text,
|
88 |
+
"<|end_content|>",
|
89 |
+
"<|start_global_token|>",
|
90 |
+
global_tokens,
|
91 |
+
"<|end_global_token|>",
|
92 |
+
"<|start_semantic_token|>",
|
93 |
+
semantic_tokens,
|
94 |
+
]
|
95 |
+
else:
|
96 |
+
inputs = [
|
97 |
+
TASK_TOKEN_MAP["tts"],
|
98 |
+
"<|start_content|>",
|
99 |
+
text,
|
100 |
+
"<|end_content|>",
|
101 |
+
"<|start_global_token|>",
|
102 |
+
global_tokens,
|
103 |
+
"<|end_global_token|>",
|
104 |
+
]
|
105 |
+
|
106 |
+
inputs = "".join(inputs)
|
107 |
+
|
108 |
+
return inputs, global_token_ids
|
109 |
+
|
110 |
+
def process_prompt_control(
|
111 |
+
self,
|
112 |
+
gender: str,
|
113 |
+
pitch: str,
|
114 |
+
speed: str,
|
115 |
+
text: str,
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
Process input for voice creation.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
gender (str): female | male.
|
122 |
+
pitch (str): very_low | low | moderate | high | very_high
|
123 |
+
speed (str): very_low | low | moderate | high | very_high
|
124 |
+
text (str): The text input to be converted to speech.
|
125 |
+
|
126 |
+
Return:
|
127 |
+
str: Input prompt
|
128 |
+
"""
|
129 |
+
assert gender in GENDER_MAP.keys()
|
130 |
+
assert pitch in LEVELS_MAP.keys()
|
131 |
+
assert speed in LEVELS_MAP.keys()
|
132 |
+
|
133 |
+
gender_id = GENDER_MAP[gender]
|
134 |
+
pitch_level_id = LEVELS_MAP[pitch]
|
135 |
+
speed_level_id = LEVELS_MAP[speed]
|
136 |
+
|
137 |
+
pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
|
138 |
+
speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
|
139 |
+
gender_tokens = f"<|gender_{gender_id}|>"
|
140 |
+
|
141 |
+
attribte_tokens = "".join(
|
142 |
+
[gender_tokens, pitch_label_tokens, speed_label_tokens]
|
143 |
+
)
|
144 |
+
|
145 |
+
control_tts_inputs = [
|
146 |
+
TASK_TOKEN_MAP["controllable_tts"],
|
147 |
+
"<|start_content|>",
|
148 |
+
text,
|
149 |
+
"<|end_content|>",
|
150 |
+
"<|start_style_label|>",
|
151 |
+
attribte_tokens,
|
152 |
+
"<|end_style_label|>",
|
153 |
+
]
|
154 |
+
|
155 |
+
return "".join(control_tts_inputs)
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def inference(
|
159 |
+
self,
|
160 |
+
text: str,
|
161 |
+
prompt_speech_path: Path = None,
|
162 |
+
prompt_text: str = None,
|
163 |
+
gender: str = None,
|
164 |
+
pitch: str = None,
|
165 |
+
speed: str = None,
|
166 |
+
temperature: float = 0.8,
|
167 |
+
top_k: float = 50,
|
168 |
+
top_p: float = 0.95,
|
169 |
+
) -> torch.Tensor:
|
170 |
+
"""
|
171 |
+
Performs inference to generate speech from text, incorporating prompt audio and/or text.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
text (str): The text input to be converted to speech.
|
175 |
+
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
176 |
+
prompt_text (str, optional): Transcript of the prompt audio.
|
177 |
+
gender (str): female | male.
|
178 |
+
pitch (str): very_low | low | moderate | high | very_high
|
179 |
+
speed (str): very_low | low | moderate | high | very_high
|
180 |
+
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
|
181 |
+
top_k (float, optional): Top-k sampling parameter. Default is 50.
|
182 |
+
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
torch.Tensor: Generated waveform as a tensor.
|
186 |
+
"""
|
187 |
+
if gender is not None:
|
188 |
+
prompt = self.process_prompt_control(gender, pitch, speed, text)
|
189 |
+
|
190 |
+
else:
|
191 |
+
prompt, global_token_ids = self.process_prompt(
|
192 |
+
text, prompt_speech_path, prompt_text
|
193 |
+
)
|
194 |
+
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
195 |
+
|
196 |
+
# Generate speech using the model
|
197 |
+
generated_ids = self.model.generate(
|
198 |
+
**model_inputs,
|
199 |
+
max_new_tokens=3000,
|
200 |
+
do_sample=True,
|
201 |
+
top_k=top_k,
|
202 |
+
top_p=top_p,
|
203 |
+
temperature=temperature,
|
204 |
+
)
|
205 |
+
|
206 |
+
# Trim the output tokens to remove the input tokens
|
207 |
+
generated_ids = [
|
208 |
+
output_ids[len(input_ids) :]
|
209 |
+
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
210 |
+
]
|
211 |
+
|
212 |
+
# Decode the generated tokens into text
|
213 |
+
predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
214 |
+
|
215 |
+
# Extract semantic token IDs from the generated text
|
216 |
+
pred_semantic_ids = (
|
217 |
+
torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
|
218 |
+
.long()
|
219 |
+
.unsqueeze(0)
|
220 |
+
)
|
221 |
+
|
222 |
+
if gender is not None:
|
223 |
+
global_token_ids = (
|
224 |
+
torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
|
225 |
+
.long()
|
226 |
+
.unsqueeze(0)
|
227 |
+
.unsqueeze(0)
|
228 |
+
)
|
229 |
+
|
230 |
+
# Convert semantic tokens back to waveform
|
231 |
+
wav = self.audio_tokenizer.detokenize(
|
232 |
+
global_token_ids.to(self.device).squeeze(0),
|
233 |
+
pred_semantic_ids.to(self.device),
|
234 |
+
)
|
235 |
+
|
236 |
+
return wav
|
cli/inference.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import os
|
18 |
+
import argparse
|
19 |
+
import torch
|
20 |
+
import soundfile as sf
|
21 |
+
import logging
|
22 |
+
from datetime import datetime
|
23 |
+
import platform
|
24 |
+
|
25 |
+
from cli.SparkTTS import SparkTTS
|
26 |
+
|
27 |
+
|
28 |
+
def parse_args():
|
29 |
+
"""Parse command-line arguments."""
|
30 |
+
parser = argparse.ArgumentParser(description="Run TTS inference.")
|
31 |
+
|
32 |
+
parser.add_argument(
|
33 |
+
"--model_dir",
|
34 |
+
type=str,
|
35 |
+
default="pretrained_models/Spark-TTS-0.5B",
|
36 |
+
help="Path to the model directory",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--save_dir",
|
40 |
+
type=str,
|
41 |
+
default="example/results",
|
42 |
+
help="Directory to save generated audio files",
|
43 |
+
)
|
44 |
+
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
|
45 |
+
parser.add_argument(
|
46 |
+
"--text", type=str, required=True, help="Text for TTS generation"
|
47 |
+
)
|
48 |
+
parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
|
49 |
+
parser.add_argument(
|
50 |
+
"--prompt_speech_path",
|
51 |
+
type=str,
|
52 |
+
help="Path to the prompt audio file",
|
53 |
+
)
|
54 |
+
parser.add_argument("--gender", choices=["male", "female"])
|
55 |
+
parser.add_argument(
|
56 |
+
"--pitch", choices=["very_low", "low", "moderate", "high", "very_high"]
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--speed", choices=["very_low", "low", "moderate", "high", "very_high"]
|
60 |
+
)
|
61 |
+
return parser.parse_args()
|
62 |
+
|
63 |
+
|
64 |
+
def run_tts(args):
|
65 |
+
"""Perform TTS inference and save the generated audio."""
|
66 |
+
logging.info(f"Using model from: {args.model_dir}")
|
67 |
+
logging.info(f"Saving audio to: {args.save_dir}")
|
68 |
+
|
69 |
+
# Ensure the save directory exists
|
70 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
71 |
+
|
72 |
+
# Convert device argument to torch.device
|
73 |
+
if platform.system() == "Darwin" and torch.backends.mps.is_available():
|
74 |
+
# macOS with MPS support (Apple Silicon)
|
75 |
+
device = torch.device(f"mps:{args.device}")
|
76 |
+
logging.info(f"Using MPS device: {device}")
|
77 |
+
elif torch.cuda.is_available():
|
78 |
+
# System with CUDA support
|
79 |
+
device = torch.device(f"cuda:{args.device}")
|
80 |
+
logging.info(f"Using CUDA device: {device}")
|
81 |
+
else:
|
82 |
+
# Fall back to CPU
|
83 |
+
device = torch.device("cpu")
|
84 |
+
logging.info("GPU acceleration not available, using CPU")
|
85 |
+
|
86 |
+
# Initialize the model
|
87 |
+
model = SparkTTS(args.model_dir, device)
|
88 |
+
|
89 |
+
# Generate unique filename using timestamp
|
90 |
+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
91 |
+
save_path = os.path.join(args.save_dir, f"{timestamp}.wav")
|
92 |
+
|
93 |
+
logging.info("Starting inference...")
|
94 |
+
|
95 |
+
# Perform inference and save the output audio
|
96 |
+
with torch.no_grad():
|
97 |
+
wav = model.inference(
|
98 |
+
args.text,
|
99 |
+
args.prompt_speech_path,
|
100 |
+
prompt_text=args.prompt_text,
|
101 |
+
gender=args.gender,
|
102 |
+
pitch=args.pitch,
|
103 |
+
speed=args.speed,
|
104 |
+
)
|
105 |
+
sf.write(save_path, wav, samplerate=16000)
|
106 |
+
|
107 |
+
logging.info(f"Audio saved at: {save_path}")
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
logging.basicConfig(
|
112 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
113 |
+
)
|
114 |
+
|
115 |
+
args = parse_args()
|
116 |
+
run_tts(args)
|
datasets/.gitkeep
ADDED
File without changes
|
pretrained_models/.gitkeep
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops==0.8.1
|
2 |
+
einx==0.3.0
|
3 |
+
numpy==2.2.3
|
4 |
+
omegaconf==2.3.0
|
5 |
+
packaging==24.2
|
6 |
+
safetensors==0.5.2
|
7 |
+
soundfile==0.12.1
|
8 |
+
soxr==0.5.0.post1
|
9 |
+
torch==2.5.1
|
10 |
+
torchaudio==2.5.1
|
11 |
+
tqdm==4.66.5
|
12 |
+
transformers==4.46.2
|
13 |
+
gradio==5.18.0
|
14 |
+
huggingface_hub
|
runtime/triton_trtllm/Dockerfile.server
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/tritonserver:25.02-trtllm-python-py3
|
2 |
+
RUN apt-get update && apt-get install -y cmake
|
3 |
+
RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
|
4 |
+
RUN pip install einx==0.3.0 omegaconf==2.3.0 soundfile==0.12.1 soxr==0.5.0.post1 gradio tritonclient librosa
|
5 |
+
WORKDIR /workspace
|
runtime/triton_trtllm/README.md
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Nvidia Triton Inference Serving Best Practice for Spark TTS
|
2 |
+
|
3 |
+
### Quick Start
|
4 |
+
Directly launch the service using docker compose.
|
5 |
+
```sh
|
6 |
+
docker compose up
|
7 |
+
```
|
8 |
+
|
9 |
+
### Build Image
|
10 |
+
Build the docker image from scratch.
|
11 |
+
```sh
|
12 |
+
docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02
|
13 |
+
```
|
14 |
+
|
15 |
+
### Create Docker Container
|
16 |
+
```sh
|
17 |
+
your_mount_dir=/mnt:/mnt
|
18 |
+
docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02
|
19 |
+
```
|
20 |
+
|
21 |
+
### Understanding `run.sh`
|
22 |
+
|
23 |
+
The `run.sh` script automates various steps using stages. You can run specific stages using:
|
24 |
+
```sh
|
25 |
+
bash run.sh <start_stage> <stop_stage> [service_type]
|
26 |
+
```
|
27 |
+
- `<start_stage>`: The stage to begin execution from (0-5).
|
28 |
+
- `<stop_stage>`: The stage to end execution at (0-5).
|
29 |
+
- `[service_type]`: Optional, specifies the service type ('streaming' or 'offline', defaults may apply based on script logic). Required for stages 4 and 5.
|
30 |
+
|
31 |
+
Stages:
|
32 |
+
- **Stage 0**: Download Spark-TTS-0.5B model from HuggingFace.
|
33 |
+
- **Stage 1**: Convert HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
|
34 |
+
- **Stage 2**: Create the Triton model repository structure and configure model files (adjusts for streaming/offline).
|
35 |
+
- **Stage 3**: Launch the Triton Inference Server.
|
36 |
+
- **Stage 4**: Run the gRPC benchmark client.
|
37 |
+
- **Stage 5**: Run the single utterance client (gRPC for streaming, HTTP for offline).
|
38 |
+
|
39 |
+
### Export Models to TensorRT-LLM and Launch Server
|
40 |
+
Inside the docker container, you can prepare the models and launch the Triton server by running stages 0 through 3. This involves downloading the original model, converting it to TensorRT-LLM format, building the optimized TensorRT engines, creating the necessary model repository structure for Triton, and finally starting the server.
|
41 |
+
```sh
|
42 |
+
# This runs stages 0, 1, 2, and 3
|
43 |
+
bash run.sh 0 3
|
44 |
+
```
|
45 |
+
*Note: Stage 2 prepares the model repository differently based on whether you intend to run streaming or offline inference later. You might need to re-run stage 2 if switching service types.*
|
46 |
+
|
47 |
+
|
48 |
+
### Single Utterance Client
|
49 |
+
Run a single inference request. Specify `streaming` or `offline` as the third argument.
|
50 |
+
|
51 |
+
**Streaming Mode (gRPC):**
|
52 |
+
```sh
|
53 |
+
bash run.sh 5 5 streaming
|
54 |
+
```
|
55 |
+
This executes the `client_grpc.py` script with predefined example text and prompt audio in streaming mode.
|
56 |
+
|
57 |
+
**Offline Mode (HTTP):**
|
58 |
+
```sh
|
59 |
+
bash run.sh 5 5 offline
|
60 |
+
```
|
61 |
+
|
62 |
+
### Benchmark using Dataset
|
63 |
+
Run the benchmark client against the running Triton server. Specify `streaming` or `offline` as the third argument.
|
64 |
+
```sh
|
65 |
+
# Run benchmark in streaming mode
|
66 |
+
bash run.sh 4 4 streaming
|
67 |
+
|
68 |
+
# Run benchmark in offline mode
|
69 |
+
bash run.sh 4 4 offline
|
70 |
+
|
71 |
+
# You can also customize parameters like num_task directly in client_grpc.py or via args if supported
|
72 |
+
# Example from run.sh (streaming):
|
73 |
+
# python3 client_grpc.py \
|
74 |
+
# --server-addr localhost \
|
75 |
+
# --model-name spark_tts \
|
76 |
+
# --num-tasks 2 \
|
77 |
+
# --mode streaming \
|
78 |
+
# --log-dir ./log_concurrent_tasks_2_streaming_new
|
79 |
+
|
80 |
+
# Example customizing dataset (requires modifying client_grpc.py or adding args):
|
81 |
+
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --mode [streaming|offline]
|
82 |
+
```
|
83 |
+
|
84 |
+
### Benchmark Results
|
85 |
+
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts), total audio duration 169 secs.
|
86 |
+
|
87 |
+
| Mode | Note | Concurrency | Avg Latency | First Chunk Latency (P50) | RTF |
|
88 |
+
|-------|-----------|-----------------------|---------|----------------|-|
|
89 |
+
| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms |-| 0.1362|
|
90 |
+
| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms |-|0.0737|
|
91 |
+
| Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms |-| 0.0704|
|
92 |
+
| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 1 | 913.28 ms |210.42 ms| 0.1501 |
|
93 |
+
| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 2 | 1009.23 ms |226.08 ms |0.0862 |
|
94 |
+
| Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 4 | 1793.86 ms |1017.70 ms| 0.0824 |
|
runtime/triton_trtllm/client_grpc.py
ADDED
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
3 |
+
# 2023 Nvidia (authors: Yuekai Zhang)
|
4 |
+
# 2023 Recurrent.ai (authors: Songtao Shi)
|
5 |
+
# See LICENSE for clarification regarding multiple authors
|
6 |
+
#
|
7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8 |
+
# you may not use this file except in compliance with the License.
|
9 |
+
# You may obtain a copy of the License at
|
10 |
+
#
|
11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12 |
+
#
|
13 |
+
# Unless required by applicable law or agreed to in writing, software
|
14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16 |
+
# See the License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
"""
|
19 |
+
This script supports to load dataset from huggingface and sends it to the server
|
20 |
+
for decoding, in parallel.
|
21 |
+
|
22 |
+
Usage:
|
23 |
+
num_task=2
|
24 |
+
|
25 |
+
# For offline F5-TTS
|
26 |
+
python3 client_grpc.py \
|
27 |
+
--server-addr localhost \
|
28 |
+
--model-name f5_tts \
|
29 |
+
--num-tasks $num_task \
|
30 |
+
--huggingface-dataset yuekai/seed_tts \
|
31 |
+
--split-name test_zh \
|
32 |
+
--log-dir ./log_concurrent_tasks_${num_task}
|
33 |
+
|
34 |
+
# For offline Spark-TTS-0.5B
|
35 |
+
python3 client_grpc.py \
|
36 |
+
--server-addr localhost \
|
37 |
+
--model-name spark_tts \
|
38 |
+
--num-tasks $num_task \
|
39 |
+
--huggingface-dataset yuekai/seed_tts \
|
40 |
+
--split-name wenetspeech4tts \
|
41 |
+
--log-dir ./log_concurrent_tasks_${num_task}
|
42 |
+
"""
|
43 |
+
|
44 |
+
import argparse
|
45 |
+
import asyncio
|
46 |
+
import json
|
47 |
+
import queue # Added
|
48 |
+
import uuid # Added
|
49 |
+
import functools # Added
|
50 |
+
|
51 |
+
import os
|
52 |
+
import time
|
53 |
+
import types
|
54 |
+
from pathlib import Path
|
55 |
+
|
56 |
+
import numpy as np
|
57 |
+
import soundfile as sf
|
58 |
+
import tritonclient
|
59 |
+
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
|
60 |
+
import tritonclient.grpc as grpcclient_sync # Added sync client import
|
61 |
+
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
|
62 |
+
|
63 |
+
|
64 |
+
# --- Added UserData and callback ---
|
65 |
+
class UserData:
|
66 |
+
def __init__(self):
|
67 |
+
self._completed_requests = queue.Queue()
|
68 |
+
self._first_chunk_time = None
|
69 |
+
self._start_time = None
|
70 |
+
|
71 |
+
def record_start_time(self):
|
72 |
+
self._start_time = time.time()
|
73 |
+
|
74 |
+
def get_first_chunk_latency(self):
|
75 |
+
if self._first_chunk_time and self._start_time:
|
76 |
+
return self._first_chunk_time - self._start_time
|
77 |
+
return None
|
78 |
+
|
79 |
+
def callback(user_data, result, error):
|
80 |
+
if user_data._first_chunk_time is None and not error:
|
81 |
+
user_data._first_chunk_time = time.time() # Record time of first successful chunk
|
82 |
+
if error:
|
83 |
+
user_data._completed_requests.put(error)
|
84 |
+
else:
|
85 |
+
user_data._completed_requests.put(result)
|
86 |
+
# --- End Added UserData and callback ---
|
87 |
+
|
88 |
+
|
89 |
+
def write_triton_stats(stats, summary_file):
|
90 |
+
with open(summary_file, "w") as summary_f:
|
91 |
+
model_stats = stats["model_stats"]
|
92 |
+
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
|
93 |
+
summary_f.write(
|
94 |
+
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
|
95 |
+
)
|
96 |
+
summary_f.write("To learn more about the log, please refer to: \n")
|
97 |
+
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
|
98 |
+
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
|
99 |
+
summary_f.write(
|
100 |
+
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
|
101 |
+
)
|
102 |
+
summary_f.write(
|
103 |
+
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
|
104 |
+
)
|
105 |
+
summary_f.write(
|
106 |
+
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
|
107 |
+
)
|
108 |
+
summary_f.write(
|
109 |
+
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
|
110 |
+
)
|
111 |
+
for model_state in model_stats:
|
112 |
+
if "last_inference" not in model_state:
|
113 |
+
continue
|
114 |
+
summary_f.write(f"model name is {model_state['name']} \n")
|
115 |
+
model_inference_stats = model_state["inference_stats"]
|
116 |
+
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
|
117 |
+
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
|
118 |
+
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
|
119 |
+
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
|
120 |
+
summary_f.write(
|
121 |
+
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
|
122 |
+
)
|
123 |
+
model_batch_stats = model_state["batch_stats"]
|
124 |
+
for batch in model_batch_stats:
|
125 |
+
batch_size = int(batch["batch_size"])
|
126 |
+
compute_input = batch["compute_input"]
|
127 |
+
compute_output = batch["compute_output"]
|
128 |
+
compute_infer = batch["compute_infer"]
|
129 |
+
batch_count = int(compute_infer["count"])
|
130 |
+
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
|
131 |
+
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
|
132 |
+
compute_input_time_ms = int(compute_input["ns"]) / 1e6
|
133 |
+
compute_output_time_ms = int(compute_output["ns"]) / 1e6
|
134 |
+
summary_f.write(
|
135 |
+
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
|
136 |
+
)
|
137 |
+
summary_f.write(
|
138 |
+
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
|
139 |
+
)
|
140 |
+
summary_f.write(
|
141 |
+
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
def get_args():
|
146 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
147 |
+
|
148 |
+
parser.add_argument(
|
149 |
+
"--server-addr",
|
150 |
+
type=str,
|
151 |
+
default="localhost",
|
152 |
+
help="Address of the server",
|
153 |
+
)
|
154 |
+
|
155 |
+
parser.add_argument(
|
156 |
+
"--server-port",
|
157 |
+
type=int,
|
158 |
+
default=8001,
|
159 |
+
help="Grpc port of the triton server, default is 8001",
|
160 |
+
)
|
161 |
+
|
162 |
+
parser.add_argument(
|
163 |
+
"--reference-audio",
|
164 |
+
type=str,
|
165 |
+
default=None,
|
166 |
+
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
167 |
+
)
|
168 |
+
|
169 |
+
parser.add_argument(
|
170 |
+
"--reference-text",
|
171 |
+
type=str,
|
172 |
+
default="",
|
173 |
+
help="",
|
174 |
+
)
|
175 |
+
|
176 |
+
parser.add_argument(
|
177 |
+
"--target-text",
|
178 |
+
type=str,
|
179 |
+
default="",
|
180 |
+
help="",
|
181 |
+
)
|
182 |
+
|
183 |
+
parser.add_argument(
|
184 |
+
"--huggingface-dataset",
|
185 |
+
type=str,
|
186 |
+
default="yuekai/seed_tts",
|
187 |
+
help="dataset name in huggingface dataset hub",
|
188 |
+
)
|
189 |
+
|
190 |
+
parser.add_argument(
|
191 |
+
"--split-name",
|
192 |
+
type=str,
|
193 |
+
default="wenetspeech4tts",
|
194 |
+
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
195 |
+
help="dataset split name, default is 'test'",
|
196 |
+
)
|
197 |
+
|
198 |
+
parser.add_argument(
|
199 |
+
"--manifest-path",
|
200 |
+
type=str,
|
201 |
+
default=None,
|
202 |
+
help="Path to the manifest dir which includes wav.scp trans.txt files.",
|
203 |
+
)
|
204 |
+
|
205 |
+
parser.add_argument(
|
206 |
+
"--model-name",
|
207 |
+
type=str,
|
208 |
+
default="f5_tts",
|
209 |
+
choices=["f5_tts", "spark_tts"],
|
210 |
+
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
211 |
+
)
|
212 |
+
|
213 |
+
parser.add_argument(
|
214 |
+
"--num-tasks",
|
215 |
+
type=int,
|
216 |
+
default=1,
|
217 |
+
help="Number of concurrent tasks for sending",
|
218 |
+
)
|
219 |
+
|
220 |
+
parser.add_argument(
|
221 |
+
"--log-interval",
|
222 |
+
type=int,
|
223 |
+
default=5,
|
224 |
+
help="Controls how frequently we print the log.",
|
225 |
+
)
|
226 |
+
|
227 |
+
parser.add_argument(
|
228 |
+
"--compute-wer",
|
229 |
+
action="store_true",
|
230 |
+
default=False,
|
231 |
+
help="""True to compute WER.
|
232 |
+
""",
|
233 |
+
)
|
234 |
+
|
235 |
+
parser.add_argument(
|
236 |
+
"--log-dir",
|
237 |
+
type=str,
|
238 |
+
required=False,
|
239 |
+
default="./tmp",
|
240 |
+
help="log directory",
|
241 |
+
)
|
242 |
+
|
243 |
+
# --- Added arguments ---
|
244 |
+
parser.add_argument(
|
245 |
+
"--mode",
|
246 |
+
type=str,
|
247 |
+
default="offline",
|
248 |
+
choices=["offline", "streaming"],
|
249 |
+
help="Select offline or streaming benchmark mode."
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--chunk-overlap-duration",
|
253 |
+
type=float,
|
254 |
+
default=0.1,
|
255 |
+
help="Chunk overlap duration for streaming reconstruction (in seconds)."
|
256 |
+
)
|
257 |
+
# --- End Added arguments ---
|
258 |
+
|
259 |
+
return parser.parse_args()
|
260 |
+
|
261 |
+
|
262 |
+
def load_audio(wav_path, target_sample_rate=16000):
|
263 |
+
assert target_sample_rate == 16000, "hard coding in server"
|
264 |
+
if isinstance(wav_path, dict):
|
265 |
+
waveform = wav_path["array"]
|
266 |
+
sample_rate = wav_path["sampling_rate"]
|
267 |
+
else:
|
268 |
+
waveform, sample_rate = sf.read(wav_path)
|
269 |
+
if sample_rate != target_sample_rate:
|
270 |
+
from scipy.signal import resample
|
271 |
+
|
272 |
+
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
273 |
+
waveform = resample(waveform, num_samples)
|
274 |
+
return waveform, target_sample_rate
|
275 |
+
|
276 |
+
def prepare_request_input_output(
|
277 |
+
protocol_client, # Can be grpcclient_aio or grpcclient_sync
|
278 |
+
waveform,
|
279 |
+
reference_text,
|
280 |
+
target_text,
|
281 |
+
sample_rate=16000,
|
282 |
+
padding_duration: int = None # Optional padding for offline mode
|
283 |
+
):
|
284 |
+
"""Prepares inputs for Triton inference (offline or streaming)."""
|
285 |
+
assert len(waveform.shape) == 1, "waveform should be 1D"
|
286 |
+
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
287 |
+
|
288 |
+
# Apply padding only if padding_duration is provided (for offline)
|
289 |
+
if padding_duration:
|
290 |
+
duration = len(waveform) / sample_rate
|
291 |
+
# Estimate target duration based on text length ratio (crude estimation)
|
292 |
+
# Avoid division by zero if reference_text is empty
|
293 |
+
if reference_text:
|
294 |
+
estimated_target_duration = duration / len(reference_text) * len(target_text)
|
295 |
+
else:
|
296 |
+
estimated_target_duration = duration # Assume target duration similar to reference if no text
|
297 |
+
|
298 |
+
# Calculate required samples based on estimated total duration
|
299 |
+
required_total_samples = padding_duration * sample_rate * (
|
300 |
+
(int(estimated_target_duration + duration) // padding_duration) + 1
|
301 |
+
)
|
302 |
+
samples = np.zeros((1, required_total_samples), dtype=np.float32)
|
303 |
+
samples[0, : len(waveform)] = waveform
|
304 |
+
else:
|
305 |
+
# No padding for streaming or if padding_duration is None
|
306 |
+
samples = waveform.reshape(1, -1).astype(np.float32)
|
307 |
+
|
308 |
+
# Common input creation logic
|
309 |
+
inputs = [
|
310 |
+
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
|
311 |
+
protocol_client.InferInput(
|
312 |
+
"reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
|
313 |
+
),
|
314 |
+
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
|
315 |
+
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
|
316 |
+
]
|
317 |
+
inputs[0].set_data_from_numpy(samples)
|
318 |
+
inputs[1].set_data_from_numpy(lengths)
|
319 |
+
|
320 |
+
input_data_numpy = np.array([reference_text], dtype=object)
|
321 |
+
input_data_numpy = input_data_numpy.reshape((1, 1))
|
322 |
+
inputs[2].set_data_from_numpy(input_data_numpy)
|
323 |
+
|
324 |
+
input_data_numpy = np.array([target_text], dtype=object)
|
325 |
+
input_data_numpy = input_data_numpy.reshape((1, 1))
|
326 |
+
inputs[3].set_data_from_numpy(input_data_numpy)
|
327 |
+
|
328 |
+
outputs = [protocol_client.InferRequestedOutput("waveform")]
|
329 |
+
|
330 |
+
return inputs, outputs
|
331 |
+
|
332 |
+
def run_sync_streaming_inference(
|
333 |
+
sync_triton_client: tritonclient.grpc.InferenceServerClient,
|
334 |
+
model_name: str,
|
335 |
+
inputs: list,
|
336 |
+
outputs: list,
|
337 |
+
request_id: str,
|
338 |
+
user_data: UserData,
|
339 |
+
chunk_overlap_duration: float,
|
340 |
+
save_sample_rate: int,
|
341 |
+
audio_save_path: str,
|
342 |
+
):
|
343 |
+
"""Helper function to run the blocking sync streaming call."""
|
344 |
+
start_time_total = time.time()
|
345 |
+
user_data.record_start_time() # Record start time for first chunk latency calculation
|
346 |
+
|
347 |
+
# Establish stream
|
348 |
+
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
|
349 |
+
|
350 |
+
# Send request
|
351 |
+
sync_triton_client.async_stream_infer(
|
352 |
+
model_name,
|
353 |
+
inputs,
|
354 |
+
request_id=request_id,
|
355 |
+
outputs=outputs,
|
356 |
+
enable_empty_final_response=True,
|
357 |
+
)
|
358 |
+
|
359 |
+
# Process results
|
360 |
+
audios = []
|
361 |
+
while True:
|
362 |
+
try:
|
363 |
+
result = user_data._completed_requests.get() # Add timeout
|
364 |
+
if isinstance(result, InferenceServerException):
|
365 |
+
print(f"Received InferenceServerException: {result}")
|
366 |
+
sync_triton_client.stop_stream()
|
367 |
+
return None, None, None # Indicate error
|
368 |
+
# Get response metadata
|
369 |
+
response = result.get_response()
|
370 |
+
final = response.parameters["triton_final_response"].bool_param
|
371 |
+
if final is True:
|
372 |
+
break
|
373 |
+
|
374 |
+
audio_chunk = result.as_numpy("waveform").reshape(-1)
|
375 |
+
if audio_chunk.size > 0: # Only append non-empty chunks
|
376 |
+
audios.append(audio_chunk)
|
377 |
+
else:
|
378 |
+
print("Warning: received empty audio chunk.")
|
379 |
+
|
380 |
+
except queue.Empty:
|
381 |
+
print(f"Timeout waiting for response for request id {request_id}")
|
382 |
+
sync_triton_client.stop_stream()
|
383 |
+
return None, None, None # Indicate error
|
384 |
+
|
385 |
+
sync_triton_client.stop_stream()
|
386 |
+
end_time_total = time.time()
|
387 |
+
total_request_latency = end_time_total - start_time_total
|
388 |
+
first_chunk_latency = user_data.get_first_chunk_latency()
|
389 |
+
|
390 |
+
# Reconstruct audio using cross-fade (from client_grpc_streaming.py)
|
391 |
+
actual_duration = 0
|
392 |
+
if audios:
|
393 |
+
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
|
394 |
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
395 |
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
396 |
+
reconstructed_audio = None
|
397 |
+
|
398 |
+
# Simplified reconstruction based on client_grpc_streaming.py
|
399 |
+
if not audios:
|
400 |
+
print("Warning: No audio chunks received.")
|
401 |
+
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
|
402 |
+
elif len(audios) == 1:
|
403 |
+
reconstructed_audio = audios[0]
|
404 |
+
else:
|
405 |
+
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
|
406 |
+
for i in range(1, len(audios)):
|
407 |
+
# Cross-fade section
|
408 |
+
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
|
409 |
+
audios[i - 1][-cross_fade_samples:] * fade_out)
|
410 |
+
# Middle section of the current chunk
|
411 |
+
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
|
412 |
+
# Concatenate
|
413 |
+
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
|
414 |
+
# Add the last part of the final chunk
|
415 |
+
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
|
416 |
+
|
417 |
+
if reconstructed_audio is not None and reconstructed_audio.size > 0:
|
418 |
+
actual_duration = len(reconstructed_audio) / save_sample_rate
|
419 |
+
# Save reconstructed audio
|
420 |
+
os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
|
421 |
+
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
|
422 |
+
else:
|
423 |
+
print("Warning: No audio chunks received or reconstructed.")
|
424 |
+
actual_duration = 0 # Set duration to 0 if no audio
|
425 |
+
|
426 |
+
else:
|
427 |
+
print("Warning: No audio chunks received.")
|
428 |
+
actual_duration = 0
|
429 |
+
|
430 |
+
return total_request_latency, first_chunk_latency, actual_duration
|
431 |
+
|
432 |
+
|
433 |
+
async def send_streaming(
|
434 |
+
manifest_item_list: list,
|
435 |
+
name: str,
|
436 |
+
server_url: str, # Changed from sync_triton_client
|
437 |
+
protocol_client: types.ModuleType,
|
438 |
+
log_interval: int,
|
439 |
+
model_name: str,
|
440 |
+
audio_save_dir: str = "./",
|
441 |
+
save_sample_rate: int = 16000,
|
442 |
+
chunk_overlap_duration: float = 0.1,
|
443 |
+
padding_duration: int = None,
|
444 |
+
):
|
445 |
+
total_duration = 0.0
|
446 |
+
latency_data = []
|
447 |
+
task_id = int(name[5:])
|
448 |
+
sync_triton_client = None # Initialize client variable
|
449 |
+
|
450 |
+
try: # Wrap in try...finally to ensure client closing
|
451 |
+
print(f"{name}: Initializing sync client for streaming...")
|
452 |
+
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
|
453 |
+
|
454 |
+
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
|
455 |
+
for i, item in enumerate(manifest_item_list):
|
456 |
+
if i % log_interval == 0:
|
457 |
+
print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
|
458 |
+
|
459 |
+
try:
|
460 |
+
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
461 |
+
reference_text, target_text = item["reference_text"], item["target_text"]
|
462 |
+
|
463 |
+
inputs, outputs = prepare_request_input_output(
|
464 |
+
protocol_client,
|
465 |
+
waveform,
|
466 |
+
reference_text,
|
467 |
+
target_text,
|
468 |
+
sample_rate,
|
469 |
+
padding_duration=padding_duration
|
470 |
+
)
|
471 |
+
request_id = str(uuid.uuid4())
|
472 |
+
user_data = UserData()
|
473 |
+
|
474 |
+
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
475 |
+
|
476 |
+
total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
|
477 |
+
run_sync_streaming_inference,
|
478 |
+
sync_triton_client,
|
479 |
+
model_name,
|
480 |
+
inputs,
|
481 |
+
outputs,
|
482 |
+
request_id,
|
483 |
+
user_data,
|
484 |
+
chunk_overlap_duration,
|
485 |
+
save_sample_rate,
|
486 |
+
audio_save_path
|
487 |
+
)
|
488 |
+
|
489 |
+
if total_request_latency is not None:
|
490 |
+
print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
|
491 |
+
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
|
492 |
+
total_duration += actual_duration
|
493 |
+
else:
|
494 |
+
print(f"{name}: Item {i} failed.")
|
495 |
+
|
496 |
+
|
497 |
+
except FileNotFoundError:
|
498 |
+
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
|
499 |
+
except Exception as e:
|
500 |
+
print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
|
501 |
+
import traceback
|
502 |
+
traceback.print_exc()
|
503 |
+
|
504 |
+
|
505 |
+
finally: # Ensure client is closed
|
506 |
+
if sync_triton_client:
|
507 |
+
try:
|
508 |
+
print(f"{name}: Closing sync client...")
|
509 |
+
sync_triton_client.close()
|
510 |
+
except Exception as e:
|
511 |
+
print(f"{name}: Error closing sync client: {e}")
|
512 |
+
|
513 |
+
|
514 |
+
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
|
515 |
+
return total_duration, latency_data
|
516 |
+
|
517 |
+
async def send(
|
518 |
+
manifest_item_list: list,
|
519 |
+
name: str,
|
520 |
+
triton_client: tritonclient.grpc.aio.InferenceServerClient,
|
521 |
+
protocol_client: types.ModuleType,
|
522 |
+
log_interval: int,
|
523 |
+
model_name: str,
|
524 |
+
padding_duration: int = None,
|
525 |
+
audio_save_dir: str = "./",
|
526 |
+
save_sample_rate: int = 16000,
|
527 |
+
):
|
528 |
+
total_duration = 0.0
|
529 |
+
latency_data = []
|
530 |
+
task_id = int(name[5:])
|
531 |
+
|
532 |
+
print(f"manifest_item_list: {manifest_item_list}")
|
533 |
+
for i, item in enumerate(manifest_item_list):
|
534 |
+
if i % log_interval == 0:
|
535 |
+
print(f"{name}: {i}/{len(manifest_item_list)}")
|
536 |
+
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
537 |
+
reference_text, target_text = item["reference_text"], item["target_text"]
|
538 |
+
|
539 |
+
inputs, outputs = prepare_request_input_output(
|
540 |
+
protocol_client,
|
541 |
+
waveform,
|
542 |
+
reference_text,
|
543 |
+
target_text,
|
544 |
+
sample_rate,
|
545 |
+
padding_duration=padding_duration
|
546 |
+
)
|
547 |
+
sequence_id = 100000000 + i + task_id * 10
|
548 |
+
start = time.time()
|
549 |
+
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
|
550 |
+
|
551 |
+
audio = response.as_numpy("waveform").reshape(-1)
|
552 |
+
actual_duration = len(audio) / save_sample_rate
|
553 |
+
|
554 |
+
end = time.time() - start
|
555 |
+
|
556 |
+
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
557 |
+
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
558 |
+
|
559 |
+
latency_data.append((end, actual_duration))
|
560 |
+
total_duration += actual_duration
|
561 |
+
|
562 |
+
return total_duration, latency_data
|
563 |
+
|
564 |
+
|
565 |
+
def load_manifests(manifest_path):
|
566 |
+
with open(manifest_path, "r") as f:
|
567 |
+
manifest_list = []
|
568 |
+
for line in f:
|
569 |
+
assert len(line.strip().split("|")) == 4
|
570 |
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
571 |
+
utt = Path(utt).stem
|
572 |
+
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
|
573 |
+
if not os.path.isabs(prompt_wav):
|
574 |
+
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
|
575 |
+
manifest_list.append(
|
576 |
+
{
|
577 |
+
"audio_filepath": prompt_wav,
|
578 |
+
"reference_text": prompt_text,
|
579 |
+
"target_text": gt_text,
|
580 |
+
"target_audio_path": utt,
|
581 |
+
}
|
582 |
+
)
|
583 |
+
return manifest_list
|
584 |
+
|
585 |
+
|
586 |
+
def split_data(data, k):
|
587 |
+
n = len(data)
|
588 |
+
if n < k:
|
589 |
+
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
|
590 |
+
k = n
|
591 |
+
|
592 |
+
quotient = n // k
|
593 |
+
remainder = n % k
|
594 |
+
|
595 |
+
result = []
|
596 |
+
start = 0
|
597 |
+
for i in range(k):
|
598 |
+
if i < remainder:
|
599 |
+
end = start + quotient + 1
|
600 |
+
else:
|
601 |
+
end = start + quotient
|
602 |
+
|
603 |
+
result.append(data[start:end])
|
604 |
+
start = end
|
605 |
+
|
606 |
+
return result
|
607 |
+
|
608 |
+
async def main():
|
609 |
+
args = get_args()
|
610 |
+
url = f"{args.server_addr}:{args.server_port}"
|
611 |
+
|
612 |
+
# --- Client Initialization based on mode ---
|
613 |
+
triton_client = None
|
614 |
+
protocol_client = None
|
615 |
+
if args.mode == "offline":
|
616 |
+
print("Initializing gRPC client for offline mode...")
|
617 |
+
# Use the async client for offline tasks
|
618 |
+
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
619 |
+
protocol_client = grpcclient_aio
|
620 |
+
elif args.mode == "streaming":
|
621 |
+
print("Initializing gRPC client for streaming mode...")
|
622 |
+
# Use the sync client for streaming tasks, handled via asyncio.to_thread
|
623 |
+
# We will create one sync client instance PER TASK inside send_streaming.
|
624 |
+
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
|
625 |
+
protocol_client = grpcclient_sync # protocol client for input prep
|
626 |
+
else:
|
627 |
+
raise ValueError(f"Invalid mode: {args.mode}")
|
628 |
+
# --- End Client Initialization ---
|
629 |
+
|
630 |
+
if args.reference_audio:
|
631 |
+
args.num_tasks = 1
|
632 |
+
args.log_interval = 1
|
633 |
+
manifest_item_list = [
|
634 |
+
{
|
635 |
+
"reference_text": args.reference_text,
|
636 |
+
"target_text": args.target_text,
|
637 |
+
"audio_filepath": args.reference_audio,
|
638 |
+
"target_audio_path": "test",
|
639 |
+
}
|
640 |
+
]
|
641 |
+
elif args.huggingface_dataset:
|
642 |
+
import datasets
|
643 |
+
|
644 |
+
dataset = datasets.load_dataset(
|
645 |
+
args.huggingface_dataset,
|
646 |
+
split=args.split_name,
|
647 |
+
trust_remote_code=True,
|
648 |
+
)
|
649 |
+
manifest_item_list = []
|
650 |
+
for i in range(len(dataset)):
|
651 |
+
manifest_item_list.append(
|
652 |
+
{
|
653 |
+
"audio_filepath": dataset[i]["prompt_audio"],
|
654 |
+
"reference_text": dataset[i]["prompt_text"],
|
655 |
+
"target_audio_path": dataset[i]["id"],
|
656 |
+
"target_text": dataset[i]["target_text"],
|
657 |
+
}
|
658 |
+
)
|
659 |
+
else:
|
660 |
+
manifest_item_list = load_manifests(args.manifest_path)
|
661 |
+
|
662 |
+
num_tasks = min(args.num_tasks, len(manifest_item_list))
|
663 |
+
manifest_item_list = split_data(manifest_item_list, num_tasks)
|
664 |
+
|
665 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
666 |
+
tasks = []
|
667 |
+
start_time = time.time()
|
668 |
+
for i in range(num_tasks):
|
669 |
+
# --- Task Creation based on mode ---
|
670 |
+
if args.mode == "offline":
|
671 |
+
task = asyncio.create_task(
|
672 |
+
send(
|
673 |
+
manifest_item_list[i],
|
674 |
+
name=f"task-{i}",
|
675 |
+
triton_client=triton_client,
|
676 |
+
protocol_client=protocol_client,
|
677 |
+
log_interval=args.log_interval,
|
678 |
+
model_name=args.model_name,
|
679 |
+
audio_save_dir=args.log_dir,
|
680 |
+
padding_duration=1,
|
681 |
+
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
682 |
+
)
|
683 |
+
)
|
684 |
+
elif args.mode == "streaming":
|
685 |
+
task = asyncio.create_task(
|
686 |
+
send_streaming(
|
687 |
+
manifest_item_list[i],
|
688 |
+
name=f"task-{i}",
|
689 |
+
server_url=url, # Pass URL instead of client
|
690 |
+
protocol_client=protocol_client,
|
691 |
+
log_interval=args.log_interval,
|
692 |
+
model_name=args.model_name,
|
693 |
+
audio_save_dir=args.log_dir,
|
694 |
+
padding_duration=10,
|
695 |
+
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
696 |
+
chunk_overlap_duration=args.chunk_overlap_duration,
|
697 |
+
)
|
698 |
+
)
|
699 |
+
# --- End Task Creation ---
|
700 |
+
tasks.append(task)
|
701 |
+
|
702 |
+
ans_list = await asyncio.gather(*tasks)
|
703 |
+
|
704 |
+
end_time = time.time()
|
705 |
+
elapsed = end_time - start_time
|
706 |
+
|
707 |
+
total_duration = 0.0
|
708 |
+
latency_data = []
|
709 |
+
for ans in ans_list:
|
710 |
+
if ans:
|
711 |
+
total_duration += ans[0]
|
712 |
+
latency_data.extend(ans[1]) # Use extend for list of lists
|
713 |
+
else:
|
714 |
+
print("Warning: A task returned None, possibly due to an error.")
|
715 |
+
|
716 |
+
|
717 |
+
if total_duration == 0:
|
718 |
+
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
|
719 |
+
rtf = float('inf')
|
720 |
+
else:
|
721 |
+
rtf = elapsed / total_duration
|
722 |
+
|
723 |
+
s = f"Mode: {args.mode}\n"
|
724 |
+
s += f"RTF: {rtf:.4f}\n"
|
725 |
+
s += f"total_duration: {total_duration:.3f} seconds\n"
|
726 |
+
s += f"({total_duration / 3600:.2f} hours)\n"
|
727 |
+
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
|
728 |
+
|
729 |
+
# --- Statistics Reporting based on mode ---
|
730 |
+
if latency_data:
|
731 |
+
if args.mode == "offline":
|
732 |
+
# Original offline latency calculation
|
733 |
+
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
|
734 |
+
if latency_list:
|
735 |
+
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
736 |
+
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
|
737 |
+
s += f"latency_variance: {latency_variance:.2f}\n"
|
738 |
+
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
|
739 |
+
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
|
740 |
+
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
|
741 |
+
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
|
742 |
+
s += f"average_latency_ms: {latency_ms:.2f}\n"
|
743 |
+
else:
|
744 |
+
s += "No latency data collected for offline mode.\n"
|
745 |
+
|
746 |
+
elif args.mode == "streaming":
|
747 |
+
# Calculate stats for total request latency and first chunk latency
|
748 |
+
total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
|
749 |
+
first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
|
750 |
+
|
751 |
+
s += "\n--- Total Request Latency ---\n"
|
752 |
+
if total_latency_list:
|
753 |
+
avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
|
754 |
+
variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
|
755 |
+
s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
|
756 |
+
s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
|
757 |
+
s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
|
758 |
+
s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
|
759 |
+
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
|
760 |
+
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
|
761 |
+
else:
|
762 |
+
s += "No total request latency data collected.\n"
|
763 |
+
|
764 |
+
s += "\n--- First Chunk Latency ---\n"
|
765 |
+
if first_chunk_latency_list:
|
766 |
+
avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
|
767 |
+
variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
|
768 |
+
s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
|
769 |
+
s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
|
770 |
+
s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
|
771 |
+
s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
|
772 |
+
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
|
773 |
+
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
|
774 |
+
else:
|
775 |
+
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
|
776 |
+
else:
|
777 |
+
s += "No latency data collected.\n"
|
778 |
+
# --- End Statistics Reporting ---
|
779 |
+
|
780 |
+
print(s)
|
781 |
+
if args.manifest_path:
|
782 |
+
name = Path(args.manifest_path).stem
|
783 |
+
elif args.split_name:
|
784 |
+
name = args.split_name
|
785 |
+
elif args.reference_audio:
|
786 |
+
name = Path(args.reference_audio).stem
|
787 |
+
else:
|
788 |
+
name = "results" # Default name if no manifest/split/audio provided
|
789 |
+
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
|
790 |
+
f.write(s)
|
791 |
+
|
792 |
+
# --- Statistics Fetching using temporary Async Client ---
|
793 |
+
# Use a separate async client for fetching stats regardless of mode
|
794 |
+
stats_client = None
|
795 |
+
try:
|
796 |
+
print("Initializing temporary async client for fetching stats...")
|
797 |
+
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
|
798 |
+
print("Fetching inference statistics...")
|
799 |
+
# Fetching for all models, filtering might be needed depending on server setup
|
800 |
+
stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
|
801 |
+
print("Fetching model config...")
|
802 |
+
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
|
803 |
+
|
804 |
+
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
|
805 |
+
|
806 |
+
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
|
807 |
+
json.dump(metadata, f, indent=4)
|
808 |
+
|
809 |
+
except Exception as e:
|
810 |
+
print(f"Could not retrieve statistics or config: {e}")
|
811 |
+
finally:
|
812 |
+
if stats_client:
|
813 |
+
try:
|
814 |
+
print("Closing temporary async stats client...")
|
815 |
+
await stats_client.close()
|
816 |
+
except Exception as e:
|
817 |
+
print(f"Error closing async stats client: {e}")
|
818 |
+
# --- End Statistics Fetching ---
|
819 |
+
|
820 |
+
|
821 |
+
if __name__ == "__main__":
|
822 |
+
# asyncio.run(main()) # Use TaskGroup for better exception handling if needed
|
823 |
+
async def run_main():
|
824 |
+
try:
|
825 |
+
await main()
|
826 |
+
except Exception as e:
|
827 |
+
print(f"An error occurred in main: {e}")
|
828 |
+
import traceback
|
829 |
+
traceback.print_exc()
|
830 |
+
|
831 |
+
asyncio.run(run_main())
|
runtime/triton_trtllm/client_http.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
import requests
|
27 |
+
import soundfile as sf
|
28 |
+
import json
|
29 |
+
import numpy as np
|
30 |
+
import argparse
|
31 |
+
|
32 |
+
def get_args():
|
33 |
+
parser = argparse.ArgumentParser(
|
34 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
35 |
+
)
|
36 |
+
|
37 |
+
parser.add_argument(
|
38 |
+
"--server-url",
|
39 |
+
type=str,
|
40 |
+
default="localhost:8000",
|
41 |
+
help="Address of the server",
|
42 |
+
)
|
43 |
+
|
44 |
+
parser.add_argument(
|
45 |
+
"--reference-audio",
|
46 |
+
type=str,
|
47 |
+
default="../../example/prompt_audio.wav",
|
48 |
+
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
|
49 |
+
)
|
50 |
+
|
51 |
+
parser.add_argument(
|
52 |
+
"--reference-text",
|
53 |
+
type=str,
|
54 |
+
default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
|
55 |
+
help="",
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"--target-text",
|
60 |
+
type=str,
|
61 |
+
default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
|
62 |
+
help="",
|
63 |
+
)
|
64 |
+
|
65 |
+
parser.add_argument(
|
66 |
+
"--model-name",
|
67 |
+
type=str,
|
68 |
+
default="spark_tts",
|
69 |
+
choices=[
|
70 |
+
"f5_tts", "spark_tts"
|
71 |
+
],
|
72 |
+
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
73 |
+
)
|
74 |
+
|
75 |
+
parser.add_argument(
|
76 |
+
"--output-audio",
|
77 |
+
type=str,
|
78 |
+
default="output.wav",
|
79 |
+
help="Path to save the output audio",
|
80 |
+
)
|
81 |
+
return parser.parse_args()
|
82 |
+
|
83 |
+
def prepare_request(
|
84 |
+
waveform,
|
85 |
+
reference_text,
|
86 |
+
target_text,
|
87 |
+
sample_rate=16000,
|
88 |
+
padding_duration: int = None,
|
89 |
+
audio_save_dir: str = "./",
|
90 |
+
):
|
91 |
+
assert len(waveform.shape) == 1, "waveform should be 1D"
|
92 |
+
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
93 |
+
if padding_duration:
|
94 |
+
# padding to nearset 10 seconds
|
95 |
+
samples = np.zeros(
|
96 |
+
(
|
97 |
+
1,
|
98 |
+
padding_duration
|
99 |
+
* sample_rate
|
100 |
+
* ((int(duration) // padding_duration) + 1),
|
101 |
+
),
|
102 |
+
dtype=np.float32,
|
103 |
+
)
|
104 |
+
|
105 |
+
samples[0, : len(waveform)] = waveform
|
106 |
+
else:
|
107 |
+
samples = waveform
|
108 |
+
|
109 |
+
samples = samples.reshape(1, -1).astype(np.float32)
|
110 |
+
|
111 |
+
data = {
|
112 |
+
"inputs":[
|
113 |
+
{
|
114 |
+
"name": "reference_wav",
|
115 |
+
"shape": samples.shape,
|
116 |
+
"datatype": "FP32",
|
117 |
+
"data": samples.tolist()
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"name": "reference_wav_len",
|
121 |
+
"shape": lengths.shape,
|
122 |
+
"datatype": "INT32",
|
123 |
+
"data": lengths.tolist(),
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"name": "reference_text",
|
127 |
+
"shape": [1, 1],
|
128 |
+
"datatype": "BYTES",
|
129 |
+
"data": [reference_text]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"name": "target_text",
|
133 |
+
"shape": [1, 1],
|
134 |
+
"datatype": "BYTES",
|
135 |
+
"data": [target_text]
|
136 |
+
}
|
137 |
+
]
|
138 |
+
}
|
139 |
+
|
140 |
+
return data
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
args = get_args()
|
144 |
+
server_url = args.server_url
|
145 |
+
if not server_url.startswith(("http://", "https://")):
|
146 |
+
server_url = f"http://{server_url}"
|
147 |
+
|
148 |
+
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
149 |
+
waveform, sr = sf.read(args.reference_audio)
|
150 |
+
assert sr == 16000, "sample rate hardcoded in server"
|
151 |
+
|
152 |
+
samples = np.array(waveform, dtype=np.float32)
|
153 |
+
data = prepare_request(samples, args.reference_text, args.target_text)
|
154 |
+
|
155 |
+
rsp = requests.post(
|
156 |
+
url,
|
157 |
+
headers={"Content-Type": "application/json"},
|
158 |
+
json=data,
|
159 |
+
verify=False,
|
160 |
+
params={"request_id": '0'}
|
161 |
+
)
|
162 |
+
result = rsp.json()
|
163 |
+
audio = result["outputs"][0]["data"]
|
164 |
+
audio = np.array(audio, dtype=np.float32)
|
165 |
+
sf.write(args.output_audio, audio, 16000, "PCM_16")
|
runtime/triton_trtllm/docker-compose.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
tts:
|
3 |
+
image: soar97/triton-spark-tts:25.02
|
4 |
+
shm_size: '1gb'
|
5 |
+
ports:
|
6 |
+
- "8000:8000"
|
7 |
+
- "8001:8001"
|
8 |
+
- "8002:8002"
|
9 |
+
environment:
|
10 |
+
- PYTHONIOENCODING=utf-8
|
11 |
+
- MODEL_ID=${MODEL_ID}
|
12 |
+
deploy:
|
13 |
+
resources:
|
14 |
+
reservations:
|
15 |
+
devices:
|
16 |
+
- driver: nvidia
|
17 |
+
device_ids: ['0']
|
18 |
+
capabilities: [gpu]
|
19 |
+
command: >
|
20 |
+
/bin/bash -c "rm -rf Spark-TTS && git clone https://github.com/SparkAudio/Spark-TTS.git && cd Spark-TTS/runtime/triton_trtllm && bash run.sh 0 3"
|
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
import json
|
27 |
+
import torch
|
28 |
+
from torch.utils.dlpack import to_dlpack
|
29 |
+
|
30 |
+
import triton_python_backend_utils as pb_utils
|
31 |
+
|
32 |
+
import os
|
33 |
+
import numpy as np
|
34 |
+
|
35 |
+
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
36 |
+
|
37 |
+
class TritonPythonModel:
|
38 |
+
"""Triton Python model for audio tokenization.
|
39 |
+
|
40 |
+
This model takes reference audio input and extracts semantic and global tokens
|
41 |
+
using BiCodec tokenizer.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def initialize(self, args):
|
45 |
+
"""Initialize the model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
args: Dictionary containing model configuration
|
49 |
+
"""
|
50 |
+
# Parse model parameters
|
51 |
+
parameters = json.loads(args['model_config'])['parameters']
|
52 |
+
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
53 |
+
|
54 |
+
# Initialize tokenizer
|
55 |
+
self.device = torch.device("cuda")
|
56 |
+
self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"],
|
57 |
+
device=self.device)
|
58 |
+
|
59 |
+
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
60 |
+
"""Extract reference audio clip for speaker embedding.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
wav: Input waveform array
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Reference clip of fixed duration
|
67 |
+
"""
|
68 |
+
SAMPLE_RATE = 16000
|
69 |
+
REF_SEGMENT_DURATION = 6 # seconds
|
70 |
+
LATENT_HOP_LENGTH = 320
|
71 |
+
|
72 |
+
ref_segment_length = (
|
73 |
+
int(SAMPLE_RATE * REF_SEGMENT_DURATION)
|
74 |
+
// LATENT_HOP_LENGTH
|
75 |
+
* LATENT_HOP_LENGTH
|
76 |
+
)
|
77 |
+
wav_length = len(wav)
|
78 |
+
|
79 |
+
if ref_segment_length > wav_length:
|
80 |
+
# Repeat and truncate if input is too short
|
81 |
+
repeat_times = ref_segment_length // wav_length + 1
|
82 |
+
wav = np.tile(wav, repeat_times)
|
83 |
+
|
84 |
+
return wav[:ref_segment_length]
|
85 |
+
|
86 |
+
def execute(self, requests):
|
87 |
+
"""Execute inference on the batched requests.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
requests: List of inference requests
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
List of inference responses containing tokenized outputs
|
94 |
+
"""
|
95 |
+
reference_wav_list = []
|
96 |
+
reference_wav_ref_clip_list = []
|
97 |
+
|
98 |
+
# Process each request in batch
|
99 |
+
for request in requests:
|
100 |
+
# Extract input tensors
|
101 |
+
wav_array = pb_utils.get_input_tensor_by_name(
|
102 |
+
request, "reference_wav").as_numpy()
|
103 |
+
wav_len = pb_utils.get_input_tensor_by_name(
|
104 |
+
request, "reference_wav_len").as_numpy().item()
|
105 |
+
|
106 |
+
# Prepare inputs
|
107 |
+
wav = wav_array[:, :wav_len].squeeze(0)
|
108 |
+
reference_wav_list.append(wav)
|
109 |
+
|
110 |
+
wav_ref_clip = self.get_ref_clip(wav)
|
111 |
+
reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip))
|
112 |
+
|
113 |
+
# Batch process through tokenizer
|
114 |
+
ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0)
|
115 |
+
wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features(
|
116 |
+
reference_wav_list)
|
117 |
+
|
118 |
+
audio_tokenizer_input = {
|
119 |
+
"ref_wav": ref_wav_clip_tensor.to(self.device),
|
120 |
+
"feat": wav2vec2_features.to(self.device),
|
121 |
+
}
|
122 |
+
semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize(
|
123 |
+
audio_tokenizer_input)
|
124 |
+
|
125 |
+
# Prepare responses
|
126 |
+
responses = []
|
127 |
+
for i in range(len(requests)):
|
128 |
+
global_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
129 |
+
"global_tokens", to_dlpack(global_tokens[i]))
|
130 |
+
semantic_tokens_tensor = pb_utils.Tensor.from_dlpack(
|
131 |
+
"semantic_tokens", to_dlpack(semantic_tokens[i]))
|
132 |
+
|
133 |
+
inference_response = pb_utils.InferenceResponse(
|
134 |
+
output_tensors=[global_tokens_tensor, semantic_tokens_tensor])
|
135 |
+
responses.append(inference_response)
|
136 |
+
|
137 |
+
return responses
|
runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "audio_tokenizer"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
+
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
+
}
|
21 |
+
parameters [
|
22 |
+
{
|
23 |
+
key: "model_dir",
|
24 |
+
value: {string_value:"${model_dir}"}
|
25 |
+
}
|
26 |
+
]
|
27 |
+
|
28 |
+
input [
|
29 |
+
{
|
30 |
+
name: "reference_wav"
|
31 |
+
data_type: TYPE_FP32
|
32 |
+
dims: [-1]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
name: "reference_wav_len"
|
36 |
+
data_type: TYPE_INT32
|
37 |
+
dims: [1]
|
38 |
+
}
|
39 |
+
]
|
40 |
+
output [
|
41 |
+
{
|
42 |
+
name: "global_tokens"
|
43 |
+
data_type: TYPE_INT32
|
44 |
+
dims: [-1]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
name: "semantic_tokens"
|
48 |
+
data_type: TYPE_INT32
|
49 |
+
dims: [-1]
|
50 |
+
}
|
51 |
+
]
|
52 |
+
|
53 |
+
instance_group [
|
54 |
+
{
|
55 |
+
count: 1
|
56 |
+
kind: KIND_CPU
|
57 |
+
}
|
58 |
+
]
|
runtime/triton_trtllm/model_repo/spark_tts/1/model.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
import json
|
28 |
+
import math
|
29 |
+
import os
|
30 |
+
import re
|
31 |
+
from typing import Dict, List, Tuple, Optional, Union
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
import torch
|
35 |
+
from torch.utils.dlpack import from_dlpack, to_dlpack
|
36 |
+
import triton_python_backend_utils as pb_utils
|
37 |
+
from transformers import AutoTokenizer
|
38 |
+
|
39 |
+
from sparktts.utils.token_parser import TASK_TOKEN_MAP
|
40 |
+
|
41 |
+
def process_prompt(
|
42 |
+
text: str,
|
43 |
+
prompt_text: Optional[str] = None,
|
44 |
+
global_token_ids: torch.Tensor = None,
|
45 |
+
semantic_token_ids: torch.Tensor = None,
|
46 |
+
) -> Tuple[str, torch.Tensor]:
|
47 |
+
"""
|
48 |
+
Process input for voice cloning.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
text: The text input to be converted to speech.
|
52 |
+
prompt_text: Transcript of the prompt audio.
|
53 |
+
global_token_ids: Global token IDs extracted from reference audio.
|
54 |
+
semantic_token_ids: Semantic token IDs extracted from reference audio.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Tuple containing the formatted input prompt and global token IDs.
|
58 |
+
"""
|
59 |
+
# Convert global tokens to string format
|
60 |
+
global_tokens = "".join(
|
61 |
+
[f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
# Prepare the input tokens for the model
|
66 |
+
if prompt_text is not None:
|
67 |
+
# Include semantic tokens when prompt text is provided
|
68 |
+
semantic_tokens = "".join(
|
69 |
+
[f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
|
70 |
+
)
|
71 |
+
|
72 |
+
inputs = [
|
73 |
+
TASK_TOKEN_MAP["tts"],
|
74 |
+
"<|start_content|>",
|
75 |
+
prompt_text,
|
76 |
+
text,
|
77 |
+
"<|end_content|>",
|
78 |
+
"<|start_global_token|>",
|
79 |
+
global_tokens,
|
80 |
+
"<|end_global_token|>",
|
81 |
+
"<|start_semantic_token|>",
|
82 |
+
semantic_tokens,
|
83 |
+
]
|
84 |
+
else:
|
85 |
+
# Without prompt text, exclude semantic tokens
|
86 |
+
inputs = [
|
87 |
+
TASK_TOKEN_MAP["tts"],
|
88 |
+
"<|start_content|>",
|
89 |
+
text,
|
90 |
+
"<|end_content|>",
|
91 |
+
"<|start_global_token|>",
|
92 |
+
global_tokens,
|
93 |
+
"<|end_global_token|>",
|
94 |
+
]
|
95 |
+
|
96 |
+
# Join all input components into a single string
|
97 |
+
inputs = "".join(inputs)
|
98 |
+
return inputs, global_token_ids
|
99 |
+
|
100 |
+
|
101 |
+
class TritonPythonModel:
|
102 |
+
"""Triton Python model for Spark TTS.
|
103 |
+
|
104 |
+
This model orchestrates the end-to-end TTS pipeline by coordinating
|
105 |
+
between audio tokenizer, LLM, and vocoder components.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def initialize(self, args):
|
109 |
+
"""Initialize the model.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
args: Dictionary containing model configuration
|
113 |
+
"""
|
114 |
+
self.logger = pb_utils.Logger
|
115 |
+
# Parse model parameters
|
116 |
+
self.model_config = json.loads(args['model_config'])
|
117 |
+
parameters = self.model_config['parameters']
|
118 |
+
model_params = {k: v["string_value"] for k, v in parameters.items()}
|
119 |
+
self.logger.log_info(f"model_params:{model_params}")
|
120 |
+
# streaming TTS parameters
|
121 |
+
assert (
|
122 |
+
float(model_params["audio_chunk_duration"]) >= 0.5
|
123 |
+
), f"audio_chunk_duration at least 0.5 seconds"
|
124 |
+
self.audio_chunk_duration = float(model_params["audio_chunk_duration"])
|
125 |
+
self.max_audio_chunk_duration = float(model_params["max_audio_chunk_duration"])
|
126 |
+
assert (
|
127 |
+
float(model_params["audio_chunk_size_scale_factor"]) >= 1.0
|
128 |
+
), "audio_chunk_size_scale_factor should be greater than 1, change it according to your actual rtf"
|
129 |
+
self.audio_chunk_size_scale_factor = float(model_params["audio_chunk_size_scale_factor"]) # scale speed
|
130 |
+
self.audio_chunk_overlap_duration = float(model_params["audio_chunk_overlap_duration"])
|
131 |
+
self.audio_tokenizer_frame_rate = int(model_params["audio_tokenizer_frame_rate"])
|
132 |
+
|
133 |
+
# Initialize tokenizer
|
134 |
+
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
|
135 |
+
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
|
136 |
+
self.device = torch.device("cuda")
|
137 |
+
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
|
138 |
+
|
139 |
+
def forward_llm(self, input_ids):
|
140 |
+
"""
|
141 |
+
Prepares the response from the language model based on the provided
|
142 |
+
inputs. Creates a `pb_utils.InferenceRequest` object with passed
|
143 |
+
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
|
144 |
+
For each response from the language model:
|
145 |
+
- Checks for errors and raise an exception if any are found.
|
146 |
+
- Extracts the "output_ids" tensor from the response.
|
147 |
+
- Determines the finish reason based on the presence of the
|
148 |
+
end-of-sequence token or reaching the maximum length.
|
149 |
+
- Appends the generated token IDs to `output_ids`.
|
150 |
+
- If the finish reason is determined, decodes the output IDs to text
|
151 |
+
and prepares the final response.
|
152 |
+
|
153 |
+
The final response includes the generated text, finish reason,
|
154 |
+
completion tokens, prompt tokens, and total tokens.
|
155 |
+
|
156 |
+
Parameters
|
157 |
+
----------
|
158 |
+
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
|
159 |
+
|
160 |
+
Returns
|
161 |
+
-------
|
162 |
+
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
|
163 |
+
"""
|
164 |
+
# convert input_ids to numpy, with shape [1, sequence_length]
|
165 |
+
input_ids = input_ids.cpu().numpy()
|
166 |
+
max_tokens = 512
|
167 |
+
input_dict = {
|
168 |
+
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
|
169 |
+
"end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
|
170 |
+
"pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
|
171 |
+
"streaming": np.array([[self.decoupled]], dtype=np.bool_),
|
172 |
+
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
|
173 |
+
"runtime_top_k": np.array([[50]], dtype=np.int32),
|
174 |
+
"temperature": np.array([[0.8]], dtype=np.float32),
|
175 |
+
"input_ids": input_ids,
|
176 |
+
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
|
177 |
+
}
|
178 |
+
|
179 |
+
# Convert inputs to Triton tensors
|
180 |
+
input_tensor_list = [
|
181 |
+
pb_utils.Tensor(k, v) for k, v in input_dict.items()
|
182 |
+
]
|
183 |
+
|
184 |
+
# Create and execute inference request
|
185 |
+
llm_request = pb_utils.InferenceRequest(
|
186 |
+
model_name="tensorrt_llm",
|
187 |
+
requested_output_names=["output_ids", "sequence_length"],
|
188 |
+
inputs=input_tensor_list,
|
189 |
+
)
|
190 |
+
|
191 |
+
llm_responses = llm_request.exec(decoupled=self.decoupled)
|
192 |
+
if self.decoupled:
|
193 |
+
for llm_response in llm_responses:
|
194 |
+
if llm_response.has_error():
|
195 |
+
raise pb_utils.TritonModelException(llm_response.error().message())
|
196 |
+
|
197 |
+
# Extract and process output
|
198 |
+
output_ids = pb_utils.get_output_tensor_by_name(
|
199 |
+
llm_response, "output_ids").as_numpy()
|
200 |
+
seq_lens = pb_utils.get_output_tensor_by_name(
|
201 |
+
llm_response, "sequence_length").as_numpy()
|
202 |
+
|
203 |
+
# Get actual output IDs up to the sequence length
|
204 |
+
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
205 |
+
|
206 |
+
yield actual_output_ids
|
207 |
+
else:
|
208 |
+
llm_response = llm_responses
|
209 |
+
if llm_response.has_error():
|
210 |
+
raise pb_utils.TritonModelException(llm_response.error().message())
|
211 |
+
|
212 |
+
# Extract and process output
|
213 |
+
output_ids = pb_utils.get_output_tensor_by_name(
|
214 |
+
llm_response, "output_ids").as_numpy()
|
215 |
+
seq_lens = pb_utils.get_output_tensor_by_name(
|
216 |
+
llm_response, "sequence_length").as_numpy()
|
217 |
+
|
218 |
+
# Get actual output IDs up to the sequence length
|
219 |
+
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
|
220 |
+
|
221 |
+
yield actual_output_ids
|
222 |
+
|
223 |
+
def forward_audio_tokenizer(self, wav, wav_len):
|
224 |
+
"""Forward pass through the audio tokenizer component.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
wav: Input waveform tensor
|
228 |
+
wav_len: Waveform length tensor
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
Tuple of global and semantic tokens
|
232 |
+
"""
|
233 |
+
inference_request = pb_utils.InferenceRequest(
|
234 |
+
model_name='audio_tokenizer',
|
235 |
+
requested_output_names=['global_tokens', 'semantic_tokens'],
|
236 |
+
inputs=[wav, wav_len]
|
237 |
+
)
|
238 |
+
|
239 |
+
inference_response = inference_request.exec()
|
240 |
+
if inference_response.has_error():
|
241 |
+
raise pb_utils.TritonModelException(inference_response.error().message())
|
242 |
+
|
243 |
+
# Extract and convert output tensors
|
244 |
+
global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens')
|
245 |
+
global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu()
|
246 |
+
|
247 |
+
semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens')
|
248 |
+
semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu()
|
249 |
+
|
250 |
+
return global_tokens, semantic_tokens
|
251 |
+
|
252 |
+
def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor:
|
253 |
+
"""Forward pass through the vocoder component.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
global_token_ids: Global token IDs tensor
|
257 |
+
pred_semantic_ids: Predicted semantic token IDs tensor
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
Generated waveform tensor
|
261 |
+
"""
|
262 |
+
# Convert tensors to Triton format
|
263 |
+
global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids))
|
264 |
+
pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids))
|
265 |
+
|
266 |
+
# Create and execute inference request
|
267 |
+
inference_request = pb_utils.InferenceRequest(
|
268 |
+
model_name='vocoder',
|
269 |
+
requested_output_names=['waveform'],
|
270 |
+
inputs=[global_token_ids_tensor, pred_semantic_ids_tensor]
|
271 |
+
)
|
272 |
+
|
273 |
+
inference_response = inference_request.exec()
|
274 |
+
if inference_response.has_error():
|
275 |
+
raise pb_utils.TritonModelException(inference_response.error().message())
|
276 |
+
|
277 |
+
# Extract and convert output waveform
|
278 |
+
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
|
279 |
+
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
|
280 |
+
|
281 |
+
return waveform
|
282 |
+
|
283 |
+
def token2wav(self, generated_token_ids, global_token_ids):
|
284 |
+
# Decode and extract semantic token IDs from generated text
|
285 |
+
predicted_text = self.tokenizer.batch_decode(
|
286 |
+
[generated_token_ids],
|
287 |
+
skip_special_tokens=True,
|
288 |
+
)[0]
|
289 |
+
pred_semantic_ids = (
|
290 |
+
torch.tensor(
|
291 |
+
[int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)]
|
292 |
+
)
|
293 |
+
.unsqueeze(0)
|
294 |
+
.to(torch.int32)
|
295 |
+
)
|
296 |
+
|
297 |
+
# Generate audio with vocoder
|
298 |
+
audio = self.forward_vocoder(
|
299 |
+
global_token_ids.to(self.device),
|
300 |
+
pred_semantic_ids.to(self.device),
|
301 |
+
)
|
302 |
+
|
303 |
+
return audio
|
304 |
+
|
305 |
+
def execute(self, requests):
|
306 |
+
"""Execute inference on the batched requests.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
requests: List of inference requests
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
List of inference responses containing generated audio
|
313 |
+
"""
|
314 |
+
responses = []
|
315 |
+
|
316 |
+
for request in requests:
|
317 |
+
# Extract input tensors
|
318 |
+
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
|
319 |
+
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
|
320 |
+
|
321 |
+
# Process reference audio through audio tokenizer
|
322 |
+
global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len)
|
323 |
+
|
324 |
+
# Extract text inputs
|
325 |
+
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
|
326 |
+
reference_text = reference_text[0][0].decode('utf-8')
|
327 |
+
|
328 |
+
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
|
329 |
+
target_text = target_text[0][0].decode('utf-8')
|
330 |
+
|
331 |
+
# Prepare prompt for LLM
|
332 |
+
prompt, global_token_ids = process_prompt(
|
333 |
+
text=target_text,
|
334 |
+
prompt_text=reference_text,
|
335 |
+
global_token_ids=global_tokens,
|
336 |
+
semantic_token_ids=semantic_tokens,
|
337 |
+
)
|
338 |
+
|
339 |
+
|
340 |
+
# Tokenize prompt for LLM
|
341 |
+
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
342 |
+
input_ids = model_inputs.input_ids.to(torch.int32)
|
343 |
+
|
344 |
+
# Generate semantic tokens with LLM
|
345 |
+
generated_ids_iter = self.forward_llm(input_ids)
|
346 |
+
|
347 |
+
if self.decoupled:
|
348 |
+
response_sender = request.get_response_sender()
|
349 |
+
request_id = request.request_id()
|
350 |
+
semantic_token_ids_arr = []
|
351 |
+
max_chunk_size = math.ceil(self.max_audio_chunk_duration * self.audio_tokenizer_frame_rate)
|
352 |
+
chunk_size = math.ceil(self.audio_chunk_duration * self.audio_tokenizer_frame_rate)
|
353 |
+
overlap_chunk_size = math.ceil(self.audio_chunk_overlap_duration * self.audio_tokenizer_frame_rate)
|
354 |
+
self.logger.log_info(
|
355 |
+
f"[{request_id}] init chunk_size: {chunk_size} max_chunk_size: {max_chunk_size}"
|
356 |
+
)
|
357 |
+
for generated_ids in generated_ids_iter:
|
358 |
+
if generated_ids is None or len(generated_ids) == 0:
|
359 |
+
break
|
360 |
+
|
361 |
+
semantic_token_ids_arr.append(generated_ids)
|
362 |
+
if len(semantic_token_ids_arr) >= chunk_size:
|
363 |
+
chunk = semantic_token_ids_arr[:chunk_size]
|
364 |
+
generated_semantic_token_ids = np.hstack(chunk)
|
365 |
+
# Process each chunk
|
366 |
+
sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids)
|
367 |
+
# Prepare response to send
|
368 |
+
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
369 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
370 |
+
response_sender.send(inference_response)
|
371 |
+
|
372 |
+
semantic_token_ids_arr = semantic_token_ids_arr[chunk_size - overlap_chunk_size:]
|
373 |
+
# increase chunk size for better speech quality
|
374 |
+
chunk_size = min(max_chunk_size, int(chunk_size * self.audio_chunk_size_scale_factor))
|
375 |
+
self.logger.log_info(f"[{request_id}] increase chunk_size: {chunk_size}")
|
376 |
+
|
377 |
+
if len(semantic_token_ids_arr) > 0: # end to finalize
|
378 |
+
generated_semantic_token_ids = np.hstack(semantic_token_ids_arr)
|
379 |
+
# Process each chunk
|
380 |
+
sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids)
|
381 |
+
# Prepare response to send
|
382 |
+
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
|
383 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
384 |
+
response_sender.send(inference_response)
|
385 |
+
self.logger.log_info(f"[{request_id}] last chunk len: {len(semantic_token_ids_arr)}")
|
386 |
+
else:
|
387 |
+
generated_ids = next(generated_ids_iter)
|
388 |
+
if generated_ids is None or len(generated_ids) == 0:
|
389 |
+
raise pb_utils.TritonModelException("Generated IDs is None or empty")
|
390 |
+
|
391 |
+
audio = self.token2wav(generated_ids, global_token_ids)
|
392 |
+
|
393 |
+
# Prepare response
|
394 |
+
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
395 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
|
396 |
+
responses.append(inference_response)
|
397 |
+
|
398 |
+
if self.decoupled:
|
399 |
+
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
|
400 |
+
self.logger.log_info(f"send tritonserver_response_complete_final to end")
|
401 |
+
|
402 |
+
if not self.decoupled:
|
403 |
+
return responses
|
404 |
+
|
runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "spark_tts"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
+
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
+
}
|
21 |
+
model_transaction_policy {
|
22 |
+
decoupled: ${decoupled_mode}
|
23 |
+
}
|
24 |
+
parameters [
|
25 |
+
{
|
26 |
+
key: "llm_tokenizer_dir",
|
27 |
+
value: {string_value:"${llm_tokenizer_dir}"}
|
28 |
+
},
|
29 |
+
{
|
30 |
+
key: "audio_chunk_duration",
|
31 |
+
value: {string_value:"${audio_chunk_duration}"}
|
32 |
+
},
|
33 |
+
{
|
34 |
+
key: "audio_chunk_size_scale_factor",
|
35 |
+
value: {string_value:"${audio_chunk_size_scale_factor}"}
|
36 |
+
},
|
37 |
+
{
|
38 |
+
key: "max_audio_chunk_duration",
|
39 |
+
value: {string_value:"${max_audio_chunk_duration}"}
|
40 |
+
},
|
41 |
+
{
|
42 |
+
key: "audio_chunk_overlap_duration",
|
43 |
+
value: {string_value:"${audio_chunk_overlap_duration}"}
|
44 |
+
},
|
45 |
+
{
|
46 |
+
key: "audio_tokenizer_frame_rate",
|
47 |
+
value: {string_value:"50"}
|
48 |
+
}
|
49 |
+
]
|
50 |
+
|
51 |
+
input [
|
52 |
+
{
|
53 |
+
name: "reference_wav"
|
54 |
+
data_type: TYPE_FP32
|
55 |
+
dims: [-1]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
name: "reference_wav_len"
|
59 |
+
data_type: TYPE_INT32
|
60 |
+
dims: [1]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
name: "reference_text"
|
64 |
+
data_type: TYPE_STRING
|
65 |
+
dims: [1]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
name: "target_text"
|
69 |
+
data_type: TYPE_STRING
|
70 |
+
dims: [1]
|
71 |
+
}
|
72 |
+
]
|
73 |
+
output [
|
74 |
+
{
|
75 |
+
name: "waveform"
|
76 |
+
data_type: TYPE_FP32
|
77 |
+
dims: [ -1 ]
|
78 |
+
}
|
79 |
+
]
|
80 |
+
|
81 |
+
instance_group [
|
82 |
+
{
|
83 |
+
count: ${bls_instance_num}
|
84 |
+
kind: KIND_CPU
|
85 |
+
}
|
86 |
+
]
|
runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep
ADDED
File without changes
|
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt
ADDED
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
name: "tensorrt_llm"
|
28 |
+
backend: "${triton_backend}"
|
29 |
+
max_batch_size: ${triton_max_batch_size}
|
30 |
+
|
31 |
+
model_transaction_policy {
|
32 |
+
decoupled: ${decoupled_mode}
|
33 |
+
}
|
34 |
+
|
35 |
+
dynamic_batching {
|
36 |
+
preferred_batch_size: [ ${triton_max_batch_size} ]
|
37 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
38 |
+
default_queue_policy: { max_queue_size: ${max_queue_size} }
|
39 |
+
}
|
40 |
+
|
41 |
+
input [
|
42 |
+
{
|
43 |
+
name: "input_ids"
|
44 |
+
data_type: TYPE_INT32
|
45 |
+
dims: [ -1 ]
|
46 |
+
allow_ragged_batch: true
|
47 |
+
optional: true
|
48 |
+
},
|
49 |
+
{
|
50 |
+
name: "encoder_input_features"
|
51 |
+
data_type: ${encoder_input_features_data_type}
|
52 |
+
dims: [ -1, -1 ]
|
53 |
+
allow_ragged_batch: true
|
54 |
+
optional: true
|
55 |
+
},
|
56 |
+
{
|
57 |
+
name: "encoder_output_lengths"
|
58 |
+
data_type: TYPE_INT32
|
59 |
+
dims: [ 1 ]
|
60 |
+
reshape: { shape: [ ] }
|
61 |
+
optional: true
|
62 |
+
},
|
63 |
+
{
|
64 |
+
name: "input_lengths"
|
65 |
+
data_type: TYPE_INT32
|
66 |
+
dims: [ 1 ]
|
67 |
+
reshape: { shape: [ ] }
|
68 |
+
},
|
69 |
+
{
|
70 |
+
name: "request_output_len"
|
71 |
+
data_type: TYPE_INT32
|
72 |
+
dims: [ 1 ]
|
73 |
+
reshape: { shape: [ ] }
|
74 |
+
},
|
75 |
+
{
|
76 |
+
name: "num_return_sequences"
|
77 |
+
data_type: TYPE_INT32
|
78 |
+
dims: [ 1 ]
|
79 |
+
reshape: { shape: [ ] }
|
80 |
+
optional: true
|
81 |
+
},
|
82 |
+
{
|
83 |
+
name: "draft_input_ids"
|
84 |
+
data_type: TYPE_INT32
|
85 |
+
dims: [ -1 ]
|
86 |
+
optional: true
|
87 |
+
allow_ragged_batch: true
|
88 |
+
},
|
89 |
+
{
|
90 |
+
name: "decoder_input_ids"
|
91 |
+
data_type: TYPE_INT32
|
92 |
+
dims: [ -1 ]
|
93 |
+
optional: true
|
94 |
+
allow_ragged_batch: true
|
95 |
+
},
|
96 |
+
{
|
97 |
+
name: "decoder_input_lengths"
|
98 |
+
data_type: TYPE_INT32
|
99 |
+
dims: [ 1 ]
|
100 |
+
optional: true
|
101 |
+
reshape: { shape: [ ] }
|
102 |
+
},
|
103 |
+
{
|
104 |
+
name: "draft_logits"
|
105 |
+
data_type: ${logits_datatype}
|
106 |
+
dims: [ -1, -1 ]
|
107 |
+
optional: true
|
108 |
+
allow_ragged_batch: true
|
109 |
+
},
|
110 |
+
{
|
111 |
+
name: "draft_acceptance_threshold"
|
112 |
+
data_type: TYPE_FP32
|
113 |
+
dims: [ 1 ]
|
114 |
+
reshape: { shape: [ ] }
|
115 |
+
optional: true
|
116 |
+
},
|
117 |
+
{
|
118 |
+
name: "end_id"
|
119 |
+
data_type: TYPE_INT32
|
120 |
+
dims: [ 1 ]
|
121 |
+
reshape: { shape: [ ] }
|
122 |
+
optional: true
|
123 |
+
},
|
124 |
+
{
|
125 |
+
name: "pad_id"
|
126 |
+
data_type: TYPE_INT32
|
127 |
+
dims: [ 1 ]
|
128 |
+
reshape: { shape: [ ] }
|
129 |
+
optional: true
|
130 |
+
},
|
131 |
+
{
|
132 |
+
name: "stop_words_list"
|
133 |
+
data_type: TYPE_INT32
|
134 |
+
dims: [ 2, -1 ]
|
135 |
+
optional: true
|
136 |
+
allow_ragged_batch: true
|
137 |
+
},
|
138 |
+
{
|
139 |
+
name: "bad_words_list"
|
140 |
+
data_type: TYPE_INT32
|
141 |
+
dims: [ 2, -1 ]
|
142 |
+
optional: true
|
143 |
+
allow_ragged_batch: true
|
144 |
+
},
|
145 |
+
{
|
146 |
+
name: "embedding_bias"
|
147 |
+
data_type: TYPE_FP32
|
148 |
+
dims: [ -1 ]
|
149 |
+
optional: true
|
150 |
+
allow_ragged_batch: true
|
151 |
+
},
|
152 |
+
{
|
153 |
+
name: "beam_width"
|
154 |
+
data_type: TYPE_INT32
|
155 |
+
dims: [ 1 ]
|
156 |
+
reshape: { shape: [ ] }
|
157 |
+
optional: true
|
158 |
+
},
|
159 |
+
{
|
160 |
+
name: "temperature"
|
161 |
+
data_type: TYPE_FP32
|
162 |
+
dims: [ 1 ]
|
163 |
+
reshape: { shape: [ ] }
|
164 |
+
optional: true
|
165 |
+
},
|
166 |
+
{
|
167 |
+
name: "runtime_top_k"
|
168 |
+
data_type: TYPE_INT32
|
169 |
+
dims: [ 1 ]
|
170 |
+
reshape: { shape: [ ] }
|
171 |
+
optional: true
|
172 |
+
},
|
173 |
+
{
|
174 |
+
name: "runtime_top_p"
|
175 |
+
data_type: TYPE_FP32
|
176 |
+
dims: [ 1 ]
|
177 |
+
reshape: { shape: [ ] }
|
178 |
+
optional: true
|
179 |
+
},
|
180 |
+
{
|
181 |
+
name: "runtime_top_p_min"
|
182 |
+
data_type: TYPE_FP32
|
183 |
+
dims: [ 1 ]
|
184 |
+
reshape: { shape: [ ] }
|
185 |
+
optional: true
|
186 |
+
},
|
187 |
+
{
|
188 |
+
name: "runtime_top_p_decay"
|
189 |
+
data_type: TYPE_FP32
|
190 |
+
dims: [ 1 ]
|
191 |
+
reshape: { shape: [ ] }
|
192 |
+
optional: true
|
193 |
+
},
|
194 |
+
{
|
195 |
+
name: "runtime_top_p_reset_ids"
|
196 |
+
data_type: TYPE_INT32
|
197 |
+
dims: [ 1 ]
|
198 |
+
reshape: { shape: [ ] }
|
199 |
+
optional: true
|
200 |
+
},
|
201 |
+
{
|
202 |
+
name: "len_penalty"
|
203 |
+
data_type: TYPE_FP32
|
204 |
+
dims: [ 1 ]
|
205 |
+
reshape: { shape: [ ] }
|
206 |
+
optional: true
|
207 |
+
},
|
208 |
+
{
|
209 |
+
name: "early_stopping"
|
210 |
+
data_type: TYPE_BOOL
|
211 |
+
dims: [ 1 ]
|
212 |
+
reshape: { shape: [ ] }
|
213 |
+
optional: true
|
214 |
+
},
|
215 |
+
{
|
216 |
+
name: "repetition_penalty"
|
217 |
+
data_type: TYPE_FP32
|
218 |
+
dims: [ 1 ]
|
219 |
+
reshape: { shape: [ ] }
|
220 |
+
optional: true
|
221 |
+
},
|
222 |
+
{
|
223 |
+
name: "min_length"
|
224 |
+
data_type: TYPE_INT32
|
225 |
+
dims: [ 1 ]
|
226 |
+
reshape: { shape: [ ] }
|
227 |
+
optional: true
|
228 |
+
},
|
229 |
+
{
|
230 |
+
name: "beam_search_diversity_rate"
|
231 |
+
data_type: TYPE_FP32
|
232 |
+
dims: [ 1 ]
|
233 |
+
reshape: { shape: [ ] }
|
234 |
+
optional: true
|
235 |
+
},
|
236 |
+
{
|
237 |
+
name: "presence_penalty"
|
238 |
+
data_type: TYPE_FP32
|
239 |
+
dims: [ 1 ]
|
240 |
+
reshape: { shape: [ ] }
|
241 |
+
optional: true
|
242 |
+
},
|
243 |
+
{
|
244 |
+
name: "frequency_penalty"
|
245 |
+
data_type: TYPE_FP32
|
246 |
+
dims: [ 1 ]
|
247 |
+
reshape: { shape: [ ] }
|
248 |
+
optional: true
|
249 |
+
},
|
250 |
+
{
|
251 |
+
name: "random_seed"
|
252 |
+
data_type: TYPE_UINT64
|
253 |
+
dims: [ 1 ]
|
254 |
+
reshape: { shape: [ ] }
|
255 |
+
optional: true
|
256 |
+
},
|
257 |
+
{
|
258 |
+
name: "return_log_probs"
|
259 |
+
data_type: TYPE_BOOL
|
260 |
+
dims: [ 1 ]
|
261 |
+
reshape: { shape: [ ] }
|
262 |
+
optional: true
|
263 |
+
},
|
264 |
+
{
|
265 |
+
name: "return_context_logits"
|
266 |
+
data_type: TYPE_BOOL
|
267 |
+
dims: [ 1 ]
|
268 |
+
reshape: { shape: [ ] }
|
269 |
+
optional: true
|
270 |
+
},
|
271 |
+
{
|
272 |
+
name: "return_generation_logits"
|
273 |
+
data_type: TYPE_BOOL
|
274 |
+
dims: [ 1 ]
|
275 |
+
reshape: { shape: [ ] }
|
276 |
+
optional: true
|
277 |
+
},
|
278 |
+
{
|
279 |
+
name: "return_perf_metrics"
|
280 |
+
data_type: TYPE_BOOL
|
281 |
+
dims: [ 1 ]
|
282 |
+
reshape: { shape: [ ] }
|
283 |
+
optional: true
|
284 |
+
},
|
285 |
+
{
|
286 |
+
name: "exclude_input_in_output"
|
287 |
+
data_type: TYPE_BOOL
|
288 |
+
dims: [ 1 ]
|
289 |
+
reshape: { shape: [ ] }
|
290 |
+
optional: true
|
291 |
+
},
|
292 |
+
{
|
293 |
+
name: "stop"
|
294 |
+
data_type: TYPE_BOOL
|
295 |
+
dims: [ 1 ]
|
296 |
+
reshape: { shape: [ ] }
|
297 |
+
optional: true
|
298 |
+
},
|
299 |
+
{
|
300 |
+
name: "streaming"
|
301 |
+
data_type: TYPE_BOOL
|
302 |
+
dims: [ 1 ]
|
303 |
+
reshape: { shape: [ ] }
|
304 |
+
optional: true
|
305 |
+
},
|
306 |
+
{
|
307 |
+
name: "prompt_embedding_table"
|
308 |
+
data_type: TYPE_FP16
|
309 |
+
dims: [ -1, -1 ]
|
310 |
+
optional: true
|
311 |
+
allow_ragged_batch: true
|
312 |
+
},
|
313 |
+
{
|
314 |
+
name: "prompt_table_extra_ids"
|
315 |
+
data_type: TYPE_UINT64
|
316 |
+
dims: [ -1 ]
|
317 |
+
optional: true
|
318 |
+
allow_ragged_batch: true
|
319 |
+
},
|
320 |
+
{
|
321 |
+
name: "prompt_vocab_size"
|
322 |
+
data_type: TYPE_INT32
|
323 |
+
dims: [ 1 ]
|
324 |
+
reshape: { shape: [ ] }
|
325 |
+
optional: true
|
326 |
+
},
|
327 |
+
# cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
|
328 |
+
{
|
329 |
+
name: "cross_attention_mask"
|
330 |
+
data_type: TYPE_BOOL
|
331 |
+
dims: [ -1, -1 ]
|
332 |
+
optional: true
|
333 |
+
allow_ragged_batch: true
|
334 |
+
},
|
335 |
+
# Mrope param when mrope is used
|
336 |
+
{
|
337 |
+
name: "mrope_rotary_cos_sin"
|
338 |
+
data_type: TYPE_FP32
|
339 |
+
dims: [ -1 ]
|
340 |
+
optional: true
|
341 |
+
},
|
342 |
+
{
|
343 |
+
name: "mrope_position_deltas"
|
344 |
+
data_type: TYPE_INT64
|
345 |
+
dims: [ 1 ]
|
346 |
+
optional: true
|
347 |
+
},
|
348 |
+
# the unique task ID for the given LoRA.
|
349 |
+
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
|
350 |
+
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
|
351 |
+
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
|
352 |
+
{
|
353 |
+
name: "lora_task_id"
|
354 |
+
data_type: TYPE_UINT64
|
355 |
+
dims: [ 1 ]
|
356 |
+
reshape: { shape: [ ] }
|
357 |
+
optional: true
|
358 |
+
},
|
359 |
+
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
|
360 |
+
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
|
361 |
+
# each of the in / out tensors are first flattened and then concatenated together in the format above.
|
362 |
+
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
|
363 |
+
{
|
364 |
+
name: "lora_weights"
|
365 |
+
data_type: TYPE_FP16
|
366 |
+
dims: [ -1, -1 ]
|
367 |
+
optional: true
|
368 |
+
allow_ragged_batch: true
|
369 |
+
},
|
370 |
+
# module identifier (same size a first dimension of lora_weights)
|
371 |
+
# See LoraModule::ModuleType for model id mapping
|
372 |
+
#
|
373 |
+
# "attn_qkv": 0 # compbined qkv adapter
|
374 |
+
# "attn_q": 1 # q adapter
|
375 |
+
# "attn_k": 2 # k adapter
|
376 |
+
# "attn_v": 3 # v adapter
|
377 |
+
# "attn_dense": 4 # adapter for the dense layer in attention
|
378 |
+
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
|
379 |
+
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
|
380 |
+
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
|
381 |
+
#
|
382 |
+
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
|
383 |
+
{
|
384 |
+
name: "lora_config"
|
385 |
+
data_type: TYPE_INT32
|
386 |
+
dims: [ -1, 3 ]
|
387 |
+
optional: true
|
388 |
+
allow_ragged_batch: true
|
389 |
+
},
|
390 |
+
{
|
391 |
+
name: "context_phase_params"
|
392 |
+
data_type: TYPE_UINT8
|
393 |
+
dims: [ -1 ]
|
394 |
+
optional: true
|
395 |
+
allow_ragged_batch: true
|
396 |
+
},
|
397 |
+
# skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
|
398 |
+
{
|
399 |
+
name: "skip_cross_attn_blocks"
|
400 |
+
data_type: TYPE_BOOL
|
401 |
+
dims: [ 1 ]
|
402 |
+
optional: true
|
403 |
+
allow_ragged_batch: true
|
404 |
+
},
|
405 |
+
{
|
406 |
+
name: "retention_token_range_starts"
|
407 |
+
data_type: TYPE_INT32
|
408 |
+
dims: [ -1 ]
|
409 |
+
optional: true
|
410 |
+
allow_ragged_batch: true
|
411 |
+
},
|
412 |
+
{
|
413 |
+
name: "retention_token_range_ends"
|
414 |
+
data_type: TYPE_INT32
|
415 |
+
dims: [ -1 ]
|
416 |
+
optional: true
|
417 |
+
allow_ragged_batch: true
|
418 |
+
},
|
419 |
+
{
|
420 |
+
name: "retention_token_range_priorities"
|
421 |
+
data_type: TYPE_INT32
|
422 |
+
dims: [ -1 ]
|
423 |
+
optional: true
|
424 |
+
allow_ragged_batch: true
|
425 |
+
},
|
426 |
+
{
|
427 |
+
name: "retention_token_range_durations_ms"
|
428 |
+
data_type: TYPE_INT32
|
429 |
+
dims: [ -1 ]
|
430 |
+
optional: true
|
431 |
+
allow_ragged_batch: true
|
432 |
+
},
|
433 |
+
{
|
434 |
+
name: "retention_decode_priority"
|
435 |
+
data_type: TYPE_INT32
|
436 |
+
dims: [ 1 ]
|
437 |
+
optional: true
|
438 |
+
allow_ragged_batch: true
|
439 |
+
},
|
440 |
+
{
|
441 |
+
name: "retention_decode_duration_ms"
|
442 |
+
data_type: TYPE_INT32
|
443 |
+
dims: [ 1 ]
|
444 |
+
optional: true
|
445 |
+
allow_ragged_batch: true
|
446 |
+
},
|
447 |
+
{
|
448 |
+
name: "guided_decoding_guide_type"
|
449 |
+
data_type: TYPE_STRING
|
450 |
+
dims: [ 1 ]
|
451 |
+
optional: true
|
452 |
+
allow_ragged_batch: true
|
453 |
+
},
|
454 |
+
{
|
455 |
+
name: "guided_decoding_guide"
|
456 |
+
data_type: TYPE_STRING
|
457 |
+
dims: [ 1 ]
|
458 |
+
optional: true
|
459 |
+
allow_ragged_batch: true
|
460 |
+
},
|
461 |
+
{
|
462 |
+
name: "lookahead_window_size"
|
463 |
+
data_type: TYPE_INT32
|
464 |
+
dims: [ 1 ]
|
465 |
+
optional: true
|
466 |
+
allow_ragged_batch: true
|
467 |
+
},
|
468 |
+
{
|
469 |
+
name: "lookahead_ngram_size"
|
470 |
+
data_type: TYPE_INT32
|
471 |
+
dims: [ 1 ]
|
472 |
+
optional: true
|
473 |
+
allow_ragged_batch: true
|
474 |
+
},
|
475 |
+
{
|
476 |
+
name: "lookahead_verification_set_size"
|
477 |
+
data_type: TYPE_INT32
|
478 |
+
dims: [ 1 ]
|
479 |
+
optional: true
|
480 |
+
allow_ragged_batch: true
|
481 |
+
}
|
482 |
+
]
|
483 |
+
output [
|
484 |
+
{
|
485 |
+
name: "output_ids"
|
486 |
+
data_type: TYPE_INT32
|
487 |
+
dims: [ -1, -1 ]
|
488 |
+
},
|
489 |
+
{
|
490 |
+
name: "sequence_length"
|
491 |
+
data_type: TYPE_INT32
|
492 |
+
dims: [ -1 ]
|
493 |
+
},
|
494 |
+
{
|
495 |
+
name: "cum_log_probs"
|
496 |
+
data_type: TYPE_FP32
|
497 |
+
dims: [ -1 ]
|
498 |
+
},
|
499 |
+
{
|
500 |
+
name: "output_log_probs"
|
501 |
+
data_type: TYPE_FP32
|
502 |
+
dims: [ -1, -1 ]
|
503 |
+
},
|
504 |
+
{
|
505 |
+
name: "context_logits"
|
506 |
+
data_type: ${logits_datatype}
|
507 |
+
dims: [ -1, -1 ]
|
508 |
+
},
|
509 |
+
{
|
510 |
+
name: "generation_logits"
|
511 |
+
data_type: ${logits_datatype}
|
512 |
+
dims: [ -1, -1, -1 ]
|
513 |
+
},
|
514 |
+
{
|
515 |
+
name: "batch_index"
|
516 |
+
data_type: TYPE_INT32
|
517 |
+
dims: [ 1 ]
|
518 |
+
},
|
519 |
+
{
|
520 |
+
name: "sequence_index"
|
521 |
+
data_type: TYPE_INT32
|
522 |
+
dims: [ 1 ]
|
523 |
+
},
|
524 |
+
{
|
525 |
+
name: "context_phase_params"
|
526 |
+
data_type: TYPE_UINT8
|
527 |
+
dims: [ -1 ]
|
528 |
+
},
|
529 |
+
{
|
530 |
+
name: "kv_cache_alloc_new_blocks"
|
531 |
+
data_type: TYPE_INT32
|
532 |
+
dims: [ 1 ]
|
533 |
+
},
|
534 |
+
{
|
535 |
+
name: "kv_cache_reused_blocks"
|
536 |
+
data_type: TYPE_INT32
|
537 |
+
dims: [ 1 ]
|
538 |
+
},
|
539 |
+
{
|
540 |
+
name: "kv_cache_alloc_total_blocks"
|
541 |
+
data_type: TYPE_INT32
|
542 |
+
dims: [ 1 ]
|
543 |
+
},
|
544 |
+
{
|
545 |
+
name: "arrival_time_ns"
|
546 |
+
data_type: TYPE_INT64
|
547 |
+
dims: [ 1 ]
|
548 |
+
},
|
549 |
+
{
|
550 |
+
name: "first_scheduled_time_ns"
|
551 |
+
data_type: TYPE_INT64
|
552 |
+
dims: [ 1 ]
|
553 |
+
},
|
554 |
+
{
|
555 |
+
name: "first_token_time_ns"
|
556 |
+
data_type: TYPE_INT64
|
557 |
+
dims: [ 1 ]
|
558 |
+
},
|
559 |
+
{
|
560 |
+
name: "last_token_time_ns"
|
561 |
+
data_type: TYPE_INT64
|
562 |
+
dims: [ 1 ]
|
563 |
+
},
|
564 |
+
{
|
565 |
+
name: "acceptance_rate"
|
566 |
+
data_type: TYPE_FP32
|
567 |
+
dims: [ 1 ]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
name: "total_accepted_draft_tokens"
|
571 |
+
data_type: TYPE_INT32
|
572 |
+
dims: [ 1 ]
|
573 |
+
},
|
574 |
+
{
|
575 |
+
name: "total_draft_tokens"
|
576 |
+
data_type: TYPE_INT32
|
577 |
+
dims: [ 1 ]
|
578 |
+
}
|
579 |
+
]
|
580 |
+
instance_group [
|
581 |
+
{
|
582 |
+
count: 1
|
583 |
+
kind : KIND_CPU
|
584 |
+
}
|
585 |
+
]
|
586 |
+
parameters: {
|
587 |
+
key: "max_beam_width"
|
588 |
+
value: {
|
589 |
+
string_value: "${max_beam_width}"
|
590 |
+
}
|
591 |
+
}
|
592 |
+
parameters: {
|
593 |
+
key: "FORCE_CPU_ONLY_INPUT_TENSORS"
|
594 |
+
value: {
|
595 |
+
string_value: "no"
|
596 |
+
}
|
597 |
+
}
|
598 |
+
parameters: {
|
599 |
+
key: "gpt_model_type"
|
600 |
+
value: {
|
601 |
+
string_value: "${batching_strategy}"
|
602 |
+
}
|
603 |
+
}
|
604 |
+
parameters: {
|
605 |
+
key: "gpt_model_path"
|
606 |
+
value: {
|
607 |
+
string_value: "${engine_dir}"
|
608 |
+
}
|
609 |
+
}
|
610 |
+
parameters: {
|
611 |
+
key: "encoder_model_path"
|
612 |
+
value: {
|
613 |
+
string_value: "${encoder_engine_dir}"
|
614 |
+
}
|
615 |
+
}
|
616 |
+
parameters: {
|
617 |
+
key: "max_tokens_in_paged_kv_cache"
|
618 |
+
value: {
|
619 |
+
string_value: "${max_tokens_in_paged_kv_cache}"
|
620 |
+
}
|
621 |
+
}
|
622 |
+
parameters: {
|
623 |
+
key: "max_attention_window_size"
|
624 |
+
value: {
|
625 |
+
string_value: "${max_attention_window_size}"
|
626 |
+
}
|
627 |
+
}
|
628 |
+
parameters: {
|
629 |
+
key: "sink_token_length"
|
630 |
+
value: {
|
631 |
+
string_value: "${sink_token_length}"
|
632 |
+
}
|
633 |
+
}
|
634 |
+
parameters: {
|
635 |
+
key: "batch_scheduler_policy"
|
636 |
+
value: {
|
637 |
+
string_value: "${batch_scheduler_policy}"
|
638 |
+
}
|
639 |
+
}
|
640 |
+
parameters: {
|
641 |
+
key: "kv_cache_free_gpu_mem_fraction"
|
642 |
+
value: {
|
643 |
+
string_value: "${kv_cache_free_gpu_mem_fraction}"
|
644 |
+
}
|
645 |
+
}
|
646 |
+
parameters: {
|
647 |
+
key: "cross_kv_cache_fraction"
|
648 |
+
value: {
|
649 |
+
string_value: "${cross_kv_cache_fraction}"
|
650 |
+
}
|
651 |
+
}
|
652 |
+
parameters: {
|
653 |
+
key: "kv_cache_host_memory_bytes"
|
654 |
+
value: {
|
655 |
+
string_value: "${kv_cache_host_memory_bytes}"
|
656 |
+
}
|
657 |
+
}
|
658 |
+
# kv_cache_onboard_blocks is for internal implementation.
|
659 |
+
parameters: {
|
660 |
+
key: "kv_cache_onboard_blocks"
|
661 |
+
value: {
|
662 |
+
string_value: "${kv_cache_onboard_blocks}"
|
663 |
+
}
|
664 |
+
}
|
665 |
+
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
|
666 |
+
# parameters: {
|
667 |
+
# key: "enable_trt_overlap"
|
668 |
+
# value: {
|
669 |
+
# string_value: "${enable_trt_overlap}"
|
670 |
+
# }
|
671 |
+
# }
|
672 |
+
parameters: {
|
673 |
+
key: "exclude_input_in_output"
|
674 |
+
value: {
|
675 |
+
string_value: "${exclude_input_in_output}"
|
676 |
+
}
|
677 |
+
}
|
678 |
+
parameters: {
|
679 |
+
key: "cancellation_check_period_ms"
|
680 |
+
value: {
|
681 |
+
string_value: "${cancellation_check_period_ms}"
|
682 |
+
}
|
683 |
+
}
|
684 |
+
parameters: {
|
685 |
+
key: "stats_check_period_ms"
|
686 |
+
value: {
|
687 |
+
string_value: "${stats_check_period_ms}"
|
688 |
+
}
|
689 |
+
}
|
690 |
+
parameters: {
|
691 |
+
key: "iter_stats_max_iterations"
|
692 |
+
value: {
|
693 |
+
string_value: "${iter_stats_max_iterations}"
|
694 |
+
}
|
695 |
+
}
|
696 |
+
parameters: {
|
697 |
+
key: "request_stats_max_iterations"
|
698 |
+
value: {
|
699 |
+
string_value: "${request_stats_max_iterations}"
|
700 |
+
}
|
701 |
+
}
|
702 |
+
parameters: {
|
703 |
+
key: "enable_kv_cache_reuse"
|
704 |
+
value: {
|
705 |
+
string_value: "${enable_kv_cache_reuse}"
|
706 |
+
}
|
707 |
+
}
|
708 |
+
parameters: {
|
709 |
+
key: "normalize_log_probs"
|
710 |
+
value: {
|
711 |
+
string_value: "${normalize_log_probs}"
|
712 |
+
}
|
713 |
+
}
|
714 |
+
parameters: {
|
715 |
+
key: "enable_chunked_context"
|
716 |
+
value: {
|
717 |
+
string_value: "${enable_chunked_context}"
|
718 |
+
}
|
719 |
+
}
|
720 |
+
parameters: {
|
721 |
+
key: "gpu_device_ids"
|
722 |
+
value: {
|
723 |
+
string_value: "${gpu_device_ids}"
|
724 |
+
}
|
725 |
+
}
|
726 |
+
parameters: {
|
727 |
+
key: "participant_ids"
|
728 |
+
value: {
|
729 |
+
string_value: "${participant_ids}"
|
730 |
+
}
|
731 |
+
}
|
732 |
+
parameters: {
|
733 |
+
key: "lora_cache_optimal_adapter_size"
|
734 |
+
value: {
|
735 |
+
string_value: "${lora_cache_optimal_adapter_size}"
|
736 |
+
}
|
737 |
+
}
|
738 |
+
parameters: {
|
739 |
+
key: "lora_cache_max_adapter_size"
|
740 |
+
value: {
|
741 |
+
string_value: "${lora_cache_max_adapter_size}"
|
742 |
+
}
|
743 |
+
}
|
744 |
+
parameters: {
|
745 |
+
key: "lora_cache_gpu_memory_fraction"
|
746 |
+
value: {
|
747 |
+
string_value: "${lora_cache_gpu_memory_fraction}"
|
748 |
+
}
|
749 |
+
}
|
750 |
+
parameters: {
|
751 |
+
key: "lora_cache_host_memory_bytes"
|
752 |
+
value: {
|
753 |
+
string_value: "${lora_cache_host_memory_bytes}"
|
754 |
+
}
|
755 |
+
}
|
756 |
+
parameters: {
|
757 |
+
key: "lora_prefetch_dir"
|
758 |
+
value: {
|
759 |
+
string_value: "${lora_prefetch_dir}"
|
760 |
+
}
|
761 |
+
}
|
762 |
+
parameters: {
|
763 |
+
key: "decoding_mode"
|
764 |
+
value: {
|
765 |
+
string_value: "${decoding_mode}"
|
766 |
+
}
|
767 |
+
}
|
768 |
+
parameters: {
|
769 |
+
key: "executor_worker_path"
|
770 |
+
value: {
|
771 |
+
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
|
772 |
+
}
|
773 |
+
}
|
774 |
+
parameters: {
|
775 |
+
key: "lookahead_window_size"
|
776 |
+
value: {
|
777 |
+
string_value: "${lookahead_window_size}"
|
778 |
+
}
|
779 |
+
}
|
780 |
+
parameters: {
|
781 |
+
key: "lookahead_ngram_size"
|
782 |
+
value: {
|
783 |
+
string_value: "${lookahead_ngram_size}"
|
784 |
+
}
|
785 |
+
}
|
786 |
+
parameters: {
|
787 |
+
key: "lookahead_verification_set_size"
|
788 |
+
value: {
|
789 |
+
string_value: "${lookahead_verification_set_size}"
|
790 |
+
}
|
791 |
+
}
|
792 |
+
parameters: {
|
793 |
+
key: "medusa_choices"
|
794 |
+
value: {
|
795 |
+
string_value: "${medusa_choices}"
|
796 |
+
}
|
797 |
+
}
|
798 |
+
parameters: {
|
799 |
+
key: "eagle_choices"
|
800 |
+
value: {
|
801 |
+
string_value: "${eagle_choices}"
|
802 |
+
}
|
803 |
+
}
|
804 |
+
parameters: {
|
805 |
+
key: "gpu_weights_percent"
|
806 |
+
value: {
|
807 |
+
string_value: "${gpu_weights_percent}"
|
808 |
+
}
|
809 |
+
}
|
810 |
+
parameters: {
|
811 |
+
key: "enable_context_fmha_fp32_acc"
|
812 |
+
value: {
|
813 |
+
string_value: "${enable_context_fmha_fp32_acc}"
|
814 |
+
}
|
815 |
+
}
|
816 |
+
parameters: {
|
817 |
+
key: "multi_block_mode"
|
818 |
+
value: {
|
819 |
+
string_value: "${multi_block_mode}"
|
820 |
+
}
|
821 |
+
}
|
822 |
+
parameters: {
|
823 |
+
key: "cuda_graph_mode"
|
824 |
+
value: {
|
825 |
+
string_value: "${cuda_graph_mode}"
|
826 |
+
}
|
827 |
+
}
|
828 |
+
parameters: {
|
829 |
+
key: "cuda_graph_cache_size"
|
830 |
+
value: {
|
831 |
+
string_value: "${cuda_graph_cache_size}"
|
832 |
+
}
|
833 |
+
}
|
834 |
+
parameters: {
|
835 |
+
key: "speculative_decoding_fast_logits"
|
836 |
+
value: {
|
837 |
+
string_value: "${speculative_decoding_fast_logits}"
|
838 |
+
}
|
839 |
+
}
|
840 |
+
parameters: {
|
841 |
+
key: "tokenizer_dir"
|
842 |
+
value: {
|
843 |
+
string_value: "${tokenizer_dir}"
|
844 |
+
}
|
845 |
+
}
|
846 |
+
parameters: {
|
847 |
+
key: "guided_decoding_backend"
|
848 |
+
value: {
|
849 |
+
string_value: "${guided_decoding_backend}"
|
850 |
+
}
|
851 |
+
}
|
852 |
+
parameters: {
|
853 |
+
key: "xgrammar_tokenizer_info_path"
|
854 |
+
value: {
|
855 |
+
string_value: "${xgrammar_tokenizer_info_path}"
|
856 |
+
}
|
857 |
+
}
|
runtime/triton_trtllm/model_repo/vocoder/1/model.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Redistribution and use in source and binary forms, with or without
|
4 |
+
# modification, are permitted provided that the following conditions
|
5 |
+
# are met:
|
6 |
+
# * Redistributions of source code must retain the above copyright
|
7 |
+
# notice, this list of conditions and the following disclaimer.
|
8 |
+
# * Redistributions in binary form must reproduce the above copyright
|
9 |
+
# notice, this list of conditions and the following disclaimer in the
|
10 |
+
# documentation and/or other materials provided with the distribution.
|
11 |
+
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
12 |
+
# contributors may be used to endorse or promote products derived
|
13 |
+
# from this software without specific prior written permission.
|
14 |
+
#
|
15 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
16 |
+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
18 |
+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
19 |
+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
20 |
+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
21 |
+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
22 |
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
23 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
24 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
25 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 |
+
|
27 |
+
import json
|
28 |
+
import os
|
29 |
+
import logging
|
30 |
+
from typing import List, Dict
|
31 |
+
|
32 |
+
import torch
|
33 |
+
from torch.utils.dlpack import to_dlpack
|
34 |
+
|
35 |
+
import triton_python_backend_utils as pb_utils
|
36 |
+
|
37 |
+
from sparktts.models.bicodec import BiCodec
|
38 |
+
|
39 |
+
# Configure logging
|
40 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
41 |
+
logger = logging.getLogger(__name__)
|
42 |
+
|
43 |
+
class TritonPythonModel:
|
44 |
+
"""Triton Python model for vocoder.
|
45 |
+
|
46 |
+
This model takes global and semantic tokens as input and generates audio waveforms
|
47 |
+
using the BiCodec vocoder.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def initialize(self, args):
|
51 |
+
"""Initialize the model.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
args: Dictionary containing model configuration
|
55 |
+
"""
|
56 |
+
# Parse model parameters
|
57 |
+
parameters = json.loads(args['model_config'])['parameters']
|
58 |
+
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
59 |
+
model_dir = model_params["model_dir"]
|
60 |
+
|
61 |
+
# Initialize device and vocoder
|
62 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
63 |
+
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
64 |
+
|
65 |
+
self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec")
|
66 |
+
del self.vocoder.encoder, self.vocoder.postnet
|
67 |
+
self.vocoder.eval().to(self.device) # Set model to evaluation mode
|
68 |
+
|
69 |
+
logger.info("Vocoder initialized successfully")
|
70 |
+
|
71 |
+
|
72 |
+
def execute(self, requests):
|
73 |
+
"""Execute inference on the batched requests.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
requests: List of inference requests
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
List of inference responses containing generated waveforms
|
80 |
+
"""
|
81 |
+
global_tokens_list, semantic_tokens_list = [], []
|
82 |
+
|
83 |
+
# Process each request in batch
|
84 |
+
for request in requests:
|
85 |
+
global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy()
|
86 |
+
semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy()
|
87 |
+
global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device))
|
88 |
+
semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device))
|
89 |
+
|
90 |
+
# Concatenate tokens for batch processing
|
91 |
+
global_tokens = torch.cat(global_tokens_list, dim=0)
|
92 |
+
semantic_tokens = torch.cat(semantic_tokens_list, dim=0)
|
93 |
+
|
94 |
+
|
95 |
+
# Generate waveforms
|
96 |
+
with torch.no_grad():
|
97 |
+
wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1))
|
98 |
+
|
99 |
+
# Prepare responses
|
100 |
+
responses = []
|
101 |
+
for i in range(len(requests)):
|
102 |
+
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i]))
|
103 |
+
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
|
104 |
+
responses.append(inference_response)
|
105 |
+
|
106 |
+
return responses
|
runtime/triton_trtllm/model_repo/vocoder/config.pbtxt
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
name: "vocoder"
|
16 |
+
backend: "python"
|
17 |
+
max_batch_size: ${triton_max_batch_size}
|
18 |
+
dynamic_batching {
|
19 |
+
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
|
20 |
+
}
|
21 |
+
parameters [
|
22 |
+
{
|
23 |
+
key: "model_dir",
|
24 |
+
value: {string_value:"${model_dir}"}
|
25 |
+
}
|
26 |
+
]
|
27 |
+
|
28 |
+
input [
|
29 |
+
{
|
30 |
+
name: "global_tokens"
|
31 |
+
data_type: TYPE_INT32
|
32 |
+
dims: [-1]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
name: "semantic_tokens"
|
36 |
+
data_type: TYPE_INT32
|
37 |
+
dims: [-1]
|
38 |
+
}
|
39 |
+
]
|
40 |
+
output [
|
41 |
+
{
|
42 |
+
name: "waveform"
|
43 |
+
data_type: TYPE_FP32
|
44 |
+
dims: [ -1 ]
|
45 |
+
}
|
46 |
+
]
|
47 |
+
|
48 |
+
instance_group [
|
49 |
+
{
|
50 |
+
count: 1
|
51 |
+
kind: KIND_CPU
|
52 |
+
}
|
53 |
+
]
|
runtime/triton_trtllm/run.sh
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export PYTHONPATH=../../../Spark-TTS/
|
2 |
+
export CUDA_VISIBLE_DEVICES=0
|
3 |
+
stage=$1
|
4 |
+
stop_stage=$2
|
5 |
+
service_type=$3
|
6 |
+
echo "Start stage: $stage, Stop stage: $stop_stage service_type: $service_type"
|
7 |
+
|
8 |
+
huggingface_model_local_dir=../../pretrained_models/Spark-TTS-0.5B
|
9 |
+
trt_dtype=bfloat16
|
10 |
+
trt_weights_dir=./tllm_checkpoint_${trt_dtype}
|
11 |
+
trt_engines_dir=./trt_engines_${trt_dtype}
|
12 |
+
|
13 |
+
model_repo=./model_repo_test
|
14 |
+
|
15 |
+
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
16 |
+
echo "Downloading Spark-TTS-0.5B from HuggingFace"
|
17 |
+
huggingface-cli download SparkAudio/Spark-TTS-0.5B --local-dir $huggingface_model_local_dir || exit 1
|
18 |
+
fi
|
19 |
+
|
20 |
+
|
21 |
+
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
22 |
+
echo "Converting checkpoint to TensorRT weights"
|
23 |
+
python scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir/LLM \
|
24 |
+
--output_dir $trt_weights_dir \
|
25 |
+
--dtype $trt_dtype || exit 1
|
26 |
+
|
27 |
+
echo "Building TensorRT engines"
|
28 |
+
trtllm-build --checkpoint_dir $trt_weights_dir \
|
29 |
+
--output_dir $trt_engines_dir \
|
30 |
+
--max_batch_size 16 \
|
31 |
+
--max_num_tokens 32768 \
|
32 |
+
--gemm_plugin $trt_dtype || exit 1
|
33 |
+
fi
|
34 |
+
|
35 |
+
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
36 |
+
echo "Creating model repository"
|
37 |
+
rm -rf $model_repo
|
38 |
+
mkdir -p $model_repo
|
39 |
+
spark_tts_dir="spark_tts"
|
40 |
+
|
41 |
+
cp -r ./model_repo/${spark_tts_dir} $model_repo
|
42 |
+
cp -r ./model_repo/audio_tokenizer $model_repo
|
43 |
+
cp -r ./model_repo/tensorrt_llm $model_repo
|
44 |
+
cp -r ./model_repo/vocoder $model_repo
|
45 |
+
|
46 |
+
ENGINE_PATH=$trt_engines_dir
|
47 |
+
MAX_QUEUE_DELAY_MICROSECONDS=0
|
48 |
+
MODEL_DIR=$huggingface_model_local_dir
|
49 |
+
LLM_TOKENIZER_DIR=$huggingface_model_local_dir/LLM
|
50 |
+
BLS_INSTANCE_NUM=4
|
51 |
+
TRITON_MAX_BATCH_SIZE=16
|
52 |
+
# streaming TTS parameters
|
53 |
+
AUDIO_CHUNK_DURATION=1.0
|
54 |
+
MAX_AUDIO_CHUNK_DURATION=30.0
|
55 |
+
AUDIO_CHUNK_SIZE_SCALE_FACTOR=8.0
|
56 |
+
AUDIO_CHUNK_OVERLAP_DURATION=0.1
|
57 |
+
python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
58 |
+
python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
|
59 |
+
if [ "$service_type" == "streaming" ]; then
|
60 |
+
DECOUPLED_MODE=True
|
61 |
+
else
|
62 |
+
DECOUPLED_MODE=False
|
63 |
+
fi
|
64 |
+
python3 scripts/fill_template.py -i ${model_repo}/${spark_tts_dir}/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},audio_chunk_duration:${AUDIO_CHUNK_DURATION},max_audio_chunk_duration:${MAX_AUDIO_CHUNK_DURATION},audio_chunk_size_scale_factor:${AUDIO_CHUNK_SIZE_SCALE_FACTOR},audio_chunk_overlap_duration:${AUDIO_CHUNK_OVERLAP_DURATION}
|
65 |
+
python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
|
66 |
+
|
67 |
+
fi
|
68 |
+
|
69 |
+
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
70 |
+
echo "Starting Triton server"
|
71 |
+
tritonserver --model-repository ${model_repo}
|
72 |
+
fi
|
73 |
+
|
74 |
+
|
75 |
+
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
76 |
+
echo "Running benchmark client"
|
77 |
+
num_task=2
|
78 |
+
if [ "$service_type" == "streaming" ]; then
|
79 |
+
mode="streaming"
|
80 |
+
else
|
81 |
+
mode="offline"
|
82 |
+
fi
|
83 |
+
python3 client_grpc.py \
|
84 |
+
--server-addr localhost \
|
85 |
+
--model-name spark_tts \
|
86 |
+
--num-tasks $num_task \
|
87 |
+
--mode $mode \
|
88 |
+
--log-dir ./log_concurrent_tasks_${num_task}_${mode}_new
|
89 |
+
fi
|
90 |
+
|
91 |
+
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
92 |
+
echo "Running single utterance client"
|
93 |
+
if [ "$service_type" == "streaming" ]; then
|
94 |
+
python client_grpc.py \
|
95 |
+
--server-addr localhost \
|
96 |
+
--reference-audio ../../example/prompt_audio.wav \
|
97 |
+
--reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
98 |
+
--target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
|
99 |
+
--model-name spark_tts \
|
100 |
+
--chunk-overlap-duration 0.1 \
|
101 |
+
--mode streaming
|
102 |
+
else
|
103 |
+
python client_http.py \
|
104 |
+
--reference-audio ../../example/prompt_audio.wav \
|
105 |
+
--reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
106 |
+
--target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
|
107 |
+
--model-name spark_tts
|
108 |
+
fi
|
109 |
+
fi
|
runtime/triton_trtllm/scripts/convert_checkpoint.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import traceback
|
5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
6 |
+
|
7 |
+
from transformers import AutoConfig
|
8 |
+
|
9 |
+
import tensorrt_llm
|
10 |
+
from tensorrt_llm._utils import release_gc
|
11 |
+
from tensorrt_llm.logger import logger
|
12 |
+
from tensorrt_llm.mapping import Mapping
|
13 |
+
from tensorrt_llm.models import QWenForCausalLM
|
14 |
+
from tensorrt_llm.models.modeling_utils import QuantConfig
|
15 |
+
from tensorrt_llm.quantization import QuantAlgo
|
16 |
+
|
17 |
+
|
18 |
+
def parse_arguments():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('--model_dir', type=str, default=None, required=True)
|
21 |
+
parser.add_argument('--tp_size',
|
22 |
+
type=int,
|
23 |
+
default=1,
|
24 |
+
help='N-way tensor parallelism size')
|
25 |
+
parser.add_argument('--pp_size',
|
26 |
+
type=int,
|
27 |
+
default=1,
|
28 |
+
help='N-way pipeline parallelism size')
|
29 |
+
parser.add_argument(
|
30 |
+
'--dtype',
|
31 |
+
type=str,
|
32 |
+
default='auto',
|
33 |
+
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
34 |
+
help=
|
35 |
+
"The data type for the model weights and activations if not quantized. "
|
36 |
+
"If 'auto', the data type is automatically inferred from the source model; "
|
37 |
+
"however, if the source dtype is float32, it is converted to float16.")
|
38 |
+
parser.add_argument(
|
39 |
+
'--use_weight_only',
|
40 |
+
default=False,
|
41 |
+
action="store_true",
|
42 |
+
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
43 |
+
'See --weight_only_precision to set the precision')
|
44 |
+
parser.add_argument(
|
45 |
+
'--disable_weight_only_quant_plugin',
|
46 |
+
default=False,
|
47 |
+
action="store_true",
|
48 |
+
help=
|
49 |
+
'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
|
50 |
+
'You must also use --use_weight_only for that argument to have an impact.'
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
'--weight_only_precision',
|
54 |
+
const='int8',
|
55 |
+
type=str,
|
56 |
+
nargs='?',
|
57 |
+
default='int8',
|
58 |
+
choices=['int8', 'int4', 'int4_gptq'],
|
59 |
+
help=
|
60 |
+
'Define the precision for the weights when using weight-only quantization.'
|
61 |
+
'You must also use --use_weight_only for that argument to have an impact.'
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
'--calib_dataset',
|
65 |
+
type=str,
|
66 |
+
default='ccdv/cnn_dailymail',
|
67 |
+
help=
|
68 |
+
"The huggingface dataset name or the local directory of the dataset for calibration."
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--smoothquant",
|
72 |
+
"-sq",
|
73 |
+
type=float,
|
74 |
+
default=None,
|
75 |
+
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
76 |
+
" to Smoothquant the model, and output int8 weights."
|
77 |
+
" A good first try is 0.5. Must be in [0, 1]")
|
78 |
+
parser.add_argument(
|
79 |
+
'--per_channel',
|
80 |
+
action="store_true",
|
81 |
+
default=False,
|
82 |
+
help=
|
83 |
+
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
84 |
+
'per_channel instead uses a different static scaling factor for each channel. '
|
85 |
+
'The latter is usually more accurate, but a little slower.')
|
86 |
+
parser.add_argument(
|
87 |
+
'--per_token',
|
88 |
+
action="store_true",
|
89 |
+
default=False,
|
90 |
+
help=
|
91 |
+
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
92 |
+
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
93 |
+
'The latter is usually more accurate, but a little slower.')
|
94 |
+
parser.add_argument(
|
95 |
+
'--int8_kv_cache',
|
96 |
+
default=False,
|
97 |
+
action="store_true",
|
98 |
+
help=
|
99 |
+
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
'--per_group',
|
103 |
+
default=False,
|
104 |
+
action="store_true",
|
105 |
+
help=
|
106 |
+
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
107 |
+
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
108 |
+
'The flag is built for GPTQ/AWQ quantization.')
|
109 |
+
|
110 |
+
parser.add_argument('--group_size',
|
111 |
+
type=int,
|
112 |
+
default=128,
|
113 |
+
help='Group size used in GPTQ quantization.')
|
114 |
+
|
115 |
+
parser.add_argument("--load_model_on_cpu", action="store_true")
|
116 |
+
parser.add_argument(
|
117 |
+
'--use_parallel_embedding',
|
118 |
+
action="store_true",
|
119 |
+
default=False,
|
120 |
+
help=
|
121 |
+
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
'--embedding_sharding_dim',
|
125 |
+
type=int,
|
126 |
+
default=0,
|
127 |
+
choices=[0, 1],
|
128 |
+
help=
|
129 |
+
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
130 |
+
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
131 |
+
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
132 |
+
)
|
133 |
+
parser.add_argument('--output_dir',
|
134 |
+
type=str,
|
135 |
+
default='tllm_checkpoint',
|
136 |
+
help='The path to save the TensorRT-LLM checkpoint')
|
137 |
+
parser.add_argument(
|
138 |
+
'--workers',
|
139 |
+
type=int,
|
140 |
+
default=1,
|
141 |
+
help='The number of workers for converting checkpoint in parallel')
|
142 |
+
parser.add_argument(
|
143 |
+
'--moe_tp_size',
|
144 |
+
type=int,
|
145 |
+
default=-1,
|
146 |
+
help=
|
147 |
+
'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
'--moe_ep_size',
|
151 |
+
type=int,
|
152 |
+
default=-1,
|
153 |
+
help=
|
154 |
+
'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
|
155 |
+
)
|
156 |
+
args = parser.parse_args()
|
157 |
+
return args
|
158 |
+
|
159 |
+
|
160 |
+
def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
|
161 |
+
'''return config dict with quantization info based on the command line args
|
162 |
+
'''
|
163 |
+
quant_config = QuantConfig()
|
164 |
+
if args.use_weight_only:
|
165 |
+
if args.weight_only_precision == 'int8':
|
166 |
+
quant_config.quant_algo = QuantAlgo.W8A16
|
167 |
+
elif args.weight_only_precision == 'int4':
|
168 |
+
quant_config.quant_algo = QuantAlgo.W4A16
|
169 |
+
elif args.smoothquant:
|
170 |
+
quant_config.smoothquant_val = args.smoothquant
|
171 |
+
if args.per_channel:
|
172 |
+
if args.per_token:
|
173 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
|
174 |
+
else:
|
175 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
|
176 |
+
else:
|
177 |
+
if args.per_token:
|
178 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
|
179 |
+
else:
|
180 |
+
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
|
181 |
+
|
182 |
+
if args.int8_kv_cache:
|
183 |
+
quant_config.kv_cache_quant_algo = QuantAlgo.INT8
|
184 |
+
|
185 |
+
if args.weight_only_precision == 'int4_gptq':
|
186 |
+
quant_config.group_size = args.group_size
|
187 |
+
quant_config.has_zero_point = True
|
188 |
+
quant_config.pre_quant_scale = False
|
189 |
+
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
190 |
+
|
191 |
+
return quant_config
|
192 |
+
|
193 |
+
|
194 |
+
def update_quant_config_from_hf(quant_config, hf_config,
|
195 |
+
override_fields) -> tuple[QuantConfig, dict]:
|
196 |
+
hf_config_dict = hf_config.to_dict()
|
197 |
+
if hf_config_dict.get('quantization_config'):
|
198 |
+
# update the quant_algo, and clamp_val.
|
199 |
+
if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
|
200 |
+
logger.info(
|
201 |
+
"Load quantization configs from huggingface model_config.")
|
202 |
+
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
203 |
+
quant_config.group_size = hf_config_dict['quantization_config'].get(
|
204 |
+
'group_size', 128)
|
205 |
+
quant_config.has_zero_point = hf_config_dict[
|
206 |
+
'quantization_config'].get('zero_point', False)
|
207 |
+
override_fields.update({"use_autoawq": True})
|
208 |
+
elif hf_config_dict['quantization_config'].get(
|
209 |
+
'quant_method') == 'gptq':
|
210 |
+
logger.info(
|
211 |
+
"Load quantization configs from huggingface model_config.")
|
212 |
+
desc_act = hf_config_dict['quantization_config'].get(
|
213 |
+
'desc_act', False)
|
214 |
+
if desc_act:
|
215 |
+
raise ValueError("GPTQ with desc_act=True is not implemented!")
|
216 |
+
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
|
217 |
+
quant_config.group_size = hf_config_dict['quantization_config'].get(
|
218 |
+
'group_size', 128)
|
219 |
+
quant_config.has_zero_point = hf_config_dict[
|
220 |
+
'quantization_config'].get('sym', False)
|
221 |
+
return quant_config, override_fields
|
222 |
+
|
223 |
+
|
224 |
+
def args_to_build_options(args):
|
225 |
+
return {
|
226 |
+
'use_parallel_embedding': args.use_parallel_embedding,
|
227 |
+
'embedding_sharding_dim': args.embedding_sharding_dim,
|
228 |
+
'disable_weight_only_quant_plugin':
|
229 |
+
args.disable_weight_only_quant_plugin
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
def convert_and_save_hf(args):
|
234 |
+
model_dir = args.model_dir
|
235 |
+
world_size = args.tp_size * args.pp_size
|
236 |
+
# Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
|
237 |
+
# Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
|
238 |
+
# before the refactor is done.
|
239 |
+
override_fields = {}
|
240 |
+
override_fields.update(args_to_build_options(args))
|
241 |
+
quant_config = args_to_quant_config(args)
|
242 |
+
|
243 |
+
try:
|
244 |
+
hf_config = AutoConfig.from_pretrained(model_dir,
|
245 |
+
trust_remote_code=True)
|
246 |
+
quant_config, override_fields = update_quant_config_from_hf(
|
247 |
+
quant_config, hf_config, override_fields)
|
248 |
+
except:
|
249 |
+
logger.warning("AutoConfig cannot load the huggingface config.")
|
250 |
+
|
251 |
+
if args.smoothquant is not None or args.int8_kv_cache:
|
252 |
+
mapping = Mapping(
|
253 |
+
world_size=world_size,
|
254 |
+
tp_size=args.tp_size,
|
255 |
+
pp_size=args.pp_size,
|
256 |
+
moe_tp_size=args.moe_tp_size,
|
257 |
+
moe_ep_size=args.moe_ep_size,
|
258 |
+
)
|
259 |
+
QWenForCausalLM.quantize(args.model_dir,
|
260 |
+
args.output_dir,
|
261 |
+
dtype=args.dtype,
|
262 |
+
mapping=mapping,
|
263 |
+
quant_config=quant_config,
|
264 |
+
calib_dataset=args.calib_dataset,
|
265 |
+
**override_fields)
|
266 |
+
else:
|
267 |
+
|
268 |
+
def convert_and_save_rank(args, rank):
|
269 |
+
mapping = Mapping(world_size=world_size,
|
270 |
+
rank=rank,
|
271 |
+
tp_size=args.tp_size,
|
272 |
+
pp_size=args.pp_size,
|
273 |
+
moe_tp_size=args.moe_tp_size,
|
274 |
+
moe_ep_size=args.moe_ep_size)
|
275 |
+
qwen = QWenForCausalLM.from_hugging_face(model_dir,
|
276 |
+
args.dtype,
|
277 |
+
mapping=mapping,
|
278 |
+
quant_config=quant_config,
|
279 |
+
**override_fields)
|
280 |
+
qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
|
281 |
+
del qwen
|
282 |
+
|
283 |
+
execute(args.workers, [convert_and_save_rank] * world_size, args)
|
284 |
+
release_gc()
|
285 |
+
|
286 |
+
|
287 |
+
def execute(workers, func, args):
|
288 |
+
if workers == 1:
|
289 |
+
for rank, f in enumerate(func):
|
290 |
+
f(args, rank)
|
291 |
+
else:
|
292 |
+
with ThreadPoolExecutor(max_workers=workers) as p:
|
293 |
+
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
|
294 |
+
exceptions = []
|
295 |
+
for future in as_completed(futures):
|
296 |
+
try:
|
297 |
+
future.result()
|
298 |
+
except Exception as e:
|
299 |
+
traceback.print_exc()
|
300 |
+
exceptions.append(e)
|
301 |
+
assert len(
|
302 |
+
exceptions
|
303 |
+
) == 0, "Checkpoint conversion failed, please check error log."
|
304 |
+
|
305 |
+
|
306 |
+
def main():
|
307 |
+
print(tensorrt_llm.__version__)
|
308 |
+
args = parse_arguments()
|
309 |
+
|
310 |
+
if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
|
311 |
+
# moe default to tp-only
|
312 |
+
args.moe_tp_size = args.tp_size
|
313 |
+
args.moe_ep_size = 1
|
314 |
+
elif (args.moe_tp_size == -1):
|
315 |
+
args.moe_tp_size = args.tp_size // args.moe_ep_size
|
316 |
+
elif (args.moe_ep_size == -1):
|
317 |
+
args.moe_ep_size = args.tp_size // args.moe_tp_size
|
318 |
+
assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
|
319 |
+
), "moe_tp_size * moe_ep_size must equal to tp_size"
|
320 |
+
|
321 |
+
tik = time.time()
|
322 |
+
|
323 |
+
if not os.path.exists(args.output_dir):
|
324 |
+
os.makedirs(args.output_dir)
|
325 |
+
|
326 |
+
assert args.model_dir is not None
|
327 |
+
convert_and_save_hf(args)
|
328 |
+
|
329 |
+
tok = time.time()
|
330 |
+
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
331 |
+
print(f'Total time of converting checkpoints: {t}')
|
332 |
+
|
333 |
+
|
334 |
+
if __name__ == '__main__':
|
335 |
+
main()
|
runtime/triton_trtllm/scripts/fill_template.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from string import Template
|
4 |
+
|
5 |
+
|
6 |
+
def split(string, delimiter):
|
7 |
+
"""Split a string using delimiter. Supports escaping.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
string (str): The string to split.
|
11 |
+
delimiter (str): The delimiter to split the string with.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
list: A list of strings.
|
15 |
+
"""
|
16 |
+
result = []
|
17 |
+
current = ""
|
18 |
+
escape = False
|
19 |
+
for char in string:
|
20 |
+
if escape:
|
21 |
+
current += char
|
22 |
+
escape = False
|
23 |
+
elif char == delimiter:
|
24 |
+
result.append(current)
|
25 |
+
current = ""
|
26 |
+
elif char == "\\":
|
27 |
+
escape = True
|
28 |
+
else:
|
29 |
+
current += char
|
30 |
+
result.append(current)
|
31 |
+
return result
|
32 |
+
|
33 |
+
|
34 |
+
def main(file_path, substitutions, in_place):
|
35 |
+
with open(file_path) as f:
|
36 |
+
pbtxt = Template(f.read())
|
37 |
+
|
38 |
+
sub_dict = {
|
39 |
+
"max_queue_size": 0,
|
40 |
+
'max_queue_delay_microseconds': 0,
|
41 |
+
}
|
42 |
+
for sub in split(substitutions, ","):
|
43 |
+
key, value = split(sub, ":")
|
44 |
+
sub_dict[key] = value
|
45 |
+
|
46 |
+
assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."
|
47 |
+
|
48 |
+
pbtxt = pbtxt.safe_substitute(sub_dict)
|
49 |
+
|
50 |
+
if in_place:
|
51 |
+
with open(file_path, "w") as f:
|
52 |
+
f.write(pbtxt)
|
53 |
+
else:
|
54 |
+
print(pbtxt)
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
parser = ArgumentParser()
|
59 |
+
parser.add_argument("file_path", help="path of the .pbtxt to modify")
|
60 |
+
parser.add_argument(
|
61 |
+
"substitutions",
|
62 |
+
help=
|
63 |
+
"substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
|
64 |
+
)
|
65 |
+
parser.add_argument("--in_place",
|
66 |
+
"-i",
|
67 |
+
action="store_true",
|
68 |
+
help="do the operation in-place")
|
69 |
+
args = parser.parse_args()
|
70 |
+
main(**vars(args))
|
sparktts/models/audio_tokenizer.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from pathlib import Path
|
21 |
+
from typing import Any, Dict, Tuple
|
22 |
+
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
|
23 |
+
|
24 |
+
from sparktts.utils.file import load_config
|
25 |
+
from sparktts.utils.audio import load_audio
|
26 |
+
from sparktts.models.bicodec import BiCodec
|
27 |
+
|
28 |
+
|
29 |
+
class BiCodecTokenizer:
|
30 |
+
"""BiCodec tokenizer for handling audio input and tokenization."""
|
31 |
+
|
32 |
+
def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
|
33 |
+
super().__init__()
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
model_dir: Path to the model directory.
|
37 |
+
device: Device to run the model on (default is GPU if available).
|
38 |
+
"""
|
39 |
+
self.device = "cpu"
|
40 |
+
self.model_dir = model_dir
|
41 |
+
self.config = load_config(f"{model_dir}/config.yaml")
|
42 |
+
self._initialize_model()
|
43 |
+
|
44 |
+
def _initialize_model(self):
|
45 |
+
"""Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
|
46 |
+
self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
|
47 |
+
self.device
|
48 |
+
)
|
49 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
|
50 |
+
f"{self.model_dir}/wav2vec2-large-xlsr-53"
|
51 |
+
)
|
52 |
+
self.feature_extractor = Wav2Vec2Model.from_pretrained(
|
53 |
+
f"{self.model_dir}/wav2vec2-large-xlsr-53"
|
54 |
+
).to(self.device)
|
55 |
+
self.feature_extractor.config.output_hidden_states = True
|
56 |
+
|
57 |
+
def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
|
58 |
+
"""Get reference audio clip for speaker embedding."""
|
59 |
+
ref_segment_length = (
|
60 |
+
int(self.config["sample_rate"] * self.config["ref_segment_duration"])
|
61 |
+
// self.config["latent_hop_length"]
|
62 |
+
* self.config["latent_hop_length"]
|
63 |
+
)
|
64 |
+
wav_length = len(wav)
|
65 |
+
|
66 |
+
if ref_segment_length > wav_length:
|
67 |
+
# Repeat and truncate to handle insufficient length
|
68 |
+
wav = np.tile(wav, ref_segment_length // wav_length + 1)
|
69 |
+
|
70 |
+
return wav[:ref_segment_length]
|
71 |
+
|
72 |
+
def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:
|
73 |
+
"""load auido and get reference audio from wav path"""
|
74 |
+
wav = load_audio(
|
75 |
+
wav_path,
|
76 |
+
sampling_rate=self.config["sample_rate"],
|
77 |
+
volume_normalize=self.config["volume_normalize"],
|
78 |
+
)
|
79 |
+
|
80 |
+
wav_ref = self.get_ref_clip(wav)
|
81 |
+
|
82 |
+
wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
|
83 |
+
return wav, wav_ref
|
84 |
+
|
85 |
+
def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
|
86 |
+
"""extract wav2vec2 features"""
|
87 |
+
inputs = self.processor(
|
88 |
+
wavs,
|
89 |
+
sampling_rate=16000,
|
90 |
+
return_tensors="pt",
|
91 |
+
padding=True,
|
92 |
+
output_hidden_states=True,
|
93 |
+
).input_values
|
94 |
+
feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
|
95 |
+
feats_mix = (
|
96 |
+
feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
|
97 |
+
) / 3
|
98 |
+
|
99 |
+
return feats_mix
|
100 |
+
|
101 |
+
def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
|
102 |
+
"""tokenize the batch of audio
|
103 |
+
|
104 |
+
Args:
|
105 |
+
batch:
|
106 |
+
wavs (List[np.ndarray]): batch of audio
|
107 |
+
ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
|
111 |
+
global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
|
112 |
+
"""
|
113 |
+
feats = self.extract_wav2vec2_features(batch["wav"])
|
114 |
+
batch["feat"] = feats
|
115 |
+
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
116 |
+
|
117 |
+
return global_tokens, semantic_tokens
|
118 |
+
|
119 |
+
def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
120 |
+
"""tokenize the audio"""
|
121 |
+
wav, ref_wav = self.process_audio(audio_path)
|
122 |
+
feat = self.extract_wav2vec2_features(wav)
|
123 |
+
batch = {
|
124 |
+
"wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
|
125 |
+
"ref_wav": ref_wav.to(self.device),
|
126 |
+
"feat": feat.to(self.device),
|
127 |
+
}
|
128 |
+
semantic_tokens, global_tokens = self.model.tokenize(batch)
|
129 |
+
|
130 |
+
return global_tokens, semantic_tokens
|
131 |
+
|
132 |
+
def detokenize(
|
133 |
+
self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
|
134 |
+
) -> np.array:
|
135 |
+
"""detokenize the tokens to waveform
|
136 |
+
|
137 |
+
Args:
|
138 |
+
global_tokens: global tokens. shape: (batch_size, global_dim)
|
139 |
+
semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
|
143 |
+
"""
|
144 |
+
global_tokens = global_tokens.unsqueeze(1)
|
145 |
+
wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
|
146 |
+
return wav_rec.detach().squeeze().cpu().numpy()
|
147 |
+
|
148 |
+
|
149 |
+
# test
|
150 |
+
if __name__ == "__main__":
|
151 |
+
import soundfile as sf
|
152 |
+
|
153 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
154 |
+
tokenizer = BiCodecTokenizer(
|
155 |
+
model_dir="pretrained_models/Spark-TTS-0.5B",
|
156 |
+
device=device,
|
157 |
+
)
|
158 |
+
wav_path = "example/prompt_audio.wav"
|
159 |
+
|
160 |
+
global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
|
161 |
+
|
162 |
+
wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
|
163 |
+
sf.write("example/prompt_recon.wav", wav_rec, 16000)
|
sparktts/models/bicodec.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import Dict, Any
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
from safetensors.torch import load_file
|
22 |
+
|
23 |
+
from sparktts.utils.file import load_config
|
24 |
+
from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder
|
25 |
+
from sparktts.modules.encoder_decoder.feat_encoder import Encoder
|
26 |
+
from sparktts.modules.encoder_decoder.feat_decoder import Decoder
|
27 |
+
from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator
|
28 |
+
from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
|
29 |
+
|
30 |
+
|
31 |
+
class BiCodec(nn.Module):
|
32 |
+
"""
|
33 |
+
BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
|
34 |
+
quantizer, and wave generator.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
mel_params: Dict[str, Any],
|
40 |
+
encoder: nn.Module,
|
41 |
+
decoder: nn.Module,
|
42 |
+
quantizer: nn.Module,
|
43 |
+
speaker_encoder: nn.Module,
|
44 |
+
prenet: nn.Module,
|
45 |
+
postnet: nn.Module,
|
46 |
+
**kwargs
|
47 |
+
) -> None:
|
48 |
+
"""
|
49 |
+
Initializes the BiCodec model with the required components.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
mel_params (dict): Parameters for the mel-spectrogram transformer.
|
53 |
+
encoder (nn.Module): Encoder module.
|
54 |
+
decoder (nn.Module): Decoder module.
|
55 |
+
quantizer (nn.Module): Quantizer module.
|
56 |
+
speaker_encoder (nn.Module): Speaker encoder module.
|
57 |
+
prenet (nn.Module): Prenet network.
|
58 |
+
postnet (nn.Module): Postnet network.
|
59 |
+
"""
|
60 |
+
super().__init__()
|
61 |
+
self.encoder = encoder
|
62 |
+
self.decoder = decoder
|
63 |
+
self.quantizer = quantizer
|
64 |
+
self.speaker_encoder = speaker_encoder
|
65 |
+
self.prenet = prenet
|
66 |
+
self.postnet = postnet
|
67 |
+
self.init_mel_transformer(mel_params)
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
|
71 |
+
"""
|
72 |
+
Loads the model from a checkpoint.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
model_dir (Path): Path to the model directory containing checkpoint and config.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
BiCodec: The initialized BiCodec model.
|
79 |
+
"""
|
80 |
+
ckpt_path = f'{model_dir}/model.safetensors'
|
81 |
+
config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
|
82 |
+
mel_params = config["mel_params"]
|
83 |
+
encoder = Encoder(**config["encoder"])
|
84 |
+
quantizer = FactorizedVectorQuantize(**config["quantizer"])
|
85 |
+
prenet = Decoder(**config["prenet"])
|
86 |
+
postnet = Decoder(**config["postnet"])
|
87 |
+
decoder = WaveGenerator(**config["decoder"])
|
88 |
+
speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
|
89 |
+
|
90 |
+
model = cls(
|
91 |
+
mel_params=mel_params,
|
92 |
+
encoder=encoder,
|
93 |
+
decoder=decoder,
|
94 |
+
quantizer=quantizer,
|
95 |
+
speaker_encoder=speaker_encoder,
|
96 |
+
prenet=prenet,
|
97 |
+
postnet=postnet,
|
98 |
+
)
|
99 |
+
|
100 |
+
state_dict = load_file(ckpt_path)
|
101 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
102 |
+
|
103 |
+
for key in missing_keys:
|
104 |
+
print(f"Missing tensor: {key}")
|
105 |
+
for key in unexpected_keys:
|
106 |
+
print(f"Unexpected tensor: {key}")
|
107 |
+
|
108 |
+
model.eval()
|
109 |
+
model.remove_weight_norm()
|
110 |
+
|
111 |
+
return model
|
112 |
+
|
113 |
+
def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
114 |
+
"""
|
115 |
+
Performs a forward pass through the model.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
batch (dict): A dictionary containing features, reference waveform, and target waveform.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
dict: A dictionary containing the reconstruction, features, and other metrics.
|
122 |
+
"""
|
123 |
+
feat = batch["feat"]
|
124 |
+
mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
|
125 |
+
|
126 |
+
z = self.encoder(feat.transpose(1, 2))
|
127 |
+
vq_outputs = self.quantizer(z)
|
128 |
+
|
129 |
+
x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))
|
130 |
+
|
131 |
+
conditions = d_vector
|
132 |
+
with_speaker_loss = False
|
133 |
+
|
134 |
+
x = self.prenet(vq_outputs["z_q"], conditions)
|
135 |
+
pred_feat = self.postnet(x)
|
136 |
+
x = x + conditions.unsqueeze(-1)
|
137 |
+
wav_recon = self.decoder(x)
|
138 |
+
|
139 |
+
return {
|
140 |
+
"vq_loss": vq_outputs["vq_loss"],
|
141 |
+
"perplexity": vq_outputs["perplexity"],
|
142 |
+
"cluster_size": vq_outputs["active_num"],
|
143 |
+
"recons": wav_recon,
|
144 |
+
"pred_feat": pred_feat,
|
145 |
+
"x_vector": x_vector,
|
146 |
+
"d_vector": d_vector,
|
147 |
+
"audios": batch["wav"].unsqueeze(1),
|
148 |
+
"with_speaker_loss": with_speaker_loss,
|
149 |
+
}
|
150 |
+
|
151 |
+
@torch.no_grad()
|
152 |
+
def tokenize(self, batch: Dict[str, Any]):
|
153 |
+
"""
|
154 |
+
Tokenizes the input audio into semantic and global tokens.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
batch (dict): The input audio features and reference waveform.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
tuple: Semantic tokens and global tokens.
|
161 |
+
"""
|
162 |
+
feat = batch["feat"]
|
163 |
+
mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
|
164 |
+
|
165 |
+
z = self.encoder(feat.transpose(1, 2))
|
166 |
+
semantic_tokens = self.quantizer.tokenize(z)
|
167 |
+
global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))
|
168 |
+
|
169 |
+
return semantic_tokens, global_tokens
|
170 |
+
|
171 |
+
@torch.no_grad()
|
172 |
+
def detokenize(self, semantic_tokens, global_tokens):
|
173 |
+
"""
|
174 |
+
Detokenizes the semantic and global tokens into a waveform.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
semantic_tokens (tensor): Semantic tokens.
|
178 |
+
global_tokens (tensor): Global tokens.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
tensor: Reconstructed waveform.
|
182 |
+
"""
|
183 |
+
z_q = self.quantizer.detokenize(semantic_tokens)
|
184 |
+
d_vector = self.speaker_encoder.detokenize(global_tokens)
|
185 |
+
x = self.prenet(z_q, d_vector)
|
186 |
+
x = x + d_vector.unsqueeze(-1)
|
187 |
+
wav_recon = self.decoder(x)
|
188 |
+
|
189 |
+
return wav_recon
|
190 |
+
|
191 |
+
def init_mel_transformer(self, config: Dict[str, Any]):
|
192 |
+
"""
|
193 |
+
Initializes the MelSpectrogram transformer based on the provided configuration.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
config (dict): Configuration parameters for MelSpectrogram.
|
197 |
+
"""
|
198 |
+
import torchaudio.transforms as TT
|
199 |
+
|
200 |
+
self.mel_transformer = TT.MelSpectrogram(
|
201 |
+
config["sample_rate"],
|
202 |
+
config["n_fft"],
|
203 |
+
config["win_length"],
|
204 |
+
config["hop_length"],
|
205 |
+
config["mel_fmin"],
|
206 |
+
config["mel_fmax"],
|
207 |
+
n_mels=config["num_mels"],
|
208 |
+
power=1,
|
209 |
+
norm="slaney",
|
210 |
+
mel_scale="slaney",
|
211 |
+
)
|
212 |
+
|
213 |
+
def remove_weight_norm(self):
|
214 |
+
"""Removes weight normalization from all layers."""
|
215 |
+
def _remove_weight_norm(m):
|
216 |
+
try:
|
217 |
+
torch.nn.utils.remove_weight_norm(m)
|
218 |
+
except ValueError:
|
219 |
+
pass # The module didn't have weight norm
|
220 |
+
|
221 |
+
self.apply(_remove_weight_norm)
|
222 |
+
|
223 |
+
|
224 |
+
# Test the model
|
225 |
+
if __name__ == "__main__":
|
226 |
+
|
227 |
+
config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
|
228 |
+
model = BiCodec.load_from_checkpoint(
|
229 |
+
model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
|
230 |
+
)
|
231 |
+
|
232 |
+
# Generate random inputs for testing
|
233 |
+
duration = 0.96
|
234 |
+
x = torch.randn(20, 1, int(duration * 16000))
|
235 |
+
feat = torch.randn(20, int(duration * 50), 1024)
|
236 |
+
inputs = {"feat": feat, "wav": x, "ref_wav": x}
|
237 |
+
|
238 |
+
# Forward pass
|
239 |
+
outputs = model(inputs)
|
240 |
+
semantic_tokens, global_tokens = model.tokenize(inputs)
|
241 |
+
wav_recon = model.detokenize(semantic_tokens, global_tokens)
|
242 |
+
|
243 |
+
# Verify if the reconstruction matches
|
244 |
+
if torch.allclose(outputs["recons"].detach(), wav_recon):
|
245 |
+
print("Test successful")
|
246 |
+
else:
|
247 |
+
print("Test failed")
|
sparktts/modules/blocks/layers.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from torch.nn.utils import weight_norm
|
22 |
+
|
23 |
+
|
24 |
+
def WNConv1d(*args, **kwargs):
|
25 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
26 |
+
|
27 |
+
|
28 |
+
def WNConvTranspose1d(*args, **kwargs):
|
29 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
30 |
+
|
31 |
+
|
32 |
+
# Scripting this brings model speed up 1.4x
|
33 |
+
@torch.jit.script
|
34 |
+
def snake(x, alpha):
|
35 |
+
shape = x.shape
|
36 |
+
x = x.reshape(shape[0], shape[1], -1)
|
37 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
38 |
+
x = x.reshape(shape)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class Snake1d(nn.Module):
|
43 |
+
def __init__(self, channels):
|
44 |
+
super().__init__()
|
45 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
return snake(x, self.alpha)
|
49 |
+
|
50 |
+
|
51 |
+
class ResidualUnit(nn.Module):
|
52 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
53 |
+
super().__init__()
|
54 |
+
pad = ((7 - 1) * dilation) // 2
|
55 |
+
self.block = nn.Sequential(
|
56 |
+
Snake1d(dim),
|
57 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
58 |
+
Snake1d(dim),
|
59 |
+
WNConv1d(dim, dim, kernel_size=1),
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
y = self.block(x)
|
64 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
65 |
+
if pad > 0:
|
66 |
+
x = x[..., pad:-pad]
|
67 |
+
return x + y
|
68 |
+
|
69 |
+
|
70 |
+
def init_weights(m):
|
71 |
+
if isinstance(m, nn.Conv1d):
|
72 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
73 |
+
nn.init.constant_(m.bias, 0)
|
sparktts/modules/blocks/samper.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
|
22 |
+
class SamplingBlock(nn.Module):
|
23 |
+
"""Sampling block for upsampling or downsampling"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
dim: int,
|
28 |
+
groups: int = 1,
|
29 |
+
upsample_scale: int = 1,
|
30 |
+
downsample_scale: int = 1,
|
31 |
+
) -> None:
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
dim: input dimension
|
35 |
+
groups: number of groups
|
36 |
+
upsample_scale: upsampling scale
|
37 |
+
downsample_scale: downsampling scale
|
38 |
+
"""
|
39 |
+
super(SamplingBlock, self).__init__()
|
40 |
+
|
41 |
+
self.upsample_scale = upsample_scale
|
42 |
+
self.downsample_scale = downsample_scale
|
43 |
+
|
44 |
+
if self.upsample_scale > 1:
|
45 |
+
self.de_conv_upsampler = nn.Sequential(
|
46 |
+
nn.LeakyReLU(0.2),
|
47 |
+
nn.ConvTranspose1d(
|
48 |
+
dim,
|
49 |
+
dim,
|
50 |
+
kernel_size=upsample_scale * 2,
|
51 |
+
stride=upsample_scale,
|
52 |
+
padding=upsample_scale // 2 + upsample_scale % 2,
|
53 |
+
output_padding=upsample_scale % 2,
|
54 |
+
groups=groups,
|
55 |
+
),
|
56 |
+
)
|
57 |
+
|
58 |
+
if self.downsample_scale > 1:
|
59 |
+
self.conv_downsampler = nn.Sequential(
|
60 |
+
nn.LeakyReLU(0.2),
|
61 |
+
nn.Conv1d(
|
62 |
+
dim,
|
63 |
+
dim,
|
64 |
+
kernel_size=2 * downsample_scale,
|
65 |
+
stride=downsample_scale,
|
66 |
+
padding=downsample_scale // 2 + downsample_scale % 2,
|
67 |
+
groups=groups,
|
68 |
+
),
|
69 |
+
)
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def repeat_upsampler(x, upsample_scale):
|
73 |
+
return x.repeat_interleave(upsample_scale, dim=2)
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def skip_downsampler(x, downsample_scale):
|
77 |
+
return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
x = x.transpose(1, 2)
|
81 |
+
if self.upsample_scale > 1:
|
82 |
+
repeat_res = self.repeat_upsampler(x, self.upsample_scale)
|
83 |
+
deconv_res = self.de_conv_upsampler(x)
|
84 |
+
upmerge_res = repeat_res + deconv_res
|
85 |
+
else:
|
86 |
+
upmerge_res = x
|
87 |
+
repeat_res = x
|
88 |
+
|
89 |
+
if self.downsample_scale > 1:
|
90 |
+
conv_res = self.conv_downsampler(upmerge_res)
|
91 |
+
skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
|
92 |
+
skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
|
93 |
+
else:
|
94 |
+
conv_res = upmerge_res
|
95 |
+
skip2_res = upmerge_res
|
96 |
+
skip1_res = repeat_res
|
97 |
+
|
98 |
+
final_res = conv_res + skip1_res + skip2_res
|
99 |
+
|
100 |
+
return final_res
|
101 |
+
|
102 |
+
|
103 |
+
# test
|
104 |
+
if __name__ == "__main__":
|
105 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
106 |
+
model = SamplingBlock(1024, 1024, upsample_scale=2)
|
107 |
+
model_down = SamplingBlock(1024, 1024, downsample_scale=2)
|
108 |
+
output = model(test_input)
|
109 |
+
output_down = model_down(test_input)
|
110 |
+
print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100])
|
111 |
+
print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25])
|
112 |
+
if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
|
113 |
+
[8, 1024, 25]
|
114 |
+
):
|
115 |
+
print("test successful")
|
sparktts/modules/blocks/vocos.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from typing import Tuple
|
21 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
22 |
+
|
23 |
+
from typing import Optional
|
24 |
+
|
25 |
+
|
26 |
+
class ConvNeXtBlock(nn.Module):
|
27 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
dim (int): Number of input channels.
|
31 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
32 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
33 |
+
Defaults to None.
|
34 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
35 |
+
None means non-conditional LayerNorm. Defaults to None.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
dim: int,
|
41 |
+
intermediate_dim: int,
|
42 |
+
layer_scale_init_value: float,
|
43 |
+
condition_dim: Optional[int] = None,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.dwconv = nn.Conv1d(
|
47 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
48 |
+
) # depthwise conv
|
49 |
+
self.adanorm = condition_dim is not None
|
50 |
+
if condition_dim:
|
51 |
+
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
|
52 |
+
else:
|
53 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
54 |
+
self.pwconv1 = nn.Linear(
|
55 |
+
dim, intermediate_dim
|
56 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
57 |
+
self.act = nn.GELU()
|
58 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
59 |
+
self.gamma = (
|
60 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
61 |
+
if layer_scale_init_value > 0
|
62 |
+
else None
|
63 |
+
)
|
64 |
+
|
65 |
+
def forward(
|
66 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
67 |
+
) -> torch.Tensor:
|
68 |
+
residual = x
|
69 |
+
x = self.dwconv(x)
|
70 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
71 |
+
if self.adanorm:
|
72 |
+
assert cond_embedding_id is not None
|
73 |
+
x = self.norm(x, cond_embedding_id)
|
74 |
+
else:
|
75 |
+
x = self.norm(x)
|
76 |
+
x = self.pwconv1(x)
|
77 |
+
x = self.act(x)
|
78 |
+
x = self.pwconv2(x)
|
79 |
+
if self.gamma is not None:
|
80 |
+
x = self.gamma * x
|
81 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
82 |
+
|
83 |
+
x = residual + x
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class AdaLayerNorm(nn.Module):
|
88 |
+
"""
|
89 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
90 |
+
|
91 |
+
Args:
|
92 |
+
condition_dim (int): Dimension of the condition.
|
93 |
+
embedding_dim (int): Dimension of the embeddings.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
|
97 |
+
super().__init__()
|
98 |
+
self.eps = eps
|
99 |
+
self.dim = embedding_dim
|
100 |
+
self.scale = nn.Linear(condition_dim, embedding_dim)
|
101 |
+
self.shift = nn.Linear(condition_dim, embedding_dim)
|
102 |
+
torch.nn.init.ones_(self.scale.weight)
|
103 |
+
torch.nn.init.zeros_(self.shift.weight)
|
104 |
+
|
105 |
+
def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
|
106 |
+
scale = self.scale(cond_embedding)
|
107 |
+
shift = self.shift(cond_embedding)
|
108 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
109 |
+
x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class ResBlock1(nn.Module):
|
114 |
+
"""
|
115 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
116 |
+
but without upsampling layers.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
dim (int): Number of input channels.
|
120 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
121 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
122 |
+
Defaults to (1, 3, 5).
|
123 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
124 |
+
Defaults to 0.1.
|
125 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
126 |
+
Defaults to None.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
dim: int,
|
132 |
+
kernel_size: int = 3,
|
133 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
134 |
+
lrelu_slope: float = 0.1,
|
135 |
+
layer_scale_init_value: Optional[float] = None,
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
self.lrelu_slope = lrelu_slope
|
139 |
+
self.convs1 = nn.ModuleList(
|
140 |
+
[
|
141 |
+
weight_norm(
|
142 |
+
nn.Conv1d(
|
143 |
+
dim,
|
144 |
+
dim,
|
145 |
+
kernel_size,
|
146 |
+
1,
|
147 |
+
dilation=dilation[0],
|
148 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
149 |
+
)
|
150 |
+
),
|
151 |
+
weight_norm(
|
152 |
+
nn.Conv1d(
|
153 |
+
dim,
|
154 |
+
dim,
|
155 |
+
kernel_size,
|
156 |
+
1,
|
157 |
+
dilation=dilation[1],
|
158 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
159 |
+
)
|
160 |
+
),
|
161 |
+
weight_norm(
|
162 |
+
nn.Conv1d(
|
163 |
+
dim,
|
164 |
+
dim,
|
165 |
+
kernel_size,
|
166 |
+
1,
|
167 |
+
dilation=dilation[2],
|
168 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
169 |
+
)
|
170 |
+
),
|
171 |
+
]
|
172 |
+
)
|
173 |
+
|
174 |
+
self.convs2 = nn.ModuleList(
|
175 |
+
[
|
176 |
+
weight_norm(
|
177 |
+
nn.Conv1d(
|
178 |
+
dim,
|
179 |
+
dim,
|
180 |
+
kernel_size,
|
181 |
+
1,
|
182 |
+
dilation=1,
|
183 |
+
padding=self.get_padding(kernel_size, 1),
|
184 |
+
)
|
185 |
+
),
|
186 |
+
weight_norm(
|
187 |
+
nn.Conv1d(
|
188 |
+
dim,
|
189 |
+
dim,
|
190 |
+
kernel_size,
|
191 |
+
1,
|
192 |
+
dilation=1,
|
193 |
+
padding=self.get_padding(kernel_size, 1),
|
194 |
+
)
|
195 |
+
),
|
196 |
+
weight_norm(
|
197 |
+
nn.Conv1d(
|
198 |
+
dim,
|
199 |
+
dim,
|
200 |
+
kernel_size,
|
201 |
+
1,
|
202 |
+
dilation=1,
|
203 |
+
padding=self.get_padding(kernel_size, 1),
|
204 |
+
)
|
205 |
+
),
|
206 |
+
]
|
207 |
+
)
|
208 |
+
|
209 |
+
self.gamma = nn.ParameterList(
|
210 |
+
[
|
211 |
+
(
|
212 |
+
nn.Parameter(
|
213 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
214 |
+
)
|
215 |
+
if layer_scale_init_value is not None
|
216 |
+
else None
|
217 |
+
),
|
218 |
+
(
|
219 |
+
nn.Parameter(
|
220 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
221 |
+
)
|
222 |
+
if layer_scale_init_value is not None
|
223 |
+
else None
|
224 |
+
),
|
225 |
+
(
|
226 |
+
nn.Parameter(
|
227 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
228 |
+
)
|
229 |
+
if layer_scale_init_value is not None
|
230 |
+
else None
|
231 |
+
),
|
232 |
+
]
|
233 |
+
)
|
234 |
+
|
235 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
236 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
237 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
238 |
+
xt = c1(xt)
|
239 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
240 |
+
xt = c2(xt)
|
241 |
+
if gamma is not None:
|
242 |
+
xt = gamma * xt
|
243 |
+
x = xt + x
|
244 |
+
return x
|
245 |
+
|
246 |
+
def remove_weight_norm(self):
|
247 |
+
for l in self.convs1:
|
248 |
+
remove_weight_norm(l)
|
249 |
+
for l in self.convs2:
|
250 |
+
remove_weight_norm(l)
|
251 |
+
|
252 |
+
@staticmethod
|
253 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
254 |
+
return int((kernel_size * dilation - dilation) / 2)
|
255 |
+
|
256 |
+
|
257 |
+
class Backbone(nn.Module):
|
258 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
259 |
+
|
260 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
261 |
+
"""
|
262 |
+
Args:
|
263 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
264 |
+
C denotes output features, and L is the sequence length.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
268 |
+
and H denotes the model dimension.
|
269 |
+
"""
|
270 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
271 |
+
|
272 |
+
|
273 |
+
class VocosBackbone(Backbone):
|
274 |
+
"""
|
275 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
276 |
+
|
277 |
+
Args:
|
278 |
+
input_channels (int): Number of input features channels.
|
279 |
+
dim (int): Hidden dimension of the model.
|
280 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
281 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
282 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
283 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
284 |
+
None means non-conditional model. Defaults to None.
|
285 |
+
"""
|
286 |
+
|
287 |
+
def __init__(
|
288 |
+
self,
|
289 |
+
input_channels: int,
|
290 |
+
dim: int,
|
291 |
+
intermediate_dim: int,
|
292 |
+
num_layers: int,
|
293 |
+
layer_scale_init_value: Optional[float] = None,
|
294 |
+
condition_dim: Optional[int] = None,
|
295 |
+
):
|
296 |
+
super().__init__()
|
297 |
+
self.input_channels = input_channels
|
298 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
299 |
+
self.adanorm = condition_dim is not None
|
300 |
+
if condition_dim:
|
301 |
+
self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
|
302 |
+
else:
|
303 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
304 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
305 |
+
self.convnext = nn.ModuleList(
|
306 |
+
[
|
307 |
+
ConvNeXtBlock(
|
308 |
+
dim=dim,
|
309 |
+
intermediate_dim=intermediate_dim,
|
310 |
+
layer_scale_init_value=layer_scale_init_value,
|
311 |
+
condition_dim=condition_dim,
|
312 |
+
)
|
313 |
+
for _ in range(num_layers)
|
314 |
+
]
|
315 |
+
)
|
316 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
317 |
+
self.apply(self._init_weights)
|
318 |
+
|
319 |
+
def _init_weights(self, m):
|
320 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
321 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
322 |
+
nn.init.constant_(m.bias, 0)
|
323 |
+
|
324 |
+
def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
|
325 |
+
x = self.embed(x)
|
326 |
+
if self.adanorm:
|
327 |
+
assert condition is not None
|
328 |
+
x = self.norm(x.transpose(1, 2), condition)
|
329 |
+
else:
|
330 |
+
x = self.norm(x.transpose(1, 2))
|
331 |
+
x = x.transpose(1, 2)
|
332 |
+
for conv_block in self.convnext:
|
333 |
+
x = conv_block(x, condition)
|
334 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
335 |
+
return x
|
336 |
+
|
337 |
+
|
338 |
+
class VocosResNetBackbone(Backbone):
|
339 |
+
"""
|
340 |
+
Vocos backbone module built with ResBlocks.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
input_channels (int): Number of input features channels.
|
344 |
+
dim (int): Hidden dimension of the model.
|
345 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
346 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
347 |
+
"""
|
348 |
+
|
349 |
+
def __init__(
|
350 |
+
self,
|
351 |
+
input_channels,
|
352 |
+
dim,
|
353 |
+
num_blocks,
|
354 |
+
layer_scale_init_value=None,
|
355 |
+
):
|
356 |
+
super().__init__()
|
357 |
+
self.input_channels = input_channels
|
358 |
+
self.embed = weight_norm(
|
359 |
+
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
360 |
+
)
|
361 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
362 |
+
self.resnet = nn.Sequential(
|
363 |
+
*[
|
364 |
+
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
365 |
+
for _ in range(num_blocks)
|
366 |
+
]
|
367 |
+
)
|
368 |
+
|
369 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
370 |
+
x = self.embed(x)
|
371 |
+
x = self.resnet(x)
|
372 |
+
x = x.transpose(1, 2)
|
373 |
+
return x
|
sparktts/modules/encoder_decoder/feat_decoder.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from typing import List
|
21 |
+
|
22 |
+
from sparktts.modules.blocks.vocos import VocosBackbone
|
23 |
+
from sparktts.modules.blocks.samper import SamplingBlock
|
24 |
+
|
25 |
+
|
26 |
+
class Decoder(nn.Module):
|
27 |
+
"""Decoder module with convnext and upsampling blocks
|
28 |
+
|
29 |
+
Args:
|
30 |
+
sample_ratios (List[int]): sample ratios
|
31 |
+
example: [2, 2] means downsample by 2x and then upsample by 2x
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
input_channels: int,
|
37 |
+
vocos_dim: int,
|
38 |
+
vocos_intermediate_dim: int,
|
39 |
+
vocos_num_layers: int,
|
40 |
+
out_channels: int,
|
41 |
+
condition_dim: int = None,
|
42 |
+
sample_ratios: List[int] = [1, 1],
|
43 |
+
use_tanh_at_final: bool = False,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.linear_pre = nn.Linear(input_channels, vocos_dim)
|
48 |
+
modules = [
|
49 |
+
nn.Sequential(
|
50 |
+
SamplingBlock(
|
51 |
+
dim=vocos_dim,
|
52 |
+
groups=vocos_dim,
|
53 |
+
upsample_scale=ratio,
|
54 |
+
),
|
55 |
+
VocosBackbone(
|
56 |
+
input_channels=vocos_dim,
|
57 |
+
dim=vocos_dim,
|
58 |
+
intermediate_dim=vocos_intermediate_dim,
|
59 |
+
num_layers=2,
|
60 |
+
condition_dim=None,
|
61 |
+
),
|
62 |
+
)
|
63 |
+
for ratio in sample_ratios
|
64 |
+
]
|
65 |
+
|
66 |
+
self.downsample = nn.Sequential(*modules)
|
67 |
+
|
68 |
+
self.vocos_backbone = VocosBackbone(
|
69 |
+
input_channels=vocos_dim,
|
70 |
+
dim=vocos_dim,
|
71 |
+
intermediate_dim=vocos_intermediate_dim,
|
72 |
+
num_layers=vocos_num_layers,
|
73 |
+
condition_dim=condition_dim,
|
74 |
+
)
|
75 |
+
self.linear = nn.Linear(vocos_dim, out_channels)
|
76 |
+
self.use_tanh_at_final = use_tanh_at_final
|
77 |
+
|
78 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor = None):
|
79 |
+
"""encoder forward.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
x (torch.Tensor): (batch_size, input_channels, length)
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
x (torch.Tensor): (batch_size, encode_channels, length)
|
86 |
+
"""
|
87 |
+
x = self.linear_pre(x.transpose(1, 2))
|
88 |
+
x = self.downsample(x).transpose(1, 2)
|
89 |
+
x = self.vocos_backbone(x, condition=c)
|
90 |
+
x = self.linear(x).transpose(1, 2)
|
91 |
+
if self.use_tanh_at_final:
|
92 |
+
x = torch.tanh(x)
|
93 |
+
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
# test
|
98 |
+
if __name__ == "__main__":
|
99 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
100 |
+
condition = torch.randn(8, 256)
|
101 |
+
decoder = Decoder(
|
102 |
+
input_channels=1024,
|
103 |
+
vocos_dim=384,
|
104 |
+
vocos_intermediate_dim=2048,
|
105 |
+
vocos_num_layers=12,
|
106 |
+
out_channels=256,
|
107 |
+
condition_dim=256,
|
108 |
+
sample_ratios=[2, 2],
|
109 |
+
)
|
110 |
+
output = decoder(test_input, condition)
|
111 |
+
print(output.shape) # torch.Size([8, 256, 200])
|
112 |
+
if output.shape == torch.Size([8, 256, 200]):
|
113 |
+
print("Decoder test passed")
|
114 |
+
else:
|
115 |
+
print("Decoder test failed")
|
sparktts/modules/encoder_decoder/feat_encoder.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from typing import List
|
21 |
+
|
22 |
+
from sparktts.modules.blocks.vocos import VocosBackbone
|
23 |
+
from sparktts.modules.blocks.samper import SamplingBlock
|
24 |
+
|
25 |
+
|
26 |
+
class Encoder(nn.Module):
|
27 |
+
"""Encoder module with convnext and downsampling blocks"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
input_channels: int,
|
32 |
+
vocos_dim: int,
|
33 |
+
vocos_intermediate_dim: int,
|
34 |
+
vocos_num_layers: int,
|
35 |
+
out_channels: int,
|
36 |
+
sample_ratios: List[int] = [1, 1],
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
"""
|
40 |
+
Encoder module with VocosBackbone and sampling blocks.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
sample_ratios (List[int]): sample ratios
|
44 |
+
example: [2, 2] means downsample by 2x and then upsample by 2x
|
45 |
+
"""
|
46 |
+
self.encoder = VocosBackbone(
|
47 |
+
input_channels=input_channels,
|
48 |
+
dim=vocos_dim,
|
49 |
+
intermediate_dim=vocos_intermediate_dim,
|
50 |
+
num_layers=vocos_num_layers,
|
51 |
+
condition_dim=None,
|
52 |
+
)
|
53 |
+
|
54 |
+
modules = [
|
55 |
+
nn.Sequential(
|
56 |
+
SamplingBlock(
|
57 |
+
dim=vocos_dim,
|
58 |
+
groups=vocos_dim,
|
59 |
+
downsample_scale=ratio,
|
60 |
+
),
|
61 |
+
VocosBackbone(
|
62 |
+
input_channels=vocos_dim,
|
63 |
+
dim=vocos_dim,
|
64 |
+
intermediate_dim=vocos_intermediate_dim,
|
65 |
+
num_layers=2,
|
66 |
+
condition_dim=None,
|
67 |
+
),
|
68 |
+
)
|
69 |
+
for ratio in sample_ratios
|
70 |
+
]
|
71 |
+
|
72 |
+
self.downsample = nn.Sequential(*modules)
|
73 |
+
|
74 |
+
self.project = nn.Linear(vocos_dim, out_channels)
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor, *args):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
x (torch.Tensor): (batch_size, input_channels, length)
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
x (torch.Tensor): (batch_size, encode_channels, length)
|
83 |
+
"""
|
84 |
+
x = self.encoder(x)
|
85 |
+
x = self.downsample(x)
|
86 |
+
x = self.project(x)
|
87 |
+
return x.transpose(1, 2)
|
88 |
+
|
89 |
+
|
90 |
+
# test
|
91 |
+
if __name__ == "__main__":
|
92 |
+
test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
|
93 |
+
encoder = Encoder(
|
94 |
+
input_channels=1024,
|
95 |
+
vocos_dim=384,
|
96 |
+
vocos_intermediate_dim=2048,
|
97 |
+
vocos_num_layers=12,
|
98 |
+
out_channels=256,
|
99 |
+
sample_ratios=[2, 2],
|
100 |
+
)
|
101 |
+
|
102 |
+
output = encoder(test_input)
|
103 |
+
print(output.shape) # torch.Size([8, 256, 12])
|
104 |
+
if output.shape == torch.Size([8, 256, 12]):
|
105 |
+
print("test successful")
|
sparktts/modules/encoder_decoder/wave_generator.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Xinsheng Wang ([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
|
16 |
+
|
17 |
+
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from sparktts.modules.blocks.layers import (
|
21 |
+
Snake1d,
|
22 |
+
WNConv1d,
|
23 |
+
ResidualUnit,
|
24 |
+
WNConvTranspose1d,
|
25 |
+
init_weights,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class DecoderBlock(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
input_dim: int = 16,
|
33 |
+
output_dim: int = 8,
|
34 |
+
kernel_size: int = 2,
|
35 |
+
stride: int = 1,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
self.block = nn.Sequential(
|
39 |
+
Snake1d(input_dim),
|
40 |
+
WNConvTranspose1d(
|
41 |
+
input_dim,
|
42 |
+
output_dim,
|
43 |
+
kernel_size=kernel_size,
|
44 |
+
stride=stride,
|
45 |
+
padding=(kernel_size - stride) // 2,
|
46 |
+
),
|
47 |
+
ResidualUnit(output_dim, dilation=1),
|
48 |
+
ResidualUnit(output_dim, dilation=3),
|
49 |
+
ResidualUnit(output_dim, dilation=9),
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
return self.block(x)
|
54 |
+
|
55 |
+
|
56 |
+
class WaveGenerator(nn.Module):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
input_channel,
|
60 |
+
channels,
|
61 |
+
rates,
|
62 |
+
kernel_sizes,
|
63 |
+
d_out: int = 1,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
# Add first conv layer
|
68 |
+
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
69 |
+
|
70 |
+
# Add upsampling + MRF blocks
|
71 |
+
for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
|
72 |
+
input_dim = channels // 2**i
|
73 |
+
output_dim = channels // 2 ** (i + 1)
|
74 |
+
layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
|
75 |
+
|
76 |
+
# Add final conv layer
|
77 |
+
layers += [
|
78 |
+
Snake1d(output_dim),
|
79 |
+
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
80 |
+
nn.Tanh(),
|
81 |
+
]
|
82 |
+
|
83 |
+
self.model = nn.Sequential(*layers)
|
84 |
+
|
85 |
+
self.apply(init_weights)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
return self.model(x)
|
sparktts/modules/fsq/finite_scalar_quantization.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
3 |
+
Code adapted from Jax version in Appendix A.1
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
from functools import wraps, partial
|
8 |
+
from contextlib import nullcontext
|
9 |
+
from typing import List, Tuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import Module
|
14 |
+
from torch import Tensor, int32
|
15 |
+
from torch.amp import autocast
|
16 |
+
|
17 |
+
from einops import rearrange, pack, unpack
|
18 |
+
|
19 |
+
# helper functions
|
20 |
+
|
21 |
+
|
22 |
+
def exists(v):
|
23 |
+
return v is not None
|
24 |
+
|
25 |
+
|
26 |
+
def default(*args):
|
27 |
+
for arg in args:
|
28 |
+
if exists(arg):
|
29 |
+
return arg
|
30 |
+
return None
|
31 |
+
|
32 |
+
|
33 |
+
def maybe(fn):
|
34 |
+
@wraps(fn)
|
35 |
+
def inner(x, *args, **kwargs):
|
36 |
+
if not exists(x):
|
37 |
+
return x
|
38 |
+
return fn(x, *args, **kwargs)
|
39 |
+
|
40 |
+
return inner
|
41 |
+
|
42 |
+
|
43 |
+
def pack_one(t, pattern):
|
44 |
+
return pack([t], pattern)
|
45 |
+
|
46 |
+
|
47 |
+
def unpack_one(t, ps, pattern):
|
48 |
+
return unpack(t, ps, pattern)[0]
|
49 |
+
|
50 |
+
|
51 |
+
# tensor helpers
|
52 |
+
|
53 |
+
|
54 |
+
def round_ste(z: Tensor) -> Tensor:
|
55 |
+
"""Round with straight through gradients."""
|
56 |
+
zhat = z.round()
|
57 |
+
return z + (zhat - z).detach()
|
58 |
+
|
59 |
+
|
60 |
+
# main class
|
61 |
+
|
62 |
+
|
63 |
+
class FSQ(Module):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
levels: List[int],
|
67 |
+
dim: int | None = None,
|
68 |
+
num_codebooks=1,
|
69 |
+
keep_num_codebooks_dim: bool | None = None,
|
70 |
+
scale: float | None = None,
|
71 |
+
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
|
72 |
+
channel_first: bool = False,
|
73 |
+
projection_has_bias: bool = True,
|
74 |
+
return_indices=True,
|
75 |
+
force_quantization_f32=True,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
_levels = torch.tensor(levels, dtype=int32)
|
79 |
+
self.register_buffer("_levels", _levels, persistent=False)
|
80 |
+
|
81 |
+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
82 |
+
self.register_buffer("_basis", _basis, persistent=False)
|
83 |
+
|
84 |
+
self.scale = scale
|
85 |
+
|
86 |
+
codebook_dim = len(levels)
|
87 |
+
self.codebook_dim = codebook_dim
|
88 |
+
|
89 |
+
effective_codebook_dim = codebook_dim * num_codebooks
|
90 |
+
self.num_codebooks = num_codebooks
|
91 |
+
self.effective_codebook_dim = effective_codebook_dim
|
92 |
+
|
93 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
94 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
95 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
96 |
+
|
97 |
+
self.dim = default(dim, len(_levels) * num_codebooks)
|
98 |
+
|
99 |
+
self.channel_first = channel_first
|
100 |
+
|
101 |
+
has_projections = self.dim != effective_codebook_dim
|
102 |
+
self.project_in = (
|
103 |
+
nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
|
104 |
+
if has_projections
|
105 |
+
else nn.Identity()
|
106 |
+
)
|
107 |
+
self.project_out = (
|
108 |
+
nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
|
109 |
+
if has_projections
|
110 |
+
else nn.Identity()
|
111 |
+
)
|
112 |
+
|
113 |
+
self.has_projections = has_projections
|
114 |
+
|
115 |
+
self.return_indices = return_indices
|
116 |
+
if return_indices:
|
117 |
+
self.codebook_size = self._levels.prod().item()
|
118 |
+
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
|
119 |
+
self.register_buffer(
|
120 |
+
"implicit_codebook", implicit_codebook, persistent=False
|
121 |
+
)
|
122 |
+
|
123 |
+
self.allowed_dtypes = allowed_dtypes
|
124 |
+
self.force_quantization_f32 = force_quantization_f32
|
125 |
+
|
126 |
+
def bound(self, z, eps: float = 1e-3):
|
127 |
+
"""Bound `z`, an array of shape (..., d)."""
|
128 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
129 |
+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
130 |
+
shift = (offset / half_l).atanh()
|
131 |
+
return (z + shift).tanh() * half_l - offset
|
132 |
+
|
133 |
+
def quantize(self, z):
|
134 |
+
"""Quantizes z, returns quantized zhat, same shape as z."""
|
135 |
+
quantized = round_ste(self.bound(z))
|
136 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
137 |
+
return quantized / half_width
|
138 |
+
|
139 |
+
def _scale_and_shift(self, zhat_normalized):
|
140 |
+
half_width = self._levels // 2
|
141 |
+
return (zhat_normalized * half_width) + half_width
|
142 |
+
|
143 |
+
def _scale_and_shift_inverse(self, zhat):
|
144 |
+
half_width = self._levels // 2
|
145 |
+
return (zhat - half_width) / half_width
|
146 |
+
|
147 |
+
def _indices_to_codes(self, indices):
|
148 |
+
level_indices = self.indices_to_level_indices(indices)
|
149 |
+
codes = self._scale_and_shift_inverse(level_indices)
|
150 |
+
return codes
|
151 |
+
|
152 |
+
def codes_to_indices(self, zhat):
|
153 |
+
"""Converts a `code` to an index in the codebook."""
|
154 |
+
assert zhat.shape[-1] == self.codebook_dim
|
155 |
+
zhat = self._scale_and_shift(zhat)
|
156 |
+
return (zhat * self._basis).sum(dim=-1).to(int32)
|
157 |
+
|
158 |
+
def indices_to_level_indices(self, indices):
|
159 |
+
"""Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
|
160 |
+
indices = rearrange(indices, "... -> ... 1")
|
161 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
162 |
+
return codes_non_centered
|
163 |
+
|
164 |
+
def indices_to_codes(self, indices):
|
165 |
+
"""Inverse of `codes_to_indices`."""
|
166 |
+
assert exists(indices)
|
167 |
+
|
168 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
169 |
+
|
170 |
+
codes = self._indices_to_codes(indices)
|
171 |
+
|
172 |
+
if self.keep_num_codebooks_dim:
|
173 |
+
codes = rearrange(codes, "... c d -> ... (c d)")
|
174 |
+
|
175 |
+
codes = self.project_out(codes)
|
176 |
+
|
177 |
+
if is_img_or_video or self.channel_first:
|
178 |
+
codes = rearrange(codes, "b ... d -> b d ...")
|
179 |
+
|
180 |
+
return codes
|
181 |
+
|
182 |
+
def forward(self, z):
|
183 |
+
"""
|
184 |
+
einstein notation
|
185 |
+
b - batch
|
186 |
+
n - sequence (or flattened spatial dimensions)
|
187 |
+
d - feature dimension
|
188 |
+
c - number of codebook dim
|
189 |
+
"""
|
190 |
+
|
191 |
+
is_img_or_video = z.ndim >= 4
|
192 |
+
need_move_channel_last = is_img_or_video or self.channel_first
|
193 |
+
|
194 |
+
# standardize image or video into (batch, seq, dimension)
|
195 |
+
|
196 |
+
if need_move_channel_last:
|
197 |
+
z = rearrange(z, "b d ... -> b ... d")
|
198 |
+
z, ps = pack_one(z, "b * d")
|
199 |
+
|
200 |
+
assert (
|
201 |
+
z.shape[-1] == self.dim
|
202 |
+
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
203 |
+
|
204 |
+
z = self.project_in(z)
|
205 |
+
|
206 |
+
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
|
207 |
+
|
208 |
+
# whether to force quantization step to be full precision or not
|
209 |
+
|
210 |
+
force_f32 = self.force_quantization_f32
|
211 |
+
quantization_context = (
|
212 |
+
partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
|
213 |
+
)
|
214 |
+
|
215 |
+
with quantization_context():
|
216 |
+
orig_dtype = z.dtype
|
217 |
+
|
218 |
+
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
219 |
+
z = z.float()
|
220 |
+
|
221 |
+
codes = self.quantize(z)
|
222 |
+
|
223 |
+
# returning indices could be optional
|
224 |
+
|
225 |
+
indices = None
|
226 |
+
|
227 |
+
if self.return_indices:
|
228 |
+
indices = self.codes_to_indices(codes)
|
229 |
+
|
230 |
+
codes = rearrange(codes, "b n c d -> b n (c d)")
|
231 |
+
|
232 |
+
codes = codes.type(orig_dtype)
|
233 |
+
|
234 |
+
# project out
|
235 |
+
|
236 |
+
out = self.project_out(codes)
|
237 |
+
|
238 |
+
# reconstitute image or video dimensions
|
239 |
+
|
240 |
+
if need_move_channel_last:
|
241 |
+
out = unpack_one(out, ps, "b * d")
|
242 |
+
out = rearrange(out, "b ... d -> b d ...")
|
243 |
+
|
244 |
+
indices = maybe(unpack_one)(indices, ps, "b * c")
|
245 |
+
|
246 |
+
if not self.keep_num_codebooks_dim and self.return_indices:
|
247 |
+
indices = maybe(rearrange)(indices, "... 1 -> ...")
|
248 |
+
|
249 |
+
# return quantized output and indices
|
250 |
+
|
251 |
+
return out, indices
|
sparktts/modules/fsq/residual_fsq.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import Module
|
9 |
+
from torch.amp import autocast
|
10 |
+
from einx import get_at
|
11 |
+
from einops import rearrange, reduce, pack, unpack
|
12 |
+
|
13 |
+
from sparktts.modules.fsq.finite_scalar_quantization import FSQ
|
14 |
+
|
15 |
+
|
16 |
+
def exists(val):
|
17 |
+
return val is not None
|
18 |
+
|
19 |
+
|
20 |
+
def first(l):
|
21 |
+
return l[0]
|
22 |
+
|
23 |
+
|
24 |
+
def default(val, d):
|
25 |
+
return val if exists(val) else d
|
26 |
+
|
27 |
+
|
28 |
+
def round_up_multiple(num, mult):
|
29 |
+
return ceil(num / mult) * mult
|
30 |
+
|
31 |
+
|
32 |
+
# distributed helpers
|
33 |
+
|
34 |
+
|
35 |
+
def is_distributed():
|
36 |
+
return dist.is_initialized() and dist.get_world_size() > 1
|
37 |
+
|
38 |
+
|
39 |
+
def get_maybe_sync_seed(device, max_size=10_000):
|
40 |
+
rand_int = torch.randint(0, max_size, (), device=device)
|
41 |
+
|
42 |
+
if is_distributed():
|
43 |
+
dist.all_reduce(rand_int)
|
44 |
+
|
45 |
+
return rand_int.item()
|
46 |
+
|
47 |
+
|
48 |
+
class ResidualFSQ(Module):
|
49 |
+
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
*,
|
54 |
+
levels: List[int],
|
55 |
+
num_quantizers,
|
56 |
+
dim=None,
|
57 |
+
is_channel_first=False,
|
58 |
+
quantize_dropout=False,
|
59 |
+
quantize_dropout_cutoff_index=0,
|
60 |
+
quantize_dropout_multiple_of=1,
|
61 |
+
**kwargs,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
codebook_dim = len(levels)
|
65 |
+
dim = default(dim, codebook_dim)
|
66 |
+
|
67 |
+
requires_projection = codebook_dim != dim
|
68 |
+
self.project_in = (
|
69 |
+
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
70 |
+
)
|
71 |
+
self.project_out = (
|
72 |
+
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
73 |
+
)
|
74 |
+
self.has_projections = requires_projection
|
75 |
+
|
76 |
+
self.is_channel_first = is_channel_first
|
77 |
+
self.num_quantizers = num_quantizers
|
78 |
+
|
79 |
+
self.levels = levels
|
80 |
+
self.layers = nn.ModuleList([])
|
81 |
+
|
82 |
+
levels_tensor = torch.Tensor(levels)
|
83 |
+
|
84 |
+
scales = []
|
85 |
+
|
86 |
+
for ind in range(num_quantizers):
|
87 |
+
scales.append((levels_tensor - 1) ** -ind)
|
88 |
+
|
89 |
+
fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
|
90 |
+
|
91 |
+
self.layers.append(fsq)
|
92 |
+
|
93 |
+
assert all([not fsq.has_projections for fsq in self.layers])
|
94 |
+
|
95 |
+
self.codebook_size = self.layers[0].codebook_size
|
96 |
+
|
97 |
+
self.register_buffer("scales", torch.stack(scales), persistent=False)
|
98 |
+
|
99 |
+
self.quantize_dropout = quantize_dropout and num_quantizers > 1
|
100 |
+
|
101 |
+
assert quantize_dropout_cutoff_index >= 0
|
102 |
+
|
103 |
+
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
|
104 |
+
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
|
105 |
+
|
106 |
+
@property
|
107 |
+
def codebooks(self):
|
108 |
+
codebooks = [layer.implicit_codebook for layer in self.layers]
|
109 |
+
codebooks = torch.stack(codebooks, dim=0)
|
110 |
+
return codebooks
|
111 |
+
|
112 |
+
def get_codes_from_indices(self, indices):
|
113 |
+
|
114 |
+
batch, quantize_dim = indices.shape[0], indices.shape[-1]
|
115 |
+
|
116 |
+
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
|
117 |
+
|
118 |
+
indices, ps = pack([indices], "b * q")
|
119 |
+
|
120 |
+
# because of quantize dropout, one can pass in indices that are coarse
|
121 |
+
# and the network should be able to reconstruct
|
122 |
+
|
123 |
+
if quantize_dim < self.num_quantizers:
|
124 |
+
assert (
|
125 |
+
self.quantize_dropout > 0.0
|
126 |
+
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
127 |
+
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
|
128 |
+
|
129 |
+
# take care of quantizer dropout
|
130 |
+
|
131 |
+
mask = indices == -1
|
132 |
+
indices = indices.masked_fill(
|
133 |
+
mask, 0
|
134 |
+
) # have it fetch a dummy code to be masked out later
|
135 |
+
|
136 |
+
all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
|
137 |
+
|
138 |
+
# mask out any codes that were dropout-ed
|
139 |
+
|
140 |
+
all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
|
141 |
+
|
142 |
+
# scale the codes
|
143 |
+
|
144 |
+
scales = rearrange(self.scales, "q d -> q 1 1 d")
|
145 |
+
all_codes = all_codes * scales
|
146 |
+
|
147 |
+
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
|
148 |
+
|
149 |
+
(all_codes,) = unpack(all_codes, ps, "q b * d")
|
150 |
+
|
151 |
+
return all_codes
|
152 |
+
|
153 |
+
def get_output_from_indices(self, indices):
|
154 |
+
codes = self.get_codes_from_indices(indices)
|
155 |
+
codes_summed = reduce(codes, "q ... -> ...", "sum")
|
156 |
+
return self.project_out(codes_summed)
|
157 |
+
|
158 |
+
def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
|
159 |
+
num_quant, quant_dropout_multiple_of, device = (
|
160 |
+
self.num_quantizers,
|
161 |
+
self.quantize_dropout_multiple_of,
|
162 |
+
x.device,
|
163 |
+
)
|
164 |
+
|
165 |
+
# handle channel first
|
166 |
+
|
167 |
+
if self.is_channel_first:
|
168 |
+
x = rearrange(x, "b d ... -> b ... d")
|
169 |
+
x, ps = pack([x], "b * d")
|
170 |
+
|
171 |
+
# maybe project in
|
172 |
+
|
173 |
+
x = self.project_in(x)
|
174 |
+
|
175 |
+
quantized_out = 0.0
|
176 |
+
residual = x
|
177 |
+
|
178 |
+
all_indices = []
|
179 |
+
|
180 |
+
should_quantize_dropout = self.training and self.quantize_dropout
|
181 |
+
|
182 |
+
# sample a layer index at which to dropout further residual quantization
|
183 |
+
# also prepare null indices
|
184 |
+
|
185 |
+
if should_quantize_dropout:
|
186 |
+
|
187 |
+
# check if seed is manually passed in
|
188 |
+
|
189 |
+
if not exists(rand_quantize_dropout_fixed_seed):
|
190 |
+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
|
191 |
+
|
192 |
+
rand = random.Random(rand_quantize_dropout_fixed_seed)
|
193 |
+
|
194 |
+
rand_quantize_dropout_index = rand.randrange(
|
195 |
+
self.quantize_dropout_cutoff_index, num_quant
|
196 |
+
)
|
197 |
+
|
198 |
+
if quant_dropout_multiple_of != 1:
|
199 |
+
rand_quantize_dropout_index = (
|
200 |
+
round_up_multiple(
|
201 |
+
rand_quantize_dropout_index + 1, quant_dropout_multiple_of
|
202 |
+
)
|
203 |
+
- 1
|
204 |
+
)
|
205 |
+
|
206 |
+
null_indices = torch.full(
|
207 |
+
x.shape[:2], -1.0, device=device, dtype=torch.long
|
208 |
+
)
|
209 |
+
|
210 |
+
# go through the layers
|
211 |
+
|
212 |
+
with autocast("cuda", enabled=False):
|
213 |
+
for quantizer_index, (layer, scale) in enumerate(
|
214 |
+
zip(self.layers, self.scales)
|
215 |
+
):
|
216 |
+
|
217 |
+
if (
|
218 |
+
should_quantize_dropout
|
219 |
+
and quantizer_index > rand_quantize_dropout_index
|
220 |
+
):
|
221 |
+
all_indices.append(null_indices)
|
222 |
+
continue
|
223 |
+
|
224 |
+
quantized, indices = layer(residual / scale)
|
225 |
+
|
226 |
+
quantized = quantized * scale
|
227 |
+
|
228 |
+
residual = residual - quantized.detach()
|
229 |
+
quantized_out = quantized_out + quantized
|
230 |
+
|
231 |
+
all_indices.append(indices)
|
232 |
+
|
233 |
+
# project out, if needed
|
234 |
+
|
235 |
+
quantized_out = self.project_out(quantized_out)
|
236 |
+
|
237 |
+
# stack all indices
|
238 |
+
|
239 |
+
all_indices = torch.stack(all_indices, dim=-1)
|
240 |
+
|
241 |
+
# channel first out
|
242 |
+
|
243 |
+
if self.is_channel_first:
|
244 |
+
(quantized_out,) = unpack(quantized_out, ps, "b * d")
|
245 |
+
(all_indices,) = unpack(all_indices, ps, "b * d")
|
246 |
+
|
247 |
+
quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
|
248 |
+
all_indices = rearrange(all_indices, "b ... d -> b d ...")
|
249 |
+
|
250 |
+
# return
|
251 |
+
|
252 |
+
ret = (quantized_out, all_indices)
|
253 |
+
|
254 |
+
if not return_all_codes:
|
255 |
+
return ret
|
256 |
+
|
257 |
+
# whether to return all codes from all codebooks across layers
|
258 |
+
|
259 |
+
all_codes = self.get_codes_from_indices(all_indices)
|
260 |
+
|
261 |
+
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
|
262 |
+
|
263 |
+
return (*ret, all_codes)
|
264 |
+
|
265 |
+
|
266 |
+
# grouped residual fsq
|
267 |
+
|
268 |
+
|
269 |
+
class GroupedResidualFSQ(Module):
|
270 |
+
def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
|
271 |
+
super().__init__()
|
272 |
+
self.dim = dim
|
273 |
+
self.groups = groups
|
274 |
+
assert (dim % groups) == 0
|
275 |
+
dim_per_group = dim // groups
|
276 |
+
|
277 |
+
self.accept_image_fmap = accept_image_fmap
|
278 |
+
|
279 |
+
self.rvqs = nn.ModuleList([])
|
280 |
+
|
281 |
+
for _ in range(groups):
|
282 |
+
self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
|
283 |
+
|
284 |
+
self.codebook_size = self.rvqs[0].codebook_size
|
285 |
+
|
286 |
+
@property
|
287 |
+
def codebooks(self):
|
288 |
+
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
|
289 |
+
|
290 |
+
@property
|
291 |
+
def split_dim(self):
|
292 |
+
return 1 if self.accept_image_fmap else -1
|
293 |
+
|
294 |
+
def get_codes_from_indices(self, indices):
|
295 |
+
codes = tuple(
|
296 |
+
rvq.get_codes_from_indices(chunk_indices)
|
297 |
+
for rvq, chunk_indices in zip(self.rvqs, indices)
|
298 |
+
)
|
299 |
+
return torch.stack(codes)
|
300 |
+
|
301 |
+
def get_output_from_indices(self, indices):
|
302 |
+
outputs = tuple(
|
303 |
+
rvq.get_output_from_indices(chunk_indices)
|
304 |
+
for rvq, chunk_indices in zip(self.rvqs, indices)
|
305 |
+
)
|
306 |
+
return torch.cat(outputs, dim=self.split_dim)
|
307 |
+
|
308 |
+
def forward(self, x, return_all_codes=False):
|
309 |
+
shape, split_dim, device = x.shape, self.split_dim, x.device
|
310 |
+
assert shape[split_dim] == self.dim
|
311 |
+
|
312 |
+
# split the feature dimension into groups
|
313 |
+
|
314 |
+
x = x.chunk(self.groups, dim=split_dim)
|
315 |
+
|
316 |
+
forward_kwargs = dict(
|
317 |
+
return_all_codes=return_all_codes,
|
318 |
+
rand_quantize_dropout_fixed_seed=(
|
319 |
+
get_maybe_sync_seed(device) if self.training else None
|
320 |
+
),
|
321 |
+
)
|
322 |
+
|
323 |
+
# invoke residual vq on each group
|
324 |
+
|
325 |
+
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
|
326 |
+
out = tuple(zip(*out))
|
327 |
+
|
328 |
+
# otherwise, get all the zipped outputs and combine them
|
329 |
+
|
330 |
+
quantized, all_indices, *maybe_all_codes = out
|
331 |
+
|
332 |
+
quantized = torch.cat(quantized, dim=split_dim)
|
333 |
+
all_indices = torch.stack(all_indices)
|
334 |
+
|
335 |
+
ret = (quantized, all_indices, *maybe_all_codes)
|
336 |
+
return ret
|
337 |
+
|
338 |
+
|
339 |
+
if __name__ == "__main__":
|
340 |
+
model = ResidualFSQ(
|
341 |
+
levels=[4, 4, 4, 4, 4, 4],
|
342 |
+
num_quantizers=1,
|
343 |
+
dim=30,
|
344 |
+
is_channel_first=True,
|
345 |
+
quantize_dropout=False,
|
346 |
+
)
|
347 |
+
x = torch.randn(2, 30, 10)
|
348 |
+
quantize, embed_ind = model(x)
|
349 |
+
|
350 |
+
emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
|
351 |
+
|
352 |
+
print(quantize == emb_from_ind.transpose(1, 2))
|
353 |
+
|
354 |
+
print("quantize shape", quantize.shape)
|
355 |
+
print("embed_ind", embed_ind)
|
sparktts/modules/speaker/ecapa_tdnn.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Zhengyang Chen ([email protected])
|
2 |
+
# 2022 Hongji Wang ([email protected])
|
3 |
+
# 2023 Bing Han ([email protected])
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
""" This implementation is adapted from github repo:
|
18 |
+
https://github.com/lawlict/ECAPA-TDNN.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
import sparktts.modules.speaker.pooling_layers as pooling_layers
|
26 |
+
|
27 |
+
|
28 |
+
class Res2Conv1dReluBn(nn.Module):
|
29 |
+
"""
|
30 |
+
in_channels == out_channels == channels
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
channels,
|
36 |
+
kernel_size=1,
|
37 |
+
stride=1,
|
38 |
+
padding=0,
|
39 |
+
dilation=1,
|
40 |
+
bias=True,
|
41 |
+
scale=4,
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
45 |
+
self.scale = scale
|
46 |
+
self.width = channels // scale
|
47 |
+
self.nums = scale if scale == 1 else scale - 1
|
48 |
+
|
49 |
+
self.convs = []
|
50 |
+
self.bns = []
|
51 |
+
for i in range(self.nums):
|
52 |
+
self.convs.append(
|
53 |
+
nn.Conv1d(
|
54 |
+
self.width,
|
55 |
+
self.width,
|
56 |
+
kernel_size,
|
57 |
+
stride,
|
58 |
+
padding,
|
59 |
+
dilation,
|
60 |
+
bias=bias,
|
61 |
+
)
|
62 |
+
)
|
63 |
+
self.bns.append(nn.BatchNorm1d(self.width))
|
64 |
+
self.convs = nn.ModuleList(self.convs)
|
65 |
+
self.bns = nn.ModuleList(self.bns)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
out = []
|
69 |
+
spx = torch.split(x, self.width, 1)
|
70 |
+
sp = spx[0]
|
71 |
+
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
|
72 |
+
# Order: conv -> relu -> bn
|
73 |
+
if i >= 1:
|
74 |
+
sp = sp + spx[i]
|
75 |
+
sp = conv(sp)
|
76 |
+
sp = bn(F.relu(sp))
|
77 |
+
out.append(sp)
|
78 |
+
if self.scale != 1:
|
79 |
+
out.append(spx[self.nums])
|
80 |
+
out = torch.cat(out, dim=1)
|
81 |
+
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
""" Conv1d + BatchNorm1d + ReLU
|
86 |
+
"""
|
87 |
+
|
88 |
+
|
89 |
+
class Conv1dReluBn(nn.Module):
|
90 |
+
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
in_channels,
|
94 |
+
out_channels,
|
95 |
+
kernel_size=1,
|
96 |
+
stride=1,
|
97 |
+
padding=0,
|
98 |
+
dilation=1,
|
99 |
+
bias=True,
|
100 |
+
):
|
101 |
+
super().__init__()
|
102 |
+
self.conv = nn.Conv1d(
|
103 |
+
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
|
104 |
+
)
|
105 |
+
self.bn = nn.BatchNorm1d(out_channels)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return self.bn(F.relu(self.conv(x)))
|
109 |
+
|
110 |
+
|
111 |
+
""" The SE connection of 1D case.
|
112 |
+
"""
|
113 |
+
|
114 |
+
|
115 |
+
class SE_Connect(nn.Module):
|
116 |
+
|
117 |
+
def __init__(self, channels, se_bottleneck_dim=128):
|
118 |
+
super().__init__()
|
119 |
+
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
120 |
+
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
out = x.mean(dim=2)
|
124 |
+
out = F.relu(self.linear1(out))
|
125 |
+
out = torch.sigmoid(self.linear2(out))
|
126 |
+
out = x * out.unsqueeze(2)
|
127 |
+
|
128 |
+
return out
|
129 |
+
|
130 |
+
|
131 |
+
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
132 |
+
"""
|
133 |
+
|
134 |
+
|
135 |
+
class SE_Res2Block(nn.Module):
|
136 |
+
|
137 |
+
def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
|
138 |
+
super().__init__()
|
139 |
+
self.se_res2block = nn.Sequential(
|
140 |
+
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
141 |
+
Res2Conv1dReluBn(
|
142 |
+
channels, kernel_size, stride, padding, dilation, scale=scale
|
143 |
+
),
|
144 |
+
Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
|
145 |
+
SE_Connect(channels),
|
146 |
+
)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
return x + self.se_res2block(x)
|
150 |
+
|
151 |
+
|
152 |
+
class ECAPA_TDNN(nn.Module):
|
153 |
+
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
channels=512,
|
157 |
+
feat_dim=80,
|
158 |
+
embed_dim=192,
|
159 |
+
pooling_func="ASTP",
|
160 |
+
global_context_att=False,
|
161 |
+
emb_bn=False,
|
162 |
+
):
|
163 |
+
super().__init__()
|
164 |
+
|
165 |
+
self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
|
166 |
+
self.layer2 = SE_Res2Block(
|
167 |
+
channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
|
168 |
+
)
|
169 |
+
self.layer3 = SE_Res2Block(
|
170 |
+
channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
|
171 |
+
)
|
172 |
+
self.layer4 = SE_Res2Block(
|
173 |
+
channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
|
174 |
+
)
|
175 |
+
|
176 |
+
cat_channels = channels * 3
|
177 |
+
out_channels = 512 * 3
|
178 |
+
self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
|
179 |
+
self.pool = getattr(pooling_layers, pooling_func)(
|
180 |
+
in_dim=out_channels, global_context_att=global_context_att
|
181 |
+
)
|
182 |
+
self.pool_out_dim = self.pool.get_out_dim()
|
183 |
+
self.bn = nn.BatchNorm1d(self.pool_out_dim)
|
184 |
+
self.linear = nn.Linear(self.pool_out_dim, embed_dim)
|
185 |
+
self.emb_bn = emb_bn
|
186 |
+
if emb_bn: # better in SSL for SV
|
187 |
+
self.bn2 = nn.BatchNorm1d(embed_dim)
|
188 |
+
else:
|
189 |
+
self.bn2 = nn.Identity()
|
190 |
+
|
191 |
+
def forward(self, x, return_latent=False):
|
192 |
+
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
|
193 |
+
|
194 |
+
out1 = self.layer1(x)
|
195 |
+
out2 = self.layer2(out1)
|
196 |
+
out3 = self.layer3(out2)
|
197 |
+
out4 = self.layer4(out3)
|
198 |
+
|
199 |
+
out = torch.cat([out2, out3, out4], dim=1)
|
200 |
+
latent = F.relu(self.conv(out))
|
201 |
+
out = self.bn(self.pool(latent))
|
202 |
+
out = self.linear(out)
|
203 |
+
if self.emb_bn:
|
204 |
+
out = self.bn2(out)
|
205 |
+
|
206 |
+
if return_latent:
|
207 |
+
return out, latent
|
208 |
+
return out
|
209 |
+
|
210 |
+
|
211 |
+
def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
212 |
+
return ECAPA_TDNN(
|
213 |
+
channels=1024,
|
214 |
+
feat_dim=feat_dim,
|
215 |
+
embed_dim=embed_dim,
|
216 |
+
pooling_func=pooling_func,
|
217 |
+
emb_bn=emb_bn,
|
218 |
+
)
|
219 |
+
|
220 |
+
|
221 |
+
def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
222 |
+
return ECAPA_TDNN(
|
223 |
+
channels=1024,
|
224 |
+
feat_dim=feat_dim,
|
225 |
+
embed_dim=embed_dim,
|
226 |
+
pooling_func=pooling_func,
|
227 |
+
global_context_att=True,
|
228 |
+
emb_bn=emb_bn,
|
229 |
+
)
|
230 |
+
|
231 |
+
|
232 |
+
def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
233 |
+
return ECAPA_TDNN(
|
234 |
+
channels=512,
|
235 |
+
feat_dim=feat_dim,
|
236 |
+
embed_dim=embed_dim,
|
237 |
+
pooling_func=pooling_func,
|
238 |
+
emb_bn=emb_bn,
|
239 |
+
)
|
240 |
+
|
241 |
+
|
242 |
+
def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
|
243 |
+
return ECAPA_TDNN(
|
244 |
+
channels=512,
|
245 |
+
feat_dim=feat_dim,
|
246 |
+
embed_dim=embed_dim,
|
247 |
+
pooling_func=pooling_func,
|
248 |
+
global_context_att=True,
|
249 |
+
emb_bn=emb_bn,
|
250 |
+
)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
x = torch.zeros(1, 200, 100)
|
255 |
+
model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP")
|
256 |
+
model.eval()
|
257 |
+
out, latent = model(x, True)
|
258 |
+
print(out.shape)
|
259 |
+
print(latent.shape)
|
260 |
+
|
261 |
+
num_params = sum(param.numel() for param in model.parameters())
|
262 |
+
print("{} M".format(num_params / 1e6))
|
263 |
+
|
264 |
+
# from thop import profile
|
265 |
+
# x_np = torch.randn(1, 200, 80)
|
266 |
+
# flops, params = profile(model, inputs=(x_np, ))
|
267 |
+
# print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
|
sparktts/modules/speaker/perceiver_encoder.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
17 |
+
|
18 |
+
from collections import namedtuple
|
19 |
+
from functools import wraps
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from einops import rearrange, repeat
|
24 |
+
from einops.layers.torch import Rearrange
|
25 |
+
from packaging import version
|
26 |
+
from torch import einsum, nn
|
27 |
+
|
28 |
+
|
29 |
+
def exists(val):
|
30 |
+
return val is not None
|
31 |
+
|
32 |
+
|
33 |
+
def once(fn):
|
34 |
+
called = False
|
35 |
+
|
36 |
+
@wraps(fn)
|
37 |
+
def inner(x):
|
38 |
+
nonlocal called
|
39 |
+
if called:
|
40 |
+
return
|
41 |
+
called = True
|
42 |
+
return fn(x)
|
43 |
+
|
44 |
+
return inner
|
45 |
+
|
46 |
+
|
47 |
+
print_once = once(print)
|
48 |
+
|
49 |
+
# main class
|
50 |
+
|
51 |
+
|
52 |
+
class Attend(nn.Module):
|
53 |
+
def __init__(self, dropout=0.0, causal=False, use_flash=False):
|
54 |
+
super().__init__()
|
55 |
+
self.dropout = dropout
|
56 |
+
self.attn_dropout = nn.Dropout(dropout)
|
57 |
+
|
58 |
+
self.causal = causal
|
59 |
+
self.register_buffer("mask", None, persistent=False)
|
60 |
+
|
61 |
+
self.use_flash = use_flash
|
62 |
+
assert not (
|
63 |
+
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
64 |
+
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
65 |
+
|
66 |
+
# determine efficient attention configs for cuda and cpu
|
67 |
+
self.config = namedtuple(
|
68 |
+
"EfficientAttentionConfig",
|
69 |
+
["enable_flash", "enable_math", "enable_mem_efficient"],
|
70 |
+
)
|
71 |
+
self.cpu_config = self.config(True, True, True)
|
72 |
+
self.cuda_config = None
|
73 |
+
|
74 |
+
if not torch.cuda.is_available() or not use_flash:
|
75 |
+
return
|
76 |
+
|
77 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
78 |
+
|
79 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
80 |
+
print_once(
|
81 |
+
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
82 |
+
)
|
83 |
+
self.cuda_config = self.config(True, False, False)
|
84 |
+
else:
|
85 |
+
print_once(
|
86 |
+
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
87 |
+
)
|
88 |
+
self.cuda_config = self.config(False, True, True)
|
89 |
+
|
90 |
+
def get_mask(self, n, device):
|
91 |
+
if exists(self.mask) and self.mask.shape[-1] >= n:
|
92 |
+
return self.mask[:n, :n]
|
93 |
+
|
94 |
+
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
95 |
+
self.register_buffer("mask", mask, persistent=False)
|
96 |
+
return mask
|
97 |
+
|
98 |
+
def flash_attn(self, q, k, v, mask=None):
|
99 |
+
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
|
100 |
+
|
101 |
+
# Recommended for multi-query single-key-value attention by Tri Dao
|
102 |
+
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
103 |
+
|
104 |
+
if k.ndim == 3:
|
105 |
+
k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
|
106 |
+
|
107 |
+
if v.ndim == 3:
|
108 |
+
v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
|
109 |
+
|
110 |
+
# Check if mask exists and expand to compatible shape
|
111 |
+
# The mask is B L, so it would have to be expanded to B H N L
|
112 |
+
|
113 |
+
if exists(mask):
|
114 |
+
mask = rearrange(mask, "b j -> b 1 1 j")
|
115 |
+
mask = mask.expand(-1, heads, q_len, -1)
|
116 |
+
|
117 |
+
# Check if there is a compatible device for flash attention
|
118 |
+
|
119 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
120 |
+
|
121 |
+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
122 |
+
|
123 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
124 |
+
out = F.scaled_dot_product_attention(
|
125 |
+
q,
|
126 |
+
k,
|
127 |
+
v,
|
128 |
+
attn_mask=mask,
|
129 |
+
dropout_p=self.dropout if self.training else 0.0,
|
130 |
+
is_causal=self.causal,
|
131 |
+
)
|
132 |
+
|
133 |
+
return out
|
134 |
+
|
135 |
+
def forward(self, q, k, v, mask=None):
|
136 |
+
"""
|
137 |
+
einstein notation
|
138 |
+
b - batch
|
139 |
+
h - heads
|
140 |
+
n, i, j - sequence length (base sequence length, source, target)
|
141 |
+
d - feature dimension
|
142 |
+
"""
|
143 |
+
|
144 |
+
n, device = q.shape[-2], q.device
|
145 |
+
|
146 |
+
scale = q.shape[-1] ** -0.5
|
147 |
+
|
148 |
+
if self.use_flash:
|
149 |
+
return self.flash_attn(q, k, v, mask=mask)
|
150 |
+
|
151 |
+
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
|
152 |
+
|
153 |
+
# similarity
|
154 |
+
|
155 |
+
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
156 |
+
|
157 |
+
# key padding mask
|
158 |
+
|
159 |
+
if exists(mask):
|
160 |
+
mask = rearrange(mask, "b j -> b 1 1 j")
|
161 |
+
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
162 |
+
|
163 |
+
# causal mask
|
164 |
+
|
165 |
+
if self.causal:
|
166 |
+
causal_mask = self.get_mask(n, device)
|
167 |
+
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
168 |
+
|
169 |
+
# attention
|
170 |
+
|
171 |
+
attn = sim.softmax(dim=-1)
|
172 |
+
attn = self.attn_dropout(attn)
|
173 |
+
|
174 |
+
# aggregate values
|
175 |
+
|
176 |
+
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
177 |
+
|
178 |
+
return out
|
179 |
+
|
180 |
+
|
181 |
+
def Sequential(*mods):
|
182 |
+
return nn.Sequential(*filter(exists, mods))
|
183 |
+
|
184 |
+
|
185 |
+
def exists(x):
|
186 |
+
return x is not None
|
187 |
+
|
188 |
+
|
189 |
+
def default(val, d):
|
190 |
+
if exists(val):
|
191 |
+
return val
|
192 |
+
return d() if callable(d) else d
|
193 |
+
|
194 |
+
|
195 |
+
class RMSNorm(nn.Module):
|
196 |
+
def __init__(self, dim, scale=True, dim_cond=None):
|
197 |
+
super().__init__()
|
198 |
+
self.cond = exists(dim_cond)
|
199 |
+
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
|
200 |
+
|
201 |
+
self.scale = dim**0.5
|
202 |
+
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
|
203 |
+
|
204 |
+
def forward(self, x, cond=None):
|
205 |
+
gamma = default(self.gamma, 1)
|
206 |
+
out = F.normalize(x, dim=-1) * self.scale * gamma
|
207 |
+
|
208 |
+
if not self.cond:
|
209 |
+
return out
|
210 |
+
|
211 |
+
assert exists(cond)
|
212 |
+
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
|
213 |
+
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
|
214 |
+
return out * gamma + beta
|
215 |
+
|
216 |
+
|
217 |
+
class CausalConv1d(nn.Conv1d):
|
218 |
+
def __init__(self, *args, **kwargs):
|
219 |
+
super().__init__(*args, **kwargs)
|
220 |
+
(kernel_size,) = self.kernel_size
|
221 |
+
(dilation,) = self.dilation
|
222 |
+
(stride,) = self.stride
|
223 |
+
|
224 |
+
assert stride == 1
|
225 |
+
self.causal_padding = dilation * (kernel_size - 1)
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
229 |
+
return super().forward(causal_padded_x)
|
230 |
+
|
231 |
+
|
232 |
+
class GEGLU(nn.Module):
|
233 |
+
def forward(self, x):
|
234 |
+
x, gate = x.chunk(2, dim=-1)
|
235 |
+
return F.gelu(gate) * x
|
236 |
+
|
237 |
+
|
238 |
+
def FeedForward(dim, mult=4, causal_conv=False):
|
239 |
+
dim_inner = int(dim * mult * 2 / 3)
|
240 |
+
|
241 |
+
conv = None
|
242 |
+
if causal_conv:
|
243 |
+
conv = nn.Sequential(
|
244 |
+
Rearrange("b n d -> b d n"),
|
245 |
+
CausalConv1d(dim_inner, dim_inner, 3),
|
246 |
+
Rearrange("b d n -> b n d"),
|
247 |
+
)
|
248 |
+
|
249 |
+
return Sequential(
|
250 |
+
nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
class Attention(nn.Module):
|
255 |
+
def __init__(
|
256 |
+
self,
|
257 |
+
dim,
|
258 |
+
*,
|
259 |
+
dim_context=None,
|
260 |
+
causal=False,
|
261 |
+
dim_head=64,
|
262 |
+
heads=8,
|
263 |
+
dropout=0.0,
|
264 |
+
use_flash=False,
|
265 |
+
cross_attn_include_queries=False,
|
266 |
+
):
|
267 |
+
super().__init__()
|
268 |
+
self.scale = dim_head**-0.5
|
269 |
+
self.heads = heads
|
270 |
+
self.cross_attn_include_queries = cross_attn_include_queries
|
271 |
+
|
272 |
+
dim_inner = dim_head * heads
|
273 |
+
dim_context = default(dim_context, dim)
|
274 |
+
|
275 |
+
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
|
276 |
+
self.to_q = nn.Linear(dim, dim_inner, bias=False)
|
277 |
+
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
|
278 |
+
self.to_out = nn.Linear(dim_inner, dim, bias=False)
|
279 |
+
|
280 |
+
def forward(self, x, context=None, mask=None):
|
281 |
+
h, has_context = self.heads, exists(context)
|
282 |
+
|
283 |
+
context = default(context, x)
|
284 |
+
|
285 |
+
if has_context and self.cross_attn_include_queries:
|
286 |
+
context = torch.cat((x, context), dim=-2)
|
287 |
+
|
288 |
+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
289 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
290 |
+
|
291 |
+
out = self.attend(q, k, v, mask=mask)
|
292 |
+
|
293 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
294 |
+
return self.to_out(out)
|
295 |
+
|
296 |
+
|
297 |
+
class PerceiverResampler(nn.Module):
|
298 |
+
def __init__(
|
299 |
+
self,
|
300 |
+
*,
|
301 |
+
dim,
|
302 |
+
depth=2,
|
303 |
+
dim_context=None,
|
304 |
+
num_latents=32,
|
305 |
+
dim_head=64,
|
306 |
+
heads=8,
|
307 |
+
ff_mult=4,
|
308 |
+
use_flash_attn=False,
|
309 |
+
):
|
310 |
+
super().__init__()
|
311 |
+
dim_context = default(dim_context, dim)
|
312 |
+
|
313 |
+
self.proj_context = (
|
314 |
+
nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
|
315 |
+
)
|
316 |
+
|
317 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
318 |
+
nn.init.normal_(self.latents, std=0.02)
|
319 |
+
|
320 |
+
self.layers = nn.ModuleList([])
|
321 |
+
for _ in range(depth):
|
322 |
+
self.layers.append(
|
323 |
+
nn.ModuleList(
|
324 |
+
[
|
325 |
+
Attention(
|
326 |
+
dim=dim,
|
327 |
+
dim_head=dim_head,
|
328 |
+
heads=heads,
|
329 |
+
use_flash=use_flash_attn,
|
330 |
+
cross_attn_include_queries=True,
|
331 |
+
),
|
332 |
+
FeedForward(dim=dim, mult=ff_mult),
|
333 |
+
]
|
334 |
+
)
|
335 |
+
)
|
336 |
+
|
337 |
+
self.norm = RMSNorm(dim)
|
338 |
+
|
339 |
+
def forward(self, x, mask=None):
|
340 |
+
batch = x.shape[0]
|
341 |
+
|
342 |
+
x = self.proj_context(x)
|
343 |
+
|
344 |
+
latents = repeat(self.latents, "n d -> b n d", b=batch)
|
345 |
+
|
346 |
+
for attn, ff in self.layers:
|
347 |
+
latents = attn(latents, x, mask=mask) + latents
|
348 |
+
latents = ff(latents) + latents
|
349 |
+
|
350 |
+
return self.norm(latents)
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == "__main__":
|
354 |
+
model = PerceiverResampler(dim=256, dim_context=80)
|
355 |
+
x = torch.randn(8, 200, 80)
|
356 |
+
out = model(x)
|
357 |
+
print(out.shape) # [8, 32, 80]
|
358 |
+
|
359 |
+
num_params = sum(param.numel() for param in model.parameters())
|
360 |
+
print("{} M".format(num_params / 1e6))
|
sparktts/modules/speaker/pooling_layers.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Shuai Wang ([email protected])
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Pooling functions to aggregate frame-level deep features
|
16 |
+
into segment-level speaker embeddings
|
17 |
+
|
18 |
+
High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
|
19 |
+
even though we remove the mean statistic, on Voxceleb.
|
20 |
+
"""
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
|
27 |
+
class TAP(nn.Module):
|
28 |
+
"""
|
29 |
+
Temporal average pooling, only first-order mean is considered
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, in_dim=0, **kwargs):
|
33 |
+
super(TAP, self).__init__()
|
34 |
+
self.in_dim = in_dim
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
pooling_mean = x.mean(dim=-1)
|
38 |
+
# To be compatable with 2D input
|
39 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
40 |
+
return pooling_mean
|
41 |
+
|
42 |
+
def get_out_dim(self):
|
43 |
+
self.out_dim = self.in_dim
|
44 |
+
return self.out_dim
|
45 |
+
|
46 |
+
|
47 |
+
class TSDP(nn.Module):
|
48 |
+
"""
|
49 |
+
Temporal standard deviation pooling, only second-order std is considered
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, in_dim=0, **kwargs):
|
53 |
+
super(TSDP, self).__init__()
|
54 |
+
self.in_dim = in_dim
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
# The last dimension is the temporal axis
|
58 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
|
59 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
60 |
+
return pooling_std
|
61 |
+
|
62 |
+
def get_out_dim(self):
|
63 |
+
self.out_dim = self.in_dim
|
64 |
+
return self.out_dim
|
65 |
+
|
66 |
+
|
67 |
+
class TSTP(nn.Module):
|
68 |
+
"""
|
69 |
+
Temporal statistics pooling, concatenate mean and std, which is used in
|
70 |
+
x-vector
|
71 |
+
Comment: simple concatenation can not make full use of both statistics
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, in_dim=0, **kwargs):
|
75 |
+
super(TSTP, self).__init__()
|
76 |
+
self.in_dim = in_dim
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
# The last dimension is the temporal axis
|
80 |
+
pooling_mean = x.mean(dim=-1)
|
81 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
|
82 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
83 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
84 |
+
stats = torch.cat((pooling_mean, pooling_std), 1)
|
85 |
+
return stats
|
86 |
+
|
87 |
+
def get_out_dim(self):
|
88 |
+
self.out_dim = self.in_dim * 2
|
89 |
+
return self.out_dim
|
90 |
+
|
91 |
+
|
92 |
+
class ASTP(nn.Module):
|
93 |
+
""" Attentive statistics pooling: Channel- and context-dependent
|
94 |
+
statistics pooling, first used in ECAPA_TDNN.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self,
|
98 |
+
in_dim,
|
99 |
+
bottleneck_dim=128,
|
100 |
+
global_context_att=False,
|
101 |
+
**kwargs):
|
102 |
+
super(ASTP, self).__init__()
|
103 |
+
self.in_dim = in_dim
|
104 |
+
self.global_context_att = global_context_att
|
105 |
+
|
106 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
107 |
+
# need to transpose inputs.
|
108 |
+
if global_context_att:
|
109 |
+
self.linear1 = nn.Conv1d(
|
110 |
+
in_dim * 3, bottleneck_dim,
|
111 |
+
kernel_size=1) # equals W and b in the paper
|
112 |
+
else:
|
113 |
+
self.linear1 = nn.Conv1d(
|
114 |
+
in_dim, bottleneck_dim,
|
115 |
+
kernel_size=1) # equals W and b in the paper
|
116 |
+
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
117 |
+
kernel_size=1) # equals V and k in the paper
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
"""
|
121 |
+
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
122 |
+
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
123 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
124 |
+
"""
|
125 |
+
if len(x.shape) == 4:
|
126 |
+
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
127 |
+
assert len(x.shape) == 3
|
128 |
+
|
129 |
+
if self.global_context_att:
|
130 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
131 |
+
context_std = torch.sqrt(
|
132 |
+
torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
|
133 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
134 |
+
else:
|
135 |
+
x_in = x
|
136 |
+
|
137 |
+
# DON'T use ReLU here! ReLU may be hard to converge.
|
138 |
+
alpha = torch.tanh(
|
139 |
+
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
140 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
141 |
+
mean = torch.sum(alpha * x, dim=2)
|
142 |
+
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
143 |
+
std = torch.sqrt(var.clamp(min=1e-7))
|
144 |
+
return torch.cat([mean, std], dim=1)
|
145 |
+
|
146 |
+
def get_out_dim(self):
|
147 |
+
self.out_dim = 2 * self.in_dim
|
148 |
+
return self.out_dim
|
149 |
+
|
150 |
+
|
151 |
+
class MHASTP(torch.nn.Module):
|
152 |
+
""" Multi head attentive statistics pooling
|
153 |
+
Reference:
|
154 |
+
Self Multi-Head Attention for Speaker Recognition
|
155 |
+
https://arxiv.org/pdf/1906.09890.pdf
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self,
|
159 |
+
in_dim,
|
160 |
+
layer_num=2,
|
161 |
+
head_num=2,
|
162 |
+
d_s=1,
|
163 |
+
bottleneck_dim=64,
|
164 |
+
**kwargs):
|
165 |
+
super(MHASTP, self).__init__()
|
166 |
+
assert (in_dim % head_num
|
167 |
+
) == 0 # make sure that head num can be divided by input_dim
|
168 |
+
self.in_dim = in_dim
|
169 |
+
self.head_num = head_num
|
170 |
+
d_model = int(in_dim / head_num)
|
171 |
+
channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
|
172 |
+
if d_s > 1:
|
173 |
+
d_s = d_model
|
174 |
+
else:
|
175 |
+
d_s = 1
|
176 |
+
self.d_s = d_s
|
177 |
+
channel_dims[0], channel_dims[-1] = d_model, d_s
|
178 |
+
heads_att_trans = []
|
179 |
+
for i in range(self.head_num):
|
180 |
+
att_trans = nn.Sequential()
|
181 |
+
for i in range(layer_num - 1):
|
182 |
+
att_trans.add_module(
|
183 |
+
'att_' + str(i),
|
184 |
+
nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
|
185 |
+
att_trans.add_module('tanh' + str(i), nn.Tanh())
|
186 |
+
att_trans.add_module(
|
187 |
+
'att_' + str(layer_num - 1),
|
188 |
+
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
|
189 |
+
1, 1))
|
190 |
+
heads_att_trans.append(att_trans)
|
191 |
+
self.heads_att_trans = nn.ModuleList(heads_att_trans)
|
192 |
+
|
193 |
+
def forward(self, input):
|
194 |
+
"""
|
195 |
+
input: a 3-dimensional tensor in xvector architecture
|
196 |
+
or a 4-dimensional tensor in resnet architecture
|
197 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
198 |
+
"""
|
199 |
+
if len(input.shape) == 4: # B x F x T
|
200 |
+
input = input.reshape(input.shape[0],
|
201 |
+
input.shape[1] * input.shape[2],
|
202 |
+
input.shape[3])
|
203 |
+
assert len(input.shape) == 3
|
204 |
+
bs, f_dim, t_dim = input.shape
|
205 |
+
chunks = torch.chunk(input, self.head_num, 1)
|
206 |
+
# split
|
207 |
+
chunks_out = []
|
208 |
+
# for i in range(self.head_num):
|
209 |
+
# att_score = self.heads_att_trans[i](chunks[i])
|
210 |
+
for i, layer in enumerate(self.heads_att_trans):
|
211 |
+
att_score = layer(chunks[i])
|
212 |
+
alpha = F.softmax(att_score, dim=-1)
|
213 |
+
mean = torch.sum(alpha * chunks[i], dim=2)
|
214 |
+
var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
|
215 |
+
std = torch.sqrt(var.clamp(min=1e-7))
|
216 |
+
chunks_out.append(torch.cat((mean, std), dim=1))
|
217 |
+
out = torch.cat(chunks_out, dim=1)
|
218 |
+
return out
|
219 |
+
|
220 |
+
def get_out_dim(self):
|
221 |
+
self.out_dim = 2 * self.in_dim
|
222 |
+
return self.out_dim
|
223 |
+
|
224 |
+
|
225 |
+
class MQMHASTP(torch.nn.Module):
|
226 |
+
""" An attentive pooling
|
227 |
+
Reference:
|
228 |
+
multi query multi head attentive statistics pooling
|
229 |
+
https://arxiv.org/pdf/2110.05042.pdf
|
230 |
+
Args:
|
231 |
+
in_dim: the feature dimension of input
|
232 |
+
layer_num: the number of layer in the pooling layer
|
233 |
+
query_num: the number of querys
|
234 |
+
head_num: the number of heads
|
235 |
+
bottleneck_dim: the bottleneck dimension
|
236 |
+
|
237 |
+
SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
|
238 |
+
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
|
239 |
+
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
|
240 |
+
https://arxiv.org/pdf/1906.09890.pdf
|
241 |
+
AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
|
242 |
+
https://arxiv.org/pdf/1803.10963.pdf
|
243 |
+
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
|
244 |
+
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
|
245 |
+
"""
|
246 |
+
|
247 |
+
def __init__(self,
|
248 |
+
in_dim,
|
249 |
+
layer_num=2,
|
250 |
+
query_num=2,
|
251 |
+
head_num=8,
|
252 |
+
d_s=2,
|
253 |
+
bottleneck_dim=64,
|
254 |
+
**kwargs):
|
255 |
+
super(MQMHASTP, self).__init__()
|
256 |
+
self.n_query = nn.ModuleList([
|
257 |
+
MHASTP(in_dim,
|
258 |
+
layer_num=layer_num,
|
259 |
+
head_num=head_num,
|
260 |
+
d_s=d_s,
|
261 |
+
bottleneck_dim=bottleneck_dim) for i in range(query_num)
|
262 |
+
])
|
263 |
+
self.query_num = query_num
|
264 |
+
self.in_dim = in_dim
|
265 |
+
|
266 |
+
def forward(self, input):
|
267 |
+
"""
|
268 |
+
input: a 3-dimensional tensor in xvector architecture
|
269 |
+
or a 4-dimensional tensor in resnet architecture
|
270 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
271 |
+
"""
|
272 |
+
if len(input.shape) == 4: # B x F x T
|
273 |
+
input = input.reshape(input.shape[0],
|
274 |
+
input.shape[1] * input.shape[2],
|
275 |
+
input.shape[3])
|
276 |
+
assert len(input.shape) == 3
|
277 |
+
res = []
|
278 |
+
for i, layer in enumerate(self.n_query):
|
279 |
+
res.append(layer(input))
|
280 |
+
out = torch.cat(res, dim=-1)
|
281 |
+
return out
|
282 |
+
|
283 |
+
def get_out_dim(self):
|
284 |
+
self.out_dim = self.in_dim * 2 * self.query_num
|
285 |
+
return self.out_dim
|
286 |
+
|
287 |
+
|
288 |
+
if __name__ == '__main__':
|
289 |
+
data = torch.randn(16, 512, 10, 35)
|
290 |
+
# model = StatisticsPooling()
|
291 |
+
model = MQMHASTP(512 * 10)
|
292 |
+
model = MHASTP(512 * 10)
|
293 |
+
model = MQMHASTP(512 * 10, context=False)
|
294 |
+
print(model)
|
295 |
+
|
296 |
+
out = model(data)
|
297 |
+
print(out.shape)
|
298 |
+
print(model.get_out_dim())
|
sparktts/modules/speaker/speaker_encoder.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
from typing import List, Tuple
|
20 |
+
from sparktts.modules.fsq.residual_fsq import ResidualFSQ
|
21 |
+
from sparktts.modules.speaker.ecapa_tdnn import ECAPA_TDNN_GLOB_c512
|
22 |
+
from sparktts.modules.speaker.perceiver_encoder import PerceiverResampler
|
23 |
+
|
24 |
+
"""
|
25 |
+
x-vector + d-vector
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
class SpeakerEncoder(nn.Module):
|
30 |
+
"""
|
31 |
+
|
32 |
+
Args:
|
33 |
+
input_dim (int): acoustic feature dimension
|
34 |
+
out_dim (int): output dimension of x-vector and d-vector
|
35 |
+
latent_dim (int): latent dimension before quantization
|
36 |
+
token_num (int): sequence length of speaker tokens
|
37 |
+
fsq_levels (List[int]): number of levels for each quantizer
|
38 |
+
fsq_num_quantizers (int): number of quantizers
|
39 |
+
|
40 |
+
Return:
|
41 |
+
speaker_embs: (B, T2, out_dim)
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
input_dim: int = 100,
|
47 |
+
out_dim: int = 512,
|
48 |
+
latent_dim: int = 128,
|
49 |
+
token_num: int = 32,
|
50 |
+
fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
|
51 |
+
fsq_num_quantizers: int = 1,
|
52 |
+
):
|
53 |
+
super(SpeakerEncoder, self).__init__()
|
54 |
+
|
55 |
+
self.speaker_encoder = ECAPA_TDNN_GLOB_c512(
|
56 |
+
feat_dim=input_dim, embed_dim=out_dim
|
57 |
+
)
|
58 |
+
self.perceiver_sampler = PerceiverResampler(
|
59 |
+
dim=latent_dim, dim_context=512 * 3, num_latents=token_num
|
60 |
+
)
|
61 |
+
self.quantizer = ResidualFSQ(
|
62 |
+
levels=fsq_levels,
|
63 |
+
num_quantizers=fsq_num_quantizers,
|
64 |
+
dim=latent_dim,
|
65 |
+
is_channel_first=True,
|
66 |
+
quantize_dropout=False,
|
67 |
+
)
|
68 |
+
|
69 |
+
self.project = nn.Linear(latent_dim * token_num, out_dim)
|
70 |
+
|
71 |
+
def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
|
72 |
+
zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2))
|
73 |
+
return zq.transpose(1, 2)
|
74 |
+
|
75 |
+
def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
|
76 |
+
mels = mels.transpose(1, 2)
|
77 |
+
x = self.perceiver_sampler(mels).transpose(1, 2)
|
78 |
+
zq, indices = self.quantizer(x)
|
79 |
+
return indices
|
80 |
+
|
81 |
+
def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
mels: (B, D_mel, T1)
|
85 |
+
|
86 |
+
Return:
|
87 |
+
x_vector: (B, out_dim)
|
88 |
+
d_vector: (B, out_dim)
|
89 |
+
"""
|
90 |
+
# mels = mels.transpose(1,2)
|
91 |
+
|
92 |
+
x_vector, features = self.speaker_encoder(mels, True)
|
93 |
+
x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
|
94 |
+
zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim)
|
95 |
+
x = zq.reshape(zq.shape[0], -1)
|
96 |
+
d_vector = self.project(x)
|
97 |
+
|
98 |
+
return x_vector, d_vector
|
99 |
+
|
100 |
+
def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
|
101 |
+
"""tokenize the input mel spectrogram"""
|
102 |
+
_, features = self.speaker_encoder(mels, True)
|
103 |
+
x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
|
104 |
+
zq, indices = self.quantizer(x)
|
105 |
+
return indices
|
106 |
+
|
107 |
+
def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
|
108 |
+
"""detokenize the input indices to d-vector"""
|
109 |
+
zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2)
|
110 |
+
x = zq.reshape(zq.shape[0], -1)
|
111 |
+
d_vector = self.project(x)
|
112 |
+
return d_vector
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
model = SpeakerEncoder(
|
116 |
+
input_dim=100,
|
117 |
+
latent_dim=128,
|
118 |
+
token_num=32,
|
119 |
+
fsq_levels=[4, 4, 4, 4, 4, 4],
|
120 |
+
fsq_num_quantizers=1,
|
121 |
+
)
|
122 |
+
mel = torch.randn(8, 200, 100)
|
123 |
+
x_vector, d_vector = model(mel)
|
124 |
+
print("x-vector shape", x_vector.shape)
|
125 |
+
print("d-vector shape", d_vector.shape)
|
126 |
+
|
127 |
+
indices = model.tokenize(mel)
|
128 |
+
print("indices shape", indices.shape)
|
129 |
+
d_vector_post = model.detokenize(indices)
|
130 |
+
print("d-vector shape", d_vector_post.shape)
|
131 |
+
if d_vector_post.all() == d_vector.all():
|
132 |
+
print("d-vector post and d-vector are the same")
|
133 |
+
else:
|
134 |
+
print("d-vector post and d-vector are different")
|
135 |
+
num_params = sum(param.numel() for param in model.parameters())
|
136 |
+
print("{} M".format(num_params / 1e6))
|
sparktts/modules/vq/factorized_vector_quantize.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Heavily based on https://github.com/lucidrains/vector-quantize-pytorch
|
17 |
+
|
18 |
+
|
19 |
+
from typing import Any, Dict
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from einops import rearrange
|
25 |
+
from torch.nn.utils import weight_norm
|
26 |
+
|
27 |
+
|
28 |
+
def WNConv1d(*args, **kwargs):
|
29 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
30 |
+
|
31 |
+
|
32 |
+
def ema_inplace(moving_avg, new, decay):
|
33 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
34 |
+
|
35 |
+
|
36 |
+
class FactorizedVectorQuantize(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
input_dim: int,
|
40 |
+
codebook_size: int,
|
41 |
+
codebook_dim: int,
|
42 |
+
commitment: float,
|
43 |
+
codebook_loss_weight: float = 1.0,
|
44 |
+
decay: float = 0.99,
|
45 |
+
threshold_ema_dead_code: float = 2,
|
46 |
+
momentum: float = 0.99,
|
47 |
+
**kwargs,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.input_dim = input_dim
|
51 |
+
self.codebook_size = codebook_size
|
52 |
+
self.codebook_dim = codebook_dim
|
53 |
+
self.commitment = commitment
|
54 |
+
self.codebook_loss_weight = codebook_loss_weight
|
55 |
+
self.decay = decay
|
56 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
57 |
+
self.momentum = momentum
|
58 |
+
|
59 |
+
if input_dim != self.codebook_dim:
|
60 |
+
self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
|
61 |
+
self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
|
62 |
+
|
63 |
+
else:
|
64 |
+
self.in_project = nn.Identity()
|
65 |
+
self.out_project = nn.Identity()
|
66 |
+
|
67 |
+
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
68 |
+
self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
|
69 |
+
|
70 |
+
def forward(self, z: torch.Tensor) -> Dict[str, Any]:
|
71 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
72 |
+
the corresponding codebook vectors
|
73 |
+
|
74 |
+
Parameters
|
75 |
+
----------
|
76 |
+
z : Tensor[B x D x T]
|
77 |
+
|
78 |
+
Returns
|
79 |
+
-------
|
80 |
+
Tensor[B x D x T]
|
81 |
+
Quantized continuous representation of input
|
82 |
+
Tensor[1]
|
83 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
84 |
+
entries
|
85 |
+
Tensor[1]
|
86 |
+
Codebook loss to update the codebook
|
87 |
+
Tensor[B x T]
|
88 |
+
Codebook indices (quantized discrete representation of input)
|
89 |
+
Tensor[B x D x T]
|
90 |
+
Projected latents (continuous representation of input before quantization)
|
91 |
+
"""
|
92 |
+
# transpose since we use linear
|
93 |
+
|
94 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
95 |
+
z_e = self.in_project(z)
|
96 |
+
z_q, indices, dists = self.decode_latents(z_e)
|
97 |
+
|
98 |
+
# statistic the usage of codes
|
99 |
+
embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype)
|
100 |
+
avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0)
|
101 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
102 |
+
|
103 |
+
active_num = (embed_onehot.sum(0).sum(0) > 0).sum()
|
104 |
+
if self.training:
|
105 |
+
# We do the expiry of code at that point as buffers are in sync
|
106 |
+
# and all the workers will take the same decision.
|
107 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay)
|
108 |
+
active_num = sum(self.cluster_size > self.threshold_ema_dead_code)
|
109 |
+
|
110 |
+
if self.training:
|
111 |
+
commit_loss = (
|
112 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
113 |
+
* self.commitment
|
114 |
+
)
|
115 |
+
|
116 |
+
codebook_loss = (
|
117 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
118 |
+
* self.codebook_loss_weight
|
119 |
+
)
|
120 |
+
|
121 |
+
else:
|
122 |
+
commit_loss = torch.zeros(0, device=z.device)
|
123 |
+
codebook_loss = torch.zeros(0, device=z.device)
|
124 |
+
|
125 |
+
z_q = (
|
126 |
+
z_e + (z_q - z_e).detach()
|
127 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
128 |
+
|
129 |
+
z_q = self.out_project(z_q)
|
130 |
+
|
131 |
+
vq_loss = (commit_loss + codebook_loss).mean()
|
132 |
+
|
133 |
+
return {
|
134 |
+
"z_q": z_q,
|
135 |
+
"indices": indices,
|
136 |
+
"dists": dists,
|
137 |
+
"vq_loss": vq_loss,
|
138 |
+
"perplexity": perplexity,
|
139 |
+
"active_num": active_num.float(),
|
140 |
+
}
|
141 |
+
|
142 |
+
def vq2emb(self, vq, out_proj=True):
|
143 |
+
emb = self.embed_code(vq)
|
144 |
+
if out_proj:
|
145 |
+
emb = self.out_project(emb)
|
146 |
+
return emb
|
147 |
+
|
148 |
+
def tokenize(self, z: torch.Tensor) -> torch.Tensor:
|
149 |
+
"""tokenize the input tensor"""
|
150 |
+
z_e = self.in_project(z)
|
151 |
+
_, indices, _ = self.decode_latents(z_e)
|
152 |
+
return indices
|
153 |
+
|
154 |
+
def detokenize(self, indices):
|
155 |
+
"""detokenize the input indices"""
|
156 |
+
z_q = self.decode_code(indices)
|
157 |
+
z_q = self.out_project(z_q)
|
158 |
+
return z_q
|
159 |
+
|
160 |
+
def get_emb(self):
|
161 |
+
return self.codebook.weight
|
162 |
+
|
163 |
+
def embed_code(self, embed_id):
|
164 |
+
return F.embedding(embed_id, self.codebook.weight)
|
165 |
+
|
166 |
+
def decode_code(self, embed_id):
|
167 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
168 |
+
|
169 |
+
def decode_latents(self, latents):
|
170 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
171 |
+
codebook = self.codebook.weight
|
172 |
+
|
173 |
+
# L2 normalize encodings and codebook
|
174 |
+
encodings = F.normalize(encodings)
|
175 |
+
codebook = F.normalize(codebook)
|
176 |
+
|
177 |
+
# Compute euclidean distance between encodings and codebook,
|
178 |
+
# with L2 normalization, the distance is equal to cosine distance
|
179 |
+
dist = (
|
180 |
+
encodings.pow(2).sum(1, keepdim=True)
|
181 |
+
- 2 * encodings @ codebook.t()
|
182 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
183 |
+
)
|
184 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
185 |
+
z_q = self.decode_code(indices)
|
186 |
+
|
187 |
+
return z_q, indices, dist
|
sparktts/utils/__init__.py
ADDED
File without changes
|
sparktts/utils/audio.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Description:
|
17 |
+
This script contains a collection of functions designed to handle various
|
18 |
+
audio processing.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import random
|
22 |
+
import soxr
|
23 |
+
import soundfile
|
24 |
+
import torch
|
25 |
+
import torchaudio
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
from pathlib import Path
|
29 |
+
from typing import Tuple
|
30 |
+
from numpy.lib.stride_tricks import sliding_window_view
|
31 |
+
|
32 |
+
|
33 |
+
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
|
34 |
+
"""
|
35 |
+
Normalize the volume of an audio signal.
|
36 |
+
|
37 |
+
Parameters:
|
38 |
+
audio (numpy array): Input audio signal array.
|
39 |
+
coeff (float): Target coefficient for normalization, default is 0.2.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
numpy array: The volume-normalized audio signal.
|
43 |
+
"""
|
44 |
+
# Sort the absolute values of the audio signal
|
45 |
+
temp = np.sort(np.abs(audio))
|
46 |
+
|
47 |
+
# If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
|
48 |
+
if temp[-1] < 0.1:
|
49 |
+
scaling_factor = max(
|
50 |
+
temp[-1], 1e-3
|
51 |
+
) # Prevent division by zero with a small constant
|
52 |
+
audio = audio / scaling_factor * 0.1
|
53 |
+
|
54 |
+
# Filter out values less than 0.01 from temp
|
55 |
+
temp = temp[temp > 0.01]
|
56 |
+
L = temp.shape[0] # Length of the filtered array
|
57 |
+
|
58 |
+
# If there are fewer than or equal to 10 significant values, return the audio without further processing
|
59 |
+
if L <= 10:
|
60 |
+
return audio
|
61 |
+
|
62 |
+
# Compute the average of the top 10% to 1% of values in temp
|
63 |
+
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
|
64 |
+
|
65 |
+
# Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
|
66 |
+
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
|
67 |
+
|
68 |
+
# Ensure the maximum absolute value in the audio does not exceed 1
|
69 |
+
max_value = np.max(np.abs(audio))
|
70 |
+
if max_value > 1:
|
71 |
+
audio = audio / max_value
|
72 |
+
|
73 |
+
return audio
|
74 |
+
|
75 |
+
|
76 |
+
def load_audio(
|
77 |
+
adfile: Path,
|
78 |
+
sampling_rate: int = None,
|
79 |
+
length: int = None,
|
80 |
+
volume_normalize: bool = False,
|
81 |
+
segment_duration: int = None,
|
82 |
+
) -> np.ndarray:
|
83 |
+
r"""Load audio file with target sampling rate and lsength
|
84 |
+
|
85 |
+
Args:
|
86 |
+
adfile (Path): path to audio file.
|
87 |
+
sampling_rate (int, optional): target sampling rate. Defaults to None.
|
88 |
+
length (int, optional): target audio length. Defaults to None.
|
89 |
+
volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
|
90 |
+
segment_duration (int): random select a segment with duration of {segment_duration}s.
|
91 |
+
Defualt to None which means the whole audio will be used.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
audio (np.ndarray): audio
|
95 |
+
"""
|
96 |
+
|
97 |
+
audio, sr = soundfile.read(adfile)
|
98 |
+
if len(audio.shape) > 1:
|
99 |
+
audio = audio[:, 0]
|
100 |
+
|
101 |
+
if sampling_rate is not None and sr != sampling_rate:
|
102 |
+
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
|
103 |
+
sr = sampling_rate
|
104 |
+
|
105 |
+
if segment_duration is not None:
|
106 |
+
seg_length = int(sr * segment_duration)
|
107 |
+
audio = random_select_audio_segment(audio, seg_length)
|
108 |
+
|
109 |
+
# Audio volume normalize
|
110 |
+
if volume_normalize:
|
111 |
+
audio = audio_volume_normalize(audio)
|
112 |
+
# check the audio length
|
113 |
+
if length is not None:
|
114 |
+
assert abs(audio.shape[0] - length) < 1000
|
115 |
+
if audio.shape[0] > length:
|
116 |
+
audio = audio[:length]
|
117 |
+
else:
|
118 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
119 |
+
return audio
|
120 |
+
|
121 |
+
|
122 |
+
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
|
123 |
+
"""get an audio segment given the length
|
124 |
+
|
125 |
+
Args:
|
126 |
+
audio (np.ndarray):
|
127 |
+
length (int): audio length = sampling_rate * duration
|
128 |
+
"""
|
129 |
+
if audio.shape[0] < length:
|
130 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
131 |
+
start_index = random.randint(0, audio.shape[0] - length)
|
132 |
+
end_index = int(start_index + length)
|
133 |
+
|
134 |
+
return audio[start_index:end_index]
|
135 |
+
|
136 |
+
|
137 |
+
def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq):
|
138 |
+
"""apply highpass fileter to audio
|
139 |
+
|
140 |
+
Args:
|
141 |
+
audio (np.ndarray):
|
142 |
+
sample_rate (ind):
|
143 |
+
highpass_cutoff_freq (int):
|
144 |
+
"""
|
145 |
+
|
146 |
+
audio = torchaudio.functional.highpass_biquad(
|
147 |
+
torch.from_numpy(audio), sample_rate, cutoff_freq=highpass_cutoff_freq
|
148 |
+
)
|
149 |
+
return audio.numpy()
|
150 |
+
|
151 |
+
|
152 |
+
def stft(
|
153 |
+
x: torch.Tensor,
|
154 |
+
fft_size: int,
|
155 |
+
hop_size: int,
|
156 |
+
win_length: int,
|
157 |
+
window: str,
|
158 |
+
use_complex: bool = False,
|
159 |
+
) -> torch.Tensor:
|
160 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
161 |
+
Args:
|
162 |
+
x (Tensor): Input signal tensor (B, T).
|
163 |
+
fft_size (int): FFT size.
|
164 |
+
hop_size (int): Hop size.
|
165 |
+
win_length (int): Window length.
|
166 |
+
window (str): Window function type.
|
167 |
+
Returns:
|
168 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
169 |
+
"""
|
170 |
+
|
171 |
+
x_stft = torch.stft(
|
172 |
+
x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True
|
173 |
+
)
|
174 |
+
|
175 |
+
# clamp is needed to avoid nan or inf
|
176 |
+
if not use_complex:
|
177 |
+
return torch.sqrt(
|
178 |
+
torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3)
|
179 |
+
).transpose(2, 1)
|
180 |
+
else:
|
181 |
+
res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1)
|
182 |
+
res = res.transpose(2, 3) # [B, 2, T, F]
|
183 |
+
return res
|
184 |
+
|
185 |
+
|
186 |
+
def detect_speech_boundaries(
|
187 |
+
wav: np.ndarray,
|
188 |
+
sample_rate: int,
|
189 |
+
window_duration: float = 0.1,
|
190 |
+
energy_threshold: float = 0.01,
|
191 |
+
margin_factor: int = 2
|
192 |
+
) -> Tuple[int, int]:
|
193 |
+
"""Detect the start and end points of speech in an audio signal using RMS energy.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
wav: Input audio signal array with values in [-1, 1]
|
197 |
+
sample_rate: Audio sample rate in Hz
|
198 |
+
window_duration: Duration of detection window in seconds
|
199 |
+
energy_threshold: RMS energy threshold for speech detection
|
200 |
+
margin_factor: Factor to determine extra margin around detected boundaries
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
tuple: (start_index, end_index) of speech segment
|
204 |
+
|
205 |
+
Raises:
|
206 |
+
ValueError: If the audio contains only silence
|
207 |
+
"""
|
208 |
+
window_size = int(window_duration * sample_rate)
|
209 |
+
margin = margin_factor * window_size
|
210 |
+
step_size = window_size // 10
|
211 |
+
|
212 |
+
# Create sliding windows using stride tricks to avoid loops
|
213 |
+
windows = sliding_window_view(wav, window_size)[::step_size]
|
214 |
+
|
215 |
+
# Calculate RMS energy for each window
|
216 |
+
energy = np.sqrt(np.mean(windows ** 2, axis=1))
|
217 |
+
speech_mask = energy >= energy_threshold
|
218 |
+
|
219 |
+
if not np.any(speech_mask):
|
220 |
+
raise ValueError("No speech detected in audio (only silence)")
|
221 |
+
|
222 |
+
start = max(0, np.argmax(speech_mask) * step_size - margin)
|
223 |
+
end = min(len(wav), (len(speech_mask) - 1 - np.argmax(speech_mask[::-1])) * step_size + margin)
|
224 |
+
|
225 |
+
return start, end
|
226 |
+
|
227 |
+
|
228 |
+
def remove_silence_on_both_ends(
|
229 |
+
wav: np.ndarray,
|
230 |
+
sample_rate: int,
|
231 |
+
window_duration: float = 0.1,
|
232 |
+
volume_threshold: float = 0.01
|
233 |
+
) -> np.ndarray:
|
234 |
+
"""Remove silence from both ends of an audio signal.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
wav: Input audio signal array
|
238 |
+
sample_rate: Audio sample rate in Hz
|
239 |
+
window_duration: Duration of detection window in seconds
|
240 |
+
volume_threshold: Amplitude threshold for silence detection
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
np.ndarray: Audio signal with silence removed from both ends
|
244 |
+
|
245 |
+
Raises:
|
246 |
+
ValueError: If the audio contains only silence
|
247 |
+
"""
|
248 |
+
start, end = detect_speech_boundaries(
|
249 |
+
wav,
|
250 |
+
sample_rate,
|
251 |
+
window_duration,
|
252 |
+
volume_threshold
|
253 |
+
)
|
254 |
+
return wav[start:end]
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
def hertz_to_mel(pitch: float) -> float:
|
259 |
+
"""
|
260 |
+
Converts a frequency from the Hertz scale to the Mel scale.
|
261 |
+
|
262 |
+
Parameters:
|
263 |
+
- pitch: float or ndarray
|
264 |
+
Frequency in Hertz.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
- mel: float or ndarray
|
268 |
+
Frequency in Mel scale.
|
269 |
+
"""
|
270 |
+
mel = 2595 * np.log10(1 + pitch / 700)
|
271 |
+
return mel
|
sparktts/utils/file.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 SparkAudio
|
2 |
+
# 2025 Xinsheng Wang ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Description:
|
17 |
+
This script contains a collection of functions designed to handle various
|
18 |
+
file reading and writing operations. It provides utilities to read from files,
|
19 |
+
write data to files, and perform file manipulation tasks.
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
import os
|
24 |
+
import json
|
25 |
+
import json
|
26 |
+
import csv
|
27 |
+
|
28 |
+
from tqdm import tqdm
|
29 |
+
from typing import List, Dict, Any, Set, Union
|
30 |
+
from pathlib import Path
|
31 |
+
from omegaconf import OmegaConf, DictConfig
|
32 |
+
|
33 |
+
|
34 |
+
def resolve_symbolic_link(symbolic_link_path: Path) -> Path:
|
35 |
+
"""
|
36 |
+
Resolves the absolute path of a symbolic link.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
symbolic_link_path (Path): The path to the symbolic link.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Path: The absolute path that the symbolic link points to.
|
43 |
+
"""
|
44 |
+
|
45 |
+
link_directory = os.path.dirname(symbolic_link_path)
|
46 |
+
target_path_relative = os.readlink(symbolic_link_path)
|
47 |
+
return os.path.join(link_directory, target_path_relative)
|
48 |
+
|
49 |
+
|
50 |
+
def write_jsonl(metadata: List[dict], file_path: Path) -> None:
|
51 |
+
"""Writes a list of dictionaries to a JSONL file.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
metadata : List[dict]
|
55 |
+
A list of dictionaries, each representing a piece of meta.
|
56 |
+
file_path : Path
|
57 |
+
The file path to save the JSONL file
|
58 |
+
|
59 |
+
This function writes each dictionary in the list to a new line in the specified file.
|
60 |
+
"""
|
61 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
62 |
+
for meta in tqdm(metadata, desc="writing jsonl"):
|
63 |
+
# Convert dictionary to JSON string and write it to the file with a newline
|
64 |
+
json_str = json.dumps(meta, ensure_ascii=False) + "\n"
|
65 |
+
f.write(json_str)
|
66 |
+
print(f"jsonl saved to {file_path}")
|
67 |
+
|
68 |
+
|
69 |
+
def read_jsonl(file_path: Path) -> List[dict]:
|
70 |
+
"""
|
71 |
+
Reads a JSONL file and returns a list of dictionaries.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
file_path : Path
|
75 |
+
The path to the JSONL file to be read.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
List[dict]
|
79 |
+
A list of dictionaries parsed from each line of the JSONL file.
|
80 |
+
"""
|
81 |
+
metadata = []
|
82 |
+
# Open the file for reading
|
83 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
84 |
+
# Split the file into lines
|
85 |
+
lines = f.read().splitlines()
|
86 |
+
# Process each line
|
87 |
+
for line in lines:
|
88 |
+
# Convert JSON string back to dictionary and append to list
|
89 |
+
meta = json.loads(line)
|
90 |
+
metadata.append(meta)
|
91 |
+
# Return the list of metadata
|
92 |
+
return metadata
|
93 |
+
|
94 |
+
def read_json_as_jsonl(file_path: Path) -> List[dict]:
|
95 |
+
metadata = []
|
96 |
+
with open(file_path, 'r', encoding='utf-8') as infile:
|
97 |
+
data = json.load(infile)
|
98 |
+
for k in sorted(data.keys()):
|
99 |
+
meta = {'index': k}
|
100 |
+
meta.update(data[k])
|
101 |
+
metadata.append(meta)
|
102 |
+
return metadata
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]:
|
107 |
+
processed_meta = {}
|
108 |
+
for k, v in meta.items():
|
109 |
+
if isinstance(v, str):
|
110 |
+
processed_meta[k] = v.encode("utf-8").decode("unicode_escape")
|
111 |
+
else:
|
112 |
+
processed_meta[k] = v
|
113 |
+
return processed_meta
|
114 |
+
|
115 |
+
|
116 |
+
def load_config(config_path: Path) -> DictConfig:
|
117 |
+
"""Loads a configuration file and optionally merges it with a base configuration.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
config_path (Path): Path to the configuration file.
|
121 |
+
"""
|
122 |
+
# Load the initial configuration from the given path
|
123 |
+
config = OmegaConf.load(config_path)
|
124 |
+
|
125 |
+
# Check if there is a base configuration specified and merge if necessary
|
126 |
+
if config.get("base_config", None) is not None:
|
127 |
+
base_config = OmegaConf.load(config["base_config"])
|
128 |
+
config = OmegaConf.merge(base_config, config)
|
129 |
+
|
130 |
+
return config
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None:
|
135 |
+
"""
|
136 |
+
Converts a JSONL file to a CSV file.
|
137 |
+
|
138 |
+
This function reads a JSONL file, determines all unique keys present in the file,
|
139 |
+
and writes the data to a CSV file with columns for all these keys.
|
140 |
+
"""
|
141 |
+
|
142 |
+
all_keys = set()
|
143 |
+
data_rows = []
|
144 |
+
|
145 |
+
# Read the JSONL file once to extract keys and collect data
|
146 |
+
with open(jsonl_file_path, 'r') as file:
|
147 |
+
for line in file:
|
148 |
+
data = json.loads(line.strip())
|
149 |
+
data_rows.append(data)
|
150 |
+
all_keys.update(data.keys())
|
151 |
+
|
152 |
+
# Convert the set of keys to a sorted list for consistent column order
|
153 |
+
sorted_keys = sorted(all_keys)
|
154 |
+
|
155 |
+
# Write the data to a CSV file
|
156 |
+
with open(csv_file_path, 'w', newline='') as csvfile:
|
157 |
+
writer = csv.DictWriter(csvfile, fieldnames=sorted_keys)
|
158 |
+
|
159 |
+
# Write the header row
|
160 |
+
writer.writeheader()
|
161 |
+
|
162 |
+
# Write each row of data
|
163 |
+
for data in data_rows:
|
164 |
+
writer.writerow(data)
|
165 |
+
|
166 |
+
print(f"CSV file has been created at {csv_file_path}")
|
167 |
+
|
168 |
+
|
169 |
+
def save_metadata(data, filename, headers=None):
|
170 |
+
"""
|
171 |
+
Save metadata to a file.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
data (list of dict): Metadata to be saved.
|
175 |
+
filename (str): Name of the file to save the metadata.
|
176 |
+
headers (list of str): The order of column names to be saved; defaults to the keys from the first dictionary in data if not provided.
|
177 |
+
"""
|
178 |
+
# Set headers to keys from the first dictionary in data if not explicitly provided
|
179 |
+
if headers is None:
|
180 |
+
headers = list(data[0].keys())
|
181 |
+
|
182 |
+
with open(filename, "w", encoding="utf-8") as file:
|
183 |
+
# Write the headers to the file
|
184 |
+
file.write("|".join(headers) + "\n")
|
185 |
+
for entry in data:
|
186 |
+
# Retrieve values in the order of headers, replacing any '|' characters with a space to prevent formatting errors
|
187 |
+
formatted_values = [str(entry.get(key, "")).replace("|", " ") for key in headers]
|
188 |
+
# Write the formatted values to the file
|
189 |
+
file.write("|".join(formatted_values) + "\n")
|
190 |
+
|
191 |
+
|
192 |
+
def read_metadata(filename, headers=None):
|
193 |
+
"""
|
194 |
+
Read metadata from a file.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
filename (str): The file from which to read the metadata.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
list of dict: The metadata read from the file.
|
201 |
+
list of str: The headers used in the file.
|
202 |
+
"""
|
203 |
+
with open(filename, "r", encoding="utf-8") as file:
|
204 |
+
lines = file.readlines()
|
205 |
+
|
206 |
+
data = []
|
207 |
+
# Set headers from the first line of the file if not provided
|
208 |
+
if headers is None:
|
209 |
+
headers = lines[0].strip().split("|")
|
210 |
+
lines = lines[1:]
|
211 |
+
|
212 |
+
for line in lines:
|
213 |
+
line = line.strip()
|
214 |
+
# Skip empty lines
|
215 |
+
if not line:
|
216 |
+
continue
|
217 |
+
# Split the line by '|' and pair with headers to form a dictionary
|
218 |
+
entry_data = dict(zip(headers, line.split("|")))
|
219 |
+
data.append(entry_data)
|
220 |
+
|
221 |
+
return data, headers
|
sparktts/utils/parse_options.sh
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
4 |
+
# Arnab Ghoshal, Karel Vesely
|
5 |
+
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
13 |
+
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
14 |
+
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
15 |
+
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
16 |
+
# See the Apache 2 License for the specific language governing permissions and
|
17 |
+
# limitations under the License.
|
18 |
+
|
19 |
+
|
20 |
+
# Parse command-line options.
|
21 |
+
# To be sourced by another script (as in ". parse_options.sh").
|
22 |
+
# Option format is: --option-name arg
|
23 |
+
# and shell variable "option_name" gets set to value "arg."
|
24 |
+
# The exception is --help, which takes no arguments, but prints the
|
25 |
+
# $help_message variable (if defined).
|
26 |
+
|
27 |
+
|
28 |
+
###
|
29 |
+
### The --config file options have lower priority to command line
|
30 |
+
### options, so we need to import them first...
|
31 |
+
###
|
32 |
+
|
33 |
+
# Now import all the configs specified by command-line, in left-to-right order
|
34 |
+
# for ((argpos=1; argpos<$#; argpos++)); do
|
35 |
+
# if [ "${!argpos}" == "--config" ]; then
|
36 |
+
# argpos_plus1=$((argpos+1))
|
37 |
+
# config=${!argpos_plus1}
|
38 |
+
# [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
39 |
+
# . $config # source the config file.
|
40 |
+
# fi
|
41 |
+
# done
|
42 |
+
|
43 |
+
|
44 |
+
###
|
45 |
+
### No we process the command line options
|
46 |
+
###
|
47 |
+
while true; do
|
48 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
49 |
+
case "$1" in
|
50 |
+
# If the enclosing script is called with --help option, print the help
|
51 |
+
# message and exit. Scripts should put help messages in $help_message
|
52 |
+
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
53 |
+
else printf "$help_message\n" 1>&2 ; fi;
|
54 |
+
exit 0 ;;
|
55 |
+
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
56 |
+
exit 1 ;;
|
57 |
+
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
58 |
+
# then work out the variable name as $name, which will equal "foo_bar".
|
59 |
+
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
60 |
+
# Next we test whether the variable in question is undefned-- if so it's
|
61 |
+
# an invalid option and we die. Note: $0 evaluates to the name of the
|
62 |
+
# enclosing script.
|
63 |
+
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
64 |
+
# is undefined. We then have to wrap this test inside "eval" because
|
65 |
+
# foo_bar is itself inside a variable ($name).
|
66 |
+
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
67 |
+
|
68 |
+
oldval="`eval echo \\$$name`";
|
69 |
+
# Work out whether we seem to be expecting a Boolean argument.
|
70 |
+
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
71 |
+
was_bool=true;
|
72 |
+
else
|
73 |
+
was_bool=false;
|
74 |
+
fi
|
75 |
+
|
76 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
77 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
78 |
+
eval $name=\"$2\";
|
79 |
+
|
80 |
+
# Check that Boolean-valued arguments are really Boolean.
|
81 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
82 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
83 |
+
exit 1;
|
84 |
+
fi
|
85 |
+
shift 2;
|
86 |
+
;;
|
87 |
+
*) break;
|
88 |
+
esac
|
89 |
+
done
|
90 |
+
|
91 |
+
|
92 |
+
# Check for an empty argument to the --cmd option, which can easily occur as a
|
93 |
+
# result of scripting errors.
|
94 |
+
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
95 |
+
|
96 |
+
|
97 |
+
true; # so this script returns exit code 0.
|
sparktts/utils/token_parser.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TASK_TOKEN_MAP = {
|
2 |
+
"vc": "<|task_vc|>",
|
3 |
+
"tts": "<|task_tts|>",
|
4 |
+
"asr": "<|task_asr|>",
|
5 |
+
"s2s": "<|task_s2s|>",
|
6 |
+
"t2s": "<|task_t2s|>",
|
7 |
+
"understand": "<|task_understand|>",
|
8 |
+
"caption": "<|task_cap|>",
|
9 |
+
"controllable_tts": "<|task_controllable_tts|>",
|
10 |
+
"prompt_tts": "<|task_prompt_tts|>",
|
11 |
+
"speech_edit": "<|task_edit|>",
|
12 |
+
}
|
13 |
+
|
14 |
+
LEVELS_MAP = {
|
15 |
+
"very_low": 0,
|
16 |
+
"low": 1,
|
17 |
+
"moderate": 2,
|
18 |
+
"high": 3,
|
19 |
+
"very_high": 4,
|
20 |
+
}
|
21 |
+
|
22 |
+
LEVELS_MAP_UI = {
|
23 |
+
1: 'very_low',
|
24 |
+
2: 'low',
|
25 |
+
3: 'moderate',
|
26 |
+
4: 'high',
|
27 |
+
5: 'very_high'
|
28 |
+
}
|
29 |
+
|
30 |
+
GENDER_MAP = {
|
31 |
+
"female": 0,
|
32 |
+
"male": 1,
|
33 |
+
}
|
34 |
+
|
35 |
+
AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
|
36 |
+
|
37 |
+
EMO_MAP = {
|
38 |
+
"UNKNOWN": 0,
|
39 |
+
"NEUTRAL": 1,
|
40 |
+
"ANGRY": 2,
|
41 |
+
"HAPPY": 3,
|
42 |
+
"SAD": 4,
|
43 |
+
"FEARFUL": 5,
|
44 |
+
"DISGUSTED": 6,
|
45 |
+
"SURPRISED": 7,
|
46 |
+
"SARCASTIC": 8,
|
47 |
+
"EXCITED": 9,
|
48 |
+
"SLEEPY": 10,
|
49 |
+
"CONFUSED": 11,
|
50 |
+
"EMPHASIS": 12,
|
51 |
+
"LAUGHING": 13,
|
52 |
+
"SINGING": 14,
|
53 |
+
"WORRIED": 15,
|
54 |
+
"WHISPER": 16,
|
55 |
+
"ANXIOUS": 17,
|
56 |
+
"NO-AGREEMENT": 18,
|
57 |
+
"APOLOGETIC": 19,
|
58 |
+
"CONCERNED": 20,
|
59 |
+
"ENUNCIATED": 21,
|
60 |
+
"ASSERTIVE": 22,
|
61 |
+
"ENCOURAGING": 23,
|
62 |
+
"CONTEMPT": 24,
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
class TokenParser:
|
67 |
+
"""Turn label to special token"""
|
68 |
+
|
69 |
+
def __init__(self):
|
70 |
+
pass
|
71 |
+
|
72 |
+
"""Parse the attributes of a person."""
|
73 |
+
|
74 |
+
def __init__(self):
|
75 |
+
pass
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def age(age: str) -> str:
|
79 |
+
"""Turn age token."""
|
80 |
+
age_id = AGE_MAP[age]
|
81 |
+
return f"<|age_{age_id}|>"
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def gender(gender: str) -> str:
|
85 |
+
"""Turn gender token."""
|
86 |
+
gender_id = GENDER_MAP[gender]
|
87 |
+
return f"<|gender_{gender_id}|>"
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def mel_value(mel: int):
|
91 |
+
"""Turn special token of mel scale pitch."""
|
92 |
+
mel = max(0, int(mel))
|
93 |
+
mel = min(1000, int(mel))
|
94 |
+
return f"<|pitch_value_{mel}|>"
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def mel_level(level: str):
|
98 |
+
"""Turn special token of mel level."""
|
99 |
+
level_tag = LEVELS_MAP[level]
|
100 |
+
return f"<|pitch_label_{level_tag}|>"
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def pitch_var_value(pitch_std: int):
|
104 |
+
"""Turn special token of pitch_std value."""
|
105 |
+
assert isinstance(pitch_std, int)
|
106 |
+
pitch_std = max(0, int(pitch_std))
|
107 |
+
pitch_std = min(10, int(pitch_std))
|
108 |
+
return f"<|pitch_var_value_{pitch_std}|>"
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def pitch_var_level(level: str):
|
112 |
+
"""Turn special token of pitch std level."""
|
113 |
+
level_tag = LEVELS_MAP[level]
|
114 |
+
return f"<|pitch_var_label_{level_tag}|>"
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def loudness_value(loudness: int):
|
118 |
+
"""Turn special toak of loudness value [0, 30]"""
|
119 |
+
assert loudness >= 0
|
120 |
+
loudness = max(0, int(loudness))
|
121 |
+
loudness = min(30, int(loudness))
|
122 |
+
return f"<|loudness_value_{loudness}|>"
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def loudness_level(level: str):
|
126 |
+
"""Turn special token of loudness level."""
|
127 |
+
level_tag = LEVELS_MAP[level]
|
128 |
+
return f"<|loudness_label_{level_tag}|>"
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def speed_value(speed: int):
|
132 |
+
"""Turn special token of speed value."""
|
133 |
+
speed = max(0, int(speed))
|
134 |
+
speed = min(10, int(speed))
|
135 |
+
return f"<|speed_value_{speed}|>"
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def speed_level(level: str):
|
139 |
+
"""Turn special token of speed level."""
|
140 |
+
level_tag = LEVELS_MAP[level]
|
141 |
+
return f"<|speed_label_{level_tag}|>"
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def task(task: str) -> str:
|
145 |
+
"""Turn special token of task."""
|
146 |
+
assert task in TASK_TOKEN_MAP.keys()
|
147 |
+
|
148 |
+
return TASK_TOKEN_MAP[task]
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
def emotion(emotion: str):
|
152 |
+
emo_id = EMO_MAP[emotion]
|
153 |
+
|
154 |
+
return f"<|emotion_{emo_id}|>"
|
155 |
+
|
156 |
+
|
157 |
+
# test
|
158 |
+
if __name__ == "__main__":
|
159 |
+
from transformers import AutoTokenizer
|
160 |
+
|
161 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
162 |
+
"/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer"
|
163 |
+
)
|
164 |
+
|
165 |
+
tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"]
|
166 |
+
ages = ["Child", "Teenager", "Youth-Adult", "Middle-aged", "Elderly"]
|
167 |
+
genders = ["female", "female", "female", "male", "male"]
|
168 |
+
mels = [100, 200, 300, 400, 500]
|
169 |
+
mel_levels = ["very_low", "low", "moderate", "high", "very_high"]
|
170 |
+
loudnesses = [1, 10, 23, 19, 30]
|
171 |
+
loudness_levels = ["very_low", "low", "moderate", "high", "very_high"]
|
172 |
+
emotions = ["UNKNOWN", "NEUTRAL", "ANGRY", "HAPPY", "SAD"]
|
173 |
+
|
174 |
+
for i in range(5):
|
175 |
+
task = TokenParser.task(tasks[i])
|
176 |
+
age = TokenParser.age(ages[i])
|
177 |
+
gender = TokenParser.gender(genders[i])
|
178 |
+
mel = TokenParser.mel_value(mels[i])
|
179 |
+
mel_level = TokenParser.mel_level(mel_levels[i])
|
180 |
+
loudness = TokenParser.loudness_value(loudnesses[i])
|
181 |
+
loudness_level = TokenParser.loudness_level(loudness_levels[i])
|
182 |
+
emotion = TokenParser.emotion(emotions[i])
|
183 |
+
inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion]
|
184 |
+
inputs = "".join(inputs)
|
185 |
+
ids = tokenizer.encode(inputs, add_special_tokens=False)
|
186 |
+
print(ids)
|
187 |
+
print("decode", tokenizer.decode(ids))
|