brunodox commited on
Commit
0d56d2e
·
verified ·
1 Parent(s): a41ac79

Update TTS/tts/layers/xtts/dvae.py

Browse files
Files changed (1) hide show
  1. TTS/tts/layers/xtts/dvae.py +10 -18
TTS/tts/layers/xtts/dvae.py CHANGED
@@ -24,9 +24,7 @@ def eval_decorator(fn):
24
  return inner
25
 
26
 
27
- def dvae_wav_to_mel(
28
- wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
29
- ):
30
  mel_stft = torchaudio.transforms.MelSpectrogram(
31
  n_fft=1024,
32
  hop_length=256,
@@ -44,7 +42,7 @@ def dvae_wav_to_mel(
44
  mel = torch.log(torch.clamp(mel, min=1e-5))
45
  if mel_norms is None:
46
  mel_norms = torch.load(mel_norms_file, map_location=device)
47
- mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
48
  return mel
49
 
50
 
@@ -112,7 +110,7 @@ class Quantize(nn.Module):
112
  self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
113
  n = self.cluster_size.sum()
114
  cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
115
- embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
116
  self.embed.data.copy_(embed_normalized)
117
 
118
  diff = (quantize.detach() - input).pow(2).mean()
@@ -198,6 +196,7 @@ class UpsampledConv(nn.Module):
198
 
199
  # DiscreteVAE partially derived from lucidrains DALLE implementation
200
  # Credit: https://github.com/lucidrains/DALLE-pytorch
 
201
  class DiscreteVAE(nn.Module):
202
  def __init__(
203
  self,
@@ -215,7 +214,7 @@ class DiscreteVAE(nn.Module):
215
  activation="relu",
216
  smooth_l1_loss=False,
217
  straight_through=False,
218
- normalization=None, # ((0.5,) * 3, (0.5,) * 3),
219
  record_codes=False,
220
  discretization_loss_averaging_steps=100,
221
  lr_quantizer_args={},
@@ -231,7 +230,7 @@ class DiscreteVAE(nn.Module):
231
  num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps
232
  )
233
 
234
- assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
235
  if positional_dims == 2:
236
  conv = nn.Conv2d
237
  conv_transpose = nn.ConvTranspose2d
@@ -246,7 +245,7 @@ class DiscreteVAE(nn.Module):
246
  elif activation == "silu":
247
  act = nn.SiLU
248
  else:
249
- assert NotImplementedError()
250
 
251
  enc_layers = []
252
  dec_layers = []
@@ -293,7 +292,6 @@ class DiscreteVAE(nn.Module):
293
  self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
294
  self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
295
 
296
- # take care of normalization within class
297
  self.normalization = normalization
298
  self.record_codes = record_codes
299
  if record_codes:
@@ -303,19 +301,18 @@ class DiscreteVAE(nn.Module):
303
  self.internal_step = 0
304
 
305
  def norm(self, images):
306
- if not self.normalization is not None:
307
  return images
308
 
309
  means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
310
  arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
311
  means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
312
  images = images.clone()
313
- images.sub_(means).div_(stds)
314
  return images
315
 
316
  def get_debug_values(self, step, __):
317
  if self.record_codes and self.total_codes > 0:
318
- # Report annealing schedule
319
  return {"histogram_codes": self.codes[: self.total_codes]}
320
  else:
321
  return {}
@@ -356,9 +353,6 @@ class DiscreteVAE(nn.Module):
356
  sampled, codes, commitment_loss = self.codebook(logits)
357
  return self.decode(codes)
358
 
359
- # Note: This module is not meant to be run in forward() except while training. It has special logic which performs
360
- # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
361
- # more lossy (but useful for determining network performance).
362
  def forward(self, img):
363
  img = self.norm(img)
364
  logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
@@ -371,16 +365,13 @@ class DiscreteVAE(nn.Module):
371
  out = d(out)
372
  self.log_codes(codes)
373
  else:
374
- # This is non-differentiable, but gives a better idea of how the network is actually performing.
375
  out, _ = self.decode(codes)
376
 
377
- # reconstruction loss
378
  recon_loss = self.loss_fn(img, out, reduction="none")
379
 
380
  return recon_loss, commitment_loss, out
381
 
382
  def log_codes(self, codes):
383
- # This is so we can debug the distribution of codes being learned.
384
  if self.record_codes and self.internal_step % 10 == 0:
385
  codes = codes.flatten()
386
  l = codes.shape[0]
@@ -391,3 +382,4 @@ class DiscreteVAE(nn.Module):
391
  self.code_ind = 0
392
  self.total_codes += 1
393
  self.internal_step += 1
 
 
24
  return inner
25
 
26
 
27
+ def dvae_wav_to_mel(wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")):
 
 
28
  mel_stft = torchaudio.transforms.MelSpectrogram(
29
  n_fft=1024,
30
  hop_length=256,
 
42
  mel = torch.log(torch.clamp(mel, min=1e-5))
43
  if mel_norms is None:
44
  mel_norms = torch.load(mel_norms_file, map_location=device)
45
+ mel = mel / (mel_norms.unsqueeze(0).unsqueeze(-1) + 1e-8) # Adicionando um valor pequeno para evitar divisão por zero
46
  return mel
47
 
48
 
 
110
  self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
111
  n = self.cluster_size.sum()
112
  cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
113
+ embed_normalized = self.embed_avg / (cluster_size.unsqueeze(0) + self.eps) # Adicionando eps para evitar divisão por zero
114
  self.embed.data.copy_(embed_normalized)
115
 
116
  diff = (quantize.detach() - input).pow(2).mean()
 
196
 
197
  # DiscreteVAE partially derived from lucidrains DALLE implementation
198
  # Credit: https://github.com/lucidrains/DALLE-pytorch
199
+
200
  class DiscreteVAE(nn.Module):
201
  def __init__(
202
  self,
 
214
  activation="relu",
215
  smooth_l1_loss=False,
216
  straight_through=False,
217
+ normalization=None,
218
  record_codes=False,
219
  discretization_loss_averaging_steps=100,
220
  lr_quantizer_args={},
 
230
  num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps
231
  )
232
 
233
+ assert positional_dims > 0 and positional_dims < 3
234
  if positional_dims == 2:
235
  conv = nn.Conv2d
236
  conv_transpose = nn.ConvTranspose2d
 
245
  elif activation == "silu":
246
  act = nn.SiLU
247
  else:
248
+ raise NotImplementedError()
249
 
250
  enc_layers = []
251
  dec_layers = []
 
292
  self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
293
  self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
294
 
 
295
  self.normalization = normalization
296
  self.record_codes = record_codes
297
  if record_codes:
 
301
  self.internal_step = 0
302
 
303
  def norm(self, images):
304
+ if self.normalization is None:
305
  return images
306
 
307
  means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
308
  arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
309
  means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
310
  images = images.clone()
311
+ images.sub_(means).div_(stds + 1e-8) # Adicionando eps para evitar divisão por zero
312
  return images
313
 
314
  def get_debug_values(self, step, __):
315
  if self.record_codes and self.total_codes > 0:
 
316
  return {"histogram_codes": self.codes[: self.total_codes]}
317
  else:
318
  return {}
 
353
  sampled, codes, commitment_loss = self.codebook(logits)
354
  return self.decode(codes)
355
 
 
 
 
356
  def forward(self, img):
357
  img = self.norm(img)
358
  logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
 
365
  out = d(out)
366
  self.log_codes(codes)
367
  else:
 
368
  out, _ = self.decode(codes)
369
 
 
370
  recon_loss = self.loss_fn(img, out, reduction="none")
371
 
372
  return recon_loss, commitment_loss, out
373
 
374
  def log_codes(self, codes):
 
375
  if self.record_codes and self.internal_step % 10 == 0:
376
  codes = codes.flatten()
377
  l = codes.shape[0]
 
382
  self.code_ind = 0
383
  self.total_codes += 1
384
  self.internal_step += 1
385
+