|
import time |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import optim |
|
from torch.utils.data import DataLoader |
|
|
|
import vocoder.hparams as hp |
|
from vocoder.display import stream, simple_table |
|
from vocoder.distribution import discretized_mix_logistic_loss |
|
from vocoder.gen_wavernn import gen_testset |
|
from vocoder.models.fatchord_version import WaveRNN |
|
from vocoder.vocoder_dataset import VocoderDataset, collate_vocoder |
|
|
|
|
|
def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, save_every: int, |
|
backup_every: int, force_restart: bool): |
|
|
|
assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length |
|
|
|
|
|
print("Initializing the model...") |
|
model = WaveRNN( |
|
rnn_dims=hp.voc_rnn_dims, |
|
fc_dims=hp.voc_fc_dims, |
|
bits=hp.bits, |
|
pad=hp.voc_pad, |
|
upsample_factors=hp.voc_upsample_factors, |
|
feat_dims=hp.num_mels, |
|
compute_dims=hp.voc_compute_dims, |
|
res_out_dims=hp.voc_res_out_dims, |
|
res_blocks=hp.voc_res_blocks, |
|
hop_length=hp.hop_length, |
|
sample_rate=hp.sample_rate, |
|
mode=hp.voc_mode |
|
) |
|
|
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
|
|
|
|
optimizer = optim.Adam(model.parameters()) |
|
for p in optimizer.param_groups: |
|
p["lr"] = hp.voc_lr |
|
loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss |
|
|
|
|
|
model_dir = models_dir / run_id |
|
model_dir.mkdir(exist_ok=True) |
|
weights_fpath = model_dir / "vocoder.pt" |
|
if force_restart or not weights_fpath.exists(): |
|
print("\nStarting the training of WaveRNN from scratch\n") |
|
model.save(weights_fpath, optimizer) |
|
else: |
|
print("\nLoading weights at %s" % weights_fpath) |
|
model.load(weights_fpath, optimizer) |
|
print("WaveRNN weights loaded from step %d" % model.step) |
|
|
|
|
|
metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \ |
|
voc_dir.joinpath("synthesized.txt") |
|
mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta") |
|
wav_dir = syn_dir.joinpath("audio") |
|
dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir) |
|
test_loader = DataLoader(dataset, batch_size=1, shuffle=True) |
|
|
|
|
|
simple_table([('Batch size', hp.voc_batch_size), |
|
('LR', hp.voc_lr), |
|
('Sequence Len', hp.voc_seq_len)]) |
|
|
|
for epoch in range(1, 350): |
|
data_loader = DataLoader(dataset, hp.voc_batch_size, shuffle=True, num_workers=2, collate_fn=collate_vocoder) |
|
start = time.time() |
|
running_loss = 0. |
|
|
|
for i, (x, y, m) in enumerate(data_loader, 1): |
|
if torch.cuda.is_available(): |
|
x, m, y = x.cuda(), m.cuda(), y.cuda() |
|
|
|
|
|
y_hat = model(x, m) |
|
if model.mode == 'RAW': |
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1) |
|
elif model.mode == 'MOL': |
|
y = y.float() |
|
y = y.unsqueeze(-1) |
|
|
|
|
|
loss = loss_func(y_hat, y) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
running_loss += loss.item() |
|
speed = i / (time.time() - start) |
|
avg_loss = running_loss / i |
|
|
|
step = model.get_step() |
|
k = step // 1000 |
|
|
|
if backup_every != 0 and step % backup_every == 0 : |
|
model.checkpoint(model_dir, optimizer) |
|
|
|
if save_every != 0 and step % save_every == 0 : |
|
model.save(weights_fpath, optimizer) |
|
|
|
msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \ |
|
f"Loss: {avg_loss:.4f} | {speed:.1f} " \ |
|
f"steps/s | Step: {k}k | " |
|
stream(msg) |
|
|
|
|
|
gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched, |
|
hp.voc_target, hp.voc_overlap, model_dir) |
|
print("") |
|
|