| import os | |
| import sys | |
| import json | |
| import torch | |
| import torch.nn.functional as F | |
| import librosa | |
| import numpy as np | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import warnings | |
| from torch.nn.utils import remove_weight_norm, weight_norm | |
| import librosa | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import json | |
| import torch | |
| from higgs_audio_tokenizer import HiggsAudioTokenizer | |
| import torch | |
| import torch.nn as nn | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| def remove_weight_norms_from_model(model): | |
| for module in model.modules(): | |
| try: | |
| remove_weight_norm(module) | |
| except: | |
| continue | |
| return model | |
| class EncodedResult: | |
| def __init__(self, audio_codes): | |
| self.audio_codes = audio_codes | |
| def encode_batch(model, x_batch): | |
| e_semantic_input = model.get_regress_target(x_batch).detach() | |
| e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2)) | |
| e_acoustic = model.encoder(x_batch) | |
| if e_acoustic.shape[2] != e_semantic.shape[2]: | |
| pad_size = 160 * model.semantic_downsample_factor | |
| x_slice = x_batch[:, 0, :] | |
| x_padded = F.pad(x_slice, (pad_size, pad_size)) | |
| e_acoustic = model.encoder(x_padded.unsqueeze(1)) | |
| min_len = min(e_acoustic.shape[2], e_semantic.shape[2]) | |
| e_acoustic = e_acoustic[:, :, :min_len] | |
| e_semantic = e_semantic[:, :, :min_len] | |
| e = torch.cat([e_acoustic, e_semantic], dim=1) | |
| e = model.fc_prior(e.transpose(1, 2)) | |
| if model.quantizer_type == "RVQ": | |
| e = e.transpose(1, 2) | |
| _, codes, _, _ = model.quantizer(e, model.frame_rate, None) | |
| codes = codes.permute(1, 0, 2) | |
| else: | |
| quantized, codes = model.quantizer(e) | |
| codes = codes.permute(0, 2, 1) | |
| return EncodedResult(audio_codes=codes) | |
| def fix_all_inference_issues(model): | |
| device = next(model.parameters()).device | |
| model.eval() | |
| with torch.no_grad(): | |
| for module in model.modules(): | |
| if isinstance(module, nn.Module): | |
| module.eval() | |
| if hasattr(module, 'training'): | |
| module.training = False | |
| if hasattr(model, 'semantic_model'): | |
| print("Fixing semantic model...") | |
| model.semantic_model = model.semantic_model.to(device) | |
| model.semantic_model.eval() | |
| def disable_gradient_checkpointing(module): | |
| if hasattr(module, 'gradient_checkpointing'): | |
| module.gradient_checkpointing = False | |
| if hasattr(module, 'gradient_checkpointing_disable'): | |
| try: | |
| module.gradient_checkpointing_disable() | |
| except: | |
| pass | |
| for child in module.children(): | |
| disable_gradient_checkpointing(child) | |
| disable_gradient_checkpointing(model.semantic_model) | |
| if hasattr(model.semantic_model, 'encoder'): | |
| model.semantic_model.encoder.gradient_checkpointing = False | |
| if hasattr(model.semantic_model.encoder, 'layers'): | |
| for layer in model.semantic_model.encoder.layers: | |
| if hasattr(layer, 'gradient_checkpointing'): | |
| layer.gradient_checkpointing = False | |
| def set_dropout_eval(module): | |
| if isinstance(module, nn.Dropout): | |
| module.eval() | |
| module.training = False | |
| for child in module.children(): | |
| set_dropout_eval(child) | |
| set_dropout_eval(model) | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| return model | |
| def inference_pipeline(checkpoint_path, config_path, device='cuda'): | |
| print("Loading config...") | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| print("Creating model...") | |
| model = HiggsAudioTokenizer( | |
| n_filters=config['n_filters'], | |
| D=config['D'], | |
| target_bandwidths=config['target_bandwidths'], | |
| ratios=config['ratios'], | |
| sample_rate=config['sample_rate'], | |
| bins=config['bins'], | |
| n_q=config['n_q'], | |
| codebook_dim=config.get('codebook_dim', None), | |
| semantic_techer=config['semantic_techer'], | |
| device=device | |
| ).to(device) | |
| print("Loading checkpoint...") | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| if 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith('module.'): | |
| new_state_dict[k[7:]] = v | |
| else: | |
| new_state_dict[k] = v | |
| model.load_state_dict(new_state_dict, strict=False) | |
| print("Fixing inference issues...") | |
| model = fix_all_inference_issues(model) | |
| return model | |
| warnings.filterwarnings("ignore") | |
| OUTPUT_DIR = "/home/ubuntu/data_boson_44.1khz" | |
| BATCH_SIZE = 32 | |
| SAMPLE_RATE = 44100 | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| DATASET_PATH = "/home/ubuntu/ttsar/Layla/src_bpe_2/Qanary_data" | |
| print(f"Using device: {DEVICE}") | |
| os.chdir("/home/ubuntu/ttsar/boson_audio_codec/audio_processing") | |
| from datasets import load_from_disk | |
| print(f"Loading dataset from: {DATASET_PATH}") | |
| ds = load_from_disk(DATASET_PATH) | |
| print(f"Dataset info: {ds}") | |
| columns_to_remove = ['spk', 'duration', 'codes', 'input_ids', 'attention_mask'] | |
| existing_columns = [col for col in columns_to_remove if col in ds.column_names] | |
| if existing_columns: | |
| ds = ds.remove_columns(existing_columns) | |
| df = ds.to_pandas() | |
| print(f"Loaded {len(df)} files from dataset") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| print(f"Output directory '{OUTPUT_DIR}' is ready.") | |
| print("Checking for already processed files...") | |
| def get_output_path(audio_path): | |
| base_name = Path(audio_path).stem | |
| return os.path.join(OUTPUT_DIR, f"{base_name}.pt") | |
| original_count = len(df) | |
| df['output_exists'] = df['filename'].apply(lambda x: os.path.exists(get_output_path(x))) | |
| df_filtered = df[~df['output_exists']].copy() | |
| skipped_count = original_count - len(df_filtered) | |
| print(f"Found {skipped_count} already processed files. Skipping them.") | |
| print(f"Processing {len(df_filtered)} remaining files.") | |
| if len(df_filtered) == 0: | |
| print("All files have already been processed!") | |
| exit() | |
| print("Loading Higgs Audio Tokenizer model...") | |
| from transformers import HubertModel | |
| from higgs_audio_tokenizer import HiggsAudioTokenizer | |
| checkpoint_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/outputs_CQT/checkpoints/step_99000.pth' | |
| config_path = '/home/ubuntu/ttsar/boson_audio_codec/audio_processing/config copy.json' | |
| device = 'cuda' | |
| model = inference_pipeline(checkpoint_path, config_path, device) | |
| _ = model.eval() | |
| model = remove_weight_norms_from_model(model) | |
| print(f"Model loaded on {DEVICE}") | |
| hop_length = model.hop_length | |
| print(f"Encoder hop length: {hop_length}") | |
| print(f"\nStarting batch processing with batch size {BATCH_SIZE}...") | |
| filenames = df_filtered['filename'].tolist() | |
| total_processed = 0 | |
| total_errors = 0 | |
| with torch.no_grad(): | |
| for batch_start in tqdm(range(0, len(filenames), BATCH_SIZE), desc="Processing batches"): | |
| batch_end = min(batch_start + BATCH_SIZE, len(filenames)) | |
| batch_filenames = filenames[batch_start:batch_end] | |
| batch_audio = [] | |
| batch_lengths = [] | |
| batch_outputs = [] | |
| for filename in batch_filenames: | |
| output_path = get_output_path(filename) | |
| if os.path.exists(output_path): | |
| continue | |
| try: | |
| wav, _ = librosa.load(filename, sr=SAMPLE_RATE) | |
| wav_tensor = torch.from_numpy(wav).float() | |
| batch_audio.append(wav_tensor) | |
| batch_lengths.append(len(wav)) | |
| batch_outputs.append(output_path) | |
| except Exception as e: | |
| print(f"\nError loading {filename}: {e}") | |
| total_errors += 1 | |
| continue | |
| if not batch_audio: | |
| continue | |
| max_len = max(len(x) for x in batch_audio) | |
| padded_batch = [] | |
| for audio in batch_audio: | |
| pad_len = max_len - len(audio) | |
| if pad_len > 0: | |
| audio = F.pad(audio, (0, pad_len), mode='constant', value=0) | |
| padded_batch.append(audio) | |
| batch_tensor = torch.stack(padded_batch, dim=0) | |
| batch_tensor = batch_tensor.unsqueeze(1) | |
| batch_tensor = batch_tensor.to(DEVICE) | |
| try: | |
| encoded = encode_batch(model, batch_tensor) | |
| codes = encoded.audio_codes | |
| for idx, (output_path, orig_len) in enumerate(zip(batch_outputs, batch_lengths)): | |
| true_code_len = int(np.ceil(orig_len / hop_length)) | |
| item_codes = codes[idx, :, :true_code_len].cpu() | |
| torch.save(item_codes, output_path) | |
| total_processed += 1 | |
| except Exception as e: | |
| print(f"\nError encoding batch: {e}") | |
| total_errors += len(batch_outputs) | |
| print("\n" + "="*50) | |
| print("PROCESSING COMPLETE!") | |
| print("="*50) | |
| print(f"Successfully processed: {total_processed} files") | |
| print(f"Previously processed: {skipped_count} files") | |
| print(f"Errors encountered: {total_errors} files") | |
| print(f"Output directory: {OUTPUT_DIR}") | |
| final_count = len(list(Path(OUTPUT_DIR).glob("*.pt"))) | |
| print(f"Total .pt files in output: {final_count}") |