Respair commited on
Commit
10c6e0f
·
verified ·
1 Parent(s): 0426fc7

Update boson_codeit.py

Browse files
Files changed (1) hide show
  1. boson_codeit.py +15 -361
boson_codeit.py CHANGED
@@ -1,260 +1,3 @@
1
- # #!/usr/bin/env python3
2
- # """
3
- # Audio Processing Script for Boson Codes
4
- # Processes audio files in parallel using Higgs Audio Tokenizer
5
- # and saves encoded representations as .pt files.
6
- # """
7
-
8
- # import os
9
- # import sys
10
- # import json
11
- # import torch
12
- # import librosa
13
- # import numpy as np
14
- # import warnings
15
- # import argparse
16
- # from pathlib import Path
17
- # from multiprocessing import Pool
18
- # from tqdm import tqdm
19
-
20
- # from datasets import load_from_disk
21
- # from higgs_audio_tokenizer import HiggsAudioTokenizer
22
-
23
- # # Suppress PyTorch FutureWarnings
24
- # warnings.filterwarnings("ignore", category=FutureWarning)
25
-
26
- # # Global configuration
27
- # DEFAULT_OUTPUT_DIR = "/home/ubuntu/boson_codes"
28
- # DEFAULT_NUM_CORES = 48
29
- # DEFAULT_SAMPLE_RATE = 44100
30
- # DEFAULT_DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/data"
31
-
32
- # # Model paths
33
- # CONFIG_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/config.json"
34
- # MODEL_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/model.pth"
35
-
36
- # # Global model variable (initialized in each worker)
37
- # model = None
38
-
39
-
40
- # def init_worker():
41
- # """Initialize model once per worker process."""
42
- # global model
43
- # device = 'cpu'
44
-
45
- # # Load config
46
- # with open(CONFIG_PATH, 'r') as f:
47
- # config = json.load(f)
48
-
49
- # # Initialize model
50
- # model = HiggsAudioTokenizer(
51
- # **config,
52
- # device=device,
53
- # )
54
-
55
- # # Load weights
56
- # parameter_dict = torch.load(MODEL_PATH, map_location=device)
57
- # _ = model.load_state_dict(parameter_dict, strict=False)
58
- # model = model.to(device)
59
- # _ = model.eval()
60
-
61
- # print(f"Model loaded in worker {os.getpid()}")
62
-
63
-
64
- # def process_audio_file(args):
65
- # """Process a single audio file using pre-loaded model."""
66
- # filename, output_dir, sample_rate = args
67
-
68
- # try:
69
- # # Output filename - same name, just change extension to .pt
70
- # base_name = Path(filename).stem
71
- # output_path = os.path.join(output_dir, f"{base_name}.pt")
72
-
73
- # # Skip if exists (double-check in case of race conditions)
74
- # if os.path.exists(output_path):
75
- # return ("skipped", filename)
76
-
77
- # # Load and process audio
78
- # wav, sr = librosa.load(filename, sr=sample_rate)
79
- # wav = torch.from_numpy(wav).unsqueeze(0).float().to('cpu')
80
-
81
- # # Encode using the pre-loaded model
82
- # with torch.no_grad():
83
- # encoded = model._xcodec_encode(wav.unsqueeze(0))
84
-
85
- # # Save codes only
86
- # torch.save(encoded.audio_codes, output_path)
87
-
88
- # return ("success", filename)
89
-
90
- # except Exception as e:
91
- # return ("error", filename, str(e))
92
-
93
-
94
- # def load_dataset(dataset_path):
95
- # """Load and prepare the dataset."""
96
- # print(f"Loading dataset from: {dataset_path}")
97
- # ds = load_from_disk(dataset_path)
98
- # print(f"Dataset info: {ds}")
99
-
100
- # # Remove unnecessary columns
101
- # columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask']
102
- # existing_columns = [col for col in columns_to_remove if col in ds.column_names]
103
- # if existing_columns:
104
- # ds = ds.remove_columns(existing_columns)
105
- # print(f"Removed columns: {existing_columns}")
106
-
107
- # # Convert to pandas DataFrame
108
- # df = ds.to_pandas()
109
- # print(f"Loaded {len(df)} files from dataset")
110
- # return df
111
-
112
-
113
- # def main(args):
114
- # """Main processing function."""
115
- # # Change to audio processing directory
116
- # os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing")
117
- # print(f"Working directory: {os.getcwd()}")
118
-
119
- # # Create output directory
120
- # os.makedirs(args.output_dir, exist_ok=True)
121
- # print(f"Output directory: {args.output_dir}")
122
-
123
- # # Check if model files exist
124
- # if not os.path.exists(CONFIG_PATH):
125
- # print(f"Error: Config file not found at {CONFIG_PATH}")
126
- # sys.exit(1)
127
- # if not os.path.exists(MODEL_PATH):
128
- # print(f"Error: Model file not found at {MODEL_PATH}")
129
- # sys.exit(1)
130
-
131
- # # Load dataset
132
- # df = load_dataset(args.dataset_path)
133
-
134
- # # Get filenames from dataframe
135
- # all_filenames = df['filename'].tolist()
136
-
137
- # # Pre-filter to exclude already processed files
138
- # filenames_to_process = []
139
- # already_processed = []
140
-
141
- # print(f"\nChecking for already processed files...")
142
- # for filename in all_filenames:
143
- # base_name = Path(filename).stem
144
- # output_path = os.path.join(args.output_dir, f"{base_name}.pt")
145
- # if os.path.exists(output_path):
146
- # already_processed.append(filename)
147
- # else:
148
- # filenames_to_process.append(filename)
149
-
150
- # print(f"\nTotal files: {len(all_filenames)}")
151
- # print(f"Already processed: {len(already_processed)}")
152
- # print(f"To process: {len(filenames_to_process)}")
153
-
154
- # if len(filenames_to_process) == 0:
155
- # print("\nAll files have already been processed!")
156
- # return
157
-
158
- # print(f"\nProcessing {len(filenames_to_process)} files using {args.num_cores} cores...")
159
- # print(f"Sample rate: {args.sample_rate} Hz")
160
-
161
- # # Prepare arguments for multiprocessing
162
- # process_args = [(filename, args.output_dir, args.sample_rate)
163
- # for filename in filenames_to_process]
164
-
165
- # # Process in parallel with model reuse
166
- # with Pool(processes=args.num_cores, initializer=init_worker) as pool:
167
- # results = list(tqdm(
168
- # pool.imap(process_audio_file, process_args, chunksize=args.chunksize),
169
- # total=len(filenames_to_process),
170
- # desc="Processing audio files"
171
- # ))
172
-
173
- # # Count results
174
- # processed = sum(1 for r in results if r[0] == "success")
175
- # skipped = sum(1 for r in results if r[0] == "skipped")
176
- # errors = sum(1 for r in results if r[0] == "error")
177
-
178
- # print(f"\nProcessing complete!")
179
- # print(f" Successfully processed: {processed}")
180
- # print(f" Previously processed: {len(already_processed)}")
181
- # print(f" Skipped (race condition): {skipped}")
182
- # print(f" Errors: {errors}")
183
-
184
- # # Show errors if any
185
- # if errors > 0:
186
- # print("\nErrors encountered:")
187
- # error_log_path = os.path.join(args.output_dir, "processing_errors.log")
188
- # with open(error_log_path, 'w') as f:
189
- # for r in results:
190
- # if r[0] == "error":
191
- # error_msg = f"{r[1]}: {r[2]}"
192
- # print(f" {error_msg}")
193
- # f.write(error_msg + "\n")
194
- # print(f"\nError log saved to: {error_log_path}")
195
-
196
- # # Show summary of all processed files
197
- # total_processed_files = len(list(Path(args.output_dir).glob("*.pt")))
198
- # print(f"\nTotal .pt files in {args.output_dir}: {total_processed_files}")
199
-
200
-
201
- # if __name__ == "__main__":
202
- # parser = argparse.ArgumentParser(
203
- # description="Process audio files using Higgs Audio Tokenizer and save as .pt files"
204
- # )
205
-
206
- # parser.add_argument(
207
- # "--dataset-path",
208
- # type=str,
209
- # default=DEFAULT_DATASET_PATH,
210
- # help=f"Path to the dataset (default: {DEFAULT_DATASET_PATH})"
211
- # )
212
-
213
- # parser.add_argument(
214
- # "--output-dir",
215
- # type=str,
216
- # default=DEFAULT_OUTPUT_DIR,
217
- # help=f"Output directory for .pt files (default: {DEFAULT_OUTPUT_DIR})"
218
- # )
219
-
220
- # parser.add_argument(
221
- # "--num-cores",
222
- # type=int,
223
- # default=DEFAULT_NUM_CORES,
224
- # help=f"Number of CPU cores to use (default: {DEFAULT_NUM_CORES})"
225
- # )
226
-
227
- # parser.add_argument(
228
- # "--sample-rate",
229
- # type=int,
230
- # default=DEFAULT_SAMPLE_RATE,
231
- # help=f"Sample rate for audio processing (default: {DEFAULT_SAMPLE_RATE})"
232
- # )
233
-
234
- # parser.add_argument(
235
- # "--chunksize",
236
- # type=int,
237
- # default=1,
238
- # help="Chunksize for multiprocessing pool (default: 1)"
239
- # )
240
-
241
- # args = parser.parse_args()
242
-
243
- # # Run main processing
244
- # try:
245
- # main(args)
246
- # except KeyboardInterrupt:
247
- # print("\n\nProcessing interrupted by user")
248
- # sys.exit(1)
249
- # except Exception as e:
250
- # print(f"\n\nError: {e}")
251
- # sys.exit(1)
252
-
253
- #!/usr/bin/env python3
254
- """
255
- GPU Batch Processing Script for Boson Codes with Dataset Loading
256
- """
257
-
258
  import os
