diff --git "a/codec_inference.ipynb" "b/codec_inference.ipynb" new file mode 100644--- /dev/null +++ "b/codec_inference.ipynb" @@ -0,0 +1,258 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/ubuntu/higgs_audio_train\n" + ] + } + ], + "source": [ + "%cd /home/ubuntu/higgs_audio_train\n", + "\n", + "import librosa\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "import json\n", + "import torch\n", + "from IPython.display import Audio as Sawt\n", + "from higgs_audio_tokenizer import HiggsAudioTokenizer\n", + "import torch\n", + "import torch.nn as nn\n", + "import warnings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/ubuntu/higgs_audio_train\n", + "Loading config...\n", + "Creating model...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading checkpoint...\n" + ] + } + ], + "source": [ + "%cd /home/ubuntu/higgs_audio_train\n", + "\n", + "import librosa\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "import json\n", + "import torch\n", + "from IPython.display import Audio as Sawt\n", + "from higgs_audio_tokenizer import HiggsAudioTokenizer\n", + "import torch\n", + "import torch.nn as nn\n", + "import warnings\n", + "\n", + "\n", + "class EncodedResult:\n", + " def __init__(self, audio_codes, quantized):\n", + " self.audio_codes = audio_codes\n", + " self.quantized = quantized\n", + "\n", + "\n", + "def encode_batch(model, x_batch):\n", + " \"\"\"\n", + " Encodes a batch of audio tensors using the HiggsAudioTokenizer model.\n", + " Args:\n", + " model: The loaded HiggsAudioTokenizer model.\n", + " x_batch: A tensor of shape [B, 1, T]\n", + " \"\"\"\n", + " # Acoustic and Semantic Feature Extraction\n", + " e_semantic_input = model.get_regress_target(x_batch).detach()\n", + " e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2))\n", + " e_acoustic = model.encoder(x_batch)\n", + "\n", + " # This block contains the fix for batch processing\n", + " if e_acoustic.shape[2] != e_semantic.shape[2]:\n", + " pad_size = 160 * model.semantic_downsample_factor\n", + " \n", + " # 1. Remove channel dim, preserving batch dim -> [B, T]\n", + " x_slice = x_batch[:, 0, :]\n", + " \n", + " # 2. Pad the tensor\n", + " x_padded = F.pad(x_slice, (pad_size, pad_size))\n", + " \n", + " # 3. Re-add channel dim before passing to encoder -> [B, 1, T_padded]\n", + " e_acoustic = model.encoder(x_padded.unsqueeze(1))\n", + "\n", + " # Ensure dimensions match before concatenating\n", + " min_len = min(e_acoustic.shape[2], e_semantic.shape[2])\n", + " e_acoustic = e_acoustic[:, :, :min_len]\n", + " e_semantic = e_semantic[:, :, :min_len]\n", + "\n", + " # Remainder of the original encoding logic\n", + " e = torch.cat([e_acoustic, e_semantic], dim=1)\n", + " e = model.fc_prior(e.transpose(1, 2))\n", + "\n", + " if model.quantizer_type == \"RVQ\":\n", + " e = e.transpose(1, 2)\n", + " quantized, codes, _, _ = model.quantizer(e, model.frame_rate, None)\n", + " codes = codes.permute(1, 0, 2)\n", + " else: # RFSQ\n", + " quantized, codes = model.quantizer(e)\n", + " codes = codes.permute(0, 2, 1)\n", + "\n", + " return EncodedResult(audio_codes=codes, quantized=quantized)\n", + "\n", + "def prepare(checkpoint_path, config_path, device='cuda'):\n", + "\n", + " # Load config\n", + " print(\"Loading config...\")\n", + " with open(config_path, 'r') as f:\n", + " config = json.load(f)\n", + " \n", + " # Create model\n", + " print(\"Creating model...\")\n", + " model = HiggsAudioTokenizer(\n", + " n_filters=config['n_filters'],\n", + " D=config['D'],\n", + " target_bandwidths=config['target_bandwidths'],\n", + " ratios=config['ratios'],\n", + " sample_rate=config['sample_rate'],\n", + " bins=config['bins'],\n", + " n_q=config['n_q'],\n", + " codebook_dim=config.get('codebook_dim', None),\n", + " semantic_techer=config['semantic_techer'],\n", + " device=device\n", + " ).to(device)\n", + " \n", + " # Load checkpoint\n", + " print(\"Loading checkpoint...\")\n", + " checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\n", + " \n", + " if 'model_state_dict' in checkpoint:\n", + " state_dict = checkpoint['model_state_dict']\n", + " else:\n", + " state_dict = checkpoint\n", + " \n", + " # Remove 'module.' prefix if present (from DDP)\n", + " new_state_dict = {}\n", + " for k, v in state_dict.items():\n", + " if k.startswith('module.'):\n", + " new_state_dict[k[7:]] = v\n", + " else:\n", + " new_state_dict[k] = v\n", + " \n", + " model.load_state_dict(new_state_dict, strict=False)\n", + " \n", + "\n", + " \n", + " return model\n", + "\n", + "# Run the complete pipeline\n", + "checkpoint_path = '/home/ubuntu/higgs_audio_train/25hz_CQT_step_99000.pth' #NOTE: this is a 25cps test model trained during a single afternoon on a small dataset. in no way it is an indication of this architecture at its best.\n", + "config_path = '/home/ubuntu/higgs_audio_train/config_25.json'\n", + "\n", + "device = 'cuda'\n", + "model = prepare(checkpoint_path, config_path, device)\n", + "_ = model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "\n", + "# ---------------------------------------------------------------------------------------------------\n", + "\n", + "\n", + "path = \"shiki_test.wav\"\n", + "# path = \"/home/ubuntu/qatilu.wav\"\n", + "wav, sr = librosa.load(path, sr=44100)\n", + "\n", + "wav = torch.from_numpy(wav).unsqueeze(0).float().to('cuda')\n", + "\n", + "with torch.no_grad():\n", + "\n", + " encoded = encode_batch(model, wav.unsqueeze(0)) \n", + " recon = model.decode(encoded.audio_codes).squeeze(0)\n", + " \n", + "display(Sawt(recon, rate=sr))\n", + "display(Sawt(path))\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "respair", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}