Respair commited on
Commit
bab13b0
·
verified ·
1 Parent(s): 6ff8078

Update train_boson_mixed_precision.py

Browse files
Files changed (1) hide show
  1. train_boson_mixed_precision.py +19 -21
train_boson_mixed_precision.py CHANGED
@@ -81,21 +81,20 @@ class AudioDataset(Dataset):
81
  audio_path = self.audio_paths[idx]
82
 
83
  try:
84
- # Load audio using librosa
85
  audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
86
 
87
- # Random segment extraction for training
88
  if len(audio) > self.segment_length:
89
  if self.is_train:
90
  start = random.randint(0, len(audio) - self.segment_length)
91
  else:
92
- start = 0 # Always use beginning for validation
93
  audio = audio[start:start + self.segment_length]
94
  else:
95
  # Pad if too short
96
  audio = np.pad(audio, (0, self.segment_length - len(audio)))
97
-
98
- # Convert to tensor and add batch dimension
99
  audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
100
 
101
  return audio_tensor, audio_path
@@ -242,7 +241,7 @@ class BosonTrainer:
242
  # return discriminator
243
 
244
  def build_discriminator(self):
245
- """Build discriminator with DDP if needed"""
246
  discriminator = Discriminator(
247
  rates=[], # No multi-rate discriminator
248
  periods=[2, 3, 5, 7, 11],
@@ -257,7 +256,7 @@ class BosonTrainer:
257
  return discriminator
258
 
259
  def setup_losses(self):
260
- """Setup all loss functions"""
261
  # Basic losses
262
  self.l1_loss = L1Loss()
263
  self.stft_loss = MultiScaleSTFTLoss(
@@ -277,16 +276,15 @@ class BosonTrainer:
277
  log_weight=1.0,
278
  )
279
 
280
- # GAN loss if using discriminator
281
  if self.discriminator is not None:
282
  self.gan_loss = GANLoss(self.discriminator)
283
 
284
- # Loss weights (matching DAC's proven configuration)
285
  self.loss_weights = {
286
  'rec': 1., # Waveform L1 loss
287
  'stft': 1., # Multi-scale STFT loss
288
- 'mel': 45.0, # Mel-spectrogram loss (DISABLED)
289
- #'mel': 0.0, # Mel-spectrogram loss (DISABLED)
290
  'commit': 0.25, # Commitment loss
291
  'semantic': 1., # Semantic loss
292
  'gen': 1., # Generator adversarial loss
@@ -294,7 +292,7 @@ class BosonTrainer:
294
  }
295
 
296
  def setup_data_loaders(self):
297
- """Setup data loaders (distributed or single GPU)"""
298
  # Split data into train/val
299
  df = pd.read_csv(self.args.data_csv)
300
  n_total = len(df)
@@ -308,7 +306,7 @@ class BosonTrainer:
308
  df[:n_train].to_csv(train_csv, index=False)
309
  df[n_train:].to_csv(val_csv, index=False)
310
 
311
- # Synchronize across processes if distributed
312
  if self.distributed:
313
  dist.barrier()
314
 
@@ -392,11 +390,11 @@ class BosonTrainer:
392
  output, commit_loss, semantic_loss, _ = self.model(audio, bw)
393
  recons_signal = AudioSignal(output, self.config['sample_rate'])
394
 
395
- # Check if discriminator should be active (after discriminator_start_step)
396
  use_discriminator = (self.discriminator is not None and
397
  self.global_step >= self.args.discriminator_start_step)
398
 
399
- # Train discriminator first if using GAN and past the start step
400
  if use_discriminator and self.global_step % self.args.disc_interval == 0:
401
  self.optimizer_d.zero_grad()
402
 
@@ -426,7 +424,7 @@ class BosonTrainer:
426
  losses['rec'] = self.l1_loss(recons_signal, audio_signal)
427
  losses['stft'] = self.stft_loss(recons_signal, audio_signal)
428
  losses['mel'] = self.mel_loss(recons_signal, audio_signal)
429
- # losses['mel'] = torch.tensor(0.0, device=self.device) # 15.
430
  losses['commit'] = commit_loss
431
  losses['semantic'] = semantic_loss
432
 
@@ -513,7 +511,7 @@ class BosonTrainer:
513
  'commit': 0, 'semantic': 0
514
  }
515
 
516
- # Store audio samples for tensorboard
517
  audio_samples = {'train': [], 'val': []}
518
 
519
  for batch_idx, (audio, paths) in enumerate(tqdm(self.val_loader, desc='Validation', disable=not self.is_main_process())):
@@ -712,7 +710,7 @@ class BosonTrainer:
712
 
713
 
714
  def load_checkpoint(self):
715
- """Load checkpoint with proper state restoration"""
716
  checkpoint_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth')
717
  if os.path.exists(checkpoint_path):
718
  print(f"Loading checkpoint from {checkpoint_path}")
@@ -734,7 +732,7 @@ class BosonTrainer:
734
  if 'scheduler_g_last_epoch' in checkpoint:
735
  self.scheduler_g.last_epoch = checkpoint['scheduler_g_last_epoch']
736
  else:
737
- # Fallback: use global_step if the explicit value wasn't saved
738
  self.scheduler_g.last_epoch = checkpoint['global_step']
739
 
740
  # Force scheduler to recompute its internal state
