diff --git "a/codec_inference.ipynb" "b/codec_inference.ipynb"
new file mode 100644--- /dev/null
+++ "b/codec_inference.ipynb"
@@ -0,0 +1,258 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/ubuntu/higgs_audio_train\n"
+ ]
+ }
+ ],
+ "source": [
+ "%cd /home/ubuntu/higgs_audio_train\n",
+ "\n",
+ "import librosa\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "import numpy as np\n",
+ "import json\n",
+ "import torch\n",
+ "from IPython.display import Audio as Sawt\n",
+ "from higgs_audio_tokenizer import HiggsAudioTokenizer\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import warnings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/ubuntu/higgs_audio_train\n",
+ "Loading config...\n",
+ "Creating model...\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading checkpoint...\n"
+ ]
+ }
+ ],
+ "source": [
+ "%cd /home/ubuntu/higgs_audio_train\n",
+ "\n",
+ "import librosa\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "import numpy as np\n",
+ "import json\n",
+ "import torch\n",
+ "from IPython.display import Audio as Sawt\n",
+ "from higgs_audio_tokenizer import HiggsAudioTokenizer\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import warnings\n",
+ "\n",
+ "\n",
+ "class EncodedResult:\n",
+ " def __init__(self, audio_codes, quantized):\n",
+ " self.audio_codes = audio_codes\n",
+ " self.quantized = quantized\n",
+ "\n",
+ "\n",
+ "def encode_batch(model, x_batch):\n",
+ " \"\"\"\n",
+ " Encodes a batch of audio tensors using the HiggsAudioTokenizer model.\n",
+ " Args:\n",
+ " model: The loaded HiggsAudioTokenizer model.\n",
+ " x_batch: A tensor of shape [B, 1, T]\n",
+ " \"\"\"\n",
+ " # Acoustic and Semantic Feature Extraction\n",
+ " e_semantic_input = model.get_regress_target(x_batch).detach()\n",
+ " e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2))\n",
+ " e_acoustic = model.encoder(x_batch)\n",
+ "\n",
+ " # This block contains the fix for batch processing\n",
+ " if e_acoustic.shape[2] != e_semantic.shape[2]:\n",
+ " pad_size = 160 * model.semantic_downsample_factor\n",
+ " \n",
+ " # 1. Remove channel dim, preserving batch dim -> [B, T]\n",
+ " x_slice = x_batch[:, 0, :]\n",
+ " \n",
+ " # 2. Pad the tensor\n",
+ " x_padded = F.pad(x_slice, (pad_size, pad_size))\n",
+ " \n",
+ " # 3. Re-add channel dim before passing to encoder -> [B, 1, T_padded]\n",
+ " e_acoustic = model.encoder(x_padded.unsqueeze(1))\n",
+ "\n",
+ " # Ensure dimensions match before concatenating\n",
+ " min_len = min(e_acoustic.shape[2], e_semantic.shape[2])\n",
+ " e_acoustic = e_acoustic[:, :, :min_len]\n",
+ " e_semantic = e_semantic[:, :, :min_len]\n",
+ "\n",
+ " # Remainder of the original encoding logic\n",
+ " e = torch.cat([e_acoustic, e_semantic], dim=1)\n",
+ " e = model.fc_prior(e.transpose(1, 2))\n",
+ "\n",
+ " if model.quantizer_type == \"RVQ\":\n",
+ " e = e.transpose(1, 2)\n",
+ " quantized, codes, _, _ = model.quantizer(e, model.frame_rate, None)\n",
+ " codes = codes.permute(1, 0, 2)\n",
+ " else: # RFSQ\n",
+ " quantized, codes = model.quantizer(e)\n",
+ " codes = codes.permute(0, 2, 1)\n",
+ "\n",
+ " return EncodedResult(audio_codes=codes, quantized=quantized)\n",
+ "\n",
+ "def prepare(checkpoint_path, config_path, device='cuda'):\n",
+ "\n",
+ " # Load config\n",
+ " print(\"Loading config...\")\n",
+ " with open(config_path, 'r') as f:\n",
+ " config = json.load(f)\n",
+ " \n",
+ " # Create model\n",
+ " print(\"Creating model...\")\n",
+ " model = HiggsAudioTokenizer(\n",
+ " n_filters=config['n_filters'],\n",
+ " D=config['D'],\n",
+ " target_bandwidths=config['target_bandwidths'],\n",
+ " ratios=config['ratios'],\n",
+ " sample_rate=config['sample_rate'],\n",
+ " bins=config['bins'],\n",
+ " n_q=config['n_q'],\n",
+ " codebook_dim=config.get('codebook_dim', None),\n",
+ " semantic_techer=config['semantic_techer'],\n",
+ " device=device\n",
+ " ).to(device)\n",
+ " \n",
+ " # Load checkpoint\n",
+ " print(\"Loading checkpoint...\")\n",
+ " checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\n",
+ " \n",
+ " if 'model_state_dict' in checkpoint:\n",
+ " state_dict = checkpoint['model_state_dict']\n",
+ " else:\n",
+ " state_dict = checkpoint\n",
+ " \n",
+ " # Remove 'module.' prefix if present (from DDP)\n",
+ " new_state_dict = {}\n",
+ " for k, v in state_dict.items():\n",
+ " if k.startswith('module.'):\n",
+ " new_state_dict[k[7:]] = v\n",
+ " else:\n",
+ " new_state_dict[k] = v\n",
+ " \n",
+ " model.load_state_dict(new_state_dict, strict=False)\n",
+ " \n",
+ "\n",
+ " \n",
+ " return model\n",
+ "\n",
+ "# Run the complete pipeline\n",
+ "checkpoint_path = '/home/ubuntu/higgs_audio_train/25hz_CQT_step_99000.pth' #NOTE: this is a 25cps test model trained during a single afternoon on a small dataset. in no way it is an indication of this architecture at its best.\n",
+ "config_path = '/home/ubuntu/higgs_audio_train/config_25.json'\n",
+ "\n",
+ "device = 'cuda'\n",
+ "model = prepare(checkpoint_path, config_path, device)\n",
+ "_ = model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "\n",
+ "\n",
+ "# ---------------------------------------------------------------------------------------------------\n",
+ "\n",
+ "\n",
+ "path = \"shiki_test.wav\"\n",
+ "# path = \"/home/ubuntu/qatilu.wav\"\n",
+ "wav, sr = librosa.load(path, sr=44100)\n",
+ "\n",
+ "wav = torch.from_numpy(wav).unsqueeze(0).float().to('cuda')\n",
+ "\n",
+ "with torch.no_grad():\n",
+ "\n",
+ " encoded = encode_batch(model, wav.unsqueeze(0)) \n",
+ " recon = model.decode(encoded.audio_codes).squeeze(0)\n",
+ " \n",
+ "display(Sawt(recon, rate=sr))\n",
+ "display(Sawt(path))\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "respair",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}