akulubala commited on
Commit
7cde53f
·
1 Parent(s): ea5fa21

start deploy

Browse files
Files changed (44) hide show
  1. .gitignore +6 -0
  2. LICENSE +201 -0
  3. app.py +278 -0
  4. cli/SparkTTS.py +236 -0
  5. cli/inference.py +116 -0
  6. datasets/.gitkeep +0 -0
  7. pretrained_models/.gitkeep +0 -0
  8. requirements.txt +14 -0
  9. runtime/triton_trtllm/Dockerfile.server +5 -0
  10. runtime/triton_trtllm/README.md +94 -0
  11. runtime/triton_trtllm/client_grpc.py +831 -0
  12. runtime/triton_trtllm/client_http.py +165 -0
  13. runtime/triton_trtllm/docker-compose.yml +20 -0
  14. runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py +137 -0
  15. runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt +58 -0
  16. runtime/triton_trtllm/model_repo/spark_tts/1/model.py +404 -0
  17. runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt +86 -0
  18. runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep +0 -0
  19. runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt +857 -0
  20. runtime/triton_trtllm/model_repo/vocoder/1/model.py +106 -0
  21. runtime/triton_trtllm/model_repo/vocoder/config.pbtxt +53 -0
  22. runtime/triton_trtllm/run.sh +109 -0
  23. runtime/triton_trtllm/scripts/convert_checkpoint.py +335 -0
  24. runtime/triton_trtllm/scripts/fill_template.py +70 -0
  25. sparktts/models/audio_tokenizer.py +163 -0
  26. sparktts/models/bicodec.py +247 -0
  27. sparktts/modules/blocks/layers.py +73 -0
  28. sparktts/modules/blocks/samper.py +115 -0
  29. sparktts/modules/blocks/vocos.py +373 -0
  30. sparktts/modules/encoder_decoder/feat_decoder.py +115 -0
  31. sparktts/modules/encoder_decoder/feat_encoder.py +105 -0
  32. sparktts/modules/encoder_decoder/wave_generator.py +88 -0
  33. sparktts/modules/fsq/finite_scalar_quantization.py +251 -0
  34. sparktts/modules/fsq/residual_fsq.py +355 -0
  35. sparktts/modules/speaker/ecapa_tdnn.py +267 -0
  36. sparktts/modules/speaker/perceiver_encoder.py +360 -0
  37. sparktts/modules/speaker/pooling_layers.py +298 -0
  38. sparktts/modules/speaker/speaker_encoder.py +136 -0
  39. sparktts/modules/vq/factorized_vector_quantize.py +187 -0
  40. sparktts/utils/__init__.py +0 -0
  41. sparktts/utils/audio.py +271 -0
  42. sparktts/utils/file.py +221 -0
  43. sparktts/utils/parse_options.sh +97 -0
  44. sparktts/utils/token_parser.py +187 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ example/