259
  import sys
260
  import json
@@ -266,27 +9,20 @@ from pathlib import Path
266
  from tqdm import tqdm
267
  import warnings
268
  from torch.nn.utils import remove_weight_norm, weight_norm
269
-
270
-
271
- # from boson_multimodal.audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
272
- # model = load_higgs_audio_tokenizer("bosonai/higgs-audio-v2-tokenizer")
273
  import librosa
274
  import torch
275
  import torch.nn.functional as F
276
  import numpy as np
277
  import json
278
  import torch
279
-
280
  from higgs_audio_tokenizer import HiggsAudioTokenizer
281
- # model = load_higgs_audio_tokenizer("bosonai/higgs-audio-v2-tokenizer")
282
-
283
  import torch
284
  import torch.nn as nn
285
  import warnings
286
 
287
- # Suppress warnings
288
  warnings.filterwarnings('ignore')
289
 
 
290
  def remove_weight_norms_from_model(model):
291
  for module in model.modules():
292
  try:
@@ -300,58 +36,42 @@ class EncodedResult:
300
  def __init__(self, audio_codes):
301
  self.audio_codes = audio_codes
302
 
 
303
  def encode_batch(model, x_batch):
304
- """
305
- Encodes a batch of audio tensors using the HiggsAudioTokenizer model.
306
- Args:
307
- model: The loaded HiggsAudioTokenizer model.
308
- x_batch: A tensor of shape [B, 1, T]
309
- """
310
- # Acoustic and Semantic Feature Extraction
311
  e_semantic_input = model.get_regress_target(x_batch).detach()
