Steveeeeeeen HF Staff commited on
Commit
ad798d2
·
1 Parent(s): 1ccbf40
Files changed (42) hide show
  1. app.py +147 -0
  2. models/codec/amphion_codec/__pycache__/vocos.cpython-310.pyc +0 -0
  3. models/codec/amphion_codec/codec.py +422 -0
  4. models/codec/amphion_codec/loss.py +401 -0
  5. models/codec/amphion_codec/quantize/__init__.py +6 -0
  6. models/codec/amphion_codec/quantize/__pycache__/__init__.cpython-310.pyc +0 -0
  7. models/codec/amphion_codec/quantize/__pycache__/bsq.cpython-310.pyc +0 -0
  8. models/codec/amphion_codec/quantize/__pycache__/factorized_vector_quantize.cpython-310.pyc +0 -0
  9. models/codec/amphion_codec/quantize/__pycache__/lookup_free_quantize.cpython-310.pyc +0 -0
  10. models/codec/amphion_codec/quantize/__pycache__/residual_vq.cpython-310.pyc +0 -0
  11. models/codec/amphion_codec/quantize/__pycache__/vector_quantize.cpython-310.pyc +0 -0
  12. models/codec/amphion_codec/quantize/bsq.py +373 -0
  13. models/codec/amphion_codec/quantize/factorized_vector_quantize.py +145 -0
  14. models/codec/amphion_codec/quantize/lookup_free_quantize.py +72 -0
  15. models/codec/amphion_codec/quantize/residual_vq.py +172 -0
  16. models/codec/amphion_codec/quantize/vector_quantize.py +396 -0
  17. models/codec/amphion_codec/vocos.py +909 -0
  18. models/codec/melvqgan/__pycache__/melspec.cpython-310.pyc +0 -0
  19. models/codec/melvqgan/melspec.py +153 -0
  20. models/tts/llm_tts/__pycache__/chat_template.cpython-310.pyc +0 -0
  21. models/tts/llm_tts/__pycache__/inference_llm_tts.cpython-310.pyc +0 -0
  22. models/tts/llm_tts/__pycache__/inference_mgm_tts.cpython-310.pyc +0 -0
  23. models/tts/llm_tts/__pycache__/llama_nar_prefix.cpython-310.pyc +0 -0
  24. models/tts/llm_tts/__pycache__/mgm.cpython-310.pyc +0 -0
  25. models/tts/llm_tts/chat_template.py +96 -0
  26. models/tts/llm_tts/inference_llm_tts.py +265 -0
  27. models/tts/llm_tts/inference_mgm_tts.py +338 -0
  28. models/tts/llm_tts/llama_nar_prefix.py +457 -0
  29. models/tts/llm_tts/mgm.py +385 -0
  30. models/tts/tadicodec/__pycache__/infer_utils.cpython-310.pyc +0 -0
  31. models/tts/tadicodec/__pycache__/inference_tadicodec.cpython-310.pyc +0 -0
  32. models/tts/tadicodec/__pycache__/llama_nar_prefix.cpython-310.pyc +0 -0
  33. models/tts/tadicodec/__pycache__/modeling_tadicodec.cpython-310.pyc +0 -0
  34. models/tts/tadicodec/infer_utils.py +24 -0
  35. models/tts/tadicodec/inference_tadicodec.py +279 -0
  36. models/tts/tadicodec/llama_nar_prefix.py +572 -0
  37. models/tts/tadicodec/modeling_tadicodec.py +641 -0
  38. requirements.txt +9 -0
  39. utils/__pycache__/hparam.cpython-310.pyc +0 -0
  40. utils/__pycache__/util.cpython-310.pyc +0 -0
  41. utils/hparam.py +659 -0
  42. utils/util.py +687 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import soundfile as sf
4
+ import tempfile
5
+ import os
6
+ import warnings
7
+ warnings.filterwarnings("ignore")
8
+
9
+ try:
10
+ from models.tts.llm_tts.inference_llm_tts import TTSInferencePipeline
11
+ MODEL_AVAILABLE = True
12
+ except ImportError:
13
+ print("Warning: TaDiCodec models not found. Running in demo mode.")
14
+ MODEL_AVAILABLE = False
15
+
16
+ class TaDiCodecTTSDemo:
17
+ def __init__(self):
18
+ self.pipeline = None
19
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ self.load_model()
21
+
22
+ def load_model(self):
23
+ try:
24
+ if MODEL_AVAILABLE:
25
+ print("Loading TaDiCodec-TTS-AR-Qwen2.5-0.5B model...")
26
+ self.pipeline = TTSInferencePipeline.from_pretrained(
27
+ tadicodec_path="amphion/TaDiCodec",
28
+ llm_path="amphion/TaDiCodec-TTS-AR-Qwen2.5-0.5B",
29
+ device=self.device,
30
+ )
31
+ print("Model loaded successfully!")
32
+ else:
33
+ print("Running in demo mode - model files not available")
34
+ self.pipeline = None
35
+ except Exception as e:
36
+ print(f"Error loading model: {e}")
37
+ self.pipeline = None
38
+
39
+ def synthesize_speech(self, text, reference_audio=None, reference_text=""):
40
+ """
41
+ Synthesize speech from text using TaDiCodec TTS
42
+ """
43
+ if not text.strip():
44
+ return None, "Please enter some text to synthesize."
45
+
46
+ try:
47
+ if self.pipeline is not None:
48
+ # Use actual TaDiCodec inference
49
+ if reference_audio and reference_text.strip():
50
+ audio = self.pipeline(
51
+ text=text,
52
+ prompt_text=reference_text,
53
+ prompt_speech_path=reference_audio,
54
+ )
55
+ else:
56
+ audio = self.pipeline(text=text)
57
+
58
+ # Save to temporary file
59
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
60
+ sf.write(tmp_file.name, audio, 24000)
61
+ return tmp_file.name, "Speech synthesized successfully!"
62
+ else:
63
+ # Fallback demo mode - return None to indicate no audio generated
64
+ return None, "Demo mode - TaDiCodec model not available. Please install the required models."
65
+
66
+ except Exception as e:
67
+ return None, f"Error during synthesis: {str(e)}"
68
+
69
+ # Initialize the demo
70
+ demo_instance = TaDiCodecTTSDemo()
71
+
72
+ def tts_interface(text, reference_audio, reference_text):
73
+ """Interface function for Gradio"""
74
+ audio_path, message = demo_instance.synthesize_speech(text, reference_audio, reference_text)
75
+ return audio_path, message
76
+
77
+ # Create Gradio interface
78
+ with gr.Blocks(title="TaDiCodec-TTS Demo", theme=gr.themes.Soft()) as demo:
79
+ gr.Markdown("""
80
+ # TaDiCodec-TTS-AR-Qwen2.5-0.5B Demo
81
+
82
+ This is a demo of the TaDiCodec Text-to-Speech model with Qwen2.5-0.5B backbone.
83
+
84
+ **Features:**
85
+ - Voice cloning with reference audio
86
+ - Code-switching support (e.g., mixing English and Chinese)
87
+ - Extremely low bitrate (0.0875 kbps)
88
+ - High-quality speech generation
89
+ """)
90
+
91
+ with gr.Row():
92
+ with gr.Column():
93
+ text_input = gr.Textbox(
94
+ label="Text to Synthesize",
95
+ placeholder="Enter the text you want to convert to speech...",
96
+ lines=3,
97
+ value="但是 to those who 知道 her well, it was a 标志 of her unwavering 决心 and spirit."
98
+ )
99
+
100
+ reference_text = gr.Textbox(
101
+ label="Reference Text",
102
+ placeholder="Text corresponding to the reference audio for voice cloning...",
103
+ lines=2,
104
+ value="In short, we embarked on a mission to make America great again, for all Americans."
105
+ )
106
+
107
+ reference_audio = gr.Audio(
108
+ label="Reference Audio",
109
+ type="filepath",
110
+ )
111
+
112
+ synthesize_btn = gr.Button("Synthesize Speech", variant="primary")
113
+
114
+ with gr.Column():
115
+ output_audio = gr.Audio(
116
+ label="Generated Speech",
117
+ type="filepath"
118
+ )
119
+
120
+ status_message = gr.Textbox(
121
+ label="Status",
122
+ interactive=False
123
+ )
124
+
125
+ # Example inputs
126
+ gr.Markdown("### Example Inputs")
127
+ examples = [
128
+ ["Yes, usually people choose to face life with more positive emotions, after all, happy times are always yearning. However, sometimes slowing down and experiencing the details of life can bring deeper joy and satisfaction. What do you think?", "Jittery Jack's jam jars jiggled jauntily, jolting Jack's jumbled jelly-filled jars joyously.", "sample/tongueTwisters_en_018.wav"],
129
+ ["You think you can just waltz in here and cause chaos? Well, I've got news for you. This time, there's no escaping the consequences. So, one by one, step forward, and let's see who's bold enough to face the music. It's time for a little dose of reality—prepare to be dealt with!","Get in line trouble makers, and I will take care of you.", "sample/en_013.wav"],
130
+ ]
131
+
132
+ gr.Examples(
133
+ examples=examples,
134
+ inputs=[text_input, reference_text, reference_audio],
135
+ outputs=[output_audio, status_message],
136
+ fn=tts_interface
137
+ )
138
+
139
+ # Connect the interface
140
+ synthesize_btn.click(
141
+ fn=tts_interface,
142
+ inputs=[text_input, reference_audio, reference_text],
143
+ outputs=[output_audio, status_message]
144
+ )
145
+
146
+ if __name__ == "__main__":
147
+ demo.launch(share=True, server_name="0.0.0.0")
models/codec/amphion_codec/__pycache__/vocos.cpython-310.pyc ADDED
Binary file (26.3 kB). View file
 
