import os import sys import json import torch import torch.nn.functional as F import librosa import numpy as np from pathlib import Path from tqdm import tqdm import warnings from torch.nn.utils import remove_weight_norm, weight_norm import librosa import torch import torch.nn.functional as F import numpy as np import json import torch from higgs_audio_tokenizer import HiggsAudioTokenizer import torch import torch.nn as nn import warnings warnings.filterwarnings('ignore') def remove_weight_norms_from_model(model): for module in model.modules(): try: remove_weight_norm(module) except: continue return model class EncodedResult: def __init__(self, audio_codes): self.audio_codes = audio_codes def encode_batch(model, x_batch): e_semantic_input = model.get_regress_target(x_batch).detach() e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2)) e_acoustic = model.encoder(x_batch) if e_acoustic.shape[2] != e_semantic.shape[2]: pad_size = 160 * model.semantic_downsample_factor x_slice = x_batch[:, 0, :] x_padded = F.pad(x_slice, (pad_size, pad_size)) e_acoustic = model.encoder(x_padded.unsqueeze(1)) min_len = min(e_acoustic.shape[2], e_semantic.shape[2]) e_acoustic = e_acoustic[:, :, :min_len] e_semantic = e_semantic[:, :, :min_len] e = torch.cat([e_acoustic, e_semantic], dim=1) e = model.fc_prior(e.transpose(1, 2)) if model.quantizer_type == "RVQ": e = e.transpose(1, 2) _, codes, _, _ = model.quantizer(e, model.frame_rate, None) codes = codes.permute(1, 0, 2) else: quantized, codes = model.quantizer(e) codes = codes.permute(0, 2, 1) return EncodedResult(audio_codes=codes) def fix_all_inference_issues(model): device = next(model.parameters()).device model.eval() with torch.no_grad(): for module in model.modules(): if isinstance(module, nn.Module): module.eval() if hasattr(module, 'training'): module.training = False if hasattr(model, 'semantic_model'): print("Fixing semantic model...") model.semantic_model = model.semantic_model.to(device) model.semantic_model.eval() def disable_gradient_checkpointing(module): if hasattr(module, 'gradient_checkpointing'): module.gradient_checkpointing = False if hasattr(module, 'gradient_checkpointing_disable'): try: module.gradient_checkpointing_disable() except: pass for child in module.children(): disable_gradient_checkpointing(child) disable_gradient_checkpointing(model.semantic_model) if hasattr(model.semantic_model, 'encoder'): model.semantic_model.encoder.gradient_checkpointing = False if hasattr(model.semantic_model.encoder, 'layers'): for layer in model.semantic_model.encoder.layers: if hasattr(layer, 'gradient_checkpointing'): layer.gradient_checkpointing = False def set_dropout_eval(module): if isinstance(module, nn.Dropout): module.eval() module.training = False for child in module.children(): set_dropout_eval(child) set_dropout_eval(model) torch.cuda.empty_cache() if torch.cuda.is_available() else None return model def inference_pipeline(checkpoint_path, config_path, device='cuda'): print("Loading config...") with open(config_path, 'r') as f: config = json.load(f) print("Creating model...") model = HiggsAudioTokenizer( n_filters=config['n_filters'], D=config['D'], target_bandwidths=config['target_bandwidths'], ratios=config['ratios'], sample_rate=config['sample_rate'], bins=config['bins'], n_q=config['n_q'], codebook_dim=config.get('codebook_dim', None), semantic_techer=config['semantic_techer'], device=device ).to(device) print("Loading checkpoint...") checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] else: state_dict = checkpoint new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict, strict=False) print("Fixing inference issues...") model = fix_all_inference_issues(model) return model warnings.filterwarnings("ignore") OUTPUT_DIR = "/home/ubuntu/data_boson_44.1khz" BATCH_SIZE = 32 SAMPLE_RATE = 44100 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/Qanary_data" print(f"Using device: {DEVICE}") os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing") from datasets import load_from_disk print(f"Loading dataset from: {DATASET_PATH}") ds = load_from_disk(DATASET_PATH) print(f"Dataset info: {ds}") columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask'] existing_columns = [col for col in columns_to_remove if col in ds.column_names] if existing_columns: ds = ds.remove_columns(existing_columns) df = ds.to_pandas() print(f"Loaded {len(df)} files from dataset") os.makedirs(OUTPUT_DIR, exist_ok=True) print(f"Output directory '{OUTPUT_DIR}' is ready.") print("Checking for already processed files...") def get_output_path(audio_path): base_name = Path(audio_path).stem return os.path.join(OUTPUT_DIR, f"{base_name}.pt") original_count = len(df) df['output_exists'] = df['filename'].apply(lambda x: os.path.exists(get_output_path(x))) df_filtered = df[~df['output_exists']].copy() skipped_count = original_count - len(df_filtered) print(f"Found {skipped_count} already processed files. Skipping them.") print(f"Processing {len(df_filtered)} remaining files.") if len(df_filtered) == 0: print("All files have already been processed!") exit() print("Loading Higgs Audio Tokenizer model...") from transformers import HubertModel from higgs_audio_tokenizer import HiggsAudioTokenizer checkpoint_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/outputs_CQT/checkpoints/step_99000.pth' config_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/config copy.json' device = 'cuda' model = inference_pipeline(checkpoint_path, config_path, device) _ = model.eval() model = remove_weight_norms_from_model(model) print(f"Model loaded on {DEVICE}") hop_length = model.hop_length print(f"Encoder hop length: {hop_length}") print(f"\nStarting batch processing with batch size {BATCH_SIZE}...") filenames = df_filtered['filename'].tolist() total_processed = 0 total_errors = 0 with torch.no_grad(): for batch_start in tqdm(range(0, len(filenames), BATCH_SIZE), desc="Processing batches"): batch_end = min(batch_start + BATCH_SIZE, len(filenames)) batch_filenames = filenames[batch_start:batch_end] batch_audio = [] batch_lengths = [] batch_outputs = [] for filename in batch_filenames: output_path = get_output_path(filename) if os.path.exists(output_path): continue try: wav, _ = librosa.load(filename, sr=SAMPLE_RATE) wav_tensor = torch.from_numpy(wav).float() batch_audio.append(wav_tensor) batch_lengths.append(len(wav)) batch_outputs.append(output_path) except Exception as e: print(f"\nError loading {filename}: {e}") total_errors += 1 continue if not batch_audio: continue max_len = max(len(x) for x in batch_audio) padded_batch = [] for audio in batch_audio: pad_len = max_len - len(audio) if pad_len > 0: audio = F.pad(audio, (0, pad_len), mode='constant', value=0) padded_batch.append(audio) batch_tensor = torch.stack(padded_batch, dim=0) batch_tensor = batch_tensor.unsqueeze(1) batch_tensor = batch_tensor.to(DEVICE) try: encoded = encode_batch(model, batch_tensor) codes = encoded.audio_codes for idx, (output_path, orig_len) in enumerate(zip(batch_outputs, batch_lengths)): true_code_len = int(np.ceil(orig_len / hop_length)) item_codes = codes[idx, :, :true_code_len].cpu() torch.save(item_codes, output_path) total_processed += 1 except Exception as e: print(f"\nError encoding batch: {e}") total_errors += len(batch_outputs) print("\n" + "="*50) print("PROCESSING COMPLETE!") print("="*50) print(f"Successfully processed: {total_processed} files") print(f"Previously processed: {skipped_count} files") print(f"Errors encountered: {total_errors} files") print(f"Output directory: {OUTPUT_DIR}") final_count = len(list(Path(OUTPUT_DIR).glob("*.pt"))) print(f"Total .pt files in output: {final_count}")