312
  e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2))
313
  e_acoustic = model.encoder(x_batch)
314
-
315
- # This block contains the fix for batch processing
316
  if e_acoustic.shape[2] != e_semantic.shape[2]:
317
  pad_size = 160 * model.semantic_downsample_factor
318
 
319
- # 1. Remove channel dim, preserving batch dim -> [B, T]
320
  x_slice = x_batch[:, 0, :]
321
 
322
- # 2. Pad the tensor
323
  x_padded = F.pad(x_slice, (pad_size, pad_size))
324
 
325
- # 3. Re-add channel dim before passing to encoder -> [B, 1, T_padded]
326
  e_acoustic = model.encoder(x_padded.unsqueeze(1))
327
-
328
- # Ensure dimensions match before concatenating
329
  min_len = min(e_acoustic.shape[2], e_semantic.shape[2])
330
  e_acoustic = e_acoustic[:, :, :min_len]
331
  e_semantic = e_semantic[:, :, :min_len]
332
-
333
- # Remainder of the original encoding logic
334
  e = torch.cat([e_acoustic, e_semantic], dim=1)
335
  e = model.fc_prior(e.transpose(1, 2))
336
-
337
  if model.quantizer_type == "RVQ":
338
  e = e.transpose(1, 2)
339
  _, codes, _, _ = model.quantizer(e, model.frame_rate, None)
