Spaces:
Running
on
T4
Running
on
T4
Commit
·
ad798d2
1
Parent(s):
1ccbf40
add demo
Browse files- app.py +147 -0
- models/codec/amphion_codec/__pycache__/vocos.cpython-310.pyc +0 -0
- models/codec/amphion_codec/codec.py +422 -0
- models/codec/amphion_codec/loss.py +401 -0
- models/codec/amphion_codec/quantize/__init__.py +6 -0
- models/codec/amphion_codec/quantize/__pycache__/__init__.cpython-310.pyc +0 -0
- models/codec/amphion_codec/quantize/__pycache__/bsq.cpython-310.pyc +0 -0
- models/codec/amphion_codec/quantize/__pycache__/factorized_vector_quantize.cpython-310.pyc +0 -0
- models/codec/amphion_codec/quantize/__pycache__/lookup_free_quantize.cpython-310.pyc +0 -0
- models/codec/amphion_codec/quantize/__pycache__/residual_vq.cpython-310.pyc +0 -0
- models/codec/amphion_codec/quantize/__pycache__/vector_quantize.cpython-310.pyc +0 -0
- models/codec/amphion_codec/quantize/bsq.py +373 -0
- models/codec/amphion_codec/quantize/factorized_vector_quantize.py +145 -0
- models/codec/amphion_codec/quantize/lookup_free_quantize.py +72 -0
- models/codec/amphion_codec/quantize/residual_vq.py +172 -0
- models/codec/amphion_codec/quantize/vector_quantize.py +396 -0
- models/codec/amphion_codec/vocos.py +909 -0
- models/codec/melvqgan/__pycache__/melspec.cpython-310.pyc +0 -0
- models/codec/melvqgan/melspec.py +153 -0
- models/tts/llm_tts/__pycache__/chat_template.cpython-310.pyc +0 -0
- models/tts/llm_tts/__pycache__/inference_llm_tts.cpython-310.pyc +0 -0
- models/tts/llm_tts/__pycache__/inference_mgm_tts.cpython-310.pyc +0 -0
- models/tts/llm_tts/__pycache__/llama_nar_prefix.cpython-310.pyc +0 -0
- models/tts/llm_tts/__pycache__/mgm.cpython-310.pyc +0 -0
- models/tts/llm_tts/chat_template.py +96 -0
- models/tts/llm_tts/inference_llm_tts.py +265 -0
- models/tts/llm_tts/inference_mgm_tts.py +338 -0
- models/tts/llm_tts/llama_nar_prefix.py +457 -0
- models/tts/llm_tts/mgm.py +385 -0
- models/tts/tadicodec/__pycache__/infer_utils.cpython-310.pyc +0 -0
- models/tts/tadicodec/__pycache__/inference_tadicodec.cpython-310.pyc +0 -0
- models/tts/tadicodec/__pycache__/llama_nar_prefix.cpython-310.pyc +0 -0
- models/tts/tadicodec/__pycache__/modeling_tadicodec.cpython-310.pyc +0 -0
- models/tts/tadicodec/infer_utils.py +24 -0
- models/tts/tadicodec/inference_tadicodec.py +279 -0
- models/tts/tadicodec/llama_nar_prefix.py +572 -0
- models/tts/tadicodec/modeling_tadicodec.py +641 -0
- requirements.txt +9 -0
- utils/__pycache__/hparam.cpython-310.pyc +0 -0
- utils/__pycache__/util.cpython-310.pyc +0 -0
- utils/hparam.py +659 -0
- utils/util.py +687 -0
app.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import soundfile as sf
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
import warnings
|
7 |
+
warnings.filterwarnings("ignore")
|
8 |
+
|
9 |
+
try:
|
10 |
+
from models.tts.llm_tts.inference_llm_tts import TTSInferencePipeline
|
11 |
+
MODEL_AVAILABLE = True
|
12 |
+
except ImportError:
|
13 |
+
print("Warning: TaDiCodec models not found. Running in demo mode.")
|
14 |
+
MODEL_AVAILABLE = False
|
15 |
+
|
16 |
+
class TaDiCodecTTSDemo:
|
17 |
+
def __init__(self):
|
18 |
+
self.pipeline = None
|
19 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
self.load_model()
|
21 |
+
|
22 |
+
def load_model(self):
|
23 |
+
try:
|
24 |
+
if MODEL_AVAILABLE:
|
25 |
+
print("Loading TaDiCodec-TTS-AR-Qwen2.5-0.5B model...")
|
26 |
+
self.pipeline = TTSInferencePipeline.from_pretrained(
|
27 |
+
tadicodec_path="amphion/TaDiCodec",
|
28 |
+
llm_path="amphion/TaDiCodec-TTS-AR-Qwen2.5-0.5B",
|
29 |
+
device=self.device,
|
30 |
+
)
|
31 |
+
print("Model loaded successfully!")
|
32 |
+
else:
|
33 |
+
print("Running in demo mode - model files not available")
|
34 |
+
self.pipeline = None
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error loading model: {e}")
|
37 |
+
self.pipeline = None
|
38 |
+
|
39 |
+
def synthesize_speech(self, text, reference_audio=None, reference_text=""):
|
40 |
+
"""
|
41 |
+
Synthesize speech from text using TaDiCodec TTS
|
42 |
+
"""
|
43 |
+
if not text.strip():
|
44 |
+
return None, "Please enter some text to synthesize."
|
45 |
+
|
46 |
+
try:
|
47 |
+
if self.pipeline is not None:
|
48 |
+
# Use actual TaDiCodec inference
|
49 |
+
if reference_audio and reference_text.strip():
|
50 |
+
audio = self.pipeline(
|
51 |
+
text=text,
|
52 |
+
prompt_text=reference_text,
|
53 |
+
prompt_speech_path=reference_audio,
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
audio = self.pipeline(text=text)
|
57 |
+
|
58 |
+
# Save to temporary file
|
59 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
60 |
+
sf.write(tmp_file.name, audio, 24000)
|
61 |
+
return tmp_file.name, "Speech synthesized successfully!"
|
62 |
+
else:
|
63 |
+
# Fallback demo mode - return None to indicate no audio generated
|
64 |
+
return None, "Demo mode - TaDiCodec model not available. Please install the required models."
|
65 |
+
|
66 |
+
except Exception as e:
|
67 |
+
return None, f"Error during synthesis: {str(e)}"
|
68 |
+
|
69 |
+
# Initialize the demo
|
70 |
+
demo_instance = TaDiCodecTTSDemo()
|
71 |
+
|
72 |
+
def tts_interface(text, reference_audio, reference_text):
|
73 |
+
"""Interface function for Gradio"""
|
74 |
+
audio_path, message = demo_instance.synthesize_speech(text, reference_audio, reference_text)
|
75 |
+
return audio_path, message
|
76 |
+
|
77 |
+
# Create Gradio interface
|
78 |
+
with gr.Blocks(title="TaDiCodec-TTS Demo", theme=gr.themes.Soft()) as demo:
|
79 |
+
gr.Markdown("""
|
80 |
+
# TaDiCodec-TTS-AR-Qwen2.5-0.5B Demo
|
81 |
+
|
82 |
+
This is a demo of the TaDiCodec Text-to-Speech model with Qwen2.5-0.5B backbone.
|
83 |
+
|
84 |
+
**Features:**
|
85 |
+
- Voice cloning with reference audio
|
86 |
+
- Code-switching support (e.g., mixing English and Chinese)
|
87 |
+
- Extremely low bitrate (0.0875 kbps)
|
88 |
+
- High-quality speech generation
|
89 |
+
""")
|
90 |
+
|
91 |
+
with gr.Row():
|
92 |
+
with gr.Column():
|
93 |
+
text_input = gr.Textbox(
|
94 |
+
label="Text to Synthesize",
|
95 |
+
placeholder="Enter the text you want to convert to speech...",
|
96 |
+
lines=3,
|
97 |
+
value="但是 to those who 知道 her well, it was a 标志 of her unwavering 决心 and spirit."
|
98 |
+
)
|
99 |
+
|
100 |
+
reference_text = gr.Textbox(
|
101 |
+
label="Reference Text",
|
102 |
+
placeholder="Text corresponding to the reference audio for voice cloning...",
|
103 |
+
lines=2,
|
104 |
+
value="In short, we embarked on a mission to make America great again, for all Americans."
|
105 |
+
)
|
106 |
+
|
107 |
+
reference_audio = gr.Audio(
|
108 |
+
label="Reference Audio",
|
109 |
+
type="filepath",
|
110 |
+
)
|
111 |
+
|
112 |
+
synthesize_btn = gr.Button("Synthesize Speech", variant="primary")
|
113 |
+
|
114 |
+
with gr.Column():
|
115 |
+
output_audio = gr.Audio(
|
116 |
+
label="Generated Speech",
|
117 |
+
type="filepath"
|
118 |
+
)
|
119 |
+
|
120 |
+
status_message = gr.Textbox(
|
121 |
+
label="Status",
|
122 |
+
interactive=False
|
123 |
+
)
|
124 |
+
|
125 |
+
# Example inputs
|
126 |
+
gr.Markdown("### Example Inputs")
|
127 |
+
examples = [
|
128 |
+
["Yes, usually people choose to face life with more positive emotions, after all, happy times are always yearning. However, sometimes slowing down and experiencing the details of life can bring deeper joy and satisfaction. What do you think?", "Jittery Jack's jam jars jiggled jauntily, jolting Jack's jumbled jelly-filled jars joyously.", "sample/tongueTwisters_en_018.wav"],
|
129 |
+
["You think you can just waltz in here and cause chaos? Well, I've got news for you. This time, there's no escaping the consequences. So, one by one, step forward, and let's see who's bold enough to face the music. It's time for a little dose of reality—prepare to be dealt with!","Get in line trouble makers, and I will take care of you.", "sample/en_013.wav"],
|
130 |
+
]
|
131 |
+
|
132 |
+
gr.Examples(
|
133 |
+
examples=examples,
|
134 |
+
inputs=[text_input, reference_text, reference_audio],
|
135 |
+
outputs=[output_audio, status_message],
|
136 |
+
fn=tts_interface
|
137 |
+
)
|
138 |
+
|
139 |
+
# Connect the interface
|
140 |
+
synthesize_btn.click(
|
141 |
+
fn=tts_interface,
|
142 |
+
inputs=[text_input, reference_audio, reference_text],
|
143 |
+
outputs=[output_audio, status_message]
|
144 |
+
)
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
demo.launch(share=True, server_name="0.0.0.0")
|
models/codec/amphion_codec/__pycache__/vocos.cpython-310.pyc
ADDED
Binary file (26.3 kB). View file
|
|
models/codec/amphion_codec/codec.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
from torch.nn.utils import weight_norm
|
8 |
+
|
9 |
+
from models.codec.amphion_codec.quantize import (
|
10 |
+
ResidualVQ,
|
11 |
+
VectorQuantize,
|
12 |
+
FactorizedVectorQuantize,
|
13 |
+
LookupFreeQuantize,
|
14 |
+
)
|
15 |
+
|
16 |
+
from models.codec.amphion_codec.vocos import Vocos
|
17 |
+
|
18 |
+
|
19 |
+
def WNConv1d(*args, **kwargs):
|
20 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
21 |
+
|
22 |
+
|
23 |
+
def WNConvTranspose1d(*args, **kwargs):
|
24 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
25 |
+
|
26 |
+
|
27 |
+
# Scripting this brings model speed up 1.4x
|
28 |
+
@torch.jit.script
|
29 |
+
def snake(x, alpha):
|
30 |
+
shape = x.shape
|
31 |
+
x = x.reshape(shape[0], shape[1], -1)
|
32 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
33 |
+
x = x.reshape(shape)
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class Snake1d(nn.Module):
|
38 |
+
def __init__(self, channels):
|
39 |
+
super().__init__()
|
40 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return snake(x, self.alpha)
|
44 |
+
|
45 |
+
|
46 |
+
def init_weights(m):
|
47 |
+
if isinstance(m, nn.Conv1d):
|
48 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
49 |
+
nn.init.constant_(m.bias, 0)
|
50 |
+
if isinstance(m, nn.Linear):
|
51 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
52 |
+
nn.init.constant_(m.bias, 0)
|
53 |
+
|
54 |
+
|
55 |
+
class ResidualUnit(nn.Module):
|
56 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
57 |
+
super().__init__()
|
58 |
+
pad = ((7 - 1) * dilation) // 2
|
59 |
+
self.block = nn.Sequential(
|
60 |
+
Snake1d(dim),
|
61 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
62 |
+
Snake1d(dim),
|
63 |
+
WNConv1d(dim, dim, kernel_size=1),
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
y = self.block(x)
|
68 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
69 |
+
if pad > 0:
|
70 |
+
x = x[..., pad:-pad]
|
71 |
+
return x + y
|
72 |
+
|
73 |
+
|
74 |
+
class EncoderBlock(nn.Module):
|
75 |
+
def __init__(self, dim: int = 16, stride: int = 1):
|
76 |
+
super().__init__()
|
77 |
+
self.block = nn.Sequential(
|
78 |
+
ResidualUnit(dim // 2, dilation=1),
|
79 |
+
ResidualUnit(dim // 2, dilation=3),
|
80 |
+
ResidualUnit(dim // 2, dilation=9),
|
81 |
+
Snake1d(dim // 2),
|
82 |
+
WNConv1d(
|
83 |
+
dim // 2,
|
84 |
+
dim,
|
85 |
+
kernel_size=2 * stride,
|
86 |
+
stride=stride,
|
87 |
+
padding=math.ceil(stride / 2),
|
88 |
+
),
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
return self.block(x)
|
93 |
+
|
94 |
+
|
95 |
+
class CodecEncoder(nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
d_model: int = 64,
|
99 |
+
up_ratios: list = [4, 5, 5, 6],
|
100 |
+
out_channels: int = 256,
|
101 |
+
use_tanh: bool = False,
|
102 |
+
cfg=None,
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
|
106 |
+
d_model = cfg.d_model if cfg is not None else d_model
|
107 |
+
up_ratios = cfg.up_ratios if cfg is not None else up_ratios
|
108 |
+
out_channels = cfg.out_channels if cfg is not None else out_channels
|
109 |
+
use_tanh = cfg.use_tanh if cfg is not None else use_tanh
|
110 |
+
|
111 |
+
# Create first convolution
|
112 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
113 |
+
|
114 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
115 |
+
for stride in up_ratios:
|
116 |
+
d_model *= 2
|
117 |
+
self.block += [EncoderBlock(d_model, stride=stride)]
|
118 |
+
|
119 |
+
# Create last convolution
|
120 |
+
self.block += [
|
121 |
+
Snake1d(d_model),
|
122 |
+
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
|
123 |
+
]
|
124 |
+
|
125 |
+
if use_tanh:
|
126 |
+
self.block += [nn.Tanh()]
|
127 |
+
|
128 |
+
# Wrap black into nn.Sequential
|
129 |
+
self.block = nn.Sequential(*self.block)
|
130 |
+
self.enc_dim = d_model
|
131 |
+
|
132 |
+
self.reset_parameters()
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
return self.block(x)
|
136 |
+
|
137 |
+
def reset_parameters(self):
|
138 |
+
self.apply(init_weights)
|
139 |
+
|
140 |
+
|
141 |
+
class DecoderBlock(nn.Module):
|
142 |
+
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
143 |
+
super().__init__()
|
144 |
+
self.block = nn.Sequential(
|
145 |
+
Snake1d(input_dim),
|
146 |
+
WNConvTranspose1d(
|
147 |
+
input_dim,
|
148 |
+
output_dim,
|
149 |
+
kernel_size=2 * stride,
|
150 |
+
stride=stride,
|
151 |
+
padding=stride // 2 + stride % 2,
|
152 |
+
output_padding=stride % 2,
|
153 |
+
),
|
154 |
+
ResidualUnit(output_dim, dilation=1),
|
155 |
+
ResidualUnit(output_dim, dilation=3),
|
156 |
+
ResidualUnit(output_dim, dilation=9),
|
157 |
+
)
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
return self.block(x)
|
161 |
+
|
162 |
+
|
163 |
+
class CodecDecoder(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
in_channels: int = 256,
|
167 |
+
upsample_initial_channel: int = 1536,
|
168 |
+
up_ratios: list = [5, 5, 4, 2],
|
169 |
+
num_quantizers: int = 8,
|
170 |
+
codebook_size: int = 1024,
|
171 |
+
codebook_dim: int = 256,
|
172 |
+
quantizer_type: str = "vq",
|
173 |
+
quantizer_dropout: float = 0.5,
|
174 |
+
commitment: float = 0.25,
|
175 |
+
codebook_loss_weight: float = 1.0,
|
176 |
+
use_l2_normlize: bool = False,
|
177 |
+
codebook_type: str = "euclidean",
|
178 |
+
kmeans_init: bool = False,
|
179 |
+
kmeans_iters: int = 10,
|
180 |
+
decay: float = 0.8,
|
181 |
+
eps: float = 1e-5,
|
182 |
+
threshold_ema_dead_code: int = 2,
|
183 |
+
weight_init: bool = False,
|
184 |
+
use_vocos: bool = False,
|
185 |
+
vocos_dim: int = 384,
|
186 |
+
vocos_intermediate_dim: int = 1152,
|
187 |
+
vocos_num_layers: int = 8,
|
188 |
+
n_fft: int = 800,
|
189 |
+
hop_size: int = 200,
|
190 |
+
padding: str = "same",
|
191 |
+
cfg=None,
|
192 |
+
):
|
193 |
+
super().__init__()
|
194 |
+
|
195 |
+
in_channels = (
|
196 |
+
cfg.in_channels
|
197 |
+
if cfg is not None and hasattr(cfg, "in_channels")
|
198 |
+
else in_channels
|
199 |
+
)
|
200 |
+
upsample_initial_channel = (
|
201 |
+
cfg.upsample_initial_channel
|
202 |
+
if cfg is not None and hasattr(cfg, "upsample_initial_channel")
|
203 |
+
else upsample_initial_channel
|
204 |
+
)
|
205 |
+
up_ratios = (
|
206 |
+
cfg.up_ratios
|
207 |
+
if cfg is not None and hasattr(cfg, "up_ratios")
|
208 |
+
else up_ratios
|
209 |
+
)
|
210 |
+
num_quantizers = (
|
211 |
+
cfg.num_quantizers
|
212 |
+
if cfg is not None and hasattr(cfg, "num_quantizers")
|
213 |
+
else num_quantizers
|
214 |
+
)
|
215 |
+
codebook_size = (
|
216 |
+
cfg.codebook_size
|
217 |
+
if cfg is not None and hasattr(cfg, "codebook_size")
|
218 |
+
else codebook_size
|
219 |
+
)
|
220 |
+
codebook_dim = (
|
221 |
+
cfg.codebook_dim
|
222 |
+
if cfg is not None and hasattr(cfg, "codebook_dim")
|
223 |
+
else codebook_dim
|
224 |
+
)
|
225 |
+
quantizer_type = (
|
226 |
+
cfg.quantizer_type
|
227 |
+
if cfg is not None and hasattr(cfg, "quantizer_type")
|
228 |
+
else quantizer_type
|
229 |
+
)
|
230 |
+
quantizer_dropout = (
|
231 |
+
cfg.quantizer_dropout
|
232 |
+
if cfg is not None and hasattr(cfg, "quantizer_dropout")
|
233 |
+
else quantizer_dropout
|
234 |
+
)
|
235 |
+
commitment = (
|
236 |
+
cfg.commitment
|
237 |
+
if cfg is not None and hasattr(cfg, "commitment")
|
238 |
+
else commitment
|
239 |
+
)
|
240 |
+
codebook_loss_weight = (
|
241 |
+
cfg.codebook_loss_weight
|
242 |
+
if cfg is not None and hasattr(cfg, "codebook_loss_weight")
|
243 |
+
else codebook_loss_weight
|
244 |
+
)
|
245 |
+
use_l2_normlize = (
|
246 |
+
cfg.use_l2_normlize
|
247 |
+
if cfg is not None and hasattr(cfg, "use_l2_normlize")
|
248 |
+
else use_l2_normlize
|
249 |
+
)
|
250 |
+
codebook_type = (
|
251 |
+
cfg.codebook_type
|
252 |
+
if cfg is not None and hasattr(cfg, "codebook_type")
|
253 |
+
else codebook_type
|
254 |
+
)
|
255 |
+
kmeans_init = (
|
256 |
+
cfg.kmeans_init
|
257 |
+
if cfg is not None and hasattr(cfg, "kmeans_init")
|
258 |
+
else kmeans_init
|
259 |
+
)
|
260 |
+
kmeans_iters = (
|
261 |
+
cfg.kmeans_iters
|
262 |
+
if cfg is not None and hasattr(cfg, "kmeans_iters")
|
263 |
+
else kmeans_iters
|
264 |
+
)
|
265 |
+
decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay
|
266 |
+
eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps
|
267 |
+
threshold_ema_dead_code = (
|
268 |
+
cfg.threshold_ema_dead_code
|
269 |
+
if cfg is not None and hasattr(cfg, "threshold_ema_dead_code")
|
270 |
+
else threshold_ema_dead_code
|
271 |
+
)
|
272 |
+
weight_init = (
|
273 |
+
cfg.weight_init
|
274 |
+
if cfg is not None and hasattr(cfg, "weight_init")
|
275 |
+
else weight_init
|
276 |
+
)
|
277 |
+
use_vocos = (
|
278 |
+
cfg.use_vocos
|
279 |
+
if cfg is not None and hasattr(cfg, "use_vocos")
|
280 |
+
else use_vocos
|
281 |
+
)
|
282 |
+
vocos_dim = (
|
283 |
+
cfg.vocos_dim
|
284 |
+
if cfg is not None and hasattr(cfg, "vocos_dim")
|
285 |
+
else vocos_dim
|
286 |
+
)
|
287 |
+
vocos_intermediate_dim = (
|
288 |
+
cfg.vocos_intermediate_dim
|
289 |
+
if cfg is not None and hasattr(cfg, "vocos_intermediate_dim")
|
290 |
+
else vocos_intermediate_dim
|
291 |
+
)
|
292 |
+
vocos_num_layers = (
|
293 |
+
cfg.vocos_num_layers
|
294 |
+
if cfg is not None and hasattr(cfg, "vocos_num_layers")
|
295 |
+
else vocos_num_layers
|
296 |
+
)
|
297 |
+
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
|
298 |
+
hop_size = (
|
299 |
+
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
|
300 |
+
)
|
301 |
+
padding = (
|
302 |
+
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
|
303 |
+
)
|
304 |
+
|
305 |
+
if quantizer_type == "vq":
|
306 |
+
self.quantizer = ResidualVQ(
|
307 |
+
input_dim=in_channels,
|
308 |
+
num_quantizers=num_quantizers,
|
309 |
+
codebook_size=codebook_size,
|
310 |
+
codebook_dim=codebook_dim,
|
311 |
+
quantizer_type=quantizer_type,
|
312 |
+
quantizer_dropout=quantizer_dropout,
|
313 |
+
commitment=commitment,
|
314 |
+
codebook_loss_weight=codebook_loss_weight,
|
315 |
+
use_l2_normlize=use_l2_normlize,
|
316 |
+
codebook_type=codebook_type,
|
317 |
+
kmeans_init=kmeans_init,
|
318 |
+
kmeans_iters=kmeans_iters,
|
319 |
+
decay=decay,
|
320 |
+
eps=eps,
|
321 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
322 |
+
weight_init=weight_init,
|
323 |
+
)
|
324 |
+
elif quantizer_type == "fvq":
|
325 |
+
self.quantizer = ResidualVQ(
|
326 |
+
input_dim=in_channels,
|
327 |
+
num_quantizers=num_quantizers,
|
328 |
+
codebook_size=codebook_size,
|
329 |
+
codebook_dim=codebook_dim,
|
330 |
+
quantizer_type=quantizer_type,
|
331 |
+
quantizer_dropout=quantizer_dropout,
|
332 |
+
commitment=commitment,
|
333 |
+
codebook_loss_weight=codebook_loss_weight,
|
334 |
+
use_l2_normlize=use_l2_normlize,
|
335 |
+
)
|
336 |
+
elif quantizer_type == "lfq":
|
337 |
+
self.quantizer = ResidualVQ(
|
338 |
+
input_dim=in_channels,
|
339 |
+
num_quantizers=num_quantizers,
|
340 |
+
codebook_size=codebook_size,
|
341 |
+
codebook_dim=codebook_dim,
|
342 |
+
quantizer_type=quantizer_type,
|
343 |
+
)
|
344 |
+
else:
|
345 |
+
raise ValueError(f"Unknown quantizer type {quantizer_type}")
|
346 |
+
|
347 |
+
if not use_vocos:
|
348 |
+
# Add first conv layer
|
349 |
+
channels = upsample_initial_channel
|
350 |
+
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
|
351 |
+
|
352 |
+
# Add upsampling + MRF blocks
|
353 |
+
for i, stride in enumerate(up_ratios):
|
354 |
+
input_dim = channels // 2**i
|
355 |
+
output_dim = channels // 2 ** (i + 1)
|
356 |
+
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
357 |
+
|
358 |
+
# Add final conv layer
|
359 |
+
layers += [
|
360 |
+
Snake1d(output_dim),
|
361 |
+
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
|
362 |
+
nn.Tanh(),
|
363 |
+
]
|
364 |
+
|
365 |
+
self.model = nn.Sequential(*layers)
|
366 |
+
|
367 |
+
if use_vocos:
|
368 |
+
self.model = Vocos(
|
369 |
+
input_channels=in_channels,
|
370 |
+
dim=vocos_dim,
|
371 |
+
intermediate_dim=vocos_intermediate_dim,
|
372 |
+
num_layers=vocos_num_layers,
|
373 |
+
adanorm_num_embeddings=None,
|
374 |
+
n_fft=n_fft,
|
375 |
+
hop_size=hop_size,
|
376 |
+
padding=padding,
|
377 |
+
)
|
378 |
+
|
379 |
+
self.reset_parameters()
|
380 |
+
|
381 |
+
def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None):
|
382 |
+
"""
|
383 |
+
if vq is True, x = encoder output, then return quantized output;
|
384 |
+
else, x = quantized output, then return decoder output
|
385 |
+
"""
|
386 |
+
if vq is True:
|
387 |
+
if eval_vq:
|
388 |
+
self.quantizer.eval()
|
389 |
+
(
|
390 |
+
quantized_out,
|
391 |
+
all_indices,
|
392 |
+
all_commit_losses,
|
393 |
+
all_codebook_losses,
|
394 |
+
all_quantized,
|
395 |
+
) = self.quantizer(x, n_quantizers=n_quantizers)
|
396 |
+
return (
|
397 |
+
quantized_out,
|
398 |
+
all_indices,
|
399 |
+
all_commit_losses,
|
400 |
+
all_codebook_losses,
|
401 |
+
all_quantized,
|
402 |
+
)
|
403 |
+
|
404 |
+
return self.model(x)
|
405 |
+
|
406 |
+
def quantize(self, x, n_quantizers=None):
|
407 |
+
self.quantizer.eval()
|
408 |
+
quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers)
|
409 |
+
return quantized_out, vq
|
410 |
+
|
411 |
+
# TODO: check consistency of vq2emb and quantize
|
412 |
+
def vq2emb(self, vq, n_quantizers=None):
|
413 |
+
return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers)
|
414 |
+
|
415 |
+
def decode(self, x):
|
416 |
+
return self.model(x)
|
417 |
+
|
418 |
+
def latent2dist(self, x, n_quantizers=None):
|
419 |
+
return self.quantizer.latent2dist(x, n_quantizers=n_quantizers)
|
420 |
+
|
421 |
+
def reset_parameters(self):
|
422 |
+
self.apply(init_weights)
|
models/codec/amphion_codec/loss.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchaudio
|
7 |
+
from torchaudio.transforms import MelSpectrogram
|
8 |
+
from einops import rearrange
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
|
12 |
+
def stft(x, fft_size, hop_size, win_length, window, use_complex=False):
|
13 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
14 |
+
Args:
|
15 |
+
x (Tensor): Input signal tensor (B, T).
|
16 |
+
fft_size (int): FFT size.
|
17 |
+
hop_size (int): Hop size.
|
18 |
+
win_length (int): Window length.
|
19 |
+
window (str): Window function type.
|
20 |
+
Returns:
|
21 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
22 |
+
"""
|
23 |
+
|
24 |
+
x_stft = torch.stft(
|
25 |
+
x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True
|
26 |
+
)
|
27 |
+
|
28 |
+
# clamp is needed to avoid nan or inf
|
29 |
+
if not use_complex:
|
30 |
+
return torch.sqrt(
|
31 |
+
torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3)
|
32 |
+
).transpose(2, 1)
|
33 |
+
else:
|
34 |
+
res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1)
|
35 |
+
res = res.transpose(2, 3) # [B, 2, T, F]
|
36 |
+
return res
|
37 |
+
|
38 |
+
|
39 |
+
def compute_mag_scale(n_fft, sampling_rate):
|
40 |
+
frequencies = librosa.fft_frequencies(sr=sampling_rate, n_fft=n_fft)
|
41 |
+
frequencies = np.where(frequencies > 1e-10, frequencies, -10)
|
42 |
+
db_scale = librosa.frequency_weighting(frequencies).reshape(1, 1, -1)
|
43 |
+
mag_scale = np.sqrt(librosa.db_to_power(db_scale)).astype(np.float32)
|
44 |
+
return torch.from_numpy(mag_scale)
|
45 |
+
|
46 |
+
|
47 |
+
class SpectralConvergenceLoss(torch.nn.Module):
|
48 |
+
"""Spectral convergence loss module."""
|
49 |
+
|
50 |
+
def __init__(self):
|
51 |
+
"""Initialize spectral convergence loss module."""
|
52 |
+
super(SpectralConvergenceLoss, self).__init__()
|
53 |
+
|
54 |
+
def forward(self, x_mag, y_mag):
|
55 |
+
"""Calculate forward propagation.
|
56 |
+
Args:
|
57 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
58 |
+
y_mag (Tensor): Magnitude spectrogram of ground-truth signal (B, #frames, #freq_bins).
|
59 |
+
Returns:
|
60 |
+
Tensor: Spectral convergence loss value.
|
61 |
+
"""
|
62 |
+
return torch.norm(y_mag - x_mag) / torch.norm(y_mag)
|
63 |
+
|
64 |
+
|
65 |
+
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
66 |
+
"""Log STFT magnitude loss module."""
|
67 |
+
|
68 |
+
def __init__(self):
|
69 |
+
"""Initialize los STFT magnitude loss module."""
|
70 |
+
super(LogSTFTMagnitudeLoss, self).__init__()
|
71 |
+
|
72 |
+
def forward(self, x_mag, y_mag):
|
73 |
+
"""Calculate forward propagation.
|
74 |
+
Args:
|
75 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
76 |
+
y_mag (Tensor): Magnitude spectrogram of ground-truth signal (B, #frames, #freq_bins).
|
77 |
+
Returns:
|
78 |
+
Tensor: Log STFT magnitude loss value.
|
79 |
+
"""
|
80 |
+
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
81 |
+
|
82 |
+
|
83 |
+
class STFTLoss(torch.nn.Module):
|
84 |
+
"""STFT loss module."""
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
fft_size=1024,
|
89 |
+
hop_length=120,
|
90 |
+
win_length=600,
|
91 |
+
sampling_rate=16000,
|
92 |
+
window="hann_window",
|
93 |
+
cfg=None,
|
94 |
+
):
|
95 |
+
"""Initialize STFT loss module."""
|
96 |
+
super(STFTLoss, self).__init__()
|
97 |
+
|
98 |
+
fft_size = (
|
99 |
+
cfg.fft_size if cfg is not None and hasattr(cfg, "fft_size") else fft_size
|
100 |
+
)
|
101 |
+
hop_length = (
|
102 |
+
cfg.hop_length
|
103 |
+
if cfg is not None and hasattr(cfg, "hop_length")
|
104 |
+
else hop_length
|
105 |
+
)
|
106 |
+
win_length = (
|
107 |
+
cfg.win_length
|
108 |
+
if cfg is not None and hasattr(cfg, "win_length")
|
109 |
+
else win_length
|
110 |
+
)
|
111 |
+
window = cfg.window if cfg is not None and hasattr(cfg, "window") else window
|
112 |
+
|
113 |
+
self.fft_size = fft_size
|
114 |
+
self.hop_length = hop_length
|
115 |
+
self.win_length = win_length
|
116 |
+
self.window = getattr(torch, window)(win_length)
|
117 |
+
self.spectral_convergence_loss = SpectralConvergenceLoss()
|
118 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
119 |
+
|
120 |
+
self.register_buffer("mag_scale", compute_mag_scale(fft_size, sampling_rate))
|
121 |
+
|
122 |
+
def forward(self, x, y):
|
123 |
+
"""Calculate forward propagation.
|
124 |
+
Args:
|
125 |
+
x (Tensor): Predicted signal (B, T).
|
126 |
+
y (Tensor): Ground truth signal (B, T).
|
127 |
+
Returns:
|
128 |
+
Tensor: Spectral convergence loss value.
|
129 |
+
Tensor: Log STFT magnitude loss value.
|
130 |
+
"""
|
131 |
+
x_mag = (
|
132 |
+
stft(x, self.fft_size, self.hop_length, self.win_length, self.window)
|
133 |
+
* self.mag_scale
|
134 |
+
)
|
135 |
+
y_mag = (
|
136 |
+
stft(y, self.fft_size, self.hop_length, self.win_length, self.window)
|
137 |
+
* self.mag_scale
|
138 |
+
)
|
139 |
+
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
|
140 |
+
log_mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
141 |
+
|
142 |
+
return sc_loss, log_mag_loss
|
143 |
+
|
144 |
+
|
145 |
+
class MultiResolutionSTFTLoss(torch.nn.Module):
|
146 |
+
"""Multi resolution STFT loss module."""
|
147 |
+
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
fft_sizes=(1024, 2048, 512),
|
151 |
+
hop_sizes=(120, 240, 50),
|
152 |
+
win_lengths=(600, 1200, 240),
|
153 |
+
window="hann_window",
|
154 |
+
sampling_rate=16000,
|
155 |
+
cfg=None,
|
156 |
+
):
|
157 |
+
"""Initialize Multi resolution STFT loss module.
|
158 |
+
Args:
|
159 |
+
fft_sizes (list): List of FFT sizes.
|
160 |
+
hop_sizes (list): List of hop sizes.
|
161 |
+
win_lengths (list): List of window lengths.
|
162 |
+
window (str): Window function type.
|
163 |
+
"""
|
164 |
+
super(MultiResolutionSTFTLoss, self).__init__()
|
165 |
+
|
166 |
+
fft_sizes = (
|
167 |
+
cfg.fft_sizes
|
168 |
+
if cfg is not None and hasattr(cfg, "fft_sizes")
|
169 |
+
else fft_sizes
|
170 |
+
)
|
171 |
+
hop_sizes = (
|
172 |
+
cfg.hop_sizes
|
173 |
+
if cfg is not None and hasattr(cfg, "hop_sizes")
|
174 |
+
else hop_sizes
|
175 |
+
)
|
176 |
+
win_lengths = (
|
177 |
+
cfg.win_lengths
|
178 |
+
if cfg is not None and hasattr(cfg, "win_lengths")
|
179 |
+
else win_lengths
|
180 |
+
)
|
181 |
+
window = cfg.window if cfg is not None and hasattr(cfg, "window") else window
|
182 |
+
sampling_rate = (
|
183 |
+
cfg.sampling_rate
|
184 |
+
if cfg is not None and hasattr(cfg, "sampling_rate")
|
185 |
+
else sampling_rate
|
186 |
+
)
|
187 |
+
|
188 |
+
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
189 |
+
self.stft_losses = torch.nn.ModuleList()
|
190 |
+
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
191 |
+
self.stft_losses += [
|
192 |
+
STFTLoss(fs, ss, wl, window=window, sampling_rate=sampling_rate)
|
193 |
+
]
|
194 |
+
|
195 |
+
def forward(self, x, y):
|
196 |
+
"""Calculate forward propagation.
|
197 |
+
Args:
|
198 |
+
x (Tensor): Predicted signal (B, T).
|
199 |
+
y (Tensor): GroundTruth signal (B, T).
|
200 |
+
Returns:
|
201 |
+
Tensor: Multi resolution spectral convergence loss value.
|
202 |
+
Tensor: Multi resolution log STFT magnitude loss value.
|
203 |
+
"""
|
204 |
+
sc_loss = 0.0
|
205 |
+
mag_loss = 0.0
|
206 |
+
for f in self.stft_losses:
|
207 |
+
sc_l, mag_l = f(x, y)
|
208 |
+
sc_loss += sc_l
|
209 |
+
mag_loss += mag_l
|
210 |
+
sc_loss /= len(self.stft_losses)
|
211 |
+
mag_loss /= len(self.stft_losses)
|
212 |
+
|
213 |
+
return sc_loss, mag_loss
|
214 |
+
|
215 |
+
|
216 |
+
class MultiResolutionMelSpectrogramLoss(nn.Module):
|
217 |
+
"""Compute distance between mel spectrograms. Can be used
|
218 |
+
in a multi-scale way.
|
219 |
+
|
220 |
+
Parameters
|
221 |
+
----------
|
222 |
+
n_mels : List[int]
|
223 |
+
Number of mels per STFT, by default [150, 80],
|
224 |
+
window_lengths : List[int], optional
|
225 |
+
Length of each window of each STFT, by default [2048, 512]
|
226 |
+
loss_fn : typing.Callable, optional
|
227 |
+
How to compare each loss, by default nn.L1Loss()
|
228 |
+
clamp_eps : float, optional
|
229 |
+
Clamp on the log magnitude, below, by default 1e-5
|
230 |
+
mag_weight : float, optional
|
231 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
232 |
+
log_weight : float, optional
|
233 |
+
Weight of log magnitude portion of loss, by default 1.0
|
234 |
+
pow : float, optional
|
235 |
+
Power to raise magnitude to before taking log, by default 2.0
|
236 |
+
weight : float, optional
|
237 |
+
Weight of this loss, by default 1.0
|
238 |
+
match_stride : bool, optional
|
239 |
+
Whether to match the stride of convolutional layers, by default False
|
240 |
+
|
241 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
242 |
+
"""
|
243 |
+
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
sample_rate=16000,
|
247 |
+
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
|
248 |
+
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
249 |
+
clamp_eps: float = 1e-5,
|
250 |
+
mag_weight: float = 0.0,
|
251 |
+
log_weight: float = 1.0,
|
252 |
+
pow: float = 1.0,
|
253 |
+
mel_fmin: List[float] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
254 |
+
mel_fmax: List[float] = [None, None, None, None, None, None, None],
|
255 |
+
cfg=None,
|
256 |
+
):
|
257 |
+
super().__init__()
|
258 |
+
|
259 |
+
sample_rate = (
|
260 |
+
cfg.sample_rate
|
261 |
+
if cfg is not None and hasattr(cfg, "sample_rate")
|
262 |
+
else sample_rate
|
263 |
+
)
|
264 |
+
n_mels = cfg.n_mels if cfg is not None and hasattr(cfg, "n_mels") else n_mels
|
265 |
+
window_lengths = (
|
266 |
+
cfg.window_lengths
|
267 |
+
if cfg is not None and hasattr(cfg, "window_lengths")
|
268 |
+
else window_lengths
|
269 |
+
)
|
270 |
+
clamp_eps = (
|
271 |
+
cfg.clamp_eps
|
272 |
+
if cfg is not None and hasattr(cfg, "clamp_eps")
|
273 |
+
else clamp_eps
|
274 |
+
)
|
275 |
+
mag_weight = (
|
276 |
+
cfg.mag_weight
|
277 |
+
if cfg is not None and hasattr(cfg, "mag_weight")
|
278 |
+
else mag_weight
|
279 |
+
)
|
280 |
+
log_weight = (
|
281 |
+
cfg.log_weight
|
282 |
+
if cfg is not None and hasattr(cfg, "log_weight")
|
283 |
+
else log_weight
|
284 |
+
)
|
285 |
+
pow = cfg.pow if cfg is not None and hasattr(cfg, "pow") else pow
|
286 |
+
mel_fmin = (
|
287 |
+
cfg.mel_fmin if cfg is not None and hasattr(cfg, "mel_fmin") else mel_fmin
|
288 |
+
)
|
289 |
+
mel_fmax = (
|
290 |
+
cfg.mel_fmax if cfg is not None and hasattr(cfg, "mel_fmax") else mel_fmax
|
291 |
+
)
|
292 |
+
|
293 |
+
self.mel_transforms = nn.ModuleList(
|
294 |
+
[
|
295 |
+
MelSpectrogram(
|
296 |
+
sample_rate=sample_rate,
|
297 |
+
n_fft=window_length,
|
298 |
+
hop_length=window_length // 4,
|
299 |
+
n_mels=n_mel,
|
300 |
+
power=1.0,
|
301 |
+
center=True,
|
302 |
+
norm="slaney",
|
303 |
+
mel_scale="slaney",
|
304 |
+
)
|
305 |
+
for n_mel, window_length in zip(n_mels, window_lengths)
|
306 |
+
]
|
307 |
+
)
|
308 |
+
self.n_mels = n_mels
|
309 |
+
self.loss_fn = nn.L1Loss()
|
310 |
+
self.clamp_eps = clamp_eps
|
311 |
+
self.log_weight = log_weight
|
312 |
+
self.mag_weight = mag_weight
|
313 |
+
self.mel_fmin = mel_fmin
|
314 |
+
self.mel_fmax = mel_fmax
|
315 |
+
self.pow = pow
|
316 |
+
|
317 |
+
def delta(self, x, k):
|
318 |
+
l = x.shape[1]
|
319 |
+
return x[:, 0 : l - k] - x[:, k:l]
|
320 |
+
|
321 |
+
def forward(self, x, y, mask=None):
|
322 |
+
"""Computes mel loss between an estimate and a reference
|
323 |
+
signal.
|
324 |
+
|
325 |
+
Parameters
|
326 |
+
----------
|
327 |
+
x : AudioSignal
|
328 |
+
Estimate signal
|
329 |
+
y : AudioSignal
|
330 |
+
Reference signal
|
331 |
+
|
332 |
+
Returns
|
333 |
+
-------
|
334 |
+
torch.Tensor
|
335 |
+
Mel loss.
|
336 |
+
"""
|
337 |
+
loss = 0.0
|
338 |
+
for mel_transform in self.mel_transforms:
|
339 |
+
x_mel = mel_transform(x)
|
340 |
+
y_mel = mel_transform(y)
|
341 |
+
log_x_mel = x_mel.clamp(self.clamp_eps).pow(self.pow).log10()
|
342 |
+
log_y_mel = y_mel.clamp(self.clamp_eps).pow(self.pow).log10()
|
343 |
+
loss += self.log_weight * self.loss_fn(log_x_mel, log_y_mel)
|
344 |
+
loss += self.mag_weight * self.loss_fn(x_mel, y_mel)
|
345 |
+
# loss += self.loss_fn(self.delta(log_x_mel, 1), self.delta(log_y_mel, 1))
|
346 |
+
# log_x_mel = rearrange(log_x_mel, 'b c t -> b t c')
|
347 |
+
# log_y_mel = rearrange(log_y_mel, 'b c t -> b t c')
|
348 |
+
# for i in range(3):
|
349 |
+
# loss += self.loss_fn(self.delta(log_x_mel, i), self.delta(log_y_mel, i))
|
350 |
+
# loss /= len(self.mel_transforms)
|
351 |
+
return loss
|
352 |
+
|
353 |
+
|
354 |
+
class GANLoss(nn.Module):
|
355 |
+
def __init__(self, mode="lsgan"):
|
356 |
+
super(GANLoss, self).__init__()
|
357 |
+
assert mode in ["lsgan", "lsgan_std", "hinge"]
|
358 |
+
self.mode = mode
|
359 |
+
|
360 |
+
def disc_loss(self, real, fake):
|
361 |
+
if self.mode == "lsgan":
|
362 |
+
real_loss = F.mse_loss(real, torch.ones_like(real))
|
363 |
+
fake_loss = F.mse_loss(fake, torch.zeros_like(fake))
|
364 |
+
elif self.mode == "lsgan_std":
|
365 |
+
real = (real - 1.0).pow(2)
|
366 |
+
fake = (fake - 0.0).pow(2)
|
367 |
+
real_loss = real.mean() + real.std()
|
368 |
+
fake_loss = fake.mean() + fake.std()
|
369 |
+
elif self.mode == "hinge":
|
370 |
+
real_loss = torch.relu(1.0 - real).mean()
|
371 |
+
fake_loss = torch.relu(1.0 + fake).mean()
|
372 |
+
else:
|
373 |
+
raise ValueError(f"no such mode {self.mode}")
|
374 |
+
|
375 |
+
return real_loss, fake_loss
|
376 |
+
|
377 |
+
def disc_loss2(self, fake):
|
378 |
+
if self.mode == "lsgan":
|
379 |
+
fake_loss = F.mse_loss(fake, torch.zeros_like(fake))
|
380 |
+
elif self.mode == "lsgan_std":
|
381 |
+
fake = (fake - 0.0).pow(2)
|
382 |
+
fake_loss = fake.mean() + fake.std()
|
383 |
+
elif self.mode == "hinge":
|
384 |
+
fake_loss = torch.relu(1.0 + fake).mean()
|
385 |
+
else:
|
386 |
+
raise ValueError(f"no such mode {self.mode}")
|
387 |
+
|
388 |
+
return fake_loss
|
389 |
+
|
390 |
+
def gen_loss(self, fake):
|
391 |
+
if self.mode == "lsgan":
|
392 |
+
gen_loss = F.mse_loss(fake, torch.ones_like(fake))
|
393 |
+
elif self.mode == "lsgan_std":
|
394 |
+
fake = (fake - 1.0).pow(2)
|
395 |
+
gen_loss = fake.mean() + fake.std()
|
396 |
+
elif self.mode == "hinge":
|
397 |
+
gen_loss = -fake.mean()
|
398 |
+
else:
|
399 |
+
raise ValueError(f"no such mode {self.mode}")
|
400 |
+
|
401 |
+
return gen_loss
|
models/codec/amphion_codec/quantize/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
|
2 |
+
FactorizedVectorQuantize,
|
3 |
+
)
|
4 |
+
from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
|
5 |
+
from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
|
6 |
+
from models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
|
models/codec/amphion_codec/quantize/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (551 Bytes). View file
|
|
models/codec/amphion_codec/quantize/__pycache__/bsq.cpython-310.pyc
ADDED
Binary file (10.6 kB). View file
|
|
models/codec/amphion_codec/quantize/__pycache__/factorized_vector_quantize.cpython-310.pyc
ADDED
Binary file (4.08 kB). View file
|
|
models/codec/amphion_codec/quantize/__pycache__/lookup_free_quantize.cpython-310.pyc
ADDED
Binary file (2.14 kB). View file
|
|
models/codec/amphion_codec/quantize/__pycache__/residual_vq.cpython-310.pyc
ADDED
Binary file (4.4 kB). View file
|
|
models/codec/amphion_codec/quantize/__pycache__/vector_quantize.cpython-310.pyc
ADDED
Binary file (10.9 kB). View file
|
|
models/codec/amphion_codec/quantize/bsq.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from https://github.com/zhaoyue-zephyrus/bsq-vit/blob/main/transcoder/models/quantizer/bsq.py
|
2 |
+
|
3 |
+
from einops import rearrange, reduce
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.autograd import Function
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class DifferentiableEntropyFunction(Function):
|
11 |
+
@staticmethod
|
12 |
+
def forward(ctx, zq, basis, K, eps):
|
13 |
+
zb = (zq + 1) / 2
|
14 |
+
zi = ((zb * basis).sum(-1)).to(torch.int64)
|
15 |
+
cnt = torch.scatter_reduce(
|
16 |
+
torch.zeros(2**K, device=zq.device, dtype=zq.dtype),
|
17 |
+
0,
|
18 |
+
zi.flatten(),
|
19 |
+
torch.ones_like(zi.flatten()).to(zq.dtype),
|
20 |
+
"sum",
|
21 |
+
)
|
22 |
+
prob = (cnt + eps) / (cnt + eps).sum()
|
23 |
+
H = -(prob * torch.log(prob)).sum()
|
24 |
+
ctx.save_for_backward(zq, zi, prob)
|
25 |
+
ctx.K = K
|
26 |
+
return H
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, grad_output):
|
30 |
+
zq, zi, prob = ctx.saved_tensors
|
31 |
+
grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
|
32 |
+
reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
|
33 |
+
grad_input = reord_grad.unsqueeze(-1) * zq
|
34 |
+
return grad_input, None, None, None, None
|
35 |
+
|
36 |
+
|
37 |
+
def codebook_entropy(zq, basis, K, eps=1e-4):
|
38 |
+
return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
|
39 |
+
|
40 |
+
|
41 |
+
class SimpleQuantizer(nn.Module):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
embed_dim=14,
|
45 |
+
codebook_size=16384,
|
46 |
+
commitment=0.25,
|
47 |
+
codebook_loss_weight=1.0,
|
48 |
+
use_l2_normlize=True,
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.embed_dim = embed_dim
|
53 |
+
self.codebook_dim = embed_dim
|
54 |
+
self.codebook_size = codebook_size
|
55 |
+
self.commitment = commitment
|
56 |
+
self.codebook_loss_weight = codebook_loss_weight
|
57 |
+
self.use_l2_normlize = use_l2_normlize
|
58 |
+
|
59 |
+
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
60 |
+
|
61 |
+
def forward(self, z):
|
62 |
+
|
63 |
+
z_e = z
|
64 |
+
|
65 |
+
z_q, indices = self.decode_latents(z_e)
|
66 |
+
|
67 |
+
if self.training:
|
68 |
+
commit_loss = (
|
69 |
+
F.mse_loss(z_e, z_q.detach(), reduction="mean") * self.commitment
|
70 |
+
)
|
71 |
+
codebook_loss = (
|
72 |
+
F.mse_loss(z_q, z_e.detach(), reduction="mean")
|
73 |
+
* self.codebook_loss_weight
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
77 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
78 |
+
|
79 |
+
z_q = z_e + (z_q - z_e).detach()
|
80 |
+
|
81 |
+
return (
|
82 |
+
z_q,
|
83 |
+
commit_loss + codebook_loss,
|
84 |
+
{
|
85 |
+
"commit_loss": commit_loss,
|
86 |
+
"codebook_loss": codebook_loss,
|
87 |
+
"indices": indices,
|
88 |
+
},
|
89 |
+
)
|
90 |
+
|
91 |
+
def embed_code(self, embed_id):
|
92 |
+
return F.embedding(embed_id, self.codebook.weight)
|
93 |
+
|
94 |
+
def decode_code(self, embed_id):
|
95 |
+
return self.embed_code(embed_id)
|
96 |
+
|
97 |
+
def get_codebook_entry(self, indices):
|
98 |
+
return self.embed_code(indices)
|
99 |
+
|
100 |
+
def decode_latents(self, latents):
|
101 |
+
encodings = rearrange(latents, "b t d -> (b t) d")
|
102 |
+
|
103 |
+
codebook = self.codebook.weight
|
104 |
+
|
105 |
+
if self.use_l2_normlize:
|
106 |
+
# encodings = F.normalize(encodings) we have already normalized the latent before vq
|
107 |
+
codebook = F.normalize(codebook)
|
108 |
+
|
109 |
+
# Compute euclidean distance between encodings and codebook,
|
110 |
+
# if use_l2_normlize is True, the distance is equal to cosine distance
|
111 |
+
dist = (
|
112 |
+
encodings.pow(2).sum(1, keepdim=True)
|
113 |
+
- 2 * encodings @ codebook.t()
|
114 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
115 |
+
)
|
116 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
117 |
+
z_q = self.decode_code(indices)
|
118 |
+
|
119 |
+
return z_q, indices
|
120 |
+
|
121 |
+
def vq2emb(self, vq, out_proj=True):
|
122 |
+
emb = self.decode_code(vq)
|
123 |
+
return emb
|
124 |
+
|
125 |
+
def latent2dist(self, latents):
|
126 |
+
|
127 |
+
encodings = latents
|
128 |
+
codebook = self.codebook.weight
|
129 |
+
|
130 |
+
# L2 normalize encodings and codebook
|
131 |
+
if self.use_l2_normlize:
|
132 |
+
# encodings = F.normalize(encodings)
|
133 |
+
codebook = F.normalize(codebook)
|
134 |
+
|
135 |
+
# Compute euclidean distance between encodings and codebook,
|
136 |
+
# if use_l2_normlize is True, the distance is equal to cosine distance
|
137 |
+
dist = (
|
138 |
+
encodings.pow(2).sum(1, keepdim=True)
|
139 |
+
- 2 * encodings @ codebook.t()
|
140 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
141 |
+
) # (b*t, k)
|
142 |
+
|
143 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
144 |
+
dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
|
145 |
+
z_q = self.decode_code(indices)
|
146 |
+
|
147 |
+
return -dist, indices, z_q
|
148 |
+
|
149 |
+
|
150 |
+
class BinarySphericalQuantizer(nn.Module):
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
embed_dim=14,
|
154 |
+
beta=0.0,
|
155 |
+
gamma0=1.0,
|
156 |
+
gamma=1.0,
|
157 |
+
zeta=1.0,
|
158 |
+
input_format="blc",
|
159 |
+
soft_entropy=True,
|
160 |
+
group_size=1,
|
161 |
+
persample_entropy_compute="group",
|
162 |
+
cb_entropy_compute="group",
|
163 |
+
l2_norm=True,
|
164 |
+
inv_temperature=1,
|
165 |
+
):
|
166 |
+
super().__init__()
|
167 |
+
self.embed_dim = embed_dim
|
168 |
+
self.beta = beta # loss weight for commit loss
|
169 |
+
self.gamma0 = gamma0 # loss weight for entropy penalty
|
170 |
+
self.gamma = gamma # loss weight for entropy penalty
|
171 |
+
self.zeta = zeta # loss weight for entire entropy penalty
|
172 |
+
self.input_format = input_format
|
173 |
+
assert (
|
174 |
+
self.embed_dim % group_size == 0
|
175 |
+
), "embed_dim must be divisible by group_size"
|
176 |
+
self.num_groups = self.embed_dim // group_size
|
177 |
+
self.group_size = group_size
|
178 |
+
assert persample_entropy_compute in [
|
179 |
+
"group",
|
180 |
+
"analytical",
|
181 |
+
], "persample_entropy_compute must be either 'group' or 'analytical'"
|
182 |
+
assert cb_entropy_compute in [
|
183 |
+
"group",
|
184 |
+
"nce",
|
185 |
+
], "cb_entropy_compute must be either 'group' or 'nce'"
|
186 |
+
self.persample_entropy_compute = persample_entropy_compute
|
187 |
+
self.cb_entropy_compute = cb_entropy_compute
|
188 |
+
self.l2_norm = l2_norm
|
189 |
+
self.inv_temperature = inv_temperature
|
190 |
+
|
191 |
+
self.register_buffer("basis", 2 ** torch.arange(embed_dim - 1, -1, -1))
|
192 |
+
self.register_buffer("group_basis", 2 ** torch.arange(group_size - 1, -1, -1))
|
193 |
+
|
194 |
+
self.num_dimensions = 2**embed_dim
|
195 |
+
self.bits_per_index = embed_dim
|
196 |
+
|
197 |
+
# we only need to keep the codebook portion up to the group size
|
198 |
+
# because we approximate the H loss with this subcode
|
199 |
+
group_codes = torch.arange(2**self.group_size)
|
200 |
+
group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
|
201 |
+
self.register_buffer("group_codebook", group_codebook, persistent=False)
|
202 |
+
|
203 |
+
self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf
|
204 |
+
|
205 |
+
def quantize(self, z):
|
206 |
+
assert (
|
207 |
+
z.shape[-1] == self.embed_dim
|
208 |
+
), f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
|
209 |
+
|
210 |
+
zhat = torch.where(
|
211 |
+
z > 0,
|
212 |
+
torch.tensor(1, dtype=z.dtype, device=z.device),
|
213 |
+
torch.tensor(-1, dtype=z.dtype, device=z.device),
|
214 |
+
)
|
215 |
+
return z + (zhat - z).detach()
|
216 |
+
|
217 |
+
def forward(self, z):
|
218 |
+
if self.input_format == "bchw":
|
219 |
+
z = rearrange(z, "b c h w -> b h w c")
|
220 |
+
zq = self.quantize(z)
|
221 |
+
|
222 |
+
indices = self.codes_to_indexes(zq.detach())
|
223 |
+
group_indices = self.codes_to_group_indexes(zq.detach())
|
224 |
+
if not self.training:
|
225 |
+
used_codes = torch.unique(indices, return_counts=False)
|
226 |
+
else:
|
227 |
+
used_codes = None
|
228 |
+
|
229 |
+
q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0
|
230 |
+
|
231 |
+
if self.soft_entropy:
|
232 |
+
persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
|
233 |
+
entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
|
234 |
+
else:
|
235 |
+
zb_by_sample = (
|
236 |
+
((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
|
237 |
+
)
|
238 |
+
persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
|
239 |
+
cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
|
240 |
+
entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
|
241 |
+
|
242 |
+
zq = zq * q_scale
|
243 |
+
|
244 |
+
# commit loss
|
245 |
+
# commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
|
246 |
+
commit_loss = torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
|
247 |
+
|
248 |
+
if self.input_format == "bchw":
|
249 |
+
zq = rearrange(zq, "b h w c -> b c h w")
|
250 |
+
|
251 |
+
return (
|
252 |
+
zq,
|
253 |
+
commit_loss * self.beta
|
254 |
+
+ self.zeta * entropy_penalty / self.inv_temperature,
|
255 |
+
{
|
256 |
+
"H": cb_entropy,
|
257 |
+
"used_codes": used_codes,
|
258 |
+
"indices": indices,
|
259 |
+
"group_indices": group_indices,
|
260 |
+
"avg_prob": avg_prob,
|
261 |
+
"commit_loss": commit_loss,
|
262 |
+
},
|
263 |
+
)
|
264 |
+
|
265 |
+
def soft_entropy_loss(self, z):
|
266 |
+
# if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
|
267 |
+
# the sub-code is the last group_size bits of the full code
|
268 |
+
group_code_book = self.group_codebook / (
|
269 |
+
self.embed_dim**0.5 if self.l2_norm else 1
|
270 |
+
)
|
271 |
+
divided_z = rearrange(z, "... (g c) -> ... g c", c=self.group_size)
|
272 |
+
|
273 |
+
# we calculate the distance between the divided_z and the codebook for each subgroup
|
274 |
+
distance = -2 * torch.einsum(
|
275 |
+
"... g c, d c ->... g d", divided_z, group_code_book
|
276 |
+
)
|
277 |
+
prob = (-distance * self.inv_temperature).softmax(dim=-1)
|
278 |
+
if self.persample_entropy_compute == "analytical":
|
279 |
+
if self.l2_norm:
|
280 |
+
p = torch.sigmoid(-4 * z / (self.embed_dim**0.5) * self.inv_temperature)
|
281 |
+
else:
|
282 |
+
p = torch.sigmoid(-4 * z * self.inv_temperature)
|
283 |
+
prob = torch.stack([p, 1 - p], dim=-1)
|
284 |
+
per_sample_entropy = (
|
285 |
+
self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
|
286 |
+
)
|
287 |
+
else:
|
288 |
+
per_sample_entropy = (
|
289 |
+
self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
|
290 |
+
)
|
291 |
+
|
292 |
+
# macro average of the probability of each subgroup
|
293 |
+
avg_prob = reduce(prob, "... g d ->g d", "mean")
|
294 |
+
codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
|
295 |
+
|
296 |
+
# the approximation of the entropy is the sum of the entropy of each subgroup
|
297 |
+
return per_sample_entropy, codebook_entropy.sum(), avg_prob
|
298 |
+
|
299 |
+
def get_hard_per_sample_entropy(self, zb_by_sample):
|
300 |
+
probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
|
301 |
+
persample_entropy = -probs_per_dim * torch.log(probs_per_dim + 1e-8) - (
|
302 |
+
1 - probs_per_dim
|
303 |
+
) * torch.log(1 - probs_per_dim + 1e-8)
|
304 |
+
persample_entropy = persample_entropy.sum(-1)
|
305 |
+
return persample_entropy.mean()
|
306 |
+
|
307 |
+
def codes_to_indexes(self, zhat):
|
308 |
+
"""Converts a `code` to an index in the codebook.
|
309 |
+
Args:
|
310 |
+
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
|
311 |
+
"""
|
312 |
+
assert (
|
313 |
+
zhat.shape[-1] == self.embed_dim
|
314 |
+
), f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
|
315 |
+
return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
|
316 |
+
|
317 |
+
def codes_to_group_indexes(self, zhat):
|
318 |
+
"""Converts a `code` to a list of indexes (in groups) in the codebook.
|
319 |
+
Args:
|
320 |
+
zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
|
321 |
+
"""
|
322 |
+
zhat_in_group = rearrange(zhat, "b ... (g c) -> b ... g c", c=self.group_size)
|
323 |
+
return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
|
324 |
+
|
325 |
+
def indexes_to_codes(self, indices):
|
326 |
+
"""Inverse of `indexes_to_codes`."""
|
327 |
+
indices = indices.unsqueeze(-1)
|
328 |
+
codes_non_centered = torch.remainder(torch.floor_divide(indices, self.basis), 2)
|
329 |
+
return codes_non_centered * 2 - 1
|
330 |
+
|
331 |
+
def group_indexes_to_codes(self, group_indices):
|
332 |
+
"""Inverse of `group_indexes_to_codes`."""
|
333 |
+
group_indices = group_indices.unsqueeze(-1)
|
334 |
+
codes_non_centered = torch.remainder(
|
335 |
+
torch.floor_divide(group_indices, self.group_basis), 2
|
336 |
+
)
|
337 |
+
codes_non_centered = rearrange(codes_non_centered, "b ... g c -> b ... (g c)")
|
338 |
+
return codes_non_centered * 2 - 1
|
339 |
+
|
340 |
+
def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
|
341 |
+
if normalize:
|
342 |
+
probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True)
|
343 |
+
else:
|
344 |
+
probs = count
|
345 |
+
H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
|
346 |
+
return H
|
347 |
+
|
348 |
+
def get_group_codebook_entry(self, group_indices):
|
349 |
+
z_q = self.group_indexes_to_codes(group_indices)
|
350 |
+
q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0
|
351 |
+
z_q = z_q * q_scale
|
352 |
+
if self.input_format == "bchw":
|
353 |
+
h, w = int(z_q.shape[1] ** 0.5)
|
354 |
+
assert h * w == z_q.shape[1], "Invalid sequence length"
|
355 |
+
z_q = rearrange(z_q, "b (h w) c -> b c h w", h=h)
|
356 |
+
return z_q
|
357 |
+
|
358 |
+
def get_codebook_entry(self, indices):
|
359 |
+
z_q = self.indexes_to_codes(indices)
|
360 |
+
q_scale = 1.0 / (self.embed_dim**0.5) if self.l2_norm else 1.0
|
361 |
+
z_q = z_q * q_scale
|
362 |
+
if self.input_format == "bchw":
|
363 |
+
h, w = int(z_q.shape[1] ** 0.5)
|
364 |
+
assert h * w == z_q.shape[1], "Invalid sequence length"
|
365 |
+
z_q = rearrange(z_q, "b (h w) c -> b c h w", h=h)
|
366 |
+
return z_q
|
367 |
+
|
368 |
+
|
369 |
+
if __name__ == "__main__":
|
370 |
+
bsq = BinarySphericalQuantizer()
|
371 |
+
z = torch.randn(3, 20, 14)
|
372 |
+
zq, loss, info = bsq(z)
|
373 |
+
print(zq.shape, loss, info)
|
models/codec/amphion_codec/quantize/factorized_vector_quantize.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
|
8 |
+
|
9 |
+
def WNConv1d(*args, **kwargs):
|
10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
11 |
+
|
12 |
+
|
13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
15 |
+
|
16 |
+
|
17 |
+
class FactorizedVectorQuantize(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
input_dim,
|
21 |
+
codebook_size,
|
22 |
+
codebook_dim,
|
23 |
+
commitment=0.005,
|
24 |
+
codebook_loss_weight=1.0,
|
25 |
+
use_l2_normlize=True,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.input_dim = input_dim
|
29 |
+
self.codebook_size = codebook_size
|
30 |
+
self.codebook_dim = codebook_dim
|
31 |
+
self.commitment = commitment
|
32 |
+
self.codebook_loss_weight = codebook_loss_weight
|
33 |
+
self.use_l2_normlize = use_l2_normlize
|
34 |
+
|
35 |
+
if self.input_dim != self.codebook_dim:
|
36 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
37 |
+
self.out_project = WNConv1d(
|
38 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
39 |
+
)
|
40 |
+
|
41 |
+
else:
|
42 |
+
self.in_project = nn.Identity()
|
43 |
+
self.out_project = nn.Identity()
|
44 |
+
|
45 |
+
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
46 |
+
|
47 |
+
def forward(self, z):
|
48 |
+
"""
|
49 |
+
Parameters
|
50 |
+
----------
|
51 |
+
z: torch.Tensor[B x D x T]
|
52 |
+
|
53 |
+
Returns
|
54 |
+
-------
|
55 |
+
z_q: torch.Tensor[B x D x T]
|
56 |
+
Quantized continuous representation of input
|
57 |
+
commit_loss: Tensor[B]
|
58 |
+
Commitment loss to train encoder to predict vectors closer to codebook entries
|
59 |
+
codebook_loss: Tensor[B]
|
60 |
+
Codebook loss to update the codebook
|
61 |
+
indices: torch.Tensor[B x T]
|
62 |
+
Codebook indices (quantized discrete representation of input)
|
63 |
+
z_e: torch.Tensor[B x D x T]
|
64 |
+
Projected latents (continuous representation of input before quantization)
|
65 |
+
"""
|
66 |
+
|
67 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
68 |
+
z_e = self.in_project(z)
|
69 |
+
z_q, indices = self.decode_latents(z_e)
|
70 |
+
|
71 |
+
# Compute commitment loss and codebook loss
|
72 |
+
if self.training:
|
73 |
+
commit_loss = (
|
74 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
75 |
+
* self.commitment
|
76 |
+
)
|
77 |
+
codebook_loss = (
|
78 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
79 |
+
* self.codebook_loss_weight
|
80 |
+
)
|
81 |
+
else:
|
82 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
83 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
84 |
+
|
85 |
+
z_q = z_e + (z_q - z_e).detach()
|
86 |
+
|
87 |
+
z_q = self.out_project(z_q)
|
88 |
+
|
89 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
90 |
+
|
91 |
+
def embed_code(self, embed_id):
|
92 |
+
return F.embedding(embed_id, self.codebook.weight)
|
93 |
+
|
94 |
+
def decode_code(self, embed_id):
|
95 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
96 |
+
|
97 |
+
def decode_latents(self, latents):
|
98 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
99 |
+
codebook = self.codebook.weight
|
100 |
+
|
101 |
+
# L2 normalize encodings and codebook
|
102 |
+
if self.use_l2_normlize:
|
103 |
+
encodings = F.normalize(encodings)
|
104 |
+
codebook = F.normalize(codebook)
|
105 |
+
|
106 |
+
# Compute euclidean distance between encodings and codebook,
|
107 |
+
# if use_l2_normlize is True, the distance is equal to cosine distance
|
108 |
+
dist = (
|
109 |
+
encodings.pow(2).sum(1, keepdim=True)
|
110 |
+
- 2 * encodings @ codebook.t()
|
111 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
112 |
+
)
|
113 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
114 |
+
z_q = self.decode_code(indices)
|
115 |
+
|
116 |
+
return z_q, indices
|
117 |
+
|
118 |
+
def vq2emb(self, vq, out_proj=True):
|
119 |
+
emb = self.decode_code(vq)
|
120 |
+
if out_proj:
|
121 |
+
emb = self.out_project(emb)
|
122 |
+
return emb
|
123 |
+
|
124 |
+
def latent2dist(self, latents):
|
125 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
126 |
+
codebook = self.codebook.weight
|
127 |
+
|
128 |
+
# L2 normalize encodings and codebook
|
129 |
+
if self.use_l2_normlize:
|
130 |
+
encodings = F.normalize(encodings)
|
131 |
+
codebook = F.normalize(codebook)
|
132 |
+
|
133 |
+
# Compute euclidean distance between encodings and codebook,
|
134 |
+
# if use_l2_normlize is True, the distance is equal to cosine distance
|
135 |
+
dist = (
|
136 |
+
encodings.pow(2).sum(1, keepdim=True)
|
137 |
+
- 2 * encodings @ codebook.t()
|
138 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
139 |
+
) # (b*t, k)
|
140 |
+
|
141 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
142 |
+
dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
|
143 |
+
z_q = self.decode_code(indices)
|
144 |
+
|
145 |
+
return -dist, indices, z_q
|
models/codec/amphion_codec/quantize/lookup_free_quantize.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
|
8 |
+
|
9 |
+
def WNConv1d(*args, **kwargs):
|
10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
11 |
+
|
12 |
+
|
13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
15 |
+
|
16 |
+
|
17 |
+
class LookupFreeQuantize(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
input_dim,
|
21 |
+
codebook_size,
|
22 |
+
codebook_dim,
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
self.input_dim = input_dim
|
26 |
+
self.codebook_size = codebook_size
|
27 |
+
self.codebook_dim = codebook_dim
|
28 |
+
|
29 |
+
assert 2**codebook_dim == codebook_size
|
30 |
+
|
31 |
+
if self.input_dim != self.codebook_dim:
|
32 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
33 |
+
self.out_project = WNConv1d(
|
34 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
35 |
+
)
|
36 |
+
|
37 |
+
else:
|
38 |
+
self.in_project = nn.Identity()
|
39 |
+
self.out_project = nn.Identity()
|
40 |
+
|
41 |
+
def forward(self, z):
|
42 |
+
z_e = self.in_project(z)
|
43 |
+
z_e = F.sigmoid(z_e)
|
44 |
+
|
45 |
+
z_q = z_e + (torch.round(z_e) - z_e).detach()
|
46 |
+
|
47 |
+
z_q = self.out_project(z_q)
|
48 |
+
|
49 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
50 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
51 |
+
|
52 |
+
bits = (
|
53 |
+
2
|
54 |
+
** torch.arange(self.codebook_dim, device=z.device)
|
55 |
+
.unsqueeze(0)
|
56 |
+
.unsqueeze(-1)
|
57 |
+
.long()
|
58 |
+
) # (1, d, 1)
|
59 |
+
indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
|
60 |
+
|
61 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
62 |
+
|
63 |
+
def vq2emb(self, vq, out_proj=True):
|
64 |
+
emb = torch.zeros(
|
65 |
+
vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
|
66 |
+
) # (B, d, T)
|
67 |
+
for i in range(self.codebook_dim):
|
68 |
+
emb[:, i, :] = (vq % 2).float()
|
69 |
+
vq = vq // 2
|
70 |
+
if out_proj:
|
71 |
+
emb = self.out_project(emb)
|
72 |
+
return emb
|
models/codec/amphion_codec/quantize/residual_vq.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
+
from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
|
11 |
+
FactorizedVectorQuantize,
|
12 |
+
)
|
13 |
+
from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
|
14 |
+
from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
|
15 |
+
|
16 |
+
|
17 |
+
class ResidualVQ(nn.Module):
|
18 |
+
"""
|
19 |
+
Introduced in SoundStream: An end2end neural audio codec
|
20 |
+
https://arxiv.org/abs/2107.03312
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
input_dim: int = 256,
|
26 |
+
num_quantizers: int = 8,
|
27 |
+
codebook_size: int = 1024,
|
28 |
+
codebook_dim: int = 256,
|
29 |
+
quantizer_type: str = "vq", # "vq" or "fvq" or "lfq"
|
30 |
+
quantizer_dropout: float = 0.5,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.input_dim = input_dim
|
36 |
+
self.num_quantizers = num_quantizers
|
37 |
+
self.codebook_size = codebook_size
|
38 |
+
self.codebook_dim = codebook_dim
|
39 |
+
self.quantizer_type = quantizer_type
|
40 |
+
self.quantizer_dropout = quantizer_dropout
|
41 |
+
|
42 |
+
if quantizer_type == "vq":
|
43 |
+
VQ = VectorQuantize
|
44 |
+
elif quantizer_type == "fvq":
|
45 |
+
VQ = FactorizedVectorQuantize
|
46 |
+
elif quantizer_type == "lfq":
|
47 |
+
VQ = LookupFreeQuantize
|
48 |
+
else:
|
49 |
+
raise ValueError(f"Unknown quantizer type {quantizer_type}")
|
50 |
+
|
51 |
+
self.quantizers = nn.ModuleList(
|
52 |
+
[
|
53 |
+
VQ(
|
54 |
+
input_dim=input_dim,
|
55 |
+
codebook_size=codebook_size,
|
56 |
+
codebook_dim=codebook_dim,
|
57 |
+
**kwargs,
|
58 |
+
)
|
59 |
+
for _ in range(num_quantizers)
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, z, n_quantizers: int = None):
|
64 |
+
"""
|
65 |
+
Parameters
|
66 |
+
----------
|
67 |
+
z : Tensor[B x D x T]
|
68 |
+
n_quantizers : int, optional
|
69 |
+
No. of quantizers to use
|
70 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
71 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
72 |
+
when in training mode, and a random number of quantizers is used.
|
73 |
+
Returns
|
74 |
+
-------
|
75 |
+
"quantized_out" : Tensor[B x D x T]
|
76 |
+
Quantized continuous representation of input
|
77 |
+
"all_indices" : Tensor[N x B x T]
|
78 |
+
Codebook indices for each codebook
|
79 |
+
(quantized discrete representation of input)
|
80 |
+
"all_commit_losses" : Tensor[N]
|
81 |
+
"all_codebook_losses" : Tensor[N]
|
82 |
+
"all_quantized" : Tensor[N x B x D x T]
|
83 |
+
"""
|
84 |
+
|
85 |
+
quantized_out = 0.0
|
86 |
+
residual = z
|
87 |
+
|
88 |
+
all_commit_losses = []
|
89 |
+
all_codebook_losses = []
|
90 |
+
all_indices = []
|
91 |
+
all_quantized = []
|
92 |
+
|
93 |
+
if n_quantizers is None:
|
94 |
+
n_quantizers = self.num_quantizers
|
95 |
+
|
96 |
+
if self.training:
|
97 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
|
98 |
+
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
|
99 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
100 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
101 |
+
n_quantizers = n_quantizers.to(z.device)
|
102 |
+
|
103 |
+
for i, quantizer in enumerate(self.quantizers):
|
104 |
+
if self.training is False and i >= n_quantizers:
|
105 |
+
break
|
106 |
+
|
107 |
+
z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
108 |
+
residual
|
109 |
+
)
|
110 |
+
|
111 |
+
# Create mask to apply quantizer dropout
|
112 |
+
mask = (
|
113 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
114 |
+
)
|
115 |
+
quantized_out = quantized_out + z_q_i * mask[:, None, None]
|
116 |
+
residual = residual - z_q_i
|
117 |
+
|
118 |
+
commit_loss_i = (commit_loss_i * mask).mean()
|
119 |
+
codebook_loss_i = (codebook_loss_i * mask).mean()
|
120 |
+
|
121 |
+
all_commit_losses.append(commit_loss_i)
|
122 |
+
all_codebook_losses.append(codebook_loss_i)
|
123 |
+
all_indices.append(indices_i)
|
124 |
+
all_quantized.append(z_q_i)
|
125 |
+
|
126 |
+
all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
|
127 |
+
torch.stack,
|
128 |
+
(all_commit_losses, all_codebook_losses, all_indices, all_quantized),
|
129 |
+
)
|
130 |
+
|
131 |
+
return (
|
132 |
+
quantized_out,
|
133 |
+
all_indices,
|
134 |
+
all_commit_losses,
|
135 |
+
all_codebook_losses,
|
136 |
+
all_quantized,
|
137 |
+
)
|
138 |
+
|
139 |
+
def vq2emb(self, vq, n_quantizers=None):
|
140 |
+
quantized_out = 0.0
|
141 |
+
if n_quantizers is None:
|
142 |
+
n_quantizers = self.num_quantizers
|
143 |
+
for idx, quantizer in enumerate(self.quantizers):
|
144 |
+
if idx >= n_quantizers:
|
145 |
+
break
|
146 |
+
quantized_out += quantizer.vq2emb(vq[idx])
|
147 |
+
return quantized_out
|
148 |
+
|
149 |
+
def latent2dist(self, z, n_quantizers=None):
|
150 |
+
quantized_out = 0.0
|
151 |
+
residual = z
|
152 |
+
|
153 |
+
all_dists = []
|
154 |
+
all_indices = []
|
155 |
+
|
156 |
+
if n_quantizers is None:
|
157 |
+
n_quantizers = self.num_quantizers
|
158 |
+
|
159 |
+
for i, quantizer in enumerate(self.quantizers):
|
160 |
+
if self.training is False and i >= n_quantizers:
|
161 |
+
break
|
162 |
+
dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
|
163 |
+
all_dists.append(dist_i)
|
164 |
+
all_indices.append(indices_i)
|
165 |
+
|
166 |
+
quantized_out = quantized_out + z_q_i
|
167 |
+
residual = residual - z_q_i
|
168 |
+
|
169 |
+
all_dists = torch.stack(all_dists)
|
170 |
+
all_indices = torch.stack(all_indices)
|
171 |
+
|
172 |
+
return all_dists, all_indices
|
models/codec/amphion_codec/quantize/vector_quantize.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
|
8 |
+
|
9 |
+
def WNConv1d(*args, **kwargs):
|
10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
11 |
+
|
12 |
+
|
13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
15 |
+
|
16 |
+
|
17 |
+
def l2norm(t):
|
18 |
+
return F.normalize(t, p=2, dim=-1)
|
19 |
+
|
20 |
+
|
21 |
+
def ema_inplace(moving_avg, new, decay):
|
22 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
23 |
+
|
24 |
+
|
25 |
+
def laplace_smoothing(x, n_categories, eps=1e-5):
|
26 |
+
return (x + eps) / (x.sum() + n_categories * eps)
|
27 |
+
|
28 |
+
|
29 |
+
def sample_vectors(samples, num):
|
30 |
+
num_samples, device = samples.shape[0], samples.device
|
31 |
+
|
32 |
+
if num_samples >= num:
|
33 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
34 |
+
else:
|
35 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
36 |
+
|
37 |
+
return samples[indices]
|
38 |
+
|
39 |
+
|
40 |
+
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
41 |
+
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
42 |
+
|
43 |
+
means = sample_vectors(samples, num_clusters)
|
44 |
+
|
45 |
+
for _ in range(num_iters):
|
46 |
+
if use_cosine_sim:
|
47 |
+
dists = samples @ means.t()
|
48 |
+
else:
|
49 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
50 |
+
means, "c d -> () c d"
|
51 |
+
)
|
52 |
+
dists = -(diffs**2).sum(dim=-1)
|
53 |
+
|
54 |
+
buckets = dists.max(dim=-1).indices
|
55 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
56 |
+
zero_mask = bins == 0
|
57 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
58 |
+
|
59 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
60 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
61 |
+
new_means = new_means / bins_min_clamped[..., None]
|
62 |
+
|
63 |
+
if use_cosine_sim:
|
64 |
+
new_means = l2norm(new_means)
|
65 |
+
|
66 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
67 |
+
|
68 |
+
return means, bins
|
69 |
+
|
70 |
+
|
71 |
+
class EuclideanCodebook(nn.Module):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
dim,
|
75 |
+
codebook_size,
|
76 |
+
kmeans_init=False,
|
77 |
+
kmeans_iters=10,
|
78 |
+
decay=0.8,
|
79 |
+
eps=1e-5,
|
80 |
+
threshold_ema_dead_code=2,
|
81 |
+
weight_init=False,
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
self.decay = decay
|
86 |
+
init_fn = torch.randn if not weight_init else torch.zeros
|
87 |
+
embed = init_fn(codebook_size, dim)
|
88 |
+
|
89 |
+
if weight_init:
|
90 |
+
nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
|
91 |
+
|
92 |
+
self.codebook_size = codebook_size
|
93 |
+
self.kmeans_iters = kmeans_iters
|
94 |
+
self.eps = eps
|
95 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
96 |
+
|
97 |
+
self.register_buffer(
|
98 |
+
"initted", torch.Tensor([not kmeans_init])
|
99 |
+
) # if kmeans_init is True, then initted is False; otherwise, initted is True
|
100 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
101 |
+
self.register_buffer("embed", embed)
|
102 |
+
self.register_buffer("embed_avg", embed.clone())
|
103 |
+
|
104 |
+
def init_embed_(self, data):
|
105 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
106 |
+
self.embed.data.copy_(embed)
|
107 |
+
self.embed_avg.data.copy_(embed)
|
108 |
+
self.cluster_size.data.copy_(cluster_size)
|
109 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
110 |
+
|
111 |
+
def replace(self, samples, mask):
|
112 |
+
modified_codebook = torch.where(
|
113 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
114 |
+
)
|
115 |
+
self.embed.data.copy_(modified_codebook)
|
116 |
+
|
117 |
+
def expire_codes_(self, batch_samples):
|
118 |
+
if self.threshold_ema_dead_code == 0:
|
119 |
+
return
|
120 |
+
|
121 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
122 |
+
if not torch.any(expired_codes):
|
123 |
+
return
|
124 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
125 |
+
self.replace(batch_samples, mask=expired_codes)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
shape, dtype = x.shape, x.dtype
|
129 |
+
flatten = rearrange(x, "... d -> (...) d")
|
130 |
+
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
|
131 |
+
|
132 |
+
if not self.initted:
|
133 |
+
self.init_embed_(flatten)
|
134 |
+
|
135 |
+
dist = -(
|
136 |
+
flatten.pow(2).sum(1, keepdim=True)
|
137 |
+
- 2 * flatten @ embed
|
138 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
139 |
+
)
|
140 |
+
|
141 |
+
embed_ind = dist.max(dim=-1).indices
|
142 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
143 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
144 |
+
quantize = F.embedding(embed_ind, self.embed)
|
145 |
+
|
146 |
+
if self.training:
|
147 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
148 |
+
embed_sum = (
|
149 |
+
flatten.t() @ embed_onehot
|
150 |
+
) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
|
151 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
152 |
+
cluster_size = (
|
153 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
|
154 |
+
* self.cluster_size.sum()
|
155 |
+
)
|
156 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
157 |
+
self.embed.data.copy_(embed_normalized)
|
158 |
+
self.expire_codes_(x)
|
159 |
+
|
160 |
+
return quantize, embed_ind
|
161 |
+
|
162 |
+
def vq2emb(self, vq):
|
163 |
+
quantize = F.embedding(vq, self.embed)
|
164 |
+
return quantize
|
165 |
+
|
166 |
+
def latent2dist(self, x):
|
167 |
+
shape, dtype = x.shape, x.dtype
|
168 |
+
flatten = rearrange(x, "... d -> (...) d")
|
169 |
+
embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size)
|
170 |
+
|
171 |
+
if not self.initted:
|
172 |
+
self.init_embed_(flatten)
|
173 |
+
|
174 |
+
dist = -(
|
175 |
+
flatten.pow(2).sum(1, keepdim=True)
|
176 |
+
- 2 * flatten @ embed
|
177 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
178 |
+
)
|
179 |
+
|
180 |
+
embed_ind = dist.max(dim=-1).indices
|
181 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
182 |
+
quantize = F.embedding(embed_ind, self.embed)
|
183 |
+
|
184 |
+
dist = dist.view(*shape[:-1], -1)
|
185 |
+
|
186 |
+
return dist, embed_ind, quantize
|
187 |
+
|
188 |
+
|
189 |
+
class SimpleCodebook(nn.Module):
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
dim,
|
193 |
+
codebook_size,
|
194 |
+
use_l2_normlize=False,
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
|
198 |
+
self.dim = dim
|
199 |
+
self.codebook_size = codebook_size
|
200 |
+
self.use_l2_normlize = use_l2_normlize
|
201 |
+
|
202 |
+
self.embed = nn.Embedding(self.codebook_size, self.dim)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
shape, dtype = x.shape, x.dtype
|
206 |
+
flatten = rearrange(x, "... d -> (...) d")
|
207 |
+
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
|
208 |
+
|
209 |
+
if self.use_l2_normlize:
|
210 |
+
flatten = F.normalize(flatten)
|
211 |
+
embed = F.normalize(embed)
|
212 |
+
|
213 |
+
dist = -(
|
214 |
+
flatten.pow(2).sum(1, keepdim=True)
|
215 |
+
- 2 * flatten @ embed
|
216 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
217 |
+
)
|
218 |
+
|
219 |
+
embed_ind = dist.max(dim=-1).indices
|
220 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
221 |
+
quantize = F.embedding(embed_ind, self.embed)
|
222 |
+
|
223 |
+
return quantize, embed_ind
|
224 |
+
|
225 |
+
def vq2emb(self, vq):
|
226 |
+
quantize = F.embedding(vq, self.embed.weight)
|
227 |
+
return quantize
|
228 |
+
|
229 |
+
def latent2dist(self, x):
|
230 |
+
shape, dtype = x.shape, x.dtype
|
231 |
+
flatten = rearrange(x, "... d -> (...) d")
|
232 |
+
embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size)
|
233 |
+
|
234 |
+
if self.use_l2_normlize:
|
235 |
+
flatten = F.normalize(flatten)
|
236 |
+
embed = F.normalize(embed)
|
237 |
+
|
238 |
+
dist = -(
|
239 |
+
flatten.pow(2).sum(1, keepdim=True)
|
240 |
+
- 2 * flatten @ embed
|
241 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
242 |
+
)
|
243 |
+
|
244 |
+
embed_ind = dist.max(dim=-1).indices
|
245 |
+
embed_ind = embed_ind.view(*shape[:-1])
|
246 |
+
quantize = F.embedding(embed_ind, self.embed)
|
247 |
+
|
248 |
+
dist = dist.view(*shape[:-1], -1)
|
249 |
+
|
250 |
+
return dist, embed_ind, quantize
|
251 |
+
|
252 |
+
|
253 |
+
class VectorQuantize(nn.Module):
|
254 |
+
"""Vector quantization and factorized vecotor quantization implementation
|
255 |
+
Args:
|
256 |
+
input_dim (int): Dimension of input.
|
257 |
+
codebook_size (int): Codebook size.
|
258 |
+
codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
|
259 |
+
if use codebook_type == "euclidean", otherwise, if you want to use
|
260 |
+
factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
|
261 |
+
commitment (float): Weight for commitment loss.
|
262 |
+
use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
|
263 |
+
we suggest use it as True if you want to use factorized vector quantization
|
264 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
265 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
266 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
267 |
+
epsilon (float): Epsilon value for numerical stability.
|
268 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
269 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
270 |
+
randomly selected vector from the current batch.
|
271 |
+
"""
|
272 |
+
|
273 |
+
def __init__(
|
274 |
+
self,
|
275 |
+
input_dim,
|
276 |
+
codebook_size,
|
277 |
+
codebook_dim,
|
278 |
+
commitment=0.005,
|
279 |
+
codebook_loss_weight=1.0,
|
280 |
+
use_l2_normlize=False,
|
281 |
+
codebook_type="euclidean", # "euclidean" or "simple"
|
282 |
+
kmeans_init=False,
|
283 |
+
kmeans_iters=10,
|
284 |
+
decay=0.8,
|
285 |
+
eps=1e-5,
|
286 |
+
threshold_ema_dead_code=2,
|
287 |
+
weight_init=False,
|
288 |
+
):
|
289 |
+
super().__init__()
|
290 |
+
self.input_dim = input_dim
|
291 |
+
self.codebook_size = codebook_size
|
292 |
+
self.codebook_dim = codebook_dim
|
293 |
+
self.commitment = commitment
|
294 |
+
self.codebook_loss_weight = codebook_loss_weight
|
295 |
+
self.use_l2_normlize = use_l2_normlize
|
296 |
+
self.codebook_type = codebook_type
|
297 |
+
self.kmeans_init = kmeans_init
|
298 |
+
self.kmeans_iters = kmeans_iters
|
299 |
+
self.decay = decay
|
300 |
+
self.eps = eps
|
301 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
302 |
+
self.weight_init = weight_init
|
303 |
+
|
304 |
+
if self.input_dim != self.codebook_dim:
|
305 |
+
self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
|
306 |
+
self.out_project = WNConv1d(
|
307 |
+
self.codebook_dim, self.input_dim, kernel_size=1
|
308 |
+
)
|
309 |
+
|
310 |
+
else:
|
311 |
+
self.in_project = nn.Identity()
|
312 |
+
self.out_project = nn.Identity()
|
313 |
+
|
314 |
+
if self.codebook_type == "euclidean":
|
315 |
+
self.codebook = EuclideanCodebook(
|
316 |
+
self.codebook_dim,
|
317 |
+
codebook_size=self.codebook_size,
|
318 |
+
kmeans_init=self.kmeans_init,
|
319 |
+
kmeans_iters=self.kmeans_iters,
|
320 |
+
decay=self.decay,
|
321 |
+
eps=self.eps,
|
322 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
323 |
+
weight_init=self.weight_init,
|
324 |
+
)
|
325 |
+
elif self.codebook_type == "simple":
|
326 |
+
self.codebook = SimpleCodebook(
|
327 |
+
self.codebook_dim,
|
328 |
+
codebook_size=self.codebook_size,
|
329 |
+
use_l2_normlize=self.use_l2_normlize,
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
raise NotImplementedError(
|
333 |
+
f"codebook_type {self.codebook_type} is not implemented!"
|
334 |
+
)
|
335 |
+
|
336 |
+
def forward(self, z):
|
337 |
+
"""
|
338 |
+
Parameters
|
339 |
+
----------
|
340 |
+
z: torch.Tensor[B x D x T]
|
341 |
+
|
342 |
+
Returns
|
343 |
+
-------
|
344 |
+
z_q: torch.Tensor[B x D x T]
|
345 |
+
Quantized continuous representation of input
|
346 |
+
commit_loss: Tensor[B]
|
347 |
+
Commitment loss to train encoder to predict vectors closer to codebook entries
|
348 |
+
codebook_loss: Tensor[B]
|
349 |
+
Codebook loss to update the codebook
|
350 |
+
indices: torch.Tensor[B x T]
|
351 |
+
Codebook indices (quantized discrete representation of input)
|
352 |
+
z_e: torch.Tensor[B x D x T]
|
353 |
+
Projected latents (continuous representation of input before quantization)
|
354 |
+
"""
|
355 |
+
|
356 |
+
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
357 |
+
z_e = self.in_project(z)
|
358 |
+
z_q, indices = self.decode_latents(z_e)
|
359 |
+
|
360 |
+
# Compute commitment loss and codebook loss
|
361 |
+
if self.training:
|
362 |
+
commit_loss = (
|
363 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
364 |
+
* self.commitment
|
365 |
+
)
|
366 |
+
codebook_loss = (
|
367 |
+
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
368 |
+
* self.codebook_loss_weight
|
369 |
+
)
|
370 |
+
else:
|
371 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
372 |
+
codebook_loss = torch.zeros(z.shape[0], device=z.device)
|
373 |
+
|
374 |
+
z_q = z_e + (z_q - z_e).detach()
|
375 |
+
|
376 |
+
z_q = self.out_project(z_q)
|
377 |
+
|
378 |
+
return z_q, commit_loss, codebook_loss, indices, z_e
|
379 |
+
|
380 |
+
def decode_latents(self, latents):
|
381 |
+
encodings = rearrange(latents, "b d t -> b t d")
|
382 |
+
z_q, indices = self.codebook(encodings)
|
383 |
+
z_q = z_q.transpose(1, 2)
|
384 |
+
return z_q, indices
|
385 |
+
|
386 |
+
def vq2emb(self, vq, out_proj=True):
|
387 |
+
emb = self.codebook.vq2emb(vq)
|
388 |
+
emb = emb.transpose(1, 2)
|
389 |
+
if out_proj:
|
390 |
+
emb = self.out_project(emb)
|
391 |
+
return emb
|
392 |
+
|
393 |
+
def latent2dist(self, latents):
|
394 |
+
latents = rearrange(latents, "b d t -> b t d")
|
395 |
+
dist, embed_ind, quantize = self.codebook.latent2dist(latents)
|
396 |
+
return dist, embed_ind, quantize.transpose(1, 2)
|
models/codec/amphion_codec/vocos.py
ADDED
@@ -0,0 +1,909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import scipy
|
5 |
+
import torch
|
6 |
+
from torch import nn, view_as_real, view_as_complex
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
9 |
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
10 |
+
from models.codec.melvqgan.melspec import MelSpectrogram
|
11 |
+
import librosa
|
12 |
+
|
13 |
+
|
14 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
15 |
+
"""
|
16 |
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
x (Tensor): Input tensor.
|
20 |
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
24 |
+
"""
|
25 |
+
return torch.log(torch.clip(x, min=clip_val))
|
26 |
+
|
27 |
+
|
28 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
29 |
+
return torch.sign(x) * torch.log1p(x.abs())
|
30 |
+
|
31 |
+
|
32 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
33 |
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
34 |
+
|
35 |
+
|
36 |
+
class STFT(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
n_fft: int,
|
40 |
+
hop_length: int,
|
41 |
+
win_length: int,
|
42 |
+
center=True,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.center = center
|
46 |
+
self.n_fft = n_fft
|
47 |
+
self.hop_length = hop_length
|
48 |
+
self.win_length = win_length
|
49 |
+
window = torch.hann_window(win_length)
|
50 |
+
self.register_buffer("window", window)
|
51 |
+
|
52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
53 |
+
# x: (B, T * hop_length)
|
54 |
+
|
55 |
+
if not self.center:
|
56 |
+
pad = self.win_length - self.hop_length
|
57 |
+
x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
|
58 |
+
|
59 |
+
stft_spec = torch.stft(
|
60 |
+
x,
|
61 |
+
self.n_fft,
|
62 |
+
hop_length=self.hop_length,
|
63 |
+
win_length=self.win_length,
|
64 |
+
window=self.window,
|
65 |
+
center=self.center,
|
66 |
+
return_complex=False,
|
67 |
+
) # (B, n_fft // 2 + 1, T, 2)
|
68 |
+
|
69 |
+
rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
|
70 |
+
imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
|
71 |
+
|
72 |
+
log_mag = torch.log(
|
73 |
+
torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
|
74 |
+
) # (B, n_fft // 2 + 1, T)
|
75 |
+
phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
|
76 |
+
|
77 |
+
return log_mag, phase
|
78 |
+
|
79 |
+
|
80 |
+
class ISTFT(nn.Module):
|
81 |
+
"""
|
82 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
83 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
84 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
85 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
86 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
n_fft (int): Size of Fourier transform.
|
90 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
91 |
+
win_length (int): The size of window frame and STFT filter.
|
92 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(
|
96 |
+
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
if padding not in ["center", "same"]:
|
100 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
101 |
+
self.padding = padding
|
102 |
+
self.n_fft = n_fft
|
103 |
+
self.hop_length = hop_length
|
104 |
+
self.win_length = win_length
|
105 |
+
window = torch.hann_window(win_length)
|
106 |
+
self.register_buffer("window", window)
|
107 |
+
|
108 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
109 |
+
"""
|
110 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
114 |
+
N is the number of frequency bins, and T is the number of time frames.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
118 |
+
"""
|
119 |
+
if self.padding == "center":
|
120 |
+
# Fallback to pytorch native implementation
|
121 |
+
return torch.istft(
|
122 |
+
spec,
|
123 |
+
self.n_fft,
|
124 |
+
self.hop_length,
|
125 |
+
self.win_length,
|
126 |
+
self.window,
|
127 |
+
center=True,
|
128 |
+
)
|
129 |
+
elif self.padding == "same":
|
130 |
+
pad = (self.win_length - self.hop_length) // 2
|
131 |
+
else:
|
132 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
133 |
+
|
134 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
135 |
+
B, N, T = spec.shape
|
136 |
+
|
137 |
+
# Inverse FFT
|
138 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
139 |
+
ifft = ifft * self.window[None, :, None]
|
140 |
+
|
141 |
+
# Overlap and Add
|
142 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
143 |
+
y = torch.nn.functional.fold(
|
144 |
+
ifft,
|
145 |
+
output_size=(1, output_size),
|
146 |
+
kernel_size=(1, self.win_length),
|
147 |
+
stride=(1, self.hop_length),
|
148 |
+
)[:, 0, 0, pad:-pad]
|
149 |
+
|
150 |
+
# Window envelope
|
151 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
152 |
+
window_envelope = torch.nn.functional.fold(
|
153 |
+
window_sq,
|
154 |
+
output_size=(1, output_size),
|
155 |
+
kernel_size=(1, self.win_length),
|
156 |
+
stride=(1, self.hop_length),
|
157 |
+
).squeeze()[pad:-pad]
|
158 |
+
|
159 |
+
# Normalize
|
160 |
+
assert (window_envelope > 1e-11).all()
|
161 |
+
y = y / window_envelope
|
162 |
+
|
163 |
+
return y
|
164 |
+
|
165 |
+
|
166 |
+
class MDCT(nn.Module):
|
167 |
+
"""
|
168 |
+
Modified Discrete Cosine Transform (MDCT) module.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
frame_len (int): Length of the MDCT frame.
|
172 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
176 |
+
super().__init__()
|
177 |
+
if padding not in ["center", "same"]:
|
178 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
179 |
+
self.padding = padding
|
180 |
+
self.frame_len = frame_len
|
181 |
+
N = frame_len // 2
|
182 |
+
n0 = (N + 1) / 2
|
183 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
184 |
+
self.register_buffer("window", window)
|
185 |
+
|
186 |
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
187 |
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
188 |
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
189 |
+
# https://github.com/pytorch/pytorch/issues/71613
|
190 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
191 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
192 |
+
|
193 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
194 |
+
"""
|
195 |
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
199 |
+
and T is the length of the audio.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
203 |
+
and N is the number of frequency bins.
|
204 |
+
"""
|
205 |
+
if self.padding == "center":
|
206 |
+
audio = torch.nn.functional.pad(
|
207 |
+
audio, (self.frame_len // 2, self.frame_len // 2)
|
208 |
+
)
|
209 |
+
elif self.padding == "same":
|
210 |
+
# hop_length is 1/2 frame_len
|
211 |
+
audio = torch.nn.functional.pad(
|
212 |
+
audio, (self.frame_len // 4, self.frame_len // 4)
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
216 |
+
|
217 |
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
218 |
+
N = self.frame_len // 2
|
219 |
+
x = x * self.window.expand(x.shape)
|
220 |
+
X = torch.fft.fft(
|
221 |
+
x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
|
222 |
+
)[..., :N]
|
223 |
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
224 |
+
return torch.real(res) * np.sqrt(2)
|
225 |
+
|
226 |
+
|
227 |
+
class IMDCT(nn.Module):
|
228 |
+
"""
|
229 |
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
frame_len (int): Length of the MDCT frame.
|
233 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
234 |
+
"""
|
235 |
+
|
236 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
237 |
+
super().__init__()
|
238 |
+
if padding not in ["center", "same"]:
|
239 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
240 |
+
self.padding = padding
|
241 |
+
self.frame_len = frame_len
|
242 |
+
N = frame_len // 2
|
243 |
+
n0 = (N + 1) / 2
|
244 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
245 |
+
self.register_buffer("window", window)
|
246 |
+
|
247 |
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
248 |
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
249 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
250 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
251 |
+
|
252 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
253 |
+
"""
|
254 |
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
258 |
+
L is the number of frames, and N is the number of frequency bins.
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
262 |
+
"""
|
263 |
+
B, L, N = X.shape
|
264 |
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
265 |
+
Y[..., :N] = X
|
266 |
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
267 |
+
y = torch.fft.ifft(
|
268 |
+
Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
|
269 |
+
)
|
270 |
+
y = (
|
271 |
+
torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
|
272 |
+
* np.sqrt(N)
|
273 |
+
* np.sqrt(2)
|
274 |
+
)
|
275 |
+
result = y * self.window.expand(y.shape)
|
276 |
+
output_size = (1, (L + 1) * N)
|
277 |
+
audio = torch.nn.functional.fold(
|
278 |
+
result.transpose(1, 2),
|
279 |
+
output_size=output_size,
|
280 |
+
kernel_size=(1, self.frame_len),
|
281 |
+
stride=(1, self.frame_len // 2),
|
282 |
+
)[:, 0, 0, :]
|
283 |
+
|
284 |
+
if self.padding == "center":
|
285 |
+
pad = self.frame_len // 2
|
286 |
+
elif self.padding == "same":
|
287 |
+
pad = self.frame_len // 4
|
288 |
+
else:
|
289 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
290 |
+
|
291 |
+
audio = audio[:, pad:-pad]
|
292 |
+
return audio
|
293 |
+
|
294 |
+
|
295 |
+
class FourierHead(nn.Module):
|
296 |
+
"""Base class for inverse fourier modules."""
|
297 |
+
|
298 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
299 |
+
"""
|
300 |
+
Args:
|
301 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
302 |
+
L is the sequence length, and H denotes the model dimension.
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
306 |
+
"""
|
307 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
308 |
+
|
309 |
+
|
310 |
+
class ISTFTHead(FourierHead):
|
311 |
+
"""
|
312 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
313 |
+
|
314 |
+
Args:
|
315 |
+
dim (int): Hidden dimension of the model.
|
316 |
+
n_fft (int): Size of Fourier transform.
|
317 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
318 |
+
the resolution of the input features.
|
319 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
320 |
+
"""
|
321 |
+
|
322 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
323 |
+
super().__init__()
|
324 |
+
out_dim = n_fft + 2
|
325 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
326 |
+
self.istft = ISTFT(
|
327 |
+
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
328 |
+
)
|
329 |
+
|
330 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
331 |
+
"""
|
332 |
+
Forward pass of the ISTFTHead module.
|
333 |
+
|
334 |
+
Args:
|
335 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
336 |
+
L is the sequence length, and H denotes the model dimension.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
340 |
+
"""
|
341 |
+
x = self.out(x).transpose(1, 2)
|
342 |
+
mag, p = x.chunk(2, dim=1)
|
343 |
+
mag = torch.exp(mag)
|
344 |
+
mag = torch.clip(
|
345 |
+
mag, max=1e2
|
346 |
+
) # safeguard to prevent excessively large magnitudes
|
347 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
348 |
+
x = torch.cos(p)
|
349 |
+
y = torch.sin(p)
|
350 |
+
# recalculating phase here does not produce anything new
|
351 |
+
# only costs time
|
352 |
+
# phase = torch.atan2(y, x)
|
353 |
+
# S = mag * torch.exp(phase * 1j)
|
354 |
+
# better directly produce the complex value
|
355 |
+
S = mag * (x + 1j * y)
|
356 |
+
audio = self.istft(S)
|
357 |
+
return audio
|
358 |
+
|
359 |
+
|
360 |
+
class IMDCTSymExpHead(FourierHead):
|
361 |
+
"""
|
362 |
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
363 |
+
|
364 |
+
Args:
|
365 |
+
dim (int): Hidden dimension of the model.
|
366 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
367 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
368 |
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
369 |
+
based on perceptual scaling. Defaults to None.
|
370 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
371 |
+
"""
|
372 |
+
|
373 |
+
def __init__(
|
374 |
+
self,
|
375 |
+
dim: int,
|
376 |
+
mdct_frame_len: int,
|
377 |
+
padding: str = "same",
|
378 |
+
sample_rate: Optional[int] = None,
|
379 |
+
clip_audio: bool = False,
|
380 |
+
):
|
381 |
+
super().__init__()
|
382 |
+
out_dim = mdct_frame_len // 2
|
383 |
+
self.out = nn.Linear(dim, out_dim)
|
384 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
385 |
+
self.clip_audio = clip_audio
|
386 |
+
|
387 |
+
if sample_rate is not None:
|
388 |
+
# optionally init the last layer following mel-scale
|
389 |
+
m_max = _hz_to_mel(sample_rate // 2)
|
390 |
+
m_pts = torch.linspace(0, m_max, out_dim)
|
391 |
+
f_pts = _mel_to_hz(m_pts)
|
392 |
+
scale = 1 - (f_pts / f_pts.max())
|
393 |
+
|
394 |
+
with torch.no_grad():
|
395 |
+
self.out.weight.mul_(scale.view(-1, 1))
|
396 |
+
|
397 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
398 |
+
"""
|
399 |
+
Forward pass of the IMDCTSymExpHead module.
|
400 |
+
|
401 |
+
Args:
|
402 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
403 |
+
L is the sequence length, and H denotes the model dimension.
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
407 |
+
"""
|
408 |
+
x = self.out(x)
|
409 |
+
x = symexp(x)
|
410 |
+
x = torch.clip(
|
411 |
+
x, min=-1e2, max=1e2
|
412 |
+
) # safeguard to prevent excessively large magnitudes
|
413 |
+
audio = self.imdct(x)
|
414 |
+
if self.clip_audio:
|
415 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
416 |
+
|
417 |
+
return audio
|
418 |
+
|
419 |
+
|
420 |
+
class IMDCTCosHead(FourierHead):
|
421 |
+
"""
|
422 |
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
423 |
+
|
424 |
+
Args:
|
425 |
+
dim (int): Hidden dimension of the model.
|
426 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
427 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
428 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
429 |
+
"""
|
430 |
+
|
431 |
+
def __init__(
|
432 |
+
self,
|
433 |
+
dim: int,
|
434 |
+
mdct_frame_len: int,
|
435 |
+
padding: str = "same",
|
436 |
+
clip_audio: bool = False,
|
437 |
+
):
|
438 |
+
super().__init__()
|
439 |
+
self.clip_audio = clip_audio
|
440 |
+
self.out = nn.Linear(dim, mdct_frame_len)
|
441 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
442 |
+
|
443 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
444 |
+
"""
|
445 |
+
Forward pass of the IMDCTCosHead module.
|
446 |
+
|
447 |
+
Args:
|
448 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
449 |
+
L is the sequence length, and H denotes the model dimension.
|
450 |
+
|
451 |
+
Returns:
|
452 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
453 |
+
"""
|
454 |
+
x = self.out(x)
|
455 |
+
m, p = x.chunk(2, dim=2)
|
456 |
+
m = torch.exp(m).clip(
|
457 |
+
max=1e2
|
458 |
+
) # safeguard to prevent excessively large magnitudes
|
459 |
+
audio = self.imdct(m * torch.cos(p))
|
460 |
+
if self.clip_audio:
|
461 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
462 |
+
return audio
|
463 |
+
|
464 |
+
|
465 |
+
class ConvNeXtBlock(nn.Module):
|
466 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
467 |
+
|
468 |
+
Args:
|
469 |
+
dim (int): Number of input channels.
|
470 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
471 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
472 |
+
Defaults to None.
|
473 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
474 |
+
None means non-conditional LayerNorm. Defaults to None.
|
475 |
+
"""
|
476 |
+
|
477 |
+
def __init__(
|
478 |
+
self,
|
479 |
+
dim: int,
|
480 |
+
intermediate_dim: int,
|
481 |
+
layer_scale_init_value: float,
|
482 |
+
adanorm_num_embeddings: Optional[int] = None,
|
483 |
+
):
|
484 |
+
super().__init__()
|
485 |
+
self.dwconv = nn.Conv1d(
|
486 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
487 |
+
) # depthwise conv
|
488 |
+
self.adanorm = adanorm_num_embeddings is not None
|
489 |
+
if adanorm_num_embeddings:
|
490 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
491 |
+
else:
|
492 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
493 |
+
self.pwconv1 = nn.Linear(
|
494 |
+
dim, intermediate_dim
|
495 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
496 |
+
self.act = nn.GELU()
|
497 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
498 |
+
self.gamma = (
|
499 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
500 |
+
if layer_scale_init_value > 0
|
501 |
+
else None
|
502 |
+
)
|
503 |
+
|
504 |
+
def forward(
|
505 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
506 |
+
) -> torch.Tensor:
|
507 |
+
residual = x
|
508 |
+
x = self.dwconv(x)
|
509 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
510 |
+
if self.adanorm:
|
511 |
+
assert cond_embedding_id is not None
|
512 |
+
x = self.norm(x, cond_embedding_id)
|
513 |
+
else:
|
514 |
+
x = self.norm(x)
|
515 |
+
x = self.pwconv1(x)
|
516 |
+
x = self.act(x)
|
517 |
+
x = self.pwconv2(x)
|
518 |
+
if self.gamma is not None:
|
519 |
+
x = self.gamma * x
|
520 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
521 |
+
|
522 |
+
x = residual + x
|
523 |
+
return x
|
524 |
+
|
525 |
+
|
526 |
+
class AdaLayerNorm(nn.Module):
|
527 |
+
"""
|
528 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
529 |
+
|
530 |
+
Args:
|
531 |
+
num_embeddings (int): Number of embeddings.
|
532 |
+
embedding_dim (int): Dimension of the embeddings.
|
533 |
+
"""
|
534 |
+
|
535 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
536 |
+
super().__init__()
|
537 |
+
self.eps = eps
|
538 |
+
self.dim = embedding_dim
|
539 |
+
self.scale = nn.Embedding(
|
540 |
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
541 |
+
)
|
542 |
+
self.shift = nn.Embedding(
|
543 |
+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
544 |
+
)
|
545 |
+
torch.nn.init.ones_(self.scale.weight)
|
546 |
+
torch.nn.init.zeros_(self.shift.weight)
|
547 |
+
|
548 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
549 |
+
scale = self.scale(cond_embedding_id)
|
550 |
+
shift = self.shift(cond_embedding_id)
|
551 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
552 |
+
x = x * scale + shift
|
553 |
+
return x
|
554 |
+
|
555 |
+
|
556 |
+
class ResBlock1(nn.Module):
|
557 |
+
"""
|
558 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
559 |
+
but without upsampling layers.
|
560 |
+
|
561 |
+
Args:
|
562 |
+
dim (int): Number of input channels.
|
563 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
564 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
565 |
+
Defaults to (1, 3, 5).
|
566 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
567 |
+
Defaults to 0.1.
|
568 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
569 |
+
Defaults to None.
|
570 |
+
"""
|
571 |
+
|
572 |
+
def __init__(
|
573 |
+
self,
|
574 |
+
dim: int,
|
575 |
+
kernel_size: int = 3,
|
576 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
577 |
+
lrelu_slope: float = 0.1,
|
578 |
+
layer_scale_init_value: Optional[float] = None,
|
579 |
+
):
|
580 |
+
super().__init__()
|
581 |
+
self.lrelu_slope = lrelu_slope
|
582 |
+
self.convs1 = nn.ModuleList(
|
583 |
+
[
|
584 |
+
weight_norm(
|
585 |
+
nn.Conv1d(
|
586 |
+
dim,
|
587 |
+
dim,
|
588 |
+
kernel_size,
|
589 |
+
1,
|
590 |
+
dilation=dilation[0],
|
591 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
592 |
+
)
|
593 |
+
),
|
594 |
+
weight_norm(
|
595 |
+
nn.Conv1d(
|
596 |
+
dim,
|
597 |
+
dim,
|
598 |
+
kernel_size,
|
599 |
+
1,
|
600 |
+
dilation=dilation[1],
|
601 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
602 |
+
)
|
603 |
+
),
|
604 |
+
weight_norm(
|
605 |
+
nn.Conv1d(
|
606 |
+
dim,
|
607 |
+
dim,
|
608 |
+
kernel_size,
|
609 |
+
1,
|
610 |
+
dilation=dilation[2],
|
611 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
612 |
+
)
|
613 |
+
),
|
614 |
+
]
|
615 |
+
)
|
616 |
+
|
617 |
+
self.convs2 = nn.ModuleList(
|
618 |
+
[
|
619 |
+
weight_norm(
|
620 |
+
nn.Conv1d(
|
621 |
+
dim,
|
622 |
+
dim,
|
623 |
+
kernel_size,
|
624 |
+
1,
|
625 |
+
dilation=1,
|
626 |
+
padding=self.get_padding(kernel_size, 1),
|
627 |
+
)
|
628 |
+
),
|
629 |
+
weight_norm(
|
630 |
+
nn.Conv1d(
|
631 |
+
dim,
|
632 |
+
dim,
|
633 |
+
kernel_size,
|
634 |
+
1,
|
635 |
+
dilation=1,
|
636 |
+
padding=self.get_padding(kernel_size, 1),
|
637 |
+
)
|
638 |
+
),
|
639 |
+
weight_norm(
|
640 |
+
nn.Conv1d(
|
641 |
+
dim,
|
642 |
+
dim,
|
643 |
+
kernel_size,
|
644 |
+
1,
|
645 |
+
dilation=1,
|
646 |
+
padding=self.get_padding(kernel_size, 1),
|
647 |
+
)
|
648 |
+
),
|
649 |
+
]
|
650 |
+
)
|
651 |
+
|
652 |
+
self.gamma = nn.ParameterList(
|
653 |
+
[
|
654 |
+
(
|
655 |
+
nn.Parameter(
|
656 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
657 |
+
)
|
658 |
+
if layer_scale_init_value is not None
|
659 |
+
else None
|
660 |
+
),
|
661 |
+
(
|
662 |
+
nn.Parameter(
|
663 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
664 |
+
)
|
665 |
+
if layer_scale_init_value is not None
|
666 |
+
else None
|
667 |
+
),
|
668 |
+
(
|
669 |
+
nn.Parameter(
|
670 |
+
layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
|
671 |
+
)
|
672 |
+
if layer_scale_init_value is not None
|
673 |
+
else None
|
674 |
+
),
|
675 |
+
]
|
676 |
+
)
|
677 |
+
|
678 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
679 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
680 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
681 |
+
xt = c1(xt)
|
682 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
683 |
+
xt = c2(xt)
|
684 |
+
if gamma is not None:
|
685 |
+
xt = gamma * xt
|
686 |
+
x = xt + x
|
687 |
+
return x
|
688 |
+
|
689 |
+
def remove_weight_norm(self):
|
690 |
+
for l in self.convs1:
|
691 |
+
remove_weight_norm(l)
|
692 |
+
for l in self.convs2:
|
693 |
+
remove_weight_norm(l)
|
694 |
+
|
695 |
+
@staticmethod
|
696 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
697 |
+
return int((kernel_size * dilation - dilation) / 2)
|
698 |
+
|
699 |
+
|
700 |
+
class Backbone(nn.Module):
|
701 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
702 |
+
|
703 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
704 |
+
"""
|
705 |
+
Args:
|
706 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
707 |
+
C denotes output features, and L is the sequence length.
|
708 |
+
|
709 |
+
Returns:
|
710 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
711 |
+
and H denotes the model dimension.
|
712 |
+
"""
|
713 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
714 |
+
|
715 |
+
|
716 |
+
class VocosBackbone(Backbone):
|
717 |
+
"""
|
718 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
719 |
+
|
720 |
+
Args:
|
721 |
+
input_channels (int): Number of input features channels.
|
722 |
+
dim (int): Hidden dimension of the model.
|
723 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
724 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
725 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
726 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
727 |
+
None means non-conditional model. Defaults to None.
|
728 |
+
"""
|
729 |
+
|
730 |
+
def __init__(
|
731 |
+
self,
|
732 |
+
input_channels: int,
|
733 |
+
dim: int,
|
734 |
+
intermediate_dim: int,
|
735 |
+
num_layers: int,
|
736 |
+
layer_scale_init_value: Optional[float] = None,
|
737 |
+
adanorm_num_embeddings: Optional[int] = None,
|
738 |
+
):
|
739 |
+
super().__init__()
|
740 |
+
self.input_channels = input_channels
|
741 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
742 |
+
self.adanorm = adanorm_num_embeddings is not None
|
743 |
+
if adanorm_num_embeddings:
|
744 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
745 |
+
else:
|
746 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
747 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
748 |
+
self.convnext = nn.ModuleList(
|
749 |
+
[
|
750 |
+
ConvNeXtBlock(
|
751 |
+
dim=dim,
|
752 |
+
intermediate_dim=intermediate_dim,
|
753 |
+
layer_scale_init_value=layer_scale_init_value,
|
754 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
755 |
+
)
|
756 |
+
for _ in range(num_layers)
|
757 |
+
]
|
758 |
+
)
|
759 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
760 |
+
self.apply(self._init_weights)
|
761 |
+
|
762 |
+
def _init_weights(self, m):
|
763 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
764 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
765 |
+
nn.init.constant_(m.bias, 0)
|
766 |
+
|
767 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
768 |
+
bandwidth_id = kwargs.get("bandwidth_id", None)
|
769 |
+
x = self.embed(x)
|
770 |
+
if self.adanorm:
|
771 |
+
assert bandwidth_id is not None
|
772 |
+
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
773 |
+
else:
|
774 |
+
x = self.norm(x.transpose(1, 2))
|
775 |
+
x = x.transpose(1, 2)
|
776 |
+
for conv_block in self.convnext:
|
777 |
+
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
778 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
779 |
+
return x
|
780 |
+
|
781 |
+
|
782 |
+
class VocosResNetBackbone(Backbone):
|
783 |
+
"""
|
784 |
+
Vocos backbone module built with ResBlocks.
|
785 |
+
|
786 |
+
Args:
|
787 |
+
input_channels (int): Number of input features channels.
|
788 |
+
dim (int): Hidden dimension of the model.
|
789 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
790 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
791 |
+
"""
|
792 |
+
|
793 |
+
def __init__(
|
794 |
+
self,
|
795 |
+
input_channels,
|
796 |
+
dim,
|
797 |
+
num_blocks,
|
798 |
+
layer_scale_init_value=None,
|
799 |
+
):
|
800 |
+
super().__init__()
|
801 |
+
self.input_channels = input_channels
|
802 |
+
self.embed = weight_norm(
|
803 |
+
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
|
804 |
+
)
|
805 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
806 |
+
self.resnet = nn.Sequential(
|
807 |
+
*[
|
808 |
+
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
|
809 |
+
for _ in range(num_blocks)
|
810 |
+
]
|
811 |
+
)
|
812 |
+
|
813 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
814 |
+
x = self.embed(x)
|
815 |
+
x = self.resnet(x)
|
816 |
+
x = x.transpose(1, 2)
|
817 |
+
return x
|
818 |
+
|
819 |
+
|
820 |
+
class Vocos(nn.Module):
|
821 |
+
def __init__(
|
822 |
+
self,
|
823 |
+
input_channels: int = 256,
|
824 |
+
dim: int = 384,
|
825 |
+
intermediate_dim: int = 1152,
|
826 |
+
num_layers: int = 8,
|
827 |
+
n_fft: int = 800,
|
828 |
+
hop_size: int = 200,
|
829 |
+
padding: str = "same",
|
830 |
+
adanorm_num_embeddings=None,
|
831 |
+
cfg=None,
|
832 |
+
):
|
833 |
+
super().__init__()
|
834 |
+
|
835 |
+
input_channels = (
|
836 |
+
cfg.input_channels
|
837 |
+
if cfg is not None and hasattr(cfg, "input_channels")
|
838 |
+
else input_channels
|
839 |
+
)
|
840 |
+
dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
|
841 |
+
intermediate_dim = (
|
842 |
+
cfg.intermediate_dim
|
843 |
+
if cfg is not None and hasattr(cfg, "intermediate_dim")
|
844 |
+
else intermediate_dim
|
845 |
+
)
|
846 |
+
num_layers = (
|
847 |
+
cfg.num_layers
|
848 |
+
if cfg is not None and hasattr(cfg, "num_layers")
|
849 |
+
else num_layers
|
850 |
+
)
|
851 |
+
adanorm_num_embeddings = (
|
852 |
+
cfg.adanorm_num_embeddings
|
853 |
+
if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
|
854 |
+
else adanorm_num_embeddings
|
855 |
+
)
|
856 |
+
n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
|
857 |
+
hop_size = (
|
858 |
+
cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
|
859 |
+
)
|
860 |
+
padding = (
|
861 |
+
cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
|
862 |
+
)
|
863 |
+
|
864 |
+
self.backbone = VocosBackbone(
|
865 |
+
input_channels=input_channels,
|
866 |
+
dim=dim,
|
867 |
+
intermediate_dim=intermediate_dim,
|
868 |
+
num_layers=num_layers,
|
869 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
870 |
+
)
|
871 |
+
self.head = ISTFTHead(dim, n_fft, hop_size, padding)
|
872 |
+
|
873 |
+
def forward(self, x):
|
874 |
+
x = self.backbone(x)
|
875 |
+
x = self.head(x)
|
876 |
+
|
877 |
+
return x[:, None, :]
|
878 |
+
|
879 |
+
|
880 |
+
# if __name__ == "__main__":
|
881 |
+
# vocos_model = Vocos(
|
882 |
+
# input_channels=128,
|
883 |
+
# dim=1024,
|
884 |
+
# intermediate_dim=4096,
|
885 |
+
# adanorm_num_embeddings=None,
|
886 |
+
# num_layers=40,
|
887 |
+
# n_fft=1920,
|
888 |
+
# hop_size=480,
|
889 |
+
# padding="same"
|
890 |
+
# )
|
891 |
+
# mel_model = MelSpectrogram(
|
892 |
+
# 1920,
|
893 |
+
# 128,
|
894 |
+
# 24000,
|
895 |
+
# 480,
|
896 |
+
# 1920,
|
897 |
+
# 0,
|
898 |
+
# 12000
|
899 |
+
# )
|
900 |
+
# print(sum(p.numel() for p in vocos_model.parameters())/1e6)
|
901 |
+
|
902 |
+
# speech = librosa.load("/mnt/bn/yuacnwang-speech/dataset/yuanshen/ganyu_en.wav", sr=24000)[0][:36000]
|
903 |
+
# speech = torch.tensor(speech).unsqueeze(0)
|
904 |
+
# print(speech.shape)
|
905 |
+
# mel_feat = mel_model(speech)
|
906 |
+
# print(mel_feat.shape)
|
907 |
+
|
908 |
+
# rec_speech = vocos_model(mel_feat)
|
909 |
+
# print(rec_speech.shape)
|
models/codec/melvqgan/__pycache__/melspec.cpython-310.pyc
ADDED
Binary file (2.78 kB). View file
|
|
models/codec/melvqgan/melspec.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pyworld as pw
|
3 |
+
import numpy as np
|
4 |
+
import soundfile as sf
|
5 |
+
import os
|
6 |
+
from torchaudio.functional import pitch_shift
|
7 |
+
import librosa
|
8 |
+
from librosa.filters import mel as librosa_mel_fn
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
15 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
16 |
+
|
17 |
+
|
18 |
+
def dynamic_range_decompression(x, C=1):
|
19 |
+
return np.exp(x) / C
|
20 |
+
|
21 |
+
|
22 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
23 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
24 |
+
|
25 |
+
|
26 |
+
def dynamic_range_decompression_torch(x, C=1):
|
27 |
+
return torch.exp(x) / C
|
28 |
+
|
29 |
+
|
30 |
+
def spectral_normalize_torch(magnitudes):
|
31 |
+
output = dynamic_range_compression_torch(magnitudes)
|
32 |
+
return output
|
33 |
+
|
34 |
+
|
35 |
+
def spectral_de_normalize_torch(magnitudes):
|
36 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
37 |
+
return output
|
38 |
+
|
39 |
+
|
40 |
+
class MelSpectrogram(nn.Module):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
n_fft,
|
44 |
+
num_mels,
|
45 |
+
sampling_rate,
|
46 |
+
hop_size,
|
47 |
+
win_size,
|
48 |
+
fmin,
|
49 |
+
fmax,
|
50 |
+
center=False,
|
51 |
+
):
|
52 |
+
super(MelSpectrogram, self).__init__()
|
53 |
+
self.n_fft = n_fft
|
54 |
+
self.hop_size = hop_size
|
55 |
+
self.win_size = win_size
|
56 |
+
self.sampling_rate = sampling_rate
|
57 |
+
self.num_mels = num_mels
|
58 |
+
self.fmin = fmin
|
59 |
+
self.fmax = fmax
|
60 |
+
self.center = center
|
61 |
+
|
62 |
+
mel_basis = {}
|
63 |
+
hann_window = {}
|
64 |
+
|
65 |
+
mel = librosa_mel_fn(
|
66 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
67 |
+
)
|
68 |
+
mel_basis = torch.from_numpy(mel).float()
|
69 |
+
hann_window = torch.hann_window(win_size)
|
70 |
+
|
71 |
+
self.register_buffer("mel_basis", mel_basis)
|
72 |
+
self.register_buffer("hann_window", hann_window)
|
73 |
+
|
74 |
+
def forward(self, y):
|
75 |
+
y = torch.nn.functional.pad(
|
76 |
+
y.unsqueeze(1),
|
77 |
+
(
|
78 |
+
int((self.n_fft - self.hop_size) / 2),
|
79 |
+
int((self.n_fft - self.hop_size) / 2),
|
80 |
+
),
|
81 |
+
mode="reflect",
|
82 |
+
)
|
83 |
+
y = y.squeeze(1)
|
84 |
+
spec = torch.stft(
|
85 |
+
y,
|
86 |
+
self.n_fft,
|
87 |
+
hop_length=self.hop_size,
|
88 |
+
win_length=self.win_size,
|
89 |
+
window=self.hann_window,
|
90 |
+
center=self.center,
|
91 |
+
pad_mode="reflect",
|
92 |
+
normalized=False,
|
93 |
+
onesided=True,
|
94 |
+
return_complex=True,
|
95 |
+
)
|
96 |
+
spec = torch.view_as_real(spec)
|
97 |
+
|
98 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
99 |
+
|
100 |
+
spec = torch.matmul(self.mel_basis, spec)
|
101 |
+
spec = spectral_normalize_torch(spec)
|
102 |
+
|
103 |
+
return spec
|
104 |
+
|
105 |
+
|
106 |
+
# if __name__ == "__main__":
|
107 |
+
# mel_model = MelSpectrogram(
|
108 |
+
# sampling_rate=24000,
|
109 |
+
# num_mels=128,
|
110 |
+
# hop_size=480,
|
111 |
+
# n_fft=1920,
|
112 |
+
# win_size=1920,
|
113 |
+
# fmin=0,
|
114 |
+
# fmax=12000,
|
115 |
+
# )
|
116 |
+
# mel_model = mel_model.to("cuda:0")
|
117 |
+
|
118 |
+
# x_len = 0
|
119 |
+
# var = 0
|
120 |
+
# mean = 0
|
121 |
+
|
122 |
+
|
123 |
+
# from utils.util import load_config
|
124 |
+
# cfg = load_config("/storage/wyc/SpeechGeneration/egs/tts/SoundStorm/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json")
|
125 |
+
# print(cfg)
|
126 |
+
# dataset = SoundStormDataset(AK, SK, bucket_name, cfg=cfg)
|
127 |
+
# print(dataset.__getitem__(0))
|
128 |
+
|
129 |
+
# idx_list = list(range(0, len(dataset)))
|
130 |
+
# np.random.shuffle(idx_list)
|
131 |
+
# for i in tqdm.tqdm(range(10000)):
|
132 |
+
# idx = idx_list[i]
|
133 |
+
# # data_path = dataset[idx]
|
134 |
+
# # speech, _ = librosa.load(data_path, sr=24000)
|
135 |
+
# speech = dataset[idx]["speech"]
|
136 |
+
# speech = torch.tensor(speech).unsqueeze(0).to("cuda")
|
137 |
+
# mel = mel_model(speech)
|
138 |
+
# temp_len = mel.shape[-1]
|
139 |
+
# temp_mean = mel.mean()
|
140 |
+
# temp_var = mel.var()
|
141 |
+
|
142 |
+
# new_mean = (mean * x_len + temp_mean * temp_len) / (x_len + temp_len)
|
143 |
+
# new_var = (var * (x_len - 1) + temp_var * (temp_len - 1) + x_len * (new_mean - mean)**2 + temp_len * (new_mean - temp_mean)**2) / (x_len + temp_len -1)
|
144 |
+
|
145 |
+
# x_len += temp_len
|
146 |
+
# mean = new_mean
|
147 |
+
# var = new_var
|
148 |
+
|
149 |
+
# if i % 100 == 0:
|
150 |
+
# print(mean, var)
|
151 |
+
|
152 |
+
# print(mean)
|
153 |
+
# print(var)
|
models/tts/llm_tts/__pycache__/chat_template.cpython-310.pyc
ADDED
Binary file (2.36 kB). View file
|
|
models/tts/llm_tts/__pycache__/inference_llm_tts.cpython-310.pyc
ADDED
Binary file (6.95 kB). View file
|
|
models/tts/llm_tts/__pycache__/inference_mgm_tts.cpython-310.pyc
ADDED
Binary file (8.31 kB). View file
|
|
models/tts/llm_tts/__pycache__/llama_nar_prefix.cpython-310.pyc
ADDED
Binary file (12.8 kB). View file
|
|
models/tts/llm_tts/__pycache__/mgm.cpython-310.pyc
ADDED
Binary file (8.05 kB). View file
|
|
models/tts/llm_tts/chat_template.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def format_chat_prompt_phi3(messages, add_assistant_token=True):
|
2 |
+
"""
|
3 |
+
Convert the messages list into the phi-3 chat template format.
|
4 |
+
|
5 |
+
Args:
|
6 |
+
messages: A list of messages containing role and content.
|
7 |
+
|
8 |
+
Returns:
|
9 |
+
str: The formatted prompt string.
|
10 |
+
"""
|
11 |
+
prompt = ""
|
12 |
+
for msg in messages:
|
13 |
+
role = msg["role"]
|
14 |
+
content = msg["content"]
|
15 |
+
# Add corresponding tags for system and user messages
|
16 |
+
if role in ["system", "user"]:
|
17 |
+
prompt += f"<|{role}|>\n{content}<|end|>\n"
|
18 |
+
# For assistant messages, add only the start tag if it's the last one
|
19 |
+
elif role == "assistant" and msg != messages[-1]:
|
20 |
+
prompt += f"<|{role}|>\n{content}<|end|>\n"
|
21 |
+
elif role == "assistant" and msg == messages[-1]:
|
22 |
+
prompt += f"<|{role}|>\n{content}"
|
23 |
+
|
24 |
+
# If the last message is not from the assistant, add the assistant tag
|
25 |
+
if messages[-1]["role"] != "assistant" and add_assistant_token:
|
26 |
+
prompt += "<|assistant|>"
|
27 |
+
return prompt
|
28 |
+
|
29 |
+
|
30 |
+
def format_chat_prompt_qwen2(messages, add_assistant_token=True):
|
31 |
+
"""
|
32 |
+
Custom function to format chat prompts without tool-related logic.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
messages: A list of messages containing role and content.
|
36 |
+
add_generation_prompt: Boolean to add a generation prompt at the end.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
str: The formatted prompt string.
|
40 |
+
"""
|
41 |
+
prompt = ""
|
42 |
+
|
43 |
+
if messages[0]["role"] == "system":
|
44 |
+
prompt += f"<|im_start|>system\n{messages[0]['content']}<|im_end|>\n"
|
45 |
+
else:
|
46 |
+
prompt += "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
|
47 |
+
|
48 |
+
for message in messages:
|
49 |
+
if (
|
50 |
+
(message["role"] == "user")
|
51 |
+
or (message["role"] == "system" and not message == messages[0])
|
52 |
+
or (message["role"] == "assistant" and not message.get("tool_calls"))
|
53 |
+
):
|
54 |
+
prompt += f"<|im_start|>{message['role']}\n{message['content']}<|im_end|>\n"
|
55 |
+
elif message["role"] == "assistant":
|
56 |
+
prompt += f"<|im_start|>{message['role']}"
|
57 |
+
if message.get("content"):
|
58 |
+
prompt += f"\n{message['content']}"
|
59 |
+
prompt += "<|im_end|>\n"
|
60 |
+
|
61 |
+
if add_assistant_token:
|
62 |
+
prompt += "<|im_start|>assistant\n"
|
63 |
+
|
64 |
+
return prompt
|
65 |
+
|
66 |
+
|
67 |
+
def gen_chat_prompt_for_tts(text, model_name="phi-3", caption=None):
|
68 |
+
if caption is None:
|
69 |
+
template = [
|
70 |
+
{
|
71 |
+
"role": "system",
|
72 |
+
"content": "You are a powerful AI assistant for speech understanding and generation.",
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"role": "user",
|
76 |
+
"content": f"Please speak the following text out loud: {text}",
|
77 |
+
},
|
78 |
+
]
|
79 |
+
else:
|
80 |
+
template = [
|
81 |
+
{
|
82 |
+
"role": "system",
|
83 |
+
"content": "You are a powerful AI assistant for speech understanding and generation.",
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"role": "user",
|
87 |
+
"content": f"Please follow the caption: <|start_of_caption|>{caption}<|end_of_caption|> and speak the following text out loud: {text}",
|
88 |
+
},
|
89 |
+
]
|
90 |
+
|
91 |
+
if model_name == "phi-3":
|
92 |
+
return format_chat_prompt_phi3(template)
|
93 |
+
elif model_name == "qwen2":
|
94 |
+
return format_chat_prompt_qwen2(template)
|
95 |
+
else:
|
96 |
+
raise ValueError(f"Unsupported model: {model_name}")
|
models/tts/llm_tts/inference_llm_tts.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Optional
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
import os
|
6 |
+
from huggingface_hub import snapshot_download
|
7 |
+
|
8 |
+
from models.tts.tadicodec.inference_tadicodec import TaDiCodecPipline
|
9 |
+
from models.tts.llm_tts.chat_template import gen_chat_prompt_for_tts
|
10 |
+
|
11 |
+
|
12 |
+
class TTSInferencePipeline(nn.Module):
|
13 |
+
"""
|
14 |
+
TTS inference pipeline that integrates TaDiCodec and LLM models
|
15 |
+
Uses standard LLM for autoregressive generation
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
tadicodec_path: str,
|
21 |
+
llm_path: str,
|
22 |
+
device: torch.device,
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
self.device = device
|
26 |
+
self.llm_path = llm_path
|
27 |
+
|
28 |
+
# Load TaDiCodec pipeline
|
29 |
+
self.tadicodec = TaDiCodecPipline.from_pretrained(
|
30 |
+
ckpt_dir=tadicodec_path, device=device
|
31 |
+
)
|
32 |
+
|
33 |
+
# Load LLM directly from pretrained
|
34 |
+
# Try to use flash attention 2, fallback to default if not available
|
35 |
+
try:
|
36 |
+
self.llm = AutoModelForCausalLM.from_pretrained(
|
37 |
+
llm_path,
|
38 |
+
device_map=device,
|
39 |
+
torch_dtype="auto",
|
40 |
+
trust_remote_code=True,
|
41 |
+
attn_implementation="flash_attention_2",
|
42 |
+
)
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Flash attention 2 not available, using default attention: {e}")
|
45 |
+
self.llm = AutoModelForCausalLM.from_pretrained(
|
46 |
+
llm_path,
|
47 |
+
device_map=device,
|
48 |
+
torch_dtype="auto",
|
49 |
+
trust_remote_code=True,
|
50 |
+
)
|
51 |
+
|
52 |
+
# Load tokenizer directly from pretrained
|
53 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
54 |
+
llm_path,
|
55 |
+
trust_remote_code=True,
|
56 |
+
)
|
57 |
+
|
58 |
+
def tensor_to_audio_string(self, tensor):
|
59 |
+
"""Convert tensor to audio string format"""
|
60 |
+
if isinstance(tensor, list) and isinstance(tensor[0], list):
|
61 |
+
values = tensor[0]
|
62 |
+
else:
|
63 |
+
values = tensor[0].tolist() if hasattr(tensor, "tolist") else tensor[0]
|
64 |
+
|
65 |
+
result = "<|start_of_audio|>"
|
66 |
+
for value in values:
|
67 |
+
result += f"<|audio_{value}|>"
|
68 |
+
return result
|
69 |
+
|
70 |
+
def extract_audio_ids(self, text):
|
71 |
+
"""Extract audio IDs from string containing audio tokens"""
|
72 |
+
import re
|
73 |
+
|
74 |
+
pattern = r"<\|audio_(\d+)\|>"
|
75 |
+
audio_ids = re.findall(pattern, text)
|
76 |
+
return [int(id) for id in audio_ids]
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def from_pretrained(
|
80 |
+
cls,
|
81 |
+
model_id: str = None,
|
82 |
+
tadicodec_path: str = None,
|
83 |
+
llm_path: str = None,
|
84 |
+
device: Optional[torch.device] = None,
|
85 |
+
auto_download: bool = True,
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
Create pipeline from pretrained models
|
89 |
+
|
90 |
+
Args:
|
91 |
+
model_id: Hugging Face model ID for the LLM (e.g., "amphion/TaDiCodec-TTS-AR-Qwen2.5-3B")
|
92 |
+
tadicodec_path: Path to TaDiCodec model or Hugging Face model ID (defaults to "amphion/TaDiCodec")
|
93 |
+
llm_path: Path to LLM model or Hugging Face model ID (overrides model_id if provided)
|
94 |
+
device: Device to run on
|
95 |
+
auto_download: Whether to automatically download models from Hugging Face if not found locally
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
TTSInferencePipeline instance
|
99 |
+
"""
|
100 |
+
resolved_device = (
|
101 |
+
device
|
102 |
+
if device is not None
|
103 |
+
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
104 |
+
)
|
105 |
+
|
106 |
+
# Set default paths if not provided
|
107 |
+
if tadicodec_path is None:
|
108 |
+
tadicodec_path = "amphion/TaDiCodec"
|
109 |
+
|
110 |
+
if llm_path is None:
|
111 |
+
if model_id is not None:
|
112 |
+
llm_path = model_id
|
113 |
+
else:
|
114 |
+
llm_path = "./ckpt/TaDiCodec-TTS-AR-Qwen2.5-0.5B"
|
115 |
+
|
116 |
+
# Handle TaDiCodec path
|
117 |
+
resolved_tadicodec_path = cls._resolve_model_path(
|
118 |
+
tadicodec_path, auto_download=auto_download, model_type="tadicodec"
|
119 |
+
)
|
120 |
+
|
121 |
+
# Handle LLM path
|
122 |
+
resolved_llm_path = cls._resolve_model_path(
|
123 |
+
llm_path, auto_download=auto_download, model_type="llm"
|
124 |
+
)
|
125 |
+
|
126 |
+
return cls(
|
127 |
+
tadicodec_path=resolved_tadicodec_path,
|
128 |
+
llm_path=resolved_llm_path,
|
129 |
+
device=resolved_device,
|
130 |
+
)
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def _resolve_model_path(
|
134 |
+
model_path: str, auto_download: bool = True, model_type: str = "llm"
|
135 |
+
) -> str:
|
136 |
+
"""
|
137 |
+
Resolve model path, downloading from Hugging Face if necessary
|
138 |
+
|
139 |
+
Args:
|
140 |
+
model_path: Local path or Hugging Face model ID
|
141 |
+
auto_download: Whether to auto-download from HF
|
142 |
+
model_type: Type of model ("llm" or "tadicodec")
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
Resolved local path
|
146 |
+
"""
|
147 |
+
# If it's already a local path and exists, return as is
|
148 |
+
if os.path.exists(model_path):
|
149 |
+
return model_path
|
150 |
+
|
151 |
+
# If it looks like a Hugging Face model ID (contains '/')
|
152 |
+
if "/" in model_path and auto_download:
|
153 |
+
print(f"Downloading {model_type} model from Hugging Face: {model_path}")
|
154 |
+
try:
|
155 |
+
# Download to cache directory
|
156 |
+
cache_dir = os.path.join(
|
157 |
+
os.path.expanduser("~"), ".cache", "huggingface", "hub"
|
158 |
+
)
|
159 |
+
downloaded_path = snapshot_download(
|
160 |
+
repo_id=model_path,
|
161 |
+
cache_dir=cache_dir,
|
162 |
+
local_dir_use_symlinks=False,
|
163 |
+
)
|
164 |
+
print(
|
165 |
+
f"Successfully downloaded {model_type} model to: {downloaded_path}"
|
166 |
+
)
|
167 |
+
return downloaded_path
|
168 |
+
except Exception as e:
|
169 |
+
print(f"Failed to download {model_type} model from Hugging Face: {e}")
|
170 |
+
raise ValueError(
|
171 |
+
f"Could not download {model_type} model from {model_path}"
|
172 |
+
)
|
173 |
+
|
174 |
+
# If it's a local path that doesn't exist
|
175 |
+
if not os.path.exists(model_path):
|
176 |
+
if auto_download:
|
177 |
+
raise ValueError(
|
178 |
+
f"Model path does not exist: {model_path}. Set auto_download=True to download from Hugging Face."
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
raise FileNotFoundError(f"Model path does not exist: {model_path}")
|
182 |
+
|
183 |
+
return model_path
|
184 |
+
|
185 |
+
@torch.no_grad()
|
186 |
+
def __call__(
|
187 |
+
self,
|
188 |
+
text: str,
|
189 |
+
prompt_text: Optional[str] = None,
|
190 |
+
prompt_speech_path: Optional[str] = None,
|
191 |
+
top_k: int = 50,
|
192 |
+
top_p: float = 0.98,
|
193 |
+
temperature: float = 1.0,
|
194 |
+
n_timesteps: int = 25,
|
195 |
+
return_code: bool = False,
|
196 |
+
):
|
197 |
+
"""
|
198 |
+
Perform TTS inference
|
199 |
+
|
200 |
+
Args:
|
201 |
+
text: Target text to synthesize
|
202 |
+
prompt_text: Prompt text for conditioning
|
203 |
+
prompt_speech_path: Path to prompt audio file
|
204 |
+
top_k: Top-k sampling parameter
|
205 |
+
top_p: Top-p sampling parameter
|
206 |
+
temperature: Temperature for sampling
|
207 |
+
n_timesteps: Number of diffusion timesteps
|
208 |
+
return_code: Whether to return audio codes instead of audio
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
Generated audio array or audio codes
|
212 |
+
"""
|
213 |
+
# Get prompt audio codes
|
214 |
+
if prompt_speech_path:
|
215 |
+
prompt_speech_code = self.tadicodec(
|
216 |
+
speech_path=prompt_speech_path, return_code=True, text=""
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
raise ValueError("prompt_speech_path is required")
|
220 |
+
|
221 |
+
# Use standard LLM for autoregressive generation
|
222 |
+
# TODO: add a json style chat prompt template
|
223 |
+
prompt = gen_chat_prompt_for_tts(
|
224 |
+
(prompt_text or "") + text,
|
225 |
+
"phi-3" if "Phi" in self.llm_path else "qwen2",
|
226 |
+
) + self.tensor_to_audio_string(prompt_speech_code)
|
227 |
+
|
228 |
+
input_ids = self.tokenizer.encode(prompt)
|
229 |
+
generate_ids = self.llm.generate(
|
230 |
+
input_ids=torch.tensor(input_ids).unsqueeze(0).to(self.device),
|
231 |
+
min_new_tokens=12,
|
232 |
+
max_new_tokens=400,
|
233 |
+
do_sample=True,
|
234 |
+
top_k=top_k,
|
235 |
+
top_p=top_p,
|
236 |
+
temperature=temperature,
|
237 |
+
)
|
238 |
+
|
239 |
+
output = self.tokenizer.decode(generate_ids[0], skip_special_tokens=False)
|
240 |
+
|
241 |
+
combine_speech_code = self.extract_audio_ids(output)
|
242 |
+
indices = torch.tensor(combine_speech_code).unsqueeze(0).long().to(self.device)
|
243 |
+
|
244 |
+
if return_code:
|
245 |
+
return indices
|
246 |
+
|
247 |
+
# Decode audio
|
248 |
+
text_token_ids = self.tadicodec.tokenize_text(text, prompt_text)
|
249 |
+
prompt_mel = self.tadicodec.extract_mel_feature(prompt_speech_path)
|
250 |
+
|
251 |
+
rec_mel = self.tadicodec.decode(
|
252 |
+
indices=indices,
|
253 |
+
text_token_ids=text_token_ids,
|
254 |
+
prompt_mel=prompt_mel,
|
255 |
+
n_timesteps=n_timesteps,
|
256 |
+
)
|
257 |
+
|
258 |
+
rec_audio = (
|
259 |
+
self.tadicodec.vocoder_model(rec_mel.transpose(1, 2))
|
260 |
+
.detach()
|
261 |
+
.cpu()
|
262 |
+
.numpy()[0][0]
|
263 |
+
)
|
264 |
+
|
265 |
+
return rec_audio
|
models/tts/llm_tts/inference_mgm_tts.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import librosa
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from typing import Optional
|
10 |
+
import safetensors
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
from utils.util import load_config
|
13 |
+
|
14 |
+
from models.tts.tadicodec.inference_tadicodec import TaDiCodecPipline
|
15 |
+
from models.tts.llm_tts.mgm import MGMT2S
|
16 |
+
|
17 |
+
from models.tts.llm_tts.chat_template import gen_chat_prompt_for_tts
|
18 |
+
|
19 |
+
|
20 |
+
class MGMInferencePipeline(nn.Module):
|
21 |
+
"""
|
22 |
+
MGM TTS inference pipeline that integrates TaDiCodec and MGM models
|
23 |
+
Uses diffusion-based generation with mask-guided modeling
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
tadicodec_path: str,
|
29 |
+
mgm_path: str,
|
30 |
+
device: torch.device,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
self.device = device
|
34 |
+
self.mgm_path = mgm_path
|
35 |
+
|
36 |
+
# Load TaDiCodec pipeline
|
37 |
+
self.tadicodec = TaDiCodecPipline.from_pretrained(
|
38 |
+
ckpt_dir=tadicodec_path, device=device
|
39 |
+
)
|
40 |
+
|
41 |
+
# Load tokenizer directly from pretrained
|
42 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
43 |
+
mgm_path,
|
44 |
+
trust_remote_code=True,
|
45 |
+
)
|
46 |
+
|
47 |
+
config_path = os.path.join(mgm_path, "config.json")
|
48 |
+
if not os.path.exists(config_path):
|
49 |
+
raise FileNotFoundError(f"Config file not found at {config_path}")
|
50 |
+
|
51 |
+
self.cfg = load_config(config_path)
|
52 |
+
|
53 |
+
# Extract MGM config from the loaded config
|
54 |
+
mgm_config = self.cfg.model.mgmt2s
|
55 |
+
if not mgm_config:
|
56 |
+
raise ValueError("MGM config not found in config.json")
|
57 |
+
|
58 |
+
# Load MGM model with config - using the same pattern as llm_infer_eval.py
|
59 |
+
self.mgm = MGMT2S(
|
60 |
+
hidden_size=mgm_config.hidden_size,
|
61 |
+
num_layers=mgm_config.num_layers,
|
62 |
+
num_heads=mgm_config.num_heads,
|
63 |
+
cfg_scale=mgm_config.cfg_scale,
|
64 |
+
cond_codebook_size=mgm_config.cond_codebook_size,
|
65 |
+
cond_dim=mgm_config.cond_dim,
|
66 |
+
phone_vocab_size=mgm_config.phone_vocab_size,
|
67 |
+
)
|
68 |
+
|
69 |
+
# Load model weights
|
70 |
+
model_path = os.path.join(mgm_path, "model.safetensors")
|
71 |
+
|
72 |
+
if os.path.exists(model_path):
|
73 |
+
safetensors.torch.load_model(self.mgm, model_path, strict=True)
|
74 |
+
else:
|
75 |
+
# Try loading from the directory directly
|
76 |
+
safetensors.torch.load_model(self.mgm, mgm_path, strict=True)
|
77 |
+
|
78 |
+
self.mgm.to(device)
|
79 |
+
self.mgm.eval()
|
80 |
+
|
81 |
+
def tensor_to_audio_string(self, tensor):
|
82 |
+
"""Convert tensor to audio string format"""
|
83 |
+
if isinstance(tensor, list) and isinstance(tensor[0], list):
|
84 |
+
values = tensor[0]
|
85 |
+
else:
|
86 |
+
values = tensor[0].tolist() if hasattr(tensor, "tolist") else tensor[0]
|
87 |
+
|
88 |
+
result = "<|start_of_audio|>"
|
89 |
+
for value in values:
|
90 |
+
result += f"<|audio_{value}|>"
|
91 |
+
return result
|
92 |
+
|
93 |
+
def extract_audio_ids(self, text):
|
94 |
+
"""Extract audio IDs from string containing audio tokens"""
|
95 |
+
import re
|
96 |
+
|
97 |
+
pattern = r"<\|audio_(\d+)\|>"
|
98 |
+
audio_ids = re.findall(pattern, text)
|
99 |
+
return [int(id) for id in audio_ids]
|
100 |
+
|
101 |
+
@classmethod
|
102 |
+
def from_pretrained(
|
103 |
+
cls,
|
104 |
+
model_id: str = None,
|
105 |
+
tadicodec_path: str = None,
|
106 |
+
mgm_path: str = None,
|
107 |
+
device: Optional[torch.device] = None,
|
108 |
+
auto_download: bool = True,
|
109 |
+
):
|
110 |
+
"""
|
111 |
+
Create pipeline from pretrained models
|
112 |
+
|
113 |
+
Args:
|
114 |
+
model_id: Hugging Face model ID for the MGM model (e.g., "amphion/TaDiCodec-TTS-MGM")
|
115 |
+
tadicodec_path: Path to TaDiCodec model or Hugging Face model ID (defaults to "amphion/TaDiCodec")
|
116 |
+
mgm_path: Path to MGM model directory or Hugging Face model ID (overrides model_id if provided)
|
117 |
+
device: Device to run on
|
118 |
+
auto_download: Whether to automatically download models from Hugging Face if not found locally
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
MGMInferencePipeline instance
|
122 |
+
"""
|
123 |
+
resolved_device = (
|
124 |
+
device
|
125 |
+
if device is not None
|
126 |
+
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
127 |
+
)
|
128 |
+
|
129 |
+
# Set default paths if not provided
|
130 |
+
if tadicodec_path is None:
|
131 |
+
tadicodec_path = "amphion/TaDiCodec"
|
132 |
+
|
133 |
+
if mgm_path is None:
|
134 |
+
if model_id is not None:
|
135 |
+
mgm_path = model_id
|
136 |
+
else:
|
137 |
+
mgm_path = "./ckpt/TaDiCodec-TTS-MGM"
|
138 |
+
|
139 |
+
# Handle TaDiCodec path
|
140 |
+
resolved_tadicodec_path = cls._resolve_model_path(
|
141 |
+
tadicodec_path, auto_download=auto_download, model_type="tadicodec"
|
142 |
+
)
|
143 |
+
|
144 |
+
# Handle MGM path
|
145 |
+
resolved_mgm_path = cls._resolve_model_path(
|
146 |
+
mgm_path, auto_download=auto_download, model_type="mgm"
|
147 |
+
)
|
148 |
+
|
149 |
+
return cls(
|
150 |
+
tadicodec_path=resolved_tadicodec_path,
|
151 |
+
mgm_path=resolved_mgm_path,
|
152 |
+
device=resolved_device,
|
153 |
+
)
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def _resolve_model_path(
|
157 |
+
model_path: str, auto_download: bool = True, model_type: str = "mgm"
|
158 |
+
) -> str:
|
159 |
+
"""
|
160 |
+
Resolve model path, downloading from Hugging Face if necessary
|
161 |
+
|
162 |
+
Args:
|
163 |
+
model_path: Local path or Hugging Face model ID
|
164 |
+
auto_download: Whether to auto-download from HF
|
165 |
+
model_type: Type of model ("mgm" or "tadicodec")
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
Resolved local path
|
169 |
+
"""
|
170 |
+
# If it's already a local path and exists, return as is
|
171 |
+
if os.path.exists(model_path):
|
172 |
+
return model_path
|
173 |
+
|
174 |
+
# If it looks like a Hugging Face model ID (contains '/')
|
175 |
+
if "/" in model_path and auto_download:
|
176 |
+
print(f"Downloading {model_type} model from Hugging Face: {model_path}")
|
177 |
+
try:
|
178 |
+
# Download to cache directory
|
179 |
+
cache_dir = os.path.join(
|
180 |
+
os.path.expanduser("~"), ".cache", "huggingface", "hub"
|
181 |
+
)
|
182 |
+
downloaded_path = snapshot_download(
|
183 |
+
repo_id=model_path,
|
184 |
+
cache_dir=cache_dir,
|
185 |
+
local_dir_use_symlinks=False,
|
186 |
+
)
|
187 |
+
print(
|
188 |
+
f"Successfully downloaded {model_type} model to: {downloaded_path}"
|
189 |
+
)
|
190 |
+
return downloaded_path
|
191 |
+
except Exception as e:
|
192 |
+
print(f"Failed to download {model_type} model from Hugging Face: {e}")
|
193 |
+
raise ValueError(
|
194 |
+
f"Could not download {model_type} model from {model_path}"
|
195 |
+
)
|
196 |
+
|
197 |
+
# If it's a local path that doesn't exist
|
198 |
+
if not os.path.exists(model_path):
|
199 |
+
if auto_download:
|
200 |
+
raise ValueError(
|
201 |
+
f"Model path does not exist: {model_path}. Set auto_download=True to download from Hugging Face."
|
202 |
+
)
|
203 |
+
else:
|
204 |
+
raise FileNotFoundError(f"Model path does not exist: {model_path}")
|
205 |
+
|
206 |
+
return model_path
|
207 |
+
|
208 |
+
@torch.no_grad()
|
209 |
+
def __call__(
|
210 |
+
self,
|
211 |
+
text: str,
|
212 |
+
prompt_text: Optional[str] = None,
|
213 |
+
prompt_speech_path: Optional[str] = None,
|
214 |
+
n_timesteps_mgm: int = 25,
|
215 |
+
n_timesteps: int = 25,
|
216 |
+
target_len: Optional[int] = None,
|
217 |
+
return_code: bool = False,
|
218 |
+
):
|
219 |
+
"""
|
220 |
+
Perform MGM TTS inference
|
221 |
+
|
222 |
+
Args:
|
223 |
+
text: Target text to synthesize
|
224 |
+
prompt_text: Prompt text for conditioning
|
225 |
+
prompt_speech_path: Path to prompt audio file
|
226 |
+
n_timesteps_mgm: Number of diffusion timesteps for MGM
|
227 |
+
n_timesteps: Number of diffusion timesteps
|
228 |
+
target_len: Target length for audio generation
|
229 |
+
return_code: Whether to return audio codes instead of audio
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
Generated audio array or audio codes
|
233 |
+
"""
|
234 |
+
# Get prompt audio codes
|
235 |
+
if prompt_speech_path:
|
236 |
+
prompt_speech_code = self.tadicodec(
|
237 |
+
speech_path=prompt_speech_path, return_code=True, text=""
|
238 |
+
)
|
239 |
+
else:
|
240 |
+
raise ValueError("prompt_speech_path is required")
|
241 |
+
|
242 |
+
# Convert prompt codes to tensor
|
243 |
+
prompt_codes = torch.tensor(prompt_speech_code).to(self.device)
|
244 |
+
prompt_len = prompt_codes.shape[1]
|
245 |
+
|
246 |
+
# Tokenize text for phone conditioning
|
247 |
+
input_text = gen_chat_prompt_for_tts(
|
248 |
+
prompt_text + " " + text,
|
249 |
+
"phi-3" if "phi" in self.cfg.preprocess.tokenizer_path else "qwen2",
|
250 |
+
)
|
251 |
+
|
252 |
+
##### debug #####
|
253 |
+
print("input_text: ", input_text)
|
254 |
+
##### debug #####
|
255 |
+
|
256 |
+
text_token_ids = self.tokenizer.encode(input_text)
|
257 |
+
text_token_ids = torch.tensor(text_token_ids).unsqueeze(0).to(self.device)
|
258 |
+
|
259 |
+
# Estimate target length based on text length
|
260 |
+
frame_rate = getattr(self.cfg.preprocess, "frame_rate", 6.25)
|
261 |
+
|
262 |
+
if target_len is None:
|
263 |
+
# If no target_len, estimate based on prompt speech length and text ratio
|
264 |
+
prompt_text_len = len(prompt_text.encode("utf-8"))
|
265 |
+
target_text_len = len(text.encode("utf-8"))
|
266 |
+
prompt_speech_len = librosa.get_duration(filename=prompt_speech_path)
|
267 |
+
target_speech_len = prompt_speech_len * target_text_len / prompt_text_len
|
268 |
+
target_len = int(target_speech_len * frame_rate)
|
269 |
+
else:
|
270 |
+
# If target_len is provided, use it directly
|
271 |
+
target_len = int(target_len * frame_rate)
|
272 |
+
|
273 |
+
##### debug #####
|
274 |
+
print(f"Prompt length: {prompt_len}, Target length: {target_len}")
|
275 |
+
print(f"Text: {text}")
|
276 |
+
print(f"Prompt text: {prompt_text}")
|
277 |
+
##### debug #####
|
278 |
+
|
279 |
+
# Generate audio codes using MGM reverse diffusion
|
280 |
+
generated_codes = self.mgm.reverse_diffusion(
|
281 |
+
prompt=prompt_codes,
|
282 |
+
target_len=target_len,
|
283 |
+
phone_id=text_token_ids,
|
284 |
+
n_timesteps=n_timesteps_mgm,
|
285 |
+
cfg=1.5,
|
286 |
+
rescale_cfg=0.75,
|
287 |
+
)
|
288 |
+
|
289 |
+
print(f"Generated codes shape: {generated_codes.shape}")
|
290 |
+
|
291 |
+
combine_codes = torch.cat([prompt_codes, generated_codes], dim=1)
|
292 |
+
|
293 |
+
if return_code:
|
294 |
+
return combine_codes
|
295 |
+
|
296 |
+
# Decode audio using TaDiCodec
|
297 |
+
prompt_mel = self.tadicodec.extract_mel_feature(prompt_speech_path)
|
298 |
+
|
299 |
+
text_token_ids = self.tadicodec.tokenize_text(text, prompt_text)
|
300 |
+
rec_mel = self.tadicodec.decode(
|
301 |
+
indices=combine_codes,
|
302 |
+
text_token_ids=text_token_ids,
|
303 |
+
prompt_mel=prompt_mel,
|
304 |
+
n_timesteps=n_timesteps,
|
305 |
+
)
|
306 |
+
|
307 |
+
rec_audio = (
|
308 |
+
self.tadicodec.vocoder_model(rec_mel.transpose(1, 2))
|
309 |
+
.detach()
|
310 |
+
.cpu()
|
311 |
+
.numpy()[0][0]
|
312 |
+
)
|
313 |
+
|
314 |
+
return rec_audio
|
315 |
+
|
316 |
+
|
317 |
+
# Usage example
|
318 |
+
if __name__ == "__main__":
|
319 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
320 |
+
|
321 |
+
# Create pipeline
|
322 |
+
pipeline = MGMInferencePipeline.from_pretrained(
|
323 |
+
tadicodec_path="./ckpt/TaDiCodec",
|
324 |
+
mgm_path="./ckpt/TaDiCodec-TTS-MGM",
|
325 |
+
device=device,
|
326 |
+
)
|
327 |
+
|
328 |
+
# Inference on single sample
|
329 |
+
audio = pipeline(
|
330 |
+
text="但是 to those who 知道 her well, it was a 标志 of her unwavering 决心 and spirit.",
|
331 |
+
prompt_text="In short, we embarked on a mission to make America great again, for all Americans.",
|
332 |
+
prompt_speech_path="./use_examples/test_audio/trump_0.wav",
|
333 |
+
)
|
334 |
+
|
335 |
+
# Save audio
|
336 |
+
import soundfile as sf
|
337 |
+
|
338 |
+
sf.write("./use_examples/test_audio/mgm_tts_output.wav", audio, 24000)
|
models/tts/llm_tts/llama_nar_prefix.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import LlamaConfig, LlamaModel
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
import math
|
6 |
+
|
7 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
8 |
+
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
|
9 |
+
|
10 |
+
|
11 |
+
class SinusoidalPosEmb(nn.Module):
|
12 |
+
"""
|
13 |
+
Sinusoidal Positional Embedding module.
|
14 |
+
|
15 |
+
This module generates sinusoidal positional embeddings for a given 1D input tensor,
|
16 |
+
which is commonly used for representing timesteps in diffusion models.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, dim: int):
|
20 |
+
"""
|
21 |
+
Initializes the SinusoidalPosEmb module.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
dim (int): The dimension of the embedding.
|
25 |
+
"""
|
26 |
+
super().__init__()
|
27 |
+
self.dim = dim
|
28 |
+
|
29 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
30 |
+
"""
|
31 |
+
Generates the positional embedding.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
x (torch.Tensor): A 1D tensor of positions (e.g., timesteps), shape `(batch_size,)`.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
torch.Tensor: The positional embeddings, shape `(batch_size, dim)`.
|
38 |
+
"""
|
39 |
+
device = x.device
|
40 |
+
half_dim = self.dim // 2
|
41 |
+
# Calculate the embedding frequencies based on the log-space formula
|
42 |
+
emb = math.log(10000) / (half_dim - 1)
|
43 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
44 |
+
# Create the embedding matrix by multiplying positions with frequencies
|
45 |
+
emb = x[:, None] * emb[None, :] * 1.0
|
46 |
+
# Concatenate sine and cosine components to form the final embedding
|
47 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
48 |
+
return emb
|
49 |
+
|
50 |
+
|
51 |
+
class LlamaAdaptiveRMSNorm(nn.Module):
|
52 |
+
"""
|
53 |
+
Adaptive Root Mean Square Layer Normalization.
|
54 |
+
|
55 |
+
This is a variant of RMSNorm where the scaling factor (weight) is adaptively
|
56 |
+
predicted from a conditional embedding, allowing for conditional modulation
|
57 |
+
of the normalized hidden states.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self, hidden_size: int = 1024, eps: float = 1e-6, dim_cond: int = 1024
|
62 |
+
):
|
63 |
+
"""
|
64 |
+
Initializes the LlamaAdaptiveRMSNorm module.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
hidden_size (int): The dimension of the hidden states to be normalized.
|
68 |
+
eps (float): A small value added to the variance for numerical stability.
|
69 |
+
dim_cond (int): The dimension of the conditional embedding.
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
# Linear layer to project the conditional embedding to the hidden size
|
73 |
+
self.to_weight = nn.Linear(dim_cond, hidden_size)
|
74 |
+
# Initialize weights to zero and bias to one for an identity transformation at the start
|
75 |
+
nn.init.zeros_(self.to_weight.weight)
|
76 |
+
nn.init.ones_(self.to_weight.bias)
|
77 |
+
self.variance_epsilon = eps
|
78 |
+
# Disable automatic Hugging Face initialization for this custom module
|
79 |
+
self._is_hf_initialized = True
|
80 |
+
|
81 |
+
def forward(
|
82 |
+
self, hidden_states: torch.Tensor, cond_embedding: torch.Tensor
|
83 |
+
) -> torch.Tensor:
|
84 |
+
"""
|
85 |
+
Applies the adaptive RMS normalization.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
hidden_states (torch.Tensor): The input tensor, shape `(batch, seq_len, hidden_size)`.
|
89 |
+
cond_embedding (torch.Tensor): The conditional embedding, shape `(batch, dim_cond)`.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
torch.Tensor: The normalized and modulated hidden states.
|
93 |
+
"""
|
94 |
+
input_dtype = hidden_states.dtype
|
95 |
+
# Calculate variance and normalize the hidden states
|
96 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
97 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
98 |
+
|
99 |
+
# Predict the scaling factor from the conditional embedding
|
100 |
+
weight = self.to_weight(cond_embedding)
|
101 |
+
# Unsqueeze if the conditional embedding is per-batch instead of per-token
|
102 |
+
if len(weight.shape) == 2:
|
103 |
+
weight = weight.unsqueeze(1)
|
104 |
+
|
105 |
+
# Apply the learned scaling factor
|
106 |
+
return (weight * hidden_states).to(input_dtype)
|
107 |
+
|
108 |
+
|
109 |
+
class LlamaNARDecoderLayer(LlamaDecoderLayer):
|
110 |
+
"""
|
111 |
+
A Non-Autoregressive (NAR) Llama Decoder Layer using adaptive layer normalization.
|
112 |
+
|
113 |
+
This class overrides the standard LlamaDecoderLayer to replace its RMSNorm
|
114 |
+
modules with LlamaAdaptiveRMSNorm, allowing it to be conditioned on an external embedding.
|
115 |
+
"""
|
116 |
+
|
117 |
+
def __init__(self, config: LlamaConfig, layer_idx: int):
|
118 |
+
"""Overrides the LlamaDecoderLayer to use adaptive layer normalization."""
|
119 |
+
super().__init__(config, layer_idx) # init attention, mlp, etc. from parent
|
120 |
+
self.layer_idx = layer_idx
|
121 |
+
# Override the standard layer norms with our adaptive versions
|
122 |
+
self.input_layernorm = LlamaAdaptiveRMSNorm(
|
123 |
+
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
|
124 |
+
)
|
125 |
+
self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
|
126 |
+
config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
|
127 |
+
)
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
hidden_states: torch.Tensor,
|
132 |
+
cond_embedding: torch.Tensor,
|
133 |
+
attention_mask: Optional[torch.Tensor] = None,
|
134 |
+
position_ids: Optional[torch.LongTensor] = None,
|
135 |
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
136 |
+
output_attentions: Optional[bool] = False,
|
137 |
+
use_cache: Optional[bool] = False,
|
138 |
+
) -> Tuple[
|
139 |
+
torch.Tensor,
|
140 |
+
Optional[torch.Tensor],
|
141 |
+
Optional[Tuple[torch.Tensor, torch.Tensor]],
|
142 |
+
]:
|
143 |
+
"""
|
144 |
+
Forward pass for the NAR decoder layer, including conditional embedding.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
hidden_states (torch.Tensor): Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
148 |
+
cond_embedding (torch.Tensor): Conditional embedding for adaptive normalization.
|
149 |
+
attention_mask (Optional[torch.Tensor]): Attention mask of size `(batch, 1, tgt_len, src_len)`.
|
150 |
+
position_ids (Optional[torch.LongTensor]): Indices of positions of each input sequence tokens.
|
151 |
+
past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached past key and value projection states.
|
152 |
+
output_attentions (Optional[bool]): Whether to return the attention tensors.
|
153 |
+
use_cache (Optional[bool]): If True, past key values are returned to speed up decoding.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Tuple containing the output hidden states, and optionally attention weights and past key/value states.
|
157 |
+
"""
|
158 |
+
residual = hidden_states
|
159 |
+
|
160 |
+
# Apply adaptive pre-attention layer norm
|
161 |
+
hidden_states = self.input_layernorm(
|
162 |
+
hidden_states, cond_embedding=cond_embedding
|
163 |
+
)
|
164 |
+
|
165 |
+
# Self Attention block
|
166 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
167 |
+
hidden_states=hidden_states,
|
168 |
+
attention_mask=attention_mask,
|
169 |
+
position_ids=position_ids,
|
170 |
+
past_key_value=past_key_value,
|
171 |
+
output_attentions=output_attentions,
|
172 |
+
use_cache=use_cache,
|
173 |
+
)
|
174 |
+
hidden_states = residual + hidden_states
|
175 |
+
|
176 |
+
# Fully Connected block
|
177 |
+
residual = hidden_states
|
178 |
+
# Apply adaptive post-attention layer norm
|
179 |
+
hidden_states = self.post_attention_layernorm(
|
180 |
+
hidden_states, cond_embedding=cond_embedding
|
181 |
+
)
|
182 |
+
hidden_states = self.mlp(hidden_states)
|
183 |
+
hidden_states = residual + hidden_states
|
184 |
+
|
185 |
+
outputs = (hidden_states,)
|
186 |
+
|
187 |
+
if output_attentions:
|
188 |
+
outputs += (self_attn_weights,)
|
189 |
+
|
190 |
+
if use_cache:
|
191 |
+
outputs += (present_key_value,)
|
192 |
+
|
193 |
+
return outputs
|
194 |
+
|
195 |
+
|
196 |
+
class DiffLlamaPrefix(LlamaModel):
|
197 |
+
"""
|
198 |
+
A Llama-based non-autoregressive transformer model for diffusion (masked generative modeling) tasks.
|
199 |
+
|
200 |
+
This model uses a Llama architecture but modifies it for non-autoregressive generation.
|
201 |
+
Key features:
|
202 |
+
1. Non-causal (fully-visible) attention mask.
|
203 |
+
2. Adaptive layer normalization conditioned on diffusion timesteps.
|
204 |
+
3. Ability to be conditioned on phoneme (text) embeddings, which are prepended as a prefix.
|
205 |
+
"""
|
206 |
+
|
207 |
+
def __init__(
|
208 |
+
self,
|
209 |
+
hidden_size: int = 1024,
|
210 |
+
num_heads: int = 16,
|
211 |
+
num_layers: int = 16,
|
212 |
+
use_phone_cond: bool = True,
|
213 |
+
config: LlamaConfig = LlamaConfig(
|
214 |
+
vocab_size=0, hidden_size=256, num_attention_heads=1, num_hidden_layers=1
|
215 |
+
),
|
216 |
+
):
|
217 |
+
"""
|
218 |
+
Initializes the DiffLlamaPrefix model.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
hidden_size (int): The hidden dimension of the transformer.
|
222 |
+
num_heads (int): The number of attention heads.
|
223 |
+
num_layers (int): The number of transformer layers.
|
224 |
+
use_phone_cond (bool): Whether to use phoneme embeddings as a conditional prefix.
|
225 |
+
config (LlamaConfig): A LlamaConfig object. A default is provided for convenience.
|
226 |
+
"""
|
227 |
+
super().__init__(config)
|
228 |
+
|
229 |
+
self.use_phone_cond = use_phone_cond
|
230 |
+
|
231 |
+
# Create a stack of non-autoregressive Llama layers
|
232 |
+
self.layers = nn.ModuleList(
|
233 |
+
[
|
234 |
+
LlamaNARDecoderLayer(
|
235 |
+
LlamaConfig(
|
236 |
+
hidden_size=hidden_size,
|
237 |
+
num_attention_heads=num_heads,
|
238 |
+
max_position_embeddings=4096,
|
239 |
+
intermediate_size=hidden_size * 4,
|
240 |
+
),
|
241 |
+
layer_idx=i,
|
242 |
+
)
|
243 |
+
for i in range(num_layers)
|
244 |
+
]
|
245 |
+
)
|
246 |
+
|
247 |
+
# Final adaptive layer norm
|
248 |
+
self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
|
249 |
+
|
250 |
+
# Modules for diffusion step conditioning
|
251 |
+
self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
|
252 |
+
self.diff_step_mlp = nn.Sequential(
|
253 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
254 |
+
nn.SiLU(),
|
255 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
256 |
+
)
|
257 |
+
|
258 |
+
# MLP for processing phoneme embedding condition
|
259 |
+
if self.use_phone_cond:
|
260 |
+
self.cond_mlp = nn.Sequential(
|
261 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
262 |
+
nn.SiLU(),
|
263 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
264 |
+
)
|
265 |
+
|
266 |
+
# This loop is redundant if layers are initialized correctly, but ensures config consistency.
|
267 |
+
for layer in self.layers:
|
268 |
+
layer.input_layernorm = LlamaAdaptiveRMSNorm(
|
269 |
+
hidden_size, dim_cond=hidden_size
|
270 |
+
)
|
271 |
+
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
|
272 |
+
hidden_size, dim_cond=hidden_size
|
273 |
+
)
|
274 |
+
|
275 |
+
# We handle embeddings manually, so disable the default token embedder
|
276 |
+
self.embed_tokens = None
|
277 |
+
|
278 |
+
self.post_init()
|
279 |
+
|
280 |
+
def _prepare_decoder_attention_mask(
|
281 |
+
self,
|
282 |
+
attention_mask: torch.Tensor,
|
283 |
+
input_shape: Tuple[int, int],
|
284 |
+
inputs_embeds: torch.Tensor,
|
285 |
+
past_key_values_length: int,
|
286 |
+
) -> Optional[torch.Tensor]:
|
287 |
+
"""
|
288 |
+
Creates a non-causal (fully-visible) attention mask.
|
289 |
+
|
290 |
+
This method overrides the default causal mask creation. It converts a 2D padding mask
|
291 |
+
`[bsz, seq_len]` into a 4D attention mask `[bsz, 1, tgt_seq_len, src_seq_len]`
|
292 |
+
suitable for self-attention, without applying a causal triangle.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
attention_mask (torch.Tensor): The 2D padding mask.
|
296 |
+
input_shape (Tuple[int, int]): The shape of the input (`batch_size`, `seq_len`).
|
297 |
+
inputs_embeds (torch.Tensor): The input embeddings tensor.
|
298 |
+
past_key_values_length (int): The length of any cached key-values.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
Optional[torch.Tensor]: The 4D attention mask, or None if the input mask is None.
|
302 |
+
"""
|
303 |
+
combined_attention_mask = None
|
304 |
+
|
305 |
+
def _expand_mask(
|
306 |
+
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
307 |
+
) -> torch.Tensor:
|
308 |
+
"""Expands a 2D attention mask to a 4D attention mask."""
|
309 |
+
bsz, src_len = mask.size()
|
310 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
311 |
+
expanded_mask = (
|
312 |
+
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
313 |
+
)
|
314 |
+
# Invert the mask and convert to additive format (-inf for masked positions)
|
315 |
+
inverted_mask = 1.0 - expanded_mask
|
316 |
+
return inverted_mask.masked_fill(
|
317 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
318 |
+
)
|
319 |
+
|
320 |
+
if attention_mask is not None:
|
321 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
322 |
+
expanded_attn_mask = _expand_mask(
|
323 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
324 |
+
).to(inputs_embeds.device)
|
325 |
+
combined_attention_mask = (
|
326 |
+
expanded_attn_mask
|
327 |
+
if combined_attention_mask is None
|
328 |
+
else expanded_attn_mask + combined_attention_mask
|
329 |
+
)
|
330 |
+
|
331 |
+
return combined_attention_mask
|
332 |
+
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
x: torch.Tensor,
|
336 |
+
diffusion_step: torch.Tensor,
|
337 |
+
x_mask: torch.Tensor,
|
338 |
+
phone_embedding: Optional[torch.Tensor] = None,
|
339 |
+
phone_mask: Optional[torch.Tensor] = None,
|
340 |
+
input_ids: Optional[torch.LongTensor] = None,
|
341 |
+
attention_mask: Optional[torch.Tensor] = None,
|
342 |
+
position_ids: Optional[torch.LongTensor] = None,
|
343 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
344 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
345 |
+
use_cache: Optional[bool] = None,
|
346 |
+
output_attentions: Optional[bool] = None,
|
347 |
+
output_hidden_states: Optional[bool] = None,
|
348 |
+
return_dict: Optional[bool] = None,
|
349 |
+
) -> torch.Tensor:
|
350 |
+
"""
|
351 |
+
Forward pass of the DiffLlamaPrefix model.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
x (torch.Tensor): The primary input tensor, e.g., noisy data (`batch, seq_len, hidden_size`).
|
355 |
+
diffusion_step (torch.Tensor): Diffusion timesteps, shape `(batch,)`.
|
356 |
+
x_mask (torch.Tensor): The padding mask for `x`, shape `(batch, seq_len)`.
|
357 |
+
phone_embedding (Optional[torch.Tensor]): Phoneme embeddings prefix, shape `(batch, phone_len, hidden_size)`.
|
358 |
+
phone_mask (Optional[torch.Tensor]): The padding mask for `phone_embedding`, shape `(batch, phone_len)`.
|
359 |
+
input_ids, etc.: Standard Hugging Face arguments, mostly for compatibility.
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
torch.Tensor: The final output tensor of shape `(batch, seq_len, hidden_size)`.
|
363 |
+
"""
|
364 |
+
# 1. Prepend conditional prefix (phoneme embeddings)
|
365 |
+
if self.use_phone_cond and phone_embedding is not None:
|
366 |
+
# Process condition through an MLP
|
367 |
+
phone_embedding = self.cond_mlp(phone_embedding) # (B, T_phone, C)
|
368 |
+
phone_length = phone_embedding.shape[1]
|
369 |
+
# Concatenate prefix and main input
|
370 |
+
inputs_embeds = torch.cat([phone_embedding, x], dim=1)
|
371 |
+
attention_mask = torch.cat([phone_mask, x_mask], dim=1)
|
372 |
+
else:
|
373 |
+
inputs_embeds = x
|
374 |
+
attention_mask = x_mask
|
375 |
+
phone_length = 0
|
376 |
+
|
377 |
+
# 2. Process diffusion step embedding for adaptive normalization
|
378 |
+
diffusion_step_emb = self.diff_step_embedding(diffusion_step).to(x.device)
|
379 |
+
diffusion_step_emb = self.diff_step_mlp(diffusion_step_emb) # (B, C)
|
380 |
+
|
381 |
+
# 3. Standard Transformer Preamble (adapted from LlamaModel)
|
382 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
383 |
+
|
384 |
+
output_attentions = (
|
385 |
+
output_attentions
|
386 |
+
if output_attentions is not None
|
387 |
+
else self.config.output_attentions
|
388 |
+
)
|
389 |
+
output_hidden_states = (
|
390 |
+
output_hidden_states
|
391 |
+
if output_hidden_states is not None
|
392 |
+
else self.config.output_hidden_states
|
393 |
+
)
|
394 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
395 |
+
return_dict = (
|
396 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
397 |
+
)
|
398 |
+
|
399 |
+
past_key_values_length = 0
|
400 |
+
if past_key_values is not None:
|
401 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
402 |
+
|
403 |
+
if position_ids is None:
|
404 |
+
device = inputs_embeds.device
|
405 |
+
position_ids = torch.arange(
|
406 |
+
past_key_values_length,
|
407 |
+
seq_length + past_key_values_length,
|
408 |
+
dtype=torch.long,
|
409 |
+
device=device,
|
410 |
+
)
|
411 |
+
position_ids = position_ids.unsqueeze(0)
|
412 |
+
|
413 |
+
# Create the non-causal attention mask
|
414 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
415 |
+
attention_mask,
|
416 |
+
(batch_size, seq_length),
|
417 |
+
inputs_embeds,
|
418 |
+
past_key_values_length,
|
419 |
+
)
|
420 |
+
|
421 |
+
hidden_states = inputs_embeds
|
422 |
+
|
423 |
+
# 4. Transformer Decoder Layers
|
424 |
+
all_hidden_states = () if output_hidden_states else None
|
425 |
+
all_self_attns = () if output_attentions else None
|
426 |
+
|
427 |
+
for idx, decoder_layer in enumerate(self.layers):
|
428 |
+
if output_hidden_states:
|
429 |
+
all_hidden_states += (hidden_states,)
|
430 |
+
|
431 |
+
past_key_value = (
|
432 |
+
past_key_values[idx] if past_key_values is not None else None
|
433 |
+
)
|
434 |
+
|
435 |
+
# Pass the processed diffusion step embedding to the adaptive layer
|
436 |
+
layer_outputs = decoder_layer(
|
437 |
+
hidden_states,
|
438 |
+
cond_embedding=diffusion_step_emb,
|
439 |
+
attention_mask=attention_mask,
|
440 |
+
position_ids=position_ids,
|
441 |
+
past_key_value=past_key_value,
|
442 |
+
output_attentions=output_attentions,
|
443 |
+
use_cache=use_cache,
|
444 |
+
)
|
445 |
+
hidden_states = layer_outputs[0]
|
446 |
+
|
447 |
+
if output_attentions:
|
448 |
+
all_self_attns += (layer_outputs[1],)
|
449 |
+
|
450 |
+
# 5. Final Normalization and Output Processing
|
451 |
+
hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step_emb)
|
452 |
+
|
453 |
+
if output_hidden_states:
|
454 |
+
all_hidden_states += (hidden_states,)
|
455 |
+
|
456 |
+
# Remove the conditional prefix from the final output sequence
|
457 |
+
return hidden_states[:, phone_length:]
|
models/tts/llm_tts/mgm.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
from einops import rearrange
|
6 |
+
from models.tts.llm_tts.llama_nar_prefix import DiffLlamaPrefix
|
7 |
+
|
8 |
+
|
9 |
+
def top_k(logits, thres=0.9):
|
10 |
+
k = math.ceil((1 - thres) * logits.shape[-1])
|
11 |
+
val, ind = logits.topk(k, dim=-1)
|
12 |
+
probs = torch.full_like(logits, float("-inf"))
|
13 |
+
probs.scatter_(2, ind, val)
|
14 |
+
return probs
|
15 |
+
|
16 |
+
|
17 |
+
def log(t, eps=1e-10):
|
18 |
+
return torch.log(t + eps)
|
19 |
+
|
20 |
+
|
21 |
+
def gumbel_noise(t):
|
22 |
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
23 |
+
return -log(-log(noise))
|
24 |
+
|
25 |
+
|
26 |
+
def gumbel_sample(t, temperature=1.0, dim=-1):
|
27 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
28 |
+
|
29 |
+
|
30 |
+
class MGMT2S(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
hidden_size=1024,
|
34 |
+
num_layers=16,
|
35 |
+
num_heads=16,
|
36 |
+
cfg_scale=0.2,
|
37 |
+
cond_codebook_size=8192,
|
38 |
+
cond_dim=1024,
|
39 |
+
use_phone_cond=True,
|
40 |
+
phone_vocab_size=32100,
|
41 |
+
cfg=None,
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
hidden_size = (
|
46 |
+
cfg.hidden_size
|
47 |
+
if cfg is not None and hasattr(cfg, "hidden_size")
|
48 |
+
else hidden_size
|
49 |
+
)
|
50 |
+
num_layers = (
|
51 |
+
cfg.num_layers
|
52 |
+
if cfg is not None and hasattr(cfg, "num_layers")
|
53 |
+
else num_layers
|
54 |
+
)
|
55 |
+
num_heads = (
|
56 |
+
cfg.num_heads
|
57 |
+
if cfg is not None and hasattr(cfg, "num_heads")
|
58 |
+
else num_heads
|
59 |
+
)
|
60 |
+
cfg_scale = (
|
61 |
+
cfg.cfg_scale
|
62 |
+
if cfg is not None and hasattr(cfg, "cfg_scale")
|
63 |
+
else cfg_scale
|
64 |
+
)
|
65 |
+
cond_codebook_size = (
|
66 |
+
cfg.cond_codebook_size
|
67 |
+
if cfg is not None and hasattr(cfg, "cond_codebook_size")
|
68 |
+
else cond_codebook_size
|
69 |
+
)
|
70 |
+
cond_dim = (
|
71 |
+
cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
|
72 |
+
)
|
73 |
+
use_phone_cond = (
|
74 |
+
cfg.use_phone_cond
|
75 |
+
if cfg is not None and hasattr(cfg, "use_phone_cond")
|
76 |
+
else use_phone_cond
|
77 |
+
)
|
78 |
+
phone_vocab_size = (
|
79 |
+
cfg.phone_vocab_size
|
80 |
+
if cfg is not None and hasattr(cfg, "phone_vocab_size")
|
81 |
+
else phone_vocab_size
|
82 |
+
)
|
83 |
+
|
84 |
+
self.hidden_size = hidden_size
|
85 |
+
self.num_layers = num_layers
|
86 |
+
self.num_heads = num_heads
|
87 |
+
self.cfg_scale = cfg_scale
|
88 |
+
self.cond_codebook_size = cond_codebook_size
|
89 |
+
self.cond_dim = cond_dim
|
90 |
+
self.use_phone_cond = use_phone_cond
|
91 |
+
self.phone_vocab_size = phone_vocab_size
|
92 |
+
|
93 |
+
self.mask_emb = nn.Embedding(1, self.hidden_size)
|
94 |
+
|
95 |
+
self.to_logit = nn.Linear(self.hidden_size, self.cond_codebook_size)
|
96 |
+
|
97 |
+
self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size)
|
98 |
+
|
99 |
+
if self.use_phone_cond:
|
100 |
+
self.phone_emb = nn.Embedding(self.phone_vocab_size, hidden_size)
|
101 |
+
torch.nn.init.normal_(self.phone_emb.weight, mean=0.0, std=0.02)
|
102 |
+
|
103 |
+
self.reset_parameters()
|
104 |
+
|
105 |
+
self.diff_estimator = DiffLlamaPrefix(
|
106 |
+
hidden_size=hidden_size,
|
107 |
+
num_heads=num_heads,
|
108 |
+
num_layers=num_layers,
|
109 |
+
use_phone_cond=use_phone_cond,
|
110 |
+
)
|
111 |
+
|
112 |
+
def mask_prob(self, t):
|
113 |
+
return torch.sin(t * np.pi / 2).to(t.device)
|
114 |
+
|
115 |
+
def forward_diffusion(self, x0, t):
|
116 |
+
# x0: semantic tokens (B, T)
|
117 |
+
new_t = t
|
118 |
+
mask_prob = self.mask_prob(new_t) # (B,)
|
119 |
+
# if mask_prob[i] < 0.2, mask_prob[i] = 0.2
|
120 |
+
mask_prob = torch.where(
|
121 |
+
mask_prob < 0.2, torch.ones_like(mask_prob) * 0.2, mask_prob
|
122 |
+
)
|
123 |
+
mask_token = self.mask_emb(
|
124 |
+
torch.LongTensor([0]).to(x0.device)
|
125 |
+
) # (1, hidden_size)
|
126 |
+
|
127 |
+
xt = torch.zeros(x0.shape[0], x0.shape[1], self.hidden_size).to(x0.device)
|
128 |
+
|
129 |
+
cfg_scale = self.cfg_scale
|
130 |
+
|
131 |
+
# a segment of r% sequence length is masked, where r ~ U[60, 100]
|
132 |
+
if torch.rand(1) > cfg_scale:
|
133 |
+
prompt_len = torch.randint(
|
134 |
+
min(x0.shape[1] // 4, 5), int(x0.shape[1] * 0.4), (x0.shape[0],)
|
135 |
+
).to(
|
136 |
+
x0.device
|
137 |
+
) # (B,)
|
138 |
+
else:
|
139 |
+
prompt_len = torch.zeros(x0.shape[0]).to(x0) # (B,)
|
140 |
+
|
141 |
+
# get is prompt
|
142 |
+
is_prompt = torch.zeros_like(x0[:, :]) # (B, T)
|
143 |
+
col_indices = (
|
144 |
+
torch.arange(is_prompt.shape[1])
|
145 |
+
.repeat(is_prompt.shape[0], 1)
|
146 |
+
.to(prompt_len)
|
147 |
+
) # (B, T)
|
148 |
+
is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # (B, T) 1 if prompt
|
149 |
+
|
150 |
+
# Add mask
|
151 |
+
mask = torch.bernoulli(torch.ones_like(x0[:, :]) * mask_prob[..., None])
|
152 |
+
mask[is_prompt.bool()] = 0
|
153 |
+
mask_num = mask[:,].sum(dim=1, keepdim=False)
|
154 |
+
all_zero_mask = (mask_num == 0).bool()
|
155 |
+
row_indices_to_modify = torch.nonzero(all_zero_mask)
|
156 |
+
mask[row_indices_to_modify, prompt_len[row_indices_to_modify]] = 1
|
157 |
+
mask = mask[..., None] # (B, T, 1)
|
158 |
+
xt = (
|
159 |
+
xt + mask * mask_token[:, None, :] + (1 - mask) * self.cond_emb(x0[:, :])
|
160 |
+
) # (B, T, hidden_size)
|
161 |
+
|
162 |
+
return xt, new_t, mask, prompt_len, mask_prob
|
163 |
+
|
164 |
+
def loss_t(self, x0, x_mask, t, phone_embedding=None, phone_mask=None):
|
165 |
+
xt, new_t, mask, prompt_len, mask_prob = self.forward_diffusion(x0, t)
|
166 |
+
# xt: (B, T, hidden_size)
|
167 |
+
# new_t: (B,)
|
168 |
+
# mask: (B, T, 1) mask if 1, not mask if 0
|
169 |
+
# prompt_len: (B,)
|
170 |
+
# mask_prob: (B,)
|
171 |
+
|
172 |
+
# # drop all condition for cfg, so if prompt_len is 0, we also drop phone_embedding
|
173 |
+
# if self.use_phone_cond and phone_embedding != None:
|
174 |
+
# phone_embedding = phone_embedding * torch.where(prompt_len > 0, torch.ones_like(prompt_len), torch.zeros_like(prompt_len)).to(x0.device).unsqueeze(-1).unsqueeze(-1)
|
175 |
+
|
176 |
+
embeds = self.diff_estimator(
|
177 |
+
xt, new_t, x_mask, phone_embedding=phone_embedding, phone_mask=phone_mask
|
178 |
+
) # (B, T, hidden_size)
|
179 |
+
logits = self.to_logit(embeds) # (B, T, codebook_size)
|
180 |
+
|
181 |
+
# final mask used for loss calculation
|
182 |
+
final_mask = mask * x_mask[..., None] # (B, T, 1)
|
183 |
+
|
184 |
+
return logits, final_mask, x0, prompt_len, mask_prob
|
185 |
+
|
186 |
+
def compute_loss(self, x0, x_mask, phone_embedding=None, phone_mask=None):
|
187 |
+
# x0: (B, T)
|
188 |
+
# x_mask: (B, T) mask is 0 for padding
|
189 |
+
t = torch.rand(x0.shape[0], device=x0.device, requires_grad=False)
|
190 |
+
t = torch.clamp(t, 1e-5, 1.0)
|
191 |
+
return self.loss_t(x0, x_mask, t, phone_embedding, phone_mask)
|
192 |
+
|
193 |
+
def reset_parameters(self):
|
194 |
+
def _reset_parameters(m):
|
195 |
+
if isinstance(m, nn.MultiheadAttention):
|
196 |
+
if m._qkv_same_embed_dim:
|
197 |
+
nn.init.normal_(m.in_proj_weight, std=0.02)
|
198 |
+
else:
|
199 |
+
nn.init.normal_(m.q_proj_weight, std=0.02)
|
200 |
+
nn.init.normal_(m.k_proj_weight, std=0.02)
|
201 |
+
nn.init.normal_(m.v_proj_weight, std=0.02)
|
202 |
+
|
203 |
+
if m.in_proj_bias is not None:
|
204 |
+
nn.init.constant_(m.in_proj_bias, 0.0)
|
205 |
+
nn.init.constant_(m.out_proj.bias, 0.0)
|
206 |
+
if m.bias_k is not None:
|
207 |
+
nn.init.xavier_normal_(m.bias_k)
|
208 |
+
if m.bias_v is not None:
|
209 |
+
nn.init.xavier_normal_(m.bias_v)
|
210 |
+
|
211 |
+
elif (
|
212 |
+
isinstance(m, nn.Conv1d)
|
213 |
+
or isinstance(m, nn.ConvTranspose1d)
|
214 |
+
or isinstance(m, nn.Conv2d)
|
215 |
+
or isinstance(m, nn.ConvTranspose2d)
|
216 |
+
):
|
217 |
+
m.weight.data.normal_(0.0, 0.02)
|
218 |
+
|
219 |
+
elif isinstance(m, nn.Linear):
|
220 |
+
m.weight.data.normal_(mean=0.0, std=0.02)
|
221 |
+
if m.bias is not None:
|
222 |
+
m.bias.data.zero_()
|
223 |
+
|
224 |
+
elif isinstance(m, nn.Embedding):
|
225 |
+
m.weight.data.normal_(mean=0.0, std=0.02)
|
226 |
+
if m.padding_idx is not None:
|
227 |
+
m.weight.data[m.padding_idx].zero_()
|
228 |
+
|
229 |
+
self.apply(_reset_parameters)
|
230 |
+
|
231 |
+
@torch.no_grad()
|
232 |
+
def reverse_diffusion(
|
233 |
+
self,
|
234 |
+
prompt,
|
235 |
+
target_len,
|
236 |
+
phone_id,
|
237 |
+
prompt_mask=None,
|
238 |
+
temp=0.9,
|
239 |
+
filter_thres=0.98,
|
240 |
+
n_timesteps=25,
|
241 |
+
cfg=1.0,
|
242 |
+
rescale_cfg=1.0,
|
243 |
+
):
|
244 |
+
# prompt: (B, T)
|
245 |
+
if self.use_phone_cond and phone_id != None:
|
246 |
+
phone_embedding = self.phone_emb(phone_id)
|
247 |
+
else:
|
248 |
+
phone_embedding = None
|
249 |
+
|
250 |
+
prompt_code = prompt # (B, prompt_len)
|
251 |
+
prompt_len = prompt_code.shape[1]
|
252 |
+
|
253 |
+
x_mask = torch.ones(prompt_code.shape[0], target_len).to(
|
254 |
+
prompt_code.device
|
255 |
+
) # (B, target_len)
|
256 |
+
phone_mask = torch.ones_like(phone_id)
|
257 |
+
|
258 |
+
if prompt_mask == None:
|
259 |
+
prompt_mask = torch.ones(prompt_code.shape[0], prompt_len).to(
|
260 |
+
prompt_code.device
|
261 |
+
) # (B, prompt_len)
|
262 |
+
|
263 |
+
cum = torch.zeros(x_mask.shape[0], x_mask.shape[1], self.hidden_size).to(
|
264 |
+
x_mask.device
|
265 |
+
) # (B, T, hidden_size)
|
266 |
+
|
267 |
+
bsz, seq_len, _ = cum.shape
|
268 |
+
|
269 |
+
choice_temp = 1.0
|
270 |
+
start_temp = temp # temperature for sampling
|
271 |
+
start_choice_temp = choice_temp # temperature for choicing mask tokens
|
272 |
+
|
273 |
+
xt = torch.LongTensor(bsz, seq_len).to(x_mask.device)
|
274 |
+
|
275 |
+
steps = n_timesteps
|
276 |
+
to_logit = self.to_logit
|
277 |
+
cond_emb = self.cond_emb
|
278 |
+
|
279 |
+
mask_token = self.mask_emb(torch.LongTensor([0]).to(xt.device))
|
280 |
+
mask = torch.full((bsz, seq_len, 1), True).to(x_mask.device) # (B, T, 1)
|
281 |
+
seq = torch.full((bsz, seq_len), 0).to(x_mask.device)
|
282 |
+
h = 1.0 / steps
|
283 |
+
|
284 |
+
cur_prompt = 0
|
285 |
+
cur_prompt = cur_prompt + cond_emb(prompt_code)
|
286 |
+
|
287 |
+
t_list = [1.0 - i * h for i in range(steps)]
|
288 |
+
t_list.append(0.0)
|
289 |
+
for i in range(steps):
|
290 |
+
t = t_list[i] * torch.ones(bsz).to(x_mask.device)
|
291 |
+
token = cond_emb(seq) # (B, T, hidden_size)
|
292 |
+
cur = cum + mask * mask_token[:, None, :] + (~mask) * token
|
293 |
+
|
294 |
+
xt_input = torch.cat([cur_prompt, cur], dim=1) # (B, T, hidden_size)
|
295 |
+
xt_mask = torch.cat(
|
296 |
+
[prompt_mask, x_mask], dim=1
|
297 |
+
) # (B, T), mask is 0 for padding
|
298 |
+
|
299 |
+
embeds = self.diff_estimator(
|
300 |
+
xt_input,
|
301 |
+
t,
|
302 |
+
xt_mask,
|
303 |
+
phone_embedding=phone_embedding,
|
304 |
+
phone_mask=phone_mask,
|
305 |
+
)
|
306 |
+
embeds = embeds[:, prompt_len:, :]
|
307 |
+
|
308 |
+
# classifier free guidance
|
309 |
+
# phone_embedding=phone_embedding[:,phone_embedding.shape[1]:,:] means phone_embedding is None
|
310 |
+
if cfg > 0:
|
311 |
+
mask_embeds = self.diff_estimator(
|
312 |
+
cur,
|
313 |
+
t,
|
314 |
+
x_mask,
|
315 |
+
phone_embedding=phone_embedding[:, phone_embedding.shape[1] :, :],
|
316 |
+
phone_mask=phone_mask[:, prompt_len:],
|
317 |
+
)
|
318 |
+
pos_emb_std = embeds.std() # std(g_cond)
|
319 |
+
embeds = embeds + cfg * (embeds - mask_embeds) # g_cfg
|
320 |
+
rescale_embeds = embeds * pos_emb_std / embeds.std() # g_final
|
321 |
+
embeds = rescale_cfg * rescale_embeds + (1 - rescale_cfg) * embeds
|
322 |
+
|
323 |
+
logits = to_logit(embeds) # (B, T, codebook_size)
|
324 |
+
annealing_scale = t_list[i]
|
325 |
+
|
326 |
+
choice_temp = start_choice_temp * annealing_scale
|
327 |
+
temp = start_temp * annealing_scale
|
328 |
+
logits = top_k(logits, filter_thres)
|
329 |
+
|
330 |
+
if i == steps - 1:
|
331 |
+
# greedy
|
332 |
+
if steps == 1:
|
333 |
+
temp = 0.2
|
334 |
+
sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
|
335 |
+
else:
|
336 |
+
sampled_ids = logits.argmax(dim=-1)
|
337 |
+
|
338 |
+
else:
|
339 |
+
# sampling
|
340 |
+
sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3))
|
341 |
+
|
342 |
+
seq = torch.where(mask.squeeze(-1), sampled_ids, seq)
|
343 |
+
|
344 |
+
scores = logits.softmax(dim=-1)
|
345 |
+
scores = scores.gather(2, rearrange(sampled_ids, "b n -> b n 1"))
|
346 |
+
scores = rearrange(scores, "b n 1 -> b n")
|
347 |
+
|
348 |
+
scores = choice_temp * gumbel_noise(scores) + scores
|
349 |
+
scores = 1 - scores
|
350 |
+
|
351 |
+
next_t = t_list[i + 1] * torch.ones(bsz).to(x_mask.device)
|
352 |
+
|
353 |
+
next_mask_num = (self.mask_prob(next_t) * seq_len).long()[0].item()
|
354 |
+
|
355 |
+
if next_mask_num == 0:
|
356 |
+
break
|
357 |
+
scores = scores.masked_fill(
|
358 |
+
~mask.squeeze(-1), -torch.finfo(scores.dtype).max
|
359 |
+
)
|
360 |
+
|
361 |
+
mask_indices = scores.topk(next_mask_num, dim=-1).indices
|
362 |
+
mask = torch.zeros_like(scores, dtype=torch.bool).scatter(
|
363 |
+
1, mask_indices, True
|
364 |
+
)
|
365 |
+
seq = seq.masked_fill(mask, 0)
|
366 |
+
|
367 |
+
mask = mask.unsqueeze(-1)
|
368 |
+
|
369 |
+
cum = cum + cond_emb(seq)
|
370 |
+
xt = seq
|
371 |
+
|
372 |
+
return xt
|
373 |
+
|
374 |
+
def forward(self, x0, x_mask, phone_id=None, phone_mask=None):
|
375 |
+
# x0: (B, T)
|
376 |
+
# x_mask: (B, T) mask is 0 for padding
|
377 |
+
if self.use_phone_cond and phone_id != None:
|
378 |
+
phone_embedding = self.phone_emb(phone_id)
|
379 |
+
else:
|
380 |
+
phone_embedding = None
|
381 |
+
|
382 |
+
logits, final_mask, x0, prompt_len, mask_prob = self.compute_loss(
|
383 |
+
x0, x_mask, phone_embedding, phone_mask=phone_mask
|
384 |
+
)
|
385 |
+
return logits, final_mask, x0, prompt_len, mask_prob
|
models/tts/tadicodec/__pycache__/infer_utils.cpython-310.pyc
ADDED
Binary file (848 Bytes). View file
|
|
models/tts/tadicodec/__pycache__/inference_tadicodec.cpython-310.pyc
ADDED
Binary file (7.14 kB). View file
|
|
models/tts/tadicodec/__pycache__/llama_nar_prefix.cpython-310.pyc
ADDED
Binary file (15.4 kB). View file
|
|
models/tts/tadicodec/__pycache__/modeling_tadicodec.cpython-310.pyc
ADDED
Binary file (15 kB). View file
|
|
models/tts/tadicodec/infer_utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.codec.melvqgan.melspec import MelSpectrogram
|
2 |
+
from models.codec.amphion_codec.vocos import Vocos
|
3 |
+
|
4 |
+
|
5 |
+
def build_vocoder_model(cfg, device):
|
6 |
+
vocoder_model = Vocos(cfg=cfg.model.vocos)
|
7 |
+
vocoder_model.eval()
|
8 |
+
vocoder_model.to(device)
|
9 |
+
return vocoder_model
|
10 |
+
|
11 |
+
|
12 |
+
def build_mel_model(cfg, device):
|
13 |
+
mel_model = MelSpectrogram(
|
14 |
+
sampling_rate=cfg.preprocess.sample_rate,
|
15 |
+
n_fft=cfg.preprocess.n_fft,
|
16 |
+
num_mels=cfg.preprocess.num_mels,
|
17 |
+
hop_size=cfg.preprocess.hop_size,
|
18 |
+
win_size=cfg.preprocess.win_size,
|
19 |
+
fmin=cfg.preprocess.fmin,
|
20 |
+
fmax=cfg.preprocess.fmax,
|
21 |
+
)
|
22 |
+
mel_model.eval()
|
23 |
+
mel_model.to(device)
|
24 |
+
return mel_model
|
models/tts/tadicodec/inference_tadicodec.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import librosa
|
8 |
+
import safetensors
|
9 |
+
import accelerate
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
+
|
12 |
+
from transformers import AutoTokenizer
|
13 |
+
|
14 |
+
from models.tts.tadicodec.infer_utils import build_vocoder_model, build_mel_model
|
15 |
+
from models.tts.tadicodec.modeling_tadicodec import TaDiCodec
|
16 |
+
|
17 |
+
|
18 |
+
class TaDiCodecPipline(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
cfg,
|
22 |
+
model_path: str,
|
23 |
+
device: torch.device,
|
24 |
+
tokenizer_path: Optional[str] = None,
|
25 |
+
vocoder_ckpt_path: Optional[str] = None,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.cfg = cfg
|
29 |
+
self.model_path = model_path
|
30 |
+
self.device = device
|
31 |
+
|
32 |
+
# tokenizer
|
33 |
+
self.tokenizer = (
|
34 |
+
AutoTokenizer.from_pretrained(tokenizer_path)
|
35 |
+
if tokenizer_path is not None
|
36 |
+
else os.path.join(model_path, "text_tokenizer")
|
37 |
+
)
|
38 |
+
|
39 |
+
# mel feature extractor
|
40 |
+
self.mel_model = build_mel_model(cfg, device)
|
41 |
+
|
42 |
+
# main model
|
43 |
+
tadiconfig = cfg.model.tadicodec
|
44 |
+
|
45 |
+
self.tadicodec = TaDiCodec(cfg=tadiconfig)
|
46 |
+
safetensors.torch.load_model(self.tadicodec, model_path, strict=False)
|
47 |
+
self.tadicodec.to(torch.float32)
|
48 |
+
self.tadicodec.to(device)
|
49 |
+
self.tadicodec.eval()
|
50 |
+
|
51 |
+
# vocoder
|
52 |
+
self.vocoder_model = build_vocoder_model(cfg, device)
|
53 |
+
v_path = (
|
54 |
+
vocoder_ckpt_path
|
55 |
+
if vocoder_ckpt_path
|
56 |
+
else os.path.join(model_path, "vocoder")
|
57 |
+
)
|
58 |
+
accelerate.load_checkpoint_and_dispatch(self.vocoder_model, v_path)
|
59 |
+
|
60 |
+
@classmethod
|
61 |
+
def from_pretrained(
|
62 |
+
cls,
|
63 |
+
ckpt_dir: str = "./ckpt/TaDiCodec",
|
64 |
+
device: Optional[torch.device] = None,
|
65 |
+
auto_download: bool = True,
|
66 |
+
):
|
67 |
+
"""Create a pipeline from a checkpoint directory or Hugging Face model ID.
|
68 |
+
|
69 |
+
Expected structure under `ckpt_dir`:
|
70 |
+
- config.json # model and preprocess config
|
71 |
+
- model.safetensors # TaDiCodec weights
|
72 |
+
- vocoder/ # directory containing vocoder weights
|
73 |
+
model.safetensors or other *.safetensors
|
74 |
+
- text_tokenizer/ # directory containing text tokenizer
|
75 |
+
tokenizer.json
|
76 |
+
tokenizer_config.json
|
77 |
+
|
78 |
+
Args:
|
79 |
+
ckpt_dir: Directory containing `config.json`, `model.safetensors`, and `vocoder/`, or Hugging Face model ID.
|
80 |
+
device: Device to place models on. Defaults to CUDA if available else CPU.
|
81 |
+
auto_download: Whether to automatically download models from Hugging Face if not found locally
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
TaDiCodecPipline
|
85 |
+
"""
|
86 |
+
import os
|
87 |
+
import glob
|
88 |
+
from utils.util import load_config
|
89 |
+
|
90 |
+
resolved_device = (
|
91 |
+
device
|
92 |
+
if device is not None
|
93 |
+
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
94 |
+
)
|
95 |
+
|
96 |
+
# Resolve checkpoint directory
|
97 |
+
resolved_ckpt_dir = cls._resolve_model_path(
|
98 |
+
ckpt_dir, auto_download=auto_download, model_type="tadicodec"
|
99 |
+
)
|
100 |
+
|
101 |
+
# Load config
|
102 |
+
config_path = os.path.join(resolved_ckpt_dir, "config.json")
|
103 |
+
if not os.path.exists(config_path):
|
104 |
+
raise FileNotFoundError(f"Config not found: {config_path}")
|
105 |
+
cfg = load_config(config_path, lowercase=False)
|
106 |
+
|
107 |
+
# Resolve main model weights
|
108 |
+
model_weights_path = os.path.join(resolved_ckpt_dir, "model.safetensors")
|
109 |
+
|
110 |
+
# Resolve vocoder weights
|
111 |
+
vocoder_ckpt_path = os.path.join(resolved_ckpt_dir, "vocoder")
|
112 |
+
|
113 |
+
text_tokenizer_dir = os.path.join(resolved_ckpt_dir, "text_tokenizer")
|
114 |
+
|
115 |
+
return cls(
|
116 |
+
cfg=cfg,
|
117 |
+
model_path=model_weights_path,
|
118 |
+
device=resolved_device,
|
119 |
+
vocoder_ckpt_path=vocoder_ckpt_path,
|
120 |
+
tokenizer_path=text_tokenizer_dir,
|
121 |
+
)
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def _resolve_model_path(
|
125 |
+
model_path: str, auto_download: bool = True, model_type: str = "tadicodec"
|
126 |
+
) -> str:
|
127 |
+
"""
|
128 |
+
Resolve model path, downloading from Hugging Face if necessary
|
129 |
+
|
130 |
+
Args:
|
131 |
+
model_path: Local path or Hugging Face model ID
|
132 |
+
auto_download: Whether to auto-download from HF
|
133 |
+
model_type: Type of model ("tadicodec")
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
Resolved local path
|
137 |
+
"""
|
138 |
+
# If it's already a local path and exists, return as is
|
139 |
+
if os.path.exists(model_path):
|
140 |
+
return model_path
|
141 |
+
|
142 |
+
# If it looks like a Hugging Face model ID (contains '/')
|
143 |
+
if "/" in model_path and auto_download:
|
144 |
+
print(f"Downloading {model_type} model from Hugging Face: {model_path}")
|
145 |
+
try:
|
146 |
+
# Download to cache directory
|
147 |
+
cache_dir = os.path.join(
|
148 |
+
os.path.expanduser("~"), ".cache", "huggingface", "hub"
|
149 |
+
)
|
150 |
+
downloaded_path = snapshot_download(
|
151 |
+
repo_id=model_path,
|
152 |
+
cache_dir=cache_dir,
|
153 |
+
local_dir_use_symlinks=False,
|
154 |
+
)
|
155 |
+
print(
|
156 |
+
f"Successfully downloaded {model_type} model to: {downloaded_path}"
|
157 |
+
)
|
158 |
+
return downloaded_path
|
159 |
+
except Exception as e:
|
160 |
+
print(f"Failed to download {model_type} model from Hugging Face: {e}")
|
161 |
+
raise ValueError(
|
162 |
+
f"Could not download {model_type} model from {model_path}"
|
163 |
+
)
|
164 |
+
|
165 |
+
# If it's a local path that doesn't exist
|
166 |
+
if not os.path.exists(model_path):
|
167 |
+
if auto_download:
|
168 |
+
raise ValueError(
|
169 |
+
f"Model path does not exist: {model_path}. Set auto_download=True to download from Hugging Face."
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
raise FileNotFoundError(f"Model path does not exist: {model_path}")
|
173 |
+
|
174 |
+
return model_path
|
175 |
+
|
176 |
+
@torch.no_grad()
|
177 |
+
def __call__(
|
178 |
+
self,
|
179 |
+
text: Optional[str] = None,
|
180 |
+
speech_path: Optional[str] = None,
|
181 |
+
prompt_text: Optional[str] = None,
|
182 |
+
prompt_speech_path: Optional[str] = None,
|
183 |
+
n_timesteps: int = 32,
|
184 |
+
return_code: bool = False,
|
185 |
+
cfg_scale: float = 2.0,
|
186 |
+
):
|
187 |
+
# tokenize text
|
188 |
+
text_input_ids = self.tokenize_text(text, prompt_text)
|
189 |
+
|
190 |
+
# extract mel features
|
191 |
+
target_mel = self.extract_mel_feature(speech_path)
|
192 |
+
prompt_mel = (
|
193 |
+
self.extract_mel_feature(prompt_speech_path) if prompt_speech_path else None
|
194 |
+
)
|
195 |
+
|
196 |
+
# encode to codes from mel
|
197 |
+
if prompt_mel is not None:
|
198 |
+
vq_emb, indices = self.encode(torch.cat([prompt_mel, target_mel], dim=1))
|
199 |
+
else:
|
200 |
+
vq_emb, indices = self.encode(target_mel)
|
201 |
+
|
202 |
+
if return_code:
|
203 |
+
return indices
|
204 |
+
|
205 |
+
# decode mel from codes + optional text/prompt
|
206 |
+
rec_mel = self.decode(
|
207 |
+
vq_emb=vq_emb,
|
208 |
+
text_token_ids=text_input_ids,
|
209 |
+
prompt_mel=(
|
210 |
+
prompt_mel if prompt_mel is not None else target_mel[:, : 50 * 3]
|
211 |
+
),
|
212 |
+
n_timesteps=n_timesteps,
|
213 |
+
cfg=cfg_scale,
|
214 |
+
rescale_cfg=0.75,
|
215 |
+
)
|
216 |
+
|
217 |
+
# vocoder
|
218 |
+
rec_audio = (
|
219 |
+
self.vocoder_model(rec_mel.transpose(1, 2)).detach().cpu().numpy()[0][0]
|
220 |
+
)
|
221 |
+
return rec_audio
|
222 |
+
|
223 |
+
def tokenize_text(
|
224 |
+
self, text: Optional[str] = None, prompt_text: Optional[str] = None
|
225 |
+
):
|
226 |
+
if self.tokenizer is None or text is None:
|
227 |
+
return None
|
228 |
+
if prompt_text is not None:
|
229 |
+
text_token_ids = self.tokenizer(
|
230 |
+
prompt_text + text, return_tensors="pt", add_special_tokens=False
|
231 |
+
).input_ids.to(self.device)
|
232 |
+
else:
|
233 |
+
text_token_ids = self.tokenizer(
|
234 |
+
text, return_tensors="pt", add_special_tokens=False
|
235 |
+
).input_ids.to(self.device)
|
236 |
+
return text_token_ids
|
237 |
+
|
238 |
+
@torch.no_grad()
|
239 |
+
def extract_mel_feature(self, speech_path: Optional[str]):
|
240 |
+
assert speech_path is not None and os.path.exists(
|
241 |
+
speech_path
|
242 |
+
), f"Invalid speech_path: {speech_path}"
|
243 |
+
speech = librosa.load(speech_path, sr=24000)[0]
|
244 |
+
speech = torch.tensor(speech).to(self.device).unsqueeze(0)
|
245 |
+
mel_feature = self.mel_model(speech) # (B, n_mels, T)
|
246 |
+
mel_feature = mel_feature.transpose(1, 2) # (B, T, n_mels)
|
247 |
+
mel_feature = (mel_feature - self.cfg.preprocess.mel_mean) / math.sqrt(
|
248 |
+
self.cfg.preprocess.mel_var
|
249 |
+
)
|
250 |
+
return mel_feature
|
251 |
+
|
252 |
+
@torch.no_grad()
|
253 |
+
def encode(self, mel_feat: torch.Tensor):
|
254 |
+
vq_emb, indices = self.tadicodec.encode(
|
255 |
+
mel_feat, torch.ones(mel_feat.shape[0], mel_feat.shape[1]).to(self.device)
|
256 |
+
)
|
257 |
+
return vq_emb, indices
|
258 |
+
|
259 |
+
@torch.no_grad()
|
260 |
+
def decode(
|
261 |
+
self,
|
262 |
+
vq_emb: Optional[torch.Tensor] = None,
|
263 |
+
indices: Optional[torch.Tensor] = None,
|
264 |
+
text_token_ids: Optional[torch.Tensor] = None,
|
265 |
+
prompt_mel: Optional[torch.Tensor] = None,
|
266 |
+
n_timesteps: int = 32,
|
267 |
+
cfg: float = 1.0,
|
268 |
+
rescale_cfg: float = 0.75,
|
269 |
+
):
|
270 |
+
rec_mel = self.tadicodec.reverse_diffusion(
|
271 |
+
vq_emb=vq_emb,
|
272 |
+
indices=indices,
|
273 |
+
text_ids=text_token_ids,
|
274 |
+
prompt_mel=prompt_mel,
|
275 |
+
n_timesteps=n_timesteps,
|
276 |
+
cfg=cfg,
|
277 |
+
rescale_cfg=rescale_cfg,
|
278 |
+
)
|
279 |
+
return rec_mel
|
models/tts/tadicodec/llama_nar_prefix.py
ADDED
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import torch.nn as nn
|
7 |
+
from typing import List, Optional, Tuple, Union
|
8 |
+
import math
|
9 |
+
|
10 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
11 |
+
from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
|
12 |
+
|
13 |
+
|
14 |
+
class SinusoidalPosEmb(nn.Module):
|
15 |
+
"""
|
16 |
+
Sinusoidal Positional Embedding module.
|
17 |
+
|
18 |
+
This module generates sinusoidal positional embeddings for a given 1D input tensor,
|
19 |
+
which is commonly used for representing timesteps in diffusion models.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, dim: int):
|
23 |
+
"""
|
24 |
+
Initializes the SinusoidalPosEmb module.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
dim (int): The dimension of the embedding.
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self.dim = dim
|
31 |
+
|
32 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
33 |
+
"""
|
34 |
+
Generates the positional embedding.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (torch.Tensor): A 1D tensor of positions (e.g., timesteps), shape `(batch_size,)`.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: The positional embeddings, shape `(batch_size, dim)`.
|
41 |
+
"""
|
42 |
+
device = x.device
|
43 |
+
half_dim = self.dim // 2
|
44 |
+
# Calculate the embedding frequencies
|
45 |
+
emb = math.log(10000) / (half_dim - 1)
|
46 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
47 |
+
# Create the embedding matrix by multiplying positions with frequencies
|
48 |
+
emb = x[:, None] * emb[None, :] * 1.0
|
49 |
+
# Concatenate sine and cosine components to form the final embedding
|
50 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
51 |
+
return emb
|
52 |
+
|
53 |
+
|
54 |
+
class LlamaAdaptiveRMSNorm(nn.Module):
|
55 |
+
"""
|
56 |
+
Adaptive Root Mean Square Layer Normalization.
|
57 |
+
|
58 |
+
This is a variant of RMSNorm where the scaling factor (weight) is adaptively
|
59 |
+
predicted from a conditional embedding, allowing for conditional modulation
|
60 |
+
of the normalized hidden states.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
hidden_size: int = 1024,
|
66 |
+
eps: float = 1e-6,
|
67 |
+
dim_cond: int = 1024,
|
68 |
+
use_cond: bool = True,
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Initializes the LlamaAdaptiveRMSNorm module.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
hidden_size (int): The dimension of the hidden states to be normalized.
|
75 |
+
eps (float): A small value added to the variance for numerical stability.
|
76 |
+
dim_cond (int): The dimension of the conditional embedding.
|
77 |
+
use_cond (bool): If True, use conditional embedding to modulate the output.
|
78 |
+
"""
|
79 |
+
super().__init__()
|
80 |
+
self.use_cond = use_cond
|
81 |
+
if self.use_cond:
|
82 |
+
# Linear layer to project the conditional embedding to the hidden size
|
83 |
+
self.to_weight = nn.Linear(dim_cond, hidden_size)
|
84 |
+
# Initialize weights to zero and bias to one for an identity transformation at the start
|
85 |
+
nn.init.zeros_(self.to_weight.weight)
|
86 |
+
nn.init.ones_(self.to_weight.bias)
|
87 |
+
self.variance_epsilon = eps
|
88 |
+
self._is_hf_initialized = True # Disable automatic Hugging Face initialization
|
89 |
+
|
90 |
+
def forward(
|
91 |
+
self, hidden_states: torch.Tensor, cond_embedding: Optional[torch.Tensor] = None
|
92 |
+
) -> torch.Tensor:
|
93 |
+
"""
|
94 |
+
Applies the adaptive RMS normalization.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
hidden_states (torch.Tensor): The input tensor, shape `(batch, seq_len, hidden_size)`.
|
98 |
+
cond_embedding (Optional[torch.Tensor]): The conditional embedding, shape `(batch, dim_cond)` or `(batch, seq_len, dim_cond)`.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
torch.Tensor: The normalized and modulated hidden states.
|
102 |
+
"""
|
103 |
+
input_dtype = hidden_states.dtype
|
104 |
+
# Calculate variance and normalize the hidden states
|
105 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
106 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
107 |
+
|
108 |
+
# Apply conditional modulation if enabled
|
109 |
+
if self.use_cond:
|
110 |
+
# Project conditional embedding to get the scaling weight
|
111 |
+
weight = self.to_weight(cond_embedding)
|
112 |
+
# Unsqueeze if the conditional embedding is per-batch instead of per-token
|
113 |
+
if len(weight.shape) == 2:
|
114 |
+
weight = weight.unsqueeze(1)
|
115 |
+
|
116 |
+
# Apply the learned scaling factor
|
117 |
+
return (weight * hidden_states).to(input_dtype)
|
118 |
+
else:
|
119 |
+
return hidden_states
|
120 |
+
|
121 |
+
|
122 |
+
class LlamaNARDecoderLayer(LlamaDecoderLayer):
|
123 |
+
"""
|
124 |
+
A Non-Autoregressive (NAR) Llama Decoder Layer using adaptive layer normalization.
|
125 |
+
|
126 |
+
This class overrides the standard LlamaDecoderLayer to replace its RMSNorm
|
127 |
+
modules with LlamaAdaptiveRMSNorm, allowing it to be conditioned on an external embedding
|
128 |
+
(e.g., from diffusion timesteps).
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, config: LlamaConfig, layer_idx: int, use_cond: bool = True):
|
132 |
+
"""
|
133 |
+
Overrides the LlamaDecoderLayer to use adaptive layer normalization.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
config (LlamaConfig): The configuration object for the Llama model.
|
137 |
+
layer_idx (int): The index of the layer.
|
138 |
+
use_cond (bool): Whether to use adaptive conditioning in the layer norms.
|
139 |
+
"""
|
140 |
+
super().__init__(config, layer_idx) # init attention, mlp, etc.
|
141 |
+
# Override the standard layer norms with our adaptive versions
|
142 |
+
self.input_layernorm = LlamaAdaptiveRMSNorm(
|
143 |
+
config.hidden_size,
|
144 |
+
eps=config.rms_norm_eps,
|
145 |
+
dim_cond=config.hidden_size,
|
146 |
+
use_cond=use_cond,
|
147 |
+
)
|
148 |
+
self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
|
149 |
+
config.hidden_size,
|
150 |
+
eps=config.rms_norm_eps,
|
151 |
+
dim_cond=config.hidden_size,
|
152 |
+
use_cond=use_cond,
|
153 |
+
)
|
154 |
+
|
155 |
+
def forward(
|
156 |
+
self,
|
157 |
+
hidden_states: torch.Tensor,
|
158 |
+
cond_embedding: torch.Tensor,
|
159 |
+
attention_mask: Optional[torch.Tensor] = None,
|
160 |
+
position_ids: Optional[torch.LongTensor] = None,
|
161 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
162 |
+
output_attentions: Optional[bool] = False,
|
163 |
+
use_cache: Optional[bool] = False,
|
164 |
+
) -> Tuple[
|
165 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
166 |
+
]:
|
167 |
+
"""
|
168 |
+
Forward pass for the NAR decoder layer, including conditional embedding.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
hidden_states (torch.FloatTensor): Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
172 |
+
cond_embedding (torch.Tensor): Conditional embedding for adaptive normalization.
|
173 |
+
attention_mask (Optional[torch.Tensor]): Attention mask of size
|
174 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
175 |
+
position_ids (Optional[torch.LongTensor]): Indices of positions of each input sequence tokens.
|
176 |
+
past_key_value (Optional[Tuple[torch.Tensor]]): Cached past key and value projection states.
|
177 |
+
output_attentions (Optional[bool]): Whether or not to return the attentions tensors of all attention layers.
|
178 |
+
use_cache (Optional[bool]): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
182 |
+
A tuple containing the output hidden states and optionally the attention weights and past key/value states.
|
183 |
+
"""
|
184 |
+
residual = hidden_states
|
185 |
+
|
186 |
+
# Apply adaptive pre-attention layer norm
|
187 |
+
hidden_states = self.input_layernorm(
|
188 |
+
hidden_states, cond_embedding=cond_embedding
|
189 |
+
)
|
190 |
+
|
191 |
+
# Self Attention block
|
192 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
193 |
+
hidden_states=hidden_states,
|
194 |
+
attention_mask=attention_mask,
|
195 |
+
position_ids=position_ids,
|
196 |
+
past_key_value=past_key_value,
|
197 |
+
output_attentions=output_attentions,
|
198 |
+
use_cache=use_cache,
|
199 |
+
)
|
200 |
+
hidden_states = residual + hidden_states
|
201 |
+
|
202 |
+
# Fully Connected block
|
203 |
+
residual = hidden_states
|
204 |
+
# Apply adaptive post-attention layer norm
|
205 |
+
hidden_states = self.post_attention_layernorm(
|
206 |
+
hidden_states, cond_embedding=cond_embedding
|
207 |
+
)
|
208 |
+
hidden_states = self.mlp(hidden_states)
|
209 |
+
hidden_states = residual + hidden_states
|
210 |
+
|
211 |
+
outputs = (hidden_states,)
|
212 |
+
|
213 |
+
if output_attentions:
|
214 |
+
outputs += (self_attn_weights,)
|
215 |
+
|
216 |
+
if use_cache:
|
217 |
+
outputs += (present_key_value,)
|
218 |
+
|
219 |
+
return outputs
|
220 |
+
|
221 |
+
|
222 |
+
class DiffLlamaPrefix(LlamaModel):
|
223 |
+
"""
|
224 |
+
A Llama-based non-autoregressive transformer model for diffusion tasks.
|
225 |
+
|
226 |
+
This model uses a Llama architecture but modifies it for non-autoregressive generation.
|
227 |
+
Key features:
|
228 |
+
1. Non-causal (fully-visible) attention mask.
|
229 |
+
2. Adaptive layer normalization conditioned on diffusion timesteps.
|
230 |
+
3. Ability to be conditioned on text embeddings, which are prepended as a prefix.
|
231 |
+
4. Input and output linear projection layers for feature mapping.
|
232 |
+
"""
|
233 |
+
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
hidden_size: int = 1024,
|
237 |
+
num_heads: int = 16,
|
238 |
+
num_layers: int = 16,
|
239 |
+
in_dim: Optional[int] = None,
|
240 |
+
out_dim: Optional[int] = None,
|
241 |
+
use_text_emb: bool = True,
|
242 |
+
use_diff_step: bool = True,
|
243 |
+
use_cond: bool = True,
|
244 |
+
config: LlamaConfig = LlamaConfig(
|
245 |
+
vocab_size=0, hidden_size=256, num_attention_heads=1, num_hidden_layers=1
|
246 |
+
),
|
247 |
+
):
|
248 |
+
"""
|
249 |
+
Initializes the DiffLlamaPrefix model.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
hidden_size (int): The hidden dimension of the transformer.
|
253 |
+
num_heads (int): The number of attention heads.
|
254 |
+
num_layers (int): The number of transformer layers.
|
255 |
+
in_dim (Optional[int]): Dimension of the input features. If set, an input linear layer is created.
|
256 |
+
out_dim (Optional[int]): Dimension of the output features. If set, an output linear layer is created.
|
257 |
+
use_text_emb (bool): Whether to use text embeddings as a conditional prefix.
|
258 |
+
use_diff_step (bool): Whether to use diffusion timestep conditioning.
|
259 |
+
use_cond (bool): Whether to use an additional per-token conditional input `cond`.
|
260 |
+
config (LlamaConfig): A LlamaConfig object. A default is provided for convenience.
|
261 |
+
"""
|
262 |
+
super().__init__(config)
|
263 |
+
|
264 |
+
self.use_text_emb = use_text_emb
|
265 |
+
self.use_diff_step = use_diff_step
|
266 |
+
self.use_cond = use_cond
|
267 |
+
self.in_dim = in_dim
|
268 |
+
self.out_dim = out_dim
|
269 |
+
|
270 |
+
# Create a stack of non-autoregressive Llama layers
|
271 |
+
self.layers = nn.ModuleList(
|
272 |
+
[
|
273 |
+
LlamaNARDecoderLayer(
|
274 |
+
LlamaConfig(
|
275 |
+
hidden_size=hidden_size,
|
276 |
+
num_attention_heads=num_heads,
|
277 |
+
max_position_embeddings=4096,
|
278 |
+
intermediate_size=hidden_size * 4,
|
279 |
+
),
|
280 |
+
layer_idx=i,
|
281 |
+
use_cond=use_cond,
|
282 |
+
)
|
283 |
+
for i in range(num_layers)
|
284 |
+
]
|
285 |
+
)
|
286 |
+
|
287 |
+
# Final adaptive layer norm
|
288 |
+
self.norm = LlamaAdaptiveRMSNorm(
|
289 |
+
hidden_size,
|
290 |
+
dim_cond=hidden_size,
|
291 |
+
use_cond=(use_cond or use_diff_step),
|
292 |
+
)
|
293 |
+
|
294 |
+
# MLP for processing text embedding condition
|
295 |
+
if self.use_text_emb:
|
296 |
+
self.text_mlp = nn.Sequential(
|
297 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
298 |
+
nn.SiLU(),
|
299 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
300 |
+
)
|
301 |
+
|
302 |
+
# Modules for processing diffusion timestep condition
|
303 |
+
if self.use_diff_step:
|
304 |
+
self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
|
305 |
+
self.diff_step_mlp = nn.Sequential(
|
306 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
307 |
+
nn.SiLU(),
|
308 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
309 |
+
)
|
310 |
+
|
311 |
+
# MLP for processing the general per-token `cond` condition
|
312 |
+
if self.use_cond:
|
313 |
+
self.cond_mlp = nn.Sequential(
|
314 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
315 |
+
nn.SiLU(),
|
316 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
317 |
+
)
|
318 |
+
|
319 |
+
# Ensure all layers use the correct adaptive norm configuration
|
320 |
+
for layer in self.layers:
|
321 |
+
layer.input_layernorm = LlamaAdaptiveRMSNorm(
|
322 |
+
hidden_size,
|
323 |
+
dim_cond=hidden_size,
|
324 |
+
use_cond=(use_cond or use_diff_step),
|
325 |
+
)
|
326 |
+
layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
|
327 |
+
hidden_size,
|
328 |
+
dim_cond=hidden_size,
|
329 |
+
use_cond=(use_cond or use_diff_step),
|
330 |
+
)
|
331 |
+
|
332 |
+
# We handle embeddings manually, so disable the default token embedder
|
333 |
+
self.embed_tokens = None
|
334 |
+
|
335 |
+
# Optional input and output projection layers
|
336 |
+
if self.in_dim is not None:
|
337 |
+
self.in_linear = nn.Linear(self.in_dim, hidden_size)
|
338 |
+
if self.out_dim is not None:
|
339 |
+
self.out_linear = nn.Linear(hidden_size, self.out_dim)
|
340 |
+
|
341 |
+
self.post_init()
|
342 |
+
|
343 |
+
def _prepare_decoder_attention_mask(
|
344 |
+
self,
|
345 |
+
attention_mask: torch.Tensor,
|
346 |
+
input_shape: Tuple[int, int, int],
|
347 |
+
inputs_embeds: torch.Tensor,
|
348 |
+
past_key_values_length: int,
|
349 |
+
) -> Optional[torch.Tensor]:
|
350 |
+
"""
|
351 |
+
Creates a non-causal (fully-visible) attention mask.
|
352 |
+
|
353 |
+
This method overrides the default causal mask creation in LlamaModel.
|
354 |
+
It converts a 2D padding mask `[bsz, seq_len]` into a 4D attention mask
|
355 |
+
`[bsz, 1, tgt_seq_len, src_seq_len]` suitable for self-attention,
|
356 |
+
without applying a causal triangle.
|
357 |
+
|
358 |
+
Args:
|
359 |
+
attention_mask (torch.Tensor): The 2D padding mask.
|
360 |
+
input_shape (Tuple[int, int, int]): The shape of the input embeddings.
|
361 |
+
inputs_embeds (torch.Tensor): The input embeddings tensor.
|
362 |
+
past_key_values_length (int): The length of any cached key-values.
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
Optional[torch.Tensor]: The 4D attention mask, or None if the input mask is None.
|
366 |
+
"""
|
367 |
+
combined_attention_mask = None
|
368 |
+
|
369 |
+
def _expand_mask(
|
370 |
+
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
371 |
+
) -> torch.Tensor:
|
372 |
+
"""
|
373 |
+
Expands a 2D attention mask to a 4D attention mask.
|
374 |
+
"""
|
375 |
+
bsz, src_len = mask.size()
|
376 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
377 |
+
|
378 |
+
expanded_mask = (
|
379 |
+
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
380 |
+
)
|
381 |
+
|
382 |
+
# Invert the mask: 1.0 for valid tokens, 0.0 for padded tokens
|
383 |
+
# and then convert to additive mask format (-inf for masked positions)
|
384 |
+
inverted_mask = 1.0 - expanded_mask
|
385 |
+
return inverted_mask.masked_fill(
|
386 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
387 |
+
)
|
388 |
+
|
389 |
+
if attention_mask is not None:
|
390 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
391 |
+
expanded_attn_mask = _expand_mask(
|
392 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
393 |
+
).to(inputs_embeds.device)
|
394 |
+
combined_attention_mask = (
|
395 |
+
expanded_attn_mask
|
396 |
+
if combined_attention_mask is None
|
397 |
+
else expanded_attn_mask + combined_attention_mask
|
398 |
+
)
|
399 |
+
|
400 |
+
return combined_attention_mask
|
401 |
+
|
402 |
+
def forward(
|
403 |
+
self,
|
404 |
+
x: torch.Tensor,
|
405 |
+
x_mask: torch.Tensor,
|
406 |
+
cond: Optional[torch.Tensor] = None,
|
407 |
+
diffusion_step: Optional[torch.Tensor] = None,
|
408 |
+
text_embedding: Optional[torch.Tensor] = None,
|
409 |
+
text_mask: Optional[torch.Tensor] = None,
|
410 |
+
input_ids: Optional[torch.LongTensor] = None,
|
411 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
412 |
+
position_ids: Optional[torch.LongTensor] = None,
|
413 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
414 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
415 |
+
use_cache: Optional[bool] = None,
|
416 |
+
output_attentions: Optional[bool] = None,
|
417 |
+
output_hidden_states: Optional[bool] = None,
|
418 |
+
return_dict: Optional[bool] = None,
|
419 |
+
) -> Union[torch.Tensor, Tuple]:
|
420 |
+
"""
|
421 |
+
Forward pass of the DiffLlamaPrefix model.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
x (torch.Tensor): The primary input tensor, e.g., noisy data in diffusion `(batch, seq_len, in_dim)`.
|
425 |
+
x_mask (torch.Tensor): The padding mask for `x`, shape `(batch, seq_len)`.
|
426 |
+
cond (Optional[torch.Tensor]): Additional per-token conditional input, shape `(batch, seq_len, hidden_size)`.
|
427 |
+
diffusion_step (Optional[torch.Tensor]): Diffusion timesteps, shape `(batch,)`.
|
428 |
+
text_embedding (Optional[torch.Tensor]): Text embeddings to be used as a prefix, shape `(batch, text_len, hidden_size)`.
|
429 |
+
text_mask (Optional[torch.Tensor]): The padding mask for `text_embedding`, shape `(batch, text_len)`.
|
430 |
+
input_ids, attention_mask, etc.: Standard Hugging Face arguments (mostly unused here but kept for compatibility).
|
431 |
+
|
432 |
+
Returns:
|
433 |
+
Union[torch.Tensor, Tuple]: The final output tensor `(batch, seq_len, out_dim)`.
|
434 |
+
If `output_hidden_states` is True, returns a tuple `(output_tensor, all_hidden_states)`.
|
435 |
+
"""
|
436 |
+
# 1. Input and Conditional Projections
|
437 |
+
if self.in_dim is not None:
|
438 |
+
x = self.in_linear(x)
|
439 |
+
|
440 |
+
if self.use_cond and cond is not None:
|
441 |
+
cond_embedding = self.cond_mlp(cond) # (B, T, C)
|
442 |
+
x = x + cond_embedding
|
443 |
+
|
444 |
+
# 2. Prepend Text Embedding Prefix
|
445 |
+
if self.use_text_emb and text_embedding is not None:
|
446 |
+
text_embedding = self.text_mlp(text_embedding) # (B, T, C)
|
447 |
+
text_length = text_embedding.shape[1]
|
448 |
+
# Concatenate text prefix and main input
|
449 |
+
inputs_embeds = torch.cat([text_embedding, x], dim=1)
|
450 |
+
attention_mask = torch.cat([text_mask, x_mask], dim=1)
|
451 |
+
else:
|
452 |
+
inputs_embeds = x
|
453 |
+
attention_mask = x_mask
|
454 |
+
text_length = 0
|
455 |
+
|
456 |
+
# 3. Process Diffusion Step Embedding for Adaptive Norm
|
457 |
+
if self.use_diff_step and diffusion_step is not None:
|
458 |
+
# Convert scalar timesteps to vector embeddings and project with an MLP
|
459 |
+
diffusion_step_emb = self.diff_step_embedding(diffusion_step).to(x.device)
|
460 |
+
diffusion_step_emb = self.diff_step_mlp(diffusion_step_emb) # (B, C)
|
461 |
+
else:
|
462 |
+
diffusion_step_emb = None
|
463 |
+
|
464 |
+
# 4. Standard Transformer Preamble (adapted from LlamaModel)
|
465 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
466 |
+
|
467 |
+
output_attentions = (
|
468 |
+
output_attentions
|
469 |
+
if output_attentions is not None
|
470 |
+
else self.config.output_attentions
|
471 |
+
)
|
472 |
+
output_hidden_states = (
|
473 |
+
output_hidden_states
|
474 |
+
if output_hidden_states is not None
|
475 |
+
else self.config.output_hidden_states
|
476 |
+
)
|
477 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
478 |
+
return_dict = (
|
479 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
480 |
+
)
|
481 |
+
|
482 |
+
seq_length_with_past = seq_length
|
483 |
+
past_key_values_length = 0
|
484 |
+
|
485 |
+
if past_key_values is not None:
|
486 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
487 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
488 |
+
|
489 |
+
if position_ids is None:
|
490 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
491 |
+
position_ids = torch.arange(
|
492 |
+
past_key_values_length,
|
493 |
+
seq_length + past_key_values_length,
|
494 |
+
dtype=torch.long,
|
495 |
+
device=device,
|
496 |
+
)
|
497 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
498 |
+
else:
|
499 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
500 |
+
|
501 |
+
if attention_mask is None:
|
502 |
+
attention_mask = torch.ones(
|
503 |
+
(batch_size, seq_length_with_past),
|
504 |
+
dtype=torch.bool,
|
505 |
+
device=inputs_embeds.device,
|
506 |
+
)
|
507 |
+
# Create the non-causal attention mask
|
508 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
509 |
+
attention_mask,
|
510 |
+
(batch_size, seq_length),
|
511 |
+
inputs_embeds,
|
512 |
+
past_key_values_length,
|
513 |
+
)
|
514 |
+
|
515 |
+
hidden_states = inputs_embeds
|
516 |
+
|
517 |
+
# 5. Transformer Decoder Layers
|
518 |
+
all_hidden_states = () if output_hidden_states else None
|
519 |
+
all_self_attns = () if output_attentions else None
|
520 |
+
next_decoder_cache = () if use_cache else None
|
521 |
+
|
522 |
+
for idx, decoder_layer in enumerate(self.layers):
|
523 |
+
if output_hidden_states:
|
524 |
+
# Store hidden states before the layer, excluding the text prefix part
|
525 |
+
all_hidden_states += (hidden_states[:, text_length:],)
|
526 |
+
|
527 |
+
past_key_value = (
|
528 |
+
past_key_values[idx] if past_key_values is not None else None
|
529 |
+
)
|
530 |
+
|
531 |
+
# Pass the processed diffusion step embedding to the adaptive layer
|
532 |
+
layer_outputs = decoder_layer(
|
533 |
+
hidden_states,
|
534 |
+
attention_mask=attention_mask,
|
535 |
+
position_ids=position_ids,
|
536 |
+
past_key_value=past_key_value,
|
537 |
+
output_attentions=output_attentions,
|
538 |
+
use_cache=use_cache,
|
539 |
+
cond_embedding=diffusion_step_emb,
|
540 |
+
)
|
541 |
+
hidden_states = layer_outputs[0]
|
542 |
+
|
543 |
+
if use_cache:
|
544 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
545 |
+
|
546 |
+
if output_attentions:
|
547 |
+
all_self_attns += (layer_outputs[1],)
|
548 |
+
|
549 |
+
# 6. Final Normalization and Output Processing
|
550 |
+
hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step_emb)
|
551 |
+
|
552 |
+
if output_hidden_states:
|
553 |
+
# Add hidden states from the last decoder layer
|
554 |
+
all_hidden_states += (hidden_states[:, text_length:],)
|
555 |
+
|
556 |
+
next_cache = next_decoder_cache if use_cache else None
|
557 |
+
|
558 |
+
# Remove the text prefix from the final output sequence
|
559 |
+
hidden_states = hidden_states[
|
560 |
+
:,
|
561 |
+
text_length:,
|
562 |
+
]
|
563 |
+
|
564 |
+
# Apply final output projection if specified
|
565 |
+
if self.out_dim is not None:
|
566 |
+
hidden_states = self.out_linear(hidden_states)
|
567 |
+
|
568 |
+
# 7. Return results (simplified for this snippet)
|
569 |
+
if output_hidden_states:
|
570 |
+
return hidden_states, all_hidden_states
|
571 |
+
else:
|
572 |
+
return hidden_states
|
models/tts/tadicodec/modeling_tadicodec.py
ADDED
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
from einops import rearrange
|
6 |
+
from typing import Optional, Dict, Any
|
7 |
+
|
8 |
+
from models.tts.tadicodec.llama_nar_prefix import DiffLlamaPrefix
|
9 |
+
from models.codec.amphion_codec.quantize.bsq import (
|
10 |
+
BinarySphericalQuantizer,
|
11 |
+
SimpleQuantizer,
|
12 |
+
)
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
|
16 |
+
class TaDiCodec(nn.Module):
|
17 |
+
"""
|
18 |
+
TaDiCodec: A diffusion-based codec model for Text-to-Speech (TTS)
|
19 |
+
that uses a non-autoregressive Llama-style transformer architecture.
|
20 |
+
|
21 |
+
It consists of:
|
22 |
+
1. An Encoder that processes input features (e.g., mel-spectrograms or SSL features)
|
23 |
+
into a latent representation.
|
24 |
+
2. A Vector Quantizer (VQ) that discretizes the latent representation into codes.
|
25 |
+
3. A Decoder that generates the output feature (e.g., mel-spectrogram) from the codes,
|
26 |
+
optional text conditioning, and a prompt, using a flow-matching diffusion process.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
mel_dim: int = 128,
|
32 |
+
in_dim: int = 128,
|
33 |
+
hidden_size: int = 1024,
|
34 |
+
encoder_num_layers: int = 8,
|
35 |
+
decoder_num_layers: int = 16,
|
36 |
+
num_heads: int = 16,
|
37 |
+
cond_drop_p: float = 0.2, # drop code for decoder
|
38 |
+
context_drop_p: float = 0.2, # drop context (mel) for decoder
|
39 |
+
down_sample_factor: int = 8, # down sample factor for vq
|
40 |
+
vq_emb_dim: int = 14, # codebook size 2^vq_emb_dim, 2^14 = 16384
|
41 |
+
use_text_cond: bool = True, # use text cond for decoder
|
42 |
+
text_vocab_size: int = 32100, # vocab size
|
43 |
+
cond_dim: int = 1024,
|
44 |
+
cond_scale_factor: int = 1,
|
45 |
+
sigma: float = 1e-5,
|
46 |
+
time_scheduler: str = "linear",
|
47 |
+
use_vq: bool = True,
|
48 |
+
vq_type: str = "bsq",
|
49 |
+
use_repa_loss: bool = False,
|
50 |
+
cfg: Optional[Any] = None,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Initializes the TaDiCodec model.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
mel_dim (int): Dimension of the mel-spectrogram.
|
57 |
+
in_dim (int): Dimension of the encoder's input features.
|
58 |
+
hidden_size (int): Hidden size of the transformer models.
|
59 |
+
encoder_num_layers (int): Number of layers in the encoder transformer.
|
60 |
+
decoder_num_layers (int): Number of layers in the decoder transformer.
|
61 |
+
num_heads (int): Number of attention heads in the transformers.
|
62 |
+
cond_drop_p (float): Dropout probability for the VQ code condition in the decoder.
|
63 |
+
context_drop_p (float): Dropout probability for the prompt context in the decoder.
|
64 |
+
down_sample_factor (int): Factor by which to downsample the latent representation before VQ.
|
65 |
+
vq_emb_dim (int): Dimension of the vector quantizer's embedding space.
|
66 |
+
use_text_cond (bool): Whether to use text embeddings as a condition in the decoder.
|
67 |
+
text_vocab_size (int): Size of the text vocabulary.
|
68 |
+
cond_dim (int): Dimension of the conditional input.
|
69 |
+
cond_scale_factor (int): Scaling factor for the condition.
|
70 |
+
sigma (float): Small constant used in the flow matching formula.
|
71 |
+
time_scheduler (str): Type of time scheduler for diffusion (e.g., 'linear').
|
72 |
+
use_vq (bool): Whether to use vector quantization.
|
73 |
+
vq_type (str): Type of vector quantizer ('bsq' or 'simple').
|
74 |
+
use_repa_loss (bool): Whether to use the representational alignment loss.
|
75 |
+
cfg (Optional[Any]): A configuration object that can override the default parameters.
|
76 |
+
"""
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
# Override parameters with values from the config object if provided
|
80 |
+
mel_dim = (
|
81 |
+
cfg.mel_dim if cfg is not None and hasattr(cfg, "mel_dim") else mel_dim
|
82 |
+
)
|
83 |
+
in_dim = cfg.in_dim if cfg is not None and hasattr(cfg, "in_dim") else in_dim
|
84 |
+
hidden_size = (
|
85 |
+
cfg.hidden_size
|
86 |
+
if cfg is not None and hasattr(cfg, "hidden_size")
|
87 |
+
else hidden_size
|
88 |
+
)
|
89 |
+
encoder_num_layers = (
|
90 |
+
cfg.encoder_num_layers
|
91 |
+
if cfg is not None and hasattr(cfg, "encoder_num_layers")
|
92 |
+
else encoder_num_layers
|
93 |
+
)
|
94 |
+
decoder_num_layers = (
|
95 |
+
cfg.decoder_num_layers
|
96 |
+
if cfg is not None and hasattr(cfg, "decoder_num_layers")
|
97 |
+
else decoder_num_layers
|
98 |
+
)
|
99 |
+
num_heads = (
|
100 |
+
cfg.num_heads
|
101 |
+
if cfg is not None and hasattr(cfg, "num_heads")
|
102 |
+
else num_heads
|
103 |
+
)
|
104 |
+
cond_drop_p = (
|
105 |
+
cfg.cond_drop_p
|
106 |
+
if cfg is not None and hasattr(cfg, "cond_drop_p")
|
107 |
+
else cond_drop_p
|
108 |
+
)
|
109 |
+
context_drop_p = (
|
110 |
+
cfg.context_drop_p
|
111 |
+
if cfg is not None and hasattr(cfg, "context_drop_p")
|
112 |
+
else context_drop_p
|
113 |
+
)
|
114 |
+
down_sample_factor = (
|
115 |
+
cfg.down_sample_factor
|
116 |
+
if cfg is not None and hasattr(cfg, "down_sample_factor")
|
117 |
+
else down_sample_factor
|
118 |
+
)
|
119 |
+
vq_emb_dim = (
|
120 |
+
cfg.vq_emb_dim
|
121 |
+
if cfg is not None and hasattr(cfg, "vq_emb_dim")
|
122 |
+
else vq_emb_dim
|
123 |
+
)
|
124 |
+
use_text_cond = (
|
125 |
+
cfg.use_text_cond
|
126 |
+
if cfg is not None and hasattr(cfg, "use_text_cond")
|
127 |
+
else use_text_cond
|
128 |
+
)
|
129 |
+
text_vocab_size = (
|
130 |
+
cfg.text_vocab_size
|
131 |
+
if cfg is not None and hasattr(cfg, "text_vocab_size")
|
132 |
+
else text_vocab_size
|
133 |
+
)
|
134 |
+
cond_dim = (
|
135 |
+
cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
|
136 |
+
)
|
137 |
+
cond_scale_factor = (
|
138 |
+
cfg.cond_scale_factor
|
139 |
+
if cfg is not None and hasattr(cfg, "cond_scale_factor")
|
140 |
+
else cond_scale_factor
|
141 |
+
)
|
142 |
+
sigma = cfg.sigma if cfg is not None and hasattr(cfg, "sigma") else sigma
|
143 |
+
time_scheduler = (
|
144 |
+
cfg.time_scheduler
|
145 |
+
if cfg is not None and hasattr(cfg, "time_scheduler")
|
146 |
+
else time_scheduler
|
147 |
+
)
|
148 |
+
use_vq = cfg.use_vq if cfg is not None and hasattr(cfg, "use_vq") else use_vq
|
149 |
+
vq_type = (
|
150 |
+
cfg.vq_type if cfg is not None and hasattr(cfg, "vq_type") else vq_type
|
151 |
+
)
|
152 |
+
use_repa_loss = (
|
153 |
+
cfg.use_repa_loss
|
154 |
+
if cfg is not None and hasattr(cfg, "use_repa_loss")
|
155 |
+
else use_repa_loss
|
156 |
+
)
|
157 |
+
|
158 |
+
self.mel_dim = mel_dim
|
159 |
+
self.in_dim = in_dim
|
160 |
+
self.hidden_size = hidden_size
|
161 |
+
self.encoder_num_layers = encoder_num_layers
|
162 |
+
self.decoder_num_layers = decoder_num_layers
|
163 |
+
self.num_heads = num_heads
|
164 |
+
self.cond_drop_p = cond_drop_p
|
165 |
+
self.context_drop_p = context_drop_p
|
166 |
+
self.vq_emb_dim = vq_emb_dim
|
167 |
+
self.down_sample_factor = down_sample_factor
|
168 |
+
self.use_text_cond = use_text_cond
|
169 |
+
self.text_vocab_size = text_vocab_size
|
170 |
+
self.cond_dim = cond_dim
|
171 |
+
self.cond_scale_factor = cond_scale_factor
|
172 |
+
self.sigma = sigma
|
173 |
+
self.time_scheduler = time_scheduler
|
174 |
+
self.use_vq = use_vq
|
175 |
+
self.vq_type = vq_type
|
176 |
+
self.use_repa_loss = use_repa_loss
|
177 |
+
|
178 |
+
# Text embedding layer
|
179 |
+
if self.use_text_cond:
|
180 |
+
self.text_emb = nn.Embedding(text_vocab_size, hidden_size)
|
181 |
+
|
182 |
+
# VQ related layers
|
183 |
+
self.vq_in_linear = nn.Linear(hidden_size, vq_emb_dim)
|
184 |
+
if self.use_vq:
|
185 |
+
if self.vq_type == "bsq":
|
186 |
+
self.bsq = BinarySphericalQuantizer(embed_dim=vq_emb_dim)
|
187 |
+
else:
|
188 |
+
self.bsq = SimpleQuantizer(embed_dim=vq_emb_dim)
|
189 |
+
self.vq_out_linear = nn.Linear(vq_emb_dim, hidden_size)
|
190 |
+
|
191 |
+
# Repa (Representational Alignment) MLP for auxiliary loss
|
192 |
+
if self.use_repa_loss:
|
193 |
+
self.repa_mlp = nn.Sequential(
|
194 |
+
nn.Linear(hidden_size, hidden_size * 2),
|
195 |
+
nn.GELU(),
|
196 |
+
nn.Linear(hidden_size * 2, 1024),
|
197 |
+
)
|
198 |
+
self.repa_layer_idx = (
|
199 |
+
6 # The decoder layer from which to extract hidden states
|
200 |
+
)
|
201 |
+
|
202 |
+
self.reset_parameters()
|
203 |
+
|
204 |
+
# Encoder: A non-autoregressive Llama-style model without time/text conditioning
|
205 |
+
self.encoder = DiffLlamaPrefix(
|
206 |
+
hidden_size=hidden_size,
|
207 |
+
num_layers=encoder_num_layers,
|
208 |
+
num_heads=num_heads,
|
209 |
+
in_dim=self.in_dim,
|
210 |
+
out_dim=None, # Outputs hidden states for VQ
|
211 |
+
use_text_emb=False,
|
212 |
+
use_diff_step=False,
|
213 |
+
use_cond=False,
|
214 |
+
)
|
215 |
+
|
216 |
+
# Decoder: A non-autoregressive Llama-style model with time, text, and code conditioning
|
217 |
+
self.decoder = DiffLlamaPrefix(
|
218 |
+
hidden_size=hidden_size,
|
219 |
+
num_layers=decoder_num_layers,
|
220 |
+
num_heads=num_heads,
|
221 |
+
in_dim=self.mel_dim,
|
222 |
+
out_dim=self.mel_dim,
|
223 |
+
use_text_emb=use_text_cond,
|
224 |
+
use_diff_step=True,
|
225 |
+
use_cond=True,
|
226 |
+
)
|
227 |
+
|
228 |
+
@torch.no_grad()
|
229 |
+
def forward_diffusion(
|
230 |
+
self, x: torch.Tensor, t: torch.Tensor
|
231 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
232 |
+
"""
|
233 |
+
Performs the forward diffusion process based on flow matching.
|
234 |
+
It takes the clean data `x` and a timestep `t` to produce a noisy sample `xt`.
|
235 |
+
It also creates a prompt/target mask.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
x (torch.Tensor): The clean input data (e.g., mel-spectrogram), shape `(B, T, mel_dim)`.
|
239 |
+
t (torch.Tensor): The diffusion timestep for each sample in the batch, shape `(B,)`.
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
Tuple[torch.Tensor, ...]:
|
243 |
+
- xt (torch.Tensor): The noisy sample at time `t`, shape `(B, T, mel_dim)`.
|
244 |
+
- z (torch.Tensor): The noise vector used, drawn from N(0, I).
|
245 |
+
- new_t (torch.Tensor): The original `t` tensor.
|
246 |
+
- prompt_len (torch.Tensor): The length of the prompt for each sample.
|
247 |
+
- mask (torch.Tensor): A mask where 1 indicates the target (noisy) region and 0 indicates the prompt (clean) region.
|
248 |
+
"""
|
249 |
+
new_t = t
|
250 |
+
t = t.unsqueeze(-1).unsqueeze(-1) # Reshape for broadcasting
|
251 |
+
z = torch.randn(
|
252 |
+
x.shape, dtype=x.dtype, device=x.device, requires_grad=False
|
253 |
+
) # (B, T, mel_dim)
|
254 |
+
|
255 |
+
context_drop_p = self.context_drop_p
|
256 |
+
|
257 |
+
# Randomly decide the length of the prompt (un-noised context)
|
258 |
+
if torch.rand(1) > context_drop_p:
|
259 |
+
prompt_len = torch.randint(
|
260 |
+
min(x.shape[1] // 4, 5), int(x.shape[1] * 0.4), (x.shape[0],)
|
261 |
+
).to(x.device)
|
262 |
+
else:
|
263 |
+
# Drop the context entirely by setting prompt length to 0
|
264 |
+
prompt_len = torch.zeros(x.shape[0], device=x.device)
|
265 |
+
|
266 |
+
# Create a mask to distinguish prompt from target
|
267 |
+
is_prompt = torch.zeros_like(x[:, :, 0]) # (B, T)
|
268 |
+
col_indices = torch.arange(is_prompt.shape[1], device=prompt_len.device).repeat(
|
269 |
+
is_prompt.shape[0], 1
|
270 |
+
) # (B, T)
|
271 |
+
is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # 1 if it's a prompt frame
|
272 |
+
|
273 |
+
mask = torch.ones_like(x[:, :, 0]) # Mask is 1 for target, 0 for prompt
|
274 |
+
mask[is_prompt.bool()] = 0
|
275 |
+
mask = mask.unsqueeze(-1) # (B, T, 1)
|
276 |
+
|
277 |
+
# Flow matching formula: xt = (1 - (1 - sigma) * t) * x0 + t * x
|
278 |
+
# where x0 ~ N(0, 1) and x is the clean data sample.
|
279 |
+
# The equation is applied only to the target region.
|
280 |
+
xt = ((1 - (1 - self.sigma) * t) * z + t * x) * mask + x * (1 - mask)
|
281 |
+
|
282 |
+
return xt, z, new_t, prompt_len, mask
|
283 |
+
|
284 |
+
def forward(
|
285 |
+
self,
|
286 |
+
x: torch.Tensor,
|
287 |
+
x_mask: torch.Tensor,
|
288 |
+
x_in: Optional[torch.Tensor] = None,
|
289 |
+
text_ids: Optional[torch.Tensor] = None,
|
290 |
+
text_mask: Optional[torch.Tensor] = None,
|
291 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
292 |
+
"""
|
293 |
+
The main training-time forward pass of the model.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
x (torch.Tensor): The target mel-spectrogram, shape `(B, T, mel_dim)`.
|
297 |
+
x_mask (torch.Tensor): Padding mask for `x`, shape `(B, T)`.
|
298 |
+
x_in (Optional[torch.Tensor]): Optional input for the encoder (e.g., SSL features). If None, `x` is used.
|
299 |
+
text_ids (Optional[torch.Tensor]): Input text token IDs for conditioning, shape `(B, text_len)`.
|
300 |
+
text_mask (Optional[torch.Tensor]): Padding mask for `text_ids`.
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
Dict[str, Optional[torch.Tensor]]: A dictionary containing various tensors for loss computation,
|
304 |
+
such as the predicted flow, the final predicted mel, the noise target, etc.
|
305 |
+
"""
|
306 |
+
# 1. Encoder pass
|
307 |
+
if x_in is None:
|
308 |
+
# Use target mel as encoder input
|
309 |
+
vq_emb_pre = self.encoder(x=x, x_mask=x_mask)
|
310 |
+
else:
|
311 |
+
# Use provided SSL features as encoder input
|
312 |
+
vq_emb_pre = self.encoder(x=x_in, x_mask=x_mask)
|
313 |
+
|
314 |
+
# 2. Downsampling before VQ
|
315 |
+
_, T, _ = vq_emb_pre.shape
|
316 |
+
vq_emb_pre = vq_emb_pre.transpose(1, 2)
|
317 |
+
vq_emb_pre = F.interpolate(
|
318 |
+
vq_emb_pre, size=T // self.down_sample_factor, mode="linear"
|
319 |
+
)
|
320 |
+
vq_emb_pre = vq_emb_pre.transpose(1, 2)
|
321 |
+
|
322 |
+
# 3. Vector Quantization
|
323 |
+
vq_emb_pre = self.vq_in_linear(vq_emb_pre)
|
324 |
+
vq_emb_pre = F.normalize(vq_emb_pre, dim=-1) # L2 normalize before quantization
|
325 |
+
|
326 |
+
if self.use_vq:
|
327 |
+
vq_emb, vq_loss, info = self.bsq(vq_emb_pre)
|
328 |
+
commit_loss = info["commit_loss"]
|
329 |
+
else:
|
330 |
+
vq_emb = vq_emb_pre
|
331 |
+
vq_loss = torch.tensor(0.0, device=x.device)
|
332 |
+
commit_loss = torch.tensor(0.0, device=x.device)
|
333 |
+
info = None
|
334 |
+
|
335 |
+
vq_emb_post = self.vq_out_linear(vq_emb)
|
336 |
+
|
337 |
+
# 4. Upsampling after VQ
|
338 |
+
vq_emb_post = vq_emb_post.transpose(1, 2)
|
339 |
+
vq_emb_post = F.interpolate(vq_emb_post, size=T, mode="linear")
|
340 |
+
vq_emb_post = vq_emb_post.transpose(1, 2)
|
341 |
+
|
342 |
+
# 5. Decoder with flow matching
|
343 |
+
# Sample a random timestep t for each item in the batch
|
344 |
+
t = torch.rand(x.shape[0], device=x.device, requires_grad=False)
|
345 |
+
t = torch.clamp(t, 1e-5, 1.0) # Clamp to avoid numerical issues at boundaries
|
346 |
+
|
347 |
+
# Perform forward diffusion to get the noisy input `xt` and the noise `z`
|
348 |
+
xt, z, new_t, prompt_len, mask = self.forward_diffusion(x, t)
|
349 |
+
noise = z
|
350 |
+
|
351 |
+
# 6. Prepare conditions for the decoder
|
352 |
+
if self.use_text_cond:
|
353 |
+
text_emb = self.text_emb(text_ids)
|
354 |
+
else:
|
355 |
+
text_emb = None
|
356 |
+
|
357 |
+
# Use the upsampled VQ embedding as the primary condition
|
358 |
+
cond_emb = vq_emb_post
|
359 |
+
# Apply condition dropout for classifier-free guidance training
|
360 |
+
if torch.rand(1) < self.cond_drop_p:
|
361 |
+
cond_emb = torch.zeros_like(cond_emb)
|
362 |
+
|
363 |
+
# 7. Decoder pass
|
364 |
+
if self.use_repa_loss:
|
365 |
+
# If using Repa loss, we need to output intermediate hidden states
|
366 |
+
flow_pred, hidden_states = self.decoder(
|
367 |
+
x=xt,
|
368 |
+
x_mask=x_mask,
|
369 |
+
text_embedding=text_emb,
|
370 |
+
text_mask=text_mask,
|
371 |
+
cond=cond_emb,
|
372 |
+
diffusion_step=new_t,
|
373 |
+
output_hidden_states=True,
|
374 |
+
)
|
375 |
+
ssl_feat_pred = self.repa_mlp(hidden_states[self.repa_layer_idx])
|
376 |
+
else:
|
377 |
+
flow_pred = self.decoder(
|
378 |
+
x=xt,
|
379 |
+
x_mask=x_mask,
|
380 |
+
text_embedding=text_emb,
|
381 |
+
text_mask=text_mask,
|
382 |
+
cond=cond_emb,
|
383 |
+
diffusion_step=new_t,
|
384 |
+
)
|
385 |
+
ssl_feat_pred = None
|
386 |
+
|
387 |
+
# Predict the clean data `x0_pred` from the noisy input `xt` and the predicted flow
|
388 |
+
# x_pred = xt + (1 - t) * flow_pred
|
389 |
+
x_pred = xt + (1 - t.unsqueeze(-1).unsqueeze(-1)) * flow_pred
|
390 |
+
|
391 |
+
# Final mask should consider both the prompt/target mask and the original padding mask
|
392 |
+
final_mask = mask * x_mask.unsqueeze(-1)
|
393 |
+
|
394 |
+
return {
|
395 |
+
"noise": noise,
|
396 |
+
"x": x,
|
397 |
+
"flow_pred": flow_pred,
|
398 |
+
"x_pred": x_pred,
|
399 |
+
"final_mask": final_mask,
|
400 |
+
"prompt_len": prompt_len,
|
401 |
+
"vq_loss": vq_loss,
|
402 |
+
"commit_loss": commit_loss,
|
403 |
+
"ssl_feat_pred": ssl_feat_pred,
|
404 |
+
}
|
405 |
+
|
406 |
+
@torch.no_grad()
|
407 |
+
def encode(
|
408 |
+
self, x: torch.Tensor, x_mask: torch.Tensor
|
409 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
410 |
+
"""
|
411 |
+
Encodes an input feature `x` into discrete VQ codes. (Inference)
|
412 |
+
|
413 |
+
Args:
|
414 |
+
x (torch.Tensor): Input feature, shape `(B, T, in_dim)`.
|
415 |
+
x_mask (torch.Tensor): Padding mask for `x`, shape `(B, T)`.
|
416 |
+
|
417 |
+
Returns:
|
418 |
+
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
419 |
+
- vq_emb (torch.Tensor): The quantized continuous embeddings.
|
420 |
+
- indices (Optional[torch.Tensor]): The discrete code indices, if VQ is used.
|
421 |
+
"""
|
422 |
+
# Encoder pass
|
423 |
+
vq_emb_pre = self.encoder(x=x, x_mask=x_mask)
|
424 |
+
|
425 |
+
# Downsampling
|
426 |
+
_, T, _ = vq_emb_pre.shape
|
427 |
+
vq_emb_pre = vq_emb_pre.transpose(1, 2)
|
428 |
+
vq_emb_pre = F.interpolate(
|
429 |
+
vq_emb_pre, size=T // self.down_sample_factor, mode="linear"
|
430 |
+
)
|
431 |
+
vq_emb_pre = vq_emb_pre.transpose(1, 2)
|
432 |
+
|
433 |
+
# VQ
|
434 |
+
vq_emb_pre = self.vq_in_linear(vq_emb_pre)
|
435 |
+
vq_emb_pre = F.normalize(vq_emb_pre, dim=-1) # L2 norm
|
436 |
+
|
437 |
+
if self.use_vq:
|
438 |
+
vq_emb, _, info = self.bsq(vq_emb_pre)
|
439 |
+
indices = info["indices"]
|
440 |
+
else:
|
441 |
+
vq_emb = vq_emb_pre
|
442 |
+
indices = None
|
443 |
+
|
444 |
+
return vq_emb, indices
|
445 |
+
|
446 |
+
@torch.no_grad()
|
447 |
+
def index2vq(self, indices: torch.Tensor) -> torch.Tensor:
|
448 |
+
"""
|
449 |
+
Converts VQ code indices back to continuous embeddings.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
indices (torch.Tensor): The discrete code indices.
|
453 |
+
|
454 |
+
Returns:
|
455 |
+
torch.Tensor: The corresponding continuous codebook embeddings.
|
456 |
+
"""
|
457 |
+
return self.bsq.get_codebook_entry(indices).float()
|
458 |
+
|
459 |
+
@torch.no_grad()
|
460 |
+
def reverse_diffusion(
|
461 |
+
self,
|
462 |
+
vq_emb: Optional[torch.Tensor] = None,
|
463 |
+
indices: Optional[torch.Tensor] = None,
|
464 |
+
text_ids: Optional[torch.Tensor] = None,
|
465 |
+
prompt_mel: Optional[torch.Tensor] = None,
|
466 |
+
x_mask: Optional[torch.Tensor] = None,
|
467 |
+
prompt_mask: Optional[torch.Tensor] = None,
|
468 |
+
text_mask: Optional[torch.Tensor] = None,
|
469 |
+
n_timesteps: int = 32,
|
470 |
+
cfg: float = 1.0,
|
471 |
+
rescale_cfg: float = 0.75,
|
472 |
+
) -> torch.Tensor:
|
473 |
+
"""
|
474 |
+
Performs the reverse diffusion process to generate mel-spectrograms from conditions. (Inference)
|
475 |
+
|
476 |
+
Args:
|
477 |
+
vq_emb (Optional[torch.Tensor]): Pre-quantized embeddings.
|
478 |
+
indices (Optional[torch.Tensor]): Discrete VQ code indices. If provided, `vq_emb` is ignored.
|
479 |
+
text_ids (Optional[torch.Tensor]): Text token IDs for conditioning.
|
480 |
+
prompt_mel (Optional[torch.Tensor]): A mel-spectrogram prompt.
|
481 |
+
x_mask (Optional[torch.Tensor]): Padding mask for the target generation length.
|
482 |
+
prompt_mask (Optional[torch.Tensor]): Padding mask for the prompt.
|
483 |
+
text_mask (Optional[torch.Tensor]): Padding mask for the text.
|
484 |
+
n_timesteps (int): Number of steps in the reverse diffusion process.
|
485 |
+
cfg (float): Classifier-Free Guidance scale.
|
486 |
+
rescale_cfg (float): Rescaling factor for CFG to prevent saturation.
|
487 |
+
|
488 |
+
Returns:
|
489 |
+
torch.Tensor: The generated mel-spectrogram.
|
490 |
+
"""
|
491 |
+
if vq_emb is None:
|
492 |
+
assert indices is not None, "Either vq_emb or indices must be provided"
|
493 |
+
vq_emb = self.index2vq(indices.long())
|
494 |
+
|
495 |
+
# Upsample VQ embeddings to match the target mel length
|
496 |
+
vq_emb_post = self.vq_out_linear(vq_emb)
|
497 |
+
vq_emb_post = vq_emb_post.transpose(1, 2)
|
498 |
+
vq_emb_post = F.interpolate(
|
499 |
+
vq_emb_post, scale_factor=self.down_sample_factor, mode="linear"
|
500 |
+
)
|
501 |
+
vq_emb_post = vq_emb_post.transpose(1, 2)
|
502 |
+
|
503 |
+
# Prepare text embeddings
|
504 |
+
if self.use_text_cond:
|
505 |
+
text_emb = self.text_emb(text_ids)
|
506 |
+
if text_mask is None:
|
507 |
+
text_mask = torch.ones_like(text_ids)
|
508 |
+
else:
|
509 |
+
text_emb, text_mask = None, None
|
510 |
+
|
511 |
+
cond_emb = vq_emb_post
|
512 |
+
|
513 |
+
# Handle prompt
|
514 |
+
if prompt_mel is None:
|
515 |
+
prompt_mel = torch.zeros(
|
516 |
+
cond_emb.shape[0], 0, self.mel_dim, device=cond_emb.device
|
517 |
+
)
|
518 |
+
|
519 |
+
prompt_len = prompt_mel.shape[1]
|
520 |
+
target_len = cond_emb.shape[1] - prompt_len
|
521 |
+
|
522 |
+
# Prepare masks
|
523 |
+
if x_mask is None:
|
524 |
+
x_mask = torch.ones(cond_emb.shape[0], target_len, device=cond_emb.device)
|
525 |
+
if prompt_mask is None:
|
526 |
+
prompt_mask = torch.ones(
|
527 |
+
cond_emb.shape[0], prompt_len, device=cond_emb.device
|
528 |
+
)
|
529 |
+
|
530 |
+
xt_mask = torch.cat([prompt_mask, x_mask], dim=1)
|
531 |
+
|
532 |
+
# Initialize with random noise
|
533 |
+
z = torch.randn(
|
534 |
+
(cond_emb.shape[0], target_len, self.mel_dim),
|
535 |
+
dtype=cond_emb.dtype,
|
536 |
+
device=cond_emb.device,
|
537 |
+
)
|
538 |
+
xt = z
|
539 |
+
h = 1.0 / n_timesteps
|
540 |
+
|
541 |
+
# Iterative denoising loop (Euler method)
|
542 |
+
for i in range(n_timesteps):
|
543 |
+
# Concatenate prompt and current noisy sample
|
544 |
+
xt_input = torch.cat([prompt_mel, xt], dim=1)
|
545 |
+
# Calculate current timestep
|
546 |
+
t = (0 + (i + 0.5) * h) * torch.ones(
|
547 |
+
z.shape[0], dtype=z.dtype, device=z.device
|
548 |
+
)
|
549 |
+
|
550 |
+
# Get conditional flow prediction
|
551 |
+
flow_pred = self.decoder(
|
552 |
+
x=xt_input,
|
553 |
+
x_mask=xt_mask,
|
554 |
+
text_embedding=text_emb,
|
555 |
+
text_mask=text_mask,
|
556 |
+
cond=cond_emb,
|
557 |
+
diffusion_step=t,
|
558 |
+
)
|
559 |
+
flow_pred = flow_pred[
|
560 |
+
:, prompt_len:, :
|
561 |
+
] # Extract flow for the target region
|
562 |
+
|
563 |
+
# Classifier-Free Guidance (CFG)
|
564 |
+
if cfg > 0 and self.use_text_cond:
|
565 |
+
# Get unconditional flow prediction by dropping conditions
|
566 |
+
uncond_flow_pred = self.decoder(
|
567 |
+
x=xt_input,
|
568 |
+
x_mask=xt_mask,
|
569 |
+
text_embedding=None, # Drop text
|
570 |
+
text_mask=None,
|
571 |
+
cond=torch.zeros_like(cond_emb), # Drop code
|
572 |
+
diffusion_step=t,
|
573 |
+
)
|
574 |
+
uncond_flow_pred = uncond_flow_pred[:, prompt_len:, :]
|
575 |
+
|
576 |
+
# Combine conditional and unconditional predictions
|
577 |
+
flow_pred_cfg = uncond_flow_pred + cfg * (flow_pred - uncond_flow_pred)
|
578 |
+
|
579 |
+
# Rescale to prevent saturation, as in Stable Diffusion
|
580 |
+
if rescale_cfg > 0:
|
581 |
+
flow_pred_std = flow_pred.std()
|
582 |
+
cfg_std = flow_pred_cfg.std()
|
583 |
+
# Avoid division by zero
|
584 |
+
if cfg_std > 1e-6:
|
585 |
+
rescale_flow_pred = flow_pred_cfg * (flow_pred_std / cfg_std)
|
586 |
+
flow_pred = (
|
587 |
+
rescale_cfg * rescale_flow_pred
|
588 |
+
+ (1 - rescale_cfg) * flow_pred_cfg
|
589 |
+
)
|
590 |
+
else:
|
591 |
+
flow_pred = flow_pred_cfg
|
592 |
+
else:
|
593 |
+
flow_pred = flow_pred_cfg
|
594 |
+
|
595 |
+
# Update the noisy sample
|
596 |
+
dxt = flow_pred * h
|
597 |
+
xt = xt + dxt
|
598 |
+
|
599 |
+
return xt
|
600 |
+
|
601 |
+
def reset_parameters(self):
|
602 |
+
"""
|
603 |
+
Applies custom weight initialization to the model's submodules.
|
604 |
+
"""
|
605 |
+
|
606 |
+
def _reset_parameters(m: nn.Module):
|
607 |
+
if isinstance(m, nn.MultiheadAttention):
|
608 |
+
if m._qkv_same_embed_dim:
|
609 |
+
nn.init.normal_(m.in_proj_weight, std=0.02)
|
610 |
+
else:
|
611 |
+
nn.init.normal_(m.q_proj_weight, std=0.02)
|
612 |
+
nn.init.normal_(m.k_proj_weight, std=0.02)
|
613 |
+
nn.init.normal_(m.v_proj_weight, std=0.02)
|
614 |
+
|
615 |
+
if m.in_proj_bias is not None:
|
616 |
+
nn.init.constant_(m.in_proj_bias, 0.0)
|
617 |
+
nn.init.constant_(m.out_proj.bias, 0.0)
|
618 |
+
if m.bias_k is not None:
|
619 |
+
nn.init.xavier_normal_(m.bias_k)
|
620 |
+
if m.bias_v is not None:
|
621 |
+
nn.init.xavier_normal_(m.bias_v)
|
622 |
+
|
623 |
+
elif (
|
624 |
+
isinstance(m, nn.Conv1d)
|
625 |
+
or isinstance(m, nn.ConvTranspose1d)
|
626 |
+
or isinstance(m, nn.Conv2d)
|
627 |
+
or isinstance(m, nn.ConvTranspose2d)
|
628 |
+
):
|
629 |
+
m.weight.data.normal_(0.0, 0.02)
|
630 |
+
|
631 |
+
elif isinstance(m, nn.Linear):
|
632 |
+
m.weight.data.normal_(mean=0.0, std=0.02)
|
633 |
+
if m.bias is not None:
|
634 |
+
m.bias.data.zero_()
|
635 |
+
|
636 |
+
elif isinstance(m, nn.Embedding):
|
637 |
+
m.weight.data.normal_(mean=0.0, std=0.02)
|
638 |
+
if m.padding_idx is not None:
|
639 |
+
m.weight.data[m.padding_idx].zero_()
|
640 |
+
|
641 |
+
self.apply(_reset_parameters)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
torchaudio
|
4 |
+
numpy
|
5 |
+
huggingface-hub
|
6 |
+
transformers
|
7 |
+
datasets
|
8 |
+
librosa
|
9 |
+
soundfile
|
utils/__pycache__/hparam.cpython-310.pyc
ADDED
Binary file (21.2 kB). View file
|
|
utils/__pycache__/util.cpython-310.pyc
ADDED
Binary file (19.4 kB). View file
|
|
utils/hparam.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# This code is modified from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long
|
7 |
+
"""Hyperparameter values."""
|
8 |
+
from __future__ import absolute_import
|
9 |
+
from __future__ import division
|
10 |
+
from __future__ import print_function
|
11 |
+
|
12 |
+
import json
|
13 |
+
import numbers
|
14 |
+
import re
|
15 |
+
import six
|
16 |
+
|
17 |
+
# Define the regular expression for parsing a single clause of the input
|
18 |
+
# (delimited by commas). A legal clause looks like:
|
19 |
+
# <variable name>[<index>]? = <rhs>
|
20 |
+
# where <rhs> is either a single token or [] enclosed list of tokens.
|
21 |
+
# For example: "var[1] = a" or "x = [1,2,3]"
|
22 |
+
PARAM_RE = re.compile(
|
23 |
+
r"""
|
24 |
+
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
|
25 |
+
(\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
|
26 |
+
\s*=\s*
|
27 |
+
((?P<val>[^,\[]*) # single value: "a" or None
|
28 |
+
|
|
29 |
+
\[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
|
30 |
+
($|,\s*)""",
|
31 |
+
re.VERBOSE,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def _parse_fail(name, var_type, value, values):
|
36 |
+
"""Helper function for raising a value error for bad assignment."""
|
37 |
+
raise ValueError(
|
38 |
+
"Could not parse hparam '%s' of type '%s' with value '%s' in %s"
|
39 |
+
% (name, var_type.__name__, value, values)
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def _reuse_fail(name, values):
|
44 |
+
"""Helper function for raising a value error for reuse of name."""
|
45 |
+
raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values))
|
46 |
+
|
47 |
+
|
48 |
+
def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary):
|
49 |
+
"""Update results_dictionary with a scalar value.
|
50 |
+
|
51 |
+
Used to update the results_dictionary to be returned by parse_values when
|
52 |
+
encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
|
53 |
+
|
54 |
+
Mutates results_dictionary.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
name: Name of variable in assignment ("s" or "arr").
|
58 |
+
parse_fn: Function for parsing the actual value.
|
59 |
+
var_type: Type of named variable.
|
60 |
+
m_dict: Dictionary constructed from regex parsing.
|
61 |
+
m_dict['val']: RHS value (scalar)
|
62 |
+
m_dict['index']: List index value (or None)
|
63 |
+
values: Full expression being parsed
|
64 |
+
results_dictionary: The dictionary being updated for return by the parsing
|
65 |
+
function.
|
66 |
+
|
67 |
+
Raises:
|
68 |
+
ValueError: If the name has already been used.
|
69 |
+
"""
|
70 |
+
try:
|
71 |
+
parsed_value = parse_fn(m_dict["val"])
|
72 |
+
except ValueError:
|
73 |
+
_parse_fail(name, var_type, m_dict["val"], values)
|
74 |
+
|
75 |
+
# If no index is provided
|
76 |
+
if not m_dict["index"]:
|
77 |
+
if name in results_dictionary:
|
78 |
+
_reuse_fail(name, values)
|
79 |
+
results_dictionary[name] = parsed_value
|
80 |
+
else:
|
81 |
+
if name in results_dictionary:
|
82 |
+
# The name has already been used as a scalar, then it
|
83 |
+
# will be in this dictionary and map to a non-dictionary.
|
84 |
+
if not isinstance(results_dictionary.get(name), dict):
|
85 |
+
_reuse_fail(name, values)
|
86 |
+
else:
|
87 |
+
results_dictionary[name] = {}
|
88 |
+
|
89 |
+
index = int(m_dict["index"])
|
90 |
+
# Make sure the index position hasn't already been assigned a value.
|
91 |
+
if index in results_dictionary[name]:
|
92 |
+
_reuse_fail("{}[{}]".format(name, index), values)
|
93 |
+
results_dictionary[name][index] = parsed_value
|
94 |
+
|
95 |
+
|
96 |
+
def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary):
|
97 |
+
"""Update results_dictionary from a list of values.
|
98 |
+
|
99 |
+
Used to update results_dictionary to be returned by parse_values when
|
100 |
+
encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
|
101 |
+
|
102 |
+
Mutates results_dictionary.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
name: Name of variable in assignment ("arr").
|
106 |
+
parse_fn: Function for parsing individual values.
|
107 |
+
var_type: Type of named variable.
|
108 |
+
m_dict: Dictionary constructed from regex parsing.
|
109 |
+
m_dict['val']: RHS value (scalar)
|
110 |
+
values: Full expression being parsed
|
111 |
+
results_dictionary: The dictionary being updated for return by the parsing
|
112 |
+
function.
|
113 |
+
|
114 |
+
Raises:
|
115 |
+
ValueError: If the name has an index or the values cannot be parsed.
|
116 |
+
"""
|
117 |
+
if m_dict["index"] is not None:
|
118 |
+
raise ValueError("Assignment of a list to a list index.")
|
119 |
+
elements = filter(None, re.split("[ ,]", m_dict["vals"]))
|
120 |
+
# Make sure the name hasn't already been assigned a value
|
121 |
+
if name in results_dictionary:
|
122 |
+
raise _reuse_fail(name, values)
|
123 |
+
try:
|
124 |
+
results_dictionary[name] = [parse_fn(e) for e in elements]
|
125 |
+
except ValueError:
|
126 |
+
_parse_fail(name, var_type, m_dict["vals"], values)
|
127 |
+
|
128 |
+
|
129 |
+
def _cast_to_type_if_compatible(name, param_type, value):
|
130 |
+
"""Cast hparam to the provided type, if compatible.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
name: Name of the hparam to be cast.
|
134 |
+
param_type: The type of the hparam.
|
135 |
+
value: The value to be cast, if compatible.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
The result of casting `value` to `param_type`.
|
139 |
+
|
140 |
+
Raises:
|
141 |
+
ValueError: If the type of `value` is not compatible with param_type.
|
142 |
+
* If `param_type` is a string type, but `value` is not.
|
143 |
+
* If `param_type` is a boolean, but `value` is not, or vice versa.
|
144 |
+
* If `param_type` is an integer type, but `value` is not.
|
145 |
+
* If `param_type` is a float type, but `value` is not a numeric type.
|
146 |
+
"""
|
147 |
+
fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % (
|
148 |
+
name,
|
149 |
+
param_type,
|
150 |
+
value,
|
151 |
+
)
|
152 |
+
|
153 |
+
# Some callers use None, for which we can't do any casting/checking. :(
|
154 |
+
if issubclass(param_type, type(None)):
|
155 |
+
return value
|
156 |
+
|
157 |
+
# Avoid converting a non-string type to a string.
|
158 |
+
if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance(
|
159 |
+
value, (six.string_types, six.binary_type)
|
160 |
+
):
|
161 |
+
raise ValueError(fail_msg)
|
162 |
+
|
163 |
+
# Avoid converting a number or string type to a boolean or vice versa.
|
164 |
+
if issubclass(param_type, bool) != isinstance(value, bool):
|
165 |
+
raise ValueError(fail_msg)
|
166 |
+
|
167 |
+
# Avoid converting float to an integer (the reverse is fine).
|
168 |
+
if issubclass(param_type, numbers.Integral) and not isinstance(
|
169 |
+
value, numbers.Integral
|
170 |
+
):
|
171 |
+
raise ValueError(fail_msg)
|
172 |
+
|
173 |
+
# Avoid converting a non-numeric type to a numeric type.
|
174 |
+
if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number):
|
175 |
+
raise ValueError(fail_msg)
|
176 |
+
|
177 |
+
return param_type(value)
|
178 |
+
|
179 |
+
|
180 |
+
def parse_values(values, type_map, ignore_unknown=False):
|
181 |
+
"""Parses hyperparameter values from a string into a python map.
|
182 |
+
|
183 |
+
`values` is a string containing comma-separated `name=value` pairs.
|
184 |
+
For each pair, the value of the hyperparameter named `name` is set to
|
185 |
+
`value`.
|
186 |
+
|
187 |
+
If a hyperparameter name appears multiple times in `values`, a ValueError
|
188 |
+
is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
|
189 |
+
|
190 |
+
If a hyperparameter name in both an index assignment and scalar assignment,
|
191 |
+
a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
|
192 |
+
|
193 |
+
The hyperparameter name may contain '.' symbols, which will result in an
|
194 |
+
attribute name that is only accessible through the getattr and setattr
|
195 |
+
functions. (And must be first explicit added through add_hparam.)
|
196 |
+
|
197 |
+
WARNING: Use of '.' in your variable names is allowed, but is not well
|
198 |
+
supported and not recommended.
|
199 |
+
|
200 |
+
The `value` in `name=value` must follows the syntax according to the
|
201 |
+
type of the parameter:
|
202 |
+
|
203 |
+
* Scalar integer: A Python-parsable integer point value. E.g.: 1,
|
204 |
+
100, -12.
|
205 |
+
* Scalar float: A Python-parsable floating point value. E.g.: 1.0,
|
206 |
+
-.54e89.
|
207 |
+
* Boolean: Either true or false.
|
208 |
+
* Scalar string: A non-empty sequence of characters, excluding comma,
|
209 |
+
spaces, and square brackets. E.g.: foo, bar_1.
|
210 |
+
* List: A comma separated list of scalar values of the parameter type
|
211 |
+
enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
|
212 |
+
|
213 |
+
When index assignment is used, the corresponding type_map key should be the
|
214 |
+
list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
|
215 |
+
"arr[1]").
|
216 |
+
|
217 |
+
Args:
|
218 |
+
values: String. Comma separated list of `name=value` pairs where
|
219 |
+
'value' must follow the syntax described above.
|
220 |
+
type_map: A dictionary mapping hyperparameter names to types. Note every
|
221 |
+
parameter name in values must be a key in type_map. The values must
|
222 |
+
conform to the types indicated, where a value V is said to conform to a
|
223 |
+
type T if either V has type T, or V is a list of elements of type T.
|
224 |
+
Hence, for a multidimensional parameter 'x' taking float values,
|
225 |
+
'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
|
226 |
+
ignore_unknown: Bool. Whether values that are missing a type in type_map
|
227 |
+
should be ignored. If set to True, a ValueError will not be raised for
|
228 |
+
unknown hyperparameter type.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
A python map mapping each name to either:
|
232 |
+
* A scalar value.
|
233 |
+
* A list of scalar values.
|
234 |
+
* A dictionary mapping index numbers to scalar values.
|
235 |
+
(e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
|
236 |
+
|
237 |
+
Raises:
|
238 |
+
ValueError: If there is a problem with input.
|
239 |
+
* If `values` cannot be parsed.
|
240 |
+
* If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
|
241 |
+
* If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
|
242 |
+
'a[1]=1,a[1]=2', or 'a=1,a=[1]')
|
243 |
+
"""
|
244 |
+
results_dictionary = {}
|
245 |
+
pos = 0
|
246 |
+
while pos < len(values):
|
247 |
+
m = PARAM_RE.match(values, pos)
|
248 |
+
if not m:
|
249 |
+
raise ValueError("Malformed hyperparameter value: %s" % values[pos:])
|
250 |
+
# Check that there is a comma between parameters and move past it.
|
251 |
+
pos = m.end()
|
252 |
+
# Parse the values.
|
253 |
+
m_dict = m.groupdict()
|
254 |
+
name = m_dict["name"]
|
255 |
+
if name not in type_map:
|
256 |
+
if ignore_unknown:
|
257 |
+
continue
|
258 |
+
raise ValueError("Unknown hyperparameter type for %s" % name)
|
259 |
+
type_ = type_map[name]
|
260 |
+
|
261 |
+
# Set up correct parsing function (depending on whether type_ is a bool)
|
262 |
+
if type_ == bool:
|
263 |
+
|
264 |
+
def parse_bool(value):
|
265 |
+
if value in ["true", "True"]:
|
266 |
+
return True
|
267 |
+
elif value in ["false", "False"]:
|
268 |
+
return False
|
269 |
+
else:
|
270 |
+
try:
|
271 |
+
return bool(int(value))
|
272 |
+
except ValueError:
|
273 |
+
_parse_fail(name, type_, value, values)
|
274 |
+
|
275 |
+
parse = parse_bool
|
276 |
+
else:
|
277 |
+
parse = type_
|
278 |
+
|
279 |
+
# If a singe value is provided
|
280 |
+
if m_dict["val"] is not None:
|
281 |
+
_process_scalar_value(
|
282 |
+
name, parse, type_, m_dict, values, results_dictionary
|
283 |
+
)
|
284 |
+
|
285 |
+
# If the assigned value is a list:
|
286 |
+
elif m_dict["vals"] is not None:
|
287 |
+
_process_list_value(name, parse, type_, m_dict, values, results_dictionary)
|
288 |
+
|
289 |
+
else: # Not assigned a list or value
|
290 |
+
_parse_fail(name, type_, "", values)
|
291 |
+
|
292 |
+
return results_dictionary
|
293 |
+
|
294 |
+
|
295 |
+
class HParams(object):
|
296 |
+
"""Class to hold a set of hyperparameters as name-value pairs.
|
297 |
+
|
298 |
+
A `HParams` object holds hyperparameters used to build and train a model,
|
299 |
+
such as the number of hidden units in a neural net layer or the learning rate
|
300 |
+
to use when training.
|
301 |
+
|
302 |
+
You first create a `HParams` object by specifying the names and values of the
|
303 |
+
hyperparameters.
|
304 |
+
|
305 |
+
To make them easily accessible the parameter names are added as direct
|
306 |
+
attributes of the class. A typical usage is as follows:
|
307 |
+
|
308 |
+
```python
|
309 |
+
# Create a HParams object specifying names and values of the model
|
310 |
+
# hyperparameters:
|
311 |
+
hparams = HParams(learning_rate=0.1, num_hidden_units=100)
|
312 |
+
|
313 |
+
# The hyperparameter are available as attributes of the HParams object:
|
314 |
+
hparams.learning_rate ==> 0.1
|
315 |
+
hparams.num_hidden_units ==> 100
|
316 |
+
```
|
317 |
+
|
318 |
+
Hyperparameters have type, which is inferred from the type of their value
|
319 |
+
passed at construction type. The currently supported types are: integer,
|
320 |
+
float, boolean, string, and list of integer, float, boolean, or string.
|
321 |
+
|
322 |
+
You can override hyperparameter values by calling the
|
323 |
+
[`parse()`](#HParams.parse) method, passing a string of comma separated
|
324 |
+
`name=value` pairs. This is intended to make it possible to override
|
325 |
+
any hyperparameter values from a single command-line flag to which
|
326 |
+
the user passes 'hyper-param=value' pairs. It avoids having to define
|
327 |
+
one flag for each hyperparameter.
|
328 |
+
|
329 |
+
The syntax expected for each value depends on the type of the parameter.
|
330 |
+
See `parse()` for a description of the syntax.
|
331 |
+
|
332 |
+
Example:
|
333 |
+
|
334 |
+
```python
|
335 |
+
# Define a command line flag to pass name=value pairs.
|
336 |
+
# For example using argparse:
|
337 |
+
import argparse
|
338 |
+
parser = argparse.ArgumentParser(description='Train my model.')
|
339 |
+
parser.add_argument('--hparams', type=str,
|
340 |
+
help='Comma separated list of "name=value" pairs.')
|
341 |
+
args = parser.parse_args()
|
342 |
+
...
|
343 |
+
def my_program():
|
344 |
+
# Create a HParams object specifying the names and values of the
|
345 |
+
# model hyperparameters:
|
346 |
+
hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
|
347 |
+
activations=['relu', 'tanh'])
|
348 |
+
|
349 |
+
# Override hyperparameters values by parsing the command line
|
350 |
+
hparams.parse(args.hparams)
|
351 |
+
|
352 |
+
# If the user passed `--hparams=learning_rate=0.3` on the command line
|
353 |
+
# then 'hparams' has the following attributes:
|
354 |
+
hparams.learning_rate ==> 0.3
|
355 |
+
hparams.num_hidden_units ==> 100
|
356 |
+
hparams.activations ==> ['relu', 'tanh']
|
357 |
+
|
358 |
+
# If the hyperparameters are in json format use parse_json:
|
359 |
+
hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
|
360 |
+
```
|
361 |
+
"""
|
362 |
+
|
363 |
+
_HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
|
364 |
+
|
365 |
+
def __init__(self, model_structure=None, **kwargs):
|
366 |
+
"""Create an instance of `HParams` from keyword arguments.
|
367 |
+
|
368 |
+
The keyword arguments specify name-values pairs for the hyperparameters.
|
369 |
+
The parameter types are inferred from the type of the values passed.
|
370 |
+
|
371 |
+
The parameter names are added as attributes of `HParams` object, so they
|
372 |
+
can be accessed directly with the dot notation `hparams._name_`.
|
373 |
+
|
374 |
+
Example:
|
375 |
+
|
376 |
+
```python
|
377 |
+
# Define 3 hyperparameters: 'learning_rate' is a float parameter,
|
378 |
+
# 'num_hidden_units' an integer parameter, and 'activation' a string
|
379 |
+
# parameter.
|
380 |
+
hparams = tf.HParams(
|
381 |
+
learning_rate=0.1, num_hidden_units=100, activation='relu')
|
382 |
+
|
383 |
+
hparams.activation ==> 'relu'
|
384 |
+
```
|
385 |
+
|
386 |
+
Note that a few names are reserved and cannot be used as hyperparameter
|
387 |
+
names. If you use one of the reserved name the constructor raises a
|
388 |
+
`ValueError`.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
model_structure: An instance of ModelStructure, defining the feature
|
392 |
+
crosses to be used in the Trial.
|
393 |
+
**kwargs: Key-value pairs where the key is the hyperparameter name and
|
394 |
+
the value is the value for the parameter.
|
395 |
+
|
396 |
+
Raises:
|
397 |
+
ValueError: If both `hparam_def` and initialization values are provided,
|
398 |
+
or if one of the arguments is invalid.
|
399 |
+
|
400 |
+
"""
|
401 |
+
# Register the hyperparameters and their type in _hparam_types.
|
402 |
+
# This simplifies the implementation of parse().
|
403 |
+
# _hparam_types maps the parameter name to a tuple (type, bool).
|
404 |
+
# The type value is the type of the parameter for scalar hyperparameters,
|
405 |
+
# or the type of the list elements for multidimensional hyperparameters.
|
406 |
+
# The bool value is True if the value is a list, False otherwise.
|
407 |
+
self._hparam_types = {}
|
408 |
+
self._model_structure = model_structure
|
409 |
+
for name, value in six.iteritems(kwargs):
|
410 |
+
self.add_hparam(name, value)
|
411 |
+
|
412 |
+
def add_hparam(self, name, value):
|
413 |
+
"""Adds {name, value} pair to hyperparameters.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
name: Name of the hyperparameter.
|
417 |
+
value: Value of the hyperparameter. Can be one of the following types:
|
418 |
+
int, float, string, int list, float list, or string list.
|
419 |
+
|
420 |
+
Raises:
|
421 |
+
ValueError: if one of the arguments is invalid.
|
422 |
+
"""
|
423 |
+
# Keys in kwargs are unique, but 'name' could the name of a pre-existing
|
424 |
+
# attribute of this object. In that case we refuse to use it as a
|
425 |
+
# hyperparameter name.
|
426 |
+
if getattr(self, name, None) is not None:
|
427 |
+
raise ValueError("Hyperparameter name is reserved: %s" % name)
|
428 |
+
if isinstance(value, (list, tuple)):
|
429 |
+
if not value:
|
430 |
+
raise ValueError(
|
431 |
+
"Multi-valued hyperparameters cannot be empty: %s" % name
|
432 |
+
)
|
433 |
+
self._hparam_types[name] = (type(value[0]), True)
|
434 |
+
else:
|
435 |
+
self._hparam_types[name] = (type(value), False)
|
436 |
+
setattr(self, name, value)
|
437 |
+
|
438 |
+
def set_hparam(self, name, value):
|
439 |
+
"""Set the value of an existing hyperparameter.
|
440 |
+
|
441 |
+
This function verifies that the type of the value matches the type of the
|
442 |
+
existing hyperparameter.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
name: Name of the hyperparameter.
|
446 |
+
value: New value of the hyperparameter.
|
447 |
+
|
448 |
+
Raises:
|
449 |
+
KeyError: If the hyperparameter doesn't exist.
|
450 |
+
ValueError: If there is a type mismatch.
|
451 |
+
"""
|
452 |
+
param_type, is_list = self._hparam_types[name]
|
453 |
+
if isinstance(value, list):
|
454 |
+
if not is_list:
|
455 |
+
raise ValueError(
|
456 |
+
"Must not pass a list for single-valued parameter: %s" % name
|
457 |
+
)
|
458 |
+
setattr(
|
459 |
+
self,
|
460 |
+
name,
|
461 |
+
[_cast_to_type_if_compatible(name, param_type, v) for v in value],
|
462 |
+
)
|
463 |
+
else:
|
464 |
+
if is_list:
|
465 |
+
raise ValueError(
|
466 |
+
"Must pass a list for multi-valued parameter: %s." % name
|
467 |
+
)
|
468 |
+
setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
|
469 |
+
|
470 |
+
def del_hparam(self, name):
|
471 |
+
"""Removes the hyperparameter with key 'name'.
|
472 |
+
|
473 |
+
Does nothing if it isn't present.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
name: Name of the hyperparameter.
|
477 |
+
"""
|
478 |
+
if hasattr(self, name):
|
479 |
+
delattr(self, name)
|
480 |
+
del self._hparam_types[name]
|
481 |
+
|
482 |
+
def parse(self, values):
|
483 |
+
"""Override existing hyperparameter values, parsing new values from a string.
|
484 |
+
|
485 |
+
See parse_values for more detail on the allowed format for values.
|
486 |
+
|
487 |
+
Args:
|
488 |
+
values: String. Comma separated list of `name=value` pairs where 'value'
|
489 |
+
must follow the syntax described above.
|
490 |
+
|
491 |
+
Returns:
|
492 |
+
The `HParams` instance.
|
493 |
+
|
494 |
+
Raises:
|
495 |
+
ValueError: If `values` cannot be parsed or a hyperparameter in `values`
|
496 |
+
doesn't exist.
|
497 |
+
"""
|
498 |
+
type_map = {}
|
499 |
+
for name, t in self._hparam_types.items():
|
500 |
+
param_type, _ = t
|
501 |
+
type_map[name] = param_type
|
502 |
+
|
503 |
+
values_map = parse_values(values, type_map)
|
504 |
+
return self.override_from_dict(values_map)
|
505 |
+
|
506 |
+
def override_from_dict(self, values_dict):
|
507 |
+
"""Override existing hyperparameter values, parsing new values from a dictionary.
|
508 |
+
|
509 |
+
Args:
|
510 |
+
values_dict: Dictionary of name:value pairs.
|
511 |
+
|
512 |
+
Returns:
|
513 |
+
The `HParams` instance.
|
514 |
+
|
515 |
+
Raises:
|
516 |
+
KeyError: If a hyperparameter in `values_dict` doesn't exist.
|
517 |
+
ValueError: If `values_dict` cannot be parsed.
|
518 |
+
"""
|
519 |
+
for name, value in values_dict.items():
|
520 |
+
self.set_hparam(name, value)
|
521 |
+
return self
|
522 |
+
|
523 |
+
def set_model_structure(self, model_structure):
|
524 |
+
self._model_structure = model_structure
|
525 |
+
|
526 |
+
def get_model_structure(self):
|
527 |
+
return self._model_structure
|
528 |
+
|
529 |
+
def to_json(self, indent=None, separators=None, sort_keys=False):
|
530 |
+
"""Serializes the hyperparameters into JSON.
|
531 |
+
|
532 |
+
Args:
|
533 |
+
indent: If a non-negative integer, JSON array elements and object members
|
534 |
+
will be pretty-printed with that indent level. An indent level of 0, or
|
535 |
+
negative, will only insert newlines. `None` (the default) selects the
|
536 |
+
most compact representation.
|
537 |
+
separators: Optional `(item_separator, key_separator)` tuple. Default is
|
538 |
+
`(', ', ': ')`.
|
539 |
+
sort_keys: If `True`, the output dictionaries will be sorted by key.
|
540 |
+
|
541 |
+
Returns:
|
542 |
+
A JSON string.
|
543 |
+
"""
|
544 |
+
|
545 |
+
def remove_callables(x):
|
546 |
+
"""Omit callable elements from input with arbitrary nesting."""
|
547 |
+
if isinstance(x, dict):
|
548 |
+
return {
|
549 |
+
k: remove_callables(v)
|
550 |
+
for k, v in six.iteritems(x)
|
551 |
+
if not callable(v)
|
552 |
+
}
|
553 |
+
elif isinstance(x, list):
|
554 |
+
return [remove_callables(i) for i in x if not callable(i)]
|
555 |
+
return x
|
556 |
+
|
557 |
+
return json.dumps(
|
558 |
+
remove_callables(self.values()),
|
559 |
+
indent=indent,
|
560 |
+
separators=separators,
|
561 |
+
sort_keys=sort_keys,
|
562 |
+
)
|
563 |
+
|
564 |
+
def parse_json(self, values_json):
|
565 |
+
"""Override existing hyperparameter values, parsing new values from a json object.
|
566 |
+
|
567 |
+
Args:
|
568 |
+
values_json: String containing a json object of name:value pairs.
|
569 |
+
|
570 |
+
Returns:
|
571 |
+
The `HParams` instance.
|
572 |
+
|
573 |
+
Raises:
|
574 |
+
KeyError: If a hyperparameter in `values_json` doesn't exist.
|
575 |
+
ValueError: If `values_json` cannot be parsed.
|
576 |
+
"""
|
577 |
+
values_map = json.loads(values_json)
|
578 |
+
return self.override_from_dict(values_map)
|
579 |
+
|
580 |
+
def values(self):
|
581 |
+
"""Return the hyperparameter values as a Python dictionary.
|
582 |
+
|
583 |
+
Returns:
|
584 |
+
A dictionary with hyperparameter names as keys. The values are the
|
585 |
+
hyperparameter values.
|
586 |
+
"""
|
587 |
+
return {n: getattr(self, n) for n in self._hparam_types.keys()}
|
588 |
+
|
589 |
+
def get(self, key, default=None):
|
590 |
+
"""Returns the value of `key` if it exists, else `default`."""
|
591 |
+
if key in self._hparam_types:
|
592 |
+
# Ensure that default is compatible with the parameter type.
|
593 |
+
if default is not None:
|
594 |
+
param_type, is_param_list = self._hparam_types[key]
|
595 |
+
type_str = "list<%s>" % param_type if is_param_list else str(param_type)
|
596 |
+
fail_msg = (
|
597 |
+
"Hparam '%s' of type '%s' is incompatible with "
|
598 |
+
"default=%s" % (key, type_str, default)
|
599 |
+
)
|
600 |
+
|
601 |
+
is_default_list = isinstance(default, list)
|
602 |
+
if is_param_list != is_default_list:
|
603 |
+
raise ValueError(fail_msg)
|
604 |
+
|
605 |
+
try:
|
606 |
+
if is_default_list:
|
607 |
+
for value in default:
|
608 |
+
_cast_to_type_if_compatible(key, param_type, value)
|
609 |
+
else:
|
610 |
+
_cast_to_type_if_compatible(key, param_type, default)
|
611 |
+
except ValueError as e:
|
612 |
+
raise ValueError("%s. %s" % (fail_msg, e))
|
613 |
+
|
614 |
+
return getattr(self, key)
|
615 |
+
|
616 |
+
return default
|
617 |
+
|
618 |
+
def __contains__(self, key):
|
619 |
+
return key in self._hparam_types
|
620 |
+
|
621 |
+
def __str__(self):
|
622 |
+
return str(sorted(self.values().items()))
|
623 |
+
|
624 |
+
def __repr__(self):
|
625 |
+
return "%s(%s)" % (type(self).__name__, self.__str__())
|
626 |
+
|
627 |
+
@staticmethod
|
628 |
+
def _get_kind_name(param_type, is_list):
|
629 |
+
"""Returns the field name given parameter type and is_list.
|
630 |
+
|
631 |
+
Args:
|
632 |
+
param_type: Data type of the hparam.
|
633 |
+
is_list: Whether this is a list.
|
634 |
+
|
635 |
+
Returns:
|
636 |
+
A string representation of the field name.
|
637 |
+
|
638 |
+
Raises:
|
639 |
+
ValueError: If parameter type is not recognized.
|
640 |
+
"""
|
641 |
+
if issubclass(param_type, bool):
|
642 |
+
# This check must happen before issubclass(param_type, six.integer_types),
|
643 |
+
# since Python considers bool to be a subclass of int.
|
644 |
+
typename = "bool"
|
645 |
+
elif issubclass(param_type, six.integer_types):
|
646 |
+
# Setting 'int' and 'long' types to be 'int64' to ensure the type is
|
647 |
+
# compatible with both Python2 and Python3.
|
648 |
+
typename = "int64"
|
649 |
+
elif issubclass(param_type, (six.string_types, six.binary_type)):
|
650 |
+
# Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
|
651 |
+
# compatible with both Python2 and Python3.
|
652 |
+
typename = "bytes"
|
653 |
+
elif issubclass(param_type, float):
|
654 |
+
typename = "float"
|
655 |
+
else:
|
656 |
+
raise ValueError("Unsupported parameter type: %s" % str(param_type))
|
657 |
+
|
658 |
+
suffix = "list" if is_list else "value"
|
659 |
+
return "_".join([typename, suffix])
|
utils/util.py
ADDED
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import collections
|
8 |
+
import glob
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import time
|
12 |
+
import argparse
|
13 |
+
from collections import OrderedDict
|
14 |
+
|
15 |
+
import json5
|
16 |
+
import numpy as np
|
17 |
+
import glob
|
18 |
+
from torch.nn import functional as F
|
19 |
+
|
20 |
+
|
21 |
+
try:
|
22 |
+
from ruamel.yaml import YAML as yaml
|
23 |
+
except:
|
24 |
+
from ruamel_yaml import YAML as yaml # type: ignore
|
25 |
+
|
26 |
+
import torch
|
27 |
+
|
28 |
+
from utils.hparam import HParams
|
29 |
+
import logging
|
30 |
+
from logging import handlers
|
31 |
+
|
32 |
+
|
33 |
+
def str2bool(v):
|
34 |
+
"""Used in argparse.ArgumentParser.add_argument to indicate
|
35 |
+
that a type is a bool type and user can enter
|
36 |
+
|
37 |
+
- yes, true, t, y, 1, to represent True
|
38 |
+
- no, false, f, n, 0, to represent False
|
39 |
+
|
40 |
+
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
41 |
+
"""
|
42 |
+
if isinstance(v, bool):
|
43 |
+
return v
|
44 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
45 |
+
return True
|
46 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
47 |
+
return False
|
48 |
+
else:
|
49 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
50 |
+
|
51 |
+
|
52 |
+
def find_checkpoint_of_mapper(mapper_ckpt_dir):
|
53 |
+
mapper_ckpts = glob.glob(os.path.join(mapper_ckpt_dir, "ckpts/*.pt"))
|
54 |
+
|
55 |
+
# Select the max steps
|
56 |
+
mapper_ckpts.sort()
|
57 |
+
mapper_weights_file = mapper_ckpts[-1]
|
58 |
+
return mapper_weights_file
|
59 |
+
|
60 |
+
|
61 |
+
def pad_f0_to_tensors(f0s, batched=None):
|
62 |
+
# Initialize
|
63 |
+
tensors = []
|
64 |
+
|
65 |
+
if batched == None:
|
66 |
+
# Get the max frame for padding
|
67 |
+
size = -1
|
68 |
+
for f0 in f0s:
|
69 |
+
size = max(size, f0.shape[-1])
|
70 |
+
|
71 |
+
tensor = torch.zeros(len(f0s), size)
|
72 |
+
|
73 |
+
for i, f0 in enumerate(f0s):
|
74 |
+
tensor[i, : f0.shape[-1]] = f0[:]
|
75 |
+
|
76 |
+
tensors.append(tensor)
|
77 |
+
else:
|
78 |
+
start = 0
|
79 |
+
while start + batched - 1 < len(f0s):
|
80 |
+
end = start + batched - 1
|
81 |
+
|
82 |
+
# Get the max frame for padding
|
83 |
+
size = -1
|
84 |
+
for i in range(start, end + 1):
|
85 |
+
size = max(size, f0s[i].shape[-1])
|
86 |
+
|
87 |
+
tensor = torch.zeros(batched, size)
|
88 |
+
|
89 |
+
for i in range(start, end + 1):
|
90 |
+
tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:]
|
91 |
+
|
92 |
+
tensors.append(tensor)
|
93 |
+
|
94 |
+
start = start + batched
|
95 |
+
|
96 |
+
if start != len(f0s):
|
97 |
+
end = len(f0s)
|
98 |
+
|
99 |
+
# Get the max frame for padding
|
100 |
+
size = -1
|
101 |
+
for i in range(start, end):
|
102 |
+
size = max(size, f0s[i].shape[-1])
|
103 |
+
|
104 |
+
tensor = torch.zeros(len(f0s) - start, size)
|
105 |
+
|
106 |
+
for i in range(start, end):
|
107 |
+
tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:]
|
108 |
+
|
109 |
+
tensors.append(tensor)
|
110 |
+
|
111 |
+
return tensors
|
112 |
+
|
113 |
+
|
114 |
+
def pad_mels_to_tensors(mels, batched=None):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
mels: A list of mel-specs
|
118 |
+
Returns:
|
119 |
+
tensors: A list of tensors containing the batched mel-specs
|
120 |
+
mel_frames: A list of tensors containing the frames of the original mel-specs
|
121 |
+
"""
|
122 |
+
# Initialize
|
123 |
+
tensors = []
|
124 |
+
mel_frames = []
|
125 |
+
|
126 |
+
# Split mel-specs into batches to avoid cuda memory exceed
|
127 |
+
if batched == None:
|
128 |
+
# Get the max frame for padding
|
129 |
+
size = -1
|
130 |
+
for mel in mels:
|
131 |
+
size = max(size, mel.shape[-1])
|
132 |
+
|
133 |
+
tensor = torch.zeros(len(mels), mels[0].shape[0], size)
|
134 |
+
mel_frame = torch.zeros(len(mels), dtype=torch.int32)
|
135 |
+
|
136 |
+
for i, mel in enumerate(mels):
|
137 |
+
tensor[i, :, : mel.shape[-1]] = mel[:]
|
138 |
+
mel_frame[i] = mel.shape[-1]
|
139 |
+
|
140 |
+
tensors.append(tensor)
|
141 |
+
mel_frames.append(mel_frame)
|
142 |
+
else:
|
143 |
+
start = 0
|
144 |
+
while start + batched - 1 < len(mels):
|
145 |
+
end = start + batched - 1
|
146 |
+
|
147 |
+
# Get the max frame for padding
|
148 |
+
size = -1
|
149 |
+
for i in range(start, end + 1):
|
150 |
+
size = max(size, mels[i].shape[-1])
|
151 |
+
|
152 |
+
tensor = torch.zeros(batched, mels[0].shape[0], size)
|
153 |
+
mel_frame = torch.zeros(batched, dtype=torch.int32)
|
154 |
+
|
155 |
+
for i in range(start, end + 1):
|
156 |
+
tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:]
|
157 |
+
mel_frame[i - start] = mels[i].shape[-1]
|
158 |
+
|
159 |
+
tensors.append(tensor)
|
160 |
+
mel_frames.append(mel_frame)
|
161 |
+
|
162 |
+
start = start + batched
|
163 |
+
|
164 |
+
if start != len(mels):
|
165 |
+
end = len(mels)
|
166 |
+
|
167 |
+
# Get the max frame for padding
|
168 |
+
size = -1
|
169 |
+
for i in range(start, end):
|
170 |
+
size = max(size, mels[i].shape[-1])
|
171 |
+
|
172 |
+
tensor = torch.zeros(len(mels) - start, mels[0].shape[0], size)
|
173 |
+
mel_frame = torch.zeros(len(mels) - start, dtype=torch.int32)
|
174 |
+
|
175 |
+
for i in range(start, end):
|
176 |
+
tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:]
|
177 |
+
mel_frame[i - start] = mels[i].shape[-1]
|
178 |
+
|
179 |
+
tensors.append(tensor)
|
180 |
+
mel_frames.append(mel_frame)
|
181 |
+
|
182 |
+
return tensors, mel_frames
|
183 |
+
|
184 |
+
|
185 |
+
def load_model_config(args):
|
186 |
+
"""Load model configurations (in args.json under checkpoint directory)
|
187 |
+
|
188 |
+
Args:
|
189 |
+
args (ArgumentParser): arguments to run bins/preprocess.py
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
dict: dictionary that stores model configurations
|
193 |
+
"""
|
194 |
+
if args.checkpoint_dir is None:
|
195 |
+
assert args.checkpoint_file is not None
|
196 |
+
checkpoint_dir = os.path.split(args.checkpoint_file)[0]
|
197 |
+
else:
|
198 |
+
checkpoint_dir = args.checkpoint_dir
|
199 |
+
config_path = os.path.join(checkpoint_dir, "args.json")
|
200 |
+
print("config_path: ", config_path)
|
201 |
+
|
202 |
+
config = load_config(config_path)
|
203 |
+
return config
|
204 |
+
|
205 |
+
|
206 |
+
def remove_and_create(dir):
|
207 |
+
if os.path.exists(dir):
|
208 |
+
os.system("rm -r {}".format(dir))
|
209 |
+
os.makedirs(dir, exist_ok=True)
|
210 |
+
|
211 |
+
|
212 |
+
def has_existed(path, warning=False):
|
213 |
+
if not warning:
|
214 |
+
return os.path.exists(path)
|
215 |
+
|
216 |
+
if os.path.exists(path):
|
217 |
+
answer = input(
|
218 |
+
"The path {} has existed. \nInput 'y' (or hit Enter) to skip it, and input 'n' to re-write it [y/n]\n".format(
|
219 |
+
path
|
220 |
+
)
|
221 |
+
)
|
222 |
+
if not answer == "n":
|
223 |
+
return True
|
224 |
+
|
225 |
+
return False
|
226 |
+
|
227 |
+
|
228 |
+
def remove_older_ckpt(saved_model_name, checkpoint_dir, max_to_keep=5):
|
229 |
+
if os.path.exists(os.path.join(checkpoint_dir, "checkpoint")):
|
230 |
+
with open(os.path.join(checkpoint_dir, "checkpoint"), "r") as f:
|
231 |
+
ckpts = [x.strip() for x in f.readlines()]
|
232 |
+
else:
|
233 |
+
ckpts = []
|
234 |
+
ckpts.append(saved_model_name)
|
235 |
+
for item in ckpts[:-max_to_keep]:
|
236 |
+
if os.path.exists(os.path.join(checkpoint_dir, item)):
|
237 |
+
os.remove(os.path.join(checkpoint_dir, item))
|
238 |
+
with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as f:
|
239 |
+
for item in ckpts[-max_to_keep:]:
|
240 |
+
f.write("{}\n".format(item))
|
241 |
+
|
242 |
+
|
243 |
+
def set_all_random_seed(seed: int):
|
244 |
+
random.seed(seed)
|
245 |
+
np.random.seed(seed)
|
246 |
+
torch.random.manual_seed(seed)
|
247 |
+
|
248 |
+
|
249 |
+
def save_checkpoint(
|
250 |
+
args,
|
251 |
+
generator,
|
252 |
+
g_optimizer,
|
253 |
+
step,
|
254 |
+
discriminator=None,
|
255 |
+
d_optimizer=None,
|
256 |
+
max_to_keep=5,
|
257 |
+
):
|
258 |
+
saved_model_name = "model.ckpt-{}.pt".format(step)
|
259 |
+
checkpoint_path = os.path.join(args.checkpoint_dir, saved_model_name)
|
260 |
+
|
261 |
+
if discriminator and d_optimizer:
|
262 |
+
torch.save(
|
263 |
+
{
|
264 |
+
"generator": generator.state_dict(),
|
265 |
+
"discriminator": discriminator.state_dict(),
|
266 |
+
"g_optimizer": g_optimizer.state_dict(),
|
267 |
+
"d_optimizer": d_optimizer.state_dict(),
|
268 |
+
"global_step": step,
|
269 |
+
},
|
270 |
+
checkpoint_path,
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
torch.save(
|
274 |
+
{
|
275 |
+
"generator": generator.state_dict(),
|
276 |
+
"g_optimizer": g_optimizer.state_dict(),
|
277 |
+
"global_step": step,
|
278 |
+
},
|
279 |
+
checkpoint_path,
|
280 |
+
)
|
281 |
+
|
282 |
+
print("Saved checkpoint: {}".format(checkpoint_path))
|
283 |
+
|
284 |
+
if os.path.exists(os.path.join(args.checkpoint_dir, "checkpoint")):
|
285 |
+
with open(os.path.join(args.checkpoint_dir, "checkpoint"), "r") as f:
|
286 |
+
ckpts = [x.strip() for x in f.readlines()]
|
287 |
+
else:
|
288 |
+
ckpts = []
|
289 |
+
ckpts.append(saved_model_name)
|
290 |
+
for item in ckpts[:-max_to_keep]:
|
291 |
+
if os.path.exists(os.path.join(args.checkpoint_dir, item)):
|
292 |
+
os.remove(os.path.join(args.checkpoint_dir, item))
|
293 |
+
with open(os.path.join(args.checkpoint_dir, "checkpoint"), "w") as f:
|
294 |
+
for item in ckpts[-max_to_keep:]:
|
295 |
+
f.write("{}\n".format(item))
|
296 |
+
|
297 |
+
|
298 |
+
def attempt_to_restore(
|
299 |
+
generator, g_optimizer, checkpoint_dir, discriminator=None, d_optimizer=None
|
300 |
+
):
|
301 |
+
checkpoint_list = os.path.join(checkpoint_dir, "checkpoint")
|
302 |
+
if os.path.exists(checkpoint_list):
|
303 |
+
checkpoint_filename = open(checkpoint_list).readlines()[-1].strip()
|
304 |
+
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(checkpoint_filename))
|
305 |
+
print("Restore from {}".format(checkpoint_path))
|
306 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
307 |
+
if generator:
|
308 |
+
if not list(generator.state_dict().keys())[0].startswith("module."):
|
309 |
+
raw_dict = checkpoint["generator"]
|
310 |
+
clean_dict = OrderedDict()
|
311 |
+
for k, v in raw_dict.items():
|
312 |
+
if k.startswith("module."):
|
313 |
+
clean_dict[k[7:]] = v
|
314 |
+
else:
|
315 |
+
clean_dict[k] = v
|
316 |
+
generator.load_state_dict(clean_dict)
|
317 |
+
else:
|
318 |
+
generator.load_state_dict(checkpoint["generator"])
|
319 |
+
if g_optimizer:
|
320 |
+
g_optimizer.load_state_dict(checkpoint["g_optimizer"])
|
321 |
+
global_step = 100000
|
322 |
+
if discriminator and "discriminator" in checkpoint.keys():
|
323 |
+
discriminator.load_state_dict(checkpoint["discriminator"])
|
324 |
+
global_step = checkpoint["global_step"]
|
325 |
+
print("restore discriminator")
|
326 |
+
if d_optimizer and "d_optimizer" in checkpoint.keys():
|
327 |
+
d_optimizer.load_state_dict(checkpoint["d_optimizer"])
|
328 |
+
print("restore d_optimizer...")
|
329 |
+
else:
|
330 |
+
global_step = 0
|
331 |
+
return global_step
|
332 |
+
|
333 |
+
|
334 |
+
class ExponentialMovingAverage(object):
|
335 |
+
def __init__(self, decay):
|
336 |
+
self.decay = decay
|
337 |
+
self.shadow = {}
|
338 |
+
|
339 |
+
def register(self, name, val):
|
340 |
+
self.shadow[name] = val.clone()
|
341 |
+
|
342 |
+
def update(self, name, x):
|
343 |
+
assert name in self.shadow
|
344 |
+
update_delta = self.shadow[name] - x
|
345 |
+
self.shadow[name] -= (1.0 - self.decay) * update_delta
|
346 |
+
|
347 |
+
|
348 |
+
def apply_moving_average(model, ema):
|
349 |
+
for name, param in model.named_parameters():
|
350 |
+
if name in ema.shadow:
|
351 |
+
ema.update(name, param.data)
|
352 |
+
|
353 |
+
|
354 |
+
def register_model_to_ema(model, ema):
|
355 |
+
for name, param in model.named_parameters():
|
356 |
+
if param.requires_grad:
|
357 |
+
ema.register(name, param.data)
|
358 |
+
|
359 |
+
|
360 |
+
class YParams(HParams):
|
361 |
+
def __init__(self, yaml_file):
|
362 |
+
if not os.path.exists(yaml_file):
|
363 |
+
raise IOError("yaml file: {} is not existed".format(yaml_file))
|
364 |
+
super().__init__()
|
365 |
+
self.d = collections.OrderedDict()
|
366 |
+
with open(yaml_file) as fp:
|
367 |
+
for _, v in yaml().load(fp).items():
|
368 |
+
for k1, v1 in v.items():
|
369 |
+
try:
|
370 |
+
if self.get(k1):
|
371 |
+
self.set_hparam(k1, v1)
|
372 |
+
else:
|
373 |
+
self.add_hparam(k1, v1)
|
374 |
+
self.d[k1] = v1
|
375 |
+
except Exception:
|
376 |
+
import traceback
|
377 |
+
|
378 |
+
print(traceback.format_exc())
|
379 |
+
|
380 |
+
# @property
|
381 |
+
def get_elements(self):
|
382 |
+
return self.d.items()
|
383 |
+
|
384 |
+
|
385 |
+
def override_config(base_config, new_config):
|
386 |
+
"""Update new configurations in the original dict with the new dict
|
387 |
+
|
388 |
+
Args:
|
389 |
+
base_config (dict): original dict to be overridden
|
390 |
+
new_config (dict): dict with new configurations
|
391 |
+
|
392 |
+
Returns:
|
393 |
+
dict: updated configuration dict
|
394 |
+
"""
|
395 |
+
for k, v in new_config.items():
|
396 |
+
if type(v) == dict:
|
397 |
+
if k not in base_config.keys():
|
398 |
+
base_config[k] = {}
|
399 |
+
base_config[k] = override_config(base_config[k], v)
|
400 |
+
else:
|
401 |
+
base_config[k] = v
|
402 |
+
return base_config
|
403 |
+
|
404 |
+
|
405 |
+
def get_lowercase_keys_config(cfg):
|
406 |
+
"""Change all keys in cfg to lower case
|
407 |
+
|
408 |
+
Args:
|
409 |
+
cfg (dict): dictionary that stores configurations
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
dict: dictionary that stores configurations
|
413 |
+
"""
|
414 |
+
updated_cfg = dict()
|
415 |
+
for k, v in cfg.items():
|
416 |
+
if type(v) == dict:
|
417 |
+
v = get_lowercase_keys_config(v)
|
418 |
+
updated_cfg[k.lower()] = v
|
419 |
+
return updated_cfg
|
420 |
+
|
421 |
+
|
422 |
+
def _load_config(config_fn, lowercase=False):
|
423 |
+
"""Load configurations into a dictionary
|
424 |
+
|
425 |
+
Args:
|
426 |
+
config_fn (str): path to configuration file
|
427 |
+
lowercase (bool, optional): whether changing keys to lower case. Defaults to False.
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
dict: dictionary that stores configurations
|
431 |
+
"""
|
432 |
+
with open(config_fn, "r") as f:
|
433 |
+
data = f.read()
|
434 |
+
config_ = json5.loads(data)
|
435 |
+
if "base_config" in config_:
|
436 |
+
# load configurations from new path
|
437 |
+
p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"])
|
438 |
+
p_config_ = _load_config(p_config_path)
|
439 |
+
config_ = override_config(p_config_, config_)
|
440 |
+
if lowercase:
|
441 |
+
# change keys in config_ to lower case
|
442 |
+
config_ = get_lowercase_keys_config(config_)
|
443 |
+
return config_
|
444 |
+
|
445 |
+
|
446 |
+
def load_config(config_fn, lowercase=False):
|
447 |
+
"""Load configurations into a dictionary
|
448 |
+
|
449 |
+
Args:
|
450 |
+
config_fn (str): path to configuration file
|
451 |
+
lowercase (bool, optional): _description_. Defaults to False.
|
452 |
+
|
453 |
+
Returns:
|
454 |
+
JsonHParams: an object that stores configurations
|
455 |
+
"""
|
456 |
+
config_ = _load_config(config_fn, lowercase=lowercase)
|
457 |
+
# create an JsonHParams object with configuration dict
|
458 |
+
cfg = JsonHParams(**config_)
|
459 |
+
return cfg
|
460 |
+
|
461 |
+
|
462 |
+
def save_config(save_path, cfg):
|
463 |
+
"""Save configurations into a json file
|
464 |
+
|
465 |
+
Args:
|
466 |
+
save_path (str): path to save configurations
|
467 |
+
cfg (dict): dictionary that stores configurations
|
468 |
+
"""
|
469 |
+
with open(save_path, "w") as f:
|
470 |
+
json5.dump(
|
471 |
+
cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True
|
472 |
+
)
|
473 |
+
|
474 |
+
|
475 |
+
class JsonHParams:
|
476 |
+
def __init__(self, **kwargs):
|
477 |
+
for k, v in kwargs.items():
|
478 |
+
if type(v) == dict:
|
479 |
+
v = JsonHParams(**v)
|
480 |
+
self[k] = v
|
481 |
+
|
482 |
+
def keys(self):
|
483 |
+
return self.__dict__.keys()
|
484 |
+
|
485 |
+
def items(self):
|
486 |
+
return self.__dict__.items()
|
487 |
+
|
488 |
+
def values(self):
|
489 |
+
return self.__dict__.values()
|
490 |
+
|
491 |
+
def __len__(self):
|
492 |
+
return len(self.__dict__)
|
493 |
+
|
494 |
+
def __getitem__(self, key):
|
495 |
+
return getattr(self, key)
|
496 |
+
|
497 |
+
def __setitem__(self, key, value):
|
498 |
+
return setattr(self, key, value)
|
499 |
+
|
500 |
+
def __contains__(self, key):
|
501 |
+
return key in self.__dict__
|
502 |
+
|
503 |
+
def __repr__(self):
|
504 |
+
return self.__dict__.__repr__()
|
505 |
+
|
506 |
+
|
507 |
+
class ValueWindow:
|
508 |
+
def __init__(self, window_size=100):
|
509 |
+
self._window_size = window_size
|
510 |
+
self._values = []
|
511 |
+
|
512 |
+
def append(self, x):
|
513 |
+
self._values = self._values[-(self._window_size - 1) :] + [x]
|
514 |
+
|
515 |
+
@property
|
516 |
+
def sum(self):
|
517 |
+
return sum(self._values)
|
518 |
+
|
519 |
+
@property
|
520 |
+
def count(self):
|
521 |
+
return len(self._values)
|
522 |
+
|
523 |
+
@property
|
524 |
+
def average(self):
|
525 |
+
return self.sum / max(1, self.count)
|
526 |
+
|
527 |
+
def reset(self):
|
528 |
+
self._values = []
|
529 |
+
|
530 |
+
|
531 |
+
class Logger(object):
|
532 |
+
def __init__(
|
533 |
+
self,
|
534 |
+
filename,
|
535 |
+
level="info",
|
536 |
+
when="D",
|
537 |
+
backCount=10,
|
538 |
+
fmt="%(asctime)s : %(message)s",
|
539 |
+
):
|
540 |
+
self.level_relations = {
|
541 |
+
"debug": logging.DEBUG,
|
542 |
+
"info": logging.INFO,
|
543 |
+
"warning": logging.WARNING,
|
544 |
+
"error": logging.ERROR,
|
545 |
+
"crit": logging.CRITICAL,
|
546 |
+
}
|
547 |
+
if level == "debug":
|
548 |
+
fmt = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
549 |
+
self.logger = logging.getLogger(filename)
|
550 |
+
format_str = logging.Formatter(fmt)
|
551 |
+
self.logger.setLevel(self.level_relations.get(level))
|
552 |
+
sh = logging.StreamHandler()
|
553 |
+
sh.setFormatter(format_str)
|
554 |
+
th = handlers.TimedRotatingFileHandler(
|
555 |
+
filename=filename, when=when, backupCount=backCount, encoding="utf-8"
|
556 |
+
)
|
557 |
+
th.setFormatter(format_str)
|
558 |
+
self.logger.addHandler(sh)
|
559 |
+
self.logger.addHandler(th)
|
560 |
+
self.logger.info(
|
561 |
+
"==========================New Starting Here=============================="
|
562 |
+
)
|
563 |
+
|
564 |
+
|
565 |
+
def init_weights(m, mean=0.0, std=0.01):
|
566 |
+
classname = m.__class__.__name__
|
567 |
+
if classname.find("Conv") != -1:
|
568 |
+
m.weight.data.normal_(mean, std)
|
569 |
+
|
570 |
+
|
571 |
+
def get_padding(kernel_size, dilation=1):
|
572 |
+
return int((kernel_size * dilation - dilation) / 2)
|
573 |
+
|
574 |
+
|
575 |
+
def slice_segments(x, ids_str, segment_size=4):
|
576 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
577 |
+
for i in range(x.size(0)):
|
578 |
+
idx_str = ids_str[i]
|
579 |
+
idx_end = idx_str + segment_size
|
580 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
581 |
+
return ret
|
582 |
+
|
583 |
+
|
584 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
585 |
+
b, d, t = x.size()
|
586 |
+
if x_lengths is None:
|
587 |
+
x_lengths = t
|
588 |
+
ids_str_max = x_lengths - segment_size + 1
|
589 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
590 |
+
ret = slice_segments(x, ids_str, segment_size)
|
591 |
+
return ret, ids_str
|
592 |
+
|
593 |
+
|
594 |
+
def subsequent_mask(length):
|
595 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
596 |
+
return mask
|
597 |
+
|
598 |
+
|
599 |
+
@torch.jit.script
|
600 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
601 |
+
n_channels_int = n_channels[0]
|
602 |
+
in_act = input_a + input_b
|
603 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
604 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
605 |
+
acts = t_act * s_act
|
606 |
+
return acts
|
607 |
+
|
608 |
+
|
609 |
+
def convert_pad_shape(pad_shape):
|
610 |
+
l = pad_shape[::-1]
|
611 |
+
pad_shape = [item for sublist in l for item in sublist]
|
612 |
+
return pad_shape
|
613 |
+
|
614 |
+
|
615 |
+
def sequence_mask(length, max_length=None):
|
616 |
+
if max_length is None:
|
617 |
+
max_length = length.max()
|
618 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
619 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
620 |
+
|
621 |
+
|
622 |
+
def generate_path(duration, mask):
|
623 |
+
"""
|
624 |
+
duration: [b, 1, t_x]
|
625 |
+
mask: [b, 1, t_y, t_x]
|
626 |
+
"""
|
627 |
+
device = duration.device
|
628 |
+
|
629 |
+
b, _, t_y, t_x = mask.shape
|
630 |
+
cum_duration = torch.cumsum(duration, -1)
|
631 |
+
|
632 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
633 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
634 |
+
path = path.view(b, t_x, t_y)
|
635 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
636 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
637 |
+
return path
|
638 |
+
|
639 |
+
|
640 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
641 |
+
if isinstance(parameters, torch.Tensor):
|
642 |
+
parameters = [parameters]
|
643 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
644 |
+
norm_type = float(norm_type)
|
645 |
+
if clip_value is not None:
|
646 |
+
clip_value = float(clip_value)
|
647 |
+
|
648 |
+
total_norm = 0
|
649 |
+
for p in parameters:
|
650 |
+
param_norm = p.grad.data.norm(norm_type)
|
651 |
+
total_norm += param_norm.item() ** norm_type
|
652 |
+
if clip_value is not None:
|
653 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
654 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
655 |
+
return total_norm
|
656 |
+
|
657 |
+
|
658 |
+
def get_current_time():
|
659 |
+
pass
|
660 |
+
|
661 |
+
|
662 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
663 |
+
"""
|
664 |
+
Args:
|
665 |
+
lengths:
|
666 |
+
A 1-D tensor containing sentence lengths.
|
667 |
+
max_len:
|
668 |
+
The length of masks.
|
669 |
+
Returns:
|
670 |
+
Return a 2-D bool tensor, where masked positions
|
671 |
+
are filled with `True` and non-masked positions are
|
672 |
+
filled with `False`.
|
673 |
+
|
674 |
+
>>> lengths = torch.tensor([1, 3, 2, 5])
|
675 |
+
>>> make_pad_mask(lengths)
|
676 |
+
tensor([[False, True, True, True, True],
|
677 |
+
[False, False, False, True, True],
|
678 |
+
[False, False, True, True, True],
|
679 |
+
[False, False, False, False, False]])
|
680 |
+
"""
|
681 |
+
assert lengths.ndim == 1, lengths.ndim
|
682 |
+
max_len = max(max_len, lengths.max())
|
683 |
+
n = lengths.size(0)
|
684 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
685 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
686 |
+
|
687 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|