models/codec/amphion_codec/codec.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from torch.nn.utils import weight_norm
8
+
9
+ from models.codec.amphion_codec.quantize import (
10
+ ResidualVQ,
11
+ VectorQuantize,
12
+ FactorizedVectorQuantize,
13
+ LookupFreeQuantize,
14
+ )
15
+
16
+ from models.codec.amphion_codec.vocos import Vocos
17
+
18
+
19
+ def WNConv1d(*args, **kwargs):
20
+ return weight_norm(nn.Conv1d(*args, **kwargs))
21
+
22
+
23
+ def WNConvTranspose1d(*args, **kwargs):
24
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
25
+
26
+
27
+ # Scripting this brings model speed up 1.4x
28
+ @torch.jit.script
29
+ def snake(x, alpha):
30
+ shape = x.shape
31
+ x = x.reshape(shape[0], shape[1], -1)
32
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
33
+ x = x.reshape(shape)
34
+ return x
35
+
36
+
37
+ class Snake1d(nn.Module):
38
+ def __init__(self, channels):
39
+ super().__init__()
40
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
41
+
42
+ def forward(self, x):
43
+ return snake(x, self.alpha)
44
+
45
+
46
+ def init_weights(m):
47
+ if isinstance(m, nn.Conv1d):
48
+ nn.init.trunc_normal_(m.weight, std=0.02)
49
+ nn.init.constant_(m.bias, 0)
50
+ if isinstance(m, nn.Linear):
51
+ nn.init.trunc_normal_(m.weight, std=0.02)
52
+ nn.init.constant_(m.bias, 0)
53
+
54
+
55
+ class ResidualUnit(nn.Module):
56
+ def __init__(self, dim: int = 16, dilation: int = 1):
57
+ super().__init__()
58
+ pad = ((7 - 1) * dilation) // 2
59
+ self.block = nn.Sequential(
60
+ Snake1d(dim),
61
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
62
+ Snake1d(dim),
63
+ WNConv1d(dim, dim, kernel_size=1),
64
+ )
65
+
66
+ def forward(self, x):
67
+ y = self.block(x)
68
+ pad = (x.shape[-1] - y.shape[-1]) // 2
69
+ if pad > 0:
70
+ x = x[..., pad:-pad]
71
+ return x + y
72
+
73
+
74
+ class EncoderBlock(nn.Module):
75
+ def __init__(self, dim: int = 16, stride: int = 1):
76
+ super().__init__()
77
+ self.block = nn.Sequential(
78
+ ResidualUnit(dim // 2, dilation=1),
79
+ ResidualUnit(dim // 2, dilation=3),
80
+ ResidualUnit(dim // 2, dilation=9),
81
+ Snake1d(dim // 2),
82
+ WNConv1d(
83
+ dim // 2,
84
+ dim,
85
+ kernel_size=2 * stride,
86
+ stride=stride,
87
+ padding=math.ceil(stride / 2),
88
+ ),
89
+ )
90
+
91
+ def forward(self, x):
92
+ return self.block(x)
93
+
94
+
95
+ class CodecEncoder(nn.Module):
96
+ def __init__(
97
+ self,
98
+ d_model: int = 64,
99
+ up_ratios: list = [4, 5, 5, 6],
100
+ out_channels: int = 256,
101
+ use_tanh: bool = False,
102
+ cfg=None,
103
+ ):
104
+ super().__init__()
105
+
106
+ d_model = cfg.d_model if cfg is not None else d_model
107
+ up_ratios = cfg.up_ratios if cfg is not None else up_ratios
108
+ out_channels = cfg.out_channels if cfg is not None else out_channels
109
+ use_tanh = cfg.use_tanh if cfg is not None else use_tanh
110
+
111
+ # Create first convolution
112
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
113
+
114
+ # Create EncoderBlocks that double channels as they downsample by `stride`
115
+ for stride in up_ratios:
116
+ d_model *= 2
117
+ self.block += [EncoderBlock(d_model, stride=stride)]
118
+
119
+ # Create last convolution
120
+ self.block += [
121
+ Snake1d(d_model),
122
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
123
+ ]
124
+
125
+ if use_tanh:
126
+ self.block += [nn.Tanh()]
127
+
128
+ # Wrap black into nn.Sequential
129
+ self.block = nn.Sequential(*self.block)
130
+ self.enc_dim = d_model
131
+
132
+ self.reset_parameters()
133
+
134
+ def forward(self, x):
135
+ return self.block(x)
136
+
137
+ def reset_parameters(self):
138
+ self.apply(init_weights)
139
+
140
+
141
+ class DecoderBlock(nn.Module):
142
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
143
+ super().__init__()
144
+ self.block = nn.Sequential(
145
+ Snake1d(input_dim),
146
+ WNConvTranspose1d(
147
+ input_dim,
148
+ output_dim,
149
+ kernel_size=2 * stride,
150
+ stride=stride,
151
+ padding=stride // 2 + stride % 2,
152
+ output_padding=stride % 2,
153
+ ),
154
+ ResidualUnit(output_dim, dilation=1),
155
+ ResidualUnit(output_dim, dilation=3),
156
+ ResidualUnit(output_dim, dilation=9),
157
+ )
158
+
159
+ def forward(self, x):
160
+ return self.block(x)
161
+
162
+
163
+ class CodecDecoder(nn.Module):
164
+ def __init__(
165
+ self,
166
+ in_channels: int = 256,
167
+ upsample_initial_channel: int = 1536,
168
+ up_ratios: list = [5, 5, 4, 2],
169
+ num_quantizers: int = 8,
170
+ codebook_size: int = 1024,
171
+ codebook_dim: int = 256,
172
+ quantizer_type: str = "vq",
173
+ quantizer_dropout: float = 0.5,
174
+ commitment: float = 0.25,
175
+ codebook_loss_weight: float = 1.0,
176
+ use_l2_normlize: bool = False,
177
+ codebook_type: str = "euclidean",
178
+ kmeans_init: bool = False,
179
+ kmeans_iters: int = 10,
180
+ decay: float = 0.8,
181
+ eps: float = 1e-5,
182
+ threshold_ema_dead_code: int = 2,
183
+ weight_init: bool = False,
184
+ use_vocos: bool = False,
185
+ vocos_dim: int = 384,
186
+ vocos_intermediate_dim: int = 1152,
187
+ vocos_num_layers: int = 8,
188
+ n_fft: int = 800,
189
+ hop_size: int = 200,
190
+ padding: str = "same",
191
+ cfg=None,
192
+ ):
193
+ super().__init__()
194
+
195
+ in_channels = (
196
+ cfg.in_channels
197
+ if cfg is not None and hasattr(cfg, "in_channels")
198
+ else in_channels
199
+ )
200
+ upsample_initial_channel = (
201
+ cfg.upsample_initial_channel
202
+ if cfg is not None and hasattr(cfg, "upsample_initial_channel")
203
+ else upsample_initial_channel
204
+ )
205
+ up_ratios = (
206
+ cfg.up_ratios
207
+ if cfg is not None and hasattr(cfg, "up_ratios")
208
+ else up_ratios
209
+ )
210
+ num_quantizers = (
211
+ cfg.num_quantizers
212
+ if cfg is not None and hasattr(cfg, "num_quantizers")
213
+ else num_quantizers
214
+ )
215
+ codebook_size = (
216
+ cfg.codebook_size
217
+ if cfg is not None and hasattr(cfg, "codebook_size")
218
+ else codebook_size
219
+ )
220
+ codebook_dim = (
221
+ cfg.codebook_dim
222
+ if cfg is not None and hasattr(cfg, "codebook_dim")
223
+ else codebook_dim
224
+ )
225
+ quantizer_type = (
226
+ cfg.quantizer_type
227
+ if cfg is not None and hasattr(cfg, "quantizer_type")
228
+ else quantizer_type
229
+ )
230
+ quantizer_dropout = (
231
+ cfg.quantizer_dropout
232
+ if cfg is not None and hasattr(cfg, "quantizer_dropout")
233
+ else quantizer_dropout
234
+ )
235
+ commitment = (
236
+ cfg.commitment
237
+ if cfg is not None and hasattr(cfg, "commitment")
238
+ else commitment
239
+ )
240
+ codebook_loss_weight = (
241
+ cfg.codebook_loss_weight
242
+ if cfg is not None and hasattr(cfg, "codebook_loss_weight")
243
+ else codebook_loss_weight
244
+ )
245
+ use_l2_normlize = (
246
+ cfg.use_l2_normlize
247
+ if cfg is not None and hasattr(cfg, "use_l2_normlize")
248
+ else use_l2_normlize
249
+ )
250
+ codebook_type = (
251
+ cfg.codebook_type
252
+ if cfg is not None and hasattr(cfg, "codebook_type")
253
+ else codebook_type
254
+ )
255
+ kmeans_init = (
256
+ cfg.kmeans_init
257
+ if cfg is not None and hasattr(cfg, "kmeans_init")
258
+ else kmeans_init
259
+ )
260
+ kmeans_iters = (
261
+ cfg.kmeans_iters
262
+ if cfg is not None and hasattr(cfg, "kmeans_iters")
263
+ else kmeans_iters
264
+ )
265
+ decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay
266
+ eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps
267
+ threshold_ema_dead_code = (
268
+ cfg.threshold_ema_dead_code
269
+ if cfg is not None and hasattr(cfg, "threshold_ema_dead_code")
270
+ else threshold_ema_dead_code
271
+ )
272
+ weight_init = (
273
+ cfg.weight_init
274
+ if cfg is not None and hasattr(cfg, "weight_init")
275
+ else weight_init
276
+ )
277
+ use_vocos = (
278
+ cfg.use_vocos
279
+ if cfg is not None and hasattr(cfg, "use_vocos")
280
+ else use_vocos
281
+ )
282
+ vocos_dim = (
283
+ cfg.vocos_dim
284
+ if cfg is not None and hasattr(cfg, "vocos_dim")
285
+ else vocos_dim
286
+ )
287
+ vocos_intermediate_dim = (
288
+ cfg.vocos_intermediate_dim
289
+ if cfg is not None and hasattr(cfg, "vocos_intermediate_dim")
290
+ else vocos_intermediate_dim
291
+ )
292
+ vocos_num_layers = (
293
+ cfg.vocos_num_layers
294
+ if cfg is not None and hasattr(cfg, "vocos_num_layers")
295
+ else vocos_num_layers
296
+ )
297
+ n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
298
+ hop_size = (
299
+ cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
300
+ )
301
+ padding = (
302
+ cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
303
+ )
304
+
305
+ if quantizer_type == "vq":
306
+ self.quantizer = ResidualVQ(
307
+ input_dim=in_channels,
308
+ num_quantizers=num_quantizers,
309
+ codebook_size=codebook_size,
310
+ codebook_dim=codebook_dim,
311
+ quantizer_type=quantizer_type,
312
+ quantizer_dropout=quantizer_dropout,
313
+ commitment=commitment,
314
+ codebook_loss_weight=codebook_loss_weight,
315
+ use_l2_normlize=use_l2_normlize,
316
+ codebook_type=codebook_type,
317
+ kmeans_init=kmeans_init,
318
+ kmeans_iters=kmeans_iters,
319
+ decay=decay,
320
+ eps=eps,
321
+ threshold_ema_dead_code=threshold_ema_dead_code,
322
+ weight_init=weight_init,
323
+ )
324
+ elif quantizer_type == "fvq":
325
+ self.quantizer = ResidualVQ(
326
+ input_dim=in_channels,
327
+ num_quantizers=num_quantizers,
328
+ codebook_size=codebook_size,
329
+ codebook_dim=codebook_dim,
330
+ quantizer_type=quantizer_type,
331
+ quantizer_dropout=quantizer_dropout,
332
+ commitment=commitment,
333
+ codebook_loss_weight=codebook_loss_weight,
334
+ use_l2_normlize=use_l2_normlize,
335
+ )
336
+ elif quantizer_type == "lfq":
337
+ self.quantizer = ResidualVQ(
338
+ input_dim=in_channels,
339
+ num_quantizers=num_quantizers,
340
+ codebook_size=codebook_size,
341
+ codebook_dim=codebook_dim,
342
+ quantizer_type=quantizer_type,
343
+ )
344
+ else:
345
+ raise ValueError(f"Unknown quantizer type {quantizer_type}")
346
+
347
+ if not use_vocos:
348
+ # Add first conv layer
349
+ channels = upsample_initial_channel
350
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
351
+
352
+ # Add upsampling + MRF blocks
353
+ for i, stride in enumerate(up_ratios):
354
+ input_dim = channels // 2**i
355
+ output_dim = channels // 2 ** (i + 1)
356
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
357
+
358
+ # Add final conv layer
359
+ layers += [
360
+ Snake1d(output_dim),
361
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
362
+ nn.Tanh(),
363
+ ]
364
+
365
+ self.model = nn.Sequential(*layers)
366
+
367
+ if use_vocos:
368
+ self.model = Vocos(
369
+ input_channels=in_channels,
370
+ dim=vocos_dim,
371
+ intermediate_dim=vocos_intermediate_dim,
372
+ num_layers=vocos_num_layers,
373
+ adanorm_num_embeddings=None,
374
+ n_fft=n_fft,
375
+ hop_size=hop_size,
376
+ padding=padding,
377
+ )
378
+
379
+ self.reset_parameters()
380
+
381
+ def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None):
382
+ """
383
+ if vq is True, x = encoder output, then return quantized output;
384
+ else, x = quantized output, then return decoder output
385
+ """
386
+ if vq is True:
387
+ if eval_vq:
388
+ self.quantizer.eval()
389
+ (
390
+ quantized_out,
391
+ all_indices,
392
+ all_commit_losses,
393
+ all_codebook_losses,
394
+ all_quantized,
395
+ ) = self.quantizer(x, n_quantizers=n_quantizers)
396
+ return (
397
+ quantized_out,
398
+ all_indices,
399
+ all_commit_losses,
400
+ all_codebook_losses,
401
+ all_quantized,
402
+ )
403
+
404
+ return self.model(x)
405
+
406
+ def quantize(self, x, n_quantizers=None):
407
+ self.quantizer.eval()
408
+ quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers)
409
+ return quantized_out, vq
410
+
411
+ # TODO: check consistency of vq2emb and quantize
412
+ def vq2emb(self, vq, n_quantizers=None):
413
+ return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers)
414
+
415
+ def decode(self, x):
416
+ return self.model(x)
417
+
418
+ def latent2dist(self, x, n_quantizers=None):
419
+ return self.quantizer.latent2dist(x, n_quantizers=n_quantizers)
420
+
421
+ def reset_parameters(self):
422
+ self.apply(init_weights)
models/codec/amphion_codec/loss.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchaudio
7
+ from torchaudio.transforms import MelSpectrogram
8
+ from einops import rearrange
9
+ from typing import List
10
+
11
+
12
+ def stft(x, fft_size, hop_size, win_length, window, use_complex=False):
13
+ """Perform STFT and convert to magnitude spectrogram.
14
+ Args:
15
+ x (Tensor): Input signal tensor (B, T).
16
+ fft_size (int): FFT size.
17
+ hop_size (int): Hop size.
18
+ win_length (int): Window length.
19
+ window (str): Window function type.
20
+ Returns:
21
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
22
+ """
23
+
24
+ x_stft = torch.stft(
25
+ x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True
26
+ )
27
+
28
+ # clamp is needed to avoid nan or inf
29
+ if not use_complex:
30
+ return torch.sqrt(
31
+ torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3)
32
+ ).transpose(2, 1)
33
+ else:
34
+ res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1)
35
+ res = res.transpose(2, 3) # [B, 2, T, F]
36
+ return res
37
+
38
+
39
+ def compute_mag_scale(n_fft, sampling_rate):
40
+ frequencies = librosa.fft_frequencies(sr=sampling_rate, n_fft=n_fft)
41
+ frequencies = np.where(frequencies > 1e-10, frequencies, -10)
42
+ db_scale = librosa.frequency_weighting(frequencies).reshape(1, 1, -1)
43
+ mag_scale = np.sqrt(librosa.db_to_power(db_scale)).astype(np.float32)
44
+ return torch.from_numpy(mag_scale)
45
+
46
+
47
+ class SpectralConvergenceLoss(torch.nn.Module):
48
+ """Spectral convergence loss module."""
49
+
50
+ def __init__(self):
51
+ """Initialize spectral convergence loss module."""
52
+ super(SpectralConvergenceLoss, self).__init__()
53
+
54
+ def forward(self, x_mag, y_mag):
55
+ """Calculate forward propagation.
56
+ Args:
57
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
58
+ y_mag (Tensor): Magnitude spectrogram of ground-truth signal (B, #frames, #freq_bins).
59
+ Returns:
60
+ Tensor: Spectral convergence loss value.
61
+ """
62
+ return torch.norm(y_mag - x_mag) / torch.norm(y_mag)
63
+
64
+
65
+ class LogSTFTMagnitudeLoss(torch.nn.Module):
66
+ """Log STFT magnitude loss module."""
67
+
68
+ def __init__(self):
69
+ """Initialize los STFT magnitude loss module."""
70
+ super(LogSTFTMagnitudeLoss, self).__init__()
71
+
72
+ def forward(self, x_mag, y_mag):
73
+ """Calculate forward propagation.
74
+ Args:
75
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
76
+ y_mag (Tensor): Magnitude spectrogram of ground-truth signal (B, #frames, #freq_bins).
77
+ Returns:
78
+ Tensor: Log STFT magnitude loss value.
79
+ """
80
+ return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
81
+
82
+
83
+ class STFTLoss(torch.nn.Module):
84
+ """STFT loss module."""
85
+
86
+ def __init__(
87
+ self,
88
+ fft_size=1024,
89
+ hop_length=120,
90
+ win_length=600,
91
+ sampling_rate=16000,
92
+ window="hann_window",
93
+ cfg=None,
94
+ ):
95
+ """Initialize STFT loss module."""
96
+ super(STFTLoss, self).__init__()
97
+
98
+ fft_size = (
99
+ cfg.fft_size if cfg is not None and hasattr(cfg, "fft_size") else fft_size
100
+ )
101
+ hop_length = (
102
+ cfg.hop_length
103
+ if cfg is not None and hasattr(cfg, "hop_length")
104
+ else hop_length
105
+ )
106
+ win_length = (
107
+ cfg.win_length
108
+ if cfg is not None and hasattr(cfg, "win_length")
109
+ else win_length
110
+ )
111
+ window = cfg.window if cfg is not None and hasattr(cfg, "window") else window
112
+
113
+ self.fft_size = fft_size
114
+ self.hop_length = hop_length
115
+ self.win_length = win_length
116
+ self.window = getattr(torch, window)(win_length)
117
+ self.spectral_convergence_loss = SpectralConvergenceLoss()
118
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
119
+
120
+ self.register_buffer("mag_scale", compute_mag_scale(fft_size, sampling_rate))
121
+
122
+ def forward(self, x, y):
123
+ """Calculate forward propagation.
124
+ Args:
125
+ x (Tensor): Predicted signal (B, T).
126
+ y (Tensor): Ground truth signal (B, T).
127
+ Returns:
128
+ Tensor: Spectral convergence loss value.
129
+ Tensor: Log STFT magnitude loss value.
130
+ """
131
+ x_mag = (
132
+ stft(x, self.fft_size, self.hop_length, self.win_length, self.window)
133
+ * self.mag_scale
134
+ )
135
+ y_mag = (
136
+ stft(y, self.fft_size, self.hop_length, self.win_length, self.window)
137
+ * self.mag_scale
138
+ )
139
+ sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
140
+ log_mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
141
+
142
+ return sc_loss, log_mag_loss
143
+
144
+
145
+ class MultiResolutionSTFTLoss(torch.nn.Module):
146
+ """Multi resolution STFT loss module."""
147
+
148
+ def __init__(
149
+ self,
150
+ fft_sizes=(1024, 2048, 512),
151
+ hop_sizes=(120, 240, 50),
152
+ win_lengths=(600, 1200, 240),
153
+ window="hann_window",
154
+ sampling_rate=16000,
155
+ cfg=None,
156
+ ):
157
+ """Initialize Multi resolution STFT loss module.
158
+ Args:
159
+ fft_sizes (list): List of FFT sizes.
160
+ hop_sizes (list): List of hop sizes.
161
+ win_lengths (list): List of window lengths.
162
+ window (str): Window function type.
163
+ """
164
+ super(MultiResolutionSTFTLoss, self).__init__()
165
+
166
+ fft_sizes = (
167
+ cfg.fft_sizes
168
+ if cfg is not None and hasattr(cfg, "fft_sizes")
169
+ else fft_sizes
170
+ )
171
+ hop_sizes = (
172
+ cfg.hop_sizes
173
+ if cfg is not None and hasattr(cfg, "hop_sizes")
174
+ else hop_sizes
175
+ )
176
+ win_lengths = (
177
+ cfg.win_lengths
178
+ if cfg is not None and hasattr(cfg, "win_lengths")
179
+ else win_lengths
180
+ )
181
+ window = cfg.window if cfg is not None and hasattr(cfg, "window") else window
182
+ sampling_rate = (
183
+ cfg.sampling_rate
184
+ if cfg is not None and hasattr(cfg, "sampling_rate")
185
+ else sampling_rate
186
+ )
187
+
188
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
189
+ self.stft_losses = torch.nn.ModuleList()
190
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
191
+ self.stft_losses += [
192
+ STFTLoss(fs, ss, wl, window=window, sampling_rate=sampling_rate)
193
+ ]
194
+
195
+ def forward(self, x, y):
196
+ """Calculate forward propagation.
197
+ Args:
198
+ x (Tensor): Predicted signal (B, T).
199
+ y (Tensor): GroundTruth signal (B, T).
200
+ Returns:
201
+ Tensor: Multi resolution spectral convergence loss value.
202
+ Tensor: Multi resolution log STFT magnitude loss value.
203
+ """
204
+ sc_loss = 0.0
205
+ mag_loss = 0.0
206
+ for f in self.stft_losses:
207
+ sc_l, mag_l = f(x, y)
208
+ sc_loss += sc_l
209
+ mag_loss += mag_l
210
+ sc_loss /= len(self.stft_losses)
211
+ mag_loss /= len(self.stft_losses)
212
+
213
+ return sc_loss, mag_loss
214
+
215
+
216
+ class MultiResolutionMelSpectrogramLoss(nn.Module):
217
+ """Compute distance between mel spectrograms. Can be used
218
+ in a multi-scale way.
219
+
220
+ Parameters
221
+ ----------
222
+ n_mels : List[int]
223
+ Number of mels per STFT, by default [150, 80],
224
+ window_lengths : List[int], optional
225
+ Length of each window of each STFT, by default [2048, 512]
226
+ loss_fn : typing.Callable, optional
227
+ How to compare each loss, by default nn.L1Loss()
228
+ clamp_eps : float, optional
229
+ Clamp on the log magnitude, below, by default 1e-5
230
+ mag_weight : float, optional
231
+ Weight of raw magnitude portion of loss, by default 1.0
232
+ log_weight : float, optional
233
+ Weight of log magnitude portion of loss, by default 1.0
234
+ pow : float, optional
235
+ Power to raise magnitude to before taking log, by default 2.0
236
+ weight : float, optional
237
+ Weight of this loss, by default 1.0
238
+ match_stride : bool, optional
239
+ Whether to match the stride of convolutional layers, by default False
240
+
241
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ sample_rate=16000,
247
+ n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
248
+ window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
249
+ clamp_eps: float = 1e-5,
250
+ mag_weight: float = 0.0,
251
+ log_weight: float = 1.0,
252
+ pow: float = 1.0,
253
+ mel_fmin: List[float] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
254
+ mel_fmax: List[float] = [None, None, None, None, None, None, None],
255
+ cfg=None,
256
+ ):
257
+ super().__init__()
258
+
259
+ sample_rate = (
260
+ cfg.sample_rate
261
+ if cfg is not None and hasattr(cfg, "sample_rate")
262
+ else sample_rate
263
+ )
264
+ n_mels = cfg.n_mels if cfg is not None and hasattr(cfg, "n_mels") else n_mels
265
+ window_lengths = (
266
+ cfg.window_lengths
267
+ if cfg is not None and hasattr(cfg, "window_lengths")
268
+ else window_lengths
269
+ )
270
+ clamp_eps = (
271
+ cfg.clamp_eps
272
+ if cfg is not None and hasattr(cfg, "clamp_eps")
273
+ else clamp_eps
274
+ )
275
+ mag_weight = (
276
+ cfg.mag_weight
277
+ if cfg is not None and hasattr(cfg, "mag_weight")
278
+ else mag_weight
279
+ )
280
+ log_weight = (
281
+ cfg.log_weight
282
+ if cfg is not None and hasattr(cfg, "log_weight")
283
+ else log_weight
284
+ )
285
+ pow = cfg.pow if cfg is not None and hasattr(cfg, "pow") else pow
286
+ mel_fmin = (
287
+ cfg.mel_fmin if cfg is not None and hasattr(cfg, "mel_fmin") else mel_fmin
288
+ )
289
+ mel_fmax = (
290
+ cfg.mel_fmax if cfg is not None and hasattr(cfg, "mel_fmax") else mel_fmax
291
+ )
292
+
293
+ self.mel_transforms = nn.ModuleList(
294
+ [
295
+ MelSpectrogram(
296
+ sample_rate=sample_rate,
297
+ n_fft=window_length,
298
+ hop_length=window_length // 4,
299
+ n_mels=n_mel,
300
+ power=1.0,
301
+ center=True,
302
+ norm="slaney",
303
+ mel_scale="slaney",
304
+ )
305
+ for n_mel, window_length in zip(n_mels, window_lengths)
306
+ ]
307
+ )
308
+ self.n_mels = n_mels
309
+ self.loss_fn = nn.L1Loss()
310
+ self.clamp_eps = clamp_eps
311
+ self.log_weight = log_weight
312
+ self.mag_weight = mag_weight
313
+ self.mel_fmin = mel_fmin
314
+ self.mel_fmax = mel_fmax
315
+ self.pow = pow
316
+
317
+ def delta(self, x, k):
318
+ l = x.shape[1]
319
+ return x[:, 0 : l - k] - x[:, k:l]
320
+
321
+ def forward(self, x, y, mask=None):
322
+ """Computes mel loss between an estimate and a reference
323
+ signal.
324
+
325
+ Parameters
326
+ ----------
327
+ x : AudioSignal
328
+ Estimate signal
329
+ y : AudioSignal
330
+ Reference signal
331
+
332
+ Returns
333
+ -------
334
+ torch.Tensor
335
+ Mel loss.
336
+ """
337
+ loss = 0.0
338
+ for mel_transform in self.mel_transforms:
339
+ x_mel = mel_transform(x)
340
+ y_mel = mel_transform(y)
341
+ log_x_mel = x_mel.clamp(self.clamp_eps).pow(self.pow).log10()
342
+ log_y_mel = y_mel.clamp(self.clamp_eps).pow(self.pow).log10()
343
+ loss += self.log_weight * self.loss_fn(log_x_mel, log_y_mel)
344
+ loss += self.mag_weight * self.loss_fn(x_mel, y_mel)
345
+ # loss += self.loss_fn(self.delta(log_x_mel, 1), self.delta(log_y_mel, 1))
346
+ # log_x_mel = rearrange(log_x_mel, 'b c t -> b t c')
347
+ # log_y_mel = rearrange(log_y_mel, 'b c t -> b t c')
348
+ # for i in range(3):
349
+ # loss += self.loss_fn(self.delta(log_x_mel, i), self.delta(log_y_mel, i))
350
+ # loss /= len(self.mel_transforms)
351
+ return loss
352
+
353
+
354
+ class GANLoss(nn.Module):
355
+ def __init__(self, mode="lsgan"):
356
+ super(GANLoss, self).__init__()
357
+ assert mode in ["lsgan", "lsgan_std", "hinge"]
358
+ self.mode = mode
359
+
360
+ def disc_loss(self, real, fake):
361
+ if self.mode == "lsgan":
362
+ real_loss = F.mse_loss(real, torch.ones_like(real))
363
+ fake_loss = F.mse_loss(fake, torch.zeros_like(fake))
364
+ elif self.mode == "lsgan_std":
365
+ real = (real - 1.0).pow(2)
366
+ fake = (fake - 0.0).pow(2)
367
+ real_loss = real.mean() + real.std()
368
+ fake_loss = fake.mean() + fake.std()
369
+ elif self.mode == "hinge":
370
+ real_loss = torch.relu(1.0 - real).mean()
371
+ fake_loss = torch.relu(1.0 + fake).mean()
372
+ else:
373
+ raise ValueError(f"no such mode {self.mode}")
374
+
375
+ return real_loss, fake_loss
376
+
377
+ def disc_loss2(self, fake):
378
+ if self.mode == "lsgan":
379
+ fake_loss = F.mse_loss(fake, torch.zeros_like(fake))
380
+ elif self.mode == "lsgan_std":
381
+ fake = (fake - 0.0).pow(2)
382
+ fake_loss = fake.mean() + fake.std()
383
+ elif self.mode == "hinge":
384
+ fake_loss = torch.relu(1.0 + fake).mean()
385
+ else:
386
+ raise ValueError(f"no such mode {self.mode}")
387
+
388
+ return fake_loss
389
+
390
+ def gen_loss(self, fake):
391
+ if self.mode == "lsgan":
392
+ gen_loss = F.mse_loss(fake, torch.ones_like(fake))
393
+ elif self.mode == "lsgan_std":
394
+ fake = (fake - 1.0).pow(2)
395
+ gen_loss = fake.mean() + fake.std()
396
+ elif self.mode == "hinge":
397
+ gen_loss = -fake.mean()
398
+ else:
399
+ raise ValueError(f"no such mode {self.mode}")
400
+
401
+ return gen_loss
models/codec/amphion_codec/quantize/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
2
+ FactorizedVectorQuantize,
3
+ )
4
+ from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
5
+ from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
6
+ from models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
models/codec/amphion_codec/quantize/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (551 Bytes). View file
 
models/codec/amphion_codec/quantize/__pycache__/bsq.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
models/codec/amphion_codec/quantize/__pycache__/factorized_vector_quantize.cpython-310.pyc ADDED
Binary file (4.08 kB). View file
 
models/codec/amphion_codec/quantize/__pycache__/lookup_free_quantize.cpython-310.pyc ADDED
Binary file (2.14 kB). View file
 
models/codec/amphion_codec/quantize/__pycache__/residual_vq.cpython-310.pyc ADDED
Binary file (4.4 kB). View file
 
models/codec/amphion_codec/quantize/__pycache__/vector_quantize.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
models/codec/amphion_codec/quantize/bsq.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from https://github.com/zhaoyue-zephyrus/bsq-vit/blob/main/transcoder/models/quantizer/bsq.py
2
+
3
+ from einops import rearrange, reduce
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Function
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class DifferentiableEntropyFunction(Function):
11
+ @staticmethod
12
+ def forward(ctx, zq, basis, K, eps):
13
+ zb = (zq + 1) / 2
14
+ zi = ((zb * basis).sum(-1)).to(torch.int64)
15
+ cnt = torch.scatter_reduce(
16
+ torch.zeros(2**K, device=zq.device, dtype=zq.dtype),
17
+ 0,
18
+ zi.flatten(),
19
+ torch.ones_like(zi.flatten()).to(zq.dtype),
20
+ "sum",
21
+ )
22
+ prob = (cnt + eps) / (cnt + eps).sum()
23
+ H = -(prob * torch.log(prob)).sum()
24
+ ctx.save_for_backward(zq, zi, prob)
25
+ ctx.K = K
26
+ return H
27
+
28
+ @staticmethod
29
+ def backward(ctx, grad_output):
30
+ zq, zi, prob = ctx.saved_tensors
31
+ grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
32
+ reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
33
+ grad_input = reord_grad.unsqueeze(-1) * zq
34
+ return grad_input, None, None, None, None
35
+
36
+
37
+ def codebook_entropy(zq, basis, K, eps=1e-4):
38
+ return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
39
+
40
+
41
+ class SimpleQuantizer(nn.Module):
42
+ def __init__(
43
+ self,
44
+ embed_dim=14,
45
+ codebook_size=16384,
46
+ commitment=0.25,
47
+ codebook_loss_weight=1.0,
48
+ use_l2_normlize=True,
49
+ ):
50
+ super().__init__()
51
+
52
+ self.embed_dim = embed_dim
53
+ self.codebook_dim = embed_dim
54
+ self.codebook_size = codebook_size
55
+ self.commitment = commitment
56
+ self.codebook_loss_weight = codebook_loss_weight
57
+ self.use_l2_normlize = use_l2_normlize
58
+
59
+ self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
60
+
61
+ def forward(self, z):
62
+
63
+ z_e = z
64
+
65
+ z_q, indices = self.decode_latents(z_e)
66
+
67
+ if self.training:
68
+ commit_loss = (
69
+ F.mse_loss(z_e, z_q.detach(), reduction="mean") * self.commitment
70
+ )
71
+ codebook_loss = (
72
+ F.mse_loss(z_q, z_e.detach(), reduction="mean")
73
+ * self.codebook_loss_weight
74
+ )
75
+ else:
76
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
77
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
78
+
79
+ z_q = z_e + (z_q - z_e).detach()
80
+
81
+ return (
82
+ z_q,
83
+ commit_loss + codebook_loss,
84
+ {
85
+ "commit_loss": commit_loss,
86
+ "codebook_loss": codebook_loss,
87
+ "indices": indices,
88
+ },
89
+ )
90
+
91
+ def embed_code(self, embed_id):
92
+ return F.embedding(embed_id, self.codebook.weight)
93
+
94
+ def decode_code(self, embed_id):
95
+ return self.embed_code(embed_id)
96
+
97
+ def get_codebook_entry(self, indices):
98
+ return self.embed_code(indices)
99
+
100
+ def decode_latents(self, latents):
101
+ encodings = rearrange(latents, "b t d -> (b t) d")
102
+
103
+ codebook = self.codebook.weight
104
+
105
+ if self.use_l2_normlize:
106
+ # encodings = F.normalize(encodings) we have already normalized the latent before vq
107
+ codebook = F.normalize(codebook)
108
+
109
+ # Compute euclidean distance between encodings and codebook,
110
+ # if use_l2_normlize is True, the distance is equal to cosine distance
111
+ dist = (
112
+ encodings.pow(2).sum(1, keepdim=True)
113
+ - 2 * encodings @ codebook.t()
114
+ + codebook.pow(2).sum(1, keepdim=True).t()
115
+ )
116
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
117
+ z_q = self.decode_code(indices)
118
+
119
+ return z_q, indices
120
+
121
+ def vq2emb(self, vq, out_proj=True):
122
+ emb = self.decode_code(vq)
123
+ return emb
124
+
125
+ def latent2dist(self, latents):
126
+
127
+ encodings = latents
128
+ codebook = self.codebook.weight
129
+
130
+ # L2 normalize encodings and codebook
131
+ if self.use_l2_normlize:
132
+ # encodings = F.normalize(encodings)
133
+ codebook = F.normalize(codebook)
134
+
135
+ # Compute euclidean distance between encodings and codebook,
136
+ # if use_l2_normlize is True, the distance is equal to cosine distance
137
+ dist = (
138
+ encodings.pow(2).sum(1, keepdim=True)
139
+ - 2 * encodings @ codebook.t()
140
+ + codebook.pow(2).sum(1, keepdim=True).t()
141
+ ) # (b*t, k)
142
+
143
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
144
+ dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
145
+ z_q = self.decode_code(indices)
146
+
147
+ return -dist, indices, z_q
148
+
149
+
150
+ class BinarySphericalQuantizer(nn.Module):
151
+ def __init__(
152
+ self,
153
+ embed_dim=14,
154
+ beta=0.0,
155
+ gamma0=1.0,
156
+ gamma=1.0,
157
+ zeta=1.0,
158
+ input_format="blc",
159
+ soft_entropy=True,
160
+ group_size=1,
161
+ persample_entropy_compute="group",
162
+ cb_entropy_compute="group",
163
+ l2_norm=True,
164
+ inv_temperature=1,
165
+ ):
166
+ super().__init__()
167
+ self.embed_dim = embed_dim
168
+ self.beta = beta # loss weight for commit loss
169
+ self.gamma0 = gamma0 # loss weight for entropy penalty
170
+ self.gamma = gamma # loss weight for entropy penalty
171
+ self.zeta = zeta # loss weight for entire entropy penalty
172
+ self.input_format = input_format
173
+ assert (
174
+ self.embed_dim % group_size == 0
175
+ ), "embed_dim must be divisible by group_size"
176
+ self.num_groups = self.embed_dim // group_size
177
+ self.group_size = group_size
178
+ assert persample_entropy_compute in [
179
+ "group",
180
+ "analytical",
181
+ ], "persample_entropy_compute must be either 'group' or 'analytical'"
182
+ assert cb_entropy_compute in [
183
+ "group",
184
+ "nce",
185
+ ], "cb_entropy_compute must be either 'group' or 'nce'"
186
+ self.persample_entropy_compute = persample_entropy_compute
187
+ self.cb_entropy_compute = cb_entropy_compute
188
+ self.l2_norm = l2_norm
189
+ self.inv_temperature = inv_temperature
190
+
191
+ self.register_buffer("basis", 2 ** torch.arange(embed_dim - 1, -1, -1))
192
+ self.register_buffer("group_basis", 2 ** torch.arange(group_size - 1, -1, -1))
193
+
194
+ self.num_dimensions = 2**embed_dim
195
+ self.bits_per_index = embed_dim
196
+
197
+ # we only need to keep the codebook portion up to the group size
198
+ # because we approximate the H loss with this subcode
199
+ group_codes = torch.arange(2**self.group_size)
200
+ group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
201
+ self.register_buffer("group_codebook", group_codebook, persistent=False)
202
+
203
+ self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf
204
+
205
+ def quantize(self, z):
206
+ assert (
207
+ z.shape[-1] == self.embed_dim
208
+ ), f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
209
+
210
+ zhat = torch.where(
211
+ z > 0,
212
+ torch.tensor(1, dtype=z.dtype, device=z.device),
213
+ torch.tensor(-1, dtype=z.dtype, device=z.device),
214
+ )
215
+ return z + (zhat - z).detach()
216
+
217
+ def forward(self, z):
218
+ if self.input_format == "bchw":
219
+ z = rearrange(z, "b c h w -> b h w c")
220
+ zq = self.quantize(z)
221
+
222
+ indices = self.codes_to_indexes(zq.detach())
223
+ group_indices = self.codes_to_group_indexes(zq.detach())
224
+ if not self.training:
225
+ used_codes = torch.unique(indices, return_counts=False)
226
+ else:
227
+ used_codes = None
228
+
229
+ q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0
230
+
231
+ if self.soft_entropy:
232
+ persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
233
+ entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
234
+ else:
235
+ zb_by_sample = (
236
+ ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
237
+ )
238
+ persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
239
+ cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
240
+ entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
241
+
242
+ zq = zq * q_scale
243
+
244
+ # commit loss
245
+ # commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
246
+ commit_loss = torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
247
+
248
+ if self.input_format == "bchw":
249
+ zq = rearrange(zq, "b h w c -> b c h w")
250
+
251
+ return (
252
+ zq,
253
+ commit_loss * self.beta
254
+ + self.zeta * entropy_penalty / self.inv_temperature,
255
+ {
256
+ "H": cb_entropy,
257
+ "used_codes": used_codes,
258
+ "indices": indices,
259
+ "group_indices": group_indices,
260
+ "avg_prob": avg_prob,
261
+ "commit_loss": commit_loss,
262
+ },
263
+ )
264
+
265
+ def soft_entropy_loss(self, z):
266
+ # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
267
+ # the sub-code is the last group_size bits of the full code
268
+ group_code_book = self.group_codebook / (
269
+ self.embed_dim**0.5 if self.l2_norm else 1
270
+ )
271
+ divided_z = rearrange(z, "... (g c) -> ... g c", c=self.group_size)
272
+
273
+ # we calculate the distance between the divided_z and the codebook for each subgroup
274
+ distance = -2 * torch.einsum(
275
+ "... g c, d c ->... g d", divided_z, group_code_book
276
+ )
277
+ prob = (-distance * self.inv_temperature).softmax(dim=-1)
278
+ if self.persample_entropy_compute == "analytical":
279
+ if self.l2_norm:
280
+ p = torch.sigmoid(-4 * z / (self.embed_dim**0.5) * self.inv_temperature)
281
+ else:
282
+ p = torch.sigmoid(-4 * z * self.inv_temperature)
283
+ prob = torch.stack([p, 1 - p], dim=-1)
284
+ per_sample_entropy = (
285
+ self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
286
+ )
287
+ else:
288
+ per_sample_entropy = (
289
+ self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
290
+ )
291
+
292
+ # macro average of the probability of each subgroup
293
+ avg_prob = reduce(prob, "... g d ->g d", "mean")
294
+ codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
295
+
296
+ # the approximation of the entropy is the sum of the entropy of each subgroup
297
+ return per_sample_entropy, codebook_entropy.sum(), avg_prob
298
+
299
+ def get_hard_per_sample_entropy(self, zb_by_sample):
300
+ probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
301
+ persample_entropy = -probs_per_dim * torch.log(probs_per_dim + 1e-8) - (
302
+ 1 - probs_per_dim
303
+ ) * torch.log(1 - probs_per_dim + 1e-8)
304
+ persample_entropy = persample_entropy.sum(-1)
305
+ return persample_entropy.mean()
306
+
307
+ def codes_to_indexes(self, zhat):
308
+ """Converts a `code` to an index in the codebook.
309
+ Args:
310
+ zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
311
+ """
312
+ assert (
313
+ zhat.shape[-1] == self.embed_dim
314
+ ), f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
315
+ return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
316
+
317
+ def codes_to_group_indexes(self, zhat):
318
+ """Converts a `code` to a list of indexes (in groups) in the codebook.
319
+ Args:
320
+ zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
321
+ """
322
+ zhat_in_group = rearrange(zhat, "b ... (g c) -> b ... g c", c=self.group_size)
323
+ return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
324
+
325
+ def indexes_to_codes(self, indices):
326
+ """Inverse of `indexes_to_codes`."""
327
+ indices = indices.unsqueeze(-1)
328
+ codes_non_centered = torch.remainder(torch.floor_divide(indices, self.basis), 2)
329
+ return codes_non_centered * 2 - 1
330
+
331
+ def group_indexes_to_codes(self, group_indices):
332
+ """Inverse of `group_indexes_to_codes`."""
333
+ group_indices = group_indices.unsqueeze(-1)
334
+ codes_non_centered = torch.remainder(
335
+ torch.floor_divide(group_indices, self.group_basis), 2
336
+ )
337
+ codes_non_centered = rearrange(codes_non_centered, "b ... g c -> b ... (g c)")
338
+ return codes_non_centered * 2 - 1
339
+
340
+ def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
341
+ if normalize:
342
+ probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True)
343
+ else:
344
+ probs = count
345
+ H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
346
+ return H
347
+
348
+ def get_group_codebook_entry(self, group_indices):
349
+ z_q = self.group_indexes_to_codes(group_indices)
350
+ q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0
351
+ z_q = z_q * q_scale
352
+ if self.input_format == "bchw":
353
+ h, w = int(z_q.shape[1] ** 0.5)
354
+ assert h * w == z_q.shape[1], "Invalid sequence length"
355
+ z_q = rearrange(z_q, "b (h w) c -> b c h w", h=h)
356
+ return z_q
357
+
358
+ def get_codebook_entry(self, indices):
359
+ z_q = self.indexes_to_codes(indices)
360
+ q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0
361
+ z_q = z_q * q_scale
362
+ if self.input_format == "bchw":
363
+ h, w = int(z_q.shape[1] ** 0.5)
364
+ assert h * w == z_q.shape[1], "Invalid sequence length"
365
+ z_q = rearrange(z_q, "b (h w) c -> b c h w", h=h)
366
+ return z_q
367
+
368
+
369
+ if __name__ == "__main__":
370
+ bsq = BinarySphericalQuantizer()
371
+ z = torch.randn(3, 20, 14)
372
+ zq, loss, info = bsq(z)
373
+ print(zq.shape, loss, info)
models/codec/amphion_codec/quantize/factorized_vector_quantize.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ class FactorizedVectorQuantize(nn.Module):
18
+ def __init__(
19
+ self,
20
+ input_dim,
21
+ codebook_size,
22
+ codebook_dim,
23
+ commitment=0.005,
24
+ codebook_loss_weight=1.0,
25
+ use_l2_normlize=True,
26
+ ):
27
+ super().__init__()
28
+ self.input_dim = input_dim
29
+ self.codebook_size = codebook_size
30
+ self.codebook_dim = codebook_dim
31
+ self.commitment = commitment
32
+ self.codebook_loss_weight = codebook_loss_weight
33
+ self.use_l2_normlize = use_l2_normlize
34
+
35
+ if self.input_dim != self.codebook_dim:
36
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
37
+ self.out_project = WNConv1d(
38
+ self.codebook_dim, self.input_dim, kernel_size=1
39
+ )
40
+
41
+ else:
42
+ self.in_project = nn.Identity()
43
+ self.out_project = nn.Identity()
44
+
45
+ self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
46
+
47
+ def forward(self, z):
48
+ """
49
+ Parameters
50
+ ----------
51
+ z: torch.Tensor[B x D x T]
52
+
53
+ Returns
54
+ -------
55
+ z_q: torch.Tensor[B x D x T]
56
+ Quantized continuous representation of input
57
+ commit_loss: Tensor[B]
58
+ Commitment loss to train encoder to predict vectors closer to codebook entries
59
+ codebook_loss: Tensor[B]
60
+ Codebook loss to update the codebook
61
+ indices: torch.Tensor[B x T]
62
+ Codebook indices (quantized discrete representation of input)
63
+ z_e: torch.Tensor[B x D x T]
64
+ Projected latents (continuous representation of input before quantization)
65
+ """
66
+
67
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
68
+ z_e = self.in_project(z)
69
+ z_q, indices = self.decode_latents(z_e)
70
+
71
+ # Compute commitment loss and codebook loss
72
+ if self.training:
73
+ commit_loss = (
74
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
75
+ * self.commitment
76
+ )
77
+ codebook_loss = (
78
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
79
+ * self.codebook_loss_weight
80
+ )
81
+ else:
82
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
83
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
84
+
85
+ z_q = z_e + (z_q - z_e).detach()
86
+
87
+ z_q = self.out_project(z_q)
88
+
89
+ return z_q, commit_loss, codebook_loss, indices, z_e
90
+
91
+ def embed_code(self, embed_id):
92
+ return F.embedding(embed_id, self.codebook.weight)
93
+
94
+ def decode_code(self, embed_id):
95
+ return self.embed_code(embed_id).transpose(1, 2)
96
+
97
+ def decode_latents(self, latents):
98
+ encodings = rearrange(latents, "b d t -> (b t) d")
99
+ codebook = self.codebook.weight
100
+
101
+ # L2 normalize encodings and codebook
102
+ if self.use_l2_normlize:
103
+ encodings = F.normalize(encodings)
104
+ codebook = F.normalize(codebook)
105
+
106
+ # Compute euclidean distance between encodings and codebook,
107
+ # if use_l2_normlize is True, the distance is equal to cosine distance
108
+ dist = (
109
+ encodings.pow(2).sum(1, keepdim=True)
110
+ - 2 * encodings @ codebook.t()
111
+ + codebook.pow(2).sum(1, keepdim=True).t()
112
+ )
113
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
114
+ z_q = self.decode_code(indices)
115
+
116
+ return z_q, indices
117
+
118
+ def vq2emb(self, vq, out_proj=True):
119
+ emb = self.decode_code(vq)
120
+ if out_proj:
121
+ emb = self.out_project(emb)
122
+ return emb
123
+
124
+ def latent2dist(self, latents):
125
+ encodings = rearrange(latents, "b d t -> (b t) d")
126
+ codebook = self.codebook.weight
127
+
128
+ # L2 normalize encodings and codebook
129
+ if self.use_l2_normlize:
130
+ encodings = F.normalize(encodings)
131
+ codebook = F.normalize(codebook)
132
+
133
+ # Compute euclidean distance between encodings and codebook,
134
+ # if use_l2_normlize is True, the distance is equal to cosine distance
135
+ dist = (
136
+ encodings.pow(2).sum(1, keepdim=True)
137
+ - 2 * encodings @ codebook.t()
138
+ + codebook.pow(2).sum(1, keepdim=True).t()
139
+ ) # (b*t, k)
140
+
141
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
142
+ dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
143
+ z_q = self.decode_code(indices)
144
+
145
+ return -dist, indices, z_q
models/codec/amphion_codec/quantize/lookup_free_quantize.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ class LookupFreeQuantize(nn.Module):
18
+ def __init__(
19
+ self,
20
+ input_dim,
21
+ codebook_size,
22
+ codebook_dim,
23
+ ):
24
+ super().__init__()
25
+ self.input_dim = input_dim
26
+ self.codebook_size = codebook_size
27
+ self.codebook_dim = codebook_dim
28
+
29
+ assert 2**codebook_dim == codebook_size
30
+
31
+ if self.input_dim != self.codebook_dim:
32
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
33
+ self.out_project = WNConv1d(
34
+ self.codebook_dim, self.input_dim, kernel_size=1
35
+ )
36
+
37
+ else:
38
+ self.in_project = nn.Identity()
39
+ self.out_project = nn.Identity()
40
+
41
+ def forward(self, z):
42
+ z_e = self.in_project(z)
43
+ z_e = F.sigmoid(z_e)
44
+
45
+ z_q = z_e + (torch.round(z_e) - z_e).detach()
46
+
47
+ z_q = self.out_project(z_q)
48
+
49
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
50
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
51
+
52
+ bits = (
53
+ 2
54
+ ** torch.arange(self.codebook_dim, device=z.device)
55
+ .unsqueeze(0)
56
+ .unsqueeze(-1)
57
+ .long()
58
+ ) # (1, d, 1)
59
+ indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
60
+
61
+ return z_q, commit_loss, codebook_loss, indices, z_e
62
+
63
+ def vq2emb(self, vq, out_proj=True):
64
+ emb = torch.zeros(
65
+ vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
66
+ ) # (B, d, T)
67
+ for i in range(self.codebook_dim):
68
+ emb[:, i, :] = (vq % 2).float()
69
+ vq = vq // 2
70
+ if out_proj:
71
+ emb = self.out_project(emb)
72
+ return emb
models/codec/amphion_codec/quantize/residual_vq.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
11
+ FactorizedVectorQuantize,
12
+ )
13
+ from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
14
+ from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
15
+
16
+
17
+ class ResidualVQ(nn.Module):
18
+ """
19
+ Introduced in SoundStream: An end2end neural audio codec
20
+ https://arxiv.org/abs/2107.03312
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ input_dim: int = 256,
26
+ num_quantizers: int = 8,
27
+ codebook_size: int = 1024,
28
+ codebook_dim: int = 256,
29
+ quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
30
+ quantizer_dropout: float = 0.5,
31
+ **kwargs,
32
+ ):
33
+ super().__init__()
34
+
35
+ self.input_dim = input_dim
36
+ self.num_quantizers = num_quantizers
37
+ self.codebook_size = codebook_size
38
+ self.codebook_dim = codebook_dim
39
+ self.quantizer_type = quantizer_type
40
+ self.quantizer_dropout = quantizer_dropout
41
+
42
+ if quantizer_type == "vq":
43
+ VQ = VectorQuantize
44
+ elif quantizer_type == "fvq":
45
+ VQ = FactorizedVectorQuantize
46
+ elif quantizer_type == "lfq":
47
+ VQ = LookupFreeQuantize
48
+ else:
49
+ raise ValueError(f"Unknown quantizer type {quantizer_type}")
50
+
51
+ self.quantizers = nn.ModuleList(
52
+ [
53
+ VQ(
54
+ input_dim=input_dim,
55
+ codebook_size=codebook_size,
56
+ codebook_dim=codebook_dim,
57
+ **kwargs,
58
+ )
59
+ for _ in range(num_quantizers)
60
+ ]
61
+ )
62
+
63
+ def forward(self, z, n_quantizers: int = None):
64
+ """
65
+ Parameters
66
+ ----------
67
+ z : Tensor[B x D x T]
68
+ n_quantizers : int, optional
69
+ No. of quantizers to use
70
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
71
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
72
+ when in training mode, and a random number of quantizers is used.
73
+ Returns
74
+ -------
75
+ "quantized_out" : Tensor[B x D x T]
76
+ Quantized continuous representation of input
77
+ "all_indices" : Tensor[N x B x T]
78
+ Codebook indices for each codebook
79
+ (quantized discrete representation of input)
80
+ "all_commit_losses" : Tensor[N]
81
+ "all_codebook_losses" : Tensor[N]
82
+ "all_quantized" : Tensor[N x B x D x T]
83
+ """
84
+
85
+ quantized_out = 0.0
86
+ residual = z
87
+
88
+ all_commit_losses = []
89
+ all_codebook_losses = []
90
+ all_indices = []
91
+ all_quantized = []
92
+
93
+ if n_quantizers is None:
94
+ n_quantizers = self.num_quantizers
95
+
96
+ if self.training:
97
+ n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
98
+ dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
99
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
100
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
101
+ n_quantizers = n_quantizers.to(z.device)
102
+
103
+ for i, quantizer in enumerate(self.quantizers):
104
+ if self.training is False and i >= n_quantizers:
105
+ break
106
+
107
+ z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
108
+ residual
109
+ )
110
+
111
+ # Create mask to apply quantizer dropout
112
+ mask = (
113
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
114
+ )
115
+ quantized_out = quantized_out + z_q_i * mask[:, None, None]
116
+ residual = residual - z_q_i
117
+
118
+ commit_loss_i = (commit_loss_i * mask).mean()
119
+ codebook_loss_i = (codebook_loss_i * mask).mean()
120
+
121
+ all_commit_losses.append(commit_loss_i)
122
+ all_codebook_losses.append(codebook_loss_i)
123
+ all_indices.append(indices_i)
124
+ all_quantized.append(z_q_i)
125
+
126
+ all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
127
+ torch.stack,
128
+ (all_commit_losses, all_codebook_losses, all_indices, all_quantized),
129
+ )
130
+
131
+ return (
132
+ quantized_out,
133
+ all_indices,
134
+ all_commit_losses,
135
+ all_codebook_losses,
136
+ all_quantized,
137
+ )
138
+
139
+ def vq2emb(self, vq, n_quantizers=None):
140
+ quantized_out = 0.0
141
+ if n_quantizers is None:
142
+ n_quantizers = self.num_quantizers
143
+ for idx, quantizer in enumerate(self.quantizers):
144
+ if idx >= n_quantizers:
145
+ break
146
+ quantized_out += quantizer.vq2emb(vq[idx])
147
+ return quantized_out
148
+
149
+ def latent2dist(self, z, n_quantizers=None):
150
+ quantized_out = 0.0
151
+ residual = z
152
+
153
+ all_dists = []
154
+ all_indices = []
155
+
156
+ if n_quantizers is None:
157
+ n_quantizers = self.num_quantizers
158
+
159
+ for i, quantizer in enumerate(self.quantizers):
160
+ if self.training is False and i >= n_quantizers:
161
+ break
162
+ dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
163
+ all_dists.append(dist_i)
164
+ all_indices.append(indices_i)
165
+
166
+ quantized_out = quantized_out + z_q_i
167
+ residual = residual - z_q_i
168
+
169
+ all_dists = torch.stack(all_dists)
170
+ all_indices = torch.stack(all_indices)
171
+
172
+ return all_dists, all_indices
models/codec/amphion_codec/quantize/vector_quantize.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ def l2norm(t):
18
+ return F.normalize(t, p=2, dim=-1)
19
+
20
+
21
+ def ema_inplace(moving_avg, new, decay):
22
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
23
+
24
+
25
+ def laplace_smoothing(x, n_categories, eps=1e-5):
26
+ return (x + eps) / (x.sum() + n_categories * eps)
27
+
28
+
29
+ def sample_vectors(samples, num):
30
+ num_samples, device = samples.shape[0], samples.device
31
+
32
+ if num_samples >= num:
33
+ indices = torch.randperm(num_samples, device=device)[:num]
34
+ else:
35
+ indices = torch.randint(0, num_samples, (num,), device=device)
36
+
37
+ return samples[indices]
38
+
39
+
40
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
41
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
42
+
43
+ means = sample_vectors(samples, num_clusters)
44
+
45
+ for _ in range(num_iters):
46
+ if use_cosine_sim:
47
+ dists = samples @ means.t()
48
+ else:
49
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
50
+ means, "c d -> () c d"
51
+ )
52
+ dists = -(diffs**2).sum(dim=-1)
53
+
54
+ buckets = dists.max(dim=-1).indices
55
+ bins = torch.bincount(buckets, minlength=num_clusters)
56
+ zero_mask = bins == 0
57
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
58
+
59
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
60
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
61
+ new_means = new_means / bins_min_clamped[..., None]
62
+
63
+ if use_cosine_sim:
64
+ new_means = l2norm(new_means)
65
+
66
+ means = torch.where(zero_mask[..., None], means, new_means)
67
+
68
+ return means, bins
69
+
70
+
71
+ class EuclideanCodebook(nn.Module):
72
+ def __init__(
73
+ self,
74
+ dim,
75
+ codebook_size,
76
+ kmeans_init=False,
77
+ kmeans_iters=10,
78
+ decay=0.8,
79
+ eps=1e-5,
80
+ threshold_ema_dead_code=2,
81
+ weight_init=False,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.decay = decay
86
+ init_fn = torch.randn if not weight_init else torch.zeros
87
+ embed = init_fn(codebook_size, dim)
88
+
89
+ if weight_init:
90
+ nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
91
+
92
+ self.codebook_size = codebook_size
93
+ self.kmeans_iters = kmeans_iters
94
+ self.eps = eps
95
+ self.threshold_ema_dead_code = threshold_ema_dead_code
96
+
97
+ self.register_buffer(
98
+ "initted", torch.Tensor([not kmeans_init])
99
+ ) # if kmeans_init is True, then initted is False; otherwise, initted is True
100
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
101
+ self.register_buffer("embed", embed)
102
+ self.register_buffer("embed_avg", embed.clone())
103
+
104
+ def init_embed_(self, data):
105
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
106
+ self.embed.data.copy_(embed)
107
+ self.embed_avg.data.copy_(embed)
108
+ self.cluster_size.data.copy_(cluster_size)
109
+ self.initted.data.copy_(torch.Tensor([True]))
110
+
111
+ def replace(self, samples, mask):
112
+ modified_codebook = torch.where(
113
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
114
+ )
115
+ self.embed.data.copy_(modified_codebook)
116
+
117
+ def expire_codes_(self, batch_samples):
118
+ if self.threshold_ema_dead_code == 0:
119
+ return
120
+
121
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
122
+ if not torch.any(expired_codes):
123
+ return
124
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
125
+ self.replace(batch_samples, mask=expired_codes)
126
+
127
+ def forward(self, x):
128
+ shape, dtype = x.shape, x.dtype
129
+ flatten = rearrange(x, "... d -> (...) d")
130
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
131
+
132
+ if not self.initted:
133
+ self.init_embed_(flatten)
134
+
135
+ dist = -(
136
+ flatten.pow(2).sum(1, keepdim=True)
137
+ - 2 * flatten @ embed
138
+ + embed.pow(2).sum(0, keepdim=True)
139
+ )
140
+
141
+ embed_ind = dist.max(dim=-1).indices
142
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
143
+ embed_ind = embed_ind.view(*shape[:-1])
144
+ quantize = F.embedding(embed_ind, self.embed)
145
+
146
+ if self.training:
147
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
148
+ embed_sum = (
149
+ flatten.t() @ embed_onehot
150
+ ) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
151
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
152
+ cluster_size = (
153
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
154
+ * self.cluster_size.sum()
155
+ )
156
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
157
+ self.embed.data.copy_(embed_normalized)
158
+ self.expire_codes_(x)
159
+
160
+ return quantize, embed_ind
161
+
162
+ def vq2emb(self, vq):
163
+ quantize = F.embedding(vq, self.embed)
164
+ return quantize
165
+
166
+ def latent2dist(self, x):
167
+ shape, dtype = x.shape, x.dtype
168
+ flatten = rearrange(x, "... d -> (...) d")
169
+ embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
170
+
171
+ if not self.initted:
172
+ self.init_embed_(flatten)
173
+
174
+ dist = -(
175
+ flatten.pow(2).sum(1, keepdim=True)
176
+ - 2 * flatten @ embed
177
+ + embed.pow(2).sum(0, keepdim=True)
178
+ )
179
+
180
+ embed_ind = dist.max(dim=-1).indices
181
+ embed_ind = embed_ind.view(*shape[:-1])
182
+ quantize = F.embedding(embed_ind, self.embed)
183
+
184
+ dist = dist.view(*shape[:-1], -1)
185
+
186
+ return dist, embed_ind, quantize
187
+
188
+
189
+ class SimpleCodebook(nn.Module):
190
+ def __init__(
191
+ self,
192
+ dim,
193
+ codebook_size,
194
+ use_l2_normlize=False,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.dim = dim
199
+ self.codebook_size = codebook_size
200
+ self.use_l2_normlize = use_l2_normlize
201
+
202
+ self.embed = nn.Embedding(self.codebook_size, self.dim)
203
+
204
+ def forward(self, x):
205
+ shape, dtype = x.shape, x.dtype
206
+ flatten = rearrange(x, "... d -> (...) d")
207
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
208
+
209
+ if self.use_l2_normlize:
210
+ flatten = F.normalize(flatten)
211
+ embed = F.normalize(embed)
212
+
213
+ dist = -(
214
+ flatten.pow(2).sum(1, keepdim=True)
215
+ - 2 * flatten @ embed
216
+ + embed.pow(2).sum(0, keepdim=True)
217
+ )
218
+
219
+ embed_ind = dist.max(dim=-1).indices
220
+ embed_ind = embed_ind.view(*shape[:-1])
221
+ quantize = F.embedding(embed_ind, self.embed)
222
+
223
+ return quantize, embed_ind
224
+
225
+ def vq2emb(self, vq):
226
+ quantize = F.embedding(vq, self.embed.weight)
227
+ return quantize
228
+
229
+ def latent2dist(self, x):
230
+ shape, dtype = x.shape, x.dtype
231
+ flatten = rearrange(x, "... d -> (...) d")
232
+ embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
233
+
234
+ if self.use_l2_normlize:
235
+ flatten = F.normalize(flatten)
236
+ embed = F.normalize(embed)
237
+
238
+ dist = -(
239
+ flatten.pow(2).sum(1, keepdim=True)
240
+ - 2 * flatten @ embed
241
+ + embed.pow(2).sum(0, keepdim=True)
242
+ )
243
+
244
+ embed_ind = dist.max(dim=-1).indices
245
+ embed_ind = embed_ind.view(*shape[:-1])
246
+ quantize = F.embedding(embed_ind, self.embed)
247
+
248
+ dist = dist.view(*shape[:-1], -1)
249
+
250
+ return dist, embed_ind, quantize
251
+
252
+
253
+ class VectorQuantize(nn.Module):
254
+ """Vector quantization and factorized vecotor quantization implementation
255
+ Args:
256
+ input_dim (int): Dimension of input.
257
+ codebook_size (int): Codebook size.
258
+ codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
259
+ if use codebook_type == "euclidean", otherwise, if you want to use
260
+ factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
261
+ commitment (float): Weight for commitment loss.
262
+ use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
263
+ we suggest use it as True if you want to use factorized vector quantization
264
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
265
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
266
+ decay (float): Decay for exponential moving average over the codebooks.
267
+ epsilon (float): Epsilon value for numerical stability.
268
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
269
+ that have an exponential moving average cluster size less than the specified threshold with
270
+ randomly selected vector from the current batch.
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ input_dim,
276
+ codebook_size,
277
+ codebook_dim,
278
+ commitment=0.005,
279
+ codebook_loss_weight=1.0,
280
+ use_l2_normlize=False,
281
+ codebook_type="euclidean", # "euclidean" or "simple"
282
+ kmeans_init=False,
283
+ kmeans_iters=10,
284
+ decay=0.8,
285
+ eps=1e-5,
286
+ threshold_ema_dead_code=2,
287
+ weight_init=False,
288
+ ):
289
+ super().__init__()
290
+ self.input_dim = input_dim
291
+ self.codebook_size = codebook_size
292
+ self.codebook_dim = codebook_dim
293
+ self.commitment = commitment
294
+ self.codebook_loss_weight = codebook_loss_weight
295
+ self.use_l2_normlize = use_l2_normlize
296
+ self.codebook_type = codebook_type
297
+ self.kmeans_init = kmeans_init
298
+ self.kmeans_iters = kmeans_iters
299
+ self.decay = decay
300
+ self.eps = eps
301
+ self.threshold_ema_dead_code = threshold_ema_dead_code
302
+ self.weight_init = weight_init
303
+
304
+ if self.input_dim != self.codebook_dim:
305
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
306
+ self.out_project = WNConv1d(
307
+ self.codebook_dim, self.input_dim, kernel_size=1
308
+ )
309
+
310
+ else:
311
+ self.in_project = nn.Identity()
312
+ self.out_project = nn.Identity()
313
+
314
+ if self.codebook_type == "euclidean":
315
+ self.codebook = EuclideanCodebook(
316
+ self.codebook_dim,
317
+ codebook_size=self.codebook_size,
318
+ kmeans_init=self.kmeans_init,
319
+ kmeans_iters=self.kmeans_iters,
320
+ decay=self.decay,
321
+ eps=self.eps,
322
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
323
+ weight_init=self.weight_init,
324
+ )
325
+ elif self.codebook_type == "simple":
326
+ self.codebook = SimpleCodebook(
327
+ self.codebook_dim,
328
+ codebook_size=self.codebook_size,
329
+ use_l2_normlize=self.use_l2_normlize,
330
+ )
331
+ else:
332
+ raise NotImplementedError(
333
+ f"codebook_type {self.codebook_type} is not implemented!"
334
+ )
335
+
336
+ def forward(self, z):
337
+ """
338
+ Parameters
339
+ ----------
340
+ z: torch.Tensor[B x D x T]
341
+
342
+ Returns
343
+ -------
344
+ z_q: torch.Tensor[B x D x T]
345
+ Quantized continuous representation of input
346
+ commit_loss: Tensor[B]
347
+ Commitment loss to train encoder to predict vectors closer to codebook entries
348
+ codebook_loss: Tensor[B]
349
+ Codebook loss to update the codebook
350
+ indices: torch.Tensor[B x T]
351
+ Codebook indices (quantized discrete representation of input)
352
+ z_e: torch.Tensor[B x D x T]
353
+ Projected latents (continuous representation of input before quantization)
354
+ """
355
+
356
+ # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
357
+ z_e = self.in_project(z)
358
+ z_q, indices = self.decode_latents(z_e)
359
+
360
+ # Compute commitment loss and codebook loss
361
+ if self.training:
362
+ commit_loss = (
363
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
364
+ * self.commitment
365
+ )
366
+ codebook_loss = (
367
+ F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
368
+ * self.codebook_loss_weight
369
+ )
370
+ else:
371
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
372
+ codebook_loss = torch.zeros(z.shape[0], device=z.device)
373
+
374
+ z_q = z_e + (z_q - z_e).detach()
375
+
376
+ z_q = self.out_project(z_q)
377
+
378
+ return z_q, commit_loss, codebook_loss, indices, z_e
379
+
380
+ def decode_latents(self, latents):
381
+ encodings = rearrange(latents, "b d t -> b t d")
382
+ z_q, indices = self.codebook(encodings)
383
+ z_q = z_q.transpose(1, 2)
384
+ return z_q, indices
385
+
386
+ def vq2emb(self, vq, out_proj=True):
387
+ emb = self.codebook.vq2emb(vq)
388
+ emb = emb.transpose(1, 2)
389
+ if out_proj:
390
+ emb = self.out_project(emb)
391
+ return emb
392
+
393
+ def latent2dist(self, latents):
394
+ latents = rearrange(latents, "b d t -> b t d")
395
+ dist, embed_ind, quantize = self.codebook.latent2dist(latents)
396
+ return dist, embed_ind, quantize.transpose(1, 2)
models/codec/amphion_codec/vocos.py ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn, view_as_real, view_as_complex
7
+ from torch import nn
8
+ from torch.nn.utils import weight_norm, remove_weight_norm
9
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
10
+ from models.codec.melvqgan.melspec import MelSpectrogram
11
+ import librosa
12
+
13
+
14
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
15
+ """
16
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
17
+
18
+ Args:
19
+ x (Tensor): Input tensor.
20
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
21
+
22
+ Returns:
23
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
24
+ """
25
+ return torch.log(torch.clip(x, min=clip_val))
26
+
27
+
28
+ def symlog(x: torch.Tensor) -> torch.Tensor:
29
+ return torch.sign(x) * torch.log1p(x.abs())
30
+
31
+
32
+ def symexp(x: torch.Tensor) -> torch.Tensor:
33
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
34
+
35
+
36
+ class STFT(nn.Module):
37
+ def __init__(
38
+ self,
39
+ n_fft: int,
40
+ hop_length: int,
41
+ win_length: int,
42
+ center=True,
43
+ ):
44
+ super().__init__()
45
+ self.center = center
46
+ self.n_fft = n_fft
47
+ self.hop_length = hop_length
48
+ self.win_length = win_length
49
+ window = torch.hann_window(win_length)
50
+ self.register_buffer("window", window)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ # x: (B, T * hop_length)
54
+
55
+ if not self.center:
56
+ pad = self.win_length - self.hop_length
57
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
58
+
59
+ stft_spec = torch.stft(
60
+ x,
61
+ self.n_fft,
62
+ hop_length=self.hop_length,
63
+ win_length=self.win_length,
64
+ window=self.window,
65
+ center=self.center,
66
+ return_complex=False,
67
+ ) # (B, n_fft // 2 + 1, T, 2)
68
+
69
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
70
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
71
+
72
+ log_mag = torch.log(
73
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
74
+ ) # (B, n_fft // 2 + 1, T)
75
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
76
+
77
+ return log_mag, phase
78
+
79
+
80
+ class ISTFT(nn.Module):
81
+ """
82
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
83
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
84
+ See issue: https://github.com/pytorch/pytorch/issues/62323
85
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
86
+ The NOLA constraint is met as we trim padded samples anyway.
87
+
88
+ Args:
89
+ n_fft (int): Size of Fourier transform.
90
+ hop_length (int): The distance between neighboring sliding window frames.
91
+ win_length (int): The size of window frame and STFT filter.
92
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
93
+ """
94
+
95
+ def __init__(
96
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
97
+ ):
98
+ super().__init__()
99
+ if padding not in ["center", "same"]:
100
+ raise ValueError("Padding must be 'center' or 'same'.")
101
+ self.padding = padding
102
+ self.n_fft = n_fft
103
+ self.hop_length = hop_length
104
+ self.win_length = win_length
105
+ window = torch.hann_window(win_length)
106
+ self.register_buffer("window", window)
107
+
108
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
109
+ """
110
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
111
+
112
+ Args:
113
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
114
+ N is the number of frequency bins, and T is the number of time frames.
115
+
116
+ Returns:
117
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
118
+ """
119
+ if self.padding == "center":
120
+ # Fallback to pytorch native implementation
121
+ return torch.istft(
122
+ spec,
123
+ self.n_fft,
124
+ self.hop_length,
125
+ self.win_length,
126
+ self.window,
127
+ center=True,
128
+ )
129
+ elif self.padding == "same":
130
+ pad = (self.win_length - self.hop_length) // 2
131
+ else:
132
+ raise ValueError("Padding must be 'center' or 'same'.")
133
+
134
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
135
+ B, N, T = spec.shape
136
+
137
+ # Inverse FFT
138
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
139
+ ifft = ifft * self.window[None, :, None]
140
+
141
+ # Overlap and Add
142
+ output_size = (T - 1) * self.hop_length + self.win_length
143
+ y = torch.nn.functional.fold(
144
+ ifft,
145
+ output_size=(1, output_size),
146
+ kernel_size=(1, self.win_length),
147
+ stride=(1, self.hop_length),
148
+ )[:, 0, 0, pad:-pad]
149
+
150
+ # Window envelope
151
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
152
+ window_envelope = torch.nn.functional.fold(
153
+ window_sq,
154
+ output_size=(1, output_size),
155
+ kernel_size=(1, self.win_length),
156
+ stride=(1, self.hop_length),
157
+ ).squeeze()[pad:-pad]
158
+
159
+ # Normalize
160
+ assert (window_envelope > 1e-11).all()
161
+ y = y / window_envelope
162
+
163
+ return y
164
+
165
+
166
+ class MDCT(nn.Module):
167
+ """
168
+ Modified Discrete Cosine Transform (MDCT) module.
169
+
170
+ Args:
171
+ frame_len (int): Length of the MDCT frame.
172
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
173
+ """
174
+
175
+ def __init__(self, frame_len: int, padding: str = "same"):
176
+ super().__init__()
177
+ if padding not in ["center", "same"]:
178
+ raise ValueError("Padding must be 'center' or 'same'.")
179
+ self.padding = padding
180
+ self.frame_len = frame_len
181
+ N = frame_len // 2
182
+ n0 = (N + 1) / 2
183
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
184
+ self.register_buffer("window", window)
185
+
186
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
187
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
188
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
189
+ # https://github.com/pytorch/pytorch/issues/71613
190
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
191
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
192
+
193
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
194
+ """
195
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
196
+
197
+ Args:
198
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
199
+ and T is the length of the audio.
200
+
201
+ Returns:
202
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
203
+ and N is the number of frequency bins.
204
+ """
205
+ if self.padding == "center":
206
+ audio = torch.nn.functional.pad(
207
+ audio, (self.frame_len // 2, self.frame_len // 2)
208
+ )
209
+ elif self.padding == "same":
210
+ # hop_length is 1/2 frame_len
211
+ audio = torch.nn.functional.pad(
212
+ audio, (self.frame_len // 4, self.frame_len // 4)
213
+ )
214
+ else:
215
+ raise ValueError("Padding must be 'center' or 'same'.")
216
+
217
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
218
+ N = self.frame_len // 2
219
+ x = x * self.window.expand(x.shape)
220
+ X = torch.fft.fft(
221
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
222
+ )[..., :N]
223
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
224
+ return torch.real(res) * np.sqrt(2)
225
+
226
+
227
+ class IMDCT(nn.Module):
228
+ """
229
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
230
+
231
+ Args:
232
+ frame_len (int): Length of the MDCT frame.
233
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
234
+ """
235
+
236
+ def __init__(self, frame_len: int, padding: str = "same"):
237
+ super().__init__()
238
+ if padding not in ["center", "same"]:
239
+ raise ValueError("Padding must be 'center' or 'same'.")
240
+ self.padding = padding
241
+ self.frame_len = frame_len
242
+ N = frame_len // 2
243
+ n0 = (N + 1) / 2
244
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
245
+ self.register_buffer("window", window)
246
+
247
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
248
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
249
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
250
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
251
+
252
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
253
+ """
254
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
255
+
256
+ Args:
257
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
258
+ L is the number of frames, and N is the number of frequency bins.
259
+
260
+ Returns:
261
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
262
+ """
263
+ B, L, N = X.shape
264
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
265
+ Y[..., :N] = X
266
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
267
+ y = torch.fft.ifft(
268
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
269
+ )
270
+ y = (
271
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
272
+ * np.sqrt(N)
273
+ * np.sqrt(2)
274
+ )
275
+ result = y * self.window.expand(y.shape)
276
+ output_size = (1, (L + 1) * N)
277
+ audio = torch.nn.functional.fold(
278
+ result.transpose(1, 2),
279
+ output_size=output_size,
280
+ kernel_size=(1, self.frame_len),
281
+ stride=(1, self.frame_len // 2),
282
+ )[:, 0, 0, :]
283
+
284
+ if self.padding == "center":
285
+ pad = self.frame_len // 2
286
+ elif self.padding == "same":
287
+ pad = self.frame_len // 4
288
+ else:
289
+ raise ValueError("Padding must be 'center' or 'same'.")
290
+
291
+ audio = audio[:, pad:-pad]
292
+ return audio
293
+
294
+
295
+ class FourierHead(nn.Module):
296
+ """Base class for inverse fourier modules."""
297
+
298
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
299
+ """
300
+ Args:
301
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
302
+ L is the sequence length, and H denotes the model dimension.
303
+
304
+ Returns:
305
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
306
+ """
307
+ raise NotImplementedError("Subclasses must implement the forward method.")
308
+
309
+
310
+ class ISTFTHead(FourierHead):
311
+ """
312
+ ISTFT Head module for predicting STFT complex coefficients.
313
+
314
+ Args:
315
+ dim (int): Hidden dimension of the model.
316
+ n_fft (int): Size of Fourier transform.
317
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
318
+ the resolution of the input features.
319
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
320
+ """
321
+
322
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
323
+ super().__init__()
324
+ out_dim = n_fft + 2
325
+ self.out = torch.nn.Linear(dim, out_dim)
326
+ self.istft = ISTFT(
327
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
328
+ )
329
+
330
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
331
+ """
332
+ Forward pass of the ISTFTHead module.
333
+
334
+ Args:
335
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
336
+ L is the sequence length, and H denotes the model dimension.
337
+
338
+ Returns:
339
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
340
+ """
341
+ x = self.out(x).transpose(1, 2)
342
+ mag, p = x.chunk(2, dim=1)
343
+ mag = torch.exp(mag)
344
+ mag = torch.clip(
345
+ mag, max=1e2
346
+ ) # safeguard to prevent excessively large magnitudes
347
+ # wrapping happens here. These two lines produce real and imaginary value
348
+ x = torch.cos(p)
349
+ y = torch.sin(p)
350
+ # recalculating phase here does not produce anything new
351
+ # only costs time
352
+ # phase = torch.atan2(y, x)
353
+ # S = mag * torch.exp(phase * 1j)
354
+ # better directly produce the complex value
355
+ S = mag * (x + 1j * y)
356
+ audio = self.istft(S)
357
+ return audio
358
+
359
+
360
+ class IMDCTSymExpHead(FourierHead):
361
+ """
362
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
363
+
364
+ Args:
365
+ dim (int): Hidden dimension of the model.
366
+ mdct_frame_len (int): Length of the MDCT frame.
367
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
368
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
369
+ based on perceptual scaling. Defaults to None.
370
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ dim: int,
376
+ mdct_frame_len: int,
377
+ padding: str = "same",
378
+ sample_rate: Optional[int] = None,
379
+ clip_audio: bool = False,
380
+ ):
381
+ super().__init__()
382
+ out_dim = mdct_frame_len // 2
383
+ self.out = nn.Linear(dim, out_dim)
384
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
385
+ self.clip_audio = clip_audio
386
+
387
+ if sample_rate is not None:
388
+ # optionally init the last layer following mel-scale
389
+ m_max = _hz_to_mel(sample_rate // 2)
390
+ m_pts = torch.linspace(0, m_max, out_dim)
391
+ f_pts = _mel_to_hz(m_pts)
392
+ scale = 1 - (f_pts / f_pts.max())
393
+
394
+ with torch.no_grad():
395
+ self.out.weight.mul_(scale.view(-1, 1))
396
+
397
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
398
+ """
399
+ Forward pass of the IMDCTSymExpHead module.
400
+
401
+ Args:
402
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
403
+ L is the sequence length, and H denotes the model dimension.
404
+
405
+ Returns:
406
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
407
+ """
408
+ x = self.out(x)
409
+ x = symexp(x)
410
+ x = torch.clip(
411
+ x, min=-1e2, max=1e2
412
+ ) # safeguard to prevent excessively large magnitudes
413
+ audio = self.imdct(x)
414
+ if self.clip_audio:
415
+ audio = torch.clip(x, min=-1.0, max=1.0)
416
+
417
+ return audio
418
+
419
+
420
+ class IMDCTCosHead(FourierHead):
421
+ """
422
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
423
+
424
+ Args:
425
+ dim (int): Hidden dimension of the model.
426
+ mdct_frame_len (int): Length of the MDCT frame.
427
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
428
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
429
+ """
430
+
431
+ def __init__(
432
+ self,
433
+ dim: int,
434
+ mdct_frame_len: int,
435
+ padding: str = "same",
436
+ clip_audio: bool = False,
437
+ ):
438
+ super().__init__()
439
+ self.clip_audio = clip_audio
440
+ self.out = nn.Linear(dim, mdct_frame_len)
441
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
442
+
443
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
444
+ """
445
+ Forward pass of the IMDCTCosHead module.
446
+
447
+ Args:
448
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
449
+ L is the sequence length, and H denotes the model dimension.
450
+
451
+ Returns:
452
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
453
+ """
454
+ x = self.out(x)
455
+ m, p = x.chunk(2, dim=2)
456
+ m = torch.exp(m).clip(
457
+ max=1e2
458
+ ) # safeguard to prevent excessively large magnitudes
459
+ audio = self.imdct(m * torch.cos(p))
460
+ if self.clip_audio:
461
+ audio = torch.clip(x, min=-1.0, max=1.0)
462
+ return audio
463
+
464
+
465
+ class ConvNeXtBlock(nn.Module):
466
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
467
+
468
+ Args:
469
+ dim (int): Number of input channels.
470
+ intermediate_dim (int): Dimensionality of the intermediate layer.
471
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
472
+ Defaults to None.
473
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
474
+ None means non-conditional LayerNorm. Defaults to None.
475
+ """
476
+
477
+ def __init__(
478
+ self,
479
+ dim: int,
480
+ intermediate_dim: int,
481
+ layer_scale_init_value: float,
482
+ adanorm_num_embeddings: Optional[int] = None,
483
+ ):
484
+ super().__init__()
485
+ self.dwconv = nn.Conv1d(
486
+ dim, dim, kernel_size=7, padding=3, groups=dim
487
+ ) # depthwise conv
488
+ self.adanorm = adanorm_num_embeddings is not None
489
+ if adanorm_num_embeddings:
490
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
491
+ else:
492
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
493
+ self.pwconv1 = nn.Linear(
494
+ dim, intermediate_dim
495
+ ) # pointwise/1x1 convs, implemented with linear layers
496
+ self.act = nn.GELU()
497
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
498
+ self.gamma = (
499
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
500
+ if layer_scale_init_value > 0
501
+ else None
502
+ )
503
+
504
+ def forward(
505
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
506
+ ) -> torch.Tensor:
507
+ residual = x
508
+ x = self.dwconv(x)
509
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
510
+ if self.adanorm:
511
+ assert cond_embedding_id is not None
512
+ x = self.norm(x, cond_embedding_id)
513
+ else:
514
+ x = self.norm(x)
515
+ x = self.pwconv1(x)
516
+ x = self.act(x)
517
+ x = self.pwconv2(x)
518
+ if self.gamma is not None:
519
+ x = self.gamma * x
520
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
521
+
522
+ x = residual + x
523
+ return x
524
+
525
+
526
+ class AdaLayerNorm(nn.Module):
527
+ """
528
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
529
+
530
+ Args:
531
+ num_embeddings (int): Number of embeddings.
532
+ embedding_dim (int): Dimension of the embeddings.
533
+ """
534
+
535
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
536
+ super().__init__()
537
+ self.eps = eps
538
+ self.dim = embedding_dim
539
+ self.scale = nn.Embedding(
540
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
541
+ )
542
+ self.shift = nn.Embedding(
543
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
544
+ )
545
+ torch.nn.init.ones_(self.scale.weight)
546
+ torch.nn.init.zeros_(self.shift.weight)
547
+
548
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
549
+ scale = self.scale(cond_embedding_id)
550
+ shift = self.shift(cond_embedding_id)
551
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
552
+ x = x * scale + shift
553
+ return x
554
+
555
+
556
+ class ResBlock1(nn.Module):
557
+ """
558
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
559
+ but without upsampling layers.
560
+
561
+ Args:
562
+ dim (int): Number of input channels.
563
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
564
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
565
+ Defaults to (1, 3, 5).
566
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
567
+ Defaults to 0.1.
568
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
569
+ Defaults to None.
570
+ """
571
+
572
+ def __init__(
573
+ self,
574
+ dim: int,
575
+ kernel_size: int = 3,
576
+ dilation: Tuple[int, int, int] = (1, 3, 5),
577
+ lrelu_slope: float = 0.1,
578
+ layer_scale_init_value: Optional[float] = None,
579
+ ):
580
+ super().__init__()
581
+ self.lrelu_slope = lrelu_slope
582
+ self.convs1 = nn.ModuleList(
583
+ [
584
+ weight_norm(
585
+ nn.Conv1d(
586
+ dim,
587
+ dim,
588
+ kernel_size,
589
+ 1,
590
+ dilation=dilation[0],
591
+ padding=self.get_padding(kernel_size, dilation[0]),
592
+ )
593
+ ),
594
+ weight_norm(
595
+ nn.Conv1d(
596
+ dim,
597
+ dim,
598
+ kernel_size,
599
+ 1,
600
+ dilation=dilation[1],
601
+ padding=self.get_padding(kernel_size, dilation[1]),
602
+ )
603
+ ),
604
+ weight_norm(
605
+ nn.Conv1d(
606
+ dim,
607
+ dim,
608
+ kernel_size,
609
+ 1,
610
+ dilation=dilation[2],
611
+ padding=self.get_padding(kernel_size, dilation[2]),
612
+ )
613
+ ),
614
+ ]
615
+ )
616
+
617
+ self.convs2 = nn.ModuleList(
618
+ [
619
+ weight_norm(
620
+ nn.Conv1d(
621
+ dim,
622
+ dim,
623
+ kernel_size,
624
+ 1,
625
+ dilation=1,
626
+ padding=self.get_padding(kernel_size, 1),
627
+ )
628
+ ),
629
+ weight_norm(
630
+ nn.Conv1d(
631
+ dim,
632
+ dim,
633
+ kernel_size,
634
+ 1,
635
+ dilation=1,
636
+ padding=self.get_padding(kernel_size, 1),
637
+ )
638
+ ),
639
+ weight_norm(
640
+ nn.Conv1d(
641
+ dim,
642
+ dim,
643
+ kernel_size,
644
+ 1,
645
+ dilation=1,
646
+ padding=self.get_padding(kernel_size, 1),
647
+ )
648
+ ),
649
+ ]
650
+ )
651
+
652
+ self.gamma = nn.ParameterList(
653
+ [
654
+ (
655
+ nn.Parameter(
656
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
657
+ )
658
+ if layer_scale_init_value is not None
659
+ else None
660
+ ),
661
+ (
662
+ nn.Parameter(
663
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
664
+ )
665
+ if layer_scale_init_value is not None
666
+ else None
667
+ ),
668
+ (
669
+ nn.Parameter(
670
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
671
+ )
672
+ if layer_scale_init_value is not None
673
+ else None
674
+ ),
675
+ ]
676
+ )
677
+
678
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
679
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
680
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
681
+ xt = c1(xt)
682
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
683
+ xt = c2(xt)
684
+ if gamma is not None:
685
+ xt = gamma * xt
686
+ x = xt + x
687
+ return x
688
+
689
+ def remove_weight_norm(self):
690
+ for l in self.convs1:
691
+ remove_weight_norm(l)
692
+ for l in self.convs2:
693
+ remove_weight_norm(l)
694
+
695
+ @staticmethod
696
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
697
+ return int((kernel_size * dilation - dilation) / 2)
698
+
699
+
700
+ class Backbone(nn.Module):
701
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
702
+
703
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
704
+ """
705
+ Args:
706
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
707
+ C denotes output features, and L is the sequence length.
708
+
709
+ Returns:
710
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
711
+ and H denotes the model dimension.
712
+ """
713
+ raise NotImplementedError("Subclasses must implement the forward method.")
714
+
715
+
716
+ class VocosBackbone(Backbone):
717
+ """
718
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
719
+
720
+ Args:
721
+ input_channels (int): Number of input features channels.
722
+ dim (int): Hidden dimension of the model.
723
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
724
+ num_layers (int): Number of ConvNeXtBlock layers.
725
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
726
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
727
+ None means non-conditional model. Defaults to None.
728
+ """
729
+
730
+ def __init__(
731
+ self,
732
+ input_channels: int,
733
+ dim: int,
734
+ intermediate_dim: int,
735
+ num_layers: int,
736
+ layer_scale_init_value: Optional[float] = None,
737
+ adanorm_num_embeddings: Optional[int] = None,
738
+ ):
739
+ super().__init__()
740
+ self.input_channels = input_channels
741
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
742
+ self.adanorm = adanorm_num_embeddings is not None
743
+ if adanorm_num_embeddings:
744
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
745
+ else:
746
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
747
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
748
+ self.convnext = nn.ModuleList(
749
+ [
750
+ ConvNeXtBlock(
751
+ dim=dim,
752
+ intermediate_dim=intermediate_dim,
753
+ layer_scale_init_value=layer_scale_init_value,
754
+ adanorm_num_embeddings=adanorm_num_embeddings,
755
+ )
756
+ for _ in range(num_layers)
757
+ ]
758
+ )
759
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
760
+ self.apply(self._init_weights)
761
+
762
+ def _init_weights(self, m):
763
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
764
+ nn.init.trunc_normal_(m.weight, std=0.02)
765
+ nn.init.constant_(m.bias, 0)
766
+
767
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
768
+ bandwidth_id = kwargs.get("bandwidth_id", None)
769
+ x = self.embed(x)
770
+ if self.adanorm:
771
+ assert bandwidth_id is not None
772
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
773
+ else:
774
+ x = self.norm(x.transpose(1, 2))
775
+ x = x.transpose(1, 2)
776
+ for conv_block in self.convnext:
777
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
778
+ x = self.final_layer_norm(x.transpose(1, 2))
779
+ return x
780
+
781
+
782
+ class VocosResNetBackbone(Backbone):
783
+ """
784
+ Vocos backbone module built with ResBlocks.
785
+
786
+ Args:
787
+ input_channels (int): Number of input features channels.
788
+ dim (int): Hidden dimension of the model.
789
+ num_blocks (int): Number of ResBlock1 blocks.
790
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
791
+ """
792
+
793
+ def __init__(
794
+ self,
795
+ input_channels,
796
+ dim,
797
+ num_blocks,
798
+ layer_scale_init_value=None,
799
+ ):
800
+ super().__init__()
801
+ self.input_channels = input_channels
802
+ self.embed = weight_norm(
803
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
804
+ )
805
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
806
+ self.resnet = nn.Sequential(
807
+ *[
808
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
809
+ for _ in range(num_blocks)
810
+ ]
811
+ )
812
+
813
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
814
+ x = self.embed(x)
815
+ x = self.resnet(x)
816
+ x = x.transpose(1, 2)
817
+ return x
818
+
819
+
820
+ class Vocos(nn.Module):
821
+ def __init__(
822
+ self,
823
+ input_channels: int = 256,
824
+ dim: int = 384,
825
+ intermediate_dim: int = 1152,
826
+ num_layers: int = 8,
827
+ n_fft: int = 800,
828
+ hop_size: int = 200,
829
+ padding: str = "same",
830
+ adanorm_num_embeddings=None,
831
+ cfg=None,
832
+ ):
833
+ super().__init__()
834
+
835
+ input_channels = (
836
+ cfg.input_channels
837
+ if cfg is not None and hasattr(cfg, "input_channels")
838
+ else input_channels
839
+ )
840
+ dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
841
+ intermediate_dim = (
842
+ cfg.intermediate_dim
843
+ if cfg is not None and hasattr(cfg, "intermediate_dim")
844
+ else intermediate_dim
845
+ )
846
+ num_layers = (
847
+ cfg.num_layers
848
+ if cfg is not None and hasattr(cfg, "num_layers")
849
+ else num_layers
850
+ )
851
+ adanorm_num_embeddings = (
852
+ cfg.adanorm_num_embeddings
853
+ if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
854
+ else adanorm_num_embeddings
855
+ )
856
+ n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
857
+ hop_size = (
858
+ cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
859
+ )
860
+ padding = (
861
+ cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
862
+ )
863
+
864
+ self.backbone = VocosBackbone(
865
+ input_channels=input_channels,
866
+ dim=dim,
867
+ intermediate_dim=intermediate_dim,
868
+ num_layers=num_layers,
869
+ adanorm_num_embeddings=adanorm_num_embeddings,
870
+ )
871
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
872
+
873
+ def forward(self, x):
874
+ x = self.backbone(x)
875
+ x = self.head(x)
876
+
877
+ return x[:, None, :]
878
+
879
+
880
+ # if __name__ == "__main__":
881
+ # vocos_model = Vocos(
882
+ # input_channels=128,
883
+ # dim=1024,
884
+ # intermediate_dim=4096,
885
+ # adanorm_num_embeddings=None,
886
+ # num_layers=40,
887
+ # n_fft=1920,
888
+ # hop_size=480,
889
+ # padding="same"
890
+ # )
891
+ # mel_model = MelSpectrogram(
892
+ # 1920,
893
+ # 128,
894
+ # 24000,
895
+ # 480,
896
+ # 1920,
897
+ # 0,
898
+ # 12000
899
+ # )
900
+ # print(sum(p.numel() for p in vocos_model.parameters())/1e6)
901
+
902
+ # speech = librosa.load("/mnt/bn/yuacnwang-speech/dataset/yuanshen/ganyu_en.wav", sr=24000)[0][:36000]
903
+ # speech = torch.tensor(speech).unsqueeze(0)
904
+ # print(speech.shape)
905
+ # mel_feat = mel_model(speech)
906
+ # print(mel_feat.shape)
907
+
908
+ # rec_speech = vocos_model(mel_feat)
909
+ # print(rec_speech.shape)
models/codec/melvqgan/__pycache__/melspec.cpython-310.pyc ADDED
Binary file (2.78 kB). View file
 
models/codec/melvqgan/melspec.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pyworld as pw
3
+ import numpy as np
4
+ import soundfile as sf
5
+ import os
6
+ from torchaudio.functional import pitch_shift
7
+ import librosa
8
+ from librosa.filters import mel as librosa_mel_fn
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import tqdm
12
+
13
+
14
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
15
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
16
+
17
+
18
+ def dynamic_range_decompression(x, C=1):
19
+ return np.exp(x) / C
20
+
21
+
22
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
23
+ return torch.log(torch.clamp(x, min=clip_val) * C)
24
+
25
+
26
+ def dynamic_range_decompression_torch(x, C=1):
27
+ return torch.exp(x) / C
28
+
29
+
30
+ def spectral_normalize_torch(magnitudes):
31
+ output = dynamic_range_compression_torch(magnitudes)
32
+ return output
33
+
34
+
35
+ def spectral_de_normalize_torch(magnitudes):
36
+ output = dynamic_range_decompression_torch(magnitudes)
37
+ return output
38
+
39
+
40
+ class MelSpectrogram(nn.Module):
41
+ def __init__(
42
+ self,
43
+ n_fft,
44
+ num_mels,
45
+ sampling_rate,
46
+ hop_size,
47
+ win_size,
48
+ fmin,
49
+ fmax,
50
+ center=False,
51
+ ):
52
+ super(MelSpectrogram, self).__init__()
53
+ self.n_fft = n_fft
54
+ self.hop_size = hop_size
55
+ self.win_size = win_size
56
+ self.sampling_rate = sampling_rate
57
+ self.num_mels = num_mels
58
+ self.fmin = fmin
59
+ self.fmax = fmax
60
+ self.center = center
61
+
62
+ mel_basis = {}
63
+ hann_window = {}
64
+
65
+ mel = librosa_mel_fn(
66
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
67
+ )
68
+ mel_basis = torch.from_numpy(mel).float()
69
+ hann_window = torch.hann_window(win_size)
70
+
71
+ self.register_buffer("mel_basis", mel_basis)
72
+ self.register_buffer("hann_window", hann_window)
73
+
74
+ def forward(self, y):
75
+ y = torch.nn.functional.pad(
76
+ y.unsqueeze(1),
77
+ (
78
+ int((self.n_fft - self.hop_size) / 2),
79
+ int((self.n_fft - self.hop_size) / 2),
80
+ ),
81
+ mode="reflect",
82
+ )
83
+ y = y.squeeze(1)
84
+ spec = torch.stft(
85
+ y,
86
+ self.n_fft,
87
+ hop_length=self.hop_size,
88
+ win_length=self.win_size,
89
+ window=self.hann_window,
90
+ center=self.center,
91
+ pad_mode="reflect",
92
+ normalized=False,
93
+ onesided=True,
94
+ return_complex=True,
95
+ )
96
+ spec = torch.view_as_real(spec)
97
+
98
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
99
+
100
+ spec = torch.matmul(self.mel_basis, spec)
101
+ spec = spectral_normalize_torch(spec)
102
+
103
+ return spec
104
+
105
+
106
+ # if __name__ == "__main__":
107
+ # mel_model = MelSpectrogram(
108
+ # sampling_rate=24000,
109
+ # num_mels=128,
110
+ # hop_size=480,
111
+ # n_fft=1920,
112
+ # win_size=1920,
113
+ # fmin=0,
114
+ # fmax=12000,
115
+ # )
116
+ # mel_model = mel_model.to("cuda:0")
117
+
118
+ # x_len = 0
119
+ # var = 0
120
+ # mean = 0
121
+
122
+
123
+ # from utils.util import load_config
124
+ # cfg = load_config("/storage/wyc/SpeechGeneration/egs/tts/SoundStorm/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json")
125
+ # print(cfg)
126
+ # dataset = SoundStormDataset(AK, SK, bucket_name, cfg=cfg)
127
+ # print(dataset.__getitem__(0))
128
+
129
+ # idx_list = list(range(0, len(dataset)))
130
+ # np.random.shuffle(idx_list)
131
+ # for i in tqdm.tqdm(range(10000)):
132
+ # idx = idx_list[i]
133
+ # # data_path = dataset[idx]
134
+ # # speech, _ = librosa.load(data_path, sr=24000)
135
+ # speech = dataset[idx]["speech"]
136
+ # speech = torch.tensor(speech).unsqueeze(0).to("cuda")
137
+ # mel = mel_model(speech)
138
+ # temp_len = mel.shape[-1]
139
+ # temp_mean = mel.mean()
140
+ # temp_var = mel.var()
141
+
142
+ # new_mean = (mean * x_len + temp_mean * temp_len) / (x_len + temp_len)
143
+ # new_var = (var * (x_len - 1) + temp_var * (temp_len - 1) + x_len * (new_mean - mean)**2 + temp_len * (new_mean - temp_mean)**2) / (x_len + temp_len -1)
144
+
145
+ # x_len += temp_len
146
+ # mean = new_mean
147
+ # var = new_var
148
+
149
+ # if i % 100 == 0:
150
+ # print(mean, var)
151
+
152
+ # print(mean)
153
+ # print(var)
models/tts/llm_tts/__pycache__/chat_template.cpython-310.pyc ADDED
Binary file (2.36 kB). View file
 
models/tts/llm_tts/__pycache__/inference_llm_tts.cpython-310.pyc ADDED
Binary file (6.95 kB). View file
 
models/tts/llm_tts/__pycache__/inference_mgm_tts.cpython-310.pyc ADDED
Binary file (8.31 kB). View file
 
models/tts/llm_tts/__pycache__/llama_nar_prefix.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
models/tts/llm_tts/__pycache__/mgm.cpython-310.pyc ADDED
Binary file (8.05 kB). View file
 
models/tts/llm_tts/chat_template.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def format_chat_prompt_phi3(messages, add_assistant_token=True):
2
+ """
3
+ Convert the messages list into the phi-3 chat template format.
4
+
5
+ Args:
6
+ messages: A list of messages containing role and content.
7
+
8
+ Returns:
9
+ str: The formatted prompt string.
10
+ """
11
+ prompt = ""
12
+ for msg in messages:
13
+ role = msg["role"]
14
+ content = msg["content"]
15
+ # Add corresponding tags for system and user messages
16
+ if role in ["system", "user"]:
17
+ prompt += f"<|{role}|>\n{content}<|end|>\n"
18
+ # For assistant messages, add only the start tag if it's the last one
19
+ elif role == "assistant" and msg != messages[-1]:
20
+ prompt += f"<|{role}|>\n{content}<|end|>\n"
21
+ elif role == "assistant" and msg == messages[-1]:
22
+ prompt += f"<|{role}|>\n{content}"
23
+
24
+ # If the last message is not from the assistant, add the assistant tag
25
+ if messages[-1]["role"] != "assistant" and add_assistant_token:
26
+ prompt += "<|assistant|>"
27
+ return prompt
28
+
29
+
30
+ def format_chat_prompt_qwen2(messages, add_assistant_token=True):
31
+ """
32
+ Custom function to format chat prompts without tool-related logic.
33
+
34
+ Args:
35
+ messages: A list of messages containing role and content.
36
+ add_generation_prompt: Boolean to add a generation prompt at the end.
37
+
38
+ Returns:
39
+ str: The formatted prompt string.
40
+ """
41
+ prompt = ""
42
+
43
+ if messages[0]["role"] == "system":
44
+ prompt += f"<|im_start|>system\n{messages[0]['content']}<|im_end|>\n"
45
+ else:
46
+ prompt += "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
47
+
48
+ for message in messages:
49
+ if (
50
+ (message["role"] == "user")
51
+ or (message["role"] == "system" and not message == messages[0])
52
+ or (message["role"] == "assistant" and not message.get("tool_calls"))
53
+ ):
54
+ prompt += f"<|im_start|>{message['role']}\n{message['content']}<|im_end|>\n"
55
+ elif message["role"] == "assistant":
56
+ prompt += f"<|im_start|>{message['role']}"
57
+ if message.get("content"):
58
+ prompt += f"\n{message['content']}"
59
+ prompt += "<|im_end|>\n"
60
+
61
+ if add_assistant_token:
62
+ prompt += "<|im_start|>assistant\n"
63
+
64
+ return prompt
65
+
66
+
67
+ def gen_chat_prompt_for_tts(text, model_name="phi-3", caption=None):
68
+ if caption is None:
69
+ template = [
70
+ {
71
+ "role": "system",
72
+ "content": "You are a powerful AI assistant for speech understanding and generation.",
73
+ },
74
+ {
75
+ "role": "user",
76
+ "content": f"Please speak the following text out loud: {text}",
77
+ },
78
+ ]
79
+ else:
80
+ template = [
81
+ {
82
+ "role": "system",
83
+ "content": "You are a powerful AI assistant for speech understanding and generation.",
84
+ },
85
+ {
86
+ "role": "user",
87
+ "content": f"Please follow the caption: <|start_of_caption|>{caption}<|end_of_caption|> and speak the following text out loud: {text}",
88
+ },
89
+ ]
90
+
91
+ if model_name == "phi-3":
92
+ return format_chat_prompt_phi3(template)
93
+ elif model_name == "qwen2":
94
+ return format_chat_prompt_qwen2(template)
95
+ else:
96
+ raise ValueError(f"Unsupported model: {model_name}")
models/tts/llm_tts/inference_llm_tts.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import os
6
+ from huggingface_hub import snapshot_download
7
+
8
+ from models.tts.tadicodec.inference_tadicodec import TaDiCodecPipline
9
+ from models.tts.llm_tts.chat_template import gen_chat_prompt_for_tts
10
+
11
+
12
+ class TTSInferencePipeline(nn.Module):
13
+ """
14
+ TTS inference pipeline that integrates TaDiCodec and LLM models
15
+ Uses standard LLM for autoregressive generation
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ tadicodec_path: str,
21
+ llm_path: str,
22
+ device: torch.device,
23
+ ):
24
+ super().__init__()
25
+ self.device = device
26
+ self.llm_path = llm_path
27
+
28
+ # Load TaDiCodec pipeline
29
+ self.tadicodec = TaDiCodecPipline.from_pretrained(
30
+ ckpt_dir=tadicodec_path, device=device
31
+ )
32
+
33
+ # Load LLM directly from pretrained
34
+ # Try to use flash attention 2, fallback to default if not available
35
+ try:
36
+ self.llm = AutoModelForCausalLM.from_pretrained(
37
+ llm_path,
38
+ device_map=device,
39
+ torch_dtype="auto",
40
+ trust_remote_code=True,
41
+ attn_implementation="flash_attention_2",
42
+ )
43
+ except Exception as e:
44
+ print(f"Flash attention 2 not available, using default attention: {e}")
45
+ self.llm = AutoModelForCausalLM.from_pretrained(
46
+ llm_path,
47
+ device_map=device,
48
+ torch_dtype="auto",
49
+ trust_remote_code=True,
50
+ )
51
+
52
+ # Load tokenizer directly from pretrained
53
+ self.tokenizer = AutoTokenizer.from_pretrained(
54
+ llm_path,
55
+ trust_remote_code=True,
56
+ )
57
+
58
+ def tensor_to_audio_string(self, tensor):
59
+ """Convert tensor to audio string format"""
60
+ if isinstance(tensor, list) and isinstance(tensor[0], list):
61
+ values = tensor[0]
62
+ else:
63
+ values = tensor[0].tolist() if hasattr(tensor, "tolist") else tensor[0]
64
+
65
+ result = "<|start_of_audio|>"
66
+ for value in values:
67
+ result += f"<|audio_{value}|>"
68
+ return result
69
+
70
+ def extract_audio_ids(self, text):
71
+ """Extract audio IDs from string containing audio tokens"""
72
+ import re
73
+
74
+ pattern = r"<\|audio_(\d+)\|>"
75
+ audio_ids = re.findall(pattern, text)
76
+ return [int(id) for id in audio_ids]
77
+
78
+ @classmethod
79
+ def from_pretrained(
80
+ cls,
81
+ model_id: str = None,
82
+ tadicodec_path: str = None,
83
+ llm_path: str = None,
84
+ device: Optional[torch.device] = None,
85
+ auto_download: bool = True,
86
+ ):
87
+ """
88
+ Create pipeline from pretrained models
89
+
90
+ Args:
91
+ model_id: Hugging Face model ID for the LLM (e.g., "amphion/TaDiCodec-TTS-AR-Qwen2.5-3B")
92
+ tadicodec_path: Path to TaDiCodec model or Hugging Face model ID (defaults to "amphion/TaDiCodec")
93
+ llm_path: Path to LLM model or Hugging Face model ID (overrides model_id if provided)
94
+ device: Device to run on
95
+ auto_download: Whether to automatically download models from Hugging Face if not found locally
96
+
97
+ Returns:
98
+ TTSInferencePipeline instance
99
+ """
100
+ resolved_device = (
101
+ device
102
+ if device is not None
103
+ else torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+ )
105
+
106
+ # Set default paths if not provided
107
+ if tadicodec_path is None:
108
+ tadicodec_path = "amphion/TaDiCodec"
109
+
110
+ if llm_path is None:
111
+ if model_id is not None:
112
+ llm_path = model_id
113
+ else:
114
+ llm_path = "./ckpt/TaDiCodec-TTS-AR-Qwen2.5-0.5B"
115
+
116
+ # Handle TaDiCodec path
117
+ resolved_tadicodec_path = cls._resolve_model_path(
118
+ tadicodec_path, auto_download=auto_download, model_type="tadicodec"
119
+ )
120
+
121
+ # Handle LLM path
122
+ resolved_llm_path = cls._resolve_model_path(
123
+ llm_path, auto_download=auto_download, model_type="llm"
124
+ )
125
+
126
+ return cls(
127
+ tadicodec_path=resolved_tadicodec_path,
128
+ llm_path=resolved_llm_path,
129
+ device=resolved_device,
130
+ )
131
+
132
+ @staticmethod
133
+ def _resolve_model_path(
134
+ model_path: str, auto_download: bool = True, model_type: str = "llm"
135
+ ) -> str:
136
+ """
137
+ Resolve model path, downloading from Hugging Face if necessary
138
+
139
+ Args:
140
+ model_path: Local path or Hugging Face model ID
141
+ auto_download: Whether to auto-download from HF
142
+ model_type: Type of model ("llm" or "tadicodec")
143
+
144
+ Returns:
145
+ Resolved local path
146
+ """
147
+ # If it's already a local path and exists, return as is
148
+ if os.path.exists(model_path):
149
+ return model_path
150
+
151
+ # If it looks like a Hugging Face model ID (contains '/')
152
+ if "/" in model_path and auto_download:
153
+ print(f"Downloading {model_type} model from Hugging Face: {model_path}")
154
+ try:
155
+ # Download to cache directory
156
+ cache_dir = os.path.join(
157
+ os.path.expanduser("~"), ".cache", "huggingface", "hub"
158
+ )
159
+ downloaded_path = snapshot_download(
160
+ repo_id=model_path,
161
+ cache_dir=cache_dir,
162
+ local_dir_use_symlinks=False,
163
+ )
164
+ print(
165
+ f"Successfully downloaded {model_type} model to: {downloaded_path}"
166
+ )
167
+ return downloaded_path
168
+ except Exception as e:
169
+ print(f"Failed to download {model_type} model from Hugging Face: {e}")
170
+ raise ValueError(
171
+ f"Could not download {model_type} model from {model_path}"
172
+ )
173
+
174
+ # If it's a local path that doesn't exist
175
+ if not os.path.exists(model_path):
176
+ if auto_download:
177
+ raise ValueError(
178
+ f"Model path does not exist: {model_path}. Set auto_download=True to download from Hugging Face."
179
+ )
180
+ else:
181
+ raise FileNotFoundError(f"Model path does not exist: {model_path}")
182
+
183
+ return model_path
184
+
185
+ @torch.no_grad()
186
+ def __call__(
187
+ self,
188
+ text: str,
189
+ prompt_text: Optional[str] = None,
190
+ prompt_speech_path: Optional[str] = None,
191
+ top_k: int = 50,
192
+ top_p: float = 0.98,
193
+ temperature: float = 1.0,
194
+ n_timesteps: int = 25,
195
+ return_code: bool = False,
196
+ ):
197
+ """
198
+ Perform TTS inference
199
+
200
+ Args:
201
+ text: Target text to synthesize
202
+ prompt_text: Prompt text for conditioning
203
+ prompt_speech_path: Path to prompt audio file
204
+ top_k: Top-k sampling parameter
205
+ top_p: Top-p sampling parameter
206
+ temperature: Temperature for sampling
207
+ n_timesteps: Number of diffusion timesteps
208
+ return_code: Whether to return audio codes instead of audio
209
+
210
+ Returns:
211
+ Generated audio array or audio codes
212
+ """
213
+ # Get prompt audio codes
214
+ if prompt_speech_path:
215
+ prompt_speech_code = self.tadicodec(
216
+ speech_path=prompt_speech_path, return_code=True, text=""
217
+ )
218
+ else:
219
+ raise ValueError("prompt_speech_path is required")
220
+
221
+ # Use standard LLM for autoregressive generation
222
+ # TODO: add a json style chat prompt template
223
+ prompt = gen_chat_prompt_for_tts(
224
+ (prompt_text or "") + text,
225
+ "phi-3" if "Phi" in self.llm_path else "qwen2",
226
+ ) + self.tensor_to_audio_string(prompt_speech_code)
227
+
228
+ input_ids = self.tokenizer.encode(prompt)
229
+ generate_ids = self.llm.generate(
230
+ input_ids=torch.tensor(input_ids).unsqueeze(0).to(self.device),
231
+ min_new_tokens=12,
232
+ max_new_tokens=400,
233
+ do_sample=True,
234
+ top_k=top_k,
235
+ top_p=top_p,
236
+ temperature=temperature,
237
+ )
238
+
239
+ output = self.tokenizer.decode(generate_ids[0], skip_special_tokens=False)
240
+
241
+ combine_speech_code = self.extract_audio_ids(output)
242
+ indices = torch.tensor(combine_speech_code).unsqueeze(0).long().to(self.device)
243
+
244
+ if return_code:
245
+ return indices
246
+
247
+ # Decode audio
248
+ text_token_ids = self.tadicodec.tokenize_text(text, prompt_text)
249
+ prompt_mel = self.tadicodec.extract_mel_feature(prompt_speech_path)
250
+
251
+ rec_mel = self.tadicodec.decode(
252
+ indices=indices,
253
+ text_token_ids=text_token_ids,
254
+ prompt_mel=prompt_mel,
255
+ n_timesteps=n_timesteps,
256
+ )
257
+
258
+ rec_audio = (
259
+ self.tadicodec.vocoder_model(rec_mel.transpose(1, 2))
260
+ .detach()
261
+ .cpu()
262
+ .numpy()[0][0]
263
+ )
264
+
265
+ return rec_audio
models/tts/llm_tts/inference_mgm_tts.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import librosa
5
+ from huggingface_hub import snapshot_download
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import Optional
10
+ import safetensors
11
+ from transformers import AutoTokenizer
12
+ from utils.util import load_config
13
+
14
+ from models.tts.tadicodec.inference_tadicodec import TaDiCodecPipline
15
+ from models.tts.llm_tts.mgm import MGMT2S
16
+
17
+ from models.tts.llm_tts.chat_template import gen_chat_prompt_for_tts
18
+
19
+
20
+ class MGMInferencePipeline(nn.Module):
21
+ """
22
+ MGM TTS inference pipeline that integrates TaDiCodec and MGM models
23
+ Uses diffusion-based generation with mask-guided modeling
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ tadicodec_path: str,
29
+ mgm_path: str,
30
+ device: torch.device,
31
+ ):
32
+ super().__init__()
33
+ self.device = device
34
+ self.mgm_path = mgm_path
35
+
36
+ # Load TaDiCodec pipeline
37
+ self.tadicodec = TaDiCodecPipline.from_pretrained(
38
+ ckpt_dir=tadicodec_path, device=device
39
+ )
40
+
41
+ # Load tokenizer directly from pretrained
42
+ self.tokenizer = AutoTokenizer.from_pretrained(
43
+ mgm_path,
44
+ trust_remote_code=True,
45
+ )
46
+
47
+ config_path = os.path.join(mgm_path, "config.json")
48
+ if not os.path.exists(config_path):
49
+ raise FileNotFoundError(f"Config file not found at {config_path}")
50
+
51
+ self.cfg = load_config(config_path)
52
+
53
+ # Extract MGM config from the loaded config
54
+ mgm_config = self.cfg.model.mgmt2s
55
+ if not mgm_config:
56
+ raise ValueError("MGM config not found in config.json")
57
+
58
+ # Load MGM model with config - using the same pattern as llm_infer_eval.py
59
+ self.mgm = MGMT2S(
60
+ hidden_size=mgm_config.hidden_size,
61
+ num_layers=mgm_config.num_layers,
62
+ num_heads=mgm_config.num_heads,
63
+ cfg_scale=mgm_config.cfg_scale,
64
+ cond_codebook_size=mgm_config.cond_codebook_size,
65
+ cond_dim=mgm_config.cond_dim,
66
+ phone_vocab_size=mgm_config.phone_vocab_size,
67
+ )
68
+
69
+ # Load model weights
70
+ model_path = os.path.join(mgm_path, "model.safetensors")
71
+
72
+ if os.path.exists(model_path):
73
+ safetensors.torch.load_model(self.mgm, model_path, strict=True)
74
+ else:
75
+ # Try loading from the directory directly
76
+ safetensors.torch.load_model(self.mgm, mgm_path, strict=True)
77
+
78
+ self.mgm.to(device)
79
+ self.mgm.eval()
80
+
81
+ def tensor_to_audio_string(self, tensor):
82
+ """Convert tensor to audio string format"""
83
+ if isinstance(tensor, list) and isinstance(tensor[0], list):
84
+ values = tensor[0]
85
+ else:
86
+ values = tensor[0].tolist() if hasattr(tensor, "tolist") else tensor[0]
87
+
88
+ result = "<|start_of_audio|>"
89
+ for value in values:
90
+ result += f"<|audio_{value}|>"
91
+ return result
92
+
93
+ def extract_audio_ids(self, text):
94
+ """Extract audio IDs from string containing audio tokens"""
95
+ import re
96
+
97
+ pattern = r"<\|audio_(\d+)\|>"
98
+ audio_ids = re.findall(pattern, text)
99
+ return [int(id) for id in audio_ids]
100
+
101
+ @classmethod
102
+ def from_pretrained(
103
+ cls,
104
+ model_id: str = None,
105
+ tadicodec_path: str = None,
106
+ mgm_path: str = None,
107
+ device: Optional[torch.device] = None,
108
+ auto_download: bool = True,
109
+ ):
110
+ """
111
+ Create pipeline from pretrained models
112
+
113
+ Args:
114
+ model_id: Hugging Face model ID for the MGM model (e.g., "amphion/TaDiCodec-TTS-MGM")
115
+ tadicodec_path: Path to TaDiCodec model or Hugging Face model ID (defaults to "amphion/TaDiCodec")
116
+ mgm_path: Path to MGM model directory or Hugging Face model ID (overrides model_id if provided)
117
+ device: Device to run on
118
+ auto_download: Whether to automatically download models from Hugging Face if not found locally
119
+
120
+ Returns:
121
+ MGMInferencePipeline instance
122
+ """
123
+ resolved_device = (
124
+ device
125
+ if device is not None
126
+ else torch.device("cuda" if torch.cuda.is_available() else "cpu")
127
+ )
128
+
129
+ # Set default paths if not provided
130
+ if tadicodec_path is None:
131
+ tadicodec_path = "amphion/TaDiCodec"
132
+
133
+ if mgm_path is None:
134
+ if model_id is not None:
135
+ mgm_path = model_id
136
+ else:
137
+ mgm_path = "./ckpt/TaDiCodec-TTS-MGM"
138
+
139
+ # Handle TaDiCodec path
140
+ resolved_tadicodec_path = cls._resolve_model_path(
141
+ tadicodec_path, auto_download=auto_download, model_type="tadicodec"
142
+ )
143
+
144
+ # Handle MGM path
145
+ resolved_mgm_path = cls._resolve_model_path(
146
+ mgm_path, auto_download=auto_download, model_type="mgm"
147
+ )
148
+
149
+ return cls(
150
+ tadicodec_path=resolved_tadicodec_path,
151
+ mgm_path=resolved_mgm_path,
152
+ device=resolved_device,
153
+ )
154
+
155
+ @staticmethod
156
+ def _resolve_model_path(
157
+ model_path: str, auto_download: bool = True, model_type: str = "mgm"
158
+ ) -> str:
159
+ """
160
+ Resolve model path, downloading from Hugging Face if necessary
161
+
162
+ Args:
163
+ model_path: Local path or Hugging Face model ID
164
+ auto_download: Whether to auto-download from HF
165
+ model_type: Type of model ("mgm" or "tadicodec")
166
+
167
+ Returns:
168
+ Resolved local path
169
+ """
170
+ # If it's already a local path and exists, return as is
171
+ if os.path.exists(model_path):
172
+ return model_path
173
+
174
+ # If it looks like a Hugging Face model ID (contains '/')
175
+ if "/" in model_path and auto_download:
176
+ print(f"Downloading {model_type} model from Hugging Face: {model_path}")
177
+ try:
178
+ # Download to cache directory
179
+ cache_dir = os.path.join(
180
+ os.path.expanduser("~"), ".cache", "huggingface", "hub"
181
+ )
182
+ downloaded_path = snapshot_download(
183
+ repo_id=model_path,
184
+ cache_dir=cache_dir,
185
+ local_dir_use_symlinks=False,
186
+ )
187
+ print(
188
+ f"Successfully downloaded {model_type} model to: {downloaded_path}"
189
+ )
190
+ return downloaded_path
191
+ except Exception as e:
192
+ print(f"Failed to download {model_type} model from Hugging Face: {e}")
193
+ raise ValueError(
194
+ f"Could not download {model_type} model from {model_path}"
195
+ )
196
+
197
+ # If it's a local path that doesn't exist
198
+ if not os.path.exists(model_path):
199
+ if auto_download:
200
+ raise ValueError(
201
+ f"Model path does not exist: {model_path}. Set auto_download=True to download from Hugging Face."
202
+ )
203
+ else:
204
+ raise FileNotFoundError(f"Model path does not exist: {model_path}")
205
+
206
+ return model_path
207
+
208
+ @torch.no_grad()
209
+ def __call__(
210
+ self,
211
+ text: str,
212
+ prompt_text: Optional[str] = None,
213
+ prompt_speech_path: Optional[str] = None,
214
+ n_timesteps_mgm: int = 25,
215
+ n_timesteps: int = 25,
216
+ target_len: Optional[int] = None,
217
+ return_code: bool = False,
218
+ ):
219
+ """
220
+ Perform MGM TTS inference
221
+
222
+ Args:
223
+ text: Target text to synthesize
224
+ prompt_text: Prompt text for conditioning
225
+ prompt_speech_path: Path to prompt audio file
226
+ n_timesteps_mgm: Number of diffusion timesteps for MGM
227
+ n_timesteps: Number of diffusion timesteps
228
+ target_len: Target length for audio generation
229
+ return_code: Whether to return audio codes instead of audio
230
+
231
+ Returns:
232
+ Generated audio array or audio codes
233
+ """
234
+ # Get prompt audio codes
235
+ if prompt_speech_path:
236
+ prompt_speech_code = self.tadicodec(
237
+ speech_path=prompt_speech_path, return_code=True, text=""
238
+ )
239
+ else:
240
+ raise ValueError("prompt_speech_path is required")
241
+
242
+ # Convert prompt codes to tensor
243
+ prompt_codes = torch.tensor(prompt_speech_code).to(self.device)
244
+ prompt_len = prompt_codes.shape[1]
245
+
246
+ # Tokenize text for phone conditioning
247
+ input_text = gen_chat_prompt_for_tts(
248
+ prompt_text + " " + text,
249
+ "phi-3" if "phi" in self.cfg.preprocess.tokenizer_path else "qwen2",
250
+ )
251
+
252
+ ##### debug #####
253
+ print("input_text: ", input_text)
254
+ ##### debug #####
255
+
256
+ text_token_ids = self.tokenizer.encode(input_text)
257
+ text_token_ids = torch.tensor(text_token_ids).unsqueeze(0).to(self.device)
258
+
259
+ # Estimate target length based on text length
260
+ frame_rate = getattr(self.cfg.preprocess, "frame_rate", 6.25)
261
+
262
+ if target_len is None:
263
+ # If no target_len, estimate based on prompt speech length and text ratio
264
+ prompt_text_len = len(prompt_text.encode("utf-8"))
265
+ target_text_len = len(text.encode("utf-8"))
266
+ prompt_speech_len = librosa.get_duration(filename=prompt_speech_path)
267
+ target_speech_len = prompt_speech_len * target_text_len / prompt_text_len
268
+ target_len = int(target_speech_len * frame_rate)
269
+ else:
270
+ # If target_len is provided, use it directly
271
+ target_len = int(target_len * frame_rate)
272
+
273
+ ##### debug #####
274
+ print(f"Prompt length: {prompt_len}, Target length: {target_len}")
275
+ print(f"Text: {text}")
276
+ print(f"Prompt text: {prompt_text}")
277
+ ##### debug #####
278
+
279
+ # Generate audio codes using MGM reverse diffusion
280
+ generated_codes = self.mgm.reverse_diffusion(
281
+ prompt=prompt_codes,
282
+ target_len=target_len,
283
+ phone_id=text_token_ids,
284
+ n_timesteps=n_timesteps_mgm,
285
+ cfg=1.5,
286
+ rescale_cfg=0.75,
287
+ )
288
+
289
+ print(f"Generated codes shape: {generated_codes.shape}")
290
+
291
+ combine_codes = torch.cat([prompt_codes, generated_codes], dim=1)
292
+
293
+ if return_code:
294
+ return combine_codes
295
+
296
+ # Decode audio using TaDiCodec
297
+ prompt_mel = self.tadicodec.extract_mel_feature(prompt_speech_path)
298
+
299
+ text_token_ids = self.tadicodec.tokenize_text(text, prompt_text)
300
+ rec_mel = self.tadicodec.decode(
301
+ indices=combine_codes,
302
+ text_token_ids=text_token_ids,
303
+ prompt_mel=prompt_mel,
304
+ n_timesteps=n_timesteps,
305
+ )
306
+
307
+ rec_audio = (
308
+ self.tadicodec.vocoder_model(rec_mel.transpose(1, 2))
309
+ .detach()
310
+ .cpu()
311
+ .numpy()[0][0]
312
+ )
313
+
314
+ return rec_audio
315
+
316
+
317
+ # Usage example
318
+ if __name__ == "__main__":
319
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
320
+
321
+ # Create pipeline
322
+ pipeline = MGMInferencePipeline.from_pretrained(
323
+ tadicodec_path="./ckpt/TaDiCodec",
324
+ mgm_path="./ckpt/TaDiCodec-TTS-MGM",
325
+ device=device,
326
+ )
327
+
328
+ # Inference on single sample
329
+ audio = pipeline(
330
+ text="但是 to those who 知道 her well, it was a 标志 of her unwavering 决心 and spirit.",
331
+ prompt_text="In short, we embarked on a mission to make America great again, for all Americans.",
332
+ prompt_speech_path="./use_examples/test_audio/trump_0.wav",
333
+ )
334
+
335
+ # Save audio
336
+ import soundfile as sf
337
+
338
+ sf.write("./use_examples/test_audio/mgm_tts_output.wav", audio, 24000)
models/tts/llm_tts/llama_nar_prefix.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaConfig, LlamaModel
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import List, Optional, Tuple, Union
5
+ import math
6
+
7
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
8
+ from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
9
+
10
+
11
+ class SinusoidalPosEmb(nn.Module):
12
+ """
13
+ Sinusoidal Positional Embedding module.
14
+
15
+ This module generates sinusoidal positional embeddings for a given 1D input tensor,
16
+ which is commonly used for representing timesteps in diffusion models.
17
+ """
18
+
19
+ def __init__(self, dim: int):
20
+ """
21
+ Initializes the SinusoidalPosEmb module.
22
+
23
+ Args:
24
+ dim (int): The dimension of the embedding.
25
+ """
26
+ super().__init__()
27
+ self.dim = dim
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ """
31
+ Generates the positional embedding.
32
+
33
+ Args:
34
+ x (torch.Tensor): A 1D tensor of positions (e.g., timesteps), shape `(batch_size,)`.
35
+
36
+ Returns:
37
+ torch.Tensor: The positional embeddings, shape `(batch_size, dim)`.
38
+ """
39
+ device = x.device
40
+ half_dim = self.dim // 2
41
+ # Calculate the embedding frequencies based on the log-space formula
42
+ emb = math.log(10000) / (half_dim - 1)
43
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
44
+ # Create the embedding matrix by multiplying positions with frequencies
45
+ emb = x[:, None] * emb[None, :] * 1.0
46
+ # Concatenate sine and cosine components to form the final embedding
47
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
48
+ return emb
49
+
50
+
51
+ class LlamaAdaptiveRMSNorm(nn.Module):
52
+ """
53
+ Adaptive Root Mean Square Layer Normalization.
54
+
55
+ This is a variant of RMSNorm where the scaling factor (weight) is adaptively
56
+ predicted from a conditional embedding, allowing for conditional modulation
57
+ of the normalized hidden states.
58
+ """
59
+
60
+ def __init__(
61
+ self, hidden_size: int = 1024, eps: float = 1e-6, dim_cond: int = 1024
62
+ ):
63
+ """
64
+ Initializes the LlamaAdaptiveRMSNorm module.
65
+
66
+ Args:
67
+ hidden_size (int): The dimension of the hidden states to be normalized.
68
+ eps (float): A small value added to the variance for numerical stability.
69
+ dim_cond (int): The dimension of the conditional embedding.
70
+ """
71
+ super().__init__()
72
+ # Linear layer to project the conditional embedding to the hidden size
73
+ self.to_weight = nn.Linear(dim_cond, hidden_size)
74
+ # Initialize weights to zero and bias to one for an identity transformation at the start
75
+ nn.init.zeros_(self.to_weight.weight)
76
+ nn.init.ones_(self.to_weight.bias)
77
+ self.variance_epsilon = eps
78
+ # Disable automatic Hugging Face initialization for this custom module
79
+ self._is_hf_initialized = True
80
+
81
+ def forward(
82
+ self, hidden_states: torch.Tensor, cond_embedding: torch.Tensor
83
+ ) -> torch.Tensor:
84
+ """
85
+ Applies the adaptive RMS normalization.
86
+
87
+ Args:
88
+ hidden_states (torch.Tensor): The input tensor, shape `(batch, seq_len, hidden_size)`.
89
+ cond_embedding (torch.Tensor): The conditional embedding, shape `(batch, dim_cond)`.
90
+
91
+ Returns:
92
+ torch.Tensor: The normalized and modulated hidden states.
93
+ """
94
+ input_dtype = hidden_states.dtype
95
+ # Calculate variance and normalize the hidden states
96
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
97
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
98
+
99
+ # Predict the scaling factor from the conditional embedding
100
+ weight = self.to_weight(cond_embedding)
101
+ # Unsqueeze if the conditional embedding is per-batch instead of per-token
102
+ if len(weight.shape) == 2:
103
+ weight = weight.unsqueeze(1)
104
+
105
+ # Apply the learned scaling factor
106
+ return (weight * hidden_states).to(input_dtype)
107
+
108
+
109
+ class LlamaNARDecoderLayer(LlamaDecoderLayer):
110
+ """
111
+ A Non-Autoregressive (NAR) Llama Decoder Layer using adaptive layer normalization.
112
+
113
+ This class overrides the standard LlamaDecoderLayer to replace its RMSNorm
114
+ modules with LlamaAdaptiveRMSNorm, allowing it to be conditioned on an external embedding.
115
+ """
116
+
117
+ def __init__(self, config: LlamaConfig, layer_idx: int):
118
+ """Overrides the LlamaDecoderLayer to use adaptive layer normalization."""
119
+ super().__init__(config, layer_idx) # init attention, mlp, etc. from parent
120
+ self.layer_idx = layer_idx
121
+ # Override the standard layer norms with our adaptive versions
122
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
123
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
124
+ )
125
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
126
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
127
+ )
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states: torch.Tensor,
132
+ cond_embedding: torch.Tensor,
133
+ attention_mask: Optional[torch.Tensor] = None,
134
+ position_ids: Optional[torch.LongTensor] = None,
135
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
136
+ output_attentions: Optional[bool] = False,
137
+ use_cache: Optional[bool] = False,
138
+ ) -> Tuple[
139
+ torch.Tensor,
140
+ Optional[torch.Tensor],
141
+ Optional[Tuple[torch.Tensor, torch.Tensor]],
142
+ ]:
143
+ """
144
+ Forward pass for the NAR decoder layer, including conditional embedding.
145
+
146
+ Args:
147
+ hidden_states (torch.Tensor): Input to the layer of shape `(batch, seq_len, embed_dim)`.
148
+ cond_embedding (torch.Tensor): Conditional embedding for adaptive normalization.
149
+ attention_mask (Optional[torch.Tensor]): Attention mask of size `(batch, 1, tgt_len, src_len)`.
150
+ position_ids (Optional[torch.LongTensor]): Indices of positions of each input sequence tokens.
151
+ past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached past key and value projection states.
152
+ output_attentions (Optional[bool]): Whether to return the attention tensors.
153
+ use_cache (Optional[bool]): If True, past key values are returned to speed up decoding.
154
+
155
+ Returns:
156
+ Tuple containing the output hidden states, and optionally attention weights and past key/value states.
157
+ """
158
+ residual = hidden_states
159
+
160
+ # Apply adaptive pre-attention layer norm
161
+ hidden_states = self.input_layernorm(
162
+ hidden_states, cond_embedding=cond_embedding
163
+ )
164
+
165
+ # Self Attention block
166
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
167
+ hidden_states=hidden_states,
168
+ attention_mask=attention_mask,
169
+ position_ids=position_ids,
170
+ past_key_value=past_key_value,
171
+ output_attentions=output_attentions,
172
+ use_cache=use_cache,
173
+ )
174
+ hidden_states = residual + hidden_states
175
+
176
+ # Fully Connected block
177
+ residual = hidden_states
178
+ # Apply adaptive post-attention layer norm
179
+ hidden_states = self.post_attention_layernorm(
180
+ hidden_states, cond_embedding=cond_embedding
181
+ )
182
+ hidden_states = self.mlp(hidden_states)
183
+ hidden_states = residual + hidden_states
184
+
185
+ outputs = (hidden_states,)
186
+
187
+ if output_attentions:
188
+ outputs += (self_attn_weights,)
189
+
190
+ if use_cache:
191
+ outputs += (present_key_value,)
192
+
193
+ return outputs
194
+
195
+
196
+ class DiffLlamaPrefix(LlamaModel):
197
+ """
198
+ A Llama-based non-autoregressive transformer model for diffusion (masked generative modeling) tasks.
199
+
200
+ This model uses a Llama architecture but modifies it for non-autoregressive generation.
201
+ Key features:
202
+ 1. Non-causal (fully-visible) attention mask.
203
+ 2. Adaptive layer normalization conditioned on diffusion timesteps.
204
+ 3. Ability to be conditioned on phoneme (text) embeddings, which are prepended as a prefix.
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ hidden_size: int = 1024,
210
+ num_heads: int = 16,
211
+ num_layers: int = 16,
212
+ use_phone_cond: bool = True,
213
+ config: LlamaConfig = LlamaConfig(
214
+ vocab_size=0, hidden_size=256, num_attention_heads=1, num_hidden_layers=1
215
+ ),
216
+ ):
217
+ """
218
+ Initializes the DiffLlamaPrefix model.
219
+
220
+ Args:
221
+ hidden_size (int): The hidden dimension of the transformer.
222
+ num_heads (int): The number of attention heads.
223
+ num_layers (int): The number of transformer layers.
224
+ use_phone_cond (bool): Whether to use phoneme embeddings as a conditional prefix.
225
+ config (LlamaConfig): A LlamaConfig object. A default is provided for convenience.
226
+ """
227
+ super().__init__(config)
228
+
229
+ self.use_phone_cond = use_phone_cond
230
+
231
+ # Create a stack of non-autoregressive Llama layers
232
+ self.layers = nn.ModuleList(
233
+ [
234
+ LlamaNARDecoderLayer(
235
+ LlamaConfig(
236
+ hidden_size=hidden_size,
237
+ num_attention_heads=num_heads,
238
+ max_position_embeddings=4096,
239
+ intermediate_size=hidden_size * 4,
240
+ ),
241
+ layer_idx=i,
242
+ )
243
+ for i in range(num_layers)
244
+ ]
245
+ )
246
+
247
+ # Final adaptive layer norm
248
+ self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
249
+
250
+ # Modules for diffusion step conditioning
251
+ self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
252
+ self.diff_step_mlp = nn.Sequential(
253
+ nn.Linear(hidden_size, hidden_size * 4),
254
+ nn.SiLU(),
255
+ nn.Linear(hidden_size * 4, hidden_size),
256
+ )
257
+
258
+ # MLP for processing phoneme embedding condition
259
+ if self.use_phone_cond:
260
+ self.cond_mlp = nn.Sequential(
261
+ nn.Linear(hidden_size, hidden_size * 4),
262
+ nn.SiLU(),
263
+ nn.Linear(hidden_size * 4, hidden_size),
264
+ )
265
+
266
+ # This loop is redundant if layers are initialized correctly, but ensures config consistency.
267
+ for layer in self.layers:
268
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
269
+ hidden_size, dim_cond=hidden_size
270
+ )
271
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
272
+ hidden_size, dim_cond=hidden_size
273
+ )
274
+
275
+ # We handle embeddings manually, so disable the default token embedder
276
+ self.embed_tokens = None
277
+
278
+ self.post_init()
279
+
280
+ def _prepare_decoder_attention_mask(
281
+ self,
282
+ attention_mask: torch.Tensor,
283
+ input_shape: Tuple[int, int],
284
+ inputs_embeds: torch.Tensor,
285
+ past_key_values_length: int,
286
+ ) -> Optional[torch.Tensor]:
287
+ """
288
+ Creates a non-causal (fully-visible) attention mask.
289
+
290
+ This method overrides the default causal mask creation. It converts a 2D padding mask
291
+ `[bsz, seq_len]` into a 4D attention mask `[bsz, 1, tgt_seq_len, src_seq_len]`
292
+ suitable for self-attention, without applying a causal triangle.
293
+
294
+ Args:
295
+ attention_mask (torch.Tensor): The 2D padding mask.
296
+ input_shape (Tuple[int, int]): The shape of the input (`batch_size`, `seq_len`).
297
+ inputs_embeds (torch.Tensor): The input embeddings tensor.
298
+ past_key_values_length (int): The length of any cached key-values.
299
+
300
+ Returns:
301
+ Optional[torch.Tensor]: The 4D attention mask, or None if the input mask is None.
302
+ """
303
+ combined_attention_mask = None
304
+
305
+ def _expand_mask(
306
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
307
+ ) -> torch.Tensor:
308
+ """Expands a 2D attention mask to a 4D attention mask."""
309
+ bsz, src_len = mask.size()
310
+ tgt_len = tgt_len if tgt_len is not None else src_len
311
+ expanded_mask = (
312
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
313
+ )
314
+ # Invert the mask and convert to additive format (-inf for masked positions)
315
+ inverted_mask = 1.0 - expanded_mask
316
+ return inverted_mask.masked_fill(
317
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
318
+ )
319
+
320
+ if attention_mask is not None:
321
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
322
+ expanded_attn_mask = _expand_mask(
323
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
324
+ ).to(inputs_embeds.device)
325
+ combined_attention_mask = (
326
+ expanded_attn_mask
327
+ if combined_attention_mask is None
328
+ else expanded_attn_mask + combined_attention_mask
329
+ )
330
+
331
+ return combined_attention_mask
332
+
333
+ def forward(
334
+ self,
335
+ x: torch.Tensor,
336
+ diffusion_step: torch.Tensor,
337
+ x_mask: torch.Tensor,
338
+ phone_embedding: Optional[torch.Tensor] = None,
339
+ phone_mask: Optional[torch.Tensor] = None,
340
+ input_ids: Optional[torch.LongTensor] = None,
341
+ attention_mask: Optional[torch.Tensor] = None,
342
+ position_ids: Optional[torch.LongTensor] = None,
343
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
344
+ inputs_embeds: Optional[torch.FloatTensor] = None,
345
+ use_cache: Optional[bool] = None,
346
+ output_attentions: Optional[bool] = None,
347
+ output_hidden_states: Optional[bool] = None,
348
+ return_dict: Optional[bool] = None,
349
+ ) -> torch.Tensor:
350
+ """
351
+ Forward pass of the DiffLlamaPrefix model.
352
+
353
+ Args:
354
+ x (torch.Tensor): The primary input tensor, e.g., noisy data (`batch, seq_len, hidden_size`).
355
+ diffusion_step (torch.Tensor): Diffusion timesteps, shape `(batch,)`.
356
+ x_mask (torch.Tensor): The padding mask for `x`, shape `(batch, seq_len)`.
357
+ phone_embedding (Optional[torch.Tensor]): Phoneme embeddings prefix, shape `(batch, phone_len, hidden_size)`.
358
+ phone_mask (Optional[torch.Tensor]): The padding mask for `phone_embedding`, shape `(batch, phone_len)`.
359
+ input_ids, etc.: Standard Hugging Face arguments, mostly for compatibility.
360
+
361
+ Returns:
362
+ torch.Tensor: The final output tensor of shape `(batch, seq_len, hidden_size)`.
363
+ """
364
+ # 1. Prepend conditional prefix (phoneme embeddings)
365
+ if self.use_phone_cond and phone_embedding is not None:
366
+ # Process condition through an MLP
367
+ phone_embedding = self.cond_mlp(phone_embedding) # (B, T_phone, C)
368
+ phone_length = phone_embedding.shape[1]
369
+ # Concatenate prefix and main input
370
+ inputs_embeds = torch.cat([phone_embedding, x], dim=1)
371
+ attention_mask = torch.cat([phone_mask, x_mask], dim=1)
372
+ else:
373
+ inputs_embeds = x
374
+ attention_mask = x_mask
375
+ phone_length = 0
376
+
377
+ # 2. Process diffusion step embedding for adaptive normalization
378
+ diffusion_step_emb = self.diff_step_embedding(diffusion_step).to(x.device)
379
+ diffusion_step_emb = self.diff_step_mlp(diffusion_step_emb) # (B, C)
380
+
381
+ # 3. Standard Transformer Preamble (adapted from LlamaModel)
382
+ batch_size, seq_length, _ = inputs_embeds.shape
383
+
384
+ output_attentions = (
385
+ output_attentions
386
+ if output_attentions is not None
387
+ else self.config.output_attentions
388
+ )
389
+ output_hidden_states = (
390
+ output_hidden_states
391
+ if output_hidden_states is not None
392
+ else self.config.output_hidden_states
393
+ )
394
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
395
+ return_dict = (
396
+ return_dict if return_dict is not None else self.config.use_return_dict
397
+ )
398
+
399
+ past_key_values_length = 0
400
+ if past_key_values is not None:
401
+ past_key_values_length = past_key_values[0][0].shape[2]
402
+
403
+ if position_ids is None:
404
+ device = inputs_embeds.device
405
+ position_ids = torch.arange(
406
+ past_key_values_length,
407
+ seq_length + past_key_values_length,
408
+ dtype=torch.long,
409
+ device=device,
410
+ )
411
+ position_ids = position_ids.unsqueeze(0)
412
+
413
+ # Create the non-causal attention mask
414
+ attention_mask = self._prepare_decoder_attention_mask(
415
+ attention_mask,
416
+ (batch_size, seq_length),
417
+ inputs_embeds,
418
+ past_key_values_length,
419
+ )
420
+
421
+ hidden_states = inputs_embeds
422
+
423
+ # 4. Transformer Decoder Layers
424
+ all_hidden_states = () if output_hidden_states else None
425
+ all_self_attns = () if output_attentions else None
426
+
427
+ for idx, decoder_layer in enumerate(self.layers):
428
+ if output_hidden_states:
429
+ all_hidden_states += (hidden_states,)
430
+
431
+ past_key_value = (
432
+ past_key_values[idx] if past_key_values is not None else None
433
+ )
434
+
435
+ # Pass the processed diffusion step embedding to the adaptive layer
436
+ layer_outputs = decoder_layer(
437
+ hidden_states,
438
+ cond_embedding=diffusion_step_emb,
439
+ attention_mask=attention_mask,
440
+ position_ids=position_ids,
441
+ past_key_value=past_key_value,
442
+ output_attentions=output_attentions,
443
+ use_cache=use_cache,
444
+ )
445
+ hidden_states = layer_outputs[0]
446
+
447
+ if output_attentions:
448
+ all_self_attns += (layer_outputs[1],)
449
+
450
+ # 5. Final Normalization and Output Processing
451
+ hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step_emb)
452
+
453
+ if output_hidden_states:
454
+ all_hidden_states += (hidden_states,)
455
+
456
+ # Remove the conditional prefix from the final output sequence
457
+ return hidden_states[:, phone_length:]
models/tts/llm_tts/mgm.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import math
5
+ from einops import rearrange
6
+ from models.tts.llm_tts.llama_nar_prefix import DiffLlamaPrefix
7
+
8
+
9
+ def top_k(logits, thres=0.9):
10
+ k = math.ceil((1 - thres) * logits.shape[-1])
11
+ val, ind = logits.topk(k, dim=-1)
12
+ probs = torch.full_like(logits, float("-inf"))
13
+ probs.scatter_(2, ind, val)
14
+ return probs
15
+
16
+
17
+ def log(t, eps=1e-10):
18
+ return torch.log(t + eps)
19
+
20
+
21
+ def gumbel_noise(t):
22
+ noise = torch.zeros_like(t).uniform_(0, 1)
23
+ return -log(-log(noise))
24
+
25
+
26
+ def gumbel_sample(t, temperature=1.0, dim=-1):
27
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
28
+
29
+
30
+ class MGMT2S(nn.Module):
31
+ def __init__(
32
+ self,
33
+ hidden_size=1024,
34
+ num_layers=16,
35
+ num_heads=16,
36
+ cfg_scale=0.2,
37
+ cond_codebook_size=8192,
38
+ cond_dim=1024,
39
+ use_phone_cond=True,
40
+ phone_vocab_size=32100,
41
+ cfg=None,
42
+ ):
43
+ super().__init__()
44
+
45
+ hidden_size = (
46
+ cfg.hidden_size
47
+ if cfg is not None and hasattr(cfg, "hidden_size")
48
+ else hidden_size
49
+ )
50
+ num_layers = (
51
+ cfg.num_layers
52
+ if cfg is not None and hasattr(cfg, "num_layers")
53
+ else num_layers
54
+ )
55
+ num_heads = (
56
+ cfg.num_heads
57
+ if cfg is not None and hasattr(cfg, "num_heads")
58
+ else num_heads
59
+ )
60
+ cfg_scale = (
61
+ cfg.cfg_scale
62
+ if cfg is not None and hasattr(cfg, "cfg_scale")
63
+ else cfg_scale
64
+ )
65
+ cond_codebook_size = (
66
+ cfg.cond_codebook_size
67
+ if cfg is not None and hasattr(cfg, "cond_codebook_size")
68
+ else cond_codebook_size
69
+ )
70
+ cond_dim = (
71
+ cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
72
+ )
73
+ use_phone_cond = (
74
+ cfg.use_phone_cond
75
+ if cfg is not None and hasattr(cfg, "use_phone_cond")
76
+ else use_phone_cond
77
+ )
78
+ phone_vocab_size = (
79
+ cfg.phone_vocab_size
80
+ if cfg is not None and hasattr(cfg, "phone_vocab_size")
81
+ else phone_vocab_size
82
+ )
83
+
84
+ self.hidden_size = hidden_size
85
+ self.num_layers = num_layers
86
+ self.num_heads = num_heads
87
+ self.cfg_scale = cfg_scale
88
+ self.cond_codebook_size = cond_codebook_size
89
+ self.cond_dim = cond_dim
90
+ self.use_phone_cond = use_phone_cond
91
+ self.phone_vocab_size = phone_vocab_size
92
+
93
+ self.mask_emb = nn.Embedding(1, self.hidden_size)
94
+
95
+ self.to_logit = nn.Linear(self.hidden_size, self.cond_codebook_size)
96
+
97
+ self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size)
98
+
99
+ if self.use_phone_cond:
100
+ self.phone_emb = nn.Embedding(self.phone_vocab_size, hidden_size)
101
+ torch.nn.init.normal_(self.phone_emb.weight, mean=0.0, std=0.02)
102
+
103
+ self.reset_parameters()
104
+
105
+ self.diff_estimator = DiffLlamaPrefix(
106
+ hidden_size=hidden_size,
107
+ num_heads=num_heads,
108
+ num_layers=num_layers,
109
+ use_phone_cond=use_phone_cond,
110
+ )
111
+
112
+ def mask_prob(self, t):
113
+ return torch.sin(t * np.pi / 2).to(t.device)
114
+
115
+ def forward_diffusion(self, x0, t):
116
+ # x0: semantic tokens (B, T)
117
+ new_t = t
118
+ mask_prob = self.mask_prob(new_t) # (B,)
119
+ # if mask_prob[i] < 0.2, mask_prob[i] = 0.2
120
+ mask_prob = torch.where(
121
+ mask_prob < 0.2, torch.ones_like(mask_prob) * 0.2, mask_prob
122
+ )
123
+ mask_token = self.mask_emb(
124
+ torch.LongTensor([0]).to(x0.device)
125
+ ) # (1, hidden_size)
126
+
127
+ xt = torch.zeros(x0.shape[0], x0.shape[1], self.hidden_size).to(x0.device)
128
+
129
+ cfg_scale = self.cfg_scale
130
+
131
+ # a segment of r% sequence length is masked, where r ~ U[60, 100]
132
+ if torch.rand(1) > cfg_scale:
133
+ prompt_len = torch.randint(
134
+ min(x0.shape[1] // 4, 5), int(x0.shape[1] * 0.4), (x0.shape[0],)
135
+ ).to(
136
+ x0.device
137
+ ) # (B,)
138
+ else:
139
+ prompt_len = torch.zeros(x0.shape[0]).to(x0) # (B,)
140
+
141
+ # get is prompt
142
+ is_prompt = torch.zeros_like(x0[:, :]) # (B, T)
143
+ col_indices = (
144
+ torch.arange(is_prompt.shape[1])
145
+ .repeat(is_prompt.shape[0], 1)
146
+ .to(prompt_len)
147
+ ) # (B, T)
148
+ is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # (B, T) 1 if prompt
149
+
150
+ # Add mask
151
+ mask = torch.bernoulli(torch.ones_like(x0[:, :]) * mask_prob[..., None])
152
+ mask[is_prompt.bool()] = 0
153
+ mask_num = mask[:,].sum(dim=1, keepdim=False)
154
+ all_zero_mask = (mask_num == 0).bool()
155
+ row_indices_to_modify = torch.nonzero(all_zero_mask)
156
+ mask[row_indices_to_modify, prompt_len[row_indices_to_modify]] = 1
157
+ mask = mask[..., None] # (B, T, 1)
158
+ xt = (
159
+ xt + mask * mask_token[:, None, :] + (1 - mask) * self.cond_emb(x0[:, :])
160
+ ) # (B, T, hidden_size)
161
+
162
+ return xt, new_t, mask, prompt_len, mask_prob
163
+
164
+ def loss_t(self, x0, x_mask, t, phone_embedding=None, phone_mask=None):
165
+ xt, new_t, mask, prompt_len, mask_prob = self.forward_diffusion(x0, t)
166
+ # xt: (B, T, hidden_size)
167
+ # new_t: (B,)
168
+ # mask: (B, T, 1) mask if 1, not mask if 0
169
+ # prompt_len: (B,)
170
+ # mask_prob: (B,)
171
+
172
+ # # drop all condition for cfg, so if prompt_len is 0, we also drop phone_embedding
173
+ # if self.use_phone_cond and phone_embedding != None:
174
+ # phone_embedding = phone_embedding * torch.where(prompt_len > 0, torch.ones_like(prompt_len), torch.zeros_like(prompt_len)).to(x0.device).unsqueeze(-1).unsqueeze(-1)
175
+
176
+ embeds = self.diff_estimator(
177
+ xt, new_t, x_mask, phone_embedding=phone_embedding, phone_mask=phone_mask
178
+ ) # (B, T, hidden_size)
179
+ logits = self.to_logit(embeds) # (B, T, codebook_size)
180
+
181
+ # final mask used for loss calculation
182
+ final_mask = mask * x_mask[..., None] # (B, T, 1)
183
+
184
+ return logits, final_mask, x0, prompt_len, mask_prob
185
+
186
+ def compute_loss(self, x0, x_mask, phone_embedding=None, phone_mask=None):
187
+ # x0: (B, T)
188
+ # x_mask: (B, T) mask is 0 for padding
189
+ t = torch.rand(x0.shape[0], device=x0.device, requires_grad=False)
190
+ t = torch.clamp(t, 1e-5, 1.0)
191
+ return self.loss_t(x0, x_mask, t, phone_embedding, phone_mask)
192
+
193
+ def reset_parameters(self):
194
+ def _reset_parameters(m):
195
+ if isinstance(m, nn.MultiheadAttention):
196
+ if m._qkv_same_embed_dim:
197
+ nn.init.normal_(m.in_proj_weight, std=0.02)
198
+ else:
199
+ nn.init.normal_(m.q_proj_weight, std=0.02)
200
+ nn.init.normal_(m.k_proj_weight, std=0.02)
201
+ nn.init.normal_(m.v_proj_weight, std=0.02)
202
+
203
+ if m.in_proj_bias is not None:
204
+ nn.init.constant_(m.in_proj_bias, 0.0)
205
+ nn.init.constant_(m.out_proj.bias, 0.0)
206
+ if m.bias_k is not None:
207
+ nn.init.xavier_normal_(m.bias_k)
208
+ if m.bias_v is not None:
209
+ nn.init.xavier_normal_(m.bias_v)
210
+
211
+ elif (
212
+ isinstance(m, nn.Conv1d)
213
+ or isinstance(m, nn.ConvTranspose1d)
214
+ or isinstance(m, nn.Conv2d)
215
+ or isinstance(m, nn.ConvTranspose2d)
216
+ ):
217
+ m.weight.data.normal_(0.0, 0.02)
218
+
219
+ elif isinstance(m, nn.Linear):
220
+ m.weight.data.normal_(mean=0.0, std=0.02)
221
+ if m.bias is not None:
222
+ m.bias.data.zero_()
223
+
224
+ elif isinstance(m, nn.Embedding):
225
+ m.weight.data.normal_(mean=0.0, std=0.02)
226
+ if m.padding_idx is not None:
227
+ m.weight.data[m.padding_idx].zero_()
228
+
229
+ self.apply(_reset_parameters)
230
+
231
+ @torch.no_grad()
232
+ def reverse_diffusion(
233
+ self,
234
+ prompt,
235
+ target_len,
236
+ phone_id,
237
+ prompt_mask=None,
238
+ temp=0.9,
239
+ filter_thres=0.98,
240
+ n_timesteps=25,
241
+ cfg=1.0,
242
+ rescale_cfg=1.0,
243
+ ):
244
+ # prompt: (B, T)
245
+ if self.use_phone_cond and phone_id != None:
246
+ phone_embedding = self.phone_emb(phone_id)
247
+ else:
248
+ phone_embedding = None
249
+
250
+ prompt_code = prompt # (B, prompt_len)
251
+ prompt_len = prompt_code.shape[1]
252
+
253
+ x_mask = torch.ones(prompt_code.shape[0], target_len).to(
254
+ prompt_code.device
255
+ ) # (B, target_len)
256
+ phone_mask = torch.ones_like(phone_id)
257
+
258
+ if prompt_mask == None:
259
+ prompt_mask = torch.ones(prompt_code.shape[0], prompt_len).to(
260
+ prompt_code.device
261
+ ) # (B, prompt_len)
262
+
263
+ cum = torch.zeros(x_mask.shape[0], x_mask.shape[1], self.hidden_size).to(
264
+ x_mask.device
265
+ ) # (B, T, hidden_size)
266
+
267
+ bsz, seq_len, _ = cum.shape
268
+
269
+ choice_temp = 1.0
270
+ start_temp = temp # temperature for sampling
271
+ start_choice_temp = choice_temp # temperature for choicing mask tokens
272
+
273
+ xt = torch.LongTensor(bsz, seq_len).to(x_mask.device)
274
+
275
+ steps = n_timesteps
276
+ to_logit = self.to_logit
277
+ cond_emb = self.cond_emb
278
+
279
+ mask_token = self.mask_emb(torch.LongTensor([0]).to(xt.device))
280
+ mask = torch.full((bsz, seq_len, 1), True).to(x_mask.device) # (B, T, 1)
281
+ seq = torch.full((bsz, seq_len), 0).to(x_mask.device)
282
+ h = 1.0 / steps
283
+
284
+ cur_prompt = 0
285
+ cur_prompt = cur_prompt + cond_emb(prompt_code)
286
+
287
+ t_list = [1.0 - i * h for i in range(steps)]
288
+ t_list.append(0.0)
289
+ for i in range(steps):
290
+ t = t_list[i] * torch.ones(bsz).to(x_mask.device)
291
+ token = cond_emb(seq) # (B, T, hidden_size)
292
+ cur = cum + mask * mask_token[:, None, :] + (~mask) * token
293
+
294
+ xt_input = torch.cat([cur_prompt, cur], dim=1) # (B, T, hidden_size)
295
+ xt_mask = torch.cat(
296
+ [prompt_mask, x_mask], dim=1
297
+ ) # (B, T), mask is 0 for padding
298
+
299
+ embeds = self.diff_estimator(
300
+ xt_input,
301
+ t,
302
+ xt_mask,
303
+ phone_embedding=phone_embedding,
304
+ phone_mask=phone_mask,
305
+ )
306
+ embeds = embeds[:, prompt_len:, :]
307
+
308
+ # classifier free guidance
309
+ # phone_embedding=phone_embedding[:,phone_embedding.shape[1]:,:] means phone_embedding is None
310
+ if cfg > 0:
311
+ mask_embeds = self.diff_estimator(
312
+ cur,
313
+ t,
314
+ x_mask,
315
+ phone_embedding=phone_embedding[:, phone_embedding.shape[1] :, :],
316
+ phone_mask=phone_mask[:, prompt_len:],
317
+ )
318
+ pos_emb_std = embeds.std() # std(g_cond)
319
+ embeds = embeds + cfg * (embeds - mask_embeds) # g_cfg
320
+ rescale_embeds = embeds * pos_emb_std / embeds.std() # g_final
321
+ embeds = rescale_cfg * rescale_embeds + (1 - rescale_cfg) * embeds
322
+
323
+ logits = to_logit(embeds) # (B, T, codebook_size)
324
+ annealing_scale = t_list[i]
325
+
326
+ choice_temp = start_choice_temp * annealing_scale
327
+ temp = start_temp * annealing_scale
328
+ logits = top_k(logits, filter_thres)
329
+
330
+ if i == steps - 1:
331
+ # greedy
332
+ if steps == 1:
333
+ temp = 0.2
334
+ sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
335
+ else:
336
+ sampled_ids = logits.argmax(dim=-1)
337
+
338
+ else:
339
+ # sampling
340
+ sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
341
+
342
+ seq = torch.where(mask.squeeze(-1), sampled_ids, seq)
343
+
344
+ scores = logits.softmax(dim=-1)
345
+ scores = scores.gather(2, rearrange(sampled_ids, "b n -> b n 1"))
346
+ scores = rearrange(scores, "b n 1 -> b n")
347
+
348
+ scores = choice_temp * gumbel_noise(scores) + scores
349
+ scores = 1 - scores
350
+
351
+ next_t = t_list[i + 1] * torch.ones(bsz).to(x_mask.device)
352
+
353
+ next_mask_num = (self.mask_prob(next_t) * seq_len).long()[0].item()
354
+
355
+ if next_mask_num == 0:
356
+ break
357
+ scores = scores.masked_fill(
358
+ ~mask.squeeze(-1), -torch.finfo(scores.dtype).max
359
+ )
360
+
361
+ mask_indices = scores.topk(next_mask_num, dim=-1).indices
362
+ mask = torch.zeros_like(scores, dtype=torch.bool).scatter(
363
+ 1, mask_indices, True
364
+ )
365
+ seq = seq.masked_fill(mask, 0)
366
+
367
+ mask = mask.unsqueeze(-1)
368
+
369
+ cum = cum + cond_emb(seq)
370
+ xt = seq
371
+
372
+ return xt
373
+
374
+ def forward(self, x0, x_mask, phone_id=None, phone_mask=None):
375
+ # x0: (B, T)
376
+ # x_mask: (B, T) mask is 0 for padding
377
+ if self.use_phone_cond and phone_id != None:
378
+ phone_embedding = self.phone_emb(phone_id)
379
+ else:
380
+ phone_embedding = None
381
+
382
+ logits, final_mask, x0, prompt_len, mask_prob = self.compute_loss(
383
+ x0, x_mask, phone_embedding, phone_mask=phone_mask
384
+ )
385
+ return logits, final_mask, x0, prompt_len, mask_prob
models/tts/tadicodec/__pycache__/infer_utils.cpython-310.pyc ADDED
Binary file (848 Bytes). View file
 
models/tts/tadicodec/__pycache__/inference_tadicodec.cpython-310.pyc ADDED
Binary file (7.14 kB). View file
 
models/tts/tadicodec/__pycache__/llama_nar_prefix.cpython-310.pyc ADDED
Binary file (15.4 kB). View file
 
models/tts/tadicodec/__pycache__/modeling_tadicodec.cpython-310.pyc ADDED
Binary file (15 kB). View file
 
models/tts/tadicodec/infer_utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.codec.melvqgan.melspec import MelSpectrogram
2
+ from models.codec.amphion_codec.vocos import Vocos
3
+
4
+
5
+ def build_vocoder_model(cfg, device):
6
+ vocoder_model = Vocos(cfg=cfg.model.vocos)
7
+ vocoder_model.eval()
8
+ vocoder_model.to(device)
9
+ return vocoder_model
10
+
11
+
12
+ def build_mel_model(cfg, device):
13
+ mel_model = MelSpectrogram(
14
+ sampling_rate=cfg.preprocess.sample_rate,
15
+ n_fft=cfg.preprocess.n_fft,
16
+ num_mels=cfg.preprocess.num_mels,
17
+ hop_size=cfg.preprocess.hop_size,
18
+ win_size=cfg.preprocess.win_size,
19
+ fmin=cfg.preprocess.fmin,
20
+ fmax=cfg.preprocess.fmax,
21
+ )
22
+ mel_model.eval()
23
+ mel_model.to(device)
24
+ return mel_model
models/tts/tadicodec/inference_tadicodec.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import librosa
8
+ import safetensors
9
+ import accelerate
10
+ from huggingface_hub import snapshot_download
11
+
12
+ from transformers import AutoTokenizer
13
+
14
+ from models.tts.tadicodec.infer_utils import build_vocoder_model, build_mel_model
15
+ from models.tts.tadicodec.modeling_tadicodec import TaDiCodec
16
+
17
+
18
+ class TaDiCodecPipline(nn.Module):
19
+ def __init__(
20
+ self,
21
+ cfg,
22
+ model_path: str,
23
+ device: torch.device,
24
+ tokenizer_path: Optional[str] = None,
25
+ vocoder_ckpt_path: Optional[str] = None,
26
+ ):
27
+ super().__init__()
28
+ self.cfg = cfg
29
+ self.model_path = model_path
30
+ self.device = device
31
+
32
+ # tokenizer
33
+ self.tokenizer = (
34
+ AutoTokenizer.from_pretrained(tokenizer_path)
35
+ if tokenizer_path is not None
36
+ else os.path.join(model_path, "text_tokenizer")
37
+ )
38
+
39
+ # mel feature extractor
40
+ self.mel_model = build_mel_model(cfg, device)
41
+
42
+ # main model
43
+ tadiconfig = cfg.model.tadicodec
44
+
45
+ self.tadicodec = TaDiCodec(cfg=tadiconfig)
46
+ safetensors.torch.load_model(self.tadicodec, model_path, strict=False)
47
+ self.tadicodec.to(torch.float32)
48
+ self.tadicodec.to(device)
49
+ self.tadicodec.eval()
50
+
51
+ # vocoder
52
+ self.vocoder_model = build_vocoder_model(cfg, device)
53
+ v_path = (
54
+ vocoder_ckpt_path
55
+ if vocoder_ckpt_path
56
+ else os.path.join(model_path, "vocoder")
57
+ )
58
+ accelerate.load_checkpoint_and_dispatch(self.vocoder_model, v_path)
59
+
60
+ @classmethod
61
+ def from_pretrained(
62
+ cls,
63
+ ckpt_dir: str = "./ckpt/TaDiCodec",
64
+ device: Optional[torch.device] = None,
65
+ auto_download: bool = True,
66
+ ):
67
+ """Create a pipeline from a checkpoint directory or Hugging Face model ID.
68
+
69
+ Expected structure under `ckpt_dir`:
70
+ - config.json # model and preprocess config
71
+ - model.safetensors # TaDiCodec weights
72
+ - vocoder/ # directory containing vocoder weights
73
+ model.safetensors or other *.safetensors
74
+ - text_tokenizer/ # directory containing text tokenizer
75
+ tokenizer.json
76
+ tokenizer_config.json
77
+
78
+ Args:
79
+ ckpt_dir: Directory containing `config.json`, `model.safetensors`, and `vocoder/`, or Hugging Face model ID.
80
+ device: Device to place models on. Defaults to CUDA if available else CPU.
81
+ auto_download: Whether to automatically download models from Hugging Face if not found locally
82
+
83
+ Returns:
84
+ TaDiCodecPipline
85
+ """
86
+ import os
87
+ import glob
88
+ from utils.util import load_config
89
+
90
+ resolved_device = (
91
+ device
92
+ if device is not None
93
+ else torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+ )
95
+
96
+ # Resolve checkpoint directory
97
+ resolved_ckpt_dir = cls._resolve_model_path(
98
+ ckpt_dir, auto_download=auto_download, model_type="tadicodec"
99
+ )
100
+
101
+ # Load config
102
+ config_path = os.path.join(resolved_ckpt_dir, "config.json")
103
+ if not os.path.exists(config_path):
104
+ raise FileNotFoundError(f"Config not found: {config_path}")
105
+ cfg = load_config(config_path, lowercase=False)
106
+
107
+ # Resolve main model weights
108
+ model_weights_path = os.path.join(resolved_ckpt_dir, "model.safetensors")
109
+
110
+ # Resolve vocoder weights
111
+ vocoder_ckpt_path = os.path.join(resolved_ckpt_dir, "vocoder")
112
+
113
+ text_tokenizer_dir = os.path.join(resolved_ckpt_dir, "text_tokenizer")
114
+
115
+ return cls(
116
+ cfg=cfg,
117
+ model_path=model_weights_path,
118
+ device=resolved_device,
119
+ vocoder_ckpt_path=vocoder_ckpt_path,
120
+ tokenizer_path=text_tokenizer_dir,
121
+ )
122
+
123
+ @staticmethod
124
+ def _resolve_model_path(
125
+ model_path: str, auto_download: bool = True, model_type: str = "tadicodec"
126
+ ) -> str:
127
+ """
128
+ Resolve model path, downloading from Hugging Face if necessary
129
+
130
+ Args:
131
+ model_path: Local path or Hugging Face model ID
132
+ auto_download: Whether to auto-download from HF
133
+ model_type: Type of model ("tadicodec")
134
+
135
+ Returns:
136
+ Resolved local path
137
+ """
138
+ # If it's already a local path and exists, return as is
139
+ if os.path.exists(model_path):
140
+ return model_path
141
+
142
+ # If it looks like a Hugging Face model ID (contains '/')
143
+ if "/" in model_path and auto_download:
144
+ print(f"Downloading {model_type} model from Hugging Face: {model_path}")
145
+ try:
146
+ # Download to cache directory
147
+ cache_dir = os.path.join(
148
+ os.path.expanduser("~"), ".cache", "huggingface", "hub"
149
+ )
150
+ downloaded_path = snapshot_download(
151
+ repo_id=model_path,
152
+ cache_dir=cache_dir,
153
+ local_dir_use_symlinks=False,
154
+ )
155
+ print(
156
+ f"Successfully downloaded {model_type} model to: {downloaded_path}"
157
+ )
158
+ return downloaded_path
159
+ except Exception as e:
160
+ print(f"Failed to download {model_type} model from Hugging Face: {e}")
161
+ raise ValueError(
162
+ f"Could not download {model_type} model from {model_path}"
163
+ )
164
+
165
+ # If it's a local path that doesn't exist
166
+ if not os.path.exists(model_path):
167
+ if auto_download:
168
+ raise ValueError(
169
+ f"Model path does not exist: {model_path}. Set auto_download=True to download from Hugging Face."
170
+ )
171
+ else:
172
+ raise FileNotFoundError(f"Model path does not exist: {model_path}")
173
+
174
+ return model_path
175
+
176
+ @torch.no_grad()
177
+ def __call__(
178
+ self,
179
+ text: Optional[str] = None,
180
+ speech_path: Optional[str] = None,
181
+ prompt_text: Optional[str] = None,
182
+ prompt_speech_path: Optional[str] = None,
183
+ n_timesteps: int = 32,
184
+ return_code: bool = False,
185
+ cfg_scale: float = 2.0,
186
+ ):
187
+ # tokenize text
188
+ text_input_ids = self.tokenize_text(text, prompt_text)
189
+
190
+ # extract mel features
191
+ target_mel = self.extract_mel_feature(speech_path)
192
+ prompt_mel = (
193
+ self.extract_mel_feature(prompt_speech_path) if prompt_speech_path else None
194
+ )
195
+
196
+ # encode to codes from mel
197
+ if prompt_mel is not None:
198
+ vq_emb, indices = self.encode(torch.cat([prompt_mel, target_mel], dim=1))
199
+ else:
200
+ vq_emb, indices = self.encode(target_mel)
201
+
202
+ if return_code:
203
+ return indices
204
+
205
+ # decode mel from codes + optional text/prompt
206
+ rec_mel = self.decode(
207
+ vq_emb=vq_emb,
208
+ text_token_ids=text_input_ids,
209
+ prompt_mel=(
210
+ prompt_mel if prompt_mel is not None else target_mel[:, : 50 * 3]
211
+ ),
212
+ n_timesteps=n_timesteps,
213
+ cfg=cfg_scale,
214
+ rescale_cfg=0.75,
215
+ )
216
+
217
+ # vocoder
218
+ rec_audio = (
219
+ self.vocoder_model(rec_mel.transpose(1, 2)).detach().cpu().numpy()[0][0]
220
+ )
221
+ return rec_audio
222
+
223
+ def tokenize_text(
224
+ self, text: Optional[str] = None, prompt_text: Optional[str] = None
225
+ ):
226
+ if self.tokenizer is None or text is None:
227
+ return None
228
+ if prompt_text is not None:
229
+ text_token_ids = self.tokenizer(
230
+ prompt_text + text, return_tensors="pt", add_special_tokens=False
231
+ ).input_ids.to(self.device)
232
+ else:
233
+ text_token_ids = self.tokenizer(
234
+ text, return_tensors="pt", add_special_tokens=False
235
+ ).input_ids.to(self.device)
236
+ return text_token_ids
237
+
238
+ @torch.no_grad()
239
+ def extract_mel_feature(self, speech_path: Optional[str]):
240
+ assert speech_path is not None and os.path.exists(
241
+ speech_path
242
+ ), f"Invalid speech_path: {speech_path}"
243
+ speech = librosa.load(speech_path, sr=24000)[0]
244
+ speech = torch.tensor(speech).to(self.device).unsqueeze(0)
245
+ mel_feature = self.mel_model(speech) # (B, n_mels, T)
246
+ mel_feature = mel_feature.transpose(1, 2) # (B, T, n_mels)
247
+ mel_feature = (mel_feature - self.cfg.preprocess.mel_mean) / math.sqrt(
248
+ self.cfg.preprocess.mel_var
249
+ )
250
+ return mel_feature
251
+
252
+ @torch.no_grad()
253
+ def encode(self, mel_feat: torch.Tensor):
254
+ vq_emb, indices = self.tadicodec.encode(
255
+ mel_feat, torch.ones(mel_feat.shape[0], mel_feat.shape[1]).to(self.device)
256
+ )
257
+ return vq_emb, indices
258
+
259
+ @torch.no_grad()
260
+ def decode(
261
+ self,
262
+ vq_emb: Optional[torch.Tensor] = None,
263
+ indices: Optional[torch.Tensor] = None,
264
+ text_token_ids: Optional[torch.Tensor] = None,
265
+ prompt_mel: Optional[torch.Tensor] = None,
266
+ n_timesteps: int = 32,
267
+ cfg: float = 1.0,
268
+ rescale_cfg: float = 0.75,
269
+ ):
270
+ rec_mel = self.tadicodec.reverse_diffusion(
271
+ vq_emb=vq_emb,
272
+ indices=indices,
273
+ text_ids=text_token_ids,
274
+ prompt_mel=prompt_mel,
275
+ n_timesteps=n_timesteps,
276
+ cfg=cfg,
277
+ rescale_cfg=rescale_cfg,
278
+ )
279
+ return rec_mel
models/tts/tadicodec/llama_nar_prefix.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import os
6
+ import torch.nn as nn
7
+ from typing import List, Optional, Tuple, Union
8
+ import math
9
+
10
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
11
+ from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
12
+
13
+
14
+ class SinusoidalPosEmb(nn.Module):
15
+ """
16
+ Sinusoidal Positional Embedding module.
17
+
18
+ This module generates sinusoidal positional embeddings for a given 1D input tensor,
19
+ which is commonly used for representing timesteps in diffusion models.
20
+ """
21
+
22
+ def __init__(self, dim: int):
23
+ """
24
+ Initializes the SinusoidalPosEmb module.
25
+
26
+ Args:
27
+ dim (int): The dimension of the embedding.
28
+ """
29
+ super().__init__()
30
+ self.dim = dim
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Generates the positional embedding.
35
+
36
+ Args:
37
+ x (torch.Tensor): A 1D tensor of positions (e.g., timesteps), shape `(batch_size,)`.
38
+
39
+ Returns:
40
+ torch.Tensor: The positional embeddings, shape `(batch_size, dim)`.
41
+ """
42
+ device = x.device
43
+ half_dim = self.dim // 2
44
+ # Calculate the embedding frequencies
45
+ emb = math.log(10000) / (half_dim - 1)
46
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
47
+ # Create the embedding matrix by multiplying positions with frequencies
48
+ emb = x[:, None] * emb[None, :] * 1.0
49
+ # Concatenate sine and cosine components to form the final embedding
50
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
51
+ return emb
52
+
53
+
54
+ class LlamaAdaptiveRMSNorm(nn.Module):
55
+ """
56
+ Adaptive Root Mean Square Layer Normalization.
57
+
58
+ This is a variant of RMSNorm where the scaling factor (weight) is adaptively
59
+ predicted from a conditional embedding, allowing for conditional modulation
60
+ of the normalized hidden states.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ hidden_size: int = 1024,
66
+ eps: float = 1e-6,
67
+ dim_cond: int = 1024,
68
+ use_cond: bool = True,
69
+ ):
70
+ """
71
+ Initializes the LlamaAdaptiveRMSNorm module.
72
+
73
+ Args:
74
+ hidden_size (int): The dimension of the hidden states to be normalized.
75
+ eps (float): A small value added to the variance for numerical stability.
76
+ dim_cond (int): The dimension of the conditional embedding.
77
+ use_cond (bool): If True, use conditional embedding to modulate the output.
78
+ """
79
+ super().__init__()
80
+ self.use_cond = use_cond
81
+ if self.use_cond:
82
+ # Linear layer to project the conditional embedding to the hidden size
83
+ self.to_weight = nn.Linear(dim_cond, hidden_size)
84
+ # Initialize weights to zero and bias to one for an identity transformation at the start
85
+ nn.init.zeros_(self.to_weight.weight)
86
+ nn.init.ones_(self.to_weight.bias)
87
+ self.variance_epsilon = eps
88
+ self._is_hf_initialized = True # Disable automatic Hugging Face initialization
89
+
90
+ def forward(
91
+ self, hidden_states: torch.Tensor, cond_embedding: Optional[torch.Tensor] = None
92
+ ) -> torch.Tensor:
93
+ """
94
+ Applies the adaptive RMS normalization.
95
+
96
+ Args:
97
+ hidden_states (torch.Tensor): The input tensor, shape `(batch, seq_len, hidden_size)`.
98
+ cond_embedding (Optional[torch.Tensor]): The conditional embedding, shape `(batch, dim_cond)` or `(batch, seq_len, dim_cond)`.
99
+
100
+ Returns:
101
+ torch.Tensor: The normalized and modulated hidden states.
102
+ """
103
+ input_dtype = hidden_states.dtype
104
+ # Calculate variance and normalize the hidden states
105
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
106
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
107
+
108
+ # Apply conditional modulation if enabled
109
+ if self.use_cond:
110
+ # Project conditional embedding to get the scaling weight
111
+ weight = self.to_weight(cond_embedding)
112
+ # Unsqueeze if the conditional embedding is per-batch instead of per-token
113
+ if len(weight.shape) == 2:
114
+ weight = weight.unsqueeze(1)
115
+
116
+ # Apply the learned scaling factor
117
+ return (weight * hidden_states).to(input_dtype)
118
+ else:
119
+ return hidden_states
120
+
121
+
122
+ class LlamaNARDecoderLayer(LlamaDecoderLayer):
123
+ """
124
+ A Non-Autoregressive (NAR) Llama Decoder Layer using adaptive layer normalization.
125
+
126
+ This class overrides the standard LlamaDecoderLayer to replace its RMSNorm
127
+ modules with LlamaAdaptiveRMSNorm, allowing it to be conditioned on an external embedding
128
+ (e.g., from diffusion timesteps).
129
+ """
130
+
131
+ def __init__(self, config: LlamaConfig, layer_idx: int, use_cond: bool = True):
132
+ """
133
+ Overrides the LlamaDecoderLayer to use adaptive layer normalization.
134
+
135
+ Args:
136
+ config (LlamaConfig): The configuration object for the Llama model.
137
+ layer_idx (int): The index of the layer.
138
+ use_cond (bool): Whether to use adaptive conditioning in the layer norms.
139
+ """
140
+ super().__init__(config, layer_idx) # init attention, mlp, etc.
141
+ # Override the standard layer norms with our adaptive versions
142
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
143
+ config.hidden_size,
144
+ eps=config.rms_norm_eps,
145
+ dim_cond=config.hidden_size,
146
+ use_cond=use_cond,
147
+ )
148
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
149
+ config.hidden_size,
150
+ eps=config.rms_norm_eps,
151
+ dim_cond=config.hidden_size,
152
+ use_cond=use_cond,
153
+ )
154
+
155
+ def forward(
156
+ self,
157
+ hidden_states: torch.Tensor,
158
+ cond_embedding: torch.Tensor,
159
+ attention_mask: Optional[torch.Tensor] = None,
160
+ position_ids: Optional[torch.LongTensor] = None,
161
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
162
+ output_attentions: Optional[bool] = False,
163
+ use_cache: Optional[bool] = False,
164
+ ) -> Tuple[
165
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
166
+ ]:
167
+ """
168
+ Forward pass for the NAR decoder layer, including conditional embedding.
169
+
170
+ Args:
171
+ hidden_states (torch.FloatTensor): Input to the layer of shape `(batch, seq_len, embed_dim)`.
172
+ cond_embedding (torch.Tensor): Conditional embedding for adaptive normalization.
173
+ attention_mask (Optional[torch.Tensor]): Attention mask of size
174
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
175
+ position_ids (Optional[torch.LongTensor]): Indices of positions of each input sequence tokens.
176
+ past_key_value (Optional[Tuple[torch.Tensor]]): Cached past key and value projection states.
177
+ output_attentions (Optional[bool]): Whether or not to return the attentions tensors of all attention layers.
178
+ use_cache (Optional[bool]): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding.
179
+
180
+ Returns:
181
+ Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
182
+ A tuple containing the output hidden states and optionally the attention weights and past key/value states.
183
+ """
184
+ residual = hidden_states
185
+
186
+ # Apply adaptive pre-attention layer norm
187
+ hidden_states = self.input_layernorm(
188
+ hidden_states, cond_embedding=cond_embedding
189
+ )
190
+
191
+ # Self Attention block
192
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
193
+ hidden_states=hidden_states,
194
+ attention_mask=attention_mask,
195
+ position_ids=position_ids,
196
+ past_key_value=past_key_value,
197
+ output_attentions=output_attentions,
198
+ use_cache=use_cache,
199
+ )
200
+ hidden_states = residual + hidden_states
201
+
202
+ # Fully Connected block
203
+ residual = hidden_states
204
+ # Apply adaptive post-attention layer norm
205
+ hidden_states = self.post_attention_layernorm(
206
+ hidden_states, cond_embedding=cond_embedding
207
+ )
208
+ hidden_states = self.mlp(hidden_states)
209
+ hidden_states = residual + hidden_states
210
+
211
+ outputs = (hidden_states,)
212
+
213
+ if output_attentions:
214
+ outputs += (self_attn_weights,)
215
+
216
+ if use_cache:
217
+ outputs += (present_key_value,)
218
+
219
+ return outputs
220
+
221
+
222
+ class DiffLlamaPrefix(LlamaModel):
223
+ """
224
+ A Llama-based non-autoregressive transformer model for diffusion tasks.
225
+
226
+ This model uses a Llama architecture but modifies it for non-autoregressive generation.
227
+ Key features:
228
+ 1. Non-causal (fully-visible) attention mask.
229
+ 2. Adaptive layer normalization conditioned on diffusion timesteps.
230
+ 3. Ability to be conditioned on text embeddings, which are prepended as a prefix.
231
+ 4. Input and output linear projection layers for feature mapping.
232
+ """
233
+
234
+ def __init__(
235
+ self,
236
+ hidden_size: int = 1024,
237
+ num_heads: int = 16,
238
+ num_layers: int = 16,
239
+ in_dim: Optional[int] = None,
240
+ out_dim: Optional[int] = None,
241
+ use_text_emb: bool = True,
242
+ use_diff_step: bool = True,
243
+ use_cond: bool = True,
244
+ config: LlamaConfig = LlamaConfig(
245
+ vocab_size=0, hidden_size=256, num_attention_heads=1, num_hidden_layers=1
246
+ ),
247
+ ):
248
+ """
249
+ Initializes the DiffLlamaPrefix model.
250
+
251
+ Args:
252
+ hidden_size (int): The hidden dimension of the transformer.
253
+ num_heads (int): The number of attention heads.
254
+ num_layers (int): The number of transformer layers.
255
+ in_dim (Optional[int]): Dimension of the input features. If set, an input linear layer is created.
256
+ out_dim (Optional[int]): Dimension of the output features. If set, an output linear layer is created.
257
+ use_text_emb (bool): Whether to use text embeddings as a conditional prefix.
258
+ use_diff_step (bool): Whether to use diffusion timestep conditioning.
259
+ use_cond (bool): Whether to use an additional per-token conditional input `cond`.
260
+ config (LlamaConfig): A LlamaConfig object. A default is provided for convenience.
261
+ """
262
+ super().__init__(config)
263
+
264
+ self.use_text_emb = use_text_emb
265
+ self.use_diff_step = use_diff_step
266
+ self.use_cond = use_cond
267
+ self.in_dim = in_dim
268
+ self.out_dim = out_dim
269
+
270
+ # Create a stack of non-autoregressive Llama layers
271
+ self.layers = nn.ModuleList(
272
+ [
273
+ LlamaNARDecoderLayer(
274
+ LlamaConfig(
275
+ hidden_size=hidden_size,
276
+ num_attention_heads=num_heads,
277
+ max_position_embeddings=4096,
278
+ intermediate_size=hidden_size * 4,
279
+ ),
280
+ layer_idx=i,
281
+ use_cond=use_cond,
282
+ )
283
+ for i in range(num_layers)
284
+ ]
285
+ )
286
+
287
+ # Final adaptive layer norm
288
+ self.norm = LlamaAdaptiveRMSNorm(
289
+ hidden_size,
290
+ dim_cond=hidden_size,
291
+ use_cond=(use_cond or use_diff_step),
292
+ )
293
+
294
+ # MLP for processing text embedding condition
295
+ if self.use_text_emb:
296
+ self.text_mlp = nn.Sequential(
297
+ nn.Linear(hidden_size, hidden_size * 4),
298
+ nn.SiLU(),
299
+ nn.Linear(hidden_size * 4, hidden_size),
300
+ )
301
+
302
+ # Modules for processing diffusion timestep condition
303
+ if self.use_diff_step:
304
+ self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
305
+ self.diff_step_mlp = nn.Sequential(
306
+ nn.Linear(hidden_size, hidden_size * 4),
307
+ nn.SiLU(),
308
+ nn.Linear(hidden_size * 4, hidden_size),
309
+ )
310
+
311
+ # MLP for processing the general per-token `cond` condition
312
+ if self.use_cond:
313
+ self.cond_mlp = nn.Sequential(
314
+ nn.Linear(hidden_size, hidden_size * 4),
315
+ nn.SiLU(),
316
+ nn.Linear(hidden_size * 4, hidden_size),
317
+ )
318
+
319
+ # Ensure all layers use the correct adaptive norm configuration
320
+ for layer in self.layers:
321
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
322
+ hidden_size,
323
+ dim_cond=hidden_size,
324
+ use_cond=(use_cond or use_diff_step),
325
+ )
326
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
327
+ hidden_size,
328
+ dim_cond=hidden_size,
329
+ use_cond=(use_cond or use_diff_step),
330
+ )
331
+
332
+ # We handle embeddings manually, so disable the default token embedder
333
+ self.embed_tokens = None
334
+
335
+ # Optional input and output projection layers
336
+ if self.in_dim is not None:
337
+ self.in_linear = nn.Linear(self.in_dim, hidden_size)
338
+ if self.out_dim is not None:
339
+ self.out_linear = nn.Linear(hidden_size, self.out_dim)
340
+
341
+ self.post_init()
342
+
343
+ def _prepare_decoder_attention_mask(
344
+ self,
345
+ attention_mask: torch.Tensor,
346
+ input_shape: Tuple[int, int, int],
347
+ inputs_embeds: torch.Tensor,
348
+ past_key_values_length: int,
349
+ ) -> Optional[torch.Tensor]:
350
+ """
351
+ Creates a non-causal (fully-visible) attention mask.
352
+
353
+ This method overrides the default causal mask creation in LlamaModel.
354
+ It converts a 2D padding mask `[bsz, seq_len]` into a 4D attention mask
355
+ `[bsz, 1, tgt_seq_len, src_seq_len]` suitable for self-attention,
356
+ without applying a causal triangle.
357
+
358
+ Args:
359
+ attention_mask (torch.Tensor): The 2D padding mask.
360
+ input_shape (Tuple[int, int, int]): The shape of the input embeddings.
361
+ inputs_embeds (torch.Tensor): The input embeddings tensor.
362
+ past_key_values_length (int): The length of any cached key-values.
363
+
364
+ Returns:
365
+ Optional[torch.Tensor]: The 4D attention mask, or None if the input mask is None.
366
+ """
367
+ combined_attention_mask = None
368
+
369
+ def _expand_mask(
370
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
371
+ ) -> torch.Tensor:
372
+ """
373
+ Expands a 2D attention mask to a 4D attention mask.
374
+ """
375
+ bsz, src_len = mask.size()
376
+ tgt_len = tgt_len if tgt_len is not None else src_len
377
+
378
+ expanded_mask = (
379
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
380
+ )
381
+
382
+ # Invert the mask: 1.0 for valid tokens, 0.0 for padded tokens
383
+ # and then convert to additive mask format (-inf for masked positions)
384
+ inverted_mask = 1.0 - expanded_mask
385
+ return inverted_mask.masked_fill(
386
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
387
+ )
388
+
389
+ if attention_mask is not None:
390
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
391
+ expanded_attn_mask = _expand_mask(
392
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
393
+ ).to(inputs_embeds.device)
394
+ combined_attention_mask = (
395
+ expanded_attn_mask
396
+ if combined_attention_mask is None
397
+ else expanded_attn_mask + combined_attention_mask
398
+ )
399
+
400
+ return combined_attention_mask
401
+
402
+ def forward(
403
+ self,
404
+ x: torch.Tensor,
405
+ x_mask: torch.Tensor,
406
+ cond: Optional[torch.Tensor] = None,
407
+ diffusion_step: Optional[torch.Tensor] = None,
408
+ text_embedding: Optional[torch.Tensor] = None,
409
+ text_mask: Optional[torch.Tensor] = None,
410
+ input_ids: Optional[torch.LongTensor] = None,
411
+ attention_mask: Optional[torch.LongTensor] = None,
412
+ position_ids: Optional[torch.LongTensor] = None,
413
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
414
+ inputs_embeds: Optional[torch.FloatTensor] = None,
415
+ use_cache: Optional[bool] = None,
416
+ output_attentions: Optional[bool] = None,
417
+ output_hidden_states: Optional[bool] = None,
418
+ return_dict: Optional[bool] = None,
419
+ ) -> Union[torch.Tensor, Tuple]:
420
+ """
421
+ Forward pass of the DiffLlamaPrefix model.
422
+
423
+ Args:
424
+ x (torch.Tensor): The primary input tensor, e.g., noisy data in diffusion `(batch, seq_len, in_dim)`.
425
+ x_mask (torch.Tensor): The padding mask for `x`, shape `(batch, seq_len)`.
426
+ cond (Optional[torch.Tensor]): Additional per-token conditional input, shape `(batch, seq_len, hidden_size)`.
427
+ diffusion_step (Optional[torch.Tensor]): Diffusion timesteps, shape `(batch,)`.
428
+ text_embedding (Optional[torch.Tensor]): Text embeddings to be used as a prefix, shape `(batch, text_len, hidden_size)`.
429
+ text_mask (Optional[torch.Tensor]): The padding mask for `text_embedding`, shape `(batch, text_len)`.
430
+ input_ids, attention_mask, etc.: Standard Hugging Face arguments (mostly unused here but kept for compatibility).
431
+
432
+ Returns:
433
+ Union[torch.Tensor, Tuple]: The final output tensor `(batch, seq_len, out_dim)`.
434
+ If `output_hidden_states` is True, returns a tuple `(output_tensor, all_hidden_states)`.
435
+ """
436
+ # 1. Input and Conditional Projections
437
+ if self.in_dim is not None:
438
+ x = self.in_linear(x)
439
+
440
+ if self.use_cond and cond is not None:
441
+ cond_embedding = self.cond_mlp(cond) # (B, T, C)
442
+ x = x + cond_embedding
443
+
444
+ # 2. Prepend Text Embedding Prefix
445
+ if self.use_text_emb and text_embedding is not None:
446
+ text_embedding = self.text_mlp(text_embedding) # (B, T, C)
447
+ text_length = text_embedding.shape[1]
448
+ # Concatenate text prefix and main input
449
+ inputs_embeds = torch.cat([text_embedding, x], dim=1)
450
+ attention_mask = torch.cat([text_mask, x_mask], dim=1)
451
+ else:
452
+ inputs_embeds = x
453
+ attention_mask = x_mask
454
+ text_length = 0
455
+
456
+ # 3. Process Diffusion Step Embedding for Adaptive Norm
457
+ if self.use_diff_step and diffusion_step is not None:
458
+ # Convert scalar timesteps to vector embeddings and project with an MLP
459
+ diffusion_step_emb = self.diff_step_embedding(diffusion_step).to(x.device)
460
+ diffusion_step_emb = self.diff_step_mlp(diffusion_step_emb) # (B, C)
461
+ else:
462
+ diffusion_step_emb = None
463
+
464
+ # 4. Standard Transformer Preamble (adapted from LlamaModel)
465
+ batch_size, seq_length, _ = inputs_embeds.shape
466
+
467
+ output_attentions = (
468
+ output_attentions
469
+ if output_attentions is not None
470
+ else self.config.output_attentions
471
+ )
472
+ output_hidden_states = (
473
+ output_hidden_states
474
+ if output_hidden_states is not None
475
+ else self.config.output_hidden_states
476
+ )
477
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
478
+ return_dict = (
479
+ return_dict if return_dict is not None else self.config.use_return_dict
480
+ )
481
+
482
+ seq_length_with_past = seq_length
483
+ past_key_values_length = 0
484
+
485
+ if past_key_values is not None:
486
+ past_key_values_length = past_key_values[0][0].shape[2]
487
+ seq_length_with_past = seq_length_with_past + past_key_values_length
488
+
489
+ if position_ids is None:
490
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
491
+ position_ids = torch.arange(
492
+ past_key_values_length,
493
+ seq_length + past_key_values_length,
494
+ dtype=torch.long,
495
+ device=device,
496
+ )
497
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
498
+ else:
499
+ position_ids = position_ids.view(-1, seq_length).long()
500
+
501
+ if attention_mask is None:
502
+ attention_mask = torch.ones(
503
+ (batch_size, seq_length_with_past),
504
+ dtype=torch.bool,
505
+ device=inputs_embeds.device,
506
+ )
507
+ # Create the non-causal attention mask
508
+ attention_mask = self._prepare_decoder_attention_mask(
509
+ attention_mask,
510
+ (batch_size, seq_length),
511
+ inputs_embeds,
512
+ past_key_values_length,
513
+ )
514
+
515
+ hidden_states = inputs_embeds
516
+
517
+ # 5. Transformer Decoder Layers
518
+ all_hidden_states = () if output_hidden_states else None
519
+ all_self_attns = () if output_attentions else None
520
+ next_decoder_cache = () if use_cache else None
521
+
522
+ for idx, decoder_layer in enumerate(self.layers):
523
+ if output_hidden_states:
524
+ # Store hidden states before the layer, excluding the text prefix part
525
+ all_hidden_states += (hidden_states[:, text_length:],)
526
+
527
+ past_key_value = (
528
+ past_key_values[idx] if past_key_values is not None else None
529
+ )
530
+
531
+ # Pass the processed diffusion step embedding to the adaptive layer
532
+ layer_outputs = decoder_layer(
533
+ hidden_states,
534
+ attention_mask=attention_mask,
535
+ position_ids=position_ids,
536
+ past_key_value=past_key_value,
537
+ output_attentions=output_attentions,
538
+ use_cache=use_cache,
539
+ cond_embedding=diffusion_step_emb,
540
+ )
541
+ hidden_states = layer_outputs[0]
542
+
543
+ if use_cache:
544
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
545
+
546
+ if output_attentions:
547
+ all_self_attns += (layer_outputs[1],)
548
+
549
+ # 6. Final Normalization and Output Processing
550
+ hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step_emb)
551
+
552
+ if output_hidden_states:
553
+ # Add hidden states from the last decoder layer
554
+ all_hidden_states += (hidden_states[:, text_length:],)
555
+
556
+ next_cache = next_decoder_cache if use_cache else None
557
+
558
+ # Remove the text prefix from the final output sequence
559
+ hidden_states = hidden_states[
560
+ :,
561
+ text_length:,
562
+ ]
563
+
564
+ # Apply final output projection if specified
565
+ if self.out_dim is not None:
566
+ hidden_states = self.out_linear(hidden_states)
567
+
568
+ # 7. Return results (simplified for this snippet)
569
+ if output_hidden_states:
570
+ return hidden_states, all_hidden_states
571
+ else:
572
+ return hidden_states
models/tts/tadicodec/modeling_tadicodec.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import math
5
+ from einops import rearrange
6
+ from typing import Optional, Dict, Any
7
+
8
+ from models.tts.tadicodec.llama_nar_prefix import DiffLlamaPrefix
9
+ from models.codec.amphion_codec.quantize.bsq import (
10
+ BinarySphericalQuantizer,
11
+ SimpleQuantizer,
12
+ )
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class TaDiCodec(nn.Module):
17
+ """
18
+ TaDiCodec: A diffusion-based codec model for Text-to-Speech (TTS)
19
+ that uses a non-autoregressive Llama-style transformer architecture.
20
+
21
+ It consists of:
22
+ 1. An Encoder that processes input features (e.g., mel-spectrograms or SSL features)
23
+ into a latent representation.
24
+ 2. A Vector Quantizer (VQ) that discretizes the latent representation into codes.
25
+ 3. A Decoder that generates the output feature (e.g., mel-spectrogram) from the codes,
26
+ optional text conditioning, and a prompt, using a flow-matching diffusion process.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ mel_dim: int = 128,
32
+ in_dim: int = 128,
33
+ hidden_size: int = 1024,
34
+ encoder_num_layers: int = 8,
35
+ decoder_num_layers: int = 16,
36
+ num_heads: int = 16,
37
+ cond_drop_p: float = 0.2, # drop code for decoder
38
+ context_drop_p: float = 0.2, # drop context (mel) for decoder
39
+ down_sample_factor: int = 8, # down sample factor for vq
40
+ vq_emb_dim: int = 14, # codebook size 2^vq_emb_dim, 2^14 = 16384
41
+ use_text_cond: bool = True, # use text cond for decoder
42
+ text_vocab_size: int = 32100, # vocab size
43
+ cond_dim: int = 1024,
44
+ cond_scale_factor: int = 1,
45
+ sigma: float = 1e-5,
46
+ time_scheduler: str = "linear",
47
+ use_vq: bool = True,
48
+ vq_type: str = "bsq",
49
+ use_repa_loss: bool = False,
50
+ cfg: Optional[Any] = None,
51
+ ):
52
+ """
53
+ Initializes the TaDiCodec model.
54
+
55
+ Args:
56
+ mel_dim (int): Dimension of the mel-spectrogram.
57
+ in_dim (int): Dimension of the encoder's input features.
58
+ hidden_size (int): Hidden size of the transformer models.
59
+ encoder_num_layers (int): Number of layers in the encoder transformer.
60
+ decoder_num_layers (int): Number of layers in the decoder transformer.
61
+ num_heads (int): Number of attention heads in the transformers.
62
+ cond_drop_p (float): Dropout probability for the VQ code condition in the decoder.
63
+ context_drop_p (float): Dropout probability for the prompt context in the decoder.
64
+ down_sample_factor (int): Factor by which to downsample the latent representation before VQ.
65
+ vq_emb_dim (int): Dimension of the vector quantizer's embedding space.
66
+ use_text_cond (bool): Whether to use text embeddings as a condition in the decoder.
67
+ text_vocab_size (int): Size of the text vocabulary.
68
+ cond_dim (int): Dimension of the conditional input.
69
+ cond_scale_factor (int): Scaling factor for the condition.
70
+ sigma (float): Small constant used in the flow matching formula.
71
+ time_scheduler (str): Type of time scheduler for diffusion (e.g., 'linear').
72
+ use_vq (bool): Whether to use vector quantization.
73
+ vq_type (str): Type of vector quantizer ('bsq' or 'simple').
74
+ use_repa_loss (bool): Whether to use the representational alignment loss.
75
+ cfg (Optional[Any]): A configuration object that can override the default parameters.
76
+ """
77
+ super().__init__()
78
+
79
+ # Override parameters with values from the config object if provided
80
+ mel_dim = (
81
+ cfg.mel_dim if cfg is not None and hasattr(cfg, "mel_dim") else mel_dim
82
+ )
83
+ in_dim = cfg.in_dim if cfg is not None and hasattr(cfg, "in_dim") else in_dim
84
+ hidden_size = (
85
+ cfg.hidden_size
86
+ if cfg is not None and hasattr(cfg, "hidden_size")
87
+ else hidden_size
88
+ )
89
+ encoder_num_layers = (
90
+ cfg.encoder_num_layers
91
+ if cfg is not None and hasattr(cfg, "encoder_num_layers")
92
+ else encoder_num_layers
93
+ )
94
+ decoder_num_layers = (
95
+ cfg.decoder_num_layers
96
+ if cfg is not None and hasattr(cfg, "decoder_num_layers")
97
+ else decoder_num_layers
98
+ )
99
+ num_heads = (
100
+ cfg.num_heads
101
+ if cfg is not None and hasattr(cfg, "num_heads")
102
+ else num_heads
103
+ )
104
+ cond_drop_p = (
105
+ cfg.cond_drop_p
106
+ if cfg is not None and hasattr(cfg, "cond_drop_p")
107
+ else cond_drop_p
108
+ )
109
+ context_drop_p = (
110
+ cfg.context_drop_p
111
+ if cfg is not None and hasattr(cfg, "context_drop_p")
112
+ else context_drop_p
113
+ )
114
+ down_sample_factor = (
115
+ cfg.down_sample_factor
116
+ if cfg is not None and hasattr(cfg, "down_sample_factor")
117
+ else down_sample_factor
118
+ )
119
+ vq_emb_dim = (
120
+ cfg.vq_emb_dim
121
+ if cfg is not None and hasattr(cfg, "vq_emb_dim")
122
+ else vq_emb_dim
123
+ )
124
+ use_text_cond = (
125
+ cfg.use_text_cond
126
+ if cfg is not None and hasattr(cfg, "use_text_cond")
127
+ else use_text_cond
128
+ )
129
+ text_vocab_size = (
130
+ cfg.text_vocab_size
131
+ if cfg is not None and hasattr(cfg, "text_vocab_size")
132
+ else text_vocab_size
133
+ )
134
+ cond_dim = (
135
+ cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
136
+ )
137
+ cond_scale_factor = (
138
+ cfg.cond_scale_factor
139
+ if cfg is not None and hasattr(cfg, "cond_scale_factor")
140
+ else cond_scale_factor
141
+ )
142
+ sigma = cfg.sigma if cfg is not None and hasattr(cfg, "sigma") else sigma
143
+ time_scheduler = (
144
+ cfg.time_scheduler
145
+ if cfg is not None and hasattr(cfg, "time_scheduler")
146
+ else time_scheduler
147
+ )
148
+ use_vq = cfg.use_vq if cfg is not None and hasattr(cfg, "use_vq") else use_vq
149
+ vq_type = (
150
+ cfg.vq_type if cfg is not None and hasattr(cfg, "vq_type") else vq_type
151
+ )
152
+ use_repa_loss = (
153
+ cfg.use_repa_loss
154
+ if cfg is not None and hasattr(cfg, "use_repa_loss")
155
+ else use_repa_loss
156
+ )
157
+
158
+ self.mel_dim = mel_dim
159
+ self.in_dim = in_dim
160
+ self.hidden_size = hidden_size
161
+ self.encoder_num_layers = encoder_num_layers
162
+ self.decoder_num_layers = decoder_num_layers
163
+ self.num_heads = num_heads
164
+ self.cond_drop_p = cond_drop_p
165
+ self.context_drop_p = context_drop_p
166
+ self.vq_emb_dim = vq_emb_dim
167
+ self.down_sample_factor = down_sample_factor
168
+ self.use_text_cond = use_text_cond
169
+ self.text_vocab_size = text_vocab_size
170
+ self.cond_dim = cond_dim
171
+ self.cond_scale_factor = cond_scale_factor
172
+ self.sigma = sigma
173
+ self.time_scheduler = time_scheduler
174
+ self.use_vq = use_vq
175
+ self.vq_type = vq_type
176
+ self.use_repa_loss = use_repa_loss
177
+
178
+ # Text embedding layer
179
+ if self.use_text_cond:
180
+ self.text_emb = nn.Embedding(text_vocab_size, hidden_size)
181
+
182
+ # VQ related layers
183
+ self.vq_in_linear = nn.Linear(hidden_size, vq_emb_dim)
184
+ if self.use_vq:
185
+ if self.vq_type == "bsq":
186
+ self.bsq = BinarySphericalQuantizer(embed_dim=vq_emb_dim)
187
+ else:
188
+ self.bsq = SimpleQuantizer(embed_dim=vq_emb_dim)
189
+ self.vq_out_linear = nn.Linear(vq_emb_dim, hidden_size)
190
+
191
+ # Repa (Representational Alignment) MLP for auxiliary loss
192
+ if self.use_repa_loss:
193
+ self.repa_mlp = nn.Sequential(
194
+ nn.Linear(hidden_size, hidden_size * 2),
195
+ nn.GELU(),
196
+ nn.Linear(hidden_size * 2, 1024),
197
+ )
198
+ self.repa_layer_idx = (
199
+ 6 # The decoder layer from which to extract hidden states
200
+ )
201
+
202
+ self.reset_parameters()
203
+
204
+ # Encoder: A non-autoregressive Llama-style model without time/text conditioning
205
+ self.encoder = DiffLlamaPrefix(
206
+ hidden_size=hidden_size,
207
+ num_layers=encoder_num_layers,
208
+ num_heads=num_heads,
209
+ in_dim=self.in_dim,
210
+ out_dim=None, # Outputs hidden states for VQ
211
+ use_text_emb=False,
212
+ use_diff_step=False,
213
+ use_cond=False,
214
+ )
215
+
216
+ # Decoder: A non-autoregressive Llama-style model with time, text, and code conditioning
217
+ self.decoder = DiffLlamaPrefix(
218
+ hidden_size=hidden_size,
219
+ num_layers=decoder_num_layers,
220
+ num_heads=num_heads,
221
+ in_dim=self.mel_dim,
222
+ out_dim=self.mel_dim,
223
+ use_text_emb=use_text_cond,
224
+ use_diff_step=True,
225
+ use_cond=True,
226
+ )
227
+
228
+ @torch.no_grad()
229
+ def forward_diffusion(
230
+ self, x: torch.Tensor, t: torch.Tensor
231
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
232
+ """
233
+ Performs the forward diffusion process based on flow matching.
234
+ It takes the clean data `x` and a timestep `t` to produce a noisy sample `xt`.
235
+ It also creates a prompt/target mask.
236
+
237
+ Args:
238
+ x (torch.Tensor): The clean input data (e.g., mel-spectrogram), shape `(B, T, mel_dim)`.
239
+ t (torch.Tensor): The diffusion timestep for each sample in the batch, shape `(B,)`.
240
+
241
+ Returns:
242
+ Tuple[torch.Tensor, ...]:
243
+ - xt (torch.Tensor): The noisy sample at time `t`, shape `(B, T, mel_dim)`.
244
+ - z (torch.Tensor): The noise vector used, drawn from N(0, I).
245
+ - new_t (torch.Tensor): The original `t` tensor.
246
+ - prompt_len (torch.Tensor): The length of the prompt for each sample.
247
+ - mask (torch.Tensor): A mask where 1 indicates the target (noisy) region and 0 indicates the prompt (clean) region.
248
+ """
249
+ new_t = t
250
+ t = t.unsqueeze(-1).unsqueeze(-1) # Reshape for broadcasting
251
+ z = torch.randn(
252
+ x.shape, dtype=x.dtype, device=x.device, requires_grad=False
253
+ ) # (B, T, mel_dim)
254
+
255
+ context_drop_p = self.context_drop_p
256
+
257
+ # Randomly decide the length of the prompt (un-noised context)
258
+ if torch.rand(1) > context_drop_p:
259
+ prompt_len = torch.randint(
260
+ min(x.shape[1] // 4, 5), int(x.shape[1] * 0.4), (x.shape[0],)
261
+ ).to(x.device)
262
+ else:
263
+ # Drop the context entirely by setting prompt length to 0
264
+ prompt_len = torch.zeros(x.shape[0], device=x.device)
265
+
266
+ # Create a mask to distinguish prompt from target
267
+ is_prompt = torch.zeros_like(x[:, :, 0]) # (B, T)
268
+ col_indices = torch.arange(is_prompt.shape[1], device=prompt_len.device).repeat(
269
+ is_prompt.shape[0], 1
270
+ ) # (B, T)
271
+ is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # 1 if it's a prompt frame
272
+
273
+ mask = torch.ones_like(x[:, :, 0]) # Mask is 1 for target, 0 for prompt
274
+ mask[is_prompt.bool()] = 0
275
+ mask = mask.unsqueeze(-1) # (B, T, 1)
276
+
277
+ # Flow matching formula: xt = (1 - (1 - sigma) * t) * x0 + t * x
278
+ # where x0 ~ N(0, 1) and x is the clean data sample.
279
+ # The equation is applied only to the target region.
280
+ xt = ((1 - (1 - self.sigma) * t) * z + t * x) * mask + x * (1 - mask)
281
+
282
+ return xt, z, new_t, prompt_len, mask
283
+
284
+ def forward(
285
+ self,
286
+ x: torch.Tensor,
287
+ x_mask: torch.Tensor,
288
+ x_in: Optional[torch.Tensor] = None,
289
+ text_ids: Optional[torch.Tensor] = None,
290
+ text_mask: Optional[torch.Tensor] = None,
291
+ ) -> Dict[str, Optional[torch.Tensor]]:
292
+ """
293
+ The main training-time forward pass of the model.
294
+
295
+ Args:
296
+ x (torch.Tensor): The target mel-spectrogram, shape `(B, T, mel_dim)`.
297
+ x_mask (torch.Tensor): Padding mask for `x`, shape `(B, T)`.
298
+ x_in (Optional[torch.Tensor]): Optional input for the encoder (e.g., SSL features). If None, `x` is used.
299
+ text_ids (Optional[torch.Tensor]): Input text token IDs for conditioning, shape `(B, text_len)`.
300
+ text_mask (Optional[torch.Tensor]): Padding mask for `text_ids`.
301
+
302
+ Returns:
303
+ Dict[str, Optional[torch.Tensor]]: A dictionary containing various tensors for loss computation,
304
+ such as the predicted flow, the final predicted mel, the noise target, etc.
305
+ """
306
+ # 1. Encoder pass
307
+ if x_in is None:
308
+ # Use target mel as encoder input
309
+ vq_emb_pre = self.encoder(x=x, x_mask=x_mask)
310
+ else:
311
+ # Use provided SSL features as encoder input
312
+ vq_emb_pre = self.encoder(x=x_in, x_mask=x_mask)
313
+
314
+ # 2. Downsampling before VQ
315
+ _, T, _ = vq_emb_pre.shape
316
+ vq_emb_pre = vq_emb_pre.transpose(1, 2)
317
+ vq_emb_pre = F.interpolate(
318
+ vq_emb_pre, size=T // self.down_sample_factor, mode="linear"
319
+ )
320
+ vq_emb_pre = vq_emb_pre.transpose(1, 2)
321
+
322
+ # 3. Vector Quantization
323
+ vq_emb_pre = self.vq_in_linear(vq_emb_pre)
324
+ vq_emb_pre = F.normalize(vq_emb_pre, dim=-1) # L2 normalize before quantization
325
+
326
+ if self.use_vq:
327
+ vq_emb, vq_loss, info = self.bsq(vq_emb_pre)
328
+ commit_loss = info["commit_loss"]
329
+ else:
330
+ vq_emb = vq_emb_pre
331
+ vq_loss = torch.tensor(0.0, device=x.device)
332
+ commit_loss = torch.tensor(0.0, device=x.device)
333
+ info = None
334
+
335
+ vq_emb_post = self.vq_out_linear(vq_emb)
336
+
337
+ # 4. Upsampling after VQ
338
+ vq_emb_post = vq_emb_post.transpose(1, 2)
339
+ vq_emb_post = F.interpolate(vq_emb_post, size=T, mode="linear")
340
+ vq_emb_post = vq_emb_post.transpose(1, 2)
341
+
342
+ # 5. Decoder with flow matching
343
+ # Sample a random timestep t for each item in the batch
344
+ t = torch.rand(x.shape[0], device=x.device, requires_grad=False)
345
+ t = torch.clamp(t, 1e-5, 1.0) # Clamp to avoid numerical issues at boundaries
346
+
347
+ # Perform forward diffusion to get the noisy input `xt` and the noise `z`
348
+ xt, z, new_t, prompt_len, mask = self.forward_diffusion(x, t)
349
+ noise = z
350
+
351
+ # 6. Prepare conditions for the decoder
352
+ if self.use_text_cond:
353
+ text_emb = self.text_emb(text_ids)
354
+ else:
355
+ text_emb = None
356
+
357
+ # Use the upsampled VQ embedding as the primary condition
358
+ cond_emb = vq_emb_post
359
+ # Apply condition dropout for classifier-free guidance training
360
+ if torch.rand(1) < self.cond_drop_p:
361
+ cond_emb = torch.zeros_like(cond_emb)
362
+
363
+ # 7. Decoder pass
364
+ if self.use_repa_loss:
365
+ # If using Repa loss, we need to output intermediate hidden states
366
+ flow_pred, hidden_states = self.decoder(
367
+ x=xt,
368
+ x_mask=x_mask,
369
+ text_embedding=text_emb,
370
+ text_mask=text_mask,
371
+ cond=cond_emb,
372
+ diffusion_step=new_t,
373
+ output_hidden_states=True,
374
+ )
375
+ ssl_feat_pred = self.repa_mlp(hidden_states[self.repa_layer_idx])
376
+ else:
377
+ flow_pred = self.decoder(
378
+ x=xt,
379
+ x_mask=x_mask,
380
+ text_embedding=text_emb,
381
+ text_mask=text_mask,
382
+ cond=cond_emb,
383
+ diffusion_step=new_t,
384
+ )
385
+ ssl_feat_pred = None
386
+
387
+ # Predict the clean data `x0_pred` from the noisy input `xt` and the predicted flow
388
+ # x_pred = xt + (1 - t) * flow_pred
389
+ x_pred = xt + (1 - t.unsqueeze(-1).unsqueeze(-1)) * flow_pred
390
+
391
+ # Final mask should consider both the prompt/target mask and the original padding mask
392
+ final_mask = mask * x_mask.unsqueeze(-1)
393
+
394
+ return {
395
+ "noise": noise,
396
+ "x": x,
397
+ "flow_pred": flow_pred,
398
+ "x_pred": x_pred,
399
+ "final_mask": final_mask,
400
+ "prompt_len": prompt_len,
401
+ "vq_loss": vq_loss,
402
+ "commit_loss": commit_loss,
403
+ "ssl_feat_pred": ssl_feat_pred,
404
+ }
405
+
406
+ @torch.no_grad()
407
+ def encode(
408
+ self, x: torch.Tensor, x_mask: torch.Tensor
409
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
410
+ """
411
+ Encodes an input feature `x` into discrete VQ codes. (Inference)
412
+
413
+ Args:
414
+ x (torch.Tensor): Input feature, shape `(B, T, in_dim)`.
415
+ x_mask (torch.Tensor): Padding mask for `x`, shape `(B, T)`.
416
+
417
+ Returns:
418
+ Tuple[torch.Tensor, Optional[torch.Tensor]]:
419
+ - vq_emb (torch.Tensor): The quantized continuous embeddings.
420
+ - indices (Optional[torch.Tensor]): The discrete code indices, if VQ is used.
421
+ """
422
+ # Encoder pass
423
+ vq_emb_pre = self.encoder(x=x, x_mask=x_mask)
424
+
425
+ # Downsampling
426
+ _, T, _ = vq_emb_pre.shape
427
+ vq_emb_pre = vq_emb_pre.transpose(1, 2)
428
+ vq_emb_pre = F.interpolate(
429
+ vq_emb_pre, size=T // self.down_sample_factor, mode="linear"
430
+ )
431
+ vq_emb_pre = vq_emb_pre.transpose(1, 2)
432
+
433
+ # VQ
434
+ vq_emb_pre = self.vq_in_linear(vq_emb_pre)
435
+ vq_emb_pre = F.normalize(vq_emb_pre, dim=-1) # L2 norm
436
+
437
+ if self.use_vq:
438
+ vq_emb, _, info = self.bsq(vq_emb_pre)
439
+ indices = info["indices"]
440
+ else:
441
+ vq_emb = vq_emb_pre
442
+ indices = None
443
+
444
+ return vq_emb, indices
445
+
446
+ @torch.no_grad()
447
+ def index2vq(self, indices: torch.Tensor) -> torch.Tensor:
448
+ """
449
+ Converts VQ code indices back to continuous embeddings.
450
+
451
+ Args:
452
+ indices (torch.Tensor): The discrete code indices.
453
+
454
+ Returns:
455
+ torch.Tensor: The corresponding continuous codebook embeddings.
456
+ """
457
+ return self.bsq.get_codebook_entry(indices).float()
458
+
459
+ @torch.no_grad()
460
+ def reverse_diffusion(
461
+ self,
462
+ vq_emb: Optional[torch.Tensor] = None,
463
+ indices: Optional[torch.Tensor] = None,
464
+ text_ids: Optional[torch.Tensor] = None,
465
+ prompt_mel: Optional[torch.Tensor] = None,
466
+ x_mask: Optional[torch.Tensor] = None,
467
+ prompt_mask: Optional[torch.Tensor] = None,
468
+ text_mask: Optional[torch.Tensor] = None,
469
+ n_timesteps: int = 32,
470
+ cfg: float = 1.0,
471
+ rescale_cfg: float = 0.75,
472
+ ) -> torch.Tensor:
473
+ """
474
+ Performs the reverse diffusion process to generate mel-spectrograms from conditions. (Inference)
475
+
476
+ Args:
477
+ vq_emb (Optional[torch.Tensor]): Pre-quantized embeddings.
478
+ indices (Optional[torch.Tensor]): Discrete VQ code indices. If provided, `vq_emb` is ignored.
479
+ text_ids (Optional[torch.Tensor]): Text token IDs for conditioning.
480
+ prompt_mel (Optional[torch.Tensor]): A mel-spectrogram prompt.
481
+ x_mask (Optional[torch.Tensor]): Padding mask for the target generation length.
482
+ prompt_mask (Optional[torch.Tensor]): Padding mask for the prompt.
483
+ text_mask (Optional[torch.Tensor]): Padding mask for the text.
484
+ n_timesteps (int): Number of steps in the reverse diffusion process.
485
+ cfg (float): Classifier-Free Guidance scale.
486
+ rescale_cfg (float): Rescaling factor for CFG to prevent saturation.
487
+
488
+ Returns:
489
+ torch.Tensor: The generated mel-spectrogram.
490
+ """
491
+ if vq_emb is None:
492
+ assert indices is not None, "Either vq_emb or indices must be provided"
493
+ vq_emb = self.index2vq(indices.long())
494
+
495
+ # Upsample VQ embeddings to match the target mel length
496
+ vq_emb_post = self.vq_out_linear(vq_emb)
497
+ vq_emb_post = vq_emb_post.transpose(1, 2)
498
+ vq_emb_post = F.interpolate(
499
+ vq_emb_post, scale_factor=self.down_sample_factor, mode="linear"
500
+ )
501
+ vq_emb_post = vq_emb_post.transpose(1, 2)
502
+
503
+ # Prepare text embeddings
504
+ if self.use_text_cond:
505
+ text_emb = self.text_emb(text_ids)
506
+ if text_mask is None:
507
+ text_mask = torch.ones_like(text_ids)
508
+ else:
509
+ text_emb, text_mask = None, None
510
+
511
+ cond_emb = vq_emb_post
512
+
513
+ # Handle prompt
514
+ if prompt_mel is None:
515
+ prompt_mel = torch.zeros(
516
+ cond_emb.shape[0], 0, self.mel_dim, device=cond_emb.device
517
+ )
518
+
519
+ prompt_len = prompt_mel.shape[1]
520
+ target_len = cond_emb.shape[1] - prompt_len
521
+
522
+ # Prepare masks
523
+ if x_mask is None:
524
+ x_mask = torch.ones(cond_emb.shape[0], target_len, device=cond_emb.device)
525
+ if prompt_mask is None:
526
+ prompt_mask = torch.ones(
527
+ cond_emb.shape[0], prompt_len, device=cond_emb.device
528
+ )
529
+
530
+ xt_mask = torch.cat([prompt_mask, x_mask], dim=1)
531
+
532
+ # Initialize with random noise
533
+ z = torch.randn(
534
+ (cond_emb.shape[0], target_len, self.mel_dim),
535
+ dtype=cond_emb.dtype,
536
+ device=cond_emb.device,
537
+ )
538
+ xt = z
539
+ h = 1.0 / n_timesteps
540
+
541
+ # Iterative denoising loop (Euler method)
542
+ for i in range(n_timesteps):
543
+ # Concatenate prompt and current noisy sample
544
+ xt_input = torch.cat([prompt_mel, xt], dim=1)
545
+ # Calculate current timestep
546
+ t = (0 + (i + 0.5) * h) * torch.ones(
547
+ z.shape[0], dtype=z.dtype, device=z.device
548
+ )
549
+
550
+ # Get conditional flow prediction
551
+ flow_pred = self.decoder(
552
+ x=xt_input,
553
+ x_mask=xt_mask,
554
+ text_embedding=text_emb,
555
+ text_mask=text_mask,
556
+ cond=cond_emb,
557
+ diffusion_step=t,
558
+ )
559
+ flow_pred = flow_pred[
560
+ :, prompt_len:, :
561
+ ] # Extract flow for the target region
562
+
563
+ # Classifier-Free Guidance (CFG)
564
+ if cfg > 0 and self.use_text_cond:
565
+ # Get unconditional flow prediction by dropping conditions
566
+ uncond_flow_pred = self.decoder(
567
+ x=xt_input,
568
+ x_mask=xt_mask,
569
+ text_embedding=None, # Drop text
570
+ text_mask=None,
571
+ cond=torch.zeros_like(cond_emb), # Drop code
572
+ diffusion_step=t,
573
+ )
574
+ uncond_flow_pred = uncond_flow_pred[:, prompt_len:, :]
575
+
576
+ # Combine conditional and unconditional predictions
577
+ flow_pred_cfg = uncond_flow_pred + cfg * (flow_pred - uncond_flow_pred)
578
+
579
+ # Rescale to prevent saturation, as in Stable Diffusion
580
+ if rescale_cfg > 0:
581
+ flow_pred_std = flow_pred.std()
582
+ cfg_std = flow_pred_cfg.std()
583
+ # Avoid division by zero
584
+ if cfg_std > 1e-6:
585
+ rescale_flow_pred = flow_pred_cfg * (flow_pred_std / cfg_std)
586
+ flow_pred = (
587
+ rescale_cfg * rescale_flow_pred
588
+ + (1 - rescale_cfg) * flow_pred_cfg
589
+ )
590
+ else:
591
+ flow_pred = flow_pred_cfg
592
+ else:
593
+ flow_pred = flow_pred_cfg
594
+
595
+ # Update the noisy sample
596
+ dxt = flow_pred * h
597
+ xt = xt + dxt
598
+
599
+ return xt
600
+
601
+ def reset_parameters(self):
602
+ """
603
+ Applies custom weight initialization to the model's submodules.
604
+ """
605
+
606
+ def _reset_parameters(m: nn.Module):
607
+ if isinstance(m, nn.MultiheadAttention):
608
+ if m._qkv_same_embed_dim:
609
+ nn.init.normal_(m.in_proj_weight, std=0.02)
610
+ else:
611
+ nn.init.normal_(m.q_proj_weight, std=0.02)
612
+ nn.init.normal_(m.k_proj_weight, std=0.02)
613
+ nn.init.normal_(m.v_proj_weight, std=0.02)
614
+
615
+ if m.in_proj_bias is not None:
616
+ nn.init.constant_(m.in_proj_bias, 0.0)
617
+ nn.init.constant_(m.out_proj.bias, 0.0)
618
+ if m.bias_k is not None:
619
+ nn.init.xavier_normal_(m.bias_k)
620
+ if m.bias_v is not None:
621
+ nn.init.xavier_normal_(m.bias_v)
622
+
623
+ elif (
624
+ isinstance(m, nn.Conv1d)
625
+ or isinstance(m, nn.ConvTranspose1d)
626
+ or isinstance(m, nn.Conv2d)
627
+ or isinstance(m, nn.ConvTranspose2d)
628
+ ):
629
+ m.weight.data.normal_(0.0, 0.02)
630
+
631
+ elif isinstance(m, nn.Linear):
632
+ m.weight.data.normal_(mean=0.0, std=0.02)
633
+ if m.bias is not None:
634
+ m.bias.data.zero_()
635
+
636
+ elif isinstance(m, nn.Embedding):
637
+ m.weight.data.normal_(mean=0.0, std=0.02)
638
+ if m.padding_idx is not None:
639
+ m.weight.data[m.padding_idx].zero_()
640
+
641
+ self.apply(_reset_parameters)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchaudio
4
+ numpy
5
+ huggingface-hub
6
+ transformers
7
+ datasets
8
+ librosa
9
+ soundfile
utils/__pycache__/hparam.cpython-310.pyc ADDED
Binary file (21.2 kB). View file
 
utils/__pycache__/util.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
utils/hparam.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long
7
+ """Hyperparameter values."""
8
+ from __future__ import absolute_import
9
+ from __future__ import division
10
+ from __future__ import print_function
11
+
12
+ import json
13
+ import numbers
14
+ import re
15
+ import six
16
+
17
+ # Define the regular expression for parsing a single clause of the input
18
+ # (delimited by commas). A legal clause looks like:
19
+ # <variable name>[<index>]? = <rhs>
20
+ # where <rhs> is either a single token or [] enclosed list of tokens.
21
+ # For example: "var[1] = a" or "x = [1,2,3]"
22
+ PARAM_RE = re.compile(
23
+ r"""
24
+ (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
25
+ (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
26
+ \s*=\s*
27
+ ((?P<val>[^,\[]*) # single value: "a" or None
28
+ |
29
+ \[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
30
+ ($|,\s*)""",
31
+ re.VERBOSE,
32
+ )
33
+
34
+
35
+ def _parse_fail(name, var_type, value, values):
36
+ """Helper function for raising a value error for bad assignment."""
37
+ raise ValueError(
38
+ "Could not parse hparam '%s' of type '%s' with value '%s' in %s"
39
+ % (name, var_type.__name__, value, values)
40
+ )
41
+
42
+
43
+ def _reuse_fail(name, values):
44
+ """Helper function for raising a value error for reuse of name."""
45
+ raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values))
46
+
47
+
48
+ def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary):
49
+ """Update results_dictionary with a scalar value.
50
+
51
+ Used to update the results_dictionary to be returned by parse_values when
52
+ encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
53
+
54
+ Mutates results_dictionary.
55
+
56
+ Args:
57
+ name: Name of variable in assignment ("s" or "arr").
58
+ parse_fn: Function for parsing the actual value.
59
+ var_type: Type of named variable.
60
+ m_dict: Dictionary constructed from regex parsing.
61
+ m_dict['val']: RHS value (scalar)
62
+ m_dict['index']: List index value (or None)
63
+ values: Full expression being parsed
64
+ results_dictionary: The dictionary being updated for return by the parsing
65
+ function.
66
+
67
+ Raises:
68
+ ValueError: If the name has already been used.
69
+ """
70
+ try:
71
+ parsed_value = parse_fn(m_dict["val"])
72
+ except ValueError:
73
+ _parse_fail(name, var_type, m_dict["val"], values)
74
+
75
+ # If no index is provided
76
+ if not m_dict["index"]:
77
+ if name in results_dictionary:
78
+ _reuse_fail(name, values)
79
+ results_dictionary[name] = parsed_value
80
+ else:
81
+ if name in results_dictionary:
82
+ # The name has already been used as a scalar, then it
83
+ # will be in this dictionary and map to a non-dictionary.
84
+ if not isinstance(results_dictionary.get(name), dict):
85
+ _reuse_fail(name, values)
86
+ else:
87
+ results_dictionary[name] = {}
88
+
89
+ index = int(m_dict["index"])
90
+ # Make sure the index position hasn't already been assigned a value.
91
+ if index in results_dictionary[name]:
92
+ _reuse_fail("{}[{}]".format(name, index), values)
93
+ results_dictionary[name][index] = parsed_value
94
+
95
+
96
+ def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary):
97
+ """Update results_dictionary from a list of values.
98
+
99
+ Used to update results_dictionary to be returned by parse_values when
100
+ encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
101
+
102
+ Mutates results_dictionary.
103
+
104
+ Args:
105
+ name: Name of variable in assignment ("arr").
106
+ parse_fn: Function for parsing individual values.
107
+ var_type: Type of named variable.
108
+ m_dict: Dictionary constructed from regex parsing.
109
+ m_dict['val']: RHS value (scalar)
110
+ values: Full expression being parsed
111
+ results_dictionary: The dictionary being updated for return by the parsing
112
+ function.
113
+
114
+ Raises:
115
+ ValueError: If the name has an index or the values cannot be parsed.
116
+ """
117
+ if m_dict["index"] is not None:
118
+ raise ValueError("Assignment of a list to a list index.")
119
+ elements = filter(None, re.split("[ ,]", m_dict["vals"]))
120
+ # Make sure the name hasn't already been assigned a value
121
+ if name in results_dictionary:
122
+ raise _reuse_fail(name, values)
123
+ try:
124
+ results_dictionary[name] = [parse_fn(e) for e in elements]
125
+ except ValueError:
126
+ _parse_fail(name, var_type, m_dict["vals"], values)
127
+
128
+
129
+ def _cast_to_type_if_compatible(name, param_type, value):
130
+ """Cast hparam to the provided type, if compatible.
131
+
132
+ Args:
133
+ name: Name of the hparam to be cast.
134
+ param_type: The type of the hparam.
135
+ value: The value to be cast, if compatible.
136
+
137
+ Returns:
138
+ The result of casting `value` to `param_type`.
139
+
140
+ Raises:
141
+ ValueError: If the type of `value` is not compatible with param_type.
142
+ * If `param_type` is a string type, but `value` is not.
143
+ * If `param_type` is a boolean, but `value` is not, or vice versa.
144
+ * If `param_type` is an integer type, but `value` is not.
145
+ * If `param_type` is a float type, but `value` is not a numeric type.
146
+ """
147
+ fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % (
148
+ name,
149
+ param_type,
150
+ value,
151
+ )
152
+
153
+ # Some callers use None, for which we can't do any casting/checking. :(
154
+ if issubclass(param_type, type(None)):
155
+ return value
156
+
157
+ # Avoid converting a non-string type to a string.
158
+ if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance(
159
+ value, (six.string_types, six.binary_type)
160
+ ):
161
+ raise ValueError(fail_msg)
162
+
163
+ # Avoid converting a number or string type to a boolean or vice versa.
164
+ if issubclass(param_type, bool) != isinstance(value, bool):
165
+ raise ValueError(fail_msg)
166
+
167
+ # Avoid converting float to an integer (the reverse is fine).
168
+ if issubclass(param_type, numbers.Integral) and not isinstance(
169
+ value, numbers.Integral
170
+ ):
171
+ raise ValueError(fail_msg)
172
+
173
+ # Avoid converting a non-numeric type to a numeric type.
174
+ if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number):
175
+ raise ValueError(fail_msg)
176
+
177
+ return param_type(value)
178
+
179
+
180
+ def parse_values(values, type_map, ignore_unknown=False):
181
+ """Parses hyperparameter values from a string into a python map.
182
+
183
+ `values` is a string containing comma-separated `name=value` pairs.
184
+ For each pair, the value of the hyperparameter named `name` is set to
185
+ `value`.
186
+
187
+ If a hyperparameter name appears multiple times in `values`, a ValueError
188
+ is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
189
+
190
+ If a hyperparameter name in both an index assignment and scalar assignment,
191
+ a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
192
+
193
+ The hyperparameter name may contain '.' symbols, which will result in an
194
+ attribute name that is only accessible through the getattr and setattr
195
+ functions. (And must be first explicit added through add_hparam.)
196
+
197
+ WARNING: Use of '.' in your variable names is allowed, but is not well
198
+ supported and not recommended.
199
+
200
+ The `value` in `name=value` must follows the syntax according to the
201
+ type of the parameter:
202
+
203
+ * Scalar integer: A Python-parsable integer point value. E.g.: 1,
204
+ 100, -12.
205
+ * Scalar float: A Python-parsable floating point value. E.g.: 1.0,
206
+ -.54e89.
207
+ * Boolean: Either true or false.
208
+ * Scalar string: A non-empty sequence of characters, excluding comma,
209
+ spaces, and square brackets. E.g.: foo, bar_1.
210
+ * List: A comma separated list of scalar values of the parameter type
211
+ enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
212
+
213
+ When index assignment is used, the corresponding type_map key should be the
214
+ list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
215
+ "arr[1]").
216
+
217
+ Args:
218
+ values: String. Comma separated list of `name=value` pairs where
219
+ 'value' must follow the syntax described above.
220
+ type_map: A dictionary mapping hyperparameter names to types. Note every
221
+ parameter name in values must be a key in type_map. The values must
222
+ conform to the types indicated, where a value V is said to conform to a
223
+ type T if either V has type T, or V is a list of elements of type T.
224
+ Hence, for a multidimensional parameter 'x' taking float values,
225
+ 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
226
+ ignore_unknown: Bool. Whether values that are missing a type in type_map
227
+ should be ignored. If set to True, a ValueError will not be raised for
228
+ unknown hyperparameter type.
229
+
230
+ Returns:
231
+ A python map mapping each name to either:
232
+ * A scalar value.
233
+ * A list of scalar values.
234
+ * A dictionary mapping index numbers to scalar values.
235
+ (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
236
+
237
+ Raises:
238
+ ValueError: If there is a problem with input.
239
+ * If `values` cannot be parsed.
240
+ * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
241
+ * If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
242
+ 'a[1]=1,a[1]=2', or 'a=1,a=[1]')
243
+ """
244
+ results_dictionary = {}
245
+ pos = 0
246
+ while pos < len(values):
247
+ m = PARAM_RE.match(values, pos)
248
+ if not m:
249
+ raise ValueError("Malformed hyperparameter value: %s" % values[pos:])
250
+ # Check that there is a comma between parameters and move past it.
251
+ pos = m.end()
252
+ # Parse the values.
253
+ m_dict = m.groupdict()
254
+ name = m_dict["name"]
255
+ if name not in type_map:
256
+ if ignore_unknown:
257
+ continue
258
+ raise ValueError("Unknown hyperparameter type for %s" % name)
259
+ type_ = type_map[name]
260
+
261
+ # Set up correct parsing function (depending on whether type_ is a bool)
262
+ if type_ == bool:
263
+
264
+ def parse_bool(value):
265
+ if value in ["true", "True"]:
266
+ return True
267
+ elif value in ["false", "False"]:
268
+ return False
269
+ else:
270
+ try:
271
+ return bool(int(value))
272
+ except ValueError:
273
+ _parse_fail(name, type_, value, values)
274
+
275
+ parse = parse_bool
276
+ else:
277
+ parse = type_
278
+
279
+ # If a singe value is provided
280
+ if m_dict["val"] is not None:
281
+ _process_scalar_value(
282
+ name, parse, type_, m_dict, values, results_dictionary
283
+ )
284
+
285
+ # If the assigned value is a list:
286
+ elif m_dict["vals"] is not None:
287
+ _process_list_value(name, parse, type_, m_dict, values, results_dictionary)
288
+
289
+ else: # Not assigned a list or value
290
+ _parse_fail(name, type_, "", values)
291
+
292
+ return results_dictionary
293
+
294
+
295
+ class HParams(object):
296
+ """Class to hold a set of hyperparameters as name-value pairs.
297
+
298
+ A `HParams` object holds hyperparameters used to build and train a model,
299
+ such as the number of hidden units in a neural net layer or the learning rate
300
+ to use when training.
301
+
302
+ You first create a `HParams` object by specifying the names and values of the
303
+ hyperparameters.
304
+
305
+ To make them easily accessible the parameter names are added as direct
306
+ attributes of the class. A typical usage is as follows:
307
+
308
+ ```python
309
+ # Create a HParams object specifying names and values of the model
310
+ # hyperparameters:
311
+ hparams = HParams(learning_rate=0.1, num_hidden_units=100)
312
+
313
+ # The hyperparameter are available as attributes of the HParams object:
314
+ hparams.learning_rate ==> 0.1
315
+ hparams.num_hidden_units ==> 100
316
+ ```
317
+
318
+ Hyperparameters have type, which is inferred from the type of their value
319
+ passed at construction type. The currently supported types are: integer,
320
+ float, boolean, string, and list of integer, float, boolean, or string.
321
+
322
+ You can override hyperparameter values by calling the
323
+ [`parse()`](#HParams.parse) method, passing a string of comma separated
324
+ `name=value` pairs. This is intended to make it possible to override
325
+ any hyperparameter values from a single command-line flag to which
326
+ the user passes 'hyper-param=value' pairs. It avoids having to define
327
+ one flag for each hyperparameter.
328
+
329
+ The syntax expected for each value depends on the type of the parameter.
330
+ See `parse()` for a description of the syntax.
331
+
332
+ Example:
333
+
334
+ ```python
335
+ # Define a command line flag to pass name=value pairs.
336
+ # For example using argparse:
337
+ import argparse
338
+ parser = argparse.ArgumentParser(description='Train my model.')
339
+ parser.add_argument('--hparams', type=str,
340
+ help='Comma separated list of "name=value" pairs.')
341
+ args = parser.parse_args()
342
+ ...
343
+ def my_program():
344
+ # Create a HParams object specifying the names and values of the
345
+ # model hyperparameters:
346
+ hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
347
+ activations=['relu', 'tanh'])
348
+
349
+ # Override hyperparameters values by parsing the command line
350
+ hparams.parse(args.hparams)
351
+
352
+ # If the user passed `--hparams=learning_rate=0.3` on the command line
353
+ # then 'hparams' has the following attributes:
354
+ hparams.learning_rate ==> 0.3
355
+ hparams.num_hidden_units ==> 100
356
+ hparams.activations ==> ['relu', 'tanh']
357
+
358
+ # If the hyperparameters are in json format use parse_json:
359
+ hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
360
+ ```
361
+ """
362
+
363
+ _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
364
+
365
+ def __init__(self, model_structure=None, **kwargs):
366
+ """Create an instance of `HParams` from keyword arguments.
367
+
368
+ The keyword arguments specify name-values pairs for the hyperparameters.
369
+ The parameter types are inferred from the type of the values passed.
370
+
371
+ The parameter names are added as attributes of `HParams` object, so they
372
+ can be accessed directly with the dot notation `hparams._name_`.
373
+
374
+ Example:
375
+
376
+ ```python
377
+ # Define 3 hyperparameters: 'learning_rate' is a float parameter,
378
+ # 'num_hidden_units' an integer parameter, and 'activation' a string
379
+ # parameter.
380
+ hparams = tf.HParams(
381
+ learning_rate=0.1, num_hidden_units=100, activation='relu')
382
+
383
+ hparams.activation ==> 'relu'
384
+ ```
385
+
386
+ Note that a few names are reserved and cannot be used as hyperparameter
387
+ names. If you use one of the reserved name the constructor raises a
388
+ `ValueError`.
389
+
390
+ Args:
391
+ model_structure: An instance of ModelStructure, defining the feature
392
+ crosses to be used in the Trial.
393
+ **kwargs: Key-value pairs where the key is the hyperparameter name and
394
+ the value is the value for the parameter.
395
+
396
+ Raises:
397
+ ValueError: If both `hparam_def` and initialization values are provided,
398
+ or if one of the arguments is invalid.
399
+
400
+ """
401
+ # Register the hyperparameters and their type in _hparam_types.
402
+ # This simplifies the implementation of parse().
403
+ # _hparam_types maps the parameter name to a tuple (type, bool).
404
+ # The type value is the type of the parameter for scalar hyperparameters,
405
+ # or the type of the list elements for multidimensional hyperparameters.
406
+ # The bool value is True if the value is a list, False otherwise.
407
+ self._hparam_types = {}
408
+ self._model_structure = model_structure
409
+ for name, value in six.iteritems(kwargs):
410
+ self.add_hparam(name, value)
411
+
412
+ def add_hparam(self, name, value):
413
+ """Adds {name, value} pair to hyperparameters.
414
+
415
+ Args:
416
+ name: Name of the hyperparameter.
417
+ value: Value of the hyperparameter. Can be one of the following types:
418
+ int, float, string, int list, float list, or string list.
419
+
420
+ Raises:
421
+ ValueError: if one of the arguments is invalid.
422
+ """
423
+ # Keys in kwargs are unique, but 'name' could the name of a pre-existing
424
+ # attribute of this object. In that case we refuse to use it as a
425
+ # hyperparameter name.
426
+ if getattr(self, name, None) is not None:
427
+ raise ValueError("Hyperparameter name is reserved: %s" % name)
428
+ if isinstance(value, (list, tuple)):
429
+ if not value:
430
+ raise ValueError(
431
+ "Multi-valued hyperparameters cannot be empty: %s" % name
432
+ )
433
+ self._hparam_types[name] = (type(value[0]), True)
434
+ else:
435
+ self._hparam_types[name] = (type(value), False)
436
+ setattr(self, name, value)
437
+
438
+ def set_hparam(self, name, value):
439
+ """Set the value of an existing hyperparameter.
440
+
441
+ This function verifies that the type of the value matches the type of the
442
+ existing hyperparameter.
443
+
444
+ Args:
445
+ name: Name of the hyperparameter.
446
+ value: New value of the hyperparameter.
447
+
448
+ Raises:
449
+ KeyError: If the hyperparameter doesn't exist.
450
+ ValueError: If there is a type mismatch.
451
+ """
452
+ param_type, is_list = self._hparam_types[name]
453
+ if isinstance(value, list):
454
+ if not is_list:
455
+ raise ValueError(
456
+ "Must not pass a list for single-valued parameter: %s" % name
457
+ )
458
+ setattr(
459
+ self,
460
+ name,
461
+ [_cast_to_type_if_compatible(name, param_type, v) for v in value],
462
+ )
463
+ else:
464
+ if is_list:
465
+ raise ValueError(
466
+ "Must pass a list for multi-valued parameter: %s." % name
467
+ )
468
+ setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
469
+
470
+ def del_hparam(self, name):
471
+ """Removes the hyperparameter with key 'name'.
472
+
473
+ Does nothing if it isn't present.
474
+
475
+ Args:
476
+ name: Name of the hyperparameter.
477
+ """
478
+ if hasattr(self, name):
479
+ delattr(self, name)
480
+ del self._hparam_types[name]
481
+
482
+ def parse(self, values):
483
+ """Override existing hyperparameter values, parsing new values from a string.
484
+
485
+ See parse_values for more detail on the allowed format for values.
486
+
487
+ Args:
488
+ values: String. Comma separated list of `name=value` pairs where 'value'
489
+ must follow the syntax described above.
490
+
491
+ Returns:
492
+ The `HParams` instance.
493
+
494
+ Raises:
495
+ ValueError: If `values` cannot be parsed or a hyperparameter in `values`
496
+ doesn't exist.
497
+ """
498
+ type_map = {}
499
+ for name, t in self._hparam_types.items():
500
+ param_type, _ = t
501
+ type_map[name] = param_type
502
+
503
+ values_map = parse_values(values, type_map)
504
+ return self.override_from_dict(values_map)
505
+
506
+ def override_from_dict(self, values_dict):
507
+ """Override existing hyperparameter values, parsing new values from a dictionary.
508
+
509
+ Args:
510
+ values_dict: Dictionary of name:value pairs.
511
+
512
+ Returns:
513
+ The `HParams` instance.
514
+
515
+ Raises:
516
+ KeyError: If a hyperparameter in `values_dict` doesn't exist.
517
+ ValueError: If `values_dict` cannot be parsed.
518
+ """
519
+ for name, value in values_dict.items():
520
+ self.set_hparam(name, value)
521
+ return self
522
+
523
+ def set_model_structure(self, model_structure):
524
+ self._model_structure = model_structure
525
+
526
+ def get_model_structure(self):
527
+ return self._model_structure
528
+
529
+ def to_json(self, indent=None, separators=None, sort_keys=False):
530
+ """Serializes the hyperparameters into JSON.
531
+
532
+ Args:
533
+ indent: If a non-negative integer, JSON array elements and object members
534
+ will be pretty-printed with that indent level. An indent level of 0, or
535
+ negative, will only insert newlines. `None` (the default) selects the
536
+ most compact representation.
537
+ separators: Optional `(item_separator, key_separator)` tuple. Default is
538
+ `(', ', ': ')`.
539
+ sort_keys: If `True`, the output dictionaries will be sorted by key.
540
+
541
+ Returns:
542
+ A JSON string.
543
+ """
544
+
545
+ def remove_callables(x):
546
+ """Omit callable elements from input with arbitrary nesting."""
547
+ if isinstance(x, dict):
548
+ return {
549
+ k: remove_callables(v)
550
+ for k, v in six.iteritems(x)
551
+ if not callable(v)
552
+ }
553
+ elif isinstance(x, list):
554
+ return [remove_callables(i) for i in x if not callable(i)]
555
+ return x
556
+
557
+ return json.dumps(
558
+ remove_callables(self.values()),
559
+ indent=indent,
560
+ separators=separators,
561
+ sort_keys=sort_keys,
562
+ )
563
+
564
+ def parse_json(self, values_json):
565
+ """Override existing hyperparameter values, parsing new values from a json object.
566
+
567
+ Args:
568
+ values_json: String containing a json object of name:value pairs.
569
+
570
+ Returns:
571
+ The `HParams` instance.
572
+
573
+ Raises:
574
+ KeyError: If a hyperparameter in `values_json` doesn't exist.
575
+ ValueError: If `values_json` cannot be parsed.
576
+ """
577
+ values_map = json.loads(values_json)
578
+ return self.override_from_dict(values_map)
579
+
580
+ def values(self):
581
+ """Return the hyperparameter values as a Python dictionary.
582
+
583
+ Returns:
584
+ A dictionary with hyperparameter names as keys. The values are the
585
+ hyperparameter values.
586
+ """
587
+ return {n: getattr(self, n) for n in self._hparam_types.keys()}
588
+
589
+ def get(self, key, default=None):
590
+ """Returns the value of `key` if it exists, else `default`."""
591
+ if key in self._hparam_types:
592
+ # Ensure that default is compatible with the parameter type.
593
+ if default is not None:
594
+ param_type, is_param_list = self._hparam_types[key]
595
+ type_str = "list<%s>" % param_type if is_param_list else str(param_type)
596
+ fail_msg = (
597
+ "Hparam '%s' of type '%s' is incompatible with "
598
+ "default=%s" % (key, type_str, default)
599
+ )
600
+
601
+ is_default_list = isinstance(default, list)
602
+ if is_param_list != is_default_list:
603
+ raise ValueError(fail_msg)
604
+
605
+ try:
606
+ if is_default_list:
607
+ for value in default:
608
+ _cast_to_type_if_compatible(key, param_type, value)
609
+ else:
610
+ _cast_to_type_if_compatible(key, param_type, default)
611
+ except ValueError as e:
612
+ raise ValueError("%s. %s" % (fail_msg, e))
613
+
614
+ return getattr(self, key)
615
+
616
+ return default
617
+
618
+ def __contains__(self, key):
619
+ return key in self._hparam_types
620
+
621
+ def __str__(self):
622
+ return str(sorted(self.values().items()))
623
+
624
+ def __repr__(self):
625
+ return "%s(%s)" % (type(self).__name__, self.__str__())
626
+
627
+ @staticmethod
628
+ def _get_kind_name(param_type, is_list):
629
+ """Returns the field name given parameter type and is_list.
630
+
631
+ Args:
632
+ param_type: Data type of the hparam.
633
+ is_list: Whether this is a list.
634
+
635
+ Returns:
636
+ A string representation of the field name.
637
+
638
+ Raises:
639
+ ValueError: If parameter type is not recognized.
640
+ """
641
+ if issubclass(param_type, bool):
642
+ # This check must happen before issubclass(param_type, six.integer_types),
643
+ # since Python considers bool to be a subclass of int.
644
+ typename = "bool"
645
+ elif issubclass(param_type, six.integer_types):
646
+ # Setting 'int' and 'long' types to be 'int64' to ensure the type is
647
+ # compatible with both Python2 and Python3.
648
+ typename = "int64"
649
+ elif issubclass(param_type, (six.string_types, six.binary_type)):
650
+ # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
651
+ # compatible with both Python2 and Python3.
652
+ typename = "bytes"
653
+ elif issubclass(param_type, float):
654
+ typename = "float"
655
+ else:
656
+ raise ValueError("Unsupported parameter type: %s" % str(param_type))
657
+
658
+ suffix = "list" if is_list else "value"
659
+ return "_".join([typename, suffix])
utils/util.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import collections
8
+ import glob
9
+ import os
10
+ import random
11
+ import time
12
+ import argparse
13
+ from collections import OrderedDict
14
+
15
+ import json5
16
+ import numpy as np
17
+ import glob
18
+ from torch.nn import functional as F
19
+
20
+
21
+ try:
22
+ from ruamel.yaml import YAML as yaml
23
+ except:
24
+ from ruamel_yaml import YAML as yaml # type: ignore
25
+
26
+ import torch
27
+
28
+ from utils.hparam import HParams
29
+ import logging
30
+ from logging import handlers
31
+
32
+
33
+ def str2bool(v):
34
+ """Used in argparse.ArgumentParser.add_argument to indicate
35
+ that a type is a bool type and user can enter
36
+
37
+ - yes, true, t, y, 1, to represent True
38
+ - no, false, f, n, 0, to represent False
39
+
40
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
41
+ """
42
+ if isinstance(v, bool):
43
+ return v
44
+ if v.lower() in ("yes", "true", "t", "y", "1"):
45
+ return True
46
+ elif v.lower() in ("no", "false", "f", "n", "0"):
47
+ return False
48
+ else:
49
+ raise argparse.ArgumentTypeError("Boolean value expected.")
50
+
51
+
52
+ def find_checkpoint_of_mapper(mapper_ckpt_dir):
53
+ mapper_ckpts = glob.glob(os.path.join(mapper_ckpt_dir, "ckpts/*.pt"))
54
+
55
+ # Select the max steps
56
+ mapper_ckpts.sort()
57
+ mapper_weights_file = mapper_ckpts[-1]
58
+ return mapper_weights_file
59
+
60
+
61
+ def pad_f0_to_tensors(f0s, batched=None):
62
+ # Initialize
63
+ tensors = []
64
+
65
+ if batched == None:
66
+ # Get the max frame for padding
67
+ size = -1
68
+ for f0 in f0s:
69
+ size = max(size, f0.shape[-1])
70
+
71
+ tensor = torch.zeros(len(f0s), size)
72
+
73
+ for i, f0 in enumerate(f0s):
74
+ tensor[i, : f0.shape[-1]] = f0[:]
75
+
76
+ tensors.append(tensor)
77
+ else:
78
+ start = 0
79
+ while start + batched - 1 < len(f0s):
80
+ end = start + batched - 1
81
+
82
+ # Get the max frame for padding
83
+ size = -1
84
+ for i in range(start, end + 1):
85
+ size = max(size, f0s[i].shape[-1])
86
+
87
+ tensor = torch.zeros(batched, size)
88
+
89
+ for i in range(start, end + 1):
90
+ tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:]
91
+
92
+ tensors.append(tensor)
93
+
94
+ start = start + batched
95
+
96
+ if start != len(f0s):
97
+ end = len(f0s)
98
+
99
+ # Get the max frame for padding
100
+ size = -1
101
+ for i in range(start, end):
102
+ size = max(size, f0s[i].shape[-1])
103
+
104
+ tensor = torch.zeros(len(f0s) - start, size)
105
+
106
+ for i in range(start, end):
107
+ tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:]
108
+
109
+ tensors.append(tensor)
110
+
111
+ return tensors
112
+
113
+
114
+ def pad_mels_to_tensors(mels, batched=None):
115
+ """
116
+ Args:
117
+ mels: A list of mel-specs
118
+ Returns:
119
+ tensors: A list of tensors containing the batched mel-specs
120
+ mel_frames: A list of tensors containing the frames of the original mel-specs
121
+ """
122
+ # Initialize
123
+ tensors = []
124
+ mel_frames = []
125
+
126
+ # Split mel-specs into batches to avoid cuda memory exceed
127
+ if batched == None:
128
+ # Get the max frame for padding
129
+ size = -1
130
+ for mel in mels:
131
+ size = max(size, mel.shape[-1])
132
+
133
+ tensor = torch.zeros(len(mels), mels[0].shape[0], size)
134
+ mel_frame = torch.zeros(len(mels), dtype=torch.int32)
135
+
136
+ for i, mel in enumerate(mels):
137
+ tensor[i, :, : mel.shape[-1]] = mel[:]
138
+ mel_frame[i] = mel.shape[-1]
139
+
140
+ tensors.append(tensor)
141
+ mel_frames.append(mel_frame)
142
+ else:
143
+ start = 0
144
+ while start + batched - 1 < len(mels):
145
+ end = start + batched - 1
146
+
147
+ # Get the max frame for padding
148
+ size = -1
149
+ for i in range(start, end + 1):
150
+ size = max(size, mels[i].shape[-1])
151
+
152
+ tensor = torch.zeros(batched, mels[0].shape[0], size)
153
+ mel_frame = torch.zeros(batched, dtype=torch.int32)
154
+
155
+ for i in range(start, end + 1):
156
+ tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:]
157
+ mel_frame[i - start] = mels[i].shape[-1]
158
+
159
+ tensors.append(tensor)
160
+ mel_frames.append(mel_frame)
161
+
162
+ start = start + batched
163
+
164
+ if start != len(mels):
165
+ end = len(mels)
166
+
167
+ # Get the max frame for padding
168
+ size = -1
169
+ for i in range(start, end):
170
+ size = max(size, mels[i].shape[-1])
171
+
172
+ tensor = torch.zeros(len(mels) - start, mels[0].shape[0], size)
173
+ mel_frame = torch.zeros(len(mels) - start, dtype=torch.int32)
174
+
175
+ for i in range(start, end):
176
+ tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:]
177
+ mel_frame[i - start] = mels[i].shape[-1]
178
+
179
+ tensors.append(tensor)
180
+ mel_frames.append(mel_frame)
181
+
182
+ return tensors, mel_frames
183
+
184
+
185
+ def load_model_config(args):
186
+ """Load model configurations (in args.json under checkpoint directory)
187
+
188
+ Args:
189
+ args (ArgumentParser): arguments to run bins/preprocess.py
190
+
191
+ Returns:
192
+ dict: dictionary that stores model configurations
193
+ """
194
+ if args.checkpoint_dir is None:
195
+ assert args.checkpoint_file is not None
196
+ checkpoint_dir = os.path.split(args.checkpoint_file)[0]
197
+ else:
198
+ checkpoint_dir = args.checkpoint_dir
199
+ config_path = os.path.join(checkpoint_dir, "args.json")
200
+ print("config_path: ", config_path)
201
+
202
+ config = load_config(config_path)
203
+ return config
204
+
205
+
206
+ def remove_and_create(dir):
207
+ if os.path.exists(dir):
208
+ os.system("rm -r {}".format(dir))
209
+ os.makedirs(dir, exist_ok=True)
210
+
211
+
212
+ def has_existed(path, warning=False):
213
+ if not warning:
214
+ return os.path.exists(path)
215
+
216
+ if os.path.exists(path):
217
+ answer = input(
218
+ "The path {} has existed. \nInput 'y' (or hit Enter) to skip it, and input 'n' to re-write it [y/n]\n".format(
219
+ path
220
+ )
221
+ )
222
+ if not answer == "n":
223
+ return True
224
+
225
+ return False
226
+
227
+
228
+ def remove_older_ckpt(saved_model_name, checkpoint_dir, max_to_keep=5):
229
+ if os.path.exists(os.path.join(checkpoint_dir, "checkpoint")):
230
+ with open(os.path.join(checkpoint_dir, "checkpoint"), "r") as f:
231
+ ckpts = [x.strip() for x in f.readlines()]
232
+ else:
233
+ ckpts = []
234
+ ckpts.append(saved_model_name)
235
+ for item in ckpts[:-max_to_keep]:
236
+ if os.path.exists(os.path.join(checkpoint_dir, item)):
237
+ os.remove(os.path.join(checkpoint_dir, item))
238
+ with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as f:
239
+ for item in ckpts[-max_to_keep:]:
240
+ f.write("{}\n".format(item))
241
+
242
+
243
+ def set_all_random_seed(seed: int):
244
+ random.seed(seed)
245
+ np.random.seed(seed)
246
+ torch.random.manual_seed(seed)
247
+
248
+
249
+ def save_checkpoint(
250
+ args,
251
+ generator,
252
+ g_optimizer,
253
+ step,
254
+ discriminator=None,
255
+ d_optimizer=None,
256
+ max_to_keep=5,
257
+ ):
258
+ saved_model_name = "model.ckpt-{}.pt".format(step)
259
+ checkpoint_path = os.path.join(args.checkpoint_dir, saved_model_name)
260
+
261
+ if discriminator and d_optimizer:
262
+ torch.save(
263
+ {
264
+ "generator": generator.state_dict(),
265
+ "discriminator": discriminator.state_dict(),
266
+ "g_optimizer": g_optimizer.state_dict(),
267
+ "d_optimizer": d_optimizer.state_dict(),
268
+ "global_step": step,
269
+ },
270
+ checkpoint_path,
271
+ )
272
+ else:
273
+ torch.save(
274
+ {
275
+ "generator": generator.state_dict(),
276
+ "g_optimizer": g_optimizer.state_dict(),
277
+ "global_step": step,
278
+ },
279
+ checkpoint_path,
280
+ )
281
+
282
+ print("Saved checkpoint: {}".format(checkpoint_path))
283
+
284
+ if os.path.exists(os.path.join(args.checkpoint_dir, "checkpoint")):
285
+ with open(os.path.join(args.checkpoint_dir, "checkpoint"), "r") as f:
286
+ ckpts = [x.strip() for x in f.readlines()]
287
+ else:
288
+ ckpts = []
289
+ ckpts.append(saved_model_name)
290
+ for item in ckpts[:-max_to_keep]:
291
+ if os.path.exists(os.path.join(args.checkpoint_dir, item)):
292
+ os.remove(os.path.join(args.checkpoint_dir, item))
293
+ with open(os.path.join(args.checkpoint_dir, "checkpoint"), "w") as f:
294
+ for item in ckpts[-max_to_keep:]:
295
+ f.write("{}\n".format(item))
296
+
297
+
298
+ def attempt_to_restore(
299
+ generator, g_optimizer, checkpoint_dir, discriminator=None, d_optimizer=None
300
+ ):
301
+ checkpoint_list = os.path.join(checkpoint_dir, "checkpoint")
302
+ if os.path.exists(checkpoint_list):
303
+ checkpoint_filename = open(checkpoint_list).readlines()[-1].strip()
304
+ checkpoint_path = os.path.join(checkpoint_dir, "{}".format(checkpoint_filename))
305
+ print("Restore from {}".format(checkpoint_path))
306
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
307
+ if generator:
308
+ if not list(generator.state_dict().keys())[0].startswith("module."):
309
+ raw_dict = checkpoint["generator"]
310
+ clean_dict = OrderedDict()
311
+ for k, v in raw_dict.items():
312
+ if k.startswith("module."):
313
+ clean_dict[k[7:]] = v
314
+ else:
315
+ clean_dict[k] = v
316
+ generator.load_state_dict(clean_dict)
317
+ else:
318
+ generator.load_state_dict(checkpoint["generator"])
319
+ if g_optimizer:
320
+ g_optimizer.load_state_dict(checkpoint["g_optimizer"])
321
+ global_step = 100000
322
+ if discriminator and "discriminator" in checkpoint.keys():
323
+ discriminator.load_state_dict(checkpoint["discriminator"])
324
+ global_step = checkpoint["global_step"]
325
+ print("restore discriminator")
326
+ if d_optimizer and "d_optimizer" in checkpoint.keys():
327
+ d_optimizer.load_state_dict(checkpoint["d_optimizer"])
328
+ print("restore d_optimizer...")
329
+ else:
330
+ global_step = 0
331
+ return global_step
332
+
333
+
334
+ class ExponentialMovingAverage(object):
335
+ def __init__(self, decay):
336
+ self.decay = decay
337
+ self.shadow = {}
338
+
339
+ def register(self, name, val):
340
+ self.shadow[name] = val.clone()
341
+
342
+ def update(self, name, x):
343
+ assert name in self.shadow
344
+ update_delta = self.shadow[name] - x
345
+ self.shadow[name] -= (1.0 - self.decay) * update_delta
346
+
347
+
348
+ def apply_moving_average(model, ema):
349
+ for name, param in model.named_parameters():
350
+ if name in ema.shadow:
351
+ ema.update(name, param.data)
352
+
353
+
354
+ def register_model_to_ema(model, ema):
355
+ for name, param in model.named_parameters():
356
+ if param.requires_grad:
357
+ ema.register(name, param.data)
358
+
359
+
360
+ class YParams(HParams):
361
+ def __init__(self, yaml_file):
362
+ if not os.path.exists(yaml_file):
363
+ raise IOError("yaml file: {} is not existed".format(yaml_file))
364
+ super().__init__()
365
+ self.d = collections.OrderedDict()
366
+ with open(yaml_file) as fp:
367
+ for _, v in yaml().load(fp).items():
368
+ for k1, v1 in v.items():
369
+ try:
370
+ if self.get(k1):
371
+ self.set_hparam(k1, v1)
372
+ else:
373
+ self.add_hparam(k1, v1)
374
+ self.d[k1] = v1
375
+ except Exception:
376
+ import traceback
377
+
378
+ print(traceback.format_exc())
379
+
380
+ # @property
381
+ def get_elements(self):
382
+ return self.d.items()
383
+
384
+
385
+ def override_config(base_config, new_config):
386
+ """Update new configurations in the original dict with the new dict
387
+
388
+ Args:
389
+ base_config (dict): original dict to be overridden
390
+ new_config (dict): dict with new configurations
391
+
392
+ Returns:
393
+ dict: updated configuration dict
394
+ """
395
+ for k, v in new_config.items():
396
+ if type(v) == dict:
397
+ if k not in base_config.keys():
398
+ base_config[k] = {}
399
+ base_config[k] = override_config(base_config[k], v)
400
+ else:
401
+ base_config[k] = v
402
+ return base_config
403
+
404
+
405
+ def get_lowercase_keys_config(cfg):
406
+ """Change all keys in cfg to lower case
407
+
408
+ Args:
409
+ cfg (dict): dictionary that stores configurations
410
+
411
+ Returns:
412
+ dict: dictionary that stores configurations
413
+ """
414
+ updated_cfg = dict()
415
+ for k, v in cfg.items():
416
+ if type(v) == dict:
417
+ v = get_lowercase_keys_config(v)
418
+ updated_cfg[k.lower()] = v
419
+ return updated_cfg
420
+
421
+
422
+ def _load_config(config_fn, lowercase=False):
423
+ """Load configurations into a dictionary
424
+
425
+ Args:
426
+ config_fn (str): path to configuration file
427
+ lowercase (bool, optional): whether changing keys to lower case. Defaults to False.
428
+
429
+ Returns:
430
+ dict: dictionary that stores configurations
431
+ """
432
+ with open(config_fn, "r") as f:
433
+ data = f.read()
434
+ config_ = json5.loads(data)
435
+ if "base_config" in config_:
436
+ # load configurations from new path
437
+ p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"])
438
+ p_config_ = _load_config(p_config_path)
439
+ config_ = override_config(p_config_, config_)
440
+ if lowercase:
441
+ # change keys in config_ to lower case
442
+ config_ = get_lowercase_keys_config(config_)
443
+ return config_
444
+
445
+
446
+ def load_config(config_fn, lowercase=False):
447
+ """Load configurations into a dictionary
448
+
449
+ Args:
450
+ config_fn (str): path to configuration file
451
+ lowercase (bool, optional): _description_. Defaults to False.
452
+
453
+ Returns:
454
+ JsonHParams: an object that stores configurations
455
+ """
456
+ config_ = _load_config(config_fn, lowercase=lowercase)
457
+ # create an JsonHParams object with configuration dict
458
+ cfg = JsonHParams(**config_)
459
+ return cfg
460
+
461
+
462
+ def save_config(save_path, cfg):
463
+ """Save configurations into a json file
464
+
465
+ Args:
466
+ save_path (str): path to save configurations
467
+ cfg (dict): dictionary that stores configurations
468
+ """
469
+ with open(save_path, "w") as f:
470
+ json5.dump(
471
+ cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True
472
+ )
473
+
474
+
475
+ class JsonHParams:
476
+ def __init__(self, **kwargs):
477
+ for k, v in kwargs.items():
478
+ if type(v) == dict:
479
+ v = JsonHParams(**v)
480
+ self[k] = v
481
+
482
+ def keys(self):
483
+ return self.__dict__.keys()
484
+
485
+ def items(self):
486
+ return self.__dict__.items()
487
+
488
+ def values(self):
489
+ return self.__dict__.values()
490
+
491
+ def __len__(self):
492
+ return len(self.__dict__)
493
+
494
+ def __getitem__(self, key):
495
+ return getattr(self, key)
496
+
497
+ def __setitem__(self, key, value):
498
+ return setattr(self, key, value)
499
+
500
+ def __contains__(self, key):
501
+ return key in self.__dict__
502
+
503
+ def __repr__(self):
504
+ return self.__dict__.__repr__()
505
+
506
+
507
+ class ValueWindow:
508
+ def __init__(self, window_size=100):
509
+ self._window_size = window_size
510
+ self._values = []
511
+
512
+ def append(self, x):
513
+ self._values = self._values[-(self._window_size - 1) :] + [x]
514
+
515
+ @property
516
+ def sum(self):
517
+ return sum(self._values)
518
+
519
+ @property
520
+ def count(self):
521
+ return len(self._values)
522
+
523
+ @property
524
+ def average(self):
525
+ return self.sum / max(1, self.count)
526
+
527
+ def reset(self):
528
+ self._values = []
529
+
530
+
531
+ class Logger(object):
532
+ def __init__(
533
+ self,
534
+ filename,
535
+ level="info",
536
+ when="D",
537
+ backCount=10,
538
+ fmt="%(asctime)s : %(message)s",
539
+ ):
540
+ self.level_relations = {
541
+ "debug": logging.DEBUG,
542
+ "info": logging.INFO,
543
+ "warning": logging.WARNING,
544
+ "error": logging.ERROR,
545
+ "crit": logging.CRITICAL,
546
+ }
547
+ if level == "debug":
548
+ fmt = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
549
+ self.logger = logging.getLogger(filename)
550
+ format_str = logging.Formatter(fmt)
551
+ self.logger.setLevel(self.level_relations.get(level))
552
+ sh = logging.StreamHandler()
553
+ sh.setFormatter(format_str)
554
+ th = handlers.TimedRotatingFileHandler(
555
+ filename=filename, when=when, backupCount=backCount, encoding="utf-8"
556
+ )
557
+ th.setFormatter(format_str)
558
+ self.logger.addHandler(sh)
559
+ self.logger.addHandler(th)
560
+ self.logger.info(
561
+ "==========================New Starting Here=============================="
562
+ )
563
+
564
+
565
+ def init_weights(m, mean=0.0, std=0.01):
566
+ classname = m.__class__.__name__
567
+ if classname.find("Conv") != -1:
568
+ m.weight.data.normal_(mean, std)
569
+
570
+
571
+ def get_padding(kernel_size, dilation=1):
572
+ return int((kernel_size * dilation - dilation) / 2)
573
+
574
+
575
+ def slice_segments(x, ids_str, segment_size=4):
576
+ ret = torch.zeros_like(x[:, :, :segment_size])
577
+ for i in range(x.size(0)):
578
+ idx_str = ids_str[i]
579
+ idx_end = idx_str + segment_size
580
+ ret[i] = x[i, :, idx_str:idx_end]
581
+ return ret
582
+
583
+
584
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
585
+ b, d, t = x.size()
586
+ if x_lengths is None:
587
+ x_lengths = t
588
+ ids_str_max = x_lengths - segment_size + 1
589
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
590
+ ret = slice_segments(x, ids_str, segment_size)
591
+ return ret, ids_str
592
+
593
+
594
+ def subsequent_mask(length):
595
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
596
+ return mask
597
+
598
+
599
+ @torch.jit.script
600
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
601
+ n_channels_int = n_channels[0]
602
+ in_act = input_a + input_b
603
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
604
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
605
+ acts = t_act * s_act
606
+ return acts
607
+
608
+
609
+ def convert_pad_shape(pad_shape):
610
+ l = pad_shape[::-1]
611
+ pad_shape = [item for sublist in l for item in sublist]
612
+ return pad_shape
613
+
614
+
615
+ def sequence_mask(length, max_length=None):
616
+ if max_length is None:
617
+ max_length = length.max()
618
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
619
+ return x.unsqueeze(0) < length.unsqueeze(1)
620
+
621
+
622
+ def generate_path(duration, mask):
623
+ """
624
+ duration: [b, 1, t_x]
625
+ mask: [b, 1, t_y, t_x]
626
+ """
627
+ device = duration.device
628
+
629
+ b, _, t_y, t_x = mask.shape
630
+ cum_duration = torch.cumsum(duration, -1)
631
+
632
+ cum_duration_flat = cum_duration.view(b * t_x)
633
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
634
+ path = path.view(b, t_x, t_y)
635
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
636
+ path = path.unsqueeze(1).transpose(2, 3) * mask
637
+ return path
638
+
639
+
640
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
641
+ if isinstance(parameters, torch.Tensor):
642
+ parameters = [parameters]
643
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
644
+ norm_type = float(norm_type)
645
+ if clip_value is not None:
646
+ clip_value = float(clip_value)
647
+
648
+ total_norm = 0
649
+ for p in parameters:
650
+ param_norm = p.grad.data.norm(norm_type)
651
+ total_norm += param_norm.item() ** norm_type
652
+ if clip_value is not None:
653
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
654
+ total_norm = total_norm ** (1.0 / norm_type)
655
+ return total_norm
656
+
657
+
658
+ def get_current_time():
659
+ pass
660
+
661
+
662
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
663
+ """
664
+ Args:
665
+ lengths:
666
+ A 1-D tensor containing sentence lengths.
667
+ max_len:
668
+ The length of masks.
669
+ Returns:
670
+ Return a 2-D bool tensor, where masked positions
671
+ are filled with `True` and non-masked positions are
672
+ filled with `False`.
673
+
674
+ >>> lengths = torch.tensor([1, 3, 2, 5])
675
+ >>> make_pad_mask(lengths)
676
+ tensor([[False, True, True, True, True],
677
+ [False, False, False, True, True],
678
+ [False, False, True, True, True],
679
+ [False, False, False, False, False]])
680
+ """
681
+ assert lengths.ndim == 1, lengths.ndim
682
+ max_len = max(max_len, lengths.max())
683
+ n = lengths.size(0)
684
+ seq_range = torch.arange(0, max_len, device=lengths.device)
685
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
686
+
687
+ return expaned_lengths >= lengths.unsqueeze(-1)