340
  codes = codes.permute(1, 0, 2)
341
- else: # RFSQ
342
  quantized, codes = model.quantizer(e)
343
  codes = codes.permute(0, 2, 1)
344
-
345
  return EncodedResult(audio_codes=codes)
346
 
347
 
348
  def fix_all_inference_issues(model):
349
- """
350
- Comprehensive fix for all potential inference issues
351
- """
352
  device = next(model.parameters()).device
353
 
354
- # 1. Force everything to eval mode
355
  model.eval()
356
  with torch.no_grad():
357
  for module in model.modules():
@@ -360,15 +80,12 @@ def fix_all_inference_issues(model):
360
  if hasattr(module, 'training'):
361
  module.training = False
362
 
363
- # 2. Fix semantic model specifically
364
  if hasattr(model, 'semantic_model'):
365
  print("Fixing semantic model...")
366
 
367
- # Move to correct device
368
  model.semantic_model = model.semantic_model.to(device)
369
  model.semantic_model.eval()
370
 
371
- # Disable ALL gradient checkpointing
372
  def disable_gradient_checkpointing(module):
373
  if hasattr(module, 'gradient_checkpointing'):
374
  module.gradient_checkpointing = False
@@ -382,7 +99,6 @@ def fix_all_inference_issues(model):
382
 
383
  disable_gradient_checkpointing(model.semantic_model)
384
 
385
- # For HuBERT specifically
386
  if hasattr(model.semantic_model, 'encoder'):
387
  model.semantic_model.encoder.gradient_checkpointing = False
388
  if hasattr(model.semantic_model.encoder, 'layers'):
@@ -390,7 +106,6 @@ def fix_all_inference_issues(model):
390
  if hasattr(layer, 'gradient_checkpointing'):
391
  layer.gradient_checkpointing = False
392
 
393
- # 3. Set all dropout to eval mode
394
  def set_dropout_eval(module):
395
  if isinstance(module, nn.Dropout):
396
  module.eval()
@@ -400,21 +115,16 @@ def fix_all_inference_issues(model):
400
 
401
  set_dropout_eval(model)
402
 
403
- # 4. Clear any cached computations
404
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
405
 
406
  return model
407
 
 
408
  def inference_pipeline(checkpoint_path, config_path, device='cuda'):
409
- """
410
- Complete pipeline for inference with your trained model
411
- """
412
- # Load config
413
  print("Loading config...")
414
  with open(config_path, 'r') as f:
415
  config = json.load(f)
416
 
417
- # Create model
418
  print("Creating model...")
419
  model = HiggsAudioTokenizer(
420
  n_filters=config['n_filters'],
@@ -429,7 +139,6 @@ def inference_pipeline(checkpoint_path, config_path, device='cuda'):
429
  device=device
430
  ).to(device)
431
 
432
- # Load checkpoint
433
  print("Loading checkpoint...")