@@ -761,7 +759,7 @@ class BosonTrainer:
761
 
762
  self.scheduler_d._last_lr = self.scheduler_d.get_lr()
763
 
764
- # Load discriminator gradient scaler state if using mixed precision
765
  if self.scaler_d is not None and 'scaler_d_state_dict' in checkpoint:
766
  self.scaler_d.load_state_dict(checkpoint['scaler_d_state_dict'])
767
 
@@ -798,7 +796,7 @@ class BosonTrainer:
798
  print(f"Next step checkpoint at: step {((self.global_step // self.args.save_step_interval) + 1) * self.args.save_step_interval}")
799
  print(f"{'='*60}\n")
800
 
801
- # Double-check by creating a fresh scheduler and comparing
802
  if self.global_step > 0:
803
  temp_scheduler = CosineWarmupScheduler(
804
  self.optimizer_g,
 
81
  audio_path = self.audio_paths[idx]
82
 
83
  try:
84
+
85
  audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
86
 
87
+ =
88
  if len(audio) > self.segment_length:
89
  if self.is_train:
90
  start = random.randint(0, len(audio) - self.segment_length)
91
  else:
92
+ start = 0 =
93
  audio = audio[start:start + self.segment_length]
94
  else:
95
  # Pad if too short
96
  audio = np.pad(audio, (0, self.segment_length - len(audio)))
97
+
 
98
  audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
99
 
100
  return audio_tensor, audio_path
 
241
  # return discriminator
242
 
243
  def build_discriminator(self):
244
+
245
  discriminator = Discriminator(
246
  rates=[], # No multi-rate discriminator
247
  periods=[2, 3, 5, 7, 11],
 
256
  return discriminator
257
 
258
  def setup_losses(self):
259
+
260
  # Basic losses
261
  self.l1_loss = L1Loss()
262
  self.stft_loss = MultiScaleSTFTLoss(
 
276
  log_weight=1.0,
277
  )
278
 
279
+
280
  if self.discriminator is not None:
281
  self.gan_loss = GANLoss(self.discriminator)
282
 
283
+
284
  self.loss_weights = {
285
  'rec': 1., # Waveform L1 loss
286
  'stft': 1., # Multi-scale STFT loss
287
+ 'mel': 45.0, # Mel-spectrogram loss
 
288
  'commit': 0.25, # Commitment loss
289
  'semantic': 1., # Semantic loss
290
  'gen': 1., # Generator adversarial loss
 
292
  }
293
 
294
  def setup_data_loaders(self):
295
+
296
  # Split data into train/val
297
  df = pd.read_csv(self.args.data_csv)
298
  n_total = len(df)
 
306
  df[:n_train].to_csv(train_csv, index=False)
307
  df[n_train:].to_csv(val_csv, index=False)
308
 
309
+
310
  if self.distributed:
311
  dist.barrier()
312
 
 
390
  output, commit_loss, semantic_loss, _ = self.model(audio, bw)
391
  recons_signal = AudioSignal(output, self.config['sample_rate'])
392
 
393
+
394
  use_discriminator = (self.discriminator is not None and
395
  self.global_step >= self.args.discriminator_start_step)
396
 
397
+
398
  if use_discriminator and self.global_step % self.args.disc_interval == 0:
399
  self.optimizer_d.zero_grad()
400
 
 
424
  losses['rec'] = self.l1_loss(recons_signal, audio_signal)
425
  losses['stft'] = self.stft_loss(recons_signal, audio_signal)
426
  losses['mel'] = self.mel_loss(recons_signal, audio_signal)
427
+ # losses['mel'] = torch.tensor(0.0, device=self.device) # uncomment this for the first 30k steps, it's faster if you pretrain it on semantic / commit loss first
428
  losses['commit'] = commit_loss
429
  losses['semantic'] = semantic_loss
430
 
 
511
  'commit': 0, 'semantic': 0
512
  }
513
 
514
+
515
  audio_samples = {'train': [], 'val': []}
516
 
517
  for batch_idx, (audio, paths) in enumerate(tqdm(self.val_loader, desc='Validation', disable=not self.is_main_process())):
 
710
 
711
 
712
  def load_checkpoint(self):
713
+
714
  checkpoint_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth')
715
  if os.path.exists(checkpoint_path):
716
  print(f"Loading checkpoint from {checkpoint_path}")
 
732
  if 'scheduler_g_last_epoch' in checkpoint:
733
  self.scheduler_g.last_epoch = checkpoint['scheduler_g_last_epoch']
734
  else:
735
+
736
  self.scheduler_g.last_epoch = checkpoint['global_step']
737
 
738
  # Force scheduler to recompute its internal state
 
759
 
760
  self.scheduler_d._last_lr = self.scheduler_d.get_lr()
761
 
762
+
763
  if self.scaler_d is not None and 'scaler_d_state_dict' in checkpoint:
764
  self.scaler_d.load_state_dict(checkpoint['scaler_d_state_dict'])
765
 
 
796
  print(f"Next step checkpoint at: step {((self.global_step // self.args.save_step_interval) + 1) * self.args.save_step_interval}")
797
  print(f"{'='*60}\n")
798
 
799
+ g
800
  if self.global_step > 0:
801
  temp_scheduler = CosineWarmupScheduler(
802
  self.optimizer_g,