2
+ src/
3
+ pretrained_models/*
4
+ !pretrained_models/.gitkeep
5
+ datasets/*
6
+ !datasets/.gitkeep
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import torch
18
+ import soundfile as sf
19
+ import logging
20
+ import argparse
21
+ import gradio as gr
22
+ import platform
23
+
24
+ from datetime import datetime
25
+ from cli.SparkTTS import SparkTTS
26
+ from sparktts.utils.token_parser import LEVELS_MAP_UI
27
+ from huggingface_hub import snapshot_download
28
+
29
+
30
+ def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0):
31
+ """Load the model once at the beginning."""
32
+ logging.info(f"Loading model from: {model_dir}")
33
+
34
+ # Determine appropriate device based on platform and availability
35
+ if platform.system() == "Darwin":
36
+ # macOS with MPS support (Apple Silicon)
37
+ device = torch.device(f"mps:{device}")
38
+ logging.info(f"Using MPS device: {device}")
39
+ elif torch.cuda.is_available():
40
+ # System with CUDA support
41
+ device = torch.device(f"cuda:{device}")
42
+ logging.info(f"Using CUDA device: {device}")
43
+ else:
44
+ # Fall back to CPU
45
+ device = torch.device("cpu")
46
+ logging.info("GPU acceleration not available, using CPU")
47
+
48
+ model = SparkTTS(model_dir, device)
49
+ return model
50
+
51
+
52
+ def run_tts(
53
+ text,
54
+ model,
55
+ prompt_text=None,
56
+ prompt_speech=None,
57
+ gender=None,
58
+ pitch=None,
59
+ speed=None,
60
+ save_dir="example/results",
61
+ ):
62
+ """Perform TTS inference and save the generated audio."""
63
+ logging.info(f"Saving audio to: {save_dir}")
64
+
65
+ if prompt_text is not None:
66
+ prompt_text = None if len(prompt_text) <= 1 else prompt_text
67
+
68
+ # Ensure the save directory exists
69
+ os.makedirs(save_dir, exist_ok=True)
70
+
71
+ # Generate unique filename using timestamp
72
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
73
+ save_path = os.path.join(save_dir, f"{timestamp}.wav")
74
+
75
+ logging.info("Starting inference...")
76
+
77
+ # Perform inference and save the output audio
78
+ with torch.no_grad():
79
+ wav = model.inference(
80
+ text,
81
+ prompt_speech,
82
+ prompt_text,
83
+ gender,
84
+ pitch,
85
+ speed,
86
+ )
87
+
88
+ sf.write(save_path, wav, samplerate=16000)
89
+
90
+ logging.info(f"Audio saved at: {save_path}")
91
+
92
+ return save_path
93
+
94
+
95
+ def build_ui(model_dir, device=0):
96
+
97
+ # Initialize model
98
+ model = initialize_model(model_dir, device=device)
99
+
100
+ # Define callback function for voice cloning
101
+ def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
102
+ """
103
+ Gradio callback to clone voice using text and optional prompt speech.
104
+ - text: The input text to be synthesised.
105
+ - prompt_text: Additional textual info for the prompt (optional).
106
+ - prompt_wav_upload/prompt_wav_record: Audio files used as reference.
107
+ """
108
+ prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
109
+ prompt_text_clean = None if len(prompt_text) < 2 else prompt_text
110
+
111
+ audio_output_path = run_tts(
112
+ text,
113
+ model,
114
+ prompt_text=prompt_text_clean,
115
+ prompt_speech=prompt_speech
116
+ )
117
+ return audio_output_path
118
+
119
+ # Define callback function for creating new voices
120
+ def voice_creation(text, gender, pitch, speed):
121
+ """
122
+ Gradio callback to create a synthetic voice with adjustable parameters.
123
+ - text: The input text for synthesis.
124
+ - gender: 'male' or 'female'.
125
+ - pitch/speed: Ranges mapped by LEVELS_MAP_UI.
126
+ """
127
+ pitch_val = LEVELS_MAP_UI[int(pitch)]
128
+ speed_val = LEVELS_MAP_UI[int(speed)]
129
+ audio_output_path = run_tts(
130
+ text,
131
+ model,
132
+ gender=gender,
133
+ pitch=pitch_val,
134
+ speed=speed_val
135
+ )
136
+ return audio_output_path
137
+
138
+ with gr.Blocks() as demo:
139
+ # Use HTML for centered title
140
+ gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>')
141
+ with gr.Tabs():
142
+ # Voice Clone Tab
143
+ with gr.TabItem("Voice Clone"):
144
+ gr.Markdown(
145
+ "### Upload reference audio or recording (上传参考音频或者录音)"
146
+ )
147
+
148
+ with gr.Row():
149
+ prompt_wav_upload = gr.Audio(
150
+ sources="upload",
151
+ type="filepath",
152
+ label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.",
153
+ )
154
+ prompt_wav_record = gr.Audio(
155
+ sources="microphone",
156
+ type="filepath",
157
+ label="Record the prompt audio file.",
158
+ )
159
+
160
+ with gr.Row():
161
+ text_input = gr.Textbox(
162
+ label="Text", lines=3, placeholder="Enter text here"
163
+ )
164
+ prompt_text_input = gr.Textbox(
165
+ label="Text of prompt speech (Optional; recommended for cloning in the same language.)",
166
+ lines=3,
167
+ placeholder="Enter text of the prompt speech.",
168
+ )
169
+
170
+ audio_output = gr.Audio(
171
+ label="Generated Audio", autoplay=True, streaming=True
172
+ )
173
+
174
+ generate_buttom_clone = gr.Button("Generate")
175
+
176
+ generate_buttom_clone.click(
177
+ voice_clone,
178
+ inputs=[
179
+ text_input,
180
+ prompt_text_input,
181
+ prompt_wav_upload,
182
+ prompt_wav_record,
183
+ ],
184
+ outputs=[audio_output],
185
+ )
186
+
187
+ # Voice Creation Tab
188
+ with gr.TabItem("Voice Creation"):
189
+ gr.Markdown(
190
+ "### Create your own voice based on the following parameters"
191
+ )
192
+
193
+ with gr.Row():
194
+ with gr.Column():
195
+ gender = gr.Radio(
196
+ choices=["male", "female"], value="male", label="Gender"
197
+ )
198
+ pitch = gr.Slider(
199
+ minimum=1, maximum=5, step=1, value=3, label="Pitch"
200
+ )
201
+ speed = gr.Slider(
202
+ minimum=1, maximum=5, step=1, value=3, label="Speed"
203
+ )
204
+ with gr.Column():
205
+ text_input_creation = gr.Textbox(
206
+ label="Input Text",
207
+ lines=3,
208
+ placeholder="Enter text here",
209
+ value="You can generate a customized voice by adjusting parameters such as pitch and speed.",
210
+ )
211
+ create_button = gr.Button("Create Voice")
212
+
213
+ audio_output = gr.Audio(
214
+ label="Generated Audio", autoplay=True, streaming=True
215
+ )
216
+ create_button.click(
217
+ voice_creation,
218
+ inputs=[text_input_creation, gender, pitch, speed],
219
+ outputs=[audio_output],
220
+ )
221
+
222
+ return demo
223
+
224
+
225
+ def parse_arguments():
226
+ """
227
+ Parse command-line arguments such as model directory and device ID.
228
+ """
229
+ parser = argparse.ArgumentParser(description="Spark TTS Gradio server.")
230
+ parser.add_argument(
231
+ "--model_dir",
232
+ type=str,
233
+ default="pretrained_models/Spark-TTS-0.5B",
234
+ help="Path to the model directory."
235
+ )
236
+ parser.add_argument(
237
+ "--device",
238
+ type=int,
239
+ default=0,
240
+ help="ID of the GPU device to use (e.g., 0 for cuda:0)."
241
+ )
242
+ parser.add_argument(
243
+ "--server_name",
244
+ type=str,
245
+ default="0.0.0.0",
246
+ help="Server host/IP for Gradio app."
247
+ )
248
+ parser.add_argument(
249
+ "--server_port",
250
+ type=int,
251
+ default=7860,
252
+ help="Server port for Gradio app."
253
+ )
254
+ return parser.parse_args()
255
+
256
+ if __name__ == "__main__":
257
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
258
+ torch.backends.mps.is_available = lambda: False
259
+
260
+
261
+ ## if model not downloaded, download it
262
+ if not os.path.exists("pretrained_models/Spark-TTS-0.5B"):
263
+ snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
264
+
265
+ # Parse command-line arguments
266
+ args = parse_arguments()
267
+
268
+ # Build the Gradio demo by specifying the model directory and GPU device
269
+ demo = build_ui(
270
+ model_dir=args.model_dir,
271
+ device=args.device
272
+ )
273
+
274
+ # Launch Gradio with the specified server name and port
275
+ demo.launch(
276
+ server_name=args.server_name,
277
+ server_port=args.server_port
278
+ )
cli/SparkTTS.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ import torch
18
+ from typing import Tuple
19
+ from pathlib import Path
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM
21
+
22
+ from sparktts.utils.file import load_config
23
+ from sparktts.models.audio_tokenizer import BiCodecTokenizer
24
+ from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
25
+
26
+
27
+ class SparkTTS:
28
+ """
29
+ Spark-TTS for text-to-speech generation.
30
+ """
31
+
32
+ def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")):
33
+ """
34
+ Initializes the SparkTTS model with the provided configurations and device.
35
+
36
+ Args:
37
+ model_dir (Path): Directory containing the model and config files.
38
+ device (torch.device): The device (CPU/GPU) to run the model on.
39
+ """
40
+ self.device = "cpu"
41
+ self.model_dir = model_dir
42
+ self.configs = load_config(f"{model_dir}/config.yaml")
43
+ self.sample_rate = self.configs["sample_rate"]
44
+ self._initialize_inference()
45
+
46
+ def _initialize_inference(self):
47
+ """Initializes the tokenizer, model, and audio tokenizer for inference."""
48
+ self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM")
49
+ self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
50
+ self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
51
+ self.model.to(self.device)
52
+
53
+ def process_prompt(
54
+ self,
55
+ text: str,
56
+ prompt_speech_path: Path,
57
+ prompt_text: str = None,
58
+ ) -> Tuple[str, torch.Tensor]:
59
+ """
60
+ Process input for voice cloning.
61
+
62
+ Args:
63
+ text (str): The text input to be converted to speech.
64
+ prompt_speech_path (Path): Path to the audio file used as a prompt.
65
+ prompt_text (str, optional): Transcript of the prompt audio.
66
+
67
+ Return:
68
+ Tuple[str, torch.Tensor]: Input prompt; global tokens
69
+ """
70
+
71
+ global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
72
+ prompt_speech_path
73
+ )
74
+ global_tokens = "".join(
75
+ [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
76
+ )
77
+
78
+ # Prepare the input tokens for the model
79
+ if prompt_text is not None:
80
+ semantic_tokens = "".join(
81
+ [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
82
+ )
83
+ inputs = [
84
+ TASK_TOKEN_MAP["tts"],
85
+ "<|start_content|>",
86
+ prompt_text,
87
+ text,
88
+ "<|end_content|>",
89
+ "<|start_global_token|>",
90
+ global_tokens,
91
+ "<|end_global_token|>",
92
+ "<|start_semantic_token|>",
93
+ semantic_tokens,
94
+ ]
95
+ else:
96
+ inputs = [
97
+ TASK_TOKEN_MAP["tts"],
98
+ "<|start_content|>",
99
+ text,
100
+ "<|end_content|>",
101
+ "<|start_global_token|>",
102
+ global_tokens,
103
+ "<|end_global_token|>",
104
+ ]
105
+
106
+ inputs = "".join(inputs)
107
+
108
+ return inputs, global_token_ids
109
+
110
+ def process_prompt_control(
111
+ self,
112
+ gender: str,
113
+ pitch: str,
114
+ speed: str,
115
+ text: str,
116
+ ):
117
+ """
118
+ Process input for voice creation.
119
+
120
+ Args:
121
+ gender (str): female | male.
122
+ pitch (str): very_low | low | moderate | high | very_high
123
+ speed (str): very_low | low | moderate | high | very_high
124
+ text (str): The text input to be converted to speech.
125
+
126
+ Return:
127
+ str: Input prompt
128
+ """
129
+ assert gender in GENDER_MAP.keys()
130
+ assert pitch in LEVELS_MAP.keys()
131
+ assert speed in LEVELS_MAP.keys()
132
+
133
+ gender_id = GENDER_MAP[gender]
134
+ pitch_level_id = LEVELS_MAP[pitch]
135
+ speed_level_id = LEVELS_MAP[speed]
136
+
137
+ pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
138
+ speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
139
+ gender_tokens = f"<|gender_{gender_id}|>"
140
+
141
+ attribte_tokens = "".join(
142
+ [gender_tokens, pitch_label_tokens, speed_label_tokens]
143
+ )
144
+
145
+ control_tts_inputs = [
146
+ TASK_TOKEN_MAP["controllable_tts"],
147
+ "<|start_content|>",
148
+ text,
149
+ "<|end_content|>",
150
+ "<|start_style_label|>",
151
+ attribte_tokens,
152
+ "<|end_style_label|>",
153
+ ]
154
+
155
+ return "".join(control_tts_inputs)
156
+
157
+ @torch.no_grad()
158
+ def inference(
159
+ self,
160
+ text: str,
161
+ prompt_speech_path: Path = None,
162
+ prompt_text: str = None,
163
+ gender: str = None,
164
+ pitch: str = None,
165
+ speed: str = None,
166
+ temperature: float = 0.8,
167
+ top_k: float = 50,
168
+ top_p: float = 0.95,
169
+ ) -> torch.Tensor:
170
+ """
171
+ Performs inference to generate speech from text, incorporating prompt audio and/or text.
172
+
173
+ Args:
174
+ text (str): The text input to be converted to speech.
175
+ prompt_speech_path (Path): Path to the audio file used as a prompt.
176
+ prompt_text (str, optional): Transcript of the prompt audio.
177
+ gender (str): female | male.
178
+ pitch (str): very_low | low | moderate | high | very_high
179
+ speed (str): very_low | low | moderate | high | very_high
180
+ temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
181
+ top_k (float, optional): Top-k sampling parameter. Default is 50.
182
+ top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
183
+
184
+ Returns:
185
+ torch.Tensor: Generated waveform as a tensor.
186
+ """
187
+ if gender is not None:
188
+ prompt = self.process_prompt_control(gender, pitch, speed, text)
189
+
190
+ else:
191
+ prompt, global_token_ids = self.process_prompt(
192
+ text, prompt_speech_path, prompt_text
193
+ )
194
+ model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
195
+
196
+ # Generate speech using the model
197
+ generated_ids = self.model.generate(
198
+ **model_inputs,
199
+ max_new_tokens=3000,
200
+ do_sample=True,
201
+ top_k=top_k,
202
+ top_p=top_p,
203
+ temperature=temperature,
204
+ )
205
+
206
+ # Trim the output tokens to remove the input tokens
207
+ generated_ids = [
208
+ output_ids[len(input_ids) :]
209
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
210
+ ]
211
+
212
+ # Decode the generated tokens into text
213
+ predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
214
+
215
+ # Extract semantic token IDs from the generated text
216
+ pred_semantic_ids = (
217
+ torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
218
+ .long()
219
+ .unsqueeze(0)
220
+ )
221
+
222
+ if gender is not None:
223
+ global_token_ids = (
224
+ torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
225
+ .long()
226
+ .unsqueeze(0)
227
+ .unsqueeze(0)
228
+ )
229
+
230
+ # Convert semantic tokens back to waveform
231
+ wav = self.audio_tokenizer.detokenize(
232
+ global_token_ids.to(self.device).squeeze(0),
233
+ pred_semantic_ids.to(self.device),
234
+ )
235
+
236
+ return wav
cli/inference.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import argparse
19
+ import torch
20
+ import soundfile as sf
21
+ import logging
22
+ from datetime import datetime
23
+ import platform
24
+
25
+ from cli.SparkTTS import SparkTTS
26
+
27
+
28
+ def parse_args():
29
+ """Parse command-line arguments."""
30
+ parser = argparse.ArgumentParser(description="Run TTS inference.")
31
+
32
+ parser.add_argument(
33
+ "--model_dir",
34
+ type=str,
35
+ default="pretrained_models/Spark-TTS-0.5B",
36
+ help="Path to the model directory",
37
+ )
38
+ parser.add_argument(
39
+ "--save_dir",
40
+ type=str,
41
+ default="example/results",
42
+ help="Directory to save generated audio files",
43
+ )
44
+ parser.add_argument("--device", type=int, default=0, help="CUDA device number")
45
+ parser.add_argument(
46
+ "--text", type=str, required=True, help="Text for TTS generation"
47
+ )
48
+ parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
49
+ parser.add_argument(
50
+ "--prompt_speech_path",
51
+ type=str,
52
+ help="Path to the prompt audio file",
53
+ )
54
+ parser.add_argument("--gender", choices=["male", "female"])
55
+ parser.add_argument(
56
+ "--pitch", choices=["very_low", "low", "moderate", "high", "very_high"]
57
+ )
58
+ parser.add_argument(
59
+ "--speed", choices=["very_low", "low", "moderate", "high", "very_high"]
60
+ )
61
+ return parser.parse_args()
62
+
63
+
64
+ def run_tts(args):
65
+ """Perform TTS inference and save the generated audio."""
66
+ logging.info(f"Using model from: {args.model_dir}")
67
+ logging.info(f"Saving audio to: {args.save_dir}")
68
+
69
+ # Ensure the save directory exists
70
+ os.makedirs(args.save_dir, exist_ok=True)
71
+
72
+ # Convert device argument to torch.device
73
+ if platform.system() == "Darwin" and torch.backends.mps.is_available():
74
+ # macOS with MPS support (Apple Silicon)
75
+ device = torch.device(f"mps:{args.device}")
76
+ logging.info(f"Using MPS device: {device}")
77
+ elif torch.cuda.is_available():
78
+ # System with CUDA support
79
+ device = torch.device(f"cuda:{args.device}")
80
+ logging.info(f"Using CUDA device: {device}")
81
+ else:
82
+ # Fall back to CPU
83
+ device = torch.device("cpu")
84
+ logging.info("GPU acceleration not available, using CPU")
85
+
86
+ # Initialize the model
87
+ model = SparkTTS(args.model_dir, device)
88
+
89
+ # Generate unique filename using timestamp
90
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
91
+ save_path = os.path.join(args.save_dir, f"{timestamp}.wav")
92
+
93
+ logging.info("Starting inference...")
94
+
95
+ # Perform inference and save the output audio
96
+ with torch.no_grad():
97
+ wav = model.inference(
98
+ args.text,
99
+ args.prompt_speech_path,
100
+ prompt_text=args.prompt_text,
101
+ gender=args.gender,
102
+ pitch=args.pitch,
103
+ speed=args.speed,
104
+ )
105
+ sf.write(save_path, wav, samplerate=16000)
106
+
107
+ logging.info(f"Audio saved at: {save_path}")
108
+
109
+
110
+ if __name__ == "__main__":
111
+ logging.basicConfig(
112
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
113
+ )
114
+
115
+ args = parse_args()
116
+ run_tts(args)
datasets/.gitkeep ADDED
File without changes
pretrained_models/.gitkeep ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.8.1
2
+ einx==0.3.0
3
+ numpy==2.2.3
4
+ omegaconf==2.3.0
5
+ packaging==24.2
6
+ safetensors==0.5.2
7
+ soundfile==0.12.1
8
+ soxr==0.5.0.post1
9
+ torch==2.5.1
10
+ torchaudio==2.5.1
11
+ tqdm==4.66.5
12
+ transformers==4.46.2
13
+ gradio==5.18.0
14
+ huggingface_hub
runtime/triton_trtllm/Dockerfile.server ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/tritonserver:25.02-trtllm-python-py3
2
+ RUN apt-get update && apt-get install -y cmake
3
+ RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
4
+ RUN pip install einx==0.3.0 omegaconf==2.3.0 soundfile==0.12.1 soxr==0.5.0.post1 gradio tritonclient librosa
5
+ WORKDIR /workspace
runtime/triton_trtllm/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Nvidia Triton Inference Serving Best Practice for Spark TTS
2
+
3
+ ### Quick Start
4
+ Directly launch the service using docker compose.
5
+ ```sh
6
+ docker compose up
7
+ ```
8
+
9
+ ### Build Image
10
+ Build the docker image from scratch.
11
+ ```sh
12
+ docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02
13
+ ```
14
+
15
+ ### Create Docker Container
16
+ ```sh
17
+ your_mount_dir=/mnt:/mnt
18
+ docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02
19
+ ```
20
+
21
+ ### Understanding `run.sh`
22
+
23
+ The `run.sh` script automates various steps using stages. You can run specific stages using:
24
+ ```sh
25
+ bash run.sh <start_stage> <stop_stage> [service_type]
26
+ ```
27
+ - `<start_stage>`: The stage to begin execution from (0-5).
28
+ - `<stop_stage>`: The stage to end execution at (0-5).
29
+ - `[service_type]`: Optional, specifies the service type ('streaming' or 'offline', defaults may apply based on script logic). Required for stages 4 and 5.
30
+
31
+ Stages:
32
+ - **Stage 0**: Download Spark-TTS-0.5B model from HuggingFace.
33
+ - **Stage 1**: Convert HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines.
34
+ - **Stage 2**: Create the Triton model repository structure and configure model files (adjusts for streaming/offline).
35
+ - **Stage 3**: Launch the Triton Inference Server.
36
+ - **Stage 4**: Run the gRPC benchmark client.
37
+ - **Stage 5**: Run the single utterance client (gRPC for streaming, HTTP for offline).
38
+
39
+ ### Export Models to TensorRT-LLM and Launch Server
40
+ Inside the docker container, you can prepare the models and launch the Triton server by running stages 0 through 3. This involves downloading the original model, converting it to TensorRT-LLM format, building the optimized TensorRT engines, creating the necessary model repository structure for Triton, and finally starting the server.
41
+ ```sh
42
+ # This runs stages 0, 1, 2, and 3
43
+ bash run.sh 0 3
44
+ ```
45
+ *Note: Stage 2 prepares the model repository differently based on whether you intend to run streaming or offline inference later. You might need to re-run stage 2 if switching service types.*
46
+
47
+
48
+ ### Single Utterance Client
49
+ Run a single inference request. Specify `streaming` or `offline` as the third argument.
50
+
51
+ **Streaming Mode (gRPC):**
52
+ ```sh
53
+ bash run.sh 5 5 streaming
54
+ ```
55
+ This executes the `client_grpc.py` script with predefined example text and prompt audio in streaming mode.
56
+
57
+ **Offline Mode (HTTP):**
58
+ ```sh
59
+ bash run.sh 5 5 offline
60
+ ```
61
+
62
+ ### Benchmark using Dataset
63
+ Run the benchmark client against the running Triton server. Specify `streaming` or `offline` as the third argument.
64
+ ```sh
65
+ # Run benchmark in streaming mode
66
+ bash run.sh 4 4 streaming
67
+
68
+ # Run benchmark in offline mode
69
+ bash run.sh 4 4 offline
70
+
71
+ # You can also customize parameters like num_task directly in client_grpc.py or via args if supported
72
+ # Example from run.sh (streaming):
73
+ # python3 client_grpc.py \
74
+ # --server-addr localhost \
75
+ # --model-name spark_tts \
76
+ # --num-tasks 2 \
77
+ # --mode streaming \
78
+ # --log-dir ./log_concurrent_tasks_2_streaming_new
79
+
80
+ # Example customizing dataset (requires modifying client_grpc.py or adding args):
81
+ # python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --mode [streaming|offline]
82
+ ```
83
+
84
+ ### Benchmark Results
85
+ Decoding on a single L20 GPU, using 26 different prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts), total audio duration 169 secs.
86
+
87
+ | Mode | Note | Concurrency | Avg Latency | First Chunk Latency (P50) | RTF |
88
+ |-------|-----------|-----------------------|---------|----------------|-|
89
+ | Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms |-| 0.1362|
90
+ | Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms |-|0.0737|
91
+ | Offline | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms |-| 0.0704|
92
+ | Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 1 | 913.28 ms |210.42 ms| 0.1501 |
93
+ | Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 2 | 1009.23 ms |226.08 ms |0.0862 |
94
+ | Streaming | [Code Commit](https://github.com/yuekaizhang/Spark-TTS/commit/0e978a327f99aa49f0735f86eb09372f16410d86) | 4 | 1793.86 ms |1017.70 ms| 0.0824 |
runtime/triton_trtllm/client_grpc.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
3
+ # 2023 Nvidia (authors: Yuekai Zhang)
4
+ # 2023 Recurrent.ai (authors: Songtao Shi)
5
+ # See LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ This script supports to load dataset from huggingface and sends it to the server
20
+ for decoding, in parallel.
21
+
22
+ Usage:
23
+ num_task=2
24
+
25
+ # For offline F5-TTS
26
+ python3 client_grpc.py \
27
+ --server-addr localhost \
28
+ --model-name f5_tts \
29
+ --num-tasks $num_task \
30
+ --huggingface-dataset yuekai/seed_tts \
31
+ --split-name test_zh \
32
+ --log-dir ./log_concurrent_tasks_${num_task}
33
+
34
+ # For offline Spark-TTS-0.5B
35
+ python3 client_grpc.py \
36
+ --server-addr localhost \
37
+ --model-name spark_tts \
38
+ --num-tasks $num_task \
39
+ --huggingface-dataset yuekai/seed_tts \
40
+ --split-name wenetspeech4tts \
41
+ --log-dir ./log_concurrent_tasks_${num_task}
42
+ """
43
+
44
+ import argparse
45
+ import asyncio
46
+ import json
47
+ import queue # Added
48
+ import uuid # Added
49
+ import functools # Added
50
+
51
+ import os
52
+ import time
53
+ import types
54
+ from pathlib import Path
55
+
56
+ import numpy as np
57
+ import soundfile as sf
58
+ import tritonclient
59
+ import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
60
+ import tritonclient.grpc as grpcclient_sync # Added sync client import
61
+ from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
62
+
63
+
64
+ # --- Added UserData and callback ---
65
+ class UserData:
66
+ def __init__(self):
67
+ self._completed_requests = queue.Queue()
68
+ self._first_chunk_time = None
69
+ self._start_time = None
70
+
71
+ def record_start_time(self):
72
+ self._start_time = time.time()
73
+
74
+ def get_first_chunk_latency(self):
75
+ if self._first_chunk_time and self._start_time:
76
+ return self._first_chunk_time - self._start_time
77
+ return None
78
+
79
+ def callback(user_data, result, error):
80
+ if user_data._first_chunk_time is None and not error:
81
+ user_data._first_chunk_time = time.time() # Record time of first successful chunk
82
+ if error:
83
+ user_data._completed_requests.put(error)
84
+ else:
85
+ user_data._completed_requests.put(result)
86
+ # --- End Added UserData and callback ---
87
+
88
+
89
+ def write_triton_stats(stats, summary_file):
90
+ with open(summary_file, "w") as summary_f:
91
+ model_stats = stats["model_stats"]
92
+ # write a note, the log is from triton_client.get_inference_statistics(), to better human readability
93
+ summary_f.write(
94
+ "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
95
+ )
96
+ summary_f.write("To learn more about the log, please refer to: \n")
97
+ summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
98
+ summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
99
+ summary_f.write(
100
+ "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
101
+ )
102
+ summary_f.write(
103
+ "However, there is a trade-off between the increased queue time and the increased batch size. \n"
104
+ )
105
+ summary_f.write(
106
+ "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
107
+ )
108
+ summary_f.write(
109
+ "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
110
+ )
111
+ for model_state in model_stats:
112
+ if "last_inference" not in model_state:
113
+ continue
114
+ summary_f.write(f"model name is {model_state['name']} \n")
115
+ model_inference_stats = model_state["inference_stats"]
116
+ total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
117
+ total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
118
+ total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
119
+ total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
120
+ summary_f.write(
121
+ f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
122
+ )
123
+ model_batch_stats = model_state["batch_stats"]
124
+ for batch in model_batch_stats:
125
+ batch_size = int(batch["batch_size"])
126
+ compute_input = batch["compute_input"]
127
+ compute_output = batch["compute_output"]
128
+ compute_infer = batch["compute_infer"]
129
+ batch_count = int(compute_infer["count"])
130
+ assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
131
+ compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
132
+ compute_input_time_ms = int(compute_input["ns"]) / 1e6
133
+ compute_output_time_ms = int(compute_output["ns"]) / 1e6
134
+ summary_f.write(
135
+ f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
136
+ )
137
+ summary_f.write(
138
+ f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
139
+ )
140
+ summary_f.write(
141
+ f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
142
+ )
143
+
144
+
145
+ def get_args():
146
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
147
+
148
+ parser.add_argument(
149
+ "--server-addr",
150
+ type=str,
151
+ default="localhost",
152
+ help="Address of the server",
153
+ )
154
+
155
+ parser.add_argument(
156
+ "--server-port",
157
+ type=int,
158
+ default=8001,
159
+ help="Grpc port of the triton server, default is 8001",
160
+ )
161
+
162
+ parser.add_argument(
163
+ "--reference-audio",
164
+ type=str,
165
+ default=None,
166
+ help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
167
+ )
168
+
169
+ parser.add_argument(
170
+ "--reference-text",
171
+ type=str,
172
+ default="",
173
+ help="",
174
+ )
175
+
176
+ parser.add_argument(
177
+ "--target-text",
178
+ type=str,
179
+ default="",
180
+ help="",
181
+ )
182
+
183
+ parser.add_argument(
184
+ "--huggingface-dataset",
185
+ type=str,
186
+ default="yuekai/seed_tts",
187
+ help="dataset name in huggingface dataset hub",
188
+ )
189
+
190
+ parser.add_argument(
191
+ "--split-name",
192
+ type=str,
193
+ default="wenetspeech4tts",
194
+ choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
195
+ help="dataset split name, default is 'test'",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--manifest-path",
200
+ type=str,
201
+ default=None,
202
+ help="Path to the manifest dir which includes wav.scp trans.txt files.",
203
+ )
204
+
205
+ parser.add_argument(
206
+ "--model-name",
207
+ type=str,
208
+ default="f5_tts",
209
+ choices=["f5_tts", "spark_tts"],
210
+ help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
211
+ )
212
+
213
+ parser.add_argument(
214
+ "--num-tasks",
215
+ type=int,
216
+ default=1,
217
+ help="Number of concurrent tasks for sending",
218
+ )
219
+
220
+ parser.add_argument(
221
+ "--log-interval",
222
+ type=int,
223
+ default=5,
224
+ help="Controls how frequently we print the log.",
225
+ )
226
+
227
+ parser.add_argument(
228
+ "--compute-wer",
229
+ action="store_true",
230
+ default=False,
231
+ help="""True to compute WER.
232
+ """,
233
+ )
234
+
235
+ parser.add_argument(
236
+ "--log-dir",
237
+ type=str,
238
+ required=False,
239
+ default="./tmp",
240
+ help="log directory",
241
+ )
242
+
243
+ # --- Added arguments ---
244
+ parser.add_argument(
245
+ "--mode",
246
+ type=str,
247
+ default="offline",
248
+ choices=["offline", "streaming"],
249
+ help="Select offline or streaming benchmark mode."
250
+ )
251
+ parser.add_argument(
252
+ "--chunk-overlap-duration",
253
+ type=float,
254
+ default=0.1,
255
+ help="Chunk overlap duration for streaming reconstruction (in seconds)."
256
+ )
257
+ # --- End Added arguments ---
258
+
259
+ return parser.parse_args()
260
+
261
+
262
+ def load_audio(wav_path, target_sample_rate=16000):
263
+ assert target_sample_rate == 16000, "hard coding in server"
264
+ if isinstance(wav_path, dict):
265
+ waveform = wav_path["array"]
266
+ sample_rate = wav_path["sampling_rate"]
267
+ else:
268
+ waveform, sample_rate = sf.read(wav_path)
269
+ if sample_rate != target_sample_rate:
270
+ from scipy.signal import resample
271
+
272
+ num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
273
+ waveform = resample(waveform, num_samples)
274
+ return waveform, target_sample_rate
275
+
276
+ def prepare_request_input_output(
277
+ protocol_client, # Can be grpcclient_aio or grpcclient_sync
278
+ waveform,
279
+ reference_text,
280
+ target_text,
281
+ sample_rate=16000,
282
+ padding_duration: int = None # Optional padding for offline mode
283
+ ):
284
+ """Prepares inputs for Triton inference (offline or streaming)."""
285
+ assert len(waveform.shape) == 1, "waveform should be 1D"
286
+ lengths = np.array([[len(waveform)]], dtype=np.int32)
287
+
288
+ # Apply padding only if padding_duration is provided (for offline)
289
+ if padding_duration:
290
+ duration = len(waveform) / sample_rate
291
+ # Estimate target duration based on text length ratio (crude estimation)
292
+ # Avoid division by zero if reference_text is empty
293
+ if reference_text:
294
+ estimated_target_duration = duration / len(reference_text) * len(target_text)
295
+ else:
296
+ estimated_target_duration = duration # Assume target duration similar to reference if no text
297
+
298
+ # Calculate required samples based on estimated total duration
299
+ required_total_samples = padding_duration * sample_rate * (
300
+ (int(estimated_target_duration + duration) // padding_duration) + 1
301
+ )
302
+ samples = np.zeros((1, required_total_samples), dtype=np.float32)
303
+ samples[0, : len(waveform)] = waveform
304
+ else:
305
+ # No padding for streaming or if padding_duration is None
306
+ samples = waveform.reshape(1, -1).astype(np.float32)
307
+
308
+ # Common input creation logic
309
+ inputs = [
310
+ protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
311
+ protocol_client.InferInput(
312
+ "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
313
+ ),
314
+ protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
315
+ protocol_client.InferInput("target_text", [1, 1], "BYTES"),
316
+ ]
317
+ inputs[0].set_data_from_numpy(samples)
318
+ inputs[1].set_data_from_numpy(lengths)
319
+
320
+ input_data_numpy = np.array([reference_text], dtype=object)
321
+ input_data_numpy = input_data_numpy.reshape((1, 1))
322
+ inputs[2].set_data_from_numpy(input_data_numpy)
323
+
324
+ input_data_numpy = np.array([target_text], dtype=object)
325
+ input_data_numpy = input_data_numpy.reshape((1, 1))
326
+ inputs[3].set_data_from_numpy(input_data_numpy)
327
+
328
+ outputs = [protocol_client.InferRequestedOutput("waveform")]
329
+
330
+ return inputs, outputs
331
+
332
+ def run_sync_streaming_inference(
333
+ sync_triton_client: tritonclient.grpc.InferenceServerClient,
334
+ model_name: str,
335
+ inputs: list,
336
+ outputs: list,
337
+ request_id: str,
338
+ user_data: UserData,
339
+ chunk_overlap_duration: float,
340
+ save_sample_rate: int,
341
+ audio_save_path: str,
342
+ ):
343
+ """Helper function to run the blocking sync streaming call."""
344
+ start_time_total = time.time()
345
+ user_data.record_start_time() # Record start time for first chunk latency calculation
346
+
347
+ # Establish stream
348
+ sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
349
+
350
+ # Send request
351
+ sync_triton_client.async_stream_infer(
352
+ model_name,
353
+ inputs,
354
+ request_id=request_id,
355
+ outputs=outputs,
356
+ enable_empty_final_response=True,
357
+ )
358
+
359
+ # Process results
360
+ audios = []
361
+ while True:
362
+ try:
363
+ result = user_data._completed_requests.get() # Add timeout
364
+ if isinstance(result, InferenceServerException):
365
+ print(f"Received InferenceServerException: {result}")
366
+ sync_triton_client.stop_stream()
367
+ return None, None, None # Indicate error
368
+ # Get response metadata
369
+ response = result.get_response()
370
+ final = response.parameters["triton_final_response"].bool_param
371
+ if final is True:
372
+ break
373
+
374
+ audio_chunk = result.as_numpy("waveform").reshape(-1)
375
+ if audio_chunk.size > 0: # Only append non-empty chunks
376
+ audios.append(audio_chunk)
377
+ else:
378
+ print("Warning: received empty audio chunk.")
379
+
380
+ except queue.Empty:
381
+ print(f"Timeout waiting for response for request id {request_id}")
382
+ sync_triton_client.stop_stream()
383
+ return None, None, None # Indicate error
384
+
385
+ sync_triton_client.stop_stream()
386
+ end_time_total = time.time()
387
+ total_request_latency = end_time_total - start_time_total
388
+ first_chunk_latency = user_data.get_first_chunk_latency()
389
+
390
+ # Reconstruct audio using cross-fade (from client_grpc_streaming.py)
391
+ actual_duration = 0
392
+ if audios:
393
+ cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
394
+ fade_out = np.linspace(1, 0, cross_fade_samples)
395
+ fade_in = np.linspace(0, 1, cross_fade_samples)
396
+ reconstructed_audio = None
397
+
398
+ # Simplified reconstruction based on client_grpc_streaming.py
399
+ if not audios:
400
+ print("Warning: No audio chunks received.")
401
+ reconstructed_audio = np.array([], dtype=np.float32) # Empty array
402
+ elif len(audios) == 1:
403
+ reconstructed_audio = audios[0]
404
+ else:
405
+ reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
406
+ for i in range(1, len(audios)):
407
+ # Cross-fade section
408
+ cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
409
+ audios[i - 1][-cross_fade_samples:] * fade_out)
410
+ # Middle section of the current chunk
411
+ middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
412
+ # Concatenate
413
+ reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
414
+ # Add the last part of the final chunk
415
+ reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
416
+
417
+ if reconstructed_audio is not None and reconstructed_audio.size > 0:
418
+ actual_duration = len(reconstructed_audio) / save_sample_rate
419
+ # Save reconstructed audio
420
+ os.makedirs(os.path.dirname(audio_save_path), exist_ok=True)
421
+ sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
422
+ else:
423
+ print("Warning: No audio chunks received or reconstructed.")
424
+ actual_duration = 0 # Set duration to 0 if no audio
425
+
426
+ else:
427
+ print("Warning: No audio chunks received.")
428
+ actual_duration = 0
429
+
430
+ return total_request_latency, first_chunk_latency, actual_duration
431
+
432
+
433
+ async def send_streaming(
434
+ manifest_item_list: list,
435
+ name: str,
436
+ server_url: str, # Changed from sync_triton_client
437
+ protocol_client: types.ModuleType,
438
+ log_interval: int,
439
+ model_name: str,
440
+ audio_save_dir: str = "./",
441
+ save_sample_rate: int = 16000,
442
+ chunk_overlap_duration: float = 0.1,
443
+ padding_duration: int = None,
444
+ ):
445
+ total_duration = 0.0
446
+ latency_data = []
447
+ task_id = int(name[5:])
448
+ sync_triton_client = None # Initialize client variable
449
+
450
+ try: # Wrap in try...finally to ensure client closing
451
+ print(f"{name}: Initializing sync client for streaming...")
452
+ sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
453
+
454
+ print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
455
+ for i, item in enumerate(manifest_item_list):
456
+ if i % log_interval == 0:
457
+ print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
458
+
459
+ try:
460
+ waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
461
+ reference_text, target_text = item["reference_text"], item["target_text"]
462
+
463
+ inputs, outputs = prepare_request_input_output(
464
+ protocol_client,
465
+ waveform,
466
+ reference_text,
467
+ target_text,
468
+ sample_rate,
469
+ padding_duration=padding_duration
470
+ )
471
+ request_id = str(uuid.uuid4())
472
+ user_data = UserData()
473
+
474
+ audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
475
+
476
+ total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread(
477
+ run_sync_streaming_inference,
478
+ sync_triton_client,
479
+ model_name,
480
+ inputs,
481
+ outputs,
482
+ request_id,
483
+ user_data,
484
+ chunk_overlap_duration,
485
+ save_sample_rate,
486
+ audio_save_path
487
+ )
488
+
489
+ if total_request_latency is not None:
490
+ print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s")
491
+ latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
492
+ total_duration += actual_duration
493
+ else:
494
+ print(f"{name}: Item {i} failed.")
495
+
496
+
497
+ except FileNotFoundError:
498
+ print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
499
+ except Exception as e:
500
+ print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
501
+ import traceback
502
+ traceback.print_exc()
503
+
504
+
505
+ finally: # Ensure client is closed
506
+ if sync_triton_client:
507
+ try:
508
+ print(f"{name}: Closing sync client...")
509
+ sync_triton_client.close()
510
+ except Exception as e:
511
+ print(f"{name}: Error closing sync client: {e}")
512
+
513
+
514
+ print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
515
+ return total_duration, latency_data
516
+
517
+ async def send(
518
+ manifest_item_list: list,
519
+ name: str,
520
+ triton_client: tritonclient.grpc.aio.InferenceServerClient,
521
+ protocol_client: types.ModuleType,
522
+ log_interval: int,
523
+ model_name: str,
524
+ padding_duration: int = None,
525
+ audio_save_dir: str = "./",
526
+ save_sample_rate: int = 16000,
527
+ ):
528
+ total_duration = 0.0
529
+ latency_data = []
530
+ task_id = int(name[5:])
531
+
532
+ print(f"manifest_item_list: {manifest_item_list}")
533
+ for i, item in enumerate(manifest_item_list):
534
+ if i % log_interval == 0:
535
+ print(f"{name}: {i}/{len(manifest_item_list)}")
536
+ waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
537
+ reference_text, target_text = item["reference_text"], item["target_text"]
538
+
539
+ inputs, outputs = prepare_request_input_output(
540
+ protocol_client,
541
+ waveform,
542
+ reference_text,
543
+ target_text,
544
+ sample_rate,
545
+ padding_duration=padding_duration
546
+ )
547
+ sequence_id = 100000000 + i + task_id * 10
548
+ start = time.time()
549
+ response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
550
+
551
+ audio = response.as_numpy("waveform").reshape(-1)
552
+ actual_duration = len(audio) / save_sample_rate
553
+
554
+ end = time.time() - start
555
+
556
+ audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
557
+ sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
558
+
559
+ latency_data.append((end, actual_duration))
560
+ total_duration += actual_duration
561
+
562
+ return total_duration, latency_data
563
+
564
+
565
+ def load_manifests(manifest_path):
566
+ with open(manifest_path, "r") as f:
567
+ manifest_list = []
568
+ for line in f:
569
+ assert len(line.strip().split("|")) == 4
570
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
571
+ utt = Path(utt).stem
572
+ # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
573
+ if not os.path.isabs(prompt_wav):
574
+ prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
575
+ manifest_list.append(
576
+ {
577
+ "audio_filepath": prompt_wav,
578
+ "reference_text": prompt_text,
579
+ "target_text": gt_text,
580
+ "target_audio_path": utt,
581
+ }
582
+ )
583
+ return manifest_list
584
+
585
+
586
+ def split_data(data, k):
587
+ n = len(data)
588
+ if n < k:
589
+ print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
590
+ k = n
591
+
592
+ quotient = n // k
593
+ remainder = n % k
594
+
595
+ result = []
596
+ start = 0
597
+ for i in range(k):
598
+ if i < remainder:
599
+ end = start + quotient + 1
600
+ else:
601
+ end = start + quotient
602
+
603
+ result.append(data[start:end])
604
+ start = end
605
+
606
+ return result
607
+
608
+ async def main():
609
+ args = get_args()
610
+ url = f"{args.server_addr}:{args.server_port}"
611
+
612
+ # --- Client Initialization based on mode ---
613
+ triton_client = None
614
+ protocol_client = None
615
+ if args.mode == "offline":
616
+ print("Initializing gRPC client for offline mode...")
617
+ # Use the async client for offline tasks
618
+ triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
619
+ protocol_client = grpcclient_aio
620
+ elif args.mode == "streaming":
621
+ print("Initializing gRPC client for streaming mode...")
622
+ # Use the sync client for streaming tasks, handled via asyncio.to_thread
623
+ # We will create one sync client instance PER TASK inside send_streaming.
624
+ # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
625
+ protocol_client = grpcclient_sync # protocol client for input prep
626
+ else:
627
+ raise ValueError(f"Invalid mode: {args.mode}")
628
+ # --- End Client Initialization ---
629
+
630
+ if args.reference_audio:
631
+ args.num_tasks = 1
632
+ args.log_interval = 1
633
+ manifest_item_list = [
634
+ {
635
+ "reference_text": args.reference_text,
636
+ "target_text": args.target_text,
637
+ "audio_filepath": args.reference_audio,
638
+ "target_audio_path": "test",
639
+ }
640
+ ]
641
+ elif args.huggingface_dataset:
642
+ import datasets
643
+
644
+ dataset = datasets.load_dataset(
645
+ args.huggingface_dataset,
646
+ split=args.split_name,
647
+ trust_remote_code=True,
648
+ )
649
+ manifest_item_list = []
650
+ for i in range(len(dataset)):
651
+ manifest_item_list.append(
652
+ {
653
+ "audio_filepath": dataset[i]["prompt_audio"],
654
+ "reference_text": dataset[i]["prompt_text"],
655
+ "target_audio_path": dataset[i]["id"],
656
+ "target_text": dataset[i]["target_text"],
657
+ }
658
+ )
659
+ else:
660
+ manifest_item_list = load_manifests(args.manifest_path)
661
+
662
+ num_tasks = min(args.num_tasks, len(manifest_item_list))
663
+ manifest_item_list = split_data(manifest_item_list, num_tasks)
664
+
665
+ os.makedirs(args.log_dir, exist_ok=True)
666
+ tasks = []
667
+ start_time = time.time()
668
+ for i in range(num_tasks):
669
+ # --- Task Creation based on mode ---
670
+ if args.mode == "offline":
671
+ task = asyncio.create_task(
672
+ send(
673
+ manifest_item_list[i],
674
+ name=f"task-{i}",
675
+ triton_client=triton_client,
676
+ protocol_client=protocol_client,
677
+ log_interval=args.log_interval,
678
+ model_name=args.model_name,
679
+ audio_save_dir=args.log_dir,
680
+ padding_duration=1,
681
+ save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
682
+ )
683
+ )
684
+ elif args.mode == "streaming":
685
+ task = asyncio.create_task(
686
+ send_streaming(
687
+ manifest_item_list[i],
688
+ name=f"task-{i}",
689
+ server_url=url, # Pass URL instead of client
690
+ protocol_client=protocol_client,
691
+ log_interval=args.log_interval,
692
+ model_name=args.model_name,
693
+ audio_save_dir=args.log_dir,
694
+ padding_duration=10,
695
+ save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
696
+ chunk_overlap_duration=args.chunk_overlap_duration,
697
+ )
698
+ )
699
+ # --- End Task Creation ---
700
+ tasks.append(task)
701
+
702
+ ans_list = await asyncio.gather(*tasks)
703
+
704
+ end_time = time.time()
705
+ elapsed = end_time - start_time
706
+
707
+ total_duration = 0.0
708
+ latency_data = []
709
+ for ans in ans_list:
710
+ if ans:
711
+ total_duration += ans[0]
712
+ latency_data.extend(ans[1]) # Use extend for list of lists
713
+ else:
714
+ print("Warning: A task returned None, possibly due to an error.")
715
+
716
+
717
+ if total_duration == 0:
718
+ print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
719
+ rtf = float('inf')
720
+ else:
721
+ rtf = elapsed / total_duration
722
+
723
+ s = f"Mode: {args.mode}\n"
724
+ s += f"RTF: {rtf:.4f}\n"
725
+ s += f"total_duration: {total_duration:.3f} seconds\n"
726
+ s += f"({total_duration / 3600:.2f} hours)\n"
727
+ s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
728
+
729
+ # --- Statistics Reporting based on mode ---
730
+ if latency_data:
731
+ if args.mode == "offline":
732
+ # Original offline latency calculation
733
+ latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
734
+ if latency_list:
735
+ latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
736
+ latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
737
+ s += f"latency_variance: {latency_variance:.2f}\n"
738
+ s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
739
+ s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
740
+ s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
741
+ s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
742
+ s += f"average_latency_ms: {latency_ms:.2f}\n"
743
+ else:
744
+ s += "No latency data collected for offline mode.\n"
745
+
746
+ elif args.mode == "streaming":
747
+ # Calculate stats for total request latency and first chunk latency
748
+ total_latency_list = [total for (total, first, duration) in latency_data if total is not None]
749
+ first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None]
750
+
751
+ s += "\n--- Total Request Latency ---\n"
752
+ if total_latency_list:
753
+ avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
754
+ variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
755
+ s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
756
+ s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
757
+ s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
758
+ s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
759
+ s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
760
+ s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
761
+ else:
762
+ s += "No total request latency data collected.\n"
763
+
764
+ s += "\n--- First Chunk Latency ---\n"
765
+ if first_chunk_latency_list:
766
+ avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
767
+ variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
768
+ s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
769
+ s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
770
+ s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
771
+ s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
772
+ s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
773
+ s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
774
+ else:
775
+ s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
776
+ else:
777
+ s += "No latency data collected.\n"
778
+ # --- End Statistics Reporting ---
779
+
780
+ print(s)
781
+ if args.manifest_path:
782
+ name = Path(args.manifest_path).stem
783
+ elif args.split_name:
784
+ name = args.split_name
785
+ elif args.reference_audio:
786
+ name = Path(args.reference_audio).stem
787
+ else:
788
+ name = "results" # Default name if no manifest/split/audio provided
789
+ with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
790
+ f.write(s)
791
+
792
+ # --- Statistics Fetching using temporary Async Client ---
793
+ # Use a separate async client for fetching stats regardless of mode
794
+ stats_client = None
795
+ try:
796
+ print("Initializing temporary async client for fetching stats...")
797
+ stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
798
+ print("Fetching inference statistics...")
799
+ # Fetching for all models, filtering might be needed depending on server setup
800
+ stats = await stats_client.get_inference_statistics(model_name="", as_json=True)
801
+ print("Fetching model config...")
802
+ metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
803
+
804
+ write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
805
+
806
+ with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
807
+ json.dump(metadata, f, indent=4)
808
+
809
+ except Exception as e:
810
+ print(f"Could not retrieve statistics or config: {e}")
811
+ finally:
812
+ if stats_client:
813
+ try:
814
+ print("Closing temporary async stats client...")
815
+ await stats_client.close()
816
+ except Exception as e:
817
+ print(f"Error closing async stats client: {e}")
818
+ # --- End Statistics Fetching ---
819
+
820
+
821
+ if __name__ == "__main__":
822
+ # asyncio.run(main()) # Use TaskGroup for better exception handling if needed
823
+ async def run_main():
824
+ try:
825
+ await main()
826
+ except Exception as e:
827
+ print(f"An error occurred in main: {e}")
828
+ import traceback
829
+ traceback.print_exc()
830
+
831
+ asyncio.run(run_main())
runtime/triton_trtllm/client_http.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import requests
27
+ import soundfile as sf
28
+ import json
29
+ import numpy as np
30
+ import argparse
31
+
32
+ def get_args():
33
+ parser = argparse.ArgumentParser(
34
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
35
+ )
36
+
37
+ parser.add_argument(
38
+ "--server-url",
39
+ type=str,
40
+ default="localhost:8000",
41
+ help="Address of the server",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--reference-audio",
46
+ type=str,
47
+ default="../../example/prompt_audio.wav",
48
+ help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
49
+ )
50
+
51
+ parser.add_argument(
52
+ "--reference-text",
53
+ type=str,
54
+ default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
55
+ help="",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--target-text",
60
+ type=str,
61
+ default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
62
+ help="",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "--model-name",
67
+ type=str,
68
+ default="spark_tts",
69
+ choices=[
70
+ "f5_tts", "spark_tts"
71
+ ],
72
+ help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--output-audio",
77
+ type=str,
78
+ default="output.wav",
79
+ help="Path to save the output audio",
80
+ )
81
+ return parser.parse_args()
82
+
83
+ def prepare_request(
84
+ waveform,
85
+ reference_text,
86
+ target_text,
87
+ sample_rate=16000,
88
+ padding_duration: int = None,
89
+ audio_save_dir: str = "./",
90
+ ):
91
+ assert len(waveform.shape) == 1, "waveform should be 1D"
92
+ lengths = np.array([[len(waveform)]], dtype=np.int32)
93
+ if padding_duration:
94
+ # padding to nearset 10 seconds
95
+ samples = np.zeros(
96
+ (
97
+ 1,
98
+ padding_duration
99
+ * sample_rate
100
+ * ((int(duration) // padding_duration) + 1),
101
+ ),
102
+ dtype=np.float32,
103
+ )
104
+
105
+ samples[0, : len(waveform)] = waveform
106
+ else:
107
+ samples = waveform
108
+
109
+ samples = samples.reshape(1, -1).astype(np.float32)
110
+
111
+ data = {
112
+ "inputs":[
113
+ {
114
+ "name": "reference_wav",
115
+ "shape": samples.shape,
116
+ "datatype": "FP32",
117
+ "data": samples.tolist()
118
+ },
119
+ {
120
+ "name": "reference_wav_len",
121
+ "shape": lengths.shape,
122
+ "datatype": "INT32",
123
+ "data": lengths.tolist(),
124
+ },
125
+ {
126
+ "name": "reference_text",
127
+ "shape": [1, 1],
128
+ "datatype": "BYTES",
129
+ "data": [reference_text]
130
+ },
131
+ {
132
+ "name": "target_text",
133
+ "shape": [1, 1],
134
+ "datatype": "BYTES",
135
+ "data": [target_text]
136
+ }
137
+ ]
138
+ }
139
+
140
+ return data
141
+
142
+ if __name__ == "__main__":
143
+ args = get_args()
144
+ server_url = args.server_url
145
+ if not server_url.startswith(("http://", "https://")):
146
+ server_url = f"http://{server_url}"
147
+
148
+ url = f"{server_url}/v2/models/{args.model_name}/infer"
149
+ waveform, sr = sf.read(args.reference_audio)
150
+ assert sr == 16000, "sample rate hardcoded in server"
151
+
152
+ samples = np.array(waveform, dtype=np.float32)
153
+ data = prepare_request(samples, args.reference_text, args.target_text)
154
+
155
+ rsp = requests.post(
156
+ url,
157
+ headers={"Content-Type": "application/json"},
158
+ json=data,
159
+ verify=False,
160
+ params={"request_id": '0'}
161
+ )
162
+ result = rsp.json()
163
+ audio = result["outputs"][0]["data"]
164
+ audio = np.array(audio, dtype=np.float32)
165
+ sf.write(args.output_audio, audio, 16000, "PCM_16")
runtime/triton_trtllm/docker-compose.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ tts:
3
+ image: soar97/triton-spark-tts:25.02
4
+ shm_size: '1gb'
5
+ ports:
6
+ - "8000:8000"
7
+ - "8001:8001"
8
+ - "8002:8002"
9
+ environment:
10
+ - PYTHONIOENCODING=utf-8
11
+ - MODEL_ID=${MODEL_ID}
12
+ deploy:
13
+ resources:
14
+ reservations:
15
+ devices:
16
+ - driver: nvidia
17
+ device_ids: ['0']
18
+ capabilities: [gpu]
19
+ command: >
20
+ /bin/bash -c "rm -rf Spark-TTS && git clone https://github.com/SparkAudio/Spark-TTS.git && cd Spark-TTS/runtime/triton_trtllm && bash run.sh 0 3"
runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ import json
27
+ import torch
28
+ from torch.utils.dlpack import to_dlpack
29
+
30
+ import triton_python_backend_utils as pb_utils
31
+
32
+ import os
33
+ import numpy as np
34
+
35
+ from sparktts.models.audio_tokenizer import BiCodecTokenizer
36
+
37
+ class TritonPythonModel:
38
+ """Triton Python model for audio tokenization.
39
+
40
+ This model takes reference audio input and extracts semantic and global tokens
41
+ using BiCodec tokenizer.
42
+ """
43
+
44
+ def initialize(self, args):
45
+ """Initialize the model.
46
+
47
+ Args:
48
+ args: Dictionary containing model configuration
49
+ """
50
+ # Parse model parameters
51
+ parameters = json.loads(args['model_config'])['parameters']
52
+ model_params = {k: v["string_value"] for k, v in parameters.items()}
53
+
54
+ # Initialize tokenizer
55
+ self.device = torch.device("cuda")
56
+ self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"],
57
+ device=self.device)
58
+
59
+ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
60
+ """Extract reference audio clip for speaker embedding.
61
+
62
+ Args:
63
+ wav: Input waveform array
64
+
65
+ Returns:
66
+ Reference clip of fixed duration
67
+ """
68
+ SAMPLE_RATE = 16000
69
+ REF_SEGMENT_DURATION = 6 # seconds
70
+ LATENT_HOP_LENGTH = 320
71
+
72
+ ref_segment_length = (
73
+ int(SAMPLE_RATE * REF_SEGMENT_DURATION)
74
+ // LATENT_HOP_LENGTH
75
+ * LATENT_HOP_LENGTH
76
+ )
77
+ wav_length = len(wav)
78
+
79
+ if ref_segment_length > wav_length:
80
+ # Repeat and truncate if input is too short
81
+ repeat_times = ref_segment_length // wav_length + 1
82
+ wav = np.tile(wav, repeat_times)
83
+
84
+ return wav[:ref_segment_length]
85
+
86
+ def execute(self, requests):
87
+ """Execute inference on the batched requests.
88
+
89
+ Args:
90
+ requests: List of inference requests
91
+
92
+ Returns:
93
+ List of inference responses containing tokenized outputs
94
+ """
95
+ reference_wav_list = []
96
+ reference_wav_ref_clip_list = []
97
+
98
+ # Process each request in batch
99
+ for request in requests:
100
+ # Extract input tensors
101
+ wav_array = pb_utils.get_input_tensor_by_name(
102
+ request, "reference_wav").as_numpy()
103
+ wav_len = pb_utils.get_input_tensor_by_name(
104
+ request, "reference_wav_len").as_numpy().item()
105
+
106
+ # Prepare inputs
107
+ wav = wav_array[:, :wav_len].squeeze(0)
108
+ reference_wav_list.append(wav)
109
+
110
+ wav_ref_clip = self.get_ref_clip(wav)
111
+ reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip))
112
+
113
+ # Batch process through tokenizer
114
+ ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0)
115
+ wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features(
116
+ reference_wav_list)
117
+
118
+ audio_tokenizer_input = {
119
+ "ref_wav": ref_wav_clip_tensor.to(self.device),
120
+ "feat": wav2vec2_features.to(self.device),
121
+ }
122
+ semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize(
123
+ audio_tokenizer_input)
124
+
125
+ # Prepare responses
126
+ responses = []
127
+ for i in range(len(requests)):
128
+ global_tokens_tensor = pb_utils.Tensor.from_dlpack(
129
+ "global_tokens", to_dlpack(global_tokens[i]))
130
+ semantic_tokens_tensor = pb_utils.Tensor.from_dlpack(
131
+ "semantic_tokens", to_dlpack(semantic_tokens[i]))
132
+
133
+ inference_response = pb_utils.InferenceResponse(
134
+ output_tensors=[global_tokens_tensor, semantic_tokens_tensor])
135
+ responses.append(inference_response)
136
+
137
+ return responses
runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "audio_tokenizer"
16
+ backend: "python"
17
+ max_batch_size: ${triton_max_batch_size}
18
+ dynamic_batching {
19
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
20
+ }
21
+ parameters [
22
+ {
23
+ key: "model_dir",
24
+ value: {string_value:"${model_dir}"}
25
+ }
26
+ ]
27
+
28
+ input [
29
+ {
30
+ name: "reference_wav"
31
+ data_type: TYPE_FP32
32
+ dims: [-1]
33
+ },
34
+ {
35
+ name: "reference_wav_len"
36
+ data_type: TYPE_INT32
37
+ dims: [1]
38
+ }
39
+ ]
40
+ output [
41
+ {
42
+ name: "global_tokens"
43
+ data_type: TYPE_INT32
44
+ dims: [-1]
45
+ },
46
+ {
47
+ name: "semantic_tokens"
48
+ data_type: TYPE_INT32
49
+ dims: [-1]
50
+ }
51
+ ]
52
+
53
+ instance_group [
54
+ {
55
+ count: 1
56
+ kind: KIND_CPU
57
+ }
58
+ ]
runtime/triton_trtllm/model_repo/spark_tts/1/model.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ import json
28
+ import math
29
+ import os
30
+ import re
31
+ from typing import Dict, List, Tuple, Optional, Union
32
+
33
+ import numpy as np
34
+ import torch
35
+ from torch.utils.dlpack import from_dlpack, to_dlpack
36
+ import triton_python_backend_utils as pb_utils
37
+ from transformers import AutoTokenizer
38
+
39
+ from sparktts.utils.token_parser import TASK_TOKEN_MAP
40
+
41
+ def process_prompt(
42
+ text: str,
43
+ prompt_text: Optional[str] = None,
44
+ global_token_ids: torch.Tensor = None,
45
+ semantic_token_ids: torch.Tensor = None,
46
+ ) -> Tuple[str, torch.Tensor]:
47
+ """
48
+ Process input for voice cloning.
49
+
50
+ Args:
51
+ text: The text input to be converted to speech.
52
+ prompt_text: Transcript of the prompt audio.
53
+ global_token_ids: Global token IDs extracted from reference audio.
54
+ semantic_token_ids: Semantic token IDs extracted from reference audio.
55
+
56
+ Returns:
57
+ Tuple containing the formatted input prompt and global token IDs.
58
+ """
59
+ # Convert global tokens to string format
60
+ global_tokens = "".join(
61
+ [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
62
+ )
63
+
64
+
65
+ # Prepare the input tokens for the model
66
+ if prompt_text is not None:
67
+ # Include semantic tokens when prompt text is provided
68
+ semantic_tokens = "".join(
69
+ [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
70
+ )
71
+
72
+ inputs = [
73
+ TASK_TOKEN_MAP["tts"],
74
+ "<|start_content|>",
75
+ prompt_text,
76
+ text,
77
+ "<|end_content|>",
78
+ "<|start_global_token|>",
79
+ global_tokens,
80
+ "<|end_global_token|>",
81
+ "<|start_semantic_token|>",
82
+ semantic_tokens,
83
+ ]
84
+ else:
85
+ # Without prompt text, exclude semantic tokens
86
+ inputs = [
87
+ TASK_TOKEN_MAP["tts"],
88
+ "<|start_content|>",
89
+ text,
90
+ "<|end_content|>",
91
+ "<|start_global_token|>",
92
+ global_tokens,
93
+ "<|end_global_token|>",
94
+ ]
95
+
96
+ # Join all input components into a single string
97
+ inputs = "".join(inputs)
98
+ return inputs, global_token_ids
99
+
100
+
101
+ class TritonPythonModel:
102
+ """Triton Python model for Spark TTS.
103
+
104
+ This model orchestrates the end-to-end TTS pipeline by coordinating
105
+ between audio tokenizer, LLM, and vocoder components.
106
+ """
107
+
108
+ def initialize(self, args):
109
+ """Initialize the model.
110
+
111
+ Args:
112
+ args: Dictionary containing model configuration
113
+ """
114
+ self.logger = pb_utils.Logger
115
+ # Parse model parameters
116
+ self.model_config = json.loads(args['model_config'])
117
+ parameters = self.model_config['parameters']
118
+ model_params = {k: v["string_value"] for k, v in parameters.items()}
119
+ self.logger.log_info(f"model_params:{model_params}")
120
+ # streaming TTS parameters
121
+ assert (
122
+ float(model_params["audio_chunk_duration"]) >= 0.5
123
+ ), f"audio_chunk_duration at least 0.5 seconds"
124
+ self.audio_chunk_duration = float(model_params["audio_chunk_duration"])
125
+ self.max_audio_chunk_duration = float(model_params["max_audio_chunk_duration"])
126
+ assert (
127
+ float(model_params["audio_chunk_size_scale_factor"]) >= 1.0
128
+ ), "audio_chunk_size_scale_factor should be greater than 1, change it according to your actual rtf"
129
+ self.audio_chunk_size_scale_factor = float(model_params["audio_chunk_size_scale_factor"]) # scale speed
130
+ self.audio_chunk_overlap_duration = float(model_params["audio_chunk_overlap_duration"])
131
+ self.audio_tokenizer_frame_rate = int(model_params["audio_tokenizer_frame_rate"])
132
+
133
+ # Initialize tokenizer
134
+ llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
135
+ self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
136
+ self.device = torch.device("cuda")
137
+ self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
138
+
139
+ def forward_llm(self, input_ids):
140
+ """
141
+ Prepares the response from the language model based on the provided
142
+ inputs. Creates a `pb_utils.InferenceRequest` object with passed
143
+ `llm_request_inputs` to send to a decoupled TensorRTLLM model.
144
+ For each response from the language model:
145
+ - Checks for errors and raise an exception if any are found.
146
+ - Extracts the "output_ids" tensor from the response.
147
+ - Determines the finish reason based on the presence of the
148
+ end-of-sequence token or reaching the maximum length.
149
+ - Appends the generated token IDs to `output_ids`.
150
+ - If the finish reason is determined, decodes the output IDs to text
151
+ and prepares the final response.
152
+
153
+ The final response includes the generated text, finish reason,
154
+ completion tokens, prompt tokens, and total tokens.
155
+
156
+ Parameters
157
+ ----------
158
+ - llm_request_inputs (dict): A dictionary containing the inputs for the language model.
159
+
160
+ Returns
161
+ -------
162
+ - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
163
+ """
164
+ # convert input_ids to numpy, with shape [1, sequence_length]
165
+ input_ids = input_ids.cpu().numpy()
166
+ max_tokens = 512
167
+ input_dict = {
168
+ "request_output_len": np.array([[max_tokens]], dtype=np.int32),
169
+ "end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32),
170
+ "pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32),
171
+ "streaming": np.array([[self.decoupled]], dtype=np.bool_),
172
+ "runtime_top_p": np.array([[0.95]], dtype=np.float32),
173
+ "runtime_top_k": np.array([[50]], dtype=np.int32),
174
+ "temperature": np.array([[0.8]], dtype=np.float32),
175
+ "input_ids": input_ids,
176
+ "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
177
+ }
178
+
179
+ # Convert inputs to Triton tensors
180
+ input_tensor_list = [
181
+ pb_utils.Tensor(k, v) for k, v in input_dict.items()
182
+ ]
183
+
184
+ # Create and execute inference request
185
+ llm_request = pb_utils.InferenceRequest(
186
+ model_name="tensorrt_llm",
187
+ requested_output_names=["output_ids", "sequence_length"],
188
+ inputs=input_tensor_list,
189
+ )
190
+
191
+ llm_responses = llm_request.exec(decoupled=self.decoupled)
192
+ if self.decoupled:
193
+ for llm_response in llm_responses:
194
+ if llm_response.has_error():
195
+ raise pb_utils.TritonModelException(llm_response.error().message())
196
+
197
+ # Extract and process output
198
+ output_ids = pb_utils.get_output_tensor_by_name(
199
+ llm_response, "output_ids").as_numpy()
200
+ seq_lens = pb_utils.get_output_tensor_by_name(
201
+ llm_response, "sequence_length").as_numpy()
202
+
203
+ # Get actual output IDs up to the sequence length
204
+ actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
205
+
206
+ yield actual_output_ids
207
+ else:
208
+ llm_response = llm_responses
209
+ if llm_response.has_error():
210
+ raise pb_utils.TritonModelException(llm_response.error().message())
211
+
212
+ # Extract and process output
213
+ output_ids = pb_utils.get_output_tensor_by_name(
214
+ llm_response, "output_ids").as_numpy()
215
+ seq_lens = pb_utils.get_output_tensor_by_name(
216
+ llm_response, "sequence_length").as_numpy()
217
+
218
+ # Get actual output IDs up to the sequence length
219
+ actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
220
+
221
+ yield actual_output_ids
222
+
223
+ def forward_audio_tokenizer(self, wav, wav_len):
224
+ """Forward pass through the audio tokenizer component.
225
+
226
+ Args:
227
+ wav: Input waveform tensor
228
+ wav_len: Waveform length tensor
229
+
230
+ Returns:
231
+ Tuple of global and semantic tokens
232
+ """
233
+ inference_request = pb_utils.InferenceRequest(
234
+ model_name='audio_tokenizer',
235
+ requested_output_names=['global_tokens', 'semantic_tokens'],
236
+ inputs=[wav, wav_len]
237
+ )
238
+
239
+ inference_response = inference_request.exec()
240
+ if inference_response.has_error():
241
+ raise pb_utils.TritonModelException(inference_response.error().message())
242
+
243
+ # Extract and convert output tensors
244
+ global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens')
245
+ global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu()
246
+
247
+ semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens')
248
+ semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu()
249
+
250
+ return global_tokens, semantic_tokens
251
+
252
+ def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor:
253
+ """Forward pass through the vocoder component.
254
+
255
+ Args:
256
+ global_token_ids: Global token IDs tensor
257
+ pred_semantic_ids: Predicted semantic token IDs tensor
258
+
259
+ Returns:
260
+ Generated waveform tensor
261
+ """
262
+ # Convert tensors to Triton format
263
+ global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids))
264
+ pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids))
265
+
266
+ # Create and execute inference request
267
+ inference_request = pb_utils.InferenceRequest(
268
+ model_name='vocoder',
269
+ requested_output_names=['waveform'],
270
+ inputs=[global_token_ids_tensor, pred_semantic_ids_tensor]
271
+ )
272
+
273
+ inference_response = inference_request.exec()
274
+ if inference_response.has_error():
275
+ raise pb_utils.TritonModelException(inference_response.error().message())
276
+
277
+ # Extract and convert output waveform
278
+ waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
279
+ waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
280
+
281
+ return waveform
282
+
283
+ def token2wav(self, generated_token_ids, global_token_ids):
284
+ # Decode and extract semantic token IDs from generated text
285
+ predicted_text = self.tokenizer.batch_decode(
286
+ [generated_token_ids],
287
+ skip_special_tokens=True,
288
+ )[0]
289
+ pred_semantic_ids = (
290
+ torch.tensor(
291
+ [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)]
292
+ )
293
+ .unsqueeze(0)
294
+ .to(torch.int32)
295
+ )
296
+
297
+ # Generate audio with vocoder
298
+ audio = self.forward_vocoder(
299
+ global_token_ids.to(self.device),
300
+ pred_semantic_ids.to(self.device),
301
+ )
302
+
303
+ return audio
304
+
305
+ def execute(self, requests):
306
+ """Execute inference on the batched requests.
307
+
308
+ Args:
309
+ requests: List of inference requests
310
+
311
+ Returns:
312
+ List of inference responses containing generated audio
313
+ """
314
+ responses = []
315
+
316
+ for request in requests:
317
+ # Extract input tensors
318
+ wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
319
+ wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
320
+
321
+ # Process reference audio through audio tokenizer
322
+ global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len)
323
+
324
+ # Extract text inputs
325
+ reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
326
+ reference_text = reference_text[0][0].decode('utf-8')
327
+
328
+ target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
329
+ target_text = target_text[0][0].decode('utf-8')
330
+
331
+ # Prepare prompt for LLM
332
+ prompt, global_token_ids = process_prompt(
333
+ text=target_text,
334
+ prompt_text=reference_text,
335
+ global_token_ids=global_tokens,
336
+ semantic_token_ids=semantic_tokens,
337
+ )
338
+
339
+
340
+ # Tokenize prompt for LLM
341
+ model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
342
+ input_ids = model_inputs.input_ids.to(torch.int32)
343
+
344
+ # Generate semantic tokens with LLM
345
+ generated_ids_iter = self.forward_llm(input_ids)
346
+
347
+ if self.decoupled:
348
+ response_sender = request.get_response_sender()
349
+ request_id = request.request_id()
350
+ semantic_token_ids_arr = []
351
+ max_chunk_size = math.ceil(self.max_audio_chunk_duration * self.audio_tokenizer_frame_rate)
352
+ chunk_size = math.ceil(self.audio_chunk_duration * self.audio_tokenizer_frame_rate)
353
+ overlap_chunk_size = math.ceil(self.audio_chunk_overlap_duration * self.audio_tokenizer_frame_rate)
354
+ self.logger.log_info(
355
+ f"[{request_id}] init chunk_size: {chunk_size} max_chunk_size: {max_chunk_size}"
356
+ )
357
+ for generated_ids in generated_ids_iter:
358
+ if generated_ids is None or len(generated_ids) == 0:
359
+ break
360
+
361
+ semantic_token_ids_arr.append(generated_ids)
362
+ if len(semantic_token_ids_arr) >= chunk_size:
363
+ chunk = semantic_token_ids_arr[:chunk_size]
364
+ generated_semantic_token_ids = np.hstack(chunk)
365
+ # Process each chunk
366
+ sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids)
367
+ # Prepare response to send
368
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
369
+ inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
370
+ response_sender.send(inference_response)
371
+
372
+ semantic_token_ids_arr = semantic_token_ids_arr[chunk_size - overlap_chunk_size:]
373
+ # increase chunk size for better speech quality
374
+ chunk_size = min(max_chunk_size, int(chunk_size * self.audio_chunk_size_scale_factor))
375
+ self.logger.log_info(f"[{request_id}] increase chunk_size: {chunk_size}")
376
+
377
+ if len(semantic_token_ids_arr) > 0: # end to finalize
378
+ generated_semantic_token_ids = np.hstack(semantic_token_ids_arr)
379
+ # Process each chunk
380
+ sub_tts_speech = self.token2wav(generated_semantic_token_ids, global_token_ids)
381
+ # Prepare response to send
382
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
383
+ inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
384
+ response_sender.send(inference_response)
385
+ self.logger.log_info(f"[{request_id}] last chunk len: {len(semantic_token_ids_arr)}")
386
+ else:
387
+ generated_ids = next(generated_ids_iter)
388
+ if generated_ids is None or len(generated_ids) == 0:
389
+ raise pb_utils.TritonModelException("Generated IDs is None or empty")
390
+
391
+ audio = self.token2wav(generated_ids, global_token_ids)
392
+
393
+ # Prepare response
394
+ audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
395
+ inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
396
+ responses.append(inference_response)
397
+
398
+ if self.decoupled:
399
+ response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
400
+ self.logger.log_info(f"send tritonserver_response_complete_final to end")
401
+
402
+ if not self.decoupled:
403
+ return responses
404
+
runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "spark_tts"
16
+ backend: "python"
17
+ max_batch_size: ${triton_max_batch_size}
18
+ dynamic_batching {
19
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
20
+ }
21
+ model_transaction_policy {
22
+ decoupled: ${decoupled_mode}
23
+ }
24
+ parameters [
25
+ {
26
+ key: "llm_tokenizer_dir",
27
+ value: {string_value:"${llm_tokenizer_dir}"}
28
+ },
29
+ {
30
+ key: "audio_chunk_duration",
31
+ value: {string_value:"${audio_chunk_duration}"}
32
+ },
33
+ {
34
+ key: "audio_chunk_size_scale_factor",
35
+ value: {string_value:"${audio_chunk_size_scale_factor}"}
36
+ },
37
+ {
38
+ key: "max_audio_chunk_duration",
39
+ value: {string_value:"${max_audio_chunk_duration}"}
40
+ },
41
+ {
42
+ key: "audio_chunk_overlap_duration",
43
+ value: {string_value:"${audio_chunk_overlap_duration}"}
44
+ },
45
+ {
46
+ key: "audio_tokenizer_frame_rate",
47
+ value: {string_value:"50"}
48
+ }
49
+ ]
50
+
51
+ input [
52
+ {
53
+ name: "reference_wav"
54
+ data_type: TYPE_FP32
55
+ dims: [-1]
56
+ },
57
+ {
58
+ name: "reference_wav_len"
59
+ data_type: TYPE_INT32
60
+ dims: [1]
61
+ },
62
+ {
63
+ name: "reference_text"
64
+ data_type: TYPE_STRING
65
+ dims: [1]
66
+ },
67
+ {
68
+ name: "target_text"
69
+ data_type: TYPE_STRING
70
+ dims: [1]
71
+ }
72
+ ]
73
+ output [
74
+ {
75
+ name: "waveform"
76
+ data_type: TYPE_FP32
77
+ dims: [ -1 ]
78
+ }
79
+ ]
80
+
81
+ instance_group [
82
+ {
83
+ count: ${bls_instance_num}
84
+ kind: KIND_CPU
85
+ }
86
+ ]
runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep ADDED
File without changes
runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ name: "tensorrt_llm"
28
+ backend: "${triton_backend}"
29
+ max_batch_size: ${triton_max_batch_size}
30
+
31
+ model_transaction_policy {
32
+ decoupled: ${decoupled_mode}
33
+ }
34
+
35
+ dynamic_batching {
36
+ preferred_batch_size: [ ${triton_max_batch_size} ]
37
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
38
+ default_queue_policy: { max_queue_size: ${max_queue_size} }
39
+ }
40
+
41
+ input [
42
+ {
43
+ name: "input_ids"
44
+ data_type: TYPE_INT32
45
+ dims: [ -1 ]
46
+ allow_ragged_batch: true
47
+ optional: true
48
+ },
49
+ {
50
+ name: "encoder_input_features"
51
+ data_type: ${encoder_input_features_data_type}
52
+ dims: [ -1, -1 ]
53
+ allow_ragged_batch: true
54
+ optional: true
55
+ },
56
+ {
57
+ name: "encoder_output_lengths"
58
+ data_type: TYPE_INT32
59
+ dims: [ 1 ]
60
+ reshape: { shape: [ ] }
61
+ optional: true
62
+ },
63
+ {
64
+ name: "input_lengths"
65
+ data_type: TYPE_INT32
66
+ dims: [ 1 ]
67
+ reshape: { shape: [ ] }
68
+ },
69
+ {
70
+ name: "request_output_len"
71
+ data_type: TYPE_INT32
72
+ dims: [ 1 ]
73
+ reshape: { shape: [ ] }
74
+ },
75
+ {
76
+ name: "num_return_sequences"
77
+ data_type: TYPE_INT32
78
+ dims: [ 1 ]
79
+ reshape: { shape: [ ] }
80
+ optional: true
81
+ },
82
+ {
83
+ name: "draft_input_ids"
84
+ data_type: TYPE_INT32
85
+ dims: [ -1 ]
86
+ optional: true
87
+ allow_ragged_batch: true
88
+ },
89
+ {
90
+ name: "decoder_input_ids"
91
+ data_type: TYPE_INT32
92
+ dims: [ -1 ]
93
+ optional: true
94
+ allow_ragged_batch: true
95
+ },
96
+ {
97
+ name: "decoder_input_lengths"
98
+ data_type: TYPE_INT32
99
+ dims: [ 1 ]
100
+ optional: true
101
+ reshape: { shape: [ ] }
102
+ },
103
+ {
104
+ name: "draft_logits"
105
+ data_type: ${logits_datatype}
106
+ dims: [ -1, -1 ]
107
+ optional: true
108
+ allow_ragged_batch: true
109
+ },
110
+ {
111
+ name: "draft_acceptance_threshold"
112
+ data_type: TYPE_FP32
113
+ dims: [ 1 ]
114
+ reshape: { shape: [ ] }
115
+ optional: true
116
+ },
117
+ {
118
+ name: "end_id"
119
+ data_type: TYPE_INT32
120
+ dims: [ 1 ]
121
+ reshape: { shape: [ ] }
122
+ optional: true
123
+ },
124
+ {
125
+ name: "pad_id"
126
+ data_type: TYPE_INT32
127
+ dims: [ 1 ]
128
+ reshape: { shape: [ ] }
129
+ optional: true
130
+ },
131
+ {
132
+ name: "stop_words_list"
133
+ data_type: TYPE_INT32
134
+ dims: [ 2, -1 ]
135
+ optional: true
136
+ allow_ragged_batch: true
137
+ },
138
+ {
139
+ name: "bad_words_list"
140
+ data_type: TYPE_INT32
141
+ dims: [ 2, -1 ]
142
+ optional: true
143
+ allow_ragged_batch: true
144
+ },
145
+ {
146
+ name: "embedding_bias"
147
+ data_type: TYPE_FP32
148
+ dims: [ -1 ]
149
+ optional: true
150
+ allow_ragged_batch: true
151
+ },
152
+ {
153
+ name: "beam_width"
154
+ data_type: TYPE_INT32
155
+ dims: [ 1 ]
156
+ reshape: { shape: [ ] }
157
+ optional: true
158
+ },
159
+ {
160
+ name: "temperature"
161
+ data_type: TYPE_FP32
162
+ dims: [ 1 ]
163
+ reshape: { shape: [ ] }
164
+ optional: true
165
+ },
166
+ {
167
+ name: "runtime_top_k"
168
+ data_type: TYPE_INT32
169
+ dims: [ 1 ]
170
+ reshape: { shape: [ ] }
171
+ optional: true
172
+ },
173
+ {
174
+ name: "runtime_top_p"
175
+ data_type: TYPE_FP32
176
+ dims: [ 1 ]
177
+ reshape: { shape: [ ] }
178
+ optional: true
179
+ },
180
+ {
181
+ name: "runtime_top_p_min"
182
+ data_type: TYPE_FP32
183
+ dims: [ 1 ]
184
+ reshape: { shape: [ ] }
185
+ optional: true
186
+ },
187
+ {
188
+ name: "runtime_top_p_decay"
189
+ data_type: TYPE_FP32
190
+ dims: [ 1 ]
191
+ reshape: { shape: [ ] }
192
+ optional: true
193
+ },
194
+ {
195
+ name: "runtime_top_p_reset_ids"
196
+ data_type: TYPE_INT32
197
+ dims: [ 1 ]
198
+ reshape: { shape: [ ] }
199
+ optional: true
200
+ },
201
+ {
202
+ name: "len_penalty"
203
+ data_type: TYPE_FP32
204
+ dims: [ 1 ]
205
+ reshape: { shape: [ ] }
206
+ optional: true
207
+ },
208
+ {
209
+ name: "early_stopping"
210
+ data_type: TYPE_BOOL
211
+ dims: [ 1 ]
212
+ reshape: { shape: [ ] }
213
+ optional: true
214
+ },
215
+ {
216
+ name: "repetition_penalty"
217
+ data_type: TYPE_FP32
218
+ dims: [ 1 ]
219
+ reshape: { shape: [ ] }
220
+ optional: true
221
+ },
222
+ {
223
+ name: "min_length"
224
+ data_type: TYPE_INT32
225
+ dims: [ 1 ]
226
+ reshape: { shape: [ ] }
227
+ optional: true
228
+ },
229
+ {
230
+ name: "beam_search_diversity_rate"
231
+ data_type: TYPE_FP32
232
+ dims: [ 1 ]
233
+ reshape: { shape: [ ] }
234
+ optional: true
235
+ },
236
+ {
237
+ name: "presence_penalty"
238
+ data_type: TYPE_FP32
239
+ dims: [ 1 ]
240
+ reshape: { shape: [ ] }
241
+ optional: true
242
+ },
243
+ {
244
+ name: "frequency_penalty"
245
+ data_type: TYPE_FP32
246
+ dims: [ 1 ]
247
+ reshape: { shape: [ ] }
248
+ optional: true
249
+ },
250
+ {
251
+ name: "random_seed"
252
+ data_type: TYPE_UINT64
253
+ dims: [ 1 ]
254
+ reshape: { shape: [ ] }
255
+ optional: true
256
+ },
257
+ {
258
+ name: "return_log_probs"
259
+ data_type: TYPE_BOOL
260
+ dims: [ 1 ]
261
+ reshape: { shape: [ ] }
262
+ optional: true
263
+ },
264
+ {
265
+ name: "return_context_logits"
266
+ data_type: TYPE_BOOL
267
+ dims: [ 1 ]
268
+ reshape: { shape: [ ] }
269
+ optional: true
270
+ },
271
+ {
272
+ name: "return_generation_logits"
273
+ data_type: TYPE_BOOL
274
+ dims: [ 1 ]
275
+ reshape: { shape: [ ] }
276
+ optional: true
277
+ },
278
+ {
279
+ name: "return_perf_metrics"
280
+ data_type: TYPE_BOOL
281
+ dims: [ 1 ]
282
+ reshape: { shape: [ ] }
283
+ optional: true
284
+ },
285
+ {
286
+ name: "exclude_input_in_output"
287
+ data_type: TYPE_BOOL
288
+ dims: [ 1 ]
289
+ reshape: { shape: [ ] }
290
+ optional: true
291
+ },
292
+ {
293
+ name: "stop"
294
+ data_type: TYPE_BOOL
295
+ dims: [ 1 ]
296
+ reshape: { shape: [ ] }
297
+ optional: true
298
+ },
299
+ {
300
+ name: "streaming"
301
+ data_type: TYPE_BOOL
302
+ dims: [ 1 ]
303
+ reshape: { shape: [ ] }
304
+ optional: true
305
+ },
306
+ {
307
+ name: "prompt_embedding_table"
308
+ data_type: TYPE_FP16
309
+ dims: [ -1, -1 ]
310
+ optional: true
311
+ allow_ragged_batch: true
312
+ },
313
+ {
314
+ name: "prompt_table_extra_ids"
315
+ data_type: TYPE_UINT64
316
+ dims: [ -1 ]
317
+ optional: true
318
+ allow_ragged_batch: true
319
+ },
320
+ {
321
+ name: "prompt_vocab_size"
322
+ data_type: TYPE_INT32
323
+ dims: [ 1 ]
324
+ reshape: { shape: [ ] }
325
+ optional: true
326
+ },
327
+ # cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
328
+ {
329
+ name: "cross_attention_mask"
330
+ data_type: TYPE_BOOL
331
+ dims: [ -1, -1 ]
332
+ optional: true
333
+ allow_ragged_batch: true
334
+ },
335
+ # Mrope param when mrope is used
336
+ {
337
+ name: "mrope_rotary_cos_sin"
338
+ data_type: TYPE_FP32
339
+ dims: [ -1 ]
340
+ optional: true
341
+ },
342
+ {
343
+ name: "mrope_position_deltas"
344
+ data_type: TYPE_INT64
345
+ dims: [ 1 ]
346
+ optional: true
347
+ },
348
+ # the unique task ID for the given LoRA.
349
+ # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
350
+ # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
351
+ # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
352
+ {
353
+ name: "lora_task_id"
354
+ data_type: TYPE_UINT64
355
+ dims: [ 1 ]
356
+ reshape: { shape: [ ] }
357
+ optional: true
358
+ },
359
+ # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
360
+ # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
361
+ # each of the in / out tensors are first flattened and then concatenated together in the format above.
362
+ # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
363
+ {
364
+ name: "lora_weights"
365
+ data_type: TYPE_FP16
366
+ dims: [ -1, -1 ]
367
+ optional: true
368
+ allow_ragged_batch: true
369
+ },
370
+ # module identifier (same size a first dimension of lora_weights)
371
+ # See LoraModule::ModuleType for model id mapping
372
+ #
373
+ # "attn_qkv": 0 # compbined qkv adapter
374
+ # "attn_q": 1 # q adapter
375
+ # "attn_k": 2 # k adapter
376
+ # "attn_v": 3 # v adapter
377
+ # "attn_dense": 4 # adapter for the dense layer in attention
378
+ # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
379
+ # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
380
+ # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
381
+ #
382
+ # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
383
+ {
384
+ name: "lora_config"
385
+ data_type: TYPE_INT32
386
+ dims: [ -1, 3 ]
387
+ optional: true
388
+ allow_ragged_batch: true
389
+ },
390
+ {
391
+ name: "context_phase_params"
392
+ data_type: TYPE_UINT8
393
+ dims: [ -1 ]
394
+ optional: true
395
+ allow_ragged_batch: true
396
+ },
397
+ # skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
398
+ {
399
+ name: "skip_cross_attn_blocks"
400
+ data_type: TYPE_BOOL
401
+ dims: [ 1 ]
402
+ optional: true
403
+ allow_ragged_batch: true
404
+ },
405
+ {
406
+ name: "retention_token_range_starts"
407
+ data_type: TYPE_INT32
408
+ dims: [ -1 ]
409
+ optional: true
410
+ allow_ragged_batch: true
411
+ },
412
+ {
413
+ name: "retention_token_range_ends"
414
+ data_type: TYPE_INT32
415
+ dims: [ -1 ]
416
+ optional: true
417
+ allow_ragged_batch: true
418
+ },
419
+ {
420
+ name: "retention_token_range_priorities"
421
+ data_type: TYPE_INT32
422
+ dims: [ -1 ]
423
+ optional: true
424
+ allow_ragged_batch: true
425
+ },
426
+ {
427
+ name: "retention_token_range_durations_ms"
428
+ data_type: TYPE_INT32
429
+ dims: [ -1 ]
430
+ optional: true
431
+ allow_ragged_batch: true
432
+ },
433
+ {
434
+ name: "retention_decode_priority"
435
+ data_type: TYPE_INT32
436
+ dims: [ 1 ]
437
+ optional: true
438
+ allow_ragged_batch: true
439
+ },
440
+ {
441
+ name: "retention_decode_duration_ms"
442
+ data_type: TYPE_INT32
443
+ dims: [ 1 ]
444
+ optional: true
445
+ allow_ragged_batch: true
446
+ },
447
+ {
448
+ name: "guided_decoding_guide_type"
449
+ data_type: TYPE_STRING
450
+ dims: [ 1 ]
451
+ optional: true
452
+ allow_ragged_batch: true
453
+ },
454
+ {
455
+ name: "guided_decoding_guide"
456
+ data_type: TYPE_STRING
457
+ dims: [ 1 ]
458
+ optional: true
459
+ allow_ragged_batch: true
460
+ },
461
+ {
462
+ name: "lookahead_window_size"
463
+ data_type: TYPE_INT32
464
+ dims: [ 1 ]
465
+ optional: true
466
+ allow_ragged_batch: true
467
+ },
468
+ {
469
+ name: "lookahead_ngram_size"
470
+ data_type: TYPE_INT32
471
+ dims: [ 1 ]
472
+ optional: true
473
+ allow_ragged_batch: true
474
+ },
475
+ {
476
+ name: "lookahead_verification_set_size"
477
+ data_type: TYPE_INT32
478
+ dims: [ 1 ]
479
+ optional: true
480
+ allow_ragged_batch: true
481
+ }
482
+ ]
483
+ output [
484
+ {
485
+ name: "output_ids"
486
+ data_type: TYPE_INT32
487
+ dims: [ -1, -1 ]
488
+ },
489
+ {
490
+ name: "sequence_length"
491
+ data_type: TYPE_INT32
492
+ dims: [ -1 ]
493
+ },
494
+ {
495
+ name: "cum_log_probs"
496
+ data_type: TYPE_FP32
497
+ dims: [ -1 ]
498
+ },
499
+ {
500
+ name: "output_log_probs"
501
+ data_type: TYPE_FP32
502
+ dims: [ -1, -1 ]
503
+ },
504
+ {
505
+ name: "context_logits"
506
+ data_type: ${logits_datatype}
507
+ dims: [ -1, -1 ]
508
+ },
509
+ {
510
+ name: "generation_logits"
511
+ data_type: ${logits_datatype}
512
+ dims: [ -1, -1, -1 ]
513
+ },
514
+ {
515
+ name: "batch_index"
516
+ data_type: TYPE_INT32
517
+ dims: [ 1 ]
518
+ },
519
+ {
520
+ name: "sequence_index"
521
+ data_type: TYPE_INT32
522
+ dims: [ 1 ]
523
+ },
524
+ {
525
+ name: "context_phase_params"
526
+ data_type: TYPE_UINT8
527
+ dims: [ -1 ]
528
+ },
529
+ {
530
+ name: "kv_cache_alloc_new_blocks"
531
+ data_type: TYPE_INT32
532
+ dims: [ 1 ]
533
+ },
534
+ {
535
+ name: "kv_cache_reused_blocks"
536
+ data_type: TYPE_INT32
537
+ dims: [ 1 ]
538
+ },
539
+ {
540
+ name: "kv_cache_alloc_total_blocks"
541
+ data_type: TYPE_INT32
542
+ dims: [ 1 ]
543
+ },
544
+ {
545
+ name: "arrival_time_ns"
546
+ data_type: TYPE_INT64
547
+ dims: [ 1 ]
548
+ },
549
+ {
550
+ name: "first_scheduled_time_ns"
551
+ data_type: TYPE_INT64
552
+ dims: [ 1 ]
553
+ },
554
+ {
555
+ name: "first_token_time_ns"
556
+ data_type: TYPE_INT64
557
+ dims: [ 1 ]
558
+ },
559
+ {
560
+ name: "last_token_time_ns"
561
+ data_type: TYPE_INT64
562
+ dims: [ 1 ]
563
+ },
564
+ {
565
+ name: "acceptance_rate"
566
+ data_type: TYPE_FP32
567
+ dims: [ 1 ]
568
+ },
569
+ {
570
+ name: "total_accepted_draft_tokens"
571
+ data_type: TYPE_INT32
572
+ dims: [ 1 ]
573
+ },
574
+ {
575
+ name: "total_draft_tokens"
576
+ data_type: TYPE_INT32
577
+ dims: [ 1 ]
578
+ }
579
+ ]
580
+ instance_group [
581
+ {
582
+ count: 1
583
+ kind : KIND_CPU
584
+ }
585
+ ]
586
+ parameters: {
587
+ key: "max_beam_width"
588
+ value: {
589
+ string_value: "${max_beam_width}"
590
+ }
591
+ }
592
+ parameters: {
593
+ key: "FORCE_CPU_ONLY_INPUT_TENSORS"
594
+ value: {
595
+ string_value: "no"
596
+ }
597
+ }
598
+ parameters: {
599
+ key: "gpt_model_type"
600
+ value: {
601
+ string_value: "${batching_strategy}"
602
+ }
603
+ }
604
+ parameters: {
605
+ key: "gpt_model_path"
606
+ value: {
607
+ string_value: "${engine_dir}"
608
+ }
609
+ }
610
+ parameters: {
611
+ key: "encoder_model_path"
612
+ value: {
613
+ string_value: "${encoder_engine_dir}"
614
+ }
615
+ }
616
+ parameters: {
617
+ key: "max_tokens_in_paged_kv_cache"
618
+ value: {
619
+ string_value: "${max_tokens_in_paged_kv_cache}"
620
+ }
621
+ }
622
+ parameters: {
623
+ key: "max_attention_window_size"
624
+ value: {
625
+ string_value: "${max_attention_window_size}"
626
+ }
627
+ }
628
+ parameters: {
629
+ key: "sink_token_length"
630
+ value: {
631
+ string_value: "${sink_token_length}"
632
+ }
633
+ }
634
+ parameters: {
635
+ key: "batch_scheduler_policy"
636
+ value: {
637
+ string_value: "${batch_scheduler_policy}"
638
+ }
639
+ }
640
+ parameters: {
641
+ key: "kv_cache_free_gpu_mem_fraction"
642
+ value: {
643
+ string_value: "${kv_cache_free_gpu_mem_fraction}"
644
+ }
645
+ }
646
+ parameters: {
647
+ key: "cross_kv_cache_fraction"
648
+ value: {
649
+ string_value: "${cross_kv_cache_fraction}"
650
+ }
651
+ }
652
+ parameters: {
653
+ key: "kv_cache_host_memory_bytes"
654
+ value: {
655
+ string_value: "${kv_cache_host_memory_bytes}"
656
+ }
657
+ }
658
+ # kv_cache_onboard_blocks is for internal implementation.
659
+ parameters: {
660
+ key: "kv_cache_onboard_blocks"
661
+ value: {
662
+ string_value: "${kv_cache_onboard_blocks}"
663
+ }
664
+ }
665
+ # enable_trt_overlap is deprecated and doesn't have any effect on the runtime
666
+ # parameters: {
667
+ # key: "enable_trt_overlap"
668
+ # value: {
669
+ # string_value: "${enable_trt_overlap}"
670
+ # }
671
+ # }
672
+ parameters: {
673
+ key: "exclude_input_in_output"
674
+ value: {
675
+ string_value: "${exclude_input_in_output}"
676
+ }
677
+ }
678
+ parameters: {
679
+ key: "cancellation_check_period_ms"
680
+ value: {
681
+ string_value: "${cancellation_check_period_ms}"
682
+ }
683
+ }
684
+ parameters: {
685
+ key: "stats_check_period_ms"
686
+ value: {
687
+ string_value: "${stats_check_period_ms}"
688
+ }
689
+ }
690
+ parameters: {
691
+ key: "iter_stats_max_iterations"
692
+ value: {
693
+ string_value: "${iter_stats_max_iterations}"
694
+ }
695
+ }
696
+ parameters: {
697
+ key: "request_stats_max_iterations"
698
+ value: {
699
+ string_value: "${request_stats_max_iterations}"
700
+ }
701
+ }
702
+ parameters: {
703
+ key: "enable_kv_cache_reuse"
704
+ value: {
705
+ string_value: "${enable_kv_cache_reuse}"
706
+ }
707
+ }
708
+ parameters: {
709
+ key: "normalize_log_probs"
710
+ value: {
711
+ string_value: "${normalize_log_probs}"
712
+ }
713
+ }
714
+ parameters: {
715
+ key: "enable_chunked_context"
716
+ value: {
717
+ string_value: "${enable_chunked_context}"
718
+ }
719
+ }
720
+ parameters: {
721
+ key: "gpu_device_ids"
722
+ value: {
723
+ string_value: "${gpu_device_ids}"
724
+ }
725
+ }
726
+ parameters: {
727
+ key: "participant_ids"
728
+ value: {
729
+ string_value: "${participant_ids}"
730
+ }
731
+ }
732
+ parameters: {
733
+ key: "lora_cache_optimal_adapter_size"
734
+ value: {
735
+ string_value: "${lora_cache_optimal_adapter_size}"
736
+ }
737
+ }
738
+ parameters: {
739
+ key: "lora_cache_max_adapter_size"
740
+ value: {
741
+ string_value: "${lora_cache_max_adapter_size}"
742
+ }
743
+ }
744
+ parameters: {
745
+ key: "lora_cache_gpu_memory_fraction"
746
+ value: {
747
+ string_value: "${lora_cache_gpu_memory_fraction}"
748
+ }
749
+ }
750
+ parameters: {
751
+ key: "lora_cache_host_memory_bytes"
752
+ value: {
753
+ string_value: "${lora_cache_host_memory_bytes}"
754
+ }
755
+ }
756
+ parameters: {
757
+ key: "lora_prefetch_dir"
758
+ value: {
759
+ string_value: "${lora_prefetch_dir}"
760
+ }
761
+ }
762
+ parameters: {
763
+ key: "decoding_mode"
764
+ value: {
765
+ string_value: "${decoding_mode}"
766
+ }
767
+ }
768
+ parameters: {
769
+ key: "executor_worker_path"
770
+ value: {
771
+ string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
772
+ }
773
+ }
774
+ parameters: {
775
+ key: "lookahead_window_size"
776
+ value: {
777
+ string_value: "${lookahead_window_size}"
778
+ }
779
+ }
780
+ parameters: {
781
+ key: "lookahead_ngram_size"
782
+ value: {
783
+ string_value: "${lookahead_ngram_size}"
784
+ }
785
+ }
786
+ parameters: {
787
+ key: "lookahead_verification_set_size"
788
+ value: {
789
+ string_value: "${lookahead_verification_set_size}"
790
+ }
791
+ }
792
+ parameters: {
793
+ key: "medusa_choices"
794
+ value: {
795
+ string_value: "${medusa_choices}"
796
+ }
797
+ }
798
+ parameters: {
799
+ key: "eagle_choices"
800
+ value: {
801
+ string_value: "${eagle_choices}"
802
+ }
803
+ }
804
+ parameters: {
805
+ key: "gpu_weights_percent"
806
+ value: {
807
+ string_value: "${gpu_weights_percent}"
808
+ }
809
+ }
810
+ parameters: {
811
+ key: "enable_context_fmha_fp32_acc"
812
+ value: {
813
+ string_value: "${enable_context_fmha_fp32_acc}"
814
+ }
815
+ }
816
+ parameters: {
817
+ key: "multi_block_mode"
818
+ value: {
819
+ string_value: "${multi_block_mode}"
820
+ }
821
+ }
822
+ parameters: {
823
+ key: "cuda_graph_mode"
824
+ value: {
825
+ string_value: "${cuda_graph_mode}"
826
+ }
827
+ }
828
+ parameters: {
829
+ key: "cuda_graph_cache_size"
830
+ value: {
831
+ string_value: "${cuda_graph_cache_size}"
832
+ }
833
+ }
834
+ parameters: {
835
+ key: "speculative_decoding_fast_logits"
836
+ value: {
837
+ string_value: "${speculative_decoding_fast_logits}"
838
+ }
839
+ }
840
+ parameters: {
841
+ key: "tokenizer_dir"
842
+ value: {
843
+ string_value: "${tokenizer_dir}"
844
+ }
845
+ }
846
+ parameters: {
847
+ key: "guided_decoding_backend"
848
+ value: {
849
+ string_value: "${guided_decoding_backend}"
850
+ }
851
+ }
852
+ parameters: {
853
+ key: "xgrammar_tokenizer_info_path"
854
+ value: {
855
+ string_value: "${xgrammar_tokenizer_info_path}"
856
+ }
857
+ }
runtime/triton_trtllm/model_repo/vocoder/1/model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ import json
28
+ import os
29
+ import logging
30
+ from typing import List, Dict
31
+
32
+ import torch
33
+ from torch.utils.dlpack import to_dlpack
34
+
35
+ import triton_python_backend_utils as pb_utils
36
+
37
+ from sparktts.models.bicodec import BiCodec
38
+
39
+ # Configure logging
40
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
41
+ logger = logging.getLogger(__name__)
42
+
43
+ class TritonPythonModel:
44
+ """Triton Python model for vocoder.
45
+
46
+ This model takes global and semantic tokens as input and generates audio waveforms
47
+ using the BiCodec vocoder.
48
+ """
49
+
50
+ def initialize(self, args):
51
+ """Initialize the model.
52
+
53
+ Args:
54
+ args: Dictionary containing model configuration
55
+ """
56
+ # Parse model parameters
57
+ parameters = json.loads(args['model_config'])['parameters']
58
+ model_params = {key: value["string_value"] for key, value in parameters.items()}
59
+ model_dir = model_params["model_dir"]
60
+
61
+ # Initialize device and vocoder
62
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+ logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
64
+
65
+ self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec")
66
+ del self.vocoder.encoder, self.vocoder.postnet
67
+ self.vocoder.eval().to(self.device) # Set model to evaluation mode
68
+
69
+ logger.info("Vocoder initialized successfully")
70
+
71
+
72
+ def execute(self, requests):
73
+ """Execute inference on the batched requests.
74
+
75
+ Args:
76
+ requests: List of inference requests
77
+
78
+ Returns:
79
+ List of inference responses containing generated waveforms
80
+ """
81
+ global_tokens_list, semantic_tokens_list = [], []
82
+
83
+ # Process each request in batch
84
+ for request in requests:
85
+ global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy()
86
+ semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy()
87
+ global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device))
88
+ semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device))
89
+
90
+ # Concatenate tokens for batch processing
91
+ global_tokens = torch.cat(global_tokens_list, dim=0)
92
+ semantic_tokens = torch.cat(semantic_tokens_list, dim=0)
93
+
94
+
95
+ # Generate waveforms
96
+ with torch.no_grad():
97
+ wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1))
98
+
99
+ # Prepare responses
100
+ responses = []
101
+ for i in range(len(requests)):
102
+ wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i]))
103
+ inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
104
+ responses.append(inference_response)
105
+
106
+ return responses
runtime/triton_trtllm/model_repo/vocoder/config.pbtxt ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "vocoder"
16
+ backend: "python"
17
+ max_batch_size: ${triton_max_batch_size}
18
+ dynamic_batching {
19
+ max_queue_delay_microseconds: ${max_queue_delay_microseconds}
20
+ }
21
+ parameters [
22
+ {
23
+ key: "model_dir",
24
+ value: {string_value:"${model_dir}"}
25
+ }
26
+ ]
27
+
28
+ input [
29
+ {
30
+ name: "global_tokens"
31
+ data_type: TYPE_INT32
32
+ dims: [-1]
33
+ },
34
+ {
35
+ name: "semantic_tokens"
36
+ data_type: TYPE_INT32
37
+ dims: [-1]
38
+ }
39
+ ]
40
+ output [
41
+ {
42
+ name: "waveform"
43
+ data_type: TYPE_FP32
44
+ dims: [ -1 ]
45
+ }
46
+ ]
47
+
48
+ instance_group [
49
+ {
50
+ count: 1
51
+ kind: KIND_CPU
52
+ }
53
+ ]
runtime/triton_trtllm/run.sh ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export PYTHONPATH=../../../Spark-TTS/
2
+ export CUDA_VISIBLE_DEVICES=0
3
+ stage=$1
4
+ stop_stage=$2
5
+ service_type=$3
6
+ echo "Start stage: $stage, Stop stage: $stop_stage service_type: $service_type"
7
+
8
+ huggingface_model_local_dir=../../pretrained_models/Spark-TTS-0.5B
9
+ trt_dtype=bfloat16
10
+ trt_weights_dir=./tllm_checkpoint_${trt_dtype}
11
+ trt_engines_dir=./trt_engines_${trt_dtype}
12
+
13
+ model_repo=./model_repo_test
14
+
15
+ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
16
+ echo "Downloading Spark-TTS-0.5B from HuggingFace"
17
+ huggingface-cli download SparkAudio/Spark-TTS-0.5B --local-dir $huggingface_model_local_dir || exit 1
18
+ fi
19
+
20
+
21
+ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
22
+ echo "Converting checkpoint to TensorRT weights"
23
+ python scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir/LLM \
24
+ --output_dir $trt_weights_dir \
25
+ --dtype $trt_dtype || exit 1
26
+
27
+ echo "Building TensorRT engines"
28
+ trtllm-build --checkpoint_dir $trt_weights_dir \
29
+ --output_dir $trt_engines_dir \
30
+ --max_batch_size 16 \
31
+ --max_num_tokens 32768 \
32
+ --gemm_plugin $trt_dtype || exit 1
33
+ fi
34
+
35
+ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
36
+ echo "Creating model repository"
37
+ rm -rf $model_repo
38
+ mkdir -p $model_repo
39
+ spark_tts_dir="spark_tts"
40
+
41
+ cp -r ./model_repo/${spark_tts_dir} $model_repo
42
+ cp -r ./model_repo/audio_tokenizer $model_repo
43
+ cp -r ./model_repo/tensorrt_llm $model_repo
44
+ cp -r ./model_repo/vocoder $model_repo
45
+
46
+ ENGINE_PATH=$trt_engines_dir
47
+ MAX_QUEUE_DELAY_MICROSECONDS=0
48
+ MODEL_DIR=$huggingface_model_local_dir
49
+ LLM_TOKENIZER_DIR=$huggingface_model_local_dir/LLM
50
+ BLS_INSTANCE_NUM=4
51
+ TRITON_MAX_BATCH_SIZE=16
52
+ # streaming TTS parameters
53
+ AUDIO_CHUNK_DURATION=1.0
54
+ MAX_AUDIO_CHUNK_DURATION=30.0
55
+ AUDIO_CHUNK_SIZE_SCALE_FACTOR=8.0
56
+ AUDIO_CHUNK_OVERLAP_DURATION=0.1
57
+ python3 scripts/fill_template.py -i ${model_repo}/vocoder/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
58
+ python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS}
59
+ if [ "$service_type" == "streaming" ]; then
60
+ DECOUPLED_MODE=True
61
+ else
62
+ DECOUPLED_MODE=False
63
+ fi
64
+ python3 scripts/fill_template.py -i ${model_repo}/${spark_tts_dir}/config.pbtxt bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},audio_chunk_duration:${AUDIO_CHUNK_DURATION},max_audio_chunk_duration:${MAX_AUDIO_CHUNK_DURATION},audio_chunk_size_scale_factor:${AUDIO_CHUNK_SIZE_SCALE_FACTOR},audio_chunk_overlap_duration:${AUDIO_CHUNK_OVERLAP_DURATION}
65
+ python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
66
+
67
+ fi
68
+
69
+ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
70
+ echo "Starting Triton server"
71
+ tritonserver --model-repository ${model_repo}
72
+ fi
73
+
74
+
75
+ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
76
+ echo "Running benchmark client"
77
+ num_task=2
78
+ if [ "$service_type" == "streaming" ]; then
79
+ mode="streaming"
80
+ else
81
+ mode="offline"
82
+ fi
83
+ python3 client_grpc.py \
84
+ --server-addr localhost \
85
+ --model-name spark_tts \
86
+ --num-tasks $num_task \
87
+ --mode $mode \
88
+ --log-dir ./log_concurrent_tasks_${num_task}_${mode}_new
89
+ fi
90
+
91
+ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
92
+ echo "Running single utterance client"
93
+ if [ "$service_type" == "streaming" ]; then
94
+ python client_grpc.py \
95
+ --server-addr localhost \
96
+ --reference-audio ../../example/prompt_audio.wav \
97
+ --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
98
+ --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
99
+ --model-name spark_tts \
100
+ --chunk-overlap-duration 0.1 \
101
+ --mode streaming
102
+ else
103
+ python client_http.py \
104
+ --reference-audio ../../example/prompt_audio.wav \
105
+ --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
106
+ --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \
107
+ --model-name spark_tts
108
+ fi
109
+ fi
runtime/triton_trtllm/scripts/convert_checkpoint.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ import traceback
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+
7
+ from transformers import AutoConfig
8
+
9
+ import tensorrt_llm
10
+ from tensorrt_llm._utils import release_gc
11
+ from tensorrt_llm.logger import logger
12
+ from tensorrt_llm.mapping import Mapping
13
+ from tensorrt_llm.models import QWenForCausalLM
14
+ from tensorrt_llm.models.modeling_utils import QuantConfig
15
+ from tensorrt_llm.quantization import QuantAlgo
16
+
17
+
18
+ def parse_arguments():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('--model_dir', type=str, default=None, required=True)
21
+ parser.add_argument('--tp_size',
22
+ type=int,
23
+ default=1,
24
+ help='N-way tensor parallelism size')
25
+ parser.add_argument('--pp_size',
26
+ type=int,
27
+ default=1,
28
+ help='N-way pipeline parallelism size')
29
+ parser.add_argument(
30
+ '--dtype',
31
+ type=str,
32
+ default='auto',
33
+ choices=['auto', 'float16', 'bfloat16', 'float32'],
34
+ help=
35
+ "The data type for the model weights and activations if not quantized. "
36
+ "If 'auto', the data type is automatically inferred from the source model; "
37
+ "however, if the source dtype is float32, it is converted to float16.")
38
+ parser.add_argument(
39
+ '--use_weight_only',
40
+ default=False,
41
+ action="store_true",
42
+ help='Quantize weights for the various GEMMs to INT4/INT8.'
43
+ 'See --weight_only_precision to set the precision')
44
+ parser.add_argument(
45
+ '--disable_weight_only_quant_plugin',
46
+ default=False,
47
+ action="store_true",
48
+ help=
49
+ 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.'
50
+ 'You must also use --use_weight_only for that argument to have an impact.'
51
+ )
52
+ parser.add_argument(
53
+ '--weight_only_precision',
54
+ const='int8',
55
+ type=str,
56
+ nargs='?',
57
+ default='int8',
58
+ choices=['int8', 'int4', 'int4_gptq'],
59
+ help=
60
+ 'Define the precision for the weights when using weight-only quantization.'
61
+ 'You must also use --use_weight_only for that argument to have an impact.'
62
+ )
63
+ parser.add_argument(
64
+ '--calib_dataset',
65
+ type=str,
66
+ default='ccdv/cnn_dailymail',
67
+ help=
68
+ "The huggingface dataset name or the local directory of the dataset for calibration."
69
+ )
70
+ parser.add_argument(
71
+ "--smoothquant",
72
+ "-sq",
73
+ type=float,
74
+ default=None,
75
+ help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
76
+ " to Smoothquant the model, and output int8 weights."
77
+ " A good first try is 0.5. Must be in [0, 1]")
78
+ parser.add_argument(
79
+ '--per_channel',
80
+ action="store_true",
81
+ default=False,
82
+ help=
83
+ 'By default, we use a single static scaling factor for the GEMM\'s result. '
84
+ 'per_channel instead uses a different static scaling factor for each channel. '
85
+ 'The latter is usually more accurate, but a little slower.')
86
+ parser.add_argument(
87
+ '--per_token',
88
+ action="store_true",
89
+ default=False,
90
+ help=
91
+ 'By default, we use a single static scaling factor to scale activations in the int8 range. '
92
+ 'per_token chooses at run time, and for each token, a custom scaling factor. '
93
+ 'The latter is usually more accurate, but a little slower.')
94
+ parser.add_argument(
95
+ '--int8_kv_cache',
96
+ default=False,
97
+ action="store_true",
98
+ help=
99
+ 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
100
+ )
101
+ parser.add_argument(
102
+ '--per_group',
103
+ default=False,
104
+ action="store_true",
105
+ help=
106
+ 'By default, we use a single static scaling factor to scale weights in the int4 range. '
107
+ 'per_group chooses at run time, and for each group, a custom scaling factor. '
108
+ 'The flag is built for GPTQ/AWQ quantization.')
109
+
110
+ parser.add_argument('--group_size',
111
+ type=int,
112
+ default=128,
113
+ help='Group size used in GPTQ quantization.')
114
+
115
+ parser.add_argument("--load_model_on_cpu", action="store_true")
116
+ parser.add_argument(
117
+ '--use_parallel_embedding',
118
+ action="store_true",
119
+ default=False,
120
+ help=
121
+ 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
122
+ )
123
+ parser.add_argument(
124
+ '--embedding_sharding_dim',
125
+ type=int,
126
+ default=0,
127
+ choices=[0, 1],
128
+ help=
129
+ 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
130
+ 'To shard it along hidden dimension, set embedding_sharding_dim=1'
131
+ 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
132
+ )
133
+ parser.add_argument('--output_dir',
134
+ type=str,
135
+ default='tllm_checkpoint',
136
+ help='The path to save the TensorRT-LLM checkpoint')
137
+ parser.add_argument(
138
+ '--workers',
139
+ type=int,
140
+ default=1,
141
+ help='The number of workers for converting checkpoint in parallel')
142
+ parser.add_argument(
143
+ '--moe_tp_size',
144
+ type=int,
145
+ default=-1,
146
+ help=
147
+ 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE'
148
+ )
149
+ parser.add_argument(
150
+ '--moe_ep_size',
151
+ type=int,
152
+ default=-1,
153
+ help=
154
+ 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE'
155
+ )
156
+ args = parser.parse_args()
157
+ return args
158
+
159
+
160
+ def args_to_quant_config(args: argparse.Namespace) -> QuantConfig:
161
+ '''return config dict with quantization info based on the command line args
162
+ '''
163
+ quant_config = QuantConfig()
164
+ if args.use_weight_only:
165
+ if args.weight_only_precision == 'int8':
166
+ quant_config.quant_algo = QuantAlgo.W8A16
167
+ elif args.weight_only_precision == 'int4':
168
+ quant_config.quant_algo = QuantAlgo.W4A16
169
+ elif args.smoothquant:
170
+ quant_config.smoothquant_val = args.smoothquant
171
+ if args.per_channel:
172
+ if args.per_token:
173
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
174
+ else:
175
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
176
+ else:
177
+ if args.per_token:
178
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
179
+ else:
180
+ quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
181
+
182
+ if args.int8_kv_cache:
183
+ quant_config.kv_cache_quant_algo = QuantAlgo.INT8
184
+
185
+ if args.weight_only_precision == 'int4_gptq':
186
+ quant_config.group_size = args.group_size
187
+ quant_config.has_zero_point = True
188
+ quant_config.pre_quant_scale = False
189
+ quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
190
+
191
+ return quant_config
192
+
193
+
194
+ def update_quant_config_from_hf(quant_config, hf_config,
195
+ override_fields) -> tuple[QuantConfig, dict]:
196
+ hf_config_dict = hf_config.to_dict()
197
+ if hf_config_dict.get('quantization_config'):
198
+ # update the quant_algo, and clamp_val.
199
+ if hf_config_dict['quantization_config'].get('quant_method') == 'awq':
200
+ logger.info(
201
+ "Load quantization configs from huggingface model_config.")
202
+ quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
203
+ quant_config.group_size = hf_config_dict['quantization_config'].get(
204
+ 'group_size', 128)
205
+ quant_config.has_zero_point = hf_config_dict[
206
+ 'quantization_config'].get('zero_point', False)
207
+ override_fields.update({"use_autoawq": True})
208
+ elif hf_config_dict['quantization_config'].get(
209
+ 'quant_method') == 'gptq':
210
+ logger.info(
211
+ "Load quantization configs from huggingface model_config.")
212
+ desc_act = hf_config_dict['quantization_config'].get(
213
+ 'desc_act', False)
214
+ if desc_act:
215
+ raise ValueError("GPTQ with desc_act=True is not implemented!")
216
+ quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
217
+ quant_config.group_size = hf_config_dict['quantization_config'].get(
218
+ 'group_size', 128)
219
+ quant_config.has_zero_point = hf_config_dict[
220
+ 'quantization_config'].get('sym', False)
221
+ return quant_config, override_fields
222
+
223
+
224
+ def args_to_build_options(args):
225
+ return {
226
+ 'use_parallel_embedding': args.use_parallel_embedding,
227
+ 'embedding_sharding_dim': args.embedding_sharding_dim,
228
+ 'disable_weight_only_quant_plugin':
229
+ args.disable_weight_only_quant_plugin
230
+ }
231
+
232
+
233
+ def convert_and_save_hf(args):
234
+ model_dir = args.model_dir
235
+ world_size = args.tp_size * args.pp_size
236
+ # Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
237
+ # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
238
+ # before the refactor is done.
239
+ override_fields = {}
240
+ override_fields.update(args_to_build_options(args))
241
+ quant_config = args_to_quant_config(args)
242
+
243
+ try:
244
+ hf_config = AutoConfig.from_pretrained(model_dir,
245
+ trust_remote_code=True)
246
+ quant_config, override_fields = update_quant_config_from_hf(
247
+ quant_config, hf_config, override_fields)
248
+ except:
249
+ logger.warning("AutoConfig cannot load the huggingface config.")
250
+
251
+ if args.smoothquant is not None or args.int8_kv_cache:
252
+ mapping = Mapping(
253
+ world_size=world_size,
254
+ tp_size=args.tp_size,
255
+ pp_size=args.pp_size,
256
+ moe_tp_size=args.moe_tp_size,
257
+ moe_ep_size=args.moe_ep_size,
258
+ )
259
+ QWenForCausalLM.quantize(args.model_dir,
260
+ args.output_dir,
261
+ dtype=args.dtype,
262
+ mapping=mapping,
263
+ quant_config=quant_config,
264
+ calib_dataset=args.calib_dataset,
265
+ **override_fields)
266
+ else:
267
+
268
+ def convert_and_save_rank(args, rank):
269
+ mapping = Mapping(world_size=world_size,
270
+ rank=rank,
271
+ tp_size=args.tp_size,
272
+ pp_size=args.pp_size,
273
+ moe_tp_size=args.moe_tp_size,
274
+ moe_ep_size=args.moe_ep_size)
275
+ qwen = QWenForCausalLM.from_hugging_face(model_dir,
276
+ args.dtype,
277
+ mapping=mapping,
278
+ quant_config=quant_config,
279
+ **override_fields)
280
+ qwen.save_checkpoint(args.output_dir, save_config=(rank == 0))
281
+ del qwen
282
+
283
+ execute(args.workers, [convert_and_save_rank] * world_size, args)
284
+ release_gc()
285
+
286
+
287
+ def execute(workers, func, args):
288
+ if workers == 1:
289
+ for rank, f in enumerate(func):
290
+ f(args, rank)
291
+ else:
292
+ with ThreadPoolExecutor(max_workers=workers) as p:
293
+ futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
294
+ exceptions = []
295
+ for future in as_completed(futures):
296
+ try:
297
+ future.result()
298
+ except Exception as e:
299
+ traceback.print_exc()
300
+ exceptions.append(e)
301
+ assert len(
302
+ exceptions
303
+ ) == 0, "Checkpoint conversion failed, please check error log."
304
+
305
+
306
+ def main():
307
+ print(tensorrt_llm.__version__)
308
+ args = parse_arguments()
309
+
310
+ if (args.moe_tp_size == -1 and args.moe_ep_size == -1):
311
+ # moe default to tp-only
312
+ args.moe_tp_size = args.tp_size
313
+ args.moe_ep_size = 1
314
+ elif (args.moe_tp_size == -1):
315
+ args.moe_tp_size = args.tp_size // args.moe_ep_size
316
+ elif (args.moe_ep_size == -1):
317
+ args.moe_ep_size = args.tp_size // args.moe_tp_size
318
+ assert (args.moe_tp_size * args.moe_ep_size == args.tp_size
319
+ ), "moe_tp_size * moe_ep_size must equal to tp_size"
320
+
321
+ tik = time.time()
322
+
323
+ if not os.path.exists(args.output_dir):
324
+ os.makedirs(args.output_dir)
325
+
326
+ assert args.model_dir is not None
327
+ convert_and_save_hf(args)
328
+
329
+ tok = time.time()
330
+ t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
331
+ print(f'Total time of converting checkpoints: {t}')
332
+
333
+
334
+ if __name__ == '__main__':
335
+ main()
runtime/triton_trtllm/scripts/fill_template.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ from argparse import ArgumentParser
3
+ from string import Template
4
+
5
+
6
+ def split(string, delimiter):
7
+ """Split a string using delimiter. Supports escaping.
8
+
9
+ Args:
10
+ string (str): The string to split.
11
+ delimiter (str): The delimiter to split the string with.
12
+
13
+ Returns:
14
+ list: A list of strings.
15
+ """
16
+ result = []
17
+ current = ""
18
+ escape = False
19
+ for char in string:
20
+ if escape:
21
+ current += char
22
+ escape = False
23
+ elif char == delimiter:
24
+ result.append(current)
25
+ current = ""
26
+ elif char == "\\":
27
+ escape = True
28
+ else:
29
+ current += char
30
+ result.append(current)
31
+ return result
32
+
33
+
34
+ def main(file_path, substitutions, in_place):
35
+ with open(file_path) as f:
36
+ pbtxt = Template(f.read())
37
+
38
+ sub_dict = {
39
+ "max_queue_size": 0,
40
+ 'max_queue_delay_microseconds': 0,
41
+ }
42
+ for sub in split(substitutions, ","):
43
+ key, value = split(sub, ":")
44
+ sub_dict[key] = value
45
+
46
+ assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}."
47
+
48
+ pbtxt = pbtxt.safe_substitute(sub_dict)
49
+
50
+ if in_place:
51
+ with open(file_path, "w") as f:
52
+ f.write(pbtxt)
53
+ else:
54
+ print(pbtxt)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ parser = ArgumentParser()
59
+ parser.add_argument("file_path", help="path of the .pbtxt to modify")
60
+ parser.add_argument(
61
+ "substitutions",
62
+ help=
63
+ "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
64
+ )
65
+ parser.add_argument("--in_place",
66
+ "-i",
67
+ action="store_true",
68
+ help="do the operation in-place")
69
+ args = parser.parse_args()
70
+ main(**vars(args))
sparktts/models/audio_tokenizer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import numpy as np
19
+
20
+ from pathlib import Path
21
+ from typing import Any, Dict, Tuple
22
+ from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
23
+
24
+ from sparktts.utils.file import load_config
25
+ from sparktts.utils.audio import load_audio
26
+ from sparktts.models.bicodec import BiCodec
27
+
28
+
29
+ class BiCodecTokenizer:
30
+ """BiCodec tokenizer for handling audio input and tokenization."""
31
+
32
+ def __init__(self, model_dir: Path, device: torch.device = None, **kwargs):
33
+ super().__init__()
34
+ """
35
+ Args:
36
+ model_dir: Path to the model directory.
37
+ device: Device to run the model on (default is GPU if available).
38
+ """
39
+ self.device = "cpu"
40
+ self.model_dir = model_dir
41
+ self.config = load_config(f"{model_dir}/config.yaml")
42
+ self._initialize_model()
43
+
44
+ def _initialize_model(self):
45
+ """Load and initialize the BiCodec model and Wav2Vec2 feature extractor."""
46
+ self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to(
47
+ self.device
48
+ )
49
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
50
+ f"{self.model_dir}/wav2vec2-large-xlsr-53"
51
+ )
52
+ self.feature_extractor = Wav2Vec2Model.from_pretrained(
53
+ f"{self.model_dir}/wav2vec2-large-xlsr-53"
54
+ ).to(self.device)
55
+ self.feature_extractor.config.output_hidden_states = True
56
+
57
+ def get_ref_clip(self, wav: np.ndarray) -> np.ndarray:
58
+ """Get reference audio clip for speaker embedding."""
59
+ ref_segment_length = (
60
+ int(self.config["sample_rate"] * self.config["ref_segment_duration"])
61
+ // self.config["latent_hop_length"]
62
+ * self.config["latent_hop_length"]
63
+ )
64
+ wav_length = len(wav)
65
+
66
+ if ref_segment_length > wav_length:
67
+ # Repeat and truncate to handle insufficient length
68
+ wav = np.tile(wav, ref_segment_length // wav_length + 1)
69
+
70
+ return wav[:ref_segment_length]
71
+
72
+ def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]:
73
+ """load auido and get reference audio from wav path"""
74
+ wav = load_audio(
75
+ wav_path,
76
+ sampling_rate=self.config["sample_rate"],
77
+ volume_normalize=self.config["volume_normalize"],
78
+ )
79
+
80
+ wav_ref = self.get_ref_clip(wav)
81
+
82
+ wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float()
83
+ return wav, wav_ref
84
+
85
+ def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor:
86
+ """extract wav2vec2 features"""
87
+ inputs = self.processor(
88
+ wavs,
89
+ sampling_rate=16000,
90
+ return_tensors="pt",
91
+ padding=True,
92
+ output_hidden_states=True,
93
+ ).input_values
94
+ feat = self.feature_extractor(inputs.to(self.feature_extractor.device))
95
+ feats_mix = (
96
+ feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16]
97
+ ) / 3
98
+
99
+ return feats_mix
100
+
101
+ def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor:
102
+ """tokenize the batch of audio
103
+
104
+ Args:
105
+ batch:
106
+ wavs (List[np.ndarray]): batch of audio
107
+ ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len)
108
+
109
+ Returns:
110
+ semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim)
111
+ global_tokens: global tokens. shape: (batch_size, seq_len, global_dim)
112
+ """
113
+ feats = self.extract_wav2vec2_features(batch["wav"])
114
+ batch["feat"] = feats
115
+ semantic_tokens, global_tokens = self.model.tokenize(batch)
116
+
117
+ return global_tokens, semantic_tokens
118
+
119
+ def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ """tokenize the audio"""
121
+ wav, ref_wav = self.process_audio(audio_path)
122
+ feat = self.extract_wav2vec2_features(wav)
123
+ batch = {
124
+ "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device),
125
+ "ref_wav": ref_wav.to(self.device),
126
+ "feat": feat.to(self.device),
127
+ }
128
+ semantic_tokens, global_tokens = self.model.tokenize(batch)
129
+
130
+ return global_tokens, semantic_tokens
131
+
132
+ def detokenize(
133
+ self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor
134
+ ) -> np.array:
135
+ """detokenize the tokens to waveform
136
+
137
+ Args:
138
+ global_tokens: global tokens. shape: (batch_size, global_dim)
139
+ semantic_tokens: semantic tokens. shape: (batch_size, latent_dim)
140
+
141
+ Returns:
142
+ wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single
143
+ """
144
+ global_tokens = global_tokens.unsqueeze(1)
145
+ wav_rec = self.model.detokenize(semantic_tokens, global_tokens)
146
+ return wav_rec.detach().squeeze().cpu().numpy()
147
+
148
+
149
+ # test
150
+ if __name__ == "__main__":
151
+ import soundfile as sf
152
+
153
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
+ tokenizer = BiCodecTokenizer(
155
+ model_dir="pretrained_models/Spark-TTS-0.5B",
156
+ device=device,
157
+ )
158
+ wav_path = "example/prompt_audio.wav"
159
+
160
+ global_tokens, semantic_tokens = tokenizer.tokenize(wav_path)
161
+
162
+ wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens)
163
+ sf.write("example/prompt_recon.wav", wav_rec, 16000)
sparktts/models/bicodec.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from pathlib import Path
19
+ from typing import Dict, Any
20
+ from omegaconf import DictConfig
21
+ from safetensors.torch import load_file
22
+
23
+ from sparktts.utils.file import load_config
24
+ from sparktts.modules.speaker.speaker_encoder import SpeakerEncoder
25
+ from sparktts.modules.encoder_decoder.feat_encoder import Encoder
26
+ from sparktts.modules.encoder_decoder.feat_decoder import Decoder
27
+ from sparktts.modules.encoder_decoder.wave_generator import WaveGenerator
28
+ from sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize
29
+
30
+
31
+ class BiCodec(nn.Module):
32
+ """
33
+ BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder,
34
+ quantizer, and wave generator.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ mel_params: Dict[str, Any],
40
+ encoder: nn.Module,
41
+ decoder: nn.Module,
42
+ quantizer: nn.Module,
43
+ speaker_encoder: nn.Module,
44
+ prenet: nn.Module,
45
+ postnet: nn.Module,
46
+ **kwargs
47
+ ) -> None:
48
+ """
49
+ Initializes the BiCodec model with the required components.
50
+
51
+ Args:
52
+ mel_params (dict): Parameters for the mel-spectrogram transformer.
53
+ encoder (nn.Module): Encoder module.
54
+ decoder (nn.Module): Decoder module.
55
+ quantizer (nn.Module): Quantizer module.
56
+ speaker_encoder (nn.Module): Speaker encoder module.
57
+ prenet (nn.Module): Prenet network.
58
+ postnet (nn.Module): Postnet network.
59
+ """
60
+ super().__init__()
61
+ self.encoder = encoder
62
+ self.decoder = decoder
63
+ self.quantizer = quantizer
64
+ self.speaker_encoder = speaker_encoder
65
+ self.prenet = prenet
66
+ self.postnet = postnet
67
+ self.init_mel_transformer(mel_params)
68
+
69
+ @classmethod
70
+ def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec":
71
+ """
72
+ Loads the model from a checkpoint.
73
+
74
+ Args:
75
+ model_dir (Path): Path to the model directory containing checkpoint and config.
76
+
77
+ Returns:
78
+ BiCodec: The initialized BiCodec model.
79
+ """
80
+ ckpt_path = f'{model_dir}/model.safetensors'
81
+ config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer']
82
+ mel_params = config["mel_params"]
83
+ encoder = Encoder(**config["encoder"])
84
+ quantizer = FactorizedVectorQuantize(**config["quantizer"])
85
+ prenet = Decoder(**config["prenet"])
86
+ postnet = Decoder(**config["postnet"])
87
+ decoder = WaveGenerator(**config["decoder"])
88
+ speaker_encoder = SpeakerEncoder(**config["speaker_encoder"])
89
+
90
+ model = cls(
91
+ mel_params=mel_params,
92
+ encoder=encoder,
93
+ decoder=decoder,
94
+ quantizer=quantizer,
95
+ speaker_encoder=speaker_encoder,
96
+ prenet=prenet,
97
+ postnet=postnet,
98
+ )
99
+
100
+ state_dict = load_file(ckpt_path)
101
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
102
+
103
+ for key in missing_keys:
104
+ print(f"Missing tensor: {key}")
105
+ for key in unexpected_keys:
106
+ print(f"Unexpected tensor: {key}")
107
+
108
+ model.eval()
109
+ model.remove_weight_norm()
110
+
111
+ return model
112
+
113
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
114
+ """
115
+ Performs a forward pass through the model.
116
+
117
+ Args:
118
+ batch (dict): A dictionary containing features, reference waveform, and target waveform.
119
+
120
+ Returns:
121
+ dict: A dictionary containing the reconstruction, features, and other metrics.
122
+ """
123
+ feat = batch["feat"]
124
+ mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
125
+
126
+ z = self.encoder(feat.transpose(1, 2))
127
+ vq_outputs = self.quantizer(z)
128
+
129
+ x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2))
130
+
131
+ conditions = d_vector
132
+ with_speaker_loss = False
133
+
134
+ x = self.prenet(vq_outputs["z_q"], conditions)
135
+ pred_feat = self.postnet(x)
136
+ x = x + conditions.unsqueeze(-1)
137
+ wav_recon = self.decoder(x)
138
+
139
+ return {
140
+ "vq_loss": vq_outputs["vq_loss"],
141
+ "perplexity": vq_outputs["perplexity"],
142
+ "cluster_size": vq_outputs["active_num"],
143
+ "recons": wav_recon,
144
+ "pred_feat": pred_feat,
145
+ "x_vector": x_vector,
146
+ "d_vector": d_vector,
147
+ "audios": batch["wav"].unsqueeze(1),
148
+ "with_speaker_loss": with_speaker_loss,
149
+ }
150
+
151
+ @torch.no_grad()
152
+ def tokenize(self, batch: Dict[str, Any]):
153
+ """
154
+ Tokenizes the input audio into semantic and global tokens.
155
+
156
+ Args:
157
+ batch (dict): The input audio features and reference waveform.
158
+
159
+ Returns:
160
+ tuple: Semantic tokens and global tokens.
161
+ """
162
+ feat = batch["feat"]
163
+ mel = self.mel_transformer(batch["ref_wav"]).squeeze(1)
164
+
165
+ z = self.encoder(feat.transpose(1, 2))
166
+ semantic_tokens = self.quantizer.tokenize(z)
167
+ global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2))
168
+
169
+ return semantic_tokens, global_tokens
170
+
171
+ @torch.no_grad()
172
+ def detokenize(self, semantic_tokens, global_tokens):
173
+ """
174
+ Detokenizes the semantic and global tokens into a waveform.
175
+
176
+ Args:
177
+ semantic_tokens (tensor): Semantic tokens.
178
+ global_tokens (tensor): Global tokens.
179
+
180
+ Returns:
181
+ tensor: Reconstructed waveform.
182
+ """
183
+ z_q = self.quantizer.detokenize(semantic_tokens)
184
+ d_vector = self.speaker_encoder.detokenize(global_tokens)
185
+ x = self.prenet(z_q, d_vector)
186
+ x = x + d_vector.unsqueeze(-1)
187
+ wav_recon = self.decoder(x)
188
+
189
+ return wav_recon
190
+
191
+ def init_mel_transformer(self, config: Dict[str, Any]):
192
+ """
193
+ Initializes the MelSpectrogram transformer based on the provided configuration.
194
+
195
+ Args:
196
+ config (dict): Configuration parameters for MelSpectrogram.
197
+ """
198
+ import torchaudio.transforms as TT
199
+
200
+ self.mel_transformer = TT.MelSpectrogram(
201
+ config["sample_rate"],
202
+ config["n_fft"],
203
+ config["win_length"],
204
+ config["hop_length"],
205
+ config["mel_fmin"],
206
+ config["mel_fmax"],
207
+ n_mels=config["num_mels"],
208
+ power=1,
209
+ norm="slaney",
210
+ mel_scale="slaney",
211
+ )
212
+
213
+ def remove_weight_norm(self):
214
+ """Removes weight normalization from all layers."""
215
+ def _remove_weight_norm(m):
216
+ try:
217
+ torch.nn.utils.remove_weight_norm(m)
218
+ except ValueError:
219
+ pass # The module didn't have weight norm
220
+
221
+ self.apply(_remove_weight_norm)
222
+
223
+
224
+ # Test the model
225
+ if __name__ == "__main__":
226
+
227
+ config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml")
228
+ model = BiCodec.load_from_checkpoint(
229
+ model_dir="pretrained_models/SparkTTS-0.5B/BiCodec",
230
+ )
231
+
232
+ # Generate random inputs for testing
233
+ duration = 0.96
234
+ x = torch.randn(20, 1, int(duration * 16000))
235
+ feat = torch.randn(20, int(duration * 50), 1024)
236
+ inputs = {"feat": feat, "wav": x, "ref_wav": x}
237
+
238
+ # Forward pass
239
+ outputs = model(inputs)
240
+ semantic_tokens, global_tokens = model.tokenize(inputs)
241
+ wav_recon = model.detokenize(semantic_tokens, global_tokens)
242
+
243
+ # Verify if the reconstruction matches
244
+ if torch.allclose(outputs["recons"].detach(), wav_recon):
245
+ print("Test successful")
246
+ else:
247
+ print("Test failed")
sparktts/modules/blocks/layers.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
17
+
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn.utils import weight_norm
22
+
23
+
24
+ def WNConv1d(*args, **kwargs):
25
+ return weight_norm(nn.Conv1d(*args, **kwargs))
26
+
27
+
28
+ def WNConvTranspose1d(*args, **kwargs):
29
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
30
+
31
+
32
+ # Scripting this brings model speed up 1.4x
33
+ @torch.jit.script
34
+ def snake(x, alpha):
35
+ shape = x.shape
36
+ x = x.reshape(shape[0], shape[1], -1)
37
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
38
+ x = x.reshape(shape)
39
+ return x
40
+
41
+
42
+ class Snake1d(nn.Module):
43
+ def __init__(self, channels):
44
+ super().__init__()
45
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
46
+
47
+ def forward(self, x):
48
+ return snake(x, self.alpha)
49
+
50
+
51
+ class ResidualUnit(nn.Module):
52
+ def __init__(self, dim: int = 16, dilation: int = 1):
53
+ super().__init__()
54
+ pad = ((7 - 1) * dilation) // 2
55
+ self.block = nn.Sequential(
56
+ Snake1d(dim),
57
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
58
+ Snake1d(dim),
59
+ WNConv1d(dim, dim, kernel_size=1),
60
+ )
61
+
62
+ def forward(self, x):
63
+ y = self.block(x)
64
+ pad = (x.shape[-1] - y.shape[-1]) // 2
65
+ if pad > 0:
66
+ x = x[..., pad:-pad]
67
+ return x + y
68
+
69
+
70
+ def init_weights(m):
71
+ if isinstance(m, nn.Conv1d):
72
+ nn.init.trunc_normal_(m.weight, std=0.02)
73
+ nn.init.constant_(m.bias, 0)
sparktts/modules/blocks/samper.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ class SamplingBlock(nn.Module):
23
+ """Sampling block for upsampling or downsampling"""
24
+
25
+ def __init__(
26
+ self,
27
+ dim: int,
28
+ groups: int = 1,
29
+ upsample_scale: int = 1,
30
+ downsample_scale: int = 1,
31
+ ) -> None:
32
+ """
33
+ Args:
34
+ dim: input dimension
35
+ groups: number of groups
36
+ upsample_scale: upsampling scale
37
+ downsample_scale: downsampling scale
38
+ """
39
+ super(SamplingBlock, self).__init__()
40
+
41
+ self.upsample_scale = upsample_scale
42
+ self.downsample_scale = downsample_scale
43
+
44
+ if self.upsample_scale > 1:
45
+ self.de_conv_upsampler = nn.Sequential(
46
+ nn.LeakyReLU(0.2),
47
+ nn.ConvTranspose1d(
48
+ dim,
49
+ dim,
50
+ kernel_size=upsample_scale * 2,
51
+ stride=upsample_scale,
52
+ padding=upsample_scale // 2 + upsample_scale % 2,
53
+ output_padding=upsample_scale % 2,
54
+ groups=groups,
55
+ ),
56
+ )
57
+
58
+ if self.downsample_scale > 1:
59
+ self.conv_downsampler = nn.Sequential(
60
+ nn.LeakyReLU(0.2),
61
+ nn.Conv1d(
62
+ dim,
63
+ dim,
64
+ kernel_size=2 * downsample_scale,
65
+ stride=downsample_scale,
66
+ padding=downsample_scale // 2 + downsample_scale % 2,
67
+ groups=groups,
68
+ ),
69
+ )
70
+
71
+ @staticmethod
72
+ def repeat_upsampler(x, upsample_scale):
73
+ return x.repeat_interleave(upsample_scale, dim=2)
74
+
75
+ @staticmethod
76
+ def skip_downsampler(x, downsample_scale):
77
+ return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale)
78
+
79
+ def forward(self, x):
80
+ x = x.transpose(1, 2)
81
+ if self.upsample_scale > 1:
82
+ repeat_res = self.repeat_upsampler(x, self.upsample_scale)
83
+ deconv_res = self.de_conv_upsampler(x)
84
+ upmerge_res = repeat_res + deconv_res
85
+ else:
86
+ upmerge_res = x
87
+ repeat_res = x
88
+
89
+ if self.downsample_scale > 1:
90
+ conv_res = self.conv_downsampler(upmerge_res)
91
+ skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale)
92
+ skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale)
93
+ else:
94
+ conv_res = upmerge_res
95
+ skip2_res = upmerge_res
96
+ skip1_res = repeat_res
97
+
98
+ final_res = conv_res + skip1_res + skip2_res
99
+
100
+ return final_res
101
+
102
+
103
+ # test
104
+ if __name__ == "__main__":
105
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
106
+ model = SamplingBlock(1024, 1024, upsample_scale=2)
107
+ model_down = SamplingBlock(1024, 1024, downsample_scale=2)
108
+ output = model(test_input)
109
+ output_down = model_down(test_input)
110
+ print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100])
111
+ print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25])
112
+ if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size(
113
+ [8, 1024, 25]
114
+ ):
115
+ print("test successful")
sparktts/modules/blocks/vocos.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import Tuple
21
+ from torch.nn.utils import weight_norm, remove_weight_norm
22
+
23
+ from typing import Optional
24
+
25
+
26
+ class ConvNeXtBlock(nn.Module):
27
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
28
+
29
+ Args:
30
+ dim (int): Number of input channels.
31
+ intermediate_dim (int): Dimensionality of the intermediate layer.
32
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
33
+ Defaults to None.
34
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
35
+ None means non-conditional LayerNorm. Defaults to None.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ intermediate_dim: int,
42
+ layer_scale_init_value: float,
43
+ condition_dim: Optional[int] = None,
44
+ ):
45
+ super().__init__()
46
+ self.dwconv = nn.Conv1d(
47
+ dim, dim, kernel_size=7, padding=3, groups=dim
48
+ ) # depthwise conv
49
+ self.adanorm = condition_dim is not None
50
+ if condition_dim:
51
+ self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
52
+ else:
53
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
54
+ self.pwconv1 = nn.Linear(
55
+ dim, intermediate_dim
56
+ ) # pointwise/1x1 convs, implemented with linear layers
57
+ self.act = nn.GELU()
58
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
59
+ self.gamma = (
60
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
61
+ if layer_scale_init_value > 0
62
+ else None
63
+ )
64
+
65
+ def forward(
66
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
67
+ ) -> torch.Tensor:
68
+ residual = x
69
+ x = self.dwconv(x)
70
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
71
+ if self.adanorm:
72
+ assert cond_embedding_id is not None
73
+ x = self.norm(x, cond_embedding_id)
74
+ else:
75
+ x = self.norm(x)
76
+ x = self.pwconv1(x)
77
+ x = self.act(x)
78
+ x = self.pwconv2(x)
79
+ if self.gamma is not None:
80
+ x = self.gamma * x
81
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
82
+
83
+ x = residual + x
84
+ return x
85
+
86
+
87
+ class AdaLayerNorm(nn.Module):
88
+ """
89
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
90
+
91
+ Args:
92
+ condition_dim (int): Dimension of the condition.
93
+ embedding_dim (int): Dimension of the embeddings.
94
+ """
95
+
96
+ def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6):
97
+ super().__init__()
98
+ self.eps = eps
99
+ self.dim = embedding_dim
100
+ self.scale = nn.Linear(condition_dim, embedding_dim)
101
+ self.shift = nn.Linear(condition_dim, embedding_dim)
102
+ torch.nn.init.ones_(self.scale.weight)
103
+ torch.nn.init.zeros_(self.shift.weight)
104
+
105
+ def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor:
106
+ scale = self.scale(cond_embedding)
107
+ shift = self.shift(cond_embedding)
108
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
109
+ x = x * scale.unsqueeze(1) + shift.unsqueeze(1)
110
+ return x
111
+
112
+
113
+ class ResBlock1(nn.Module):
114
+ """
115
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
116
+ but without upsampling layers.
117
+
118
+ Args:
119
+ dim (int): Number of input channels.
120
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
121
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
122
+ Defaults to (1, 3, 5).
123
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
124
+ Defaults to 0.1.
125
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
126
+ Defaults to None.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ dim: int,
132
+ kernel_size: int = 3,
133
+ dilation: Tuple[int, int, int] = (1, 3, 5),
134
+ lrelu_slope: float = 0.1,
135
+ layer_scale_init_value: Optional[float] = None,
136
+ ):
137
+ super().__init__()
138
+ self.lrelu_slope = lrelu_slope
139
+ self.convs1 = nn.ModuleList(
140
+ [
141
+ weight_norm(
142
+ nn.Conv1d(
143
+ dim,
144
+ dim,
145
+ kernel_size,
146
+ 1,
147
+ dilation=dilation[0],
148
+ padding=self.get_padding(kernel_size, dilation[0]),
149
+ )
150
+ ),
151
+ weight_norm(
152
+ nn.Conv1d(
153
+ dim,
154
+ dim,
155
+ kernel_size,
156
+ 1,
157
+ dilation=dilation[1],
158
+ padding=self.get_padding(kernel_size, dilation[1]),
159
+ )
160
+ ),
161
+ weight_norm(
162
+ nn.Conv1d(
163
+ dim,
164
+ dim,
165
+ kernel_size,
166
+ 1,
167
+ dilation=dilation[2],
168
+ padding=self.get_padding(kernel_size, dilation[2]),
169
+ )
170
+ ),
171
+ ]
172
+ )
173
+
174
+ self.convs2 = nn.ModuleList(
175
+ [
176
+ weight_norm(
177
+ nn.Conv1d(
178
+ dim,
179
+ dim,
180
+ kernel_size,
181
+ 1,
182
+ dilation=1,
183
+ padding=self.get_padding(kernel_size, 1),
184
+ )
185
+ ),
186
+ weight_norm(
187
+ nn.Conv1d(
188
+ dim,
189
+ dim,
190
+ kernel_size,
191
+ 1,
192
+ dilation=1,
193
+ padding=self.get_padding(kernel_size, 1),
194
+ )
195
+ ),
196
+ weight_norm(
197
+ nn.Conv1d(
198
+ dim,
199
+ dim,
200
+ kernel_size,
201
+ 1,
202
+ dilation=1,
203
+ padding=self.get_padding(kernel_size, 1),
204
+ )
205
+ ),
206
+ ]
207
+ )
208
+
209
+ self.gamma = nn.ParameterList(
210
+ [
211
+ (
212
+ nn.Parameter(
213
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
214
+ )
215
+ if layer_scale_init_value is not None
216
+ else None
217
+ ),
218
+ (
219
+ nn.Parameter(
220
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
221
+ )
222
+ if layer_scale_init_value is not None
223
+ else None
224
+ ),
225
+ (
226
+ nn.Parameter(
227
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
228
+ )
229
+ if layer_scale_init_value is not None
230
+ else None
231
+ ),
232
+ ]
233
+ )
234
+
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
237
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
238
+ xt = c1(xt)
239
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
240
+ xt = c2(xt)
241
+ if gamma is not None:
242
+ xt = gamma * xt
243
+ x = xt + x
244
+ return x
245
+
246
+ def remove_weight_norm(self):
247
+ for l in self.convs1:
248
+ remove_weight_norm(l)
249
+ for l in self.convs2:
250
+ remove_weight_norm(l)
251
+
252
+ @staticmethod
253
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
254
+ return int((kernel_size * dilation - dilation) / 2)
255
+
256
+
257
+ class Backbone(nn.Module):
258
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
259
+
260
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
261
+ """
262
+ Args:
263
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
264
+ C denotes output features, and L is the sequence length.
265
+
266
+ Returns:
267
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
268
+ and H denotes the model dimension.
269
+ """
270
+ raise NotImplementedError("Subclasses must implement the forward method.")
271
+
272
+
273
+ class VocosBackbone(Backbone):
274
+ """
275
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
276
+
277
+ Args:
278
+ input_channels (int): Number of input features channels.
279
+ dim (int): Hidden dimension of the model.
280
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
281
+ num_layers (int): Number of ConvNeXtBlock layers.
282
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
283
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
284
+ None means non-conditional model. Defaults to None.
285
+ """
286
+
287
+ def __init__(
288
+ self,
289
+ input_channels: int,
290
+ dim: int,
291
+ intermediate_dim: int,
292
+ num_layers: int,
293
+ layer_scale_init_value: Optional[float] = None,
294
+ condition_dim: Optional[int] = None,
295
+ ):
296
+ super().__init__()
297
+ self.input_channels = input_channels
298
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
299
+ self.adanorm = condition_dim is not None
300
+ if condition_dim:
301
+ self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6)
302
+ else:
303
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
304
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
305
+ self.convnext = nn.ModuleList(
306
+ [
307
+ ConvNeXtBlock(
308
+ dim=dim,
309
+ intermediate_dim=intermediate_dim,
310
+ layer_scale_init_value=layer_scale_init_value,
311
+ condition_dim=condition_dim,
312
+ )
313
+ for _ in range(num_layers)
314
+ ]
315
+ )
316
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
317
+ self.apply(self._init_weights)
318
+
319
+ def _init_weights(self, m):
320
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
321
+ nn.init.trunc_normal_(m.weight, std=0.02)
322
+ nn.init.constant_(m.bias, 0)
323
+
324
+ def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor:
325
+ x = self.embed(x)
326
+ if self.adanorm:
327
+ assert condition is not None
328
+ x = self.norm(x.transpose(1, 2), condition)
329
+ else:
330
+ x = self.norm(x.transpose(1, 2))
331
+ x = x.transpose(1, 2)
332
+ for conv_block in self.convnext:
333
+ x = conv_block(x, condition)
334
+ x = self.final_layer_norm(x.transpose(1, 2))
335
+ return x
336
+
337
+
338
+ class VocosResNetBackbone(Backbone):
339
+ """
340
+ Vocos backbone module built with ResBlocks.
341
+
342
+ Args:
343
+ input_channels (int): Number of input features channels.
344
+ dim (int): Hidden dimension of the model.
345
+ num_blocks (int): Number of ResBlock1 blocks.
346
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ input_channels,
352
+ dim,
353
+ num_blocks,
354
+ layer_scale_init_value=None,
355
+ ):
356
+ super().__init__()
357
+ self.input_channels = input_channels
358
+ self.embed = weight_norm(
359
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
360
+ )
361
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
362
+ self.resnet = nn.Sequential(
363
+ *[
364
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
365
+ for _ in range(num_blocks)
366
+ ]
367
+ )
368
+
369
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
370
+ x = self.embed(x)
371
+ x = self.resnet(x)
372
+ x = x.transpose(1, 2)
373
+ return x
sparktts/modules/encoder_decoder/feat_decoder.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import List
21
+
22
+ from sparktts.modules.blocks.vocos import VocosBackbone
23
+ from sparktts.modules.blocks.samper import SamplingBlock
24
+
25
+
26
+ class Decoder(nn.Module):
27
+ """Decoder module with convnext and upsampling blocks
28
+
29
+ Args:
30
+ sample_ratios (List[int]): sample ratios
31
+ example: [2, 2] means downsample by 2x and then upsample by 2x
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ input_channels: int,
37
+ vocos_dim: int,
38
+ vocos_intermediate_dim: int,
39
+ vocos_num_layers: int,
40
+ out_channels: int,
41
+ condition_dim: int = None,
42
+ sample_ratios: List[int] = [1, 1],
43
+ use_tanh_at_final: bool = False,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.linear_pre = nn.Linear(input_channels, vocos_dim)
48
+ modules = [
49
+ nn.Sequential(
50
+ SamplingBlock(
51
+ dim=vocos_dim,
52
+ groups=vocos_dim,
53
+ upsample_scale=ratio,
54
+ ),
55
+ VocosBackbone(
56
+ input_channels=vocos_dim,
57
+ dim=vocos_dim,
58
+ intermediate_dim=vocos_intermediate_dim,
59
+ num_layers=2,
60
+ condition_dim=None,
61
+ ),
62
+ )
63
+ for ratio in sample_ratios
64
+ ]
65
+
66
+ self.downsample = nn.Sequential(*modules)
67
+
68
+ self.vocos_backbone = VocosBackbone(
69
+ input_channels=vocos_dim,
70
+ dim=vocos_dim,
71
+ intermediate_dim=vocos_intermediate_dim,
72
+ num_layers=vocos_num_layers,
73
+ condition_dim=condition_dim,
74
+ )
75
+ self.linear = nn.Linear(vocos_dim, out_channels)
76
+ self.use_tanh_at_final = use_tanh_at_final
77
+
78
+ def forward(self, x: torch.Tensor, c: torch.Tensor = None):
79
+ """encoder forward.
80
+
81
+ Args:
82
+ x (torch.Tensor): (batch_size, input_channels, length)
83
+
84
+ Returns:
85
+ x (torch.Tensor): (batch_size, encode_channels, length)
86
+ """
87
+ x = self.linear_pre(x.transpose(1, 2))
88
+ x = self.downsample(x).transpose(1, 2)
89
+ x = self.vocos_backbone(x, condition=c)
90
+ x = self.linear(x).transpose(1, 2)
91
+ if self.use_tanh_at_final:
92
+ x = torch.tanh(x)
93
+
94
+ return x
95
+
96
+
97
+ # test
98
+ if __name__ == "__main__":
99
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
100
+ condition = torch.randn(8, 256)
101
+ decoder = Decoder(
102
+ input_channels=1024,
103
+ vocos_dim=384,
104
+ vocos_intermediate_dim=2048,
105
+ vocos_num_layers=12,
106
+ out_channels=256,
107
+ condition_dim=256,
108
+ sample_ratios=[2, 2],
109
+ )
110
+ output = decoder(test_input, condition)
111
+ print(output.shape) # torch.Size([8, 256, 200])
112
+ if output.shape == torch.Size([8, 256, 200]):
113
+ print("Decoder test passed")
114
+ else:
115
+ print("Decoder test failed")
sparktts/modules/encoder_decoder/feat_encoder.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from typing import List
21
+
22
+ from sparktts.modules.blocks.vocos import VocosBackbone
23
+ from sparktts.modules.blocks.samper import SamplingBlock
24
+
25
+
26
+ class Encoder(nn.Module):
27
+ """Encoder module with convnext and downsampling blocks"""
28
+
29
+ def __init__(
30
+ self,
31
+ input_channels: int,
32
+ vocos_dim: int,
33
+ vocos_intermediate_dim: int,
34
+ vocos_num_layers: int,
35
+ out_channels: int,
36
+ sample_ratios: List[int] = [1, 1],
37
+ ):
38
+ super().__init__()
39
+ """
40
+ Encoder module with VocosBackbone and sampling blocks.
41
+
42
+ Args:
43
+ sample_ratios (List[int]): sample ratios
44
+ example: [2, 2] means downsample by 2x and then upsample by 2x
45
+ """
46
+ self.encoder = VocosBackbone(
47
+ input_channels=input_channels,
48
+ dim=vocos_dim,
49
+ intermediate_dim=vocos_intermediate_dim,
50
+ num_layers=vocos_num_layers,
51
+ condition_dim=None,
52
+ )
53
+
54
+ modules = [
55
+ nn.Sequential(
56
+ SamplingBlock(
57
+ dim=vocos_dim,
58
+ groups=vocos_dim,
59
+ downsample_scale=ratio,
60
+ ),
61
+ VocosBackbone(
62
+ input_channels=vocos_dim,
63
+ dim=vocos_dim,
64
+ intermediate_dim=vocos_intermediate_dim,
65
+ num_layers=2,
66
+ condition_dim=None,
67
+ ),
68
+ )
69
+ for ratio in sample_ratios
70
+ ]
71
+
72
+ self.downsample = nn.Sequential(*modules)
73
+
74
+ self.project = nn.Linear(vocos_dim, out_channels)
75
+
76
+ def forward(self, x: torch.Tensor, *args):
77
+ """
78
+ Args:
79
+ x (torch.Tensor): (batch_size, input_channels, length)
80
+
81
+ Returns:
82
+ x (torch.Tensor): (batch_size, encode_channels, length)
83
+ """
84
+ x = self.encoder(x)
85
+ x = self.downsample(x)
86
+ x = self.project(x)
87
+ return x.transpose(1, 2)
88
+
89
+
90
+ # test
91
+ if __name__ == "__main__":
92
+ test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50
93
+ encoder = Encoder(
94
+ input_channels=1024,
95
+ vocos_dim=384,
96
+ vocos_intermediate_dim=2048,
97
+ vocos_num_layers=12,
98
+ out_channels=256,
99
+ sample_ratios=[2, 2],
100
+ )
101
+
102
+ output = encoder(test_input)
103
+ print(output.shape) # torch.Size([8, 256, 12])
104
+ if output.shape == torch.Size([8, 256, 12]):
105
+ print("test successful")
sparktts/modules/encoder_decoder/wave_generator.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Xinsheng Wang ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0
16
+
17
+
18
+ import torch.nn as nn
19
+
20
+ from sparktts.modules.blocks.layers import (
21
+ Snake1d,
22
+ WNConv1d,
23
+ ResidualUnit,
24
+ WNConvTranspose1d,
25
+ init_weights,
26
+ )
27
+
28
+
29
+ class DecoderBlock(nn.Module):
30
+ def __init__(
31
+ self,
32
+ input_dim: int = 16,
33
+ output_dim: int = 8,
34
+ kernel_size: int = 2,
35
+ stride: int = 1,
36
+ ):
37
+ super().__init__()
38
+ self.block = nn.Sequential(
39
+ Snake1d(input_dim),
40
+ WNConvTranspose1d(
41
+ input_dim,
42
+ output_dim,
43
+ kernel_size=kernel_size,
44
+ stride=stride,
45
+ padding=(kernel_size - stride) // 2,
46
+ ),
47
+ ResidualUnit(output_dim, dilation=1),
48
+ ResidualUnit(output_dim, dilation=3),
49
+ ResidualUnit(output_dim, dilation=9),
50
+ )
51
+
52
+ def forward(self, x):
53
+ return self.block(x)
54
+
55
+
56
+ class WaveGenerator(nn.Module):
57
+ def __init__(
58
+ self,
59
+ input_channel,
60
+ channels,
61
+ rates,
62
+ kernel_sizes,
63
+ d_out: int = 1,
64
+ ):
65
+ super().__init__()
66
+
67
+ # Add first conv layer
68
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
69
+
70
+ # Add upsampling + MRF blocks
71
+ for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)):
72
+ input_dim = channels // 2**i
73
+ output_dim = channels // 2 ** (i + 1)
74
+ layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)]
75
+
76
+ # Add final conv layer
77
+ layers += [
78
+ Snake1d(output_dim),
79
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
80
+ nn.Tanh(),
81
+ ]
82
+
83
+ self.model = nn.Sequential(*layers)
84
+
85
+ self.apply(init_weights)
86
+
87
+ def forward(self, x):
88
+ return self.model(x)
sparktts/modules/fsq/finite_scalar_quantization.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
3
+ Code adapted from Jax version in Appendix A.1
4
+ """
5
+
6
+ from __future__ import annotations
7
+ from functools import wraps, partial
8
+ from contextlib import nullcontext
9
+ from typing import List, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import Module
14
+ from torch import Tensor, int32
15
+ from torch.amp import autocast
16
+
17
+ from einops import rearrange, pack, unpack
18
+
19
+ # helper functions
20
+
21
+
22
+ def exists(v):
23
+ return v is not None
24
+
25
+
26
+ def default(*args):
27
+ for arg in args:
28
+ if exists(arg):
29
+ return arg
30
+ return None
31
+
32
+
33
+ def maybe(fn):
34
+ @wraps(fn)
35
+ def inner(x, *args, **kwargs):
36
+ if not exists(x):
37
+ return x
38
+ return fn(x, *args, **kwargs)
39
+
40
+ return inner
41
+
42
+
43
+ def pack_one(t, pattern):
44
+ return pack([t], pattern)
45
+
46
+
47
+ def unpack_one(t, ps, pattern):
48
+ return unpack(t, ps, pattern)[0]
49
+
50
+
51
+ # tensor helpers
52
+
53
+
54
+ def round_ste(z: Tensor) -> Tensor:
55
+ """Round with straight through gradients."""
56
+ zhat = z.round()
57
+ return z + (zhat - z).detach()
58
+
59
+
60
+ # main class
61
+
62
+
63
+ class FSQ(Module):
64
+ def __init__(
65
+ self,
66
+ levels: List[int],
67
+ dim: int | None = None,
68
+ num_codebooks=1,
69
+ keep_num_codebooks_dim: bool | None = None,
70
+ scale: float | None = None,
71
+ allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
72
+ channel_first: bool = False,
73
+ projection_has_bias: bool = True,
74
+ return_indices=True,
75
+ force_quantization_f32=True,
76
+ ):
77
+ super().__init__()
78
+ _levels = torch.tensor(levels, dtype=int32)
79
+ self.register_buffer("_levels", _levels, persistent=False)
80
+
81
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
82
+ self.register_buffer("_basis", _basis, persistent=False)
83
+
84
+ self.scale = scale
85
+
86
+ codebook_dim = len(levels)
87
+ self.codebook_dim = codebook_dim
88
+
89
+ effective_codebook_dim = codebook_dim * num_codebooks
90
+ self.num_codebooks = num_codebooks
91
+ self.effective_codebook_dim = effective_codebook_dim
92
+
93
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
94
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
95
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
96
+
97
+ self.dim = default(dim, len(_levels) * num_codebooks)
98
+
99
+ self.channel_first = channel_first
100
+
101
+ has_projections = self.dim != effective_codebook_dim
102
+ self.project_in = (
103
+ nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
104
+ if has_projections
105
+ else nn.Identity()
106
+ )
107
+ self.project_out = (
108
+ nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
109
+ if has_projections
110
+ else nn.Identity()
111
+ )
112
+
113
+ self.has_projections = has_projections
114
+
115
+ self.return_indices = return_indices
116
+ if return_indices:
117
+ self.codebook_size = self._levels.prod().item()
118
+ implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
119
+ self.register_buffer(
120
+ "implicit_codebook", implicit_codebook, persistent=False
121
+ )
122
+
123
+ self.allowed_dtypes = allowed_dtypes
124
+ self.force_quantization_f32 = force_quantization_f32
125
+
126
+ def bound(self, z, eps: float = 1e-3):
127
+ """Bound `z`, an array of shape (..., d)."""
128
+ half_l = (self._levels - 1) * (1 + eps) / 2
129
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
130
+ shift = (offset / half_l).atanh()
131
+ return (z + shift).tanh() * half_l - offset
132
+
133
+ def quantize(self, z):
134
+ """Quantizes z, returns quantized zhat, same shape as z."""
135
+ quantized = round_ste(self.bound(z))
136
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
137
+ return quantized / half_width
138
+
139
+ def _scale_and_shift(self, zhat_normalized):
140
+ half_width = self._levels // 2
141
+ return (zhat_normalized * half_width) + half_width
142
+
143
+ def _scale_and_shift_inverse(self, zhat):
144
+ half_width = self._levels // 2
145
+ return (zhat - half_width) / half_width
146
+
147
+ def _indices_to_codes(self, indices):
148
+ level_indices = self.indices_to_level_indices(indices)
149
+ codes = self._scale_and_shift_inverse(level_indices)
150
+ return codes
151
+
152
+ def codes_to_indices(self, zhat):
153
+ """Converts a `code` to an index in the codebook."""
154
+ assert zhat.shape[-1] == self.codebook_dim
155
+ zhat = self._scale_and_shift(zhat)
156
+ return (zhat * self._basis).sum(dim=-1).to(int32)
157
+
158
+ def indices_to_level_indices(self, indices):
159
+ """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
160
+ indices = rearrange(indices, "... -> ... 1")
161
+ codes_non_centered = (indices // self._basis) % self._levels
162
+ return codes_non_centered
163
+
164
+ def indices_to_codes(self, indices):
165
+ """Inverse of `codes_to_indices`."""
166
+ assert exists(indices)
167
+
168
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
169
+
170
+ codes = self._indices_to_codes(indices)
171
+
172
+ if self.keep_num_codebooks_dim:
173
+ codes = rearrange(codes, "... c d -> ... (c d)")
174
+
175
+ codes = self.project_out(codes)
176
+
177
+ if is_img_or_video or self.channel_first:
178
+ codes = rearrange(codes, "b ... d -> b d ...")
179
+
180
+ return codes
181
+
182
+ def forward(self, z):
183
+ """
184
+ einstein notation
185
+ b - batch
186
+ n - sequence (or flattened spatial dimensions)
187
+ d - feature dimension
188
+ c - number of codebook dim
189
+ """
190
+
191
+ is_img_or_video = z.ndim >= 4
192
+ need_move_channel_last = is_img_or_video or self.channel_first
193
+
194
+ # standardize image or video into (batch, seq, dimension)
195
+
196
+ if need_move_channel_last:
197
+ z = rearrange(z, "b d ... -> b ... d")
198
+ z, ps = pack_one(z, "b * d")
199
+
200
+ assert (
201
+ z.shape[-1] == self.dim
202
+ ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
203
+
204
+ z = self.project_in(z)
205
+
206
+ z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
207
+
208
+ # whether to force quantization step to be full precision or not
209
+
210
+ force_f32 = self.force_quantization_f32
211
+ quantization_context = (
212
+ partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext
213
+ )
214
+
215
+ with quantization_context():
216
+ orig_dtype = z.dtype
217
+
218
+ if force_f32 and orig_dtype not in self.allowed_dtypes:
219
+ z = z.float()
220
+
221
+ codes = self.quantize(z)
222
+
223
+ # returning indices could be optional
224
+
225
+ indices = None
226
+
227
+ if self.return_indices:
228
+ indices = self.codes_to_indices(codes)
229
+
230
+ codes = rearrange(codes, "b n c d -> b n (c d)")
231
+
232
+ codes = codes.type(orig_dtype)
233
+
234
+ # project out
235
+
236
+ out = self.project_out(codes)
237
+
238
+ # reconstitute image or video dimensions
239
+
240
+ if need_move_channel_last:
241
+ out = unpack_one(out, ps, "b * d")
242
+ out = rearrange(out, "b ... d -> b d ...")
243
+
244
+ indices = maybe(unpack_one)(indices, ps, "b * c")
245
+
246
+ if not self.keep_num_codebooks_dim and self.return_indices:
247
+ indices = maybe(rearrange)(indices, "... 1 -> ...")
248
+
249
+ # return quantized output and indices
250
+
251
+ return out, indices
sparktts/modules/fsq/residual_fsq.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+ from typing import List
7
+ from torch import nn
8
+ from torch.nn import Module
9
+ from torch.amp import autocast
10
+ from einx import get_at
11
+ from einops import rearrange, reduce, pack, unpack
12
+
13
+ from sparktts.modules.fsq.finite_scalar_quantization import FSQ
14
+
15
+
16
+ def exists(val):
17
+ return val is not None
18
+
19
+
20
+ def first(l):
21
+ return l[0]
22
+
23
+
24
+ def default(val, d):
25
+ return val if exists(val) else d
26
+
27
+
28
+ def round_up_multiple(num, mult):
29
+ return ceil(num / mult) * mult
30
+
31
+
32
+ # distributed helpers
33
+
34
+
35
+ def is_distributed():
36
+ return dist.is_initialized() and dist.get_world_size() > 1
37
+
38
+
39
+ def get_maybe_sync_seed(device, max_size=10_000):
40
+ rand_int = torch.randint(0, max_size, (), device=device)
41
+
42
+ if is_distributed():
43
+ dist.all_reduce(rand_int)
44
+
45
+ return rand_int.item()
46
+
47
+
48
+ class ResidualFSQ(Module):
49
+ """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
50
+
51
+ def __init__(
52
+ self,
53
+ *,
54
+ levels: List[int],
55
+ num_quantizers,
56
+ dim=None,
57
+ is_channel_first=False,
58
+ quantize_dropout=False,
59
+ quantize_dropout_cutoff_index=0,
60
+ quantize_dropout_multiple_of=1,
61
+ **kwargs,
62
+ ):
63
+ super().__init__()
64
+ codebook_dim = len(levels)
65
+ dim = default(dim, codebook_dim)
66
+
67
+ requires_projection = codebook_dim != dim
68
+ self.project_in = (
69
+ nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
70
+ )
71
+ self.project_out = (
72
+ nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
73
+ )
74
+ self.has_projections = requires_projection
75
+
76
+ self.is_channel_first = is_channel_first
77
+ self.num_quantizers = num_quantizers
78
+
79
+ self.levels = levels
80
+ self.layers = nn.ModuleList([])
81
+
82
+ levels_tensor = torch.Tensor(levels)
83
+
84
+ scales = []
85
+
86
+ for ind in range(num_quantizers):
87
+ scales.append((levels_tensor - 1) ** -ind)
88
+
89
+ fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
90
+
91
+ self.layers.append(fsq)
92
+
93
+ assert all([not fsq.has_projections for fsq in self.layers])
94
+
95
+ self.codebook_size = self.layers[0].codebook_size
96
+
97
+ self.register_buffer("scales", torch.stack(scales), persistent=False)
98
+
99
+ self.quantize_dropout = quantize_dropout and num_quantizers > 1
100
+
101
+ assert quantize_dropout_cutoff_index >= 0
102
+
103
+ self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
104
+ self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
105
+
106
+ @property
107
+ def codebooks(self):
108
+ codebooks = [layer.implicit_codebook for layer in self.layers]
109
+ codebooks = torch.stack(codebooks, dim=0)
110
+ return codebooks
111
+
112
+ def get_codes_from_indices(self, indices):
113
+
114
+ batch, quantize_dim = indices.shape[0], indices.shape[-1]
115
+
116
+ # may also receive indices in the shape of 'b h w q' (accept_image_fmap)
117
+
118
+ indices, ps = pack([indices], "b * q")
119
+
120
+ # because of quantize dropout, one can pass in indices that are coarse
121
+ # and the network should be able to reconstruct
122
+
123
+ if quantize_dim < self.num_quantizers:
124
+ assert (
125
+ self.quantize_dropout > 0.0
126
+ ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
127
+ indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
128
+
129
+ # take care of quantizer dropout
130
+
131
+ mask = indices == -1
132
+ indices = indices.masked_fill(
133
+ mask, 0
134
+ ) # have it fetch a dummy code to be masked out later
135
+
136
+ all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
137
+
138
+ # mask out any codes that were dropout-ed
139
+
140
+ all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
141
+
142
+ # scale the codes
143
+
144
+ scales = rearrange(self.scales, "q d -> q 1 1 d")
145
+ all_codes = all_codes * scales
146
+
147
+ # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
148
+
149
+ (all_codes,) = unpack(all_codes, ps, "q b * d")
150
+
151
+ return all_codes
152
+
153
+ def get_output_from_indices(self, indices):
154
+ codes = self.get_codes_from_indices(indices)
155
+ codes_summed = reduce(codes, "q ... -> ...", "sum")
156
+ return self.project_out(codes_summed)
157
+
158
+ def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
159
+ num_quant, quant_dropout_multiple_of, device = (
160
+ self.num_quantizers,
161
+ self.quantize_dropout_multiple_of,
162
+ x.device,
163
+ )
164
+
165
+ # handle channel first
166
+
167
+ if self.is_channel_first:
168
+ x = rearrange(x, "b d ... -> b ... d")
169
+ x, ps = pack([x], "b * d")
170
+
171
+ # maybe project in
172
+
173
+ x = self.project_in(x)
174
+
175
+ quantized_out = 0.0
176
+ residual = x
177
+
178
+ all_indices = []
179
+
180
+ should_quantize_dropout = self.training and self.quantize_dropout
181
+
182
+ # sample a layer index at which to dropout further residual quantization
183
+ # also prepare null indices
184
+
185
+ if should_quantize_dropout:
186
+
187
+ # check if seed is manually passed in
188
+
189
+ if not exists(rand_quantize_dropout_fixed_seed):
190
+ rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
191
+
192
+ rand = random.Random(rand_quantize_dropout_fixed_seed)
193
+
194
+ rand_quantize_dropout_index = rand.randrange(
195
+ self.quantize_dropout_cutoff_index, num_quant
196
+ )
197
+
198
+ if quant_dropout_multiple_of != 1:
199
+ rand_quantize_dropout_index = (
200
+ round_up_multiple(
201
+ rand_quantize_dropout_index + 1, quant_dropout_multiple_of
202
+ )
203
+ - 1
204
+ )
205
+
206
+ null_indices = torch.full(
207
+ x.shape[:2], -1.0, device=device, dtype=torch.long
208
+ )
209
+
210
+ # go through the layers
211
+
212
+ with autocast("cuda", enabled=False):
213
+ for quantizer_index, (layer, scale) in enumerate(
214
+ zip(self.layers, self.scales)
215
+ ):
216
+
217
+ if (
218
+ should_quantize_dropout
219
+ and quantizer_index > rand_quantize_dropout_index
220
+ ):
221
+ all_indices.append(null_indices)
222
+ continue
223
+
224
+ quantized, indices = layer(residual / scale)
225
+
226
+ quantized = quantized * scale
227
+
228
+ residual = residual - quantized.detach()
229
+ quantized_out = quantized_out + quantized
230
+
231
+ all_indices.append(indices)
232
+
233
+ # project out, if needed
234
+
235
+ quantized_out = self.project_out(quantized_out)
236
+
237
+ # stack all indices
238
+
239
+ all_indices = torch.stack(all_indices, dim=-1)
240
+
241
+ # channel first out
242
+
243
+ if self.is_channel_first:
244
+ (quantized_out,) = unpack(quantized_out, ps, "b * d")
245
+ (all_indices,) = unpack(all_indices, ps, "b * d")
246
+
247
+ quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
248
+ all_indices = rearrange(all_indices, "b ... d -> b d ...")
249
+
250
+ # return
251
+
252
+ ret = (quantized_out, all_indices)
253
+
254
+ if not return_all_codes:
255
+ return ret
256
+
257
+ # whether to return all codes from all codebooks across layers
258
+
259
+ all_codes = self.get_codes_from_indices(all_indices)
260
+
261
+ # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
262
+
263
+ return (*ret, all_codes)
264
+
265
+
266
+ # grouped residual fsq
267
+
268
+
269
+ class GroupedResidualFSQ(Module):
270
+ def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
271
+ super().__init__()
272
+ self.dim = dim
273
+ self.groups = groups
274
+ assert (dim % groups) == 0
275
+ dim_per_group = dim // groups
276
+
277
+ self.accept_image_fmap = accept_image_fmap
278
+
279
+ self.rvqs = nn.ModuleList([])
280
+
281
+ for _ in range(groups):
282
+ self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
283
+
284
+ self.codebook_size = self.rvqs[0].codebook_size
285
+
286
+ @property
287
+ def codebooks(self):
288
+ return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
289
+
290
+ @property
291
+ def split_dim(self):
292
+ return 1 if self.accept_image_fmap else -1
293
+
294
+ def get_codes_from_indices(self, indices):
295
+ codes = tuple(
296
+ rvq.get_codes_from_indices(chunk_indices)
297
+ for rvq, chunk_indices in zip(self.rvqs, indices)
298
+ )
299
+ return torch.stack(codes)
300
+
301
+ def get_output_from_indices(self, indices):
302
+ outputs = tuple(
303
+ rvq.get_output_from_indices(chunk_indices)
304
+ for rvq, chunk_indices in zip(self.rvqs, indices)
305
+ )
306
+ return torch.cat(outputs, dim=self.split_dim)
307
+
308
+ def forward(self, x, return_all_codes=False):
309
+ shape, split_dim, device = x.shape, self.split_dim, x.device
310
+ assert shape[split_dim] == self.dim
311
+
312
+ # split the feature dimension into groups
313
+
314
+ x = x.chunk(self.groups, dim=split_dim)
315
+
316
+ forward_kwargs = dict(
317
+ return_all_codes=return_all_codes,
318
+ rand_quantize_dropout_fixed_seed=(
319
+ get_maybe_sync_seed(device) if self.training else None
320
+ ),
321
+ )
322
+
323
+ # invoke residual vq on each group
324
+
325
+ out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
326
+ out = tuple(zip(*out))
327
+
328
+ # otherwise, get all the zipped outputs and combine them
329
+
330
+ quantized, all_indices, *maybe_all_codes = out
331
+
332
+ quantized = torch.cat(quantized, dim=split_dim)
333
+ all_indices = torch.stack(all_indices)
334
+
335
+ ret = (quantized, all_indices, *maybe_all_codes)
336
+ return ret
337
+
338
+
339
+ if __name__ == "__main__":
340
+ model = ResidualFSQ(
341
+ levels=[4, 4, 4, 4, 4, 4],
342
+ num_quantizers=1,
343
+ dim=30,
344
+ is_channel_first=True,
345
+ quantize_dropout=False,
346
+ )
347
+ x = torch.randn(2, 30, 10)
348
+ quantize, embed_ind = model(x)
349
+
350
+ emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
351
+
352
+ print(quantize == emb_from_ind.transpose(1, 2))
353
+
354
+ print("quantize shape", quantize.shape)
355
+ print("embed_ind", embed_ind)
sparktts/modules/speaker/ecapa_tdnn.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Zhengyang Chen ([email protected])
2
+ # 2022 Hongji Wang ([email protected])
3
+ # 2023 Bing Han ([email protected])
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """ This implementation is adapted from github repo:
18
+ https://github.com/lawlict/ECAPA-TDNN.
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ import sparktts.modules.speaker.pooling_layers as pooling_layers
26
+
27
+
28
+ class Res2Conv1dReluBn(nn.Module):
29
+ """
30
+ in_channels == out_channels == channels
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ channels,
36
+ kernel_size=1,
37
+ stride=1,
38
+ padding=0,
39
+ dilation=1,
40
+ bias=True,
41
+ scale=4,
42
+ ):
43
+ super().__init__()
44
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
45
+ self.scale = scale
46
+ self.width = channels // scale
47
+ self.nums = scale if scale == 1 else scale - 1
48
+
49
+ self.convs = []
50
+ self.bns = []
51
+ for i in range(self.nums):
52
+ self.convs.append(
53
+ nn.Conv1d(
54
+ self.width,
55
+ self.width,
56
+ kernel_size,
57
+ stride,
58
+ padding,
59
+ dilation,
60
+ bias=bias,
61
+ )
62
+ )
63
+ self.bns.append(nn.BatchNorm1d(self.width))
64
+ self.convs = nn.ModuleList(self.convs)
65
+ self.bns = nn.ModuleList(self.bns)
66
+
67
+ def forward(self, x):
68
+ out = []
69
+ spx = torch.split(x, self.width, 1)
70
+ sp = spx[0]
71
+ for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
72
+ # Order: conv -> relu -> bn
73
+ if i >= 1:
74
+ sp = sp + spx[i]
75
+ sp = conv(sp)
76
+ sp = bn(F.relu(sp))
77
+ out.append(sp)
78
+ if self.scale != 1:
79
+ out.append(spx[self.nums])
80
+ out = torch.cat(out, dim=1)
81
+
82
+ return out
83
+
84
+
85
+ """ Conv1d + BatchNorm1d + ReLU
86
+ """
87
+
88
+
89
+ class Conv1dReluBn(nn.Module):
90
+
91
+ def __init__(
92
+ self,
93
+ in_channels,
94
+ out_channels,
95
+ kernel_size=1,
96
+ stride=1,
97
+ padding=0,
98
+ dilation=1,
99
+ bias=True,
100
+ ):
101
+ super().__init__()
102
+ self.conv = nn.Conv1d(
103
+ in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
104
+ )
105
+ self.bn = nn.BatchNorm1d(out_channels)
106
+
107
+ def forward(self, x):
108
+ return self.bn(F.relu(self.conv(x)))
109
+
110
+
111
+ """ The SE connection of 1D case.
112
+ """
113
+
114
+
115
+ class SE_Connect(nn.Module):
116
+
117
+ def __init__(self, channels, se_bottleneck_dim=128):
118
+ super().__init__()
119
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
120
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
121
+
122
+ def forward(self, x):
123
+ out = x.mean(dim=2)
124
+ out = F.relu(self.linear1(out))
125
+ out = torch.sigmoid(self.linear2(out))
126
+ out = x * out.unsqueeze(2)
127
+
128
+ return out
129
+
130
+
131
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
132
+ """
133
+
134
+
135
+ class SE_Res2Block(nn.Module):
136
+
137
+ def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
138
+ super().__init__()
139
+ self.se_res2block = nn.Sequential(
140
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
141
+ Res2Conv1dReluBn(
142
+ channels, kernel_size, stride, padding, dilation, scale=scale
143
+ ),
144
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
145
+ SE_Connect(channels),
146
+ )
147
+
148
+ def forward(self, x):
149
+ return x + self.se_res2block(x)
150
+
151
+
152
+ class ECAPA_TDNN(nn.Module):
153
+
154
+ def __init__(
155
+ self,
156
+ channels=512,
157
+ feat_dim=80,
158
+ embed_dim=192,
159
+ pooling_func="ASTP",
160
+ global_context_att=False,
161
+ emb_bn=False,
162
+ ):
163
+ super().__init__()
164
+
165
+ self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2)
166
+ self.layer2 = SE_Res2Block(
167
+ channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8
168
+ )
169
+ self.layer3 = SE_Res2Block(
170
+ channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8
171
+ )
172
+ self.layer4 = SE_Res2Block(
173
+ channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8
174
+ )
175
+
176
+ cat_channels = channels * 3
177
+ out_channels = 512 * 3
178
+ self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
179
+ self.pool = getattr(pooling_layers, pooling_func)(
180
+ in_dim=out_channels, global_context_att=global_context_att
181
+ )
182
+ self.pool_out_dim = self.pool.get_out_dim()
183
+ self.bn = nn.BatchNorm1d(self.pool_out_dim)
184
+ self.linear = nn.Linear(self.pool_out_dim, embed_dim)
185
+ self.emb_bn = emb_bn
186
+ if emb_bn: # better in SSL for SV
187
+ self.bn2 = nn.BatchNorm1d(embed_dim)
188
+ else:
189
+ self.bn2 = nn.Identity()
190
+
191
+ def forward(self, x, return_latent=False):
192
+ x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
193
+
194
+ out1 = self.layer1(x)
195
+ out2 = self.layer2(out1)
196
+ out3 = self.layer3(out2)
197
+ out4 = self.layer4(out3)
198
+
199
+ out = torch.cat([out2, out3, out4], dim=1)
200
+ latent = F.relu(self.conv(out))
201
+ out = self.bn(self.pool(latent))
202
+ out = self.linear(out)
203
+ if self.emb_bn:
204
+ out = self.bn2(out)
205
+
206
+ if return_latent:
207
+ return out, latent
208
+ return out
209
+
210
+
211
+ def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
212
+ return ECAPA_TDNN(
213
+ channels=1024,
214
+ feat_dim=feat_dim,
215
+ embed_dim=embed_dim,
216
+ pooling_func=pooling_func,
217
+ emb_bn=emb_bn,
218
+ )
219
+
220
+
221
+ def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
222
+ return ECAPA_TDNN(
223
+ channels=1024,
224
+ feat_dim=feat_dim,
225
+ embed_dim=embed_dim,
226
+ pooling_func=pooling_func,
227
+ global_context_att=True,
228
+ emb_bn=emb_bn,
229
+ )
230
+
231
+
232
+ def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
233
+ return ECAPA_TDNN(
234
+ channels=512,
235
+ feat_dim=feat_dim,
236
+ embed_dim=embed_dim,
237
+ pooling_func=pooling_func,
238
+ emb_bn=emb_bn,
239
+ )
240
+
241
+
242
+ def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False):
243
+ return ECAPA_TDNN(
244
+ channels=512,
245
+ feat_dim=feat_dim,
246
+ embed_dim=embed_dim,
247
+ pooling_func=pooling_func,
248
+ global_context_att=True,
249
+ emb_bn=emb_bn,
250
+ )
251
+
252
+
253
+ if __name__ == "__main__":
254
+ x = torch.zeros(1, 200, 100)
255
+ model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP")
256
+ model.eval()
257
+ out, latent = model(x, True)
258
+ print(out.shape)
259
+ print(latent.shape)
260
+
261
+ num_params = sum(param.numel() for param in model.parameters())
262
+ print("{} M".format(num_params / 1e6))
263
+
264
+ # from thop import profile
265
+ # x_np = torch.randn(1, 200, 80)
266
+ # flops, params = profile(model, inputs=(x_np, ))
267
+ # print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
sparktts/modules/speaker/perceiver_encoder.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
17
+
18
+ from collections import namedtuple
19
+ from functools import wraps
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from einops import rearrange, repeat
24
+ from einops.layers.torch import Rearrange
25
+ from packaging import version
26
+ from torch import einsum, nn
27
+
28
+
29
+ def exists(val):
30
+ return val is not None
31
+
32
+
33
+ def once(fn):
34
+ called = False
35
+
36
+ @wraps(fn)
37
+ def inner(x):
38
+ nonlocal called
39
+ if called:
40
+ return
41
+ called = True
42
+ return fn(x)
43
+
44
+ return inner
45
+
46
+
47
+ print_once = once(print)
48
+
49
+ # main class
50
+
51
+
52
+ class Attend(nn.Module):
53
+ def __init__(self, dropout=0.0, causal=False, use_flash=False):
54
+ super().__init__()
55
+ self.dropout = dropout
56
+ self.attn_dropout = nn.Dropout(dropout)
57
+
58
+ self.causal = causal
59
+ self.register_buffer("mask", None, persistent=False)
60
+
61
+ self.use_flash = use_flash
62
+ assert not (
63
+ use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
64
+ ), "in order to use flash attention, you must be using pytorch 2.0 or above"
65
+
66
+ # determine efficient attention configs for cuda and cpu
67
+ self.config = namedtuple(
68
+ "EfficientAttentionConfig",
69
+ ["enable_flash", "enable_math", "enable_mem_efficient"],
70
+ )
71
+ self.cpu_config = self.config(True, True, True)
72
+ self.cuda_config = None
73
+
74
+ if not torch.cuda.is_available() or not use_flash:
75
+ return
76
+
77
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
78
+
79
+ if device_properties.major == 8 and device_properties.minor == 0:
80
+ print_once(
81
+ "A100 GPU detected, using flash attention if input tensor is on cuda"
82
+ )
83
+ self.cuda_config = self.config(True, False, False)
84
+ else:
85
+ print_once(
86
+ "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
87
+ )
88
+ self.cuda_config = self.config(False, True, True)
89
+
90
+ def get_mask(self, n, device):
91
+ if exists(self.mask) and self.mask.shape[-1] >= n:
92
+ return self.mask[:n, :n]
93
+
94
+ mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
95
+ self.register_buffer("mask", mask, persistent=False)
96
+ return mask
97
+
98
+ def flash_attn(self, q, k, v, mask=None):
99
+ _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
100
+
101
+ # Recommended for multi-query single-key-value attention by Tri Dao
102
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
103
+
104
+ if k.ndim == 3:
105
+ k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
106
+
107
+ if v.ndim == 3:
108
+ v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
109
+
110
+ # Check if mask exists and expand to compatible shape
111
+ # The mask is B L, so it would have to be expanded to B H N L
112
+
113
+ if exists(mask):
114
+ mask = rearrange(mask, "b j -> b 1 1 j")
115
+ mask = mask.expand(-1, heads, q_len, -1)
116
+
117
+ # Check if there is a compatible device for flash attention
118
+
119
+ config = self.cuda_config if is_cuda else self.cpu_config
120
+
121
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
122
+
123
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
124
+ out = F.scaled_dot_product_attention(
125
+ q,
126
+ k,
127
+ v,
128
+ attn_mask=mask,
129
+ dropout_p=self.dropout if self.training else 0.0,
130
+ is_causal=self.causal,
131
+ )
132
+
133
+ return out
134
+
135
+ def forward(self, q, k, v, mask=None):
136
+ """
137
+ einstein notation
138
+ b - batch
139
+ h - heads
140
+ n, i, j - sequence length (base sequence length, source, target)
141
+ d - feature dimension
142
+ """
143
+
144
+ n, device = q.shape[-2], q.device
145
+
146
+ scale = q.shape[-1] ** -0.5
147
+
148
+ if self.use_flash:
149
+ return self.flash_attn(q, k, v, mask=mask)
150
+
151
+ kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
152
+
153
+ # similarity
154
+
155
+ sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
156
+
157
+ # key padding mask
158
+
159
+ if exists(mask):
160
+ mask = rearrange(mask, "b j -> b 1 1 j")
161
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
162
+
163
+ # causal mask
164
+
165
+ if self.causal:
166
+ causal_mask = self.get_mask(n, device)
167
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
168
+
169
+ # attention
170
+
171
+ attn = sim.softmax(dim=-1)
172
+ attn = self.attn_dropout(attn)
173
+
174
+ # aggregate values
175
+
176
+ out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
177
+
178
+ return out
179
+
180
+
181
+ def Sequential(*mods):
182
+ return nn.Sequential(*filter(exists, mods))
183
+
184
+
185
+ def exists(x):
186
+ return x is not None
187
+
188
+
189
+ def default(val, d):
190
+ if exists(val):
191
+ return val
192
+ return d() if callable(d) else d
193
+
194
+
195
+ class RMSNorm(nn.Module):
196
+ def __init__(self, dim, scale=True, dim_cond=None):
197
+ super().__init__()
198
+ self.cond = exists(dim_cond)
199
+ self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
200
+
201
+ self.scale = dim**0.5
202
+ self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
203
+
204
+ def forward(self, x, cond=None):
205
+ gamma = default(self.gamma, 1)
206
+ out = F.normalize(x, dim=-1) * self.scale * gamma
207
+
208
+ if not self.cond:
209
+ return out
210
+
211
+ assert exists(cond)
212
+ gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
213
+ gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
214
+ return out * gamma + beta
215
+
216
+
217
+ class CausalConv1d(nn.Conv1d):
218
+ def __init__(self, *args, **kwargs):
219
+ super().__init__(*args, **kwargs)
220
+ (kernel_size,) = self.kernel_size
221
+ (dilation,) = self.dilation
222
+ (stride,) = self.stride
223
+
224
+ assert stride == 1
225
+ self.causal_padding = dilation * (kernel_size - 1)
226
+
227
+ def forward(self, x):
228
+ causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
229
+ return super().forward(causal_padded_x)
230
+
231
+
232
+ class GEGLU(nn.Module):
233
+ def forward(self, x):
234
+ x, gate = x.chunk(2, dim=-1)
235
+ return F.gelu(gate) * x
236
+
237
+
238
+ def FeedForward(dim, mult=4, causal_conv=False):
239
+ dim_inner = int(dim * mult * 2 / 3)
240
+
241
+ conv = None
242
+ if causal_conv:
243
+ conv = nn.Sequential(
244
+ Rearrange("b n d -> b d n"),
245
+ CausalConv1d(dim_inner, dim_inner, 3),
246
+ Rearrange("b d n -> b n d"),
247
+ )
248
+
249
+ return Sequential(
250
+ nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)
251
+ )
252
+
253
+
254
+ class Attention(nn.Module):
255
+ def __init__(
256
+ self,
257
+ dim,
258
+ *,
259
+ dim_context=None,
260
+ causal=False,
261
+ dim_head=64,
262
+ heads=8,
263
+ dropout=0.0,
264
+ use_flash=False,
265
+ cross_attn_include_queries=False,
266
+ ):
267
+ super().__init__()
268
+ self.scale = dim_head**-0.5
269
+ self.heads = heads
270
+ self.cross_attn_include_queries = cross_attn_include_queries
271
+
272
+ dim_inner = dim_head * heads
273
+ dim_context = default(dim_context, dim)
274
+
275
+ self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
276
+ self.to_q = nn.Linear(dim, dim_inner, bias=False)
277
+ self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
278
+ self.to_out = nn.Linear(dim_inner, dim, bias=False)
279
+
280
+ def forward(self, x, context=None, mask=None):
281
+ h, has_context = self.heads, exists(context)
282
+
283
+ context = default(context, x)
284
+
285
+ if has_context and self.cross_attn_include_queries:
286
+ context = torch.cat((x, context), dim=-2)
287
+
288
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
289
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
290
+
291
+ out = self.attend(q, k, v, mask=mask)
292
+
293
+ out = rearrange(out, "b h n d -> b n (h d)")
294
+ return self.to_out(out)
295
+
296
+
297
+ class PerceiverResampler(nn.Module):
298
+ def __init__(
299
+ self,
300
+ *,
301
+ dim,
302
+ depth=2,
303
+ dim_context=None,
304
+ num_latents=32,
305
+ dim_head=64,
306
+ heads=8,
307
+ ff_mult=4,
308
+ use_flash_attn=False,
309
+ ):
310
+ super().__init__()
311
+ dim_context = default(dim_context, dim)
312
+
313
+ self.proj_context = (
314
+ nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
315
+ )
316
+
317
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
318
+ nn.init.normal_(self.latents, std=0.02)
319
+
320
+ self.layers = nn.ModuleList([])
321
+ for _ in range(depth):
322
+ self.layers.append(
323
+ nn.ModuleList(
324
+ [
325
+ Attention(
326
+ dim=dim,
327
+ dim_head=dim_head,
328
+ heads=heads,
329
+ use_flash=use_flash_attn,
330
+ cross_attn_include_queries=True,
331
+ ),
332
+ FeedForward(dim=dim, mult=ff_mult),
333
+ ]
334
+ )
335
+ )
336
+
337
+ self.norm = RMSNorm(dim)
338
+
339
+ def forward(self, x, mask=None):
340
+ batch = x.shape[0]
341
+
342
+ x = self.proj_context(x)
343
+
344
+ latents = repeat(self.latents, "n d -> b n d", b=batch)
345
+
346
+ for attn, ff in self.layers:
347
+ latents = attn(latents, x, mask=mask) + latents
348
+ latents = ff(latents) + latents
349
+
350
+ return self.norm(latents)
351
+
352
+
353
+ if __name__ == "__main__":
354
+ model = PerceiverResampler(dim=256, dim_context=80)
355
+ x = torch.randn(8, 200, 80)
356
+ out = model(x)
357
+ print(out.shape) # [8, 32, 80]
358
+
359
+ num_params = sum(param.numel() for param in model.parameters())
360
+ print("{} M".format(num_params / 1e6))
sparktts/modules/speaker/pooling_layers.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Shuai Wang ([email protected])
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Pooling functions to aggregate frame-level deep features
16
+ into segment-level speaker embeddings
17
+
18
+ High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
19
+ even though we remove the mean statistic, on Voxceleb.
20
+ """
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+
27
+ class TAP(nn.Module):
28
+ """
29
+ Temporal average pooling, only first-order mean is considered
30
+ """
31
+
32
+ def __init__(self, in_dim=0, **kwargs):
33
+ super(TAP, self).__init__()
34
+ self.in_dim = in_dim
35
+
36
+ def forward(self, x):
37
+ pooling_mean = x.mean(dim=-1)
38
+ # To be compatable with 2D input
39
+ pooling_mean = pooling_mean.flatten(start_dim=1)
40
+ return pooling_mean
41
+
42
+ def get_out_dim(self):
43
+ self.out_dim = self.in_dim
44
+ return self.out_dim
45
+
46
+
47
+ class TSDP(nn.Module):
48
+ """
49
+ Temporal standard deviation pooling, only second-order std is considered
50
+ """
51
+
52
+ def __init__(self, in_dim=0, **kwargs):
53
+ super(TSDP, self).__init__()
54
+ self.in_dim = in_dim
55
+
56
+ def forward(self, x):
57
+ # The last dimension is the temporal axis
58
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
59
+ pooling_std = pooling_std.flatten(start_dim=1)
60
+ return pooling_std
61
+
62
+ def get_out_dim(self):
63
+ self.out_dim = self.in_dim
64
+ return self.out_dim
65
+
66
+
67
+ class TSTP(nn.Module):
68
+ """
69
+ Temporal statistics pooling, concatenate mean and std, which is used in
70
+ x-vector
71
+ Comment: simple concatenation can not make full use of both statistics
72
+ """
73
+
74
+ def __init__(self, in_dim=0, **kwargs):
75
+ super(TSTP, self).__init__()
76
+ self.in_dim = in_dim
77
+
78
+ def forward(self, x):
79
+ # The last dimension is the temporal axis
80
+ pooling_mean = x.mean(dim=-1)
81
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
82
+ pooling_mean = pooling_mean.flatten(start_dim=1)
83
+ pooling_std = pooling_std.flatten(start_dim=1)
84
+ stats = torch.cat((pooling_mean, pooling_std), 1)
85
+ return stats
86
+
87
+ def get_out_dim(self):
88
+ self.out_dim = self.in_dim * 2
89
+ return self.out_dim
90
+
91
+
92
+ class ASTP(nn.Module):
93
+ """ Attentive statistics pooling: Channel- and context-dependent
94
+ statistics pooling, first used in ECAPA_TDNN.
95
+ """
96
+
97
+ def __init__(self,
98
+ in_dim,
99
+ bottleneck_dim=128,
100
+ global_context_att=False,
101
+ **kwargs):
102
+ super(ASTP, self).__init__()
103
+ self.in_dim = in_dim
104
+ self.global_context_att = global_context_att
105
+
106
+ # Use Conv1d with stride == 1 rather than Linear, then we don't
107
+ # need to transpose inputs.
108
+ if global_context_att:
109
+ self.linear1 = nn.Conv1d(
110
+ in_dim * 3, bottleneck_dim,
111
+ kernel_size=1) # equals W and b in the paper
112
+ else:
113
+ self.linear1 = nn.Conv1d(
114
+ in_dim, bottleneck_dim,
115
+ kernel_size=1) # equals W and b in the paper
116
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
117
+ kernel_size=1) # equals V and k in the paper
118
+
119
+ def forward(self, x):
120
+ """
121
+ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
122
+ or a 4-dimensional tensor in resnet architecture (B,C,F,T)
123
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
124
+ """
125
+ if len(x.shape) == 4:
126
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
127
+ assert len(x.shape) == 3
128
+
129
+ if self.global_context_att:
130
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
131
+ context_std = torch.sqrt(
132
+ torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x)
133
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
134
+ else:
135
+ x_in = x
136
+
137
+ # DON'T use ReLU here! ReLU may be hard to converge.
138
+ alpha = torch.tanh(
139
+ self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
140
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
141
+ mean = torch.sum(alpha * x, dim=2)
142
+ var = torch.sum(alpha * (x**2), dim=2) - mean**2
143
+ std = torch.sqrt(var.clamp(min=1e-7))
144
+ return torch.cat([mean, std], dim=1)
145
+
146
+ def get_out_dim(self):
147
+ self.out_dim = 2 * self.in_dim
148
+ return self.out_dim
149
+
150
+
151
+ class MHASTP(torch.nn.Module):
152
+ """ Multi head attentive statistics pooling
153
+ Reference:
154
+ Self Multi-Head Attention for Speaker Recognition
155
+ https://arxiv.org/pdf/1906.09890.pdf
156
+ """
157
+
158
+ def __init__(self,
159
+ in_dim,
160
+ layer_num=2,
161
+ head_num=2,
162
+ d_s=1,
163
+ bottleneck_dim=64,
164
+ **kwargs):
165
+ super(MHASTP, self).__init__()
166
+ assert (in_dim % head_num
167
+ ) == 0 # make sure that head num can be divided by input_dim
168
+ self.in_dim = in_dim
169
+ self.head_num = head_num
170
+ d_model = int(in_dim / head_num)
171
+ channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
172
+ if d_s > 1:
173
+ d_s = d_model
174
+ else:
175
+ d_s = 1
176
+ self.d_s = d_s
177
+ channel_dims[0], channel_dims[-1] = d_model, d_s
178
+ heads_att_trans = []
179
+ for i in range(self.head_num):
180
+ att_trans = nn.Sequential()
181
+ for i in range(layer_num - 1):
182
+ att_trans.add_module(
183
+ 'att_' + str(i),
184
+ nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
185
+ att_trans.add_module('tanh' + str(i), nn.Tanh())
186
+ att_trans.add_module(
187
+ 'att_' + str(layer_num - 1),
188
+ nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
189
+ 1, 1))
190
+ heads_att_trans.append(att_trans)
191
+ self.heads_att_trans = nn.ModuleList(heads_att_trans)
192
+
193
+ def forward(self, input):
194
+ """
195
+ input: a 3-dimensional tensor in xvector architecture
196
+ or a 4-dimensional tensor in resnet architecture
197
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
198
+ """
199
+ if len(input.shape) == 4: # B x F x T
200
+ input = input.reshape(input.shape[0],
201
+ input.shape[1] * input.shape[2],
202
+ input.shape[3])
203
+ assert len(input.shape) == 3
204
+ bs, f_dim, t_dim = input.shape
205
+ chunks = torch.chunk(input, self.head_num, 1)
206
+ # split
207
+ chunks_out = []
208
+ # for i in range(self.head_num):
209
+ # att_score = self.heads_att_trans[i](chunks[i])
210
+ for i, layer in enumerate(self.heads_att_trans):
211
+ att_score = layer(chunks[i])
212
+ alpha = F.softmax(att_score, dim=-1)
213
+ mean = torch.sum(alpha * chunks[i], dim=2)
214
+ var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
215
+ std = torch.sqrt(var.clamp(min=1e-7))
216
+ chunks_out.append(torch.cat((mean, std), dim=1))
217
+ out = torch.cat(chunks_out, dim=1)
218
+ return out
219
+
220
+ def get_out_dim(self):
221
+ self.out_dim = 2 * self.in_dim
222
+ return self.out_dim
223
+
224
+
225
+ class MQMHASTP(torch.nn.Module):
226
+ """ An attentive pooling
227
+ Reference:
228
+ multi query multi head attentive statistics pooling
229
+ https://arxiv.org/pdf/2110.05042.pdf
230
+ Args:
231
+ in_dim: the feature dimension of input
232
+ layer_num: the number of layer in the pooling layer
233
+ query_num: the number of querys
234
+ head_num: the number of heads
235
+ bottleneck_dim: the bottleneck dimension
236
+
237
+ SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
238
+ https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
239
+ MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
240
+ https://arxiv.org/pdf/1906.09890.pdf
241
+ AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
242
+ https://arxiv.org/pdf/1803.10963.pdf
243
+ VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
244
+ http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
245
+ """
246
+
247
+ def __init__(self,
248
+ in_dim,
249
+ layer_num=2,
250
+ query_num=2,
251
+ head_num=8,
252
+ d_s=2,
253
+ bottleneck_dim=64,
254
+ **kwargs):
255
+ super(MQMHASTP, self).__init__()
256
+ self.n_query = nn.ModuleList([
257
+ MHASTP(in_dim,
258
+ layer_num=layer_num,
259
+ head_num=head_num,
260
+ d_s=d_s,
261
+ bottleneck_dim=bottleneck_dim) for i in range(query_num)
262
+ ])
263
+ self.query_num = query_num
264
+ self.in_dim = in_dim
265
+
266
+ def forward(self, input):
267
+ """
268
+ input: a 3-dimensional tensor in xvector architecture
269
+ or a 4-dimensional tensor in resnet architecture
270
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
271
+ """
272
+ if len(input.shape) == 4: # B x F x T
273
+ input = input.reshape(input.shape[0],
274
+ input.shape[1] * input.shape[2],
275
+ input.shape[3])
276
+ assert len(input.shape) == 3
277
+ res = []
278
+ for i, layer in enumerate(self.n_query):
279
+ res.append(layer(input))
280
+ out = torch.cat(res, dim=-1)
281
+ return out
282
+
283
+ def get_out_dim(self):
284
+ self.out_dim = self.in_dim * 2 * self.query_num
285
+ return self.out_dim
286
+
287
+
288
+ if __name__ == '__main__':
289
+ data = torch.randn(16, 512, 10, 35)
290
+ # model = StatisticsPooling()
291
+ model = MQMHASTP(512 * 10)
292
+ model = MHASTP(512 * 10)
293
+ model = MQMHASTP(512 * 10, context=False)
294
+ print(model)
295
+
296
+ out = model(data)
297
+ print(out.shape)
298
+ print(model.get_out_dim())
sparktts/modules/speaker/speaker_encoder.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from typing import List, Tuple
20
+ from sparktts.modules.fsq.residual_fsq import ResidualFSQ
21
+ from sparktts.modules.speaker.ecapa_tdnn import ECAPA_TDNN_GLOB_c512
22
+ from sparktts.modules.speaker.perceiver_encoder import PerceiverResampler
23
+
24
+ """
25
+ x-vector + d-vector
26
+ """
27
+
28
+
29
+ class SpeakerEncoder(nn.Module):
30
+ """
31
+
32
+ Args:
33
+ input_dim (int): acoustic feature dimension
34
+ out_dim (int): output dimension of x-vector and d-vector
35
+ latent_dim (int): latent dimension before quantization
36
+ token_num (int): sequence length of speaker tokens
37
+ fsq_levels (List[int]): number of levels for each quantizer
38
+ fsq_num_quantizers (int): number of quantizers
39
+
40
+ Return:
41
+ speaker_embs: (B, T2, out_dim)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ input_dim: int = 100,
47
+ out_dim: int = 512,
48
+ latent_dim: int = 128,
49
+ token_num: int = 32,
50
+ fsq_levels: List[int] = [4, 4, 4, 4, 4, 4],
51
+ fsq_num_quantizers: int = 1,
52
+ ):
53
+ super(SpeakerEncoder, self).__init__()
54
+
55
+ self.speaker_encoder = ECAPA_TDNN_GLOB_c512(
56
+ feat_dim=input_dim, embed_dim=out_dim
57
+ )
58
+ self.perceiver_sampler = PerceiverResampler(
59
+ dim=latent_dim, dim_context=512 * 3, num_latents=token_num
60
+ )
61
+ self.quantizer = ResidualFSQ(
62
+ levels=fsq_levels,
63
+ num_quantizers=fsq_num_quantizers,
64
+ dim=latent_dim,
65
+ is_channel_first=True,
66
+ quantize_dropout=False,
67
+ )
68
+
69
+ self.project = nn.Linear(latent_dim * token_num, out_dim)
70
+
71
+ def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor:
72
+ zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2))
73
+ return zq.transpose(1, 2)
74
+
75
+ def get_indices(self, mels: torch.Tensor) -> torch.Tensor:
76
+ mels = mels.transpose(1, 2)
77
+ x = self.perceiver_sampler(mels).transpose(1, 2)
78
+ zq, indices = self.quantizer(x)
79
+ return indices
80
+
81
+ def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
82
+ """
83
+ Args:
84
+ mels: (B, D_mel, T1)
85
+
86
+ Return:
87
+ x_vector: (B, out_dim)
88
+ d_vector: (B, out_dim)
89
+ """
90
+ # mels = mels.transpose(1,2)
91
+
92
+ x_vector, features = self.speaker_encoder(mels, True)
93
+ x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
94
+ zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim)
95
+ x = zq.reshape(zq.shape[0], -1)
96
+ d_vector = self.project(x)
97
+
98
+ return x_vector, d_vector
99
+
100
+ def tokenize(self, mels: torch.Tensor) -> torch.Tensor:
101
+ """tokenize the input mel spectrogram"""
102
+ _, features = self.speaker_encoder(mels, True)
103
+ x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2)
104
+ zq, indices = self.quantizer(x)
105
+ return indices
106
+
107
+ def detokenize(self, indices: torch.Tensor) -> torch.Tensor:
108
+ """detokenize the input indices to d-vector"""
109
+ zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2)
110
+ x = zq.reshape(zq.shape[0], -1)
111
+ d_vector = self.project(x)
112
+ return d_vector
113
+
114
+ if __name__ == "__main__":
115
+ model = SpeakerEncoder(
116
+ input_dim=100,
117
+ latent_dim=128,
118
+ token_num=32,
119
+ fsq_levels=[4, 4, 4, 4, 4, 4],
120
+ fsq_num_quantizers=1,
121
+ )
122
+ mel = torch.randn(8, 200, 100)
123
+ x_vector, d_vector = model(mel)
124
+ print("x-vector shape", x_vector.shape)
125
+ print("d-vector shape", d_vector.shape)
126
+
127
+ indices = model.tokenize(mel)
128
+ print("indices shape", indices.shape)
129
+ d_vector_post = model.detokenize(indices)
130
+ print("d-vector shape", d_vector_post.shape)
131
+ if d_vector_post.all() == d_vector.all():
132
+ print("d-vector post and d-vector are the same")
133
+ else:
134
+ print("d-vector post and d-vector are different")
135
+ num_params = sum(param.numel() for param in model.parameters())
136
+ print("{} M".format(num_params / 1e6))
sparktts/modules/vq/factorized_vector_quantize.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Heavily based on https://github.com/lucidrains/vector-quantize-pytorch
17
+
18
+
19
+ from typing import Any, Dict
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from einops import rearrange
25
+ from torch.nn.utils import weight_norm
26
+
27
+
28
+ def WNConv1d(*args, **kwargs):
29
+ return weight_norm(nn.Conv1d(*args, **kwargs))
30
+
31
+
32
+ def ema_inplace(moving_avg, new, decay):
33
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
34
+
35
+
36
+ class FactorizedVectorQuantize(nn.Module):
37
+ def __init__(
38
+ self,
39
+ input_dim: int,
40
+ codebook_size: int,
41
+ codebook_dim: int,
42
+ commitment: float,
43
+ codebook_loss_weight: float = 1.0,
44
+ decay: float = 0.99,
45
+ threshold_ema_dead_code: float = 2,
46
+ momentum: float = 0.99,
47
+ **kwargs,
48
+ ):
49
+ super().__init__()
50
+ self.input_dim = input_dim
51
+ self.codebook_size = codebook_size
52
+ self.codebook_dim = codebook_dim
53
+ self.commitment = commitment
54
+ self.codebook_loss_weight = codebook_loss_weight
55
+ self.decay = decay
56
+ self.threshold_ema_dead_code = threshold_ema_dead_code
57
+ self.momentum = momentum
58
+
59
+ if input_dim != self.codebook_dim:
60
+ self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
61
+ self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
62
+
63
+ else:
64
+ self.in_project = nn.Identity()
65
+ self.out_project = nn.Identity()
66
+
67
+ self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
68
+ self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
69
+
70
+ def forward(self, z: torch.Tensor) -> Dict[str, Any]:
71
+ """Quantized the input tensor using a fixed codebook and returns
72
+ the corresponding codebook vectors
73
+
74
+ Parameters
75
+ ----------
76
+ z : Tensor[B x D x T]
77
+
78
+ Returns
79
+ -------
80
+ Tensor[B x D x T]
81
+ Quantized continuous representation of input
82
+ Tensor[1]
83
+ Commitment loss to train encoder to predict vectors closer to codebook
84
+ entries
85
+ Tensor[1]
86
+ Codebook loss to update the codebook
87
+ Tensor[B x T]
88
+ Codebook indices (quantized discrete representation of input)
89
+ Tensor[B x D x T]
90
+ Projected latents (continuous representation of input before quantization)
91
+ """
92
+ # transpose since we use linear
93
+
94
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
95
+ z_e = self.in_project(z)
96
+ z_q, indices, dists = self.decode_latents(z_e)
97
+
98
+ # statistic the usage of codes
99
+ embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype)
100
+ avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0)
101
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
102
+
103
+ active_num = (embed_onehot.sum(0).sum(0) > 0).sum()
104
+ if self.training:
105
+ # We do the expiry of code at that point as buffers are in sync
106
+ # and all the workers will take the same decision.
107
+ ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay)
108
+ active_num = sum(self.cluster_size > self.threshold_ema_dead_code)
109
+
110
+ if self.training:
111
+ commit_loss = (
112
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
113
+ * self.commitment
114
+ )
115
+
116
+ codebook_loss = (
117
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
118
+ * self.codebook_loss_weight
119
+ )
120
+
121
+ else:
122
+ commit_loss = torch.zeros(0, device=z.device)
123
+ codebook_loss = torch.zeros(0, device=z.device)
124
+
125
+ z_q = (
126
+ z_e + (z_q - z_e).detach()
127
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
128
+
129
+ z_q = self.out_project(z_q)
130
+
131
+ vq_loss = (commit_loss + codebook_loss).mean()
132
+
133
+ return {
134
+ "z_q": z_q,
135
+ "indices": indices,
136
+ "dists": dists,
137
+ "vq_loss": vq_loss,
138
+ "perplexity": perplexity,
139
+ "active_num": active_num.float(),
140
+ }
141
+
142
+ def vq2emb(self, vq, out_proj=True):
143
+ emb = self.embed_code(vq)
144
+ if out_proj:
145
+ emb = self.out_project(emb)
146
+ return emb
147
+
148
+ def tokenize(self, z: torch.Tensor) -> torch.Tensor:
149
+ """tokenize the input tensor"""
150
+ z_e = self.in_project(z)
151
+ _, indices, _ = self.decode_latents(z_e)
152
+ return indices
153
+
154
+ def detokenize(self, indices):
155
+ """detokenize the input indices"""
156
+ z_q = self.decode_code(indices)
157
+ z_q = self.out_project(z_q)
158
+ return z_q
159
+
160
+ def get_emb(self):
161
+ return self.codebook.weight
162
+
163
+ def embed_code(self, embed_id):
164
+ return F.embedding(embed_id, self.codebook.weight)
165
+
166
+ def decode_code(self, embed_id):
167
+ return self.embed_code(embed_id).transpose(1, 2)
168
+
169
+ def decode_latents(self, latents):
170
+ encodings = rearrange(latents, "b d t -> (b t) d")
171
+ codebook = self.codebook.weight
172
+
173
+ # L2 normalize encodings and codebook
174
+ encodings = F.normalize(encodings)
175
+ codebook = F.normalize(codebook)
176
+
177
+ # Compute euclidean distance between encodings and codebook,
178
+ # with L2 normalization, the distance is equal to cosine distance
179
+ dist = (
180
+ encodings.pow(2).sum(1, keepdim=True)
181
+ - 2 * encodings @ codebook.t()
182
+ + codebook.pow(2).sum(1, keepdim=True).t()
183
+ )
184
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
185
+ z_q = self.decode_code(indices)
186
+
187
+ return z_q, indices, dist
sparktts/utils/__init__.py ADDED
File without changes
sparktts/utils/audio.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Description:
17
+ This script contains a collection of functions designed to handle various
18
+ audio processing.
19
+ """
20
+
21
+ import random
22
+ import soxr
23
+ import soundfile
24
+ import torch
25
+ import torchaudio
26
+ import numpy as np
27
+
28
+ from pathlib import Path
29
+ from typing import Tuple
30
+ from numpy.lib.stride_tricks import sliding_window_view
31
+
32
+
33
+ def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
34
+ """
35
+ Normalize the volume of an audio signal.
36
+
37
+ Parameters:
38
+ audio (numpy array): Input audio signal array.
39
+ coeff (float): Target coefficient for normalization, default is 0.2.
40
+
41
+ Returns:
42
+ numpy array: The volume-normalized audio signal.
43
+ """
44
+ # Sort the absolute values of the audio signal
45
+ temp = np.sort(np.abs(audio))
46
+
47
+ # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
48
+ if temp[-1] < 0.1:
49
+ scaling_factor = max(
50
+ temp[-1], 1e-3
51
+ ) # Prevent division by zero with a small constant
52
+ audio = audio / scaling_factor * 0.1
53
+
54
+ # Filter out values less than 0.01 from temp
55
+ temp = temp[temp > 0.01]
56
+ L = temp.shape[0] # Length of the filtered array
57
+
58
+ # If there are fewer than or equal to 10 significant values, return the audio without further processing
59
+ if L <= 10:
60
+ return audio
61
+
62
+ # Compute the average of the top 10% to 1% of values in temp
63
+ volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
64
+
65
+ # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
66
+ audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
67
+
68
+ # Ensure the maximum absolute value in the audio does not exceed 1
69
+ max_value = np.max(np.abs(audio))
70
+ if max_value > 1:
71
+ audio = audio / max_value
72
+
73
+ return audio
74
+
75
+
76
+ def load_audio(
77
+ adfile: Path,
78
+ sampling_rate: int = None,
79
+ length: int = None,
80
+ volume_normalize: bool = False,
81
+ segment_duration: int = None,
82
+ ) -> np.ndarray:
83
+ r"""Load audio file with target sampling rate and lsength
84
+
85
+ Args:
86
+ adfile (Path): path to audio file.
87
+ sampling_rate (int, optional): target sampling rate. Defaults to None.
88
+ length (int, optional): target audio length. Defaults to None.
89
+ volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
90
+ segment_duration (int): random select a segment with duration of {segment_duration}s.
91
+ Defualt to None which means the whole audio will be used.
92
+
93
+ Returns:
94
+ audio (np.ndarray): audio
95
+ """
96
+
97
+ audio, sr = soundfile.read(adfile)
98
+ if len(audio.shape) > 1:
99
+ audio = audio[:, 0]
100
+
101
+ if sampling_rate is not None and sr != sampling_rate:
102
+ audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
103
+ sr = sampling_rate
104
+
105
+ if segment_duration is not None:
106
+ seg_length = int(sr * segment_duration)
107
+ audio = random_select_audio_segment(audio, seg_length)
108
+
109
+ # Audio volume normalize
110
+ if volume_normalize:
111
+ audio = audio_volume_normalize(audio)
112
+ # check the audio length
113
+ if length is not None:
114
+ assert abs(audio.shape[0] - length) < 1000
115
+ if audio.shape[0] > length:
116
+ audio = audio[:length]
117
+ else:
118
+ audio = np.pad(audio, (0, int(length - audio.shape[0])))
119
+ return audio
120
+
121
+
122
+ def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
123
+ """get an audio segment given the length
124
+
125
+ Args:
126
+ audio (np.ndarray):
127
+ length (int): audio length = sampling_rate * duration
128
+ """
129
+ if audio.shape[0] < length:
130
+ audio = np.pad(audio, (0, int(length - audio.shape[0])))
131
+ start_index = random.randint(0, audio.shape[0] - length)
132
+ end_index = int(start_index + length)
133
+
134
+ return audio[start_index:end_index]
135
+
136
+
137
+ def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq):
138
+ """apply highpass fileter to audio
139
+
140
+ Args:
141
+ audio (np.ndarray):
142
+ sample_rate (ind):
143
+ highpass_cutoff_freq (int):
144
+ """
145
+
146
+ audio = torchaudio.functional.highpass_biquad(
147
+ torch.from_numpy(audio), sample_rate, cutoff_freq=highpass_cutoff_freq
148
+ )
149
+ return audio.numpy()
150
+
151
+
152
+ def stft(
153
+ x: torch.Tensor,
154
+ fft_size: int,
155
+ hop_size: int,
156
+ win_length: int,
157
+ window: str,
158
+ use_complex: bool = False,
159
+ ) -> torch.Tensor:
160
+ """Perform STFT and convert to magnitude spectrogram.
161
+ Args:
162
+ x (Tensor): Input signal tensor (B, T).
163
+ fft_size (int): FFT size.
164
+ hop_size (int): Hop size.
165
+ win_length (int): Window length.
166
+ window (str): Window function type.
167
+ Returns:
168
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
169
+ """
170
+
171
+ x_stft = torch.stft(
172
+ x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True
173
+ )
174
+
175
+ # clamp is needed to avoid nan or inf
176
+ if not use_complex:
177
+ return torch.sqrt(
178
+ torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3)
179
+ ).transpose(2, 1)
180
+ else:
181
+ res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1)
182
+ res = res.transpose(2, 3) # [B, 2, T, F]
183
+ return res
184
+
185
+
186
+ def detect_speech_boundaries(
187
+ wav: np.ndarray,
188
+ sample_rate: int,
189
+ window_duration: float = 0.1,
190
+ energy_threshold: float = 0.01,
191
+ margin_factor: int = 2
192
+ ) -> Tuple[int, int]:
193
+ """Detect the start and end points of speech in an audio signal using RMS energy.
194
+
195
+ Args:
196
+ wav: Input audio signal array with values in [-1, 1]
197
+ sample_rate: Audio sample rate in Hz
198
+ window_duration: Duration of detection window in seconds
199
+ energy_threshold: RMS energy threshold for speech detection
200
+ margin_factor: Factor to determine extra margin around detected boundaries
201
+
202
+ Returns:
203
+ tuple: (start_index, end_index) of speech segment
204
+
205
+ Raises:
206
+ ValueError: If the audio contains only silence
207
+ """
208
+ window_size = int(window_duration * sample_rate)
209
+ margin = margin_factor * window_size
210
+ step_size = window_size // 10
211
+
212
+ # Create sliding windows using stride tricks to avoid loops
213
+ windows = sliding_window_view(wav, window_size)[::step_size]
214
+
215
+ # Calculate RMS energy for each window
216
+ energy = np.sqrt(np.mean(windows ** 2, axis=1))
217
+ speech_mask = energy >= energy_threshold
218
+
219
+ if not np.any(speech_mask):
220
+ raise ValueError("No speech detected in audio (only silence)")
221
+
222
+ start = max(0, np.argmax(speech_mask) * step_size - margin)
223
+ end = min(len(wav), (len(speech_mask) - 1 - np.argmax(speech_mask[::-1])) * step_size + margin)
224
+
225
+ return start, end
226
+
227
+
228
+ def remove_silence_on_both_ends(
229
+ wav: np.ndarray,
230
+ sample_rate: int,
231
+ window_duration: float = 0.1,
232
+ volume_threshold: float = 0.01
233
+ ) -> np.ndarray:
234
+ """Remove silence from both ends of an audio signal.
235
+
236
+ Args:
237
+ wav: Input audio signal array
238
+ sample_rate: Audio sample rate in Hz
239
+ window_duration: Duration of detection window in seconds
240
+ volume_threshold: Amplitude threshold for silence detection
241
+
242
+ Returns:
243
+ np.ndarray: Audio signal with silence removed from both ends
244
+
245
+ Raises:
246
+ ValueError: If the audio contains only silence
247
+ """
248
+ start, end = detect_speech_boundaries(
249
+ wav,
250
+ sample_rate,
251
+ window_duration,
252
+ volume_threshold
253
+ )
254
+ return wav[start:end]
255
+
256
+
257
+
258
+ def hertz_to_mel(pitch: float) -> float:
259
+ """
260
+ Converts a frequency from the Hertz scale to the Mel scale.
261
+
262
+ Parameters:
263
+ - pitch: float or ndarray
264
+ Frequency in Hertz.
265
+
266
+ Returns:
267
+ - mel: float or ndarray
268
+ Frequency in Mel scale.
269
+ """
270
+ mel = 2595 * np.log10(1 + pitch / 700)
271
+ return mel
sparktts/utils/file.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Description:
17
+ This script contains a collection of functions designed to handle various
18
+ file reading and writing operations. It provides utilities to read from files,
19
+ write data to files, and perform file manipulation tasks.
20
+ """
21
+
22
+
23
+ import os
24
+ import json
25
+ import json
26
+ import csv
27
+
28
+ from tqdm import tqdm
29
+ from typing import List, Dict, Any, Set, Union
30
+ from pathlib import Path
31
+ from omegaconf import OmegaConf, DictConfig
32
+
33
+
34
+ def resolve_symbolic_link(symbolic_link_path: Path) -> Path:
35
+ """
36
+ Resolves the absolute path of a symbolic link.
37
+
38
+ Args:
39
+ symbolic_link_path (Path): The path to the symbolic link.
40
+
41
+ Returns:
42
+ Path: The absolute path that the symbolic link points to.
43
+ """
44
+
45
+ link_directory = os.path.dirname(symbolic_link_path)
46
+ target_path_relative = os.readlink(symbolic_link_path)
47
+ return os.path.join(link_directory, target_path_relative)
48
+
49
+
50
+ def write_jsonl(metadata: List[dict], file_path: Path) -> None:
51
+ """Writes a list of dictionaries to a JSONL file.
52
+
53
+ Args:
54
+ metadata : List[dict]
55
+ A list of dictionaries, each representing a piece of meta.
56
+ file_path : Path
57
+ The file path to save the JSONL file
58
+
59
+ This function writes each dictionary in the list to a new line in the specified file.
60
+ """
61
+ with open(file_path, "w", encoding="utf-8") as f:
62
+ for meta in tqdm(metadata, desc="writing jsonl"):
63
+ # Convert dictionary to JSON string and write it to the file with a newline
64
+ json_str = json.dumps(meta, ensure_ascii=False) + "\n"
65
+ f.write(json_str)
66
+ print(f"jsonl saved to {file_path}")
67
+
68
+
69
+ def read_jsonl(file_path: Path) -> List[dict]:
70
+ """
71
+ Reads a JSONL file and returns a list of dictionaries.
72
+
73
+ Args:
74
+ file_path : Path
75
+ The path to the JSONL file to be read.
76
+
77
+ Returns:
78
+ List[dict]
79
+ A list of dictionaries parsed from each line of the JSONL file.
80
+ """
81
+ metadata = []
82
+ # Open the file for reading
83
+ with open(file_path, "r", encoding="utf-8") as f:
84
+ # Split the file into lines
85
+ lines = f.read().splitlines()
86
+ # Process each line
87
+ for line in lines:
88
+ # Convert JSON string back to dictionary and append to list
89
+ meta = json.loads(line)
90
+ metadata.append(meta)
91
+ # Return the list of metadata
92
+ return metadata
93
+
94
+ def read_json_as_jsonl(file_path: Path) -> List[dict]:
95
+ metadata = []
96
+ with open(file_path, 'r', encoding='utf-8') as infile:
97
+ data = json.load(infile)
98
+ for k in sorted(data.keys()):
99
+ meta = {'index': k}
100
+ meta.update(data[k])
101
+ metadata.append(meta)
102
+ return metadata
103
+
104
+
105
+
106
+ def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]:
107
+ processed_meta = {}
108
+ for k, v in meta.items():
109
+ if isinstance(v, str):
110
+ processed_meta[k] = v.encode("utf-8").decode("unicode_escape")
111
+ else:
112
+ processed_meta[k] = v
113
+ return processed_meta
114
+
115
+
116
+ def load_config(config_path: Path) -> DictConfig:
117
+ """Loads a configuration file and optionally merges it with a base configuration.
118
+
119
+ Args:
120
+ config_path (Path): Path to the configuration file.
121
+ """
122
+ # Load the initial configuration from the given path
123
+ config = OmegaConf.load(config_path)
124
+
125
+ # Check if there is a base configuration specified and merge if necessary
126
+ if config.get("base_config", None) is not None:
127
+ base_config = OmegaConf.load(config["base_config"])
128
+ config = OmegaConf.merge(base_config, config)
129
+
130
+ return config
131
+
132
+
133
+
134
+ def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None:
135
+ """
136
+ Converts a JSONL file to a CSV file.
137
+
138
+ This function reads a JSONL file, determines all unique keys present in the file,
139
+ and writes the data to a CSV file with columns for all these keys.
140
+ """
141
+
142
+ all_keys = set()
143
+ data_rows = []
144
+
145
+ # Read the JSONL file once to extract keys and collect data
146
+ with open(jsonl_file_path, 'r') as file:
147
+ for line in file:
148
+ data = json.loads(line.strip())
149
+ data_rows.append(data)
150
+ all_keys.update(data.keys())
151
+
152
+ # Convert the set of keys to a sorted list for consistent column order
153
+ sorted_keys = sorted(all_keys)
154
+
155
+ # Write the data to a CSV file
156
+ with open(csv_file_path, 'w', newline='') as csvfile:
157
+ writer = csv.DictWriter(csvfile, fieldnames=sorted_keys)
158
+
159
+ # Write the header row
160
+ writer.writeheader()
161
+
162
+ # Write each row of data
163
+ for data in data_rows:
164
+ writer.writerow(data)
165
+
166
+ print(f"CSV file has been created at {csv_file_path}")
167
+
168
+
169
+ def save_metadata(data, filename, headers=None):
170
+ """
171
+ Save metadata to a file.
172
+
173
+ Args:
174
+ data (list of dict): Metadata to be saved.
175
+ filename (str): Name of the file to save the metadata.
176
+ headers (list of str): The order of column names to be saved; defaults to the keys from the first dictionary in data if not provided.
177
+ """
178
+ # Set headers to keys from the first dictionary in data if not explicitly provided
179
+ if headers is None:
180
+ headers = list(data[0].keys())
181
+
182
+ with open(filename, "w", encoding="utf-8") as file:
183
+ # Write the headers to the file
184
+ file.write("|".join(headers) + "\n")
185
+ for entry in data:
186
+ # Retrieve values in the order of headers, replacing any '|' characters with a space to prevent formatting errors
187
+ formatted_values = [str(entry.get(key, "")).replace("|", " ") for key in headers]
188
+ # Write the formatted values to the file
189
+ file.write("|".join(formatted_values) + "\n")
190
+
191
+
192
+ def read_metadata(filename, headers=None):
193
+ """
194
+ Read metadata from a file.
195
+
196
+ Args:
197
+ filename (str): The file from which to read the metadata.
198
+
199
+ Returns:
200
+ list of dict: The metadata read from the file.
201
+ list of str: The headers used in the file.
202
+ """
203
+ with open(filename, "r", encoding="utf-8") as file:
204
+ lines = file.readlines()
205
+
206
+ data = []
207
+ # Set headers from the first line of the file if not provided
208
+ if headers is None:
209
+ headers = lines[0].strip().split("|")
210
+ lines = lines[1:]
211
+
212
+ for line in lines:
213
+ line = line.strip()
214
+ # Skip empty lines
215
+ if not line:
216
+ continue
217
+ # Split the line by '|' and pair with headers to form a dictionary
218
+ entry_data = dict(zip(headers, line.split("|")))
219
+ data.append(entry_data)
220
+
221
+ return data, headers
sparktts/utils/parse_options.sh ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
4
+ # Arnab Ghoshal, Karel Vesely
5
+
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
+ # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14
+ # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15
+ # MERCHANTABLITY OR NON-INFRINGEMENT.
16
+ # See the Apache 2 License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ # Parse command-line options.
21
+ # To be sourced by another script (as in ". parse_options.sh").
22
+ # Option format is: --option-name arg
23
+ # and shell variable "option_name" gets set to value "arg."
24
+ # The exception is --help, which takes no arguments, but prints the
25
+ # $help_message variable (if defined).
26
+
27
+
28
+ ###
29
+ ### The --config file options have lower priority to command line
30
+ ### options, so we need to import them first...
31
+ ###
32
+
33
+ # Now import all the configs specified by command-line, in left-to-right order
34
+ # for ((argpos=1; argpos<$#; argpos++)); do
35
+ # if [ "${!argpos}" == "--config" ]; then
36
+ # argpos_plus1=$((argpos+1))
37
+ # config=${!argpos_plus1}
38
+ # [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
39
+ # . $config # source the config file.
40
+ # fi
41
+ # done
42
+
43
+
44
+ ###
45
+ ### No we process the command line options
46
+ ###
47
+ while true; do
48
+ [ -z "${1:-}" ] && break; # break if there are no arguments
49
+ case "$1" in
50
+ # If the enclosing script is called with --help option, print the help
51
+ # message and exit. Scripts should put help messages in $help_message
52
+ --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
53
+ else printf "$help_message\n" 1>&2 ; fi;
54
+ exit 0 ;;
55
+ --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
56
+ exit 1 ;;
57
+ # If the first command-line argument begins with "--" (e.g. --foo-bar),
58
+ # then work out the variable name as $name, which will equal "foo_bar".
59
+ --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
60
+ # Next we test whether the variable in question is undefned-- if so it's
61
+ # an invalid option and we die. Note: $0 evaluates to the name of the
62
+ # enclosing script.
63
+ # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
64
+ # is undefined. We then have to wrap this test inside "eval" because
65
+ # foo_bar is itself inside a variable ($name).
66
+ eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
67
+
68
+ oldval="`eval echo \\$$name`";
69
+ # Work out whether we seem to be expecting a Boolean argument.
70
+ if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
71
+ was_bool=true;
72
+ else
73
+ was_bool=false;
74
+ fi
75
+
76
+ # Set the variable to the right value-- the escaped quotes make it work if
77
+ # the option had spaces, like --cmd "queue.pl -sync y"
78
+ eval $name=\"$2\";
79
+
80
+ # Check that Boolean-valued arguments are really Boolean.
81
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
82
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
83
+ exit 1;
84
+ fi
85
+ shift 2;
86
+ ;;
87
+ *) break;
88
+ esac
89
+ done
90
+
91
+
92
+ # Check for an empty argument to the --cmd option, which can easily occur as a
93
+ # result of scripting errors.
94
+ [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
95
+
96
+
97
+ true; # so this script returns exit code 0.
sparktts/utils/token_parser.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK_TOKEN_MAP = {
2
+ "vc": "<|task_vc|>",
3
+ "tts": "<|task_tts|>",
4
+ "asr": "<|task_asr|>",
5
+ "s2s": "<|task_s2s|>",
6
+ "t2s": "<|task_t2s|>",
7
+ "understand": "<|task_understand|>",
8
+ "caption": "<|task_cap|>",
9
+ "controllable_tts": "<|task_controllable_tts|>",
10
+ "prompt_tts": "<|task_prompt_tts|>",
11
+ "speech_edit": "<|task_edit|>",
12
+ }
13
+
14
+ LEVELS_MAP = {
15
+ "very_low": 0,
16
+ "low": 1,
17
+ "moderate": 2,
18
+ "high": 3,
19
+ "very_high": 4,
20
+ }
21
+
22
+ LEVELS_MAP_UI = {
23
+ 1: 'very_low',
24
+ 2: 'low',
25
+ 3: 'moderate',
26
+ 4: 'high',
27
+ 5: 'very_high'
28
+ }
29
+
30
+ GENDER_MAP = {
31
+ "female": 0,
32
+ "male": 1,
33
+ }
34
+
35
+ AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
36
+
37
+ EMO_MAP = {
38
+ "UNKNOWN": 0,
39
+ "NEUTRAL": 1,
40
+ "ANGRY": 2,
41
+ "HAPPY": 3,
42
+ "SAD": 4,
43
+ "FEARFUL": 5,
44
+ "DISGUSTED": 6,
45
+ "SURPRISED": 7,
46
+ "SARCASTIC": 8,
47
+ "EXCITED": 9,
48
+ "SLEEPY": 10,
49
+ "CONFUSED": 11,
50
+ "EMPHASIS": 12,
51
+ "LAUGHING": 13,
52
+ "SINGING": 14,
53
+ "WORRIED": 15,
54
+ "WHISPER": 16,
55
+ "ANXIOUS": 17,
56
+ "NO-AGREEMENT": 18,
57
+ "APOLOGETIC": 19,
58
+ "CONCERNED": 20,
59
+ "ENUNCIATED": 21,
60
+ "ASSERTIVE": 22,
61
+ "ENCOURAGING": 23,
62
+ "CONTEMPT": 24,
63
+ }
64
+
65
+
66
+ class TokenParser:
67
+ """Turn label to special token"""
68
+
69
+ def __init__(self):
70
+ pass
71
+
72
+ """Parse the attributes of a person."""
73
+
74
+ def __init__(self):
75
+ pass
76
+
77
+ @staticmethod
78
+ def age(age: str) -> str:
79
+ """Turn age token."""
80
+ age_id = AGE_MAP[age]
81
+ return f"<|age_{age_id}|>"
82
+
83
+ @staticmethod
84
+ def gender(gender: str) -> str:
85
+ """Turn gender token."""
86
+ gender_id = GENDER_MAP[gender]
87
+ return f"<|gender_{gender_id}|>"
88
+
89
+ @staticmethod
90
+ def mel_value(mel: int):
91
+ """Turn special token of mel scale pitch."""
92
+ mel = max(0, int(mel))
93
+ mel = min(1000, int(mel))
94
+ return f"<|pitch_value_{mel}|>"
95
+
96
+ @staticmethod
97
+ def mel_level(level: str):
98
+ """Turn special token of mel level."""
99
+ level_tag = LEVELS_MAP[level]
100
+ return f"<|pitch_label_{level_tag}|>"
101
+
102
+ @staticmethod
103
+ def pitch_var_value(pitch_std: int):
104
+ """Turn special token of pitch_std value."""
105
+ assert isinstance(pitch_std, int)
106
+ pitch_std = max(0, int(pitch_std))
107
+ pitch_std = min(10, int(pitch_std))
108
+ return f"<|pitch_var_value_{pitch_std}|>"
109
+
110
+ @staticmethod
111
+ def pitch_var_level(level: str):
112
+ """Turn special token of pitch std level."""
113
+ level_tag = LEVELS_MAP[level]
114
+ return f"<|pitch_var_label_{level_tag}|>"
115
+
116
+ @staticmethod
117
+ def loudness_value(loudness: int):
118
+ """Turn special toak of loudness value [0, 30]"""
119
+ assert loudness >= 0
120
+ loudness = max(0, int(loudness))
121
+ loudness = min(30, int(loudness))
122
+ return f"<|loudness_value_{loudness}|>"
123
+
124
+ @staticmethod
125
+ def loudness_level(level: str):
126
+ """Turn special token of loudness level."""
127
+ level_tag = LEVELS_MAP[level]
128
+ return f"<|loudness_label_{level_tag}|>"
129
+
130
+ @staticmethod
131
+ def speed_value(speed: int):
132
+ """Turn special token of speed value."""
133
+ speed = max(0, int(speed))
134
+ speed = min(10, int(speed))
135
+ return f"<|speed_value_{speed}|>"
136
+
137
+ @staticmethod
138
+ def speed_level(level: str):
139
+ """Turn special token of speed level."""
140
+ level_tag = LEVELS_MAP[level]
141
+ return f"<|speed_label_{level_tag}|>"
142
+
143
+ @staticmethod
144
+ def task(task: str) -> str:
145
+ """Turn special token of task."""
146
+ assert task in TASK_TOKEN_MAP.keys()
147
+
148
+ return TASK_TOKEN_MAP[task]
149
+
150
+ @staticmethod
151
+ def emotion(emotion: str):
152
+ emo_id = EMO_MAP[emotion]
153
+
154
+ return f"<|emotion_{emo_id}|>"
155
+
156
+
157
+ # test
158
+ if __name__ == "__main__":
159
+ from transformers import AutoTokenizer
160
+
161
+ tokenizer = AutoTokenizer.from_pretrained(
162
+ "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer"
163
+ )
164
+
165
+ tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"]
166
+ ages = ["Child", "Teenager", "Youth-Adult", "Middle-aged", "Elderly"]
167
+ genders = ["female", "female", "female", "male", "male"]
168
+ mels = [100, 200, 300, 400, 500]
169
+ mel_levels = ["very_low", "low", "moderate", "high", "very_high"]
170
+ loudnesses = [1, 10, 23, 19, 30]
171
+ loudness_levels = ["very_low", "low", "moderate", "high", "very_high"]
172
+ emotions = ["UNKNOWN", "NEUTRAL", "ANGRY", "HAPPY", "SAD"]
173
+
174
+ for i in range(5):
175
+ task = TokenParser.task(tasks[i])
176
+ age = TokenParser.age(ages[i])
177
+ gender = TokenParser.gender(genders[i])
178
+ mel = TokenParser.mel_value(mels[i])
179
+ mel_level = TokenParser.mel_level(mel_levels[i])
180
+ loudness = TokenParser.loudness_value(loudnesses[i])
181
+ loudness_level = TokenParser.loudness_level(loudness_levels[i])
182
+ emotion = TokenParser.emotion(emotions[i])
183
+ inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion]
184
+ inputs = "".join(inputs)
185
+ ids = tokenizer.encode(inputs, add_special_tokens=False)
186
+ print(ids)
187
+ print("decode", tokenizer.decode(ids))