434
  checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
435
 
@@ -438,7 +147,6 @@ def inference_pipeline(checkpoint_path, config_path, device='cuda'):
438
  else:
439
  state_dict = checkpoint
440
 
441
- # Remove 'module.' prefix if present (from DDP)
442
  new_state_dict = {}
443
  for k, v in state_dict.items():
444
  if k.startswith('module.'):
@@ -448,47 +156,30 @@ def inference_pipeline(checkpoint_path, config_path, device='cuda'):
448
 
449
  model.load_state_dict(new_state_dict, strict=False)
450
 
451
- # Fix all inference issues
452
  print("Fixing inference issues...")
453
  model = fix_all_inference_issues(model)
454
-
455
 
456
  return model
457
 
458
 
459
-
460
- # # Add paths
461
- # sys.path.insert(0, "/home/ubuntu/AP-BWE")
462
-
463
- # Suppress warnings
464
  warnings.filterwarnings("ignore")
465
 
466
- # Configuration
467
  OUTPUT_DIR = "/home/ubuntu/data_boson_44.1khz"
468
  BATCH_SIZE = 32
469
  SAMPLE_RATE = 44100
470
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
471
  DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/Qanary_data"
472
 
473
- # # Model paths
474
- # CONFIG_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/config.json"
475
- # MODEL_PATH = "/home/ubuntu/.cache/huggingface/hub/models--bosonai--higgs-audio-v2-tokenizer/snapshots/9d4988fbd4ad07b4cac3a5fa462741a41810dbec/model.pth"
476
-
477
- # --- Setup ---
478
  print(f"Using device: {DEVICE}")
479
 
480
- # Change to working directory
481
  os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing")
482
 
483
- # Load dataset
484
  from datasets import load_from_disk
485
 
486
-
487
  print(f"Loading dataset from: {DATASET_PATH}")
488
  ds = load_from_disk(DATASET_PATH)
489
  print(f"Dataset info: {ds}")
490
 
491
- # Remove unnecessary columns
492
  columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask']
493
  existing_columns = [col for col in columns_to_remove if col in ds.column_names]
494
  if existing_columns:
@@ -500,14 +191,14 @@ print(f"Loaded {len(df)} files from dataset")
500
  os.makedirs(OUTPUT_DIR, exist_ok=True)
501
  print(f"Output directory '{OUTPUT_DIR}' is ready.")
502
 
503
- # --- Filter already processed ---
504
  print("Checking for already processed files...")
505
 
 
506
  def get_output_path(audio_path):
507
  base_name = Path(audio_path).stem
508
  return os.path.join(OUTPUT_DIR, f"{base_name}.pt")
509
 
510
- # Filter
511
  original_count = len(df)
512
  df['output_exists'] = df['filename'].apply(lambda x: os.path.exists(get_output_path(x)))
513
  df_filtered = df[~df['output_exists']].copy()
@@ -520,47 +211,24 @@ if len(df_filtered) == 0:
520
  print("All files have already been processed!")
521
  exit()
522
 
523
- # --- Load Model ---
524
  print("Loading Higgs Audio Tokenizer model...")
525
-
526
  from transformers import HubertModel
527
  from higgs_audio_tokenizer import HiggsAudioTokenizer
528
 
529
- # Load config
530
- # with open(CONFIG_PATH, 'r') as f:
531
- # config = json.load(f)
532
-
533
- # # Initialize model
534
- # model = HiggsAudioTokenizer(
535
- # **config,
536
- # device=DEVICE,
537
- # )
538
-
539
- # Load weights
540
- # parameter_dict = torch.load(MODEL_PATH, map_location=DEVICE)
541
- # _ = model.load_state_dict(parameter_dict, strict=False)
542
- # model = model.to(DEVICE)
543
- # _ = model.eval()
544
-
545
-
546
  checkpoint_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/outputs_CQT/checkpoints/step_99000.pth'
547
  config_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/config copy.json'
548
  device = 'cuda'
 
