{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "33e4a305",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForCTC, Wav2Vec2Processor\n",
    "from datasets import load_dataset, load_metric, Audio\n",
    "from pyctcdecode import build_ctcdecoder\n",
    "from pydub import AudioSegment\n",
    "from pydub.playback import play\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import kenlm\n",
    "import pandas as pd\n",
    "import random\n",
    "import soundfile as sf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "328d0662",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCTC.from_pretrained(\".\")\n",
    "processor = Wav2Vec2Processor.from_pretrained(\".\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "0fea2518",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model = AutoModelForCTC.from_pretrained(\"vitouphy/xls-r-300m-km\").to('cuda')\n",
    "# processor = Wav2Vec2Processor.from_pretrained(\"vitouphy/xls-r-300m-km\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9cfef23c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration default-36119ec2a15afb82\n",
      "Reusing dataset csv (/workspace/.cache/huggingface/datasets/csv/default-36119ec2a15afb82/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e)\n"
     ]
    }
   ],
   "source": [
    "common_voice_test  = (load_dataset('csv', data_files='km_kh_male/line_index_test.csv', split = 'train')\n",
    "                      .remove_columns([\"Unnamed: 0\", \"drop\"])\n",
    "                      .rename_column('text', 'sentence')\n",
    "                      .cast_column(\"path\", Audio(sampling_rate=16_000)).rename_column('path', 'audio'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "29e6bb1a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'audio': {'path': '/workspace/xls-r-300m-km/km_kh_male/wavs/khm_3154_2555595821.wav',\n",
       "  'array': array([ 0.00014737,  0.00016698,  0.00013704, ..., -0.00011244,\n",
       "         -0.0001059 , -0.00011476], dtype=float32),\n",
       "  'sampling_rate': 16000},\n",
       " 'sentence': 'ការ ធ្វើ អាជីវកម្ម រ៉ែ ដំបូង នៅ កម្ពុជា'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "common_voice_test[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0554b8d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_dataset(batch):\n",
    "    audio = batch[\"audio\"]\n",
    "    \n",
    "    # batched output is \"un-batched\"\n",
    "    batch[\"input_values\"] = processor(np.array(audio[\"array\"]), sampling_rate=audio[\"sampling_rate\"]).input_values[0]\n",
    "    batch[\"input_length\"] = len(batch[\"input_values\"])\n",
    "    \n",
    "    with processor.as_target_processor():\n",
    "        batch[\"labels\"] = processor(batch[\"sentence\"]).input_ids\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d26a6659",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/csv/default-36119ec2a15afb82/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e/cache-081703c0621182da.arrow\n"
     ]
    }
   ],
   "source": [
    "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "04a94f74",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3993d2c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
     ]
    }
   ],
   "source": [
    "input_dict = processor(common_voice_test[i][\"input_values\"], return_tensors=\"pt\", padding=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7e3026dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_values': tensor([[ 2.8537e-04,  2.5043e-04,  2.7738e-04,  ..., -4.8949e-05,\n",
       "         -1.1382e-04,  2.7166e-04]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], dtype=torch.int32)}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "adf215c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
     ]
    }
   ],
   "source": [
    "input_dict = processor(common_voice_test[i][\"input_values\"], return_tensors=\"pt\", padding=True)\n",
    "logits = model(input_dict.input_values.to(\"cuda\")).logits\n",
    "pred_ids = torch.argmax(logits, dim=-1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e8310629",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,\n",
       "        72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 10, 70, 70, 70, 10, 72,\n",
       "        43, 72, 72, 72, 72, 72, 72,  0,  0, 72, 72, 18, 72, 54, 72, 72, 72, 72,\n",
       "        72,  0, 72, 21, 72, 49, 72, 72, 72, 72, 72, 72, 23, 70, 70, 27, 72, 46,\n",
       "        72, 72, 72,  1, 72,  0,  0, 30, 72, 72, 72, 72, 25, 70, 70, 72, 72, 11,\n",
       "        55, 72, 72, 72, 72,  5, 72,  0, 20, 58, 72, 72, 72,  0,  0, 16, 72, 72,\n",
       "        72, 20, 70, 70, 72, 72, 16, 70, 27, 72, 72, 72, 72, 72, 45,  0,  0, 30,\n",
       "        30, 70, 70, 27, 72, 43, 72, 72, 72, 72, 72, 72, 21, 72, 53, 72, 72, 72,\n",
       "        27, 72,  0,  1, 72, 72, 72, 72, 25, 70, 23, 23, 48, 72, 72, 72, 72, 72,\n",
       "        72,  8, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,\n",
       "        72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,\n",
       "        72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72,\n",
       "        72, 72, 72, 72, 72, 72, 72, 72, 43], device='cuda:0')"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5dd986a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction:\n",
      "កញ្ញា ទេ បូព្រឹក សម្ដែង នៅ តន្ត្រី ស្រាបៀរ កម្ពុជា\n",
      "\n",
      "Reference:\n",
      "កញ្ញា ទេព បូព្រឹក្ស សម្ដែង នៅ តន្ត្រី ស្រាបៀរ កម្ពុជា\n"
     ]
    }
   ],
   "source": [
    "print(\"Prediction:\")\n",
    "pred_ids = pred_ids[pred_ids != processor.tokenizer.pad_token_id]\n",
    "print(processor.decode(pred_ids))\n",
    "\n",
    "print(\"\\nReference:\")\n",
    "print(processor.decode(common_voice_test['labels'][i]))\n",
    "# print(common_voice_test_transcription[0][\"sentence\"].lower())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e39b112",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "562af933",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}