Update train_boson_mixed_precision.py
Browse files- train_boson_mixed_precision.py +19 -21
train_boson_mixed_precision.py
CHANGED
|
@@ -81,21 +81,20 @@ class AudioDataset(Dataset):
|
|
| 81 |
audio_path = self.audio_paths[idx]
|
| 82 |
|
| 83 |
try:
|
| 84 |
-
|
| 85 |
audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
|
| 86 |
|
| 87 |
-
|
| 88 |
if len(audio) > self.segment_length:
|
| 89 |
if self.is_train:
|
| 90 |
start = random.randint(0, len(audio) - self.segment_length)
|
| 91 |
else:
|
| 92 |
-
start = 0
|
| 93 |
audio = audio[start:start + self.segment_length]
|
| 94 |
else:
|
| 95 |
# Pad if too short
|
| 96 |
audio = np.pad(audio, (0, self.segment_length - len(audio)))
|
| 97 |
-
|
| 98 |
-
# Convert to tensor and add batch dimension
|
| 99 |
audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
|
| 100 |
|
| 101 |
return audio_tensor, audio_path
|
|
@@ -242,7 +241,7 @@ class BosonTrainer:
|
|
| 242 |
# return discriminator
|
| 243 |
|
| 244 |
def build_discriminator(self):
|
| 245 |
-
|
| 246 |
discriminator = Discriminator(
|
| 247 |
rates=[], # No multi-rate discriminator
|
| 248 |
periods=[2, 3, 5, 7, 11],
|
|
@@ -257,7 +256,7 @@ class BosonTrainer:
|
|
| 257 |
return discriminator
|
| 258 |
|
| 259 |
def setup_losses(self):
|
| 260 |
-
|
| 261 |
# Basic losses
|
| 262 |
self.l1_loss = L1Loss()
|
| 263 |
self.stft_loss = MultiScaleSTFTLoss(
|
|
@@ -277,16 +276,15 @@ class BosonTrainer:
|
|
| 277 |
log_weight=1.0,
|
| 278 |
)
|
| 279 |
|
| 280 |
-
|
| 281 |
if self.discriminator is not None:
|
| 282 |
self.gan_loss = GANLoss(self.discriminator)
|
| 283 |
|
| 284 |
-
|
| 285 |
self.loss_weights = {
|
| 286 |
'rec': 1., # Waveform L1 loss
|
| 287 |
'stft': 1., # Multi-scale STFT loss
|
| 288 |
-
'mel': 45.0, # Mel-spectrogram loss
|
| 289 |
-
#'mel': 0.0, # Mel-spectrogram loss (DISABLED)
|
| 290 |
'commit': 0.25, # Commitment loss
|
| 291 |
'semantic': 1., # Semantic loss
|
| 292 |
'gen': 1., # Generator adversarial loss
|
|
@@ -294,7 +292,7 @@ class BosonTrainer:
|
|
| 294 |
}
|
| 295 |
|
| 296 |
def setup_data_loaders(self):
|
| 297 |
-
|
| 298 |
# Split data into train/val
|
| 299 |
df = pd.read_csv(self.args.data_csv)
|
| 300 |
n_total = len(df)
|
|
@@ -308,7 +306,7 @@ class BosonTrainer:
|
|
| 308 |
df[:n_train].to_csv(train_csv, index=False)
|
| 309 |
df[n_train:].to_csv(val_csv, index=False)
|
| 310 |
|
| 311 |
-
|
| 312 |
if self.distributed:
|
| 313 |
dist.barrier()
|
| 314 |
|
|
@@ -392,11 +390,11 @@ class BosonTrainer:
|
|
| 392 |
output, commit_loss, semantic_loss, _ = self.model(audio, bw)
|
| 393 |
recons_signal = AudioSignal(output, self.config['sample_rate'])
|
| 394 |
|
| 395 |
-
|
| 396 |
use_discriminator = (self.discriminator is not None and
|
| 397 |
self.global_step >= self.args.discriminator_start_step)
|
| 398 |
|
| 399 |
-
|
| 400 |
if use_discriminator and self.global_step % self.args.disc_interval == 0:
|
| 401 |
self.optimizer_d.zero_grad()
|
| 402 |
|
|
@@ -426,7 +424,7 @@ class BosonTrainer:
|
|
| 426 |
losses['rec'] = self.l1_loss(recons_signal, audio_signal)
|
| 427 |
losses['stft'] = self.stft_loss(recons_signal, audio_signal)
|
| 428 |
losses['mel'] = self.mel_loss(recons_signal, audio_signal)
|
| 429 |
-
# losses['mel'] = torch.tensor(0.0, device=self.device) #
|
| 430 |
losses['commit'] = commit_loss
|
| 431 |
losses['semantic'] = semantic_loss
|
| 432 |
|
|
@@ -513,7 +511,7 @@ class BosonTrainer:
|
|
| 513 |
'commit': 0, 'semantic': 0
|
| 514 |
}
|
| 515 |
|
| 516 |
-
|
| 517 |
audio_samples = {'train': [], 'val': []}
|
| 518 |
|
| 519 |
for batch_idx, (audio, paths) in enumerate(tqdm(self.val_loader, desc='Validation', disable=not self.is_main_process())):
|
|
@@ -712,7 +710,7 @@ class BosonTrainer:
|
|
| 712 |
|
| 713 |
|
| 714 |
def load_checkpoint(self):
|
| 715 |
-
|
| 716 |
checkpoint_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth')
|
| 717 |
if os.path.exists(checkpoint_path):
|
| 718 |
print(f"Loading checkpoint from {checkpoint_path}")
|
|
@@ -734,7 +732,7 @@ class BosonTrainer:
|
|
| 734 |
if 'scheduler_g_last_epoch' in checkpoint:
|
| 735 |
self.scheduler_g.last_epoch = checkpoint['scheduler_g_last_epoch']
|
| 736 |
else:
|
| 737 |
-
|
| 738 |
self.scheduler_g.last_epoch = checkpoint['global_step']
|
| 739 |
|
| 740 |
# Force scheduler to recompute its internal state
|
|
@@ -761,7 +759,7 @@ class BosonTrainer:
|
|
| 761 |
|
| 762 |
self.scheduler_d._last_lr = self.scheduler_d.get_lr()
|
| 763 |
|
| 764 |
-
|
| 765 |
if self.scaler_d is not None and 'scaler_d_state_dict' in checkpoint:
|
| 766 |
self.scaler_d.load_state_dict(checkpoint['scaler_d_state_dict'])
|
| 767 |
|
|
@@ -798,7 +796,7 @@ class BosonTrainer:
|
|
| 798 |
print(f"Next step checkpoint at: step {((self.global_step // self.args.save_step_interval) + 1) * self.args.save_step_interval}")
|
| 799 |
print(f"{'='*60}\n")
|
| 800 |
|
| 801 |
-
|
| 802 |
if self.global_step > 0:
|
| 803 |
temp_scheduler = CosineWarmupScheduler(
|
| 804 |
self.optimizer_g,
|
|
|
|
| 81 |
audio_path = self.audio_paths[idx]
|
| 82 |
|
| 83 |
try:
|
| 84 |
+
|
| 85 |
audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
|
| 86 |
|
| 87 |
+
=
|
| 88 |
if len(audio) > self.segment_length:
|
| 89 |
if self.is_train:
|
| 90 |
start = random.randint(0, len(audio) - self.segment_length)
|
| 91 |
else:
|
| 92 |
+
start = 0 =
|
| 93 |
audio = audio[start:start + self.segment_length]
|
| 94 |
else:
|
| 95 |
# Pad if too short
|
| 96 |
audio = np.pad(audio, (0, self.segment_length - len(audio)))
|
| 97 |
+
|
|
|
|
| 98 |
audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
|
| 99 |
|
| 100 |
return audio_tensor, audio_path
|
|
|
|
| 241 |
# return discriminator
|
| 242 |
|
| 243 |
def build_discriminator(self):
|
| 244 |
+
|
| 245 |
discriminator = Discriminator(
|
| 246 |
rates=[], # No multi-rate discriminator
|
| 247 |
periods=[2, 3, 5, 7, 11],
|
|
|
|
| 256 |
return discriminator
|
| 257 |
|
| 258 |
def setup_losses(self):
|
| 259 |
+
|
| 260 |
# Basic losses
|
| 261 |
self.l1_loss = L1Loss()
|
| 262 |
self.stft_loss = MultiScaleSTFTLoss(
|
|
|
|
| 276 |
log_weight=1.0,
|
| 277 |
)
|
| 278 |
|
| 279 |
+
|
| 280 |
if self.discriminator is not None:
|
| 281 |
self.gan_loss = GANLoss(self.discriminator)
|
| 282 |
|
| 283 |
+
|
| 284 |
self.loss_weights = {
|
| 285 |
'rec': 1., # Waveform L1 loss
|
| 286 |
'stft': 1., # Multi-scale STFT loss
|
| 287 |
+
'mel': 45.0, # Mel-spectrogram loss
|
|
|
|
| 288 |
'commit': 0.25, # Commitment loss
|
| 289 |
'semantic': 1., # Semantic loss
|
| 290 |
'gen': 1., # Generator adversarial loss
|
|
|
|
| 292 |
}
|
| 293 |
|
| 294 |
def setup_data_loaders(self):
|
| 295 |
+
|
| 296 |
# Split data into train/val
|
| 297 |
df = pd.read_csv(self.args.data_csv)
|
| 298 |
n_total = len(df)
|
|
|
|
| 306 |
df[:n_train].to_csv(train_csv, index=False)
|
| 307 |
df[n_train:].to_csv(val_csv, index=False)
|
| 308 |
|
| 309 |
+
|
| 310 |
if self.distributed:
|
| 311 |
dist.barrier()
|
| 312 |
|
|
|
|
| 390 |
output, commit_loss, semantic_loss, _ = self.model(audio, bw)
|
| 391 |
recons_signal = AudioSignal(output, self.config['sample_rate'])
|
| 392 |
|
| 393 |
+
|
| 394 |
use_discriminator = (self.discriminator is not None and
|
| 395 |
self.global_step >= self.args.discriminator_start_step)
|
| 396 |
|
| 397 |
+
|
| 398 |
if use_discriminator and self.global_step % self.args.disc_interval == 0:
|
| 399 |
self.optimizer_d.zero_grad()
|
| 400 |
|
|
|
|
| 424 |
losses['rec'] = self.l1_loss(recons_signal, audio_signal)
|
| 425 |
losses['stft'] = self.stft_loss(recons_signal, audio_signal)
|
| 426 |
losses['mel'] = self.mel_loss(recons_signal, audio_signal)
|
| 427 |
+
# losses['mel'] = torch.tensor(0.0, device=self.device) # uncomment this for the first 30k steps, it's faster if you pretrain it on semantic / commit loss first
|
| 428 |
losses['commit'] = commit_loss
|
| 429 |
losses['semantic'] = semantic_loss
|
| 430 |
|
|
|
|
| 511 |
'commit': 0, 'semantic': 0
|
| 512 |
}
|
| 513 |
|
| 514 |
+
|
| 515 |
audio_samples = {'train': [], 'val': []}
|
| 516 |
|
| 517 |
for batch_idx, (audio, paths) in enumerate(tqdm(self.val_loader, desc='Validation', disable=not self.is_main_process())):
|
|
|
|
| 710 |
|
| 711 |
|
| 712 |
def load_checkpoint(self):
|
| 713 |
+
|
| 714 |
checkpoint_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth')
|
| 715 |
if os.path.exists(checkpoint_path):
|
| 716 |
print(f"Loading checkpoint from {checkpoint_path}")
|
|
|
|
| 732 |
if 'scheduler_g_last_epoch' in checkpoint:
|
| 733 |
self.scheduler_g.last_epoch = checkpoint['scheduler_g_last_epoch']
|
| 734 |
else:
|
| 735 |
+
|
| 736 |
self.scheduler_g.last_epoch = checkpoint['global_step']
|
| 737 |
|
| 738 |
# Force scheduler to recompute its internal state
|
|
|
|
| 759 |
|
| 760 |
self.scheduler_d._last_lr = self.scheduler_d.get_lr()
|
| 761 |
|
| 762 |
+
|
| 763 |
if self.scaler_d is not None and 'scaler_d_state_dict' in checkpoint:
|
| 764 |
self.scaler_d.load_state_dict(checkpoint['scaler_d_state_dict'])
|
| 765 |
|
|
|
|
| 796 |
print(f"Next step checkpoint at: step {((self.global_step // self.args.save_step_interval) + 1) * self.args.save_step_interval}")
|
| 797 |
print(f"{'='*60}\n")
|
| 798 |
|
| 799 |
+
g
|
| 800 |
if self.global_step > 0:
|
| 801 |
temp_scheduler = CosineWarmupScheduler(
|
| 802 |
self.optimizer_g,
|