549
  model = inference_pipeline(checkpoint_path, config_path, device)
550
  _ = model.eval()
551
-
552
  model = remove_weight_norms_from_model(model)
553
-
554
  print(f"Model loaded on {DEVICE}")
555
 
556
- # Get hop length
557
  hop_length = model.hop_length
558
  print(f"Encoder hop length: {hop_length}")
559
 
560
- # --- Batch Processing ---
561
  print(f"\nStarting batch processing with batch size {BATCH_SIZE}...")
562
 
563
- # Process in batches
564
  filenames = df_filtered['filename'].tolist()
565
  total_processed = 0
566
  total_errors = 0
@@ -574,16 +242,13 @@ with torch.no_grad():
574
  batch_lengths = []
575
  batch_outputs = []
576
 
577
- # Load batch
578
  for filename in batch_filenames:
579
  output_path = get_output_path(filename)
580
 
581
- # Skip if exists (race condition check)
582
  if os.path.exists(output_path):
583
  continue
584
 
585
  try:
586
- # Load audio
587
  wav, _ = librosa.load(filename, sr=SAMPLE_RATE)
588
  wav_tensor = torch.from_numpy(wav).float()
589
 
@@ -599,7 +264,6 @@ with torch.no_grad():
599
  if not batch_audio:
600
  continue
601
 
602
- # Pad batch to same length
603
  max_len = max(len(x) for x in batch_audio)
604
  padded_batch = []
605
 
@@ -607,30 +271,21 @@ with torch.no_grad():
607
  pad_len = max_len - len(audio)
608
  if pad_len > 0:
609
  audio = F.pad(audio, (0, pad_len), mode='constant', value=0)
610
- # Don't add extra dimensions here, just collect the padded audio
611
  padded_batch.append(audio)
612
 
613
- # Convert list to tensor and add channel dimension
614
- # Stack along batch dimension to get [B, T]
615
- batch_tensor = torch.stack(padded_batch, dim=0) # [B, T]
616
- # Add channel dimension
617
- batch_tensor = batch_tensor.unsqueeze(1) # [B, 1, T]
618
  batch_tensor = batch_tensor.to(DEVICE)
619
 
620
- # Encode batch
621
  try:
622
  encoded = encode_batch(model, batch_tensor)
623
- codes = encoded.audio_codes # [B, n_codebooks, T_compressed]
624
 
625
- # Save each item
626
  for idx, (output_path, orig_len) in enumerate(zip(batch_outputs, batch_lengths)):
627
- # Calculate true code length
628
  true_code_len = int(np.ceil(orig_len / hop_length))
629
 
630
- # Extract non-padded codes
631
  item_codes = codes[idx, :, :true_code_len].cpu()
632
 
633
- # Save
634
  torch.save(item_codes, output_path)
635
  total_processed += 1
636
 
@@ -646,6 +301,5 @@ print(f"Previously processed: {skipped_count} files")
646
  print(f"Errors encountered: {total_errors} files")
647
  print(f"Output directory: {OUTPUT_DIR}")
648
 
649
- # Final count
650
  final_count = len(list(Path(OUTPUT_DIR).glob("*.pt")))
651
  print(f"Total .pt files in output: {final_count}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import json
 
9
  from tqdm import tqdm
10
  import warnings
11
  from torch.nn.utils import remove_weight_norm, weight_norm
 
 
 
 
12
  import librosa
13
  import torch
14
  import torch.nn.functional as F
15
  import numpy as np
16
  import json
17
  import torch
 
18
  from higgs_audio_tokenizer import HiggsAudioTokenizer
 
 
19
  import torch
20
  import torch.nn as nn
21
  import warnings
22
 
 
23
  warnings.filterwarnings('ignore')
24
 
25
+
26
  def remove_weight_norms_from_model(model):
27
  for module in model.modules():
28
  try:
 
36
  def __init__(self, audio_codes):
37
  self.audio_codes = audio_codes
38
 
39
+
40
  def encode_batch(model, x_batch):
 
 
 
 
 
 
 
41
  e_semantic_input = model.get_regress_target(x_batch).detach()
42
  e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2))
