Update TTS/tts/layers/xtts/dvae.py
Browse files- TTS/tts/layers/xtts/dvae.py +10 -18
TTS/tts/layers/xtts/dvae.py
CHANGED
@@ -24,9 +24,7 @@ def eval_decorator(fn):
|
|
24 |
return inner
|
25 |
|
26 |
|
27 |
-
def dvae_wav_to_mel(
|
28 |
-
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
|
29 |
-
):
|
30 |
mel_stft = torchaudio.transforms.MelSpectrogram(
|
31 |
n_fft=1024,
|
32 |
hop_length=256,
|
@@ -44,7 +42,7 @@ def dvae_wav_to_mel(
|
|
44 |
mel = torch.log(torch.clamp(mel, min=1e-5))
|
45 |
if mel_norms is None:
|
46 |
mel_norms = torch.load(mel_norms_file, map_location=device)
|
47 |
-
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
|
48 |
return mel
|
49 |
|
50 |
|
@@ -112,7 +110,7 @@ class Quantize(nn.Module):
|
|
112 |
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
|
113 |
n = self.cluster_size.sum()
|
114 |
cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
115 |
-
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
116 |
self.embed.data.copy_(embed_normalized)
|
117 |
|
118 |
diff = (quantize.detach() - input).pow(2).mean()
|
@@ -198,6 +196,7 @@ class UpsampledConv(nn.Module):
|
|
198 |
|
199 |
# DiscreteVAE partially derived from lucidrains DALLE implementation
|
200 |
# Credit: https://github.com/lucidrains/DALLE-pytorch
|
|
|
201 |
class DiscreteVAE(nn.Module):
|
202 |
def __init__(
|
203 |
self,
|
@@ -215,7 +214,7 @@ class DiscreteVAE(nn.Module):
|
|
215 |
activation="relu",
|
216 |
smooth_l1_loss=False,
|
217 |
straight_through=False,
|
218 |
-
normalization=None,
|
219 |
record_codes=False,
|
220 |
discretization_loss_averaging_steps=100,
|
221 |
lr_quantizer_args={},
|
@@ -231,7 +230,7 @@ class DiscreteVAE(nn.Module):
|
|
231 |
num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps
|
232 |
)
|
233 |
|
234 |
-
assert positional_dims > 0 and positional_dims < 3
|
235 |
if positional_dims == 2:
|
236 |
conv = nn.Conv2d
|
237 |
conv_transpose = nn.ConvTranspose2d
|
@@ -246,7 +245,7 @@ class DiscreteVAE(nn.Module):
|
|
246 |
elif activation == "silu":
|
247 |
act = nn.SiLU
|
248 |
else:
|
249 |
-
|
250 |
|
251 |
enc_layers = []
|
252 |
dec_layers = []
|
@@ -293,7 +292,6 @@ class DiscreteVAE(nn.Module):
|
|
293 |
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
294 |
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
295 |
|
296 |
-
# take care of normalization within class
|
297 |
self.normalization = normalization
|
298 |
self.record_codes = record_codes
|
299 |
if record_codes:
|
@@ -303,19 +301,18 @@ class DiscreteVAE(nn.Module):
|
|
303 |
self.internal_step = 0
|
304 |
|
305 |
def norm(self, images):
|
306 |
-
if
|
307 |
return images
|
308 |
|
309 |
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
310 |
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
311 |
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
312 |
images = images.clone()
|
313 |
-
images.sub_(means).div_(stds)
|
314 |
return images
|
315 |
|
316 |
def get_debug_values(self, step, __):
|
317 |
if self.record_codes and self.total_codes > 0:
|
318 |
-
# Report annealing schedule
|
319 |
return {"histogram_codes": self.codes[: self.total_codes]}
|
320 |
else:
|
321 |
return {}
|
@@ -356,9 +353,6 @@ class DiscreteVAE(nn.Module):
|
|
356 |
sampled, codes, commitment_loss = self.codebook(logits)
|
357 |
return self.decode(codes)
|
358 |
|
359 |
-
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
360 |
-
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
|
361 |
-
# more lossy (but useful for determining network performance).
|
362 |
def forward(self, img):
|
363 |
img = self.norm(img)
|
364 |
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
|
@@ -371,16 +365,13 @@ class DiscreteVAE(nn.Module):
|
|
371 |
out = d(out)
|
372 |
self.log_codes(codes)
|
373 |
else:
|
374 |
-
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
375 |
out, _ = self.decode(codes)
|
376 |
|
377 |
-
# reconstruction loss
|
378 |
recon_loss = self.loss_fn(img, out, reduction="none")
|
379 |
|
380 |
return recon_loss, commitment_loss, out
|
381 |
|
382 |
def log_codes(self, codes):
|
383 |
-
# This is so we can debug the distribution of codes being learned.
|
384 |
if self.record_codes and self.internal_step % 10 == 0:
|
385 |
codes = codes.flatten()
|
386 |
l = codes.shape[0]
|
@@ -391,3 +382,4 @@ class DiscreteVAE(nn.Module):
|
|
391 |
self.code_ind = 0
|
392 |
self.total_codes += 1
|
393 |
self.internal_step += 1
|
|
|
|
24 |
return inner
|
25 |
|
26 |
|
27 |
+
def dvae_wav_to_mel(wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")):
|
|
|
|
|
28 |
mel_stft = torchaudio.transforms.MelSpectrogram(
|
29 |
n_fft=1024,
|
30 |
hop_length=256,
|
|
|
42 |
mel = torch.log(torch.clamp(mel, min=1e-5))
|
43 |
if mel_norms is None:
|
44 |
mel_norms = torch.load(mel_norms_file, map_location=device)
|
45 |
+
mel = mel / (mel_norms.unsqueeze(0).unsqueeze(-1) + 1e-8) # Adicionando um valor pequeno para evitar divisão por zero
|
46 |
return mel
|
47 |
|
48 |
|
|
|
110 |
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
|
111 |
n = self.cluster_size.sum()
|
112 |
cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
113 |
+
embed_normalized = self.embed_avg / (cluster_size.unsqueeze(0) + self.eps) # Adicionando eps para evitar divisão por zero
|
114 |
self.embed.data.copy_(embed_normalized)
|
115 |
|
116 |
diff = (quantize.detach() - input).pow(2).mean()
|
|
|
196 |
|
197 |
# DiscreteVAE partially derived from lucidrains DALLE implementation
|
198 |
# Credit: https://github.com/lucidrains/DALLE-pytorch
|
199 |
+
|
200 |
class DiscreteVAE(nn.Module):
|
201 |
def __init__(
|
202 |
self,
|
|
|
214 |
activation="relu",
|
215 |
smooth_l1_loss=False,
|
216 |
straight_through=False,
|
217 |
+
normalization=None,
|
218 |
record_codes=False,
|
219 |
discretization_loss_averaging_steps=100,
|
220 |
lr_quantizer_args={},
|
|
|
230 |
num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps
|
231 |
)
|
232 |
|
233 |
+
assert positional_dims > 0 and positional_dims < 3
|
234 |
if positional_dims == 2:
|
235 |
conv = nn.Conv2d
|
236 |
conv_transpose = nn.ConvTranspose2d
|
|
|
245 |
elif activation == "silu":
|
246 |
act = nn.SiLU
|
247 |
else:
|
248 |
+
raise NotImplementedError()
|
249 |
|
250 |
enc_layers = []
|
251 |
dec_layers = []
|
|
|
292 |
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
293 |
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
294 |
|
|
|
295 |
self.normalization = normalization
|
296 |
self.record_codes = record_codes
|
297 |
if record_codes:
|
|
|
301 |
self.internal_step = 0
|
302 |
|
303 |
def norm(self, images):
|
304 |
+
if self.normalization is None:
|
305 |
return images
|
306 |
|
307 |
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
308 |
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
309 |
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
310 |
images = images.clone()
|
311 |
+
images.sub_(means).div_(stds + 1e-8) # Adicionando eps para evitar divisão por zero
|
312 |
return images
|
313 |
|
314 |
def get_debug_values(self, step, __):
|
315 |
if self.record_codes and self.total_codes > 0:
|
|
|
316 |
return {"histogram_codes": self.codes[: self.total_codes]}
|
317 |
else:
|
318 |
return {}
|
|
|
353 |
sampled, codes, commitment_loss = self.codebook(logits)
|
354 |
return self.decode(codes)
|
355 |
|
|
|
|
|
|
|
356 |
def forward(self, img):
|
357 |
img = self.norm(img)
|
358 |
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
|
|
|
365 |
out = d(out)
|
366 |
self.log_codes(codes)
|
367 |
else:
|
|
|
368 |
out, _ = self.decode(codes)
|
369 |
|
|
|
370 |
recon_loss = self.loss_fn(img, out, reduction="none")
|
371 |
|
372 |
return recon_loss, commitment_loss, out
|
373 |
|
374 |
def log_codes(self, codes):
|
|
|
375 |
if self.record_codes and self.internal_step % 10 == 0:
|
376 |
codes = codes.flatten()
|
377 |
l = codes.shape[0]
|
|
|
382 |
self.code_ind = 0
|
383 |
self.total_codes += 1
|
384 |
self.internal_step += 1
|
385 |
+
|