43
  e_acoustic = model.encoder(x_batch)
44
+
 
45
  if e_acoustic.shape[2] != e_semantic.shape[2]:
46
  pad_size = 160 * model.semantic_downsample_factor
47
 
 
48
  x_slice = x_batch[:, 0, :]
49
 
 
50
  x_padded = F.pad(x_slice, (pad_size, pad_size))
51
 
 
52
  e_acoustic = model.encoder(x_padded.unsqueeze(1))
53
+
 
54
  min_len = min(e_acoustic.shape[2], e_semantic.shape[2])
55
  e_acoustic = e_acoustic[:, :, :min_len]
56
  e_semantic = e_semantic[:, :, :min_len]
57
+
 
58
  e = torch.cat([e_acoustic, e_semantic], dim=1)
59
  e = model.fc_prior(e.transpose(1, 2))
60
+
61
  if model.quantizer_type == "RVQ":
62
  e = e.transpose(1, 2)
63
  _, codes, _, _ = model.quantizer(e, model.frame_rate, None)
64
  codes = codes.permute(1, 0, 2)
65
+ else:
66
  quantized, codes = model.quantizer(e)
67
  codes = codes.permute(0, 2, 1)
68
+
69
  return EncodedResult(audio_codes=codes)
70
 
71
 
72
  def fix_all_inference_issues(model):
 
 
 
73
  device = next(model.parameters()).device
74
 
 
75
  model.eval()
76
  with torch.no_grad():
77
  for module in model.modules():
 
80
  if hasattr(module, 'training'):
81
  module.training = False
82
 
 
83
  if hasattr(model, 'semantic_model'):
84
  print("Fixing semantic model...")
85
 
 
86
  model.semantic_model = model.semantic_model.to(device)
87
  model.semantic_model.eval()
88
 
 
89
  def disable_gradient_checkpointing(module):
90
  if hasattr(module, 'gradient_checkpointing'):
91
  module.gradient_checkpointing = False
 
99
 
100
  disable_gradient_checkpointing(model.semantic_model)
101
 
 
102
  if hasattr(model.semantic_model, 'encoder'):
103
  model.semantic_model.encoder.gradient_checkpointing = False
104
  if hasattr(model.semantic_model.encoder, 'layers'):
 
106
  if hasattr(layer, 'gradient_checkpointing'):
107
  layer.gradient_checkpointing = False
108
 
 
109
  def set_dropout_eval(module):
110
  if isinstance(module, nn.Dropout):
111
  module.eval()
 
115
 
116
  set_dropout_eval(model)
117
 
 
118
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
119
 
120
  return model
121
 
122
+
123
  def inference_pipeline(checkpoint_path, config_path, device='cuda'):
 
 
 
 
124
  print("Loading config...")
125
  with open(config_path, 'r') as f:
126
  config = json.load(f)
127
 
 
128
  print("Creating model...")
129
  model = HiggsAudioTokenizer(
130
  n_filters=config['n_filters'],
 
139
  device=device
140
  ).to(device)
141
 
 
142
  print("Loading checkpoint...")
143
  checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
144
 
 
147
  else:
148
  state_dict = checkpoint
149
 
 
150
  new_state_dict = {}
151
  for k, v in state_dict.items():
152
  if k.startswith('module.'):
 
156
 
157
  model.load_state_dict(new_state_dict, strict=False)
158
 
 
159
  print("Fixing inference issues...")
160
  model = fix_all_inference_issues(model)
 
161
 
162
  return model
163
 
164
 
 
 
 
 
 
165
  warnings.filterwarnings("ignore")
166
 
 
167
  OUTPUT_DIR = "/home/ubuntu/data_boson_44.1khz"
168
  BATCH_SIZE = 32
169
  SAMPLE_RATE = 44100
170
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
171
  DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/Qanary_data"
172
 
 
 
 
 
 
173
  print(f"Using device: {DEVICE}")
174
 
 
175
  os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing")
176
 
 
177
  from datasets import load_from_disk
178
 
 
179
  print(f"Loading dataset from: {DATASET_PATH}")
180
  ds = load_from_disk(DATASET_PATH)
181
  print(f"Dataset info: {ds}")
182
 
 
183
  columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask']
184
  existing_columns = [col for col in columns_to_remove if col in ds.column_names]
185
  if existing_columns:
 
191
  os.makedirs(OUTPUT_DIR, exist_ok=True)
192
  print(f"Output directory '{OUTPUT_DIR}' is ready.")
193
 
 
194
  print("Checking for already processed files...")
195
 
196
+
197
  def get_output_path(audio_path):
198
  base_name = Path(audio_path).stem
199
  return os.path.join(OUTPUT_DIR, f"{base_name}.pt")
200
 
201
+
202
  original_count = len(df)
203
  df['output_exists'] = df['filename'].apply(lambda x: os.path.exists(get_output_path(x)))
204
  df_filtered = df[~df['output_exists']].copy()
 
211
  print("All files have already been processed!")
212
  exit()
213
 
 
214
  print("Loading Higgs Audio Tokenizer model...")
 
215
  from transformers import HubertModel
216
  from higgs_audio_tokenizer import HiggsAudioTokenizer
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  checkpoint_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/outputs_CQT/checkpoints/step_99000.pth'
219
  config_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/config copy.json'
220
  device = 'cuda'
221
+
222
  model = inference_pipeline(checkpoint_path, config_path, device)
223
  _ = model.eval()
 
224
  model = remove_weight_norms_from_model(model)
 
225
  print(f"Model loaded on {DEVICE}")
226
 
 
227
  hop_length = model.hop_length
228
  print(f"Encoder hop length: {hop_length}")
229
 
 
230
  print(f"\nStarting batch processing with batch size {BATCH_SIZE}...")
231
 
 
232
  filenames = df_filtered['filename'].tolist()
233
  total_processed = 0
234
  total_errors = 0
 
242
  batch_lengths = []
243
  batch_outputs = []
244
 
 
245
  for filename in batch_filenames:
246
  output_path = get_output_path(filename)
247
 
 
248
  if os.path.exists(output_path):
249
  continue
250
 
251
  try:
 
252
  wav, _ = librosa.load(filename, sr=SAMPLE_RATE)
253
  wav_tensor = torch.from_numpy(wav).float()
254
 
 
264
  if not batch_audio:
265
  continue
266
 
 
267
  max_len = max(len(x) for x in batch_audio)
268
  padded_batch = []
269
 
 
271
  pad_len = max_len - len(audio)
272
  if pad_len > 0:
273
  audio = F.pad(audio, (0, pad_len), mode='constant', value=0)
 
274
  padded_batch.append(audio)
275
 
276
+ batch_tensor = torch.stack(padded_batch, dim=0)
277
+ batch_tensor = batch_tensor.unsqueeze(1)
 
 
 
278
  batch_tensor = batch_tensor.to(DEVICE)
279
 
 
280
  try:
281
  encoded = encode_batch(model, batch_tensor)
282
+ codes = encoded.audio_codes
283
 
 
284
  for idx, (output_path, orig_len) in enumerate(zip(batch_outputs, batch_lengths)):
 
285
  true_code_len = int(np.ceil(orig_len / hop_length))
286
 
 
287
  item_codes = codes[idx, :, :true_code_len].cpu()
288
 
 
289
  torch.save(item_codes, output_path)
290
  total_processed += 1
291
 
 
301
  print(f"Errors encountered: {total_errors} files")
302
  print(f"Output directory: {OUTPUT_DIR}")
303
 
 
304
  final_count = len(list(Path(OUTPUT_DIR).glob("*.pt")))
305
  print(f"Total .pt files in output: {final_count}")