diff --git "a/examples/FinGPT_Training_LoRA_with_ChatGLM2_6B_for_Beginners.ipynb" "b/examples/FinGPT_Training_LoRA_with_ChatGLM2_6B_for_Beginners.ipynb" new file mode 100644--- /dev/null +++ "b/examples/FinGPT_Training_LoRA_with_ChatGLM2_6B_for_Beginners.ipynb" @@ -0,0 +1,19783 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Getting Started with FinGPT\n", + "Welcome to this comprehensive guide aimed at beginners diving into the realm of Financial Large Language Models (FinLLMs) with FinGPT. This blog post demystifies the process of training FinGPT using Low-Rank Adaptation (LoRA) with the robust base model ChatGlm2-6b.\n", + "\n" + ], + "metadata": { + "id": "X8H-Vc6w6WSU" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Part 1: Preparing the Data\n", + "Data preparation is a crucial step when it comes to training Financial Large Language Models. Here, we’ll guide you on how to get your dataset ready for FinGPT using Python.\n", + "\n", + "In this section, you’ve initialized your working directory and loaded a financial sentiment dataset. Let’s break down the steps:\n", + "\n" + ], + "metadata": { + "id": "4oLjc7bbv0tO" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install datasets transformers torch tqdm pandas huggingface_hub\n", + "!pip install sentencepiece\n", + "!pip install protobuf transformers==4.30.2 cpm_kernels torch>=2.0 gradio mdtex2html sentencepiece accelerate\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-maUV8CH7JPB", + "outputId": "dc512a8f-b4e3-44cc-f489-b8f768d82f7e" + }, + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.14.5)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.30.2)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.1)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (1.5.3)\n", + "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (0.16.4)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n", + "Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.5)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.4)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.3)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.27.6)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (17.0.2)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2023.3.post1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", + "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.3.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### 1.1 Initialize Directories:\n", + "This block checks if certain paths exist; if they do, it deletes them to avoid data conflicts, and then creates a new directory for the upcoming data.\n", + "\n" + ], + "metadata": { + "id": "hJp_UOiB70o3" + } + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "aBE7gRUJ3L8u" + }, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "\n", + "jsonl_path = \"../data/dataset_new.jsonl\"\n", + "save_path = '../data/dataset_new'\n", + "\n", + "\n", + "if os.path.exists(jsonl_path):\n", + " os.remove(jsonl_path)\n", + "\n", + "if os.path.exists(save_path):\n", + " shutil.rmtree(save_path)\n", + "\n", + "directory = \"../data\"\n", + "if not os.path.exists(directory):\n", + " os.makedirs(directory)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OvIBHhS5pV8o" + }, + "source": [ + "### 1.2 Load and Prepare Dataset:\n", + "\n", + "* Import necessary libraries from the datasets package: https://huggingface.co/docs/datasets/index\n", + "* Load the Twitter Financial News Sentiment (TFNS) dataset and convert it to a Pandas dataframe. https://huggingface.co/datasets/zeroshot/twitter-financial-news-sentiment\n", + "* Map numerical labels to their corresponding sentiments (negative, positive, neutral).\n", + "* Add instruction for each data entry, which is crucial for Instruction Tuning.\n", + "* Convert the Pandas dataframe back to a Hugging Face Dataset object.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 310, + "referenced_widgets": [ + "6573cca7c7eb4d8cae686699a2239067", + "1f8cdca59218433b8816a1a92c08afd4", + "55f739dce229477886de18f0c642f222", + "17efe5d0bb5546dfbacb81e84f838d9b", + "7d52d2f9f0334ca1b8fcc1cc94a49033", + "6db51490d9c64dfd9962fccfad9904f3", + "429c7390cf0a4c24940dba27f9d21539", + "8a97cc02ef4149afaed023bf17c25a10", + "d5dc927524dc4585a7613213bb8216be", + "28a2d7a657d24be5b6cbb04031e7473a", + "5c510a8b7acb47f9a414455d41cac438", + "5694f8f68e2248f69e7df724e08a6195", + "fd09e1d8c0a24cc69dcdb5a7747b8667", + "3b063090da0b4a48b1d5eeb75bbabe15", + "5b502b155170438b8f333665c41fa991", + "bf1a8d32137d434286a9e04cb6ea1d57", + "0f3e12be01f747278897f34557e03ee3", + "25190fdce66a4c4d97007073c7353d1d", + "af54642a39bb41d7906d2f54e6e28950", + "551f8f30319f4b2caa7265709345410d", + "7d92560d9e5d4894989e436b5427b530", + "5196f93d756f4a04881cc97a7b131db6", + "2dd75ac84881458ea838d2ca0795ebe0", + "f9809fbbcd5440fa8f6aa1aa8b2aa16e", + "93cbab99432743b2a8ffb742ec6c6689", + "40c1d516a2044471afdde9ae935b6135", + "cd55af40f7d943fc8662f909c6777624", + "a3a291f58b5f4ff99f242bf9b80c0b07", + "c0b49871e23849d7879824b7d802af25", + "7ad9c5b26e034ad288017a937c3db645", + "bd9d88fe1c6642599d6ae09dfda9442f", + "af8f987952f842b192cadca19a337b77", + "6010e617345648f3aa5394a128b84ac2", + "a9ac7dcbc7e24f2bb3f32df048a3bfde", + "f38f2ace74574a839e48f8b759b1e5de", + "7b1b05174b6848d99f54ab731699df37", + "d4a6effdfa5b499f896cecc5ffafb9e3", + "4c994cad782d4cd4b431e82883be93ba", + "fc6bc52c48034b21947b540912e7efe1", + "8d2ec668f1c9459fb54568e03d38394b", + "35a454b114e843a6a86489911ba05a45", + "5483aa2a3627407892f403af82534bc8", + "10ede871ce7f4062823be58587ec393d", + "22a6ff1db2fe4224a96bf584a9de303a", + "4a2421ab197141c78ff4e2b1d68a525b", + "5685027d2d264c2987b0161e6d5210c2", + "e596b924d42f4a1f98427d11ae7de71e", + "fba7496f6b2548df985a386ccd8b5f3b", + "8f11428ba9f54e17a24cce1fb5f0bf89", + "5e13f6c2128440ae8f3ffe3c21e0971b", + "a17ea5e5f5e44e629b7da31df35d0b53", + "c30f0bedb7ac4683b3b942d4d6ee65b7", + "8bfa056e0d5a4a8a86e484e7126bad66", + "2b9d93df766c4bf98a4127fde973ceed", + "658151cb98dc4f83b06fbff9c5749359", + "4c04bb4b1f714bb8b7fa45fe4109533a", + "ce711dac250c4bd28dcee2fd59cbcd56", + "cc58ed0821684ce0a747e62fc4700410", + "e98762694e4b4b10aa99ff67bee6658f", + "3c22d7322ca34ec7961e034c15975739", + "536e88e5cecb41e3941428695f4f01c2", + "a9d770653dc6465783b36041480eb2b2", + "9d3e8ec490d94ff2a1c59b5b83625da8", + "114b691b525747119ec500897d737099", + "48014711b00a44d6be079c7187b203fd", + "9650e39fc9884696abb1a52428ca7258", + "ea2db2f0c44b461da017f2115123faa5", + "43ec3650376a4bd6bd569c0a67fe6715", + "42ad8776790b42adb04d2f6986a5d5f9", + "14ec5e9872774ab996dd75d8211e99b8", + "1e7ea2ad609d4e59a6f561a329da8ac6", + "e28c2135866a40308e8e91feb663d124", + "034a74360ebf4b9c96eb9ae8b922c585", + "52709cf51cba478da76d778853b161e5", + "019f57dd755d4972ae29628592f957ad", + "99589e5c1eb04a849d0ef67406cfe45f", + "d4cb320b1c7549fb90ae272b40f31a9b" + ] + }, + "id": "wVVWAq54ohCT", + "outputId": "6db17b03-b647-4af7-bf3d-aa79ae03cd83" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading readme: 0%| | 0.00/1.57k [00:00 dict:\n", + " context = f\"Instruction: {example['instruction']}\\n\"\n", + " if example.get(\"input\"):\n", + " context += f\"Input: {example['input']}\\n\"\n", + " context += \"Answer: \"\n", + " target = example[\"output\"]\n", + " return {\"context\": context, \"target\": target}" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "oomCwuggpjsF" + }, + "outputs": [], + "source": [ + "data_list = []\n", + "for item in all_dataset.to_pandas().itertuples():\n", + " tmp = {}\n", + " tmp[\"instruction\"] = item.instruction\n", + " tmp[\"input\"] = item.input\n", + " tmp[\"output\"] = item.output\n", + " data_list.append(tmp)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "0c47ca540b134cc980002fece7027cd1", + "55a63e75fb0f4e9b8abcdf97d6396f4c", + "828b5c4bc5304f2f8d087e6be2aea3de", + "4410d3b4c0054d839220de695b43d39e", + "a2ffd06a989144a0b24ff9867afb216c", + "9760258333124b609e9ce1d68d99104b", + "78f0701517ac406199d9352fc14621ee", + "36c3ce5239624f5296d3b8c2c7299023", + "44bf2888fb6b489ab4b1285204e1241f", + "9bd5a0968db9477a90bb965def93819e", + "eeee958db1ce4d98a8dfc6d64fce44d6" + ] + }, + "id": "VuuRhlD1pjqH", + "outputId": "52b73e43-d23d-43ea-ee00-a94c504527f9" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "formatting..: 0%| | 0/19086 [00:00 max_seq_length:\n", + " continue\n", + " feature[\"input_ids\"] = feature[\"input_ids\"][:max_seq_length]\n", + " yield feature" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### 2.3 Save the dataset" + ], + "metadata": { + "id": "j9M_NsM2-9Mk" + } + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 377, + "referenced_widgets": [ + "cab66d7e3c104f87b990ae93ccafbe0a", + "b2094294470c4fbbb5cbf035bb2d9100", + "8b59d6ba6fb04e8f8daaf8e9a70d3a67", + "c5c1965b109b42918bb4bfa00583e1bb", + "9179dfa7238a454fa802b1f8c289291d", + "b543d4d83d714c7ca73986aab40312d9", + "481dfd1ab35e4993b73c8115829e003f", + "ed0a6d4c825540b1baba24e23a40a597", + "919cf192efd0496f80eefe2381e21611", + "6fa516c80f9245b9962250ac1d424a09", + "2ef88176c4f344e19d81c0b30f21cdf7", + "069070c8ff784daea79efbf7e61af6be", + "d5e6ffbedea84088b73b90c004e5a75c", + "0ce02204ecb843238cdbdef35ab3aec7", + "0e254352c2484611897973bf107391ea", + "088f77ad3daf4cf6b22c207508d01f28", + "44bec16172c442f3846e1d7b03ed25a9", + "8a91b7639f974291aac885c35e2b5538", + "d4aecd757dd142bfb97028990f739ce4", + "c2b9561d79a8410fa94cecf63da3d072", + "3f5bcdf93a204eb4af03df92597211f3", + "b3c4b520dee04c8a91b117b90ef42f8b", + "e7a91951ddf5486187e1cb976fc3c666", + "b50651bc8dcd43528bd5491e8c103430", + "46346adec17946fb8a77dac7c91cdcf6", + "4dc1959896a145d5810f790ebc3a2e0e", + "d9b8a038666c479f8f9421054b985ef8", + "ab7a49736356450d931c1d312afb3248", + "693df0108c1a468a92a22fd9121caa71", + "dd7e5e0cef41413a96862d7a4c575e50", + "a20a481889454ef0982d86044c0b3436", + "da95116446aa4dd69a68a2f52c0d74e2", + "03c712f54b1b4c0d9b428f1236d56b73", + "c37604b40abc4680871dfda07fef8994", + "07da2fc551ec47f1b3b488b21a56dda9", + "b6a916d73bf9430aa7a92bfeb6ce7368", + "84d53294e36b4602bb9c5c477fb3fcda", + "f444394113264eeba7d6f8ff3dc03b04", + "a762c7583c1a474fae09eedf4ddcb10a", + "adad8d6b45b8481b9f235a8417acc911", + "89d07c37c1ad4322b1b20c1d03238d7d", + "2ab485d024404b2d91ade2ddfb320187", + "045c4050829d409491c9ee7c1abad664", + "7011bb5d764c475f9f586b2575e2346c", + "eafeedacf9fe43aa8a6305908701f8ac", + "8fa433809db24ec0ad3f7351f47dd04b", + "409bb62c3a7747f28bedb07c60703268", + "684963775bc14cddbe83698df6e07b44", + "f2348a2b3924491a8480e299cfec61ad", + "18138557a9424a27a584945d2f8cf593", + "1b7a91b980224223957eccd0c448e524", + "a4027742e02a4becae9db06b6536a394", + "ac6469633e834243b15435b944193e68", + "0e956b9ec9314b87af222da764efb16f", + "70f21e42eb0d48beb51f0952db58edcf", + "d653840a8636418eaf0ff15027c4f9a8", + "37d196cedc84401f958dbd84b5d29d05", + "c8110176d1754e9eaafd4a4f4eb4ee1e", + "540c4e93247149f88ae7577fc6a1bc08", + "4cbb60f5c0c24f5daa48aa56b419ea05", + "f413f8a87ebd49d2bec468b21626eab7", + "39a8b861d23b4b3baae535607a510745", + "494d449d273e4244a87b7545c5c40a56", + "68db9b8eca73403d8c974f308eee075b", + "ae687857bd93447c8270b34a06d923e9", + "cc6a68c6eca8422ead9de611a4ec8114", + "dc0acdd6bcb7453c8dc1efce21559562", + "4b575dff1050494293f899d07781d094", + "6ae54d0b9ac04768a2f25ded4c9f2a95", + "594e2960ef414a16ac74ee9eb7ae44ee", + "d224e745c8ee43ecb6f6ebe93b57a3f7", + "be2b1064148141a9bb28b79ad2d70510", + "a88a5bf023194d1eb631fbb1fa280043", + "ed4d24236dc34ac5b6276edf001dc678", + "ab73990c1ef64a73a6c19f0dd2456718", + "85074dc29a5d4a23b58a4e815c87f319", + "68d14efd5cd7420abfe06e18ab1efec1", + "90053e24fb314c4cb8b2314f5f4b454c", + "881468e12a1e4b9cbb7b48862a4b260c", + "fbdf2503cb1f49e59e87feb5fee7c17b", + "c3a335fff0364ce39881e3abca0c3567", + "81cf157ed7d9401abbab6f91ed56ca6b", + "d4a03d42a8134a05945d89dc372cfc45", + "77bc972c6f0c40d7ba4459997aa6f3f9", + "7bbcd34ca42d45faa32e5192350e6c37", + "bdd5fe6edfde4611bea6a83ef2f7b718", + "a60a3bd364ed4ab5a3e48ff20b15cd80", + "1f6470abb8574791b446c270dbd9283f" + ] + }, + "id": "B5yRJ52HpjiD", + "outputId": "3b5aa77d-59ff-4a62-b19b-fb283935342e" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Generating train split: 0 examples [00:00, ? examples/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "cab66d7e3c104f87b990ae93ccafbe0a" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading (…)okenizer_config.json: 0%| | 0.00/244 [00:00torch) (3.27.6)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (17.0.2)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.23.5)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision) (2.31.0)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.4.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (3.3.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (2023.7.22)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.30.2)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.4)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.3)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.5.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n", + "Requirement already satisfied: loguru in /usr/local/lib/python3.10/dist-packages (0.7.2)\n", + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.14.5)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", + "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n", + "Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.5)\n", + "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.16.4)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", + "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.3.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.12.4)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.5.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n", + "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.5.0)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.23.5)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (23.2)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n", + "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.0.1+cu118)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.30.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft) (4.66.1)\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from peft) (0.23.0)\n", + "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.3.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.12.4)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.0.0)\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft) (3.27.6)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft) (17.0.2)\n", + "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate->peft) (0.16.4)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2023.6.3)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2.31.0)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.13.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate->peft) (2023.6.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.3.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2023.7.22)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n", + "Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.41.1)\n", + "Requirement already satisfied: tensorboard in /usr/local/lib/python3.10/dist-packages (2.13.0)\n", + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (1.4.0)\n", + "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (1.59.0)\n", + "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (2.17.3)\n", + "Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (1.0.0)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (3.4.4)\n", + "Requirement already satisfied: numpy>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (1.23.5)\n", + "Requirement already satisfied: protobuf>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (3.20.3)\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (2.31.0)\n", + "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (67.7.2)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (0.7.1)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (3.0.0)\n", + "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.10/dist-packages (from tensorboard) (0.41.2)\n", + "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard) (5.3.1)\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard) (0.3.0)\n", + "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard) (1.16.0)\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard) (4.9)\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard) (1.3.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard) (3.3.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard) (2023.7.22)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard) (2.1.3)\n", + "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard) (0.5.0)\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard) (3.2.2)\n", + "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.23.0)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.23.5)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.2)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.1)\n", + "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)\n", + "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.16.4)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.12.4)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.0.0)\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (3.27.6)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (17.0.2)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (2023.6.0)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (2.31.0)\n", + "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (4.66.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.3.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2023.7.22)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n" + ] + } + ], + "source": [ + "!pip install torch torchvision torchaudio\n", + "!pip install transformers\n", + "!pip install loguru\n", + "!pip install datasets\n", + "!pip install peft\n", + "!pip install bitsandbytes\n", + "!pip install tensorboard\n", + "!pip install sentencepiece\n", + "!pip install accelerate -U" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "Jpn7KAIW9RuP" + }, + "outputs": [], + "source": [ + "# Ensure CUDA is accessible in the system path\n", + "# Only for Windows Subsystem for Linux (WSL)\n", + "import os\n", + "os.environ[\"PATH\"] = f\"{os.environ['PATH']}:/usr/local/cuda/bin\"\n", + "os.environ['LD_LIBRARY_PATH'] = \"/usr/lib/wsl/lib:/usr/local/cuda/lib64\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### 3.1 Training Arguments Setup:\n", + "Initialize and set training arguments.\n", + "\n" + ], + "metadata": { + "id": "aLjcEr3FBGzi" + } + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "gW5hW4-F9RsX" + }, + "outputs": [], + "source": [ + "from typing import List, Dict, Optional\n", + "import torch\n", + "from loguru import logger\n", + "from transformers import (\n", + " AutoModel,\n", + " AutoTokenizer,\n", + " TrainingArguments,\n", + " Trainer,\n", + " BitsAndBytesConfig\n", + ")\n", + "from peft import (\n", + " TaskType,\n", + " LoraConfig,\n", + " get_peft_model,\n", + " set_peft_model_state_dict,\n", + " prepare_model_for_kbit_training,\n", + " prepare_model_for_int8_training,\n", + ")\n", + "from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "4w2Pe1rT9Rqv" + }, + "outputs": [], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir='./finetuned_model', # saved model path\n", + " logging_steps = 500,\n", + " # max_steps=10000,\n", + " num_train_epochs = 2,\n", + " per_device_train_batch_size=4,\n", + " gradient_accumulation_steps=8,\n", + " learning_rate=1e-4,\n", + " weight_decay=0.01,\n", + " warmup_steps=1000,\n", + " save_steps=500,\n", + " fp16=True,\n", + " # bf16=True,\n", + " torch_compile = False,\n", + " load_best_model_at_end = True,\n", + " evaluation_strategy=\"steps\",\n", + " remove_unused_columns=False,\n", + "\n", + " )" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### 3.2 Quantization Config Setup:\n", + "Set quantization configuration to reduce model size without losing significant precision.\n", + "\n" + ], + "metadata": { + "id": "QGabm7cyBM6r" + } + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "OfKTizyP9Rmn" + }, + "outputs": [], + "source": [ + " # Quantization\n", + "q_config = BitsAndBytesConfig(load_in_4bit=True,\n", + " bnb_4bit_quant_type='nf4',\n", + " bnb_4bit_use_double_quant=True,\n", + " bnb_4bit_compute_dtype=torch.float16\n", + " )" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### 3.3 Model Loading & Preparation:\n", + "Load the base model and tokenizer, and prepare the model for INT8 training.\n", + "\n", + "* **Runtime -> Change runtime type -> A100 GPU**\n", + "* retart runtime and run again if not working\n" + ], + "metadata": { + "id": "WMqFjp_mBVqO" + } + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "th_3Rnqy9Rkg", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 392, + "referenced_widgets": [ + "2740623b3fb44f27ba6e4c8ad2261403", + "41faa0829f5940f5ba6e4b92315b2c0d", + "55fe34ff85f1400caf04f248320fc3f8", + "09879c43d78e4325a79fed99b4d95792", + "156fee6a3567490d84844c58b6431b2d", + "999a6449b6ea46df820f4f07c4037f7d", + "01299e8c8b7d47be8b17fdf660a27d8b", + "a4e52b3f8a394635bd1e47fe7663bd1c", + "d23b0ed7eba04c7ebbafdb44c7b5e60b", + "8b354be158374e4d843f0e069f0fd9c4", + "5a61d5e363f44e2488cb93431627a636", + "fa3548ca466c466a8d1d1294b2c52f4d", + "314d5643183945688c322a4d7f46f6e6", + "c4da7184483f4791ad14edd04e386d7c", + "6c052cc6d58448fb91b390b9998567f4", + "776ac3877e6f4d1c85410604a32c3768", + "f6cc40a971764c6e912a5a1255733e46", + "321311dbb13648eba09dea32019af115", + "6ac343f30f394400aa968a47c5d0c204", + "3e340dc86f244372b1c976e92783b981", + "6b64fca589ad4f9d91294e9658c1e482", + "6614098a16ad464d9c255e33719e527c", + "b157f7eaad234cdbaeb400ebbdc39d68", + "4c0d0ebfad7c4da394d5d2b8cb4677ec", + "72735d52ab4b44748e907f20c61464b1", + "f18c8022f7a14be9a9159d8fb2152238", + "5808b763539443c28e3186e35b4593f3", + "abac754de8e04bbb9122d52f27247b0e", + "1ca6336b74d34c5a8bed97c916b4ac04", + "548b16e12fa44936917c3c48c22f4442", + "250da57b21a9409aaaf41f1a8f972a5d", + "b8ba10ddae554923b20a0af100bee1ca", + "1fce8665b14c42a18945a4a6362f6e79", + "d1581466c1a04fcdb35c14dd352b6ad1", + "92601dc799214006b49c2296e3059db8", + "eccb174ab6a64a3d995731c92f5dcca5", + "1b2969439cd24d188f5bf84f24b5205b", + "93e13162ec0647989f275012b955bd89", + "52366ad75f394766aabac53d474b317e", + "77555dd440934d5fabd4c0c85d3e3a80", + "4023f05a72a94f34ac9c99f73c595d0a", + "6f6a8d619bf947c7979bb6371716a0db", + "1dc2cc5fe1594ce7b95c79798033b542", + "84ff5512ae8d4a4a950311d5e0e27f0f", + "0a5474b0a1b945deb57aaa9ae9ae0467", + "54d1528d53e640998f09688f4ea1d75c", + "488d3b4c4d3748cea18b3906d7543712", + "06a65c4a9cf9483e9a814b966769ea3f", + "3a635187fbad4cc39054bc59658e20e5", + "4f07cd6f4c46439887c9a85c33d3baae", + "512654d21fa645b98b702c37e438784a", + "e6456544cfbc424baa63f8fdd6d17c7a", + "246fe12c849745ada9f9676389e5a71f", + "3f9b9d9a7b8b4b3296d8ab48e334ff96", + "8a7d598f36eb4e39b57a3cc725ccfeae", + "f05f80f9ebc940478a6ff8f485e17fbb", + "612820e4b57440e0acc6005e5c675139", + "be9fe158498f462b904e8386899e50a9", + "b53c1f7535ad41bd9cdbac55efb912c2", + "700fcac7f4cc4ce881f5cf1605d9f112", + "02e64751c685444291ec5d43eddeda67", + "09a6bf680ea94547a7e965500779d254", + "788f68f501ec46b3872a151b9ed58ac5", + "7e2ac9a84e5843719aea1a39a404b058", + "a1c33ce68b35412a97e77c8f2556e6c1", + "b2b8ad4765d04dfb89ec62e14c35a3a2", + "4056b6768e9449c2a0b34479e9599b0f", + "e65fa0a769cf48f5b3f373858af45d56", + "438cb4c759f3477583452dc5d8bdf389", + "1b44124c77c14ea69afa532ef8b1c60a", + "bf8a87edb4bc401999cfe335c00496b8", + "28d0f93e391149f6ad3ac030d1e8e32a", + "9e04a8712b7c4719869bf1d5f774e7d0", + "e50411984f654b2f8a6f8b18a54ce913", + "77a0471892804bc795fbc382dcccd511", + "563939a1bf9c43b28677f951f51843dd", + "cb8298aa164a40edb0d9059da15f3470", + "5a4e14560e8d434f87adb8dfbcb69513", + "698926ed25b94c17927f454062e8af2f", + "15520924776f444cac64312728130f1f", + "f5604e8b61ee46fc8da80b677413bcb8", + "8933b6cfc52e4cb5945c8cf9977d1a52", + "5c721d9db95b4cb88c826ac9e8bff365", + "7d5c389c0d2a470e9e93611988af6879", + "97168560e8154abf95c68b74231fc4d1", + "bc91132e6f2248e4ab7ab78d7bf6d02c", + "12fb5d18878e45d9b523740bb0894f0b", + "217cb2d68fc148d4aa4154ab5afb8e33", + "34a61ba5c4b5422396d002af25fa4429", + "212cac56e05b43aaa3b68cfe29028708", + "73185ad5e4a34b71a68ca563584a5dd3", + "c6ea1e6115a64d9199baf1fb73de43cf", + "5ba4865098804fcbb65cff01421c9541", + "f6e682f649ae45429680a5633c8757da", + "3320d1c2b9a44b8e987dae250b895284", + "03351d22d3574ded9ef868d9b29d1402", + "1c1cad9387bb4a619d08406de6c41368", + "81c7cdbdfd924517b973870365e3a40c", + "64afa277d684419b8d26ef5ce059d09a", + "3eefa6154f1f464493db26cad778cdf1", + "2a52f13cd0b44342948a85eb9dc53d2c", + "fcee88e52db9483d827cdbeb5de3c187", + "e66a7d914d7743c78fb498b402dae997", + "6ea4575b71254e75a1517efcefabbbce", + "d2303a1b0fd3454bbdace8bbcf42a1d8", + "1dd0f715fbf348d9bb7f1467cd64b951", + "a039389e25a94ccab17bba51ac1ce132", + "0009ba50dc5e47dd8bec74de2d6465e9", + "2779bea2fa5745eab7d09a95348f3847", + "169215329249494aaacfa43106b7f542" + ] + }, + "outputId": "1e91c290-beec-416f-f7ec-5f33a2953581" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading (…)model.bin.index.json: 0%| | 0.00/20.4k [00:00 dict:\n", + " len_ids = [len(feature[\"input_ids\"]) for feature in features]\n", + " longest = max(len_ids)\n", + " input_ids = []\n", + " labels_list = []\n", + " for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):\n", + " ids = feature[\"input_ids\"]\n", + " seq_len = feature[\"seq_len\"]\n", + " labels = (\n", + " [tokenizer.pad_token_id] * (seq_len - 1) + ids[(seq_len - 1) :] + [tokenizer.pad_token_id] * (longest - ids_l)\n", + " )\n", + " ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)\n", + " _ids = torch.LongTensor(ids)\n", + " labels_list.append(torch.LongTensor(labels))\n", + " input_ids.append(_ids)\n", + " input_ids = torch.stack(input_ids)\n", + " labels = torch.stack(labels_list)\n", + " return {\n", + " \"input_ids\": input_ids,\n", + " \"labels\": labels,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "ITubEZSK9RVv" + }, + "outputs": [], + "source": [ + "from torch.utils.tensorboard import SummaryWriter\n", + "from transformers.integrations import TensorBoardCallback" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "Cw4Zik6a9RT3", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 231 + }, + "outputId": "fcf72d9b-3267-4d72-93fc-0fd09057c338" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "You are adding a to the callbacks of this Trainer, but there is already one. The currentlist of callbacks is\n", + ":DefaultFlowCallback\n", + "TensorBoardCallback\n", + "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", + " warnings.warn(\n", + "WARNING:transformers_modules.THUDM.chatglm2-6b.8fd7fba285f7171d3ae7ea3b35c53b6340501ed1.modeling_chatglm:`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [954/954 40:18, Epoch 1/2]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining LossValidation Loss
5009.6197006.023976

" + ] + }, + "metadata": {} + } + ], + "source": [ + "# Train\n", + "# Took about 10 compute units\n", + "writer = SummaryWriter()\n", + "trainer = ModifiedTrainer(\n", + " model=model,\n", + " args=training_args, # Trainer args\n", + " train_dataset=dataset[\"train\"], # Training set\n", + " eval_dataset=dataset[\"test\"], # Testing set\n", + " data_collator=data_collator, # Data Collator\n", + " callbacks=[TensorBoardCallback(writer)],\n", + ")\n", + "trainer.train()\n", + "writer.close()\n", + "# save model\n", + "model.save_pretrained(training_args.output_dir)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### 4.3 Model Saving and Download:\n", + "After training, save and download your model. You can also check the model's size.\n", + "\n" + ], + "metadata": { + "id": "brHTWfnmCn5D" + } + }, + { + "cell_type": "code", + "source": [ + "!zip -r /content/saved_model.zip /content/{training_args.output_dir}\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HUYxzwS_9lMI", + "outputId": "7840afdf-ab03-4664-fd1b-99aa55f81125" + }, + "execution_count": 19, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " adding: content/./finetuned_model/ (stored 0%)\n", + " adding: content/./finetuned_model/checkpoint-500/ (stored 0%)\n", + " adding: content/./finetuned_model/checkpoint-500/optimizer.pt (deflated 7%)\n", + " adding: content/./finetuned_model/checkpoint-500/trainer_state.json (deflated 53%)\n", + " adding: content/./finetuned_model/checkpoint-500/scheduler.pt (deflated 49%)\n", + " adding: content/./finetuned_model/checkpoint-500/rng_state.pth (deflated 28%)\n", + " adding: content/./finetuned_model/checkpoint-500/adapter_model.bin (deflated 8%)\n", + " adding: content/./finetuned_model/checkpoint-500/training_args.bin (deflated 48%)\n", + " adding: content/./finetuned_model/runs/ (stored 0%)\n", + " adding: content/./finetuned_model/runs/Oct06_08-19-46_81e899208623/ (stored 0%)\n", + " adding: content/./finetuned_model/runs/Oct06_08-19-46_81e899208623/events.out.tfevents.1696580727.81e899208623.2594.1 (deflated 59%)\n", + " adding: content/./finetuned_model/adapter_model.bin (deflated 8%)\n", + " adding: content/./finetuned_model/README.md (deflated 39%)\n", + " adding: content/./finetuned_model/adapter_config.json (deflated 42%)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# download to local\n", + "from google.colab import files\n", + "files.download('/content/saved_model.zip')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "s-7s2Cjw9pAM", + "outputId": "f00c656b-1122-4a81-896b-c8d94a31979c" + }, + "execution_count": 20, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "download(\"download_0a618fc5-4945-4791-b2cd-8ac358536eeb\", \"saved_model.zip\", 28956881)" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [ + "# save to google drive\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rvBgMgO8RADU", + "outputId": "52bd847d-5bef-4b12-b9ed-d7e9795b3fcb" + }, + "execution_count": 22, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/drive\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# save the finetuned model to google drive\n", + "!cp -r \"/content/finetuned_model\" \"/content/drive/MyDrive\"\n" + ], + "metadata": { + "id": "UUctmjm8RIfQ" + }, + "execution_count": 23, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "id": "unRoLshR9RQZ", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "9e39e6a5-cbc4-4459-9346-7df02f7c5c5d" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model size: 29.84746265411377 MB\n" + ] + } + ], + "source": [ + "def get_folder_size(folder_path):\n", + " total_size = 0\n", + " for dirpath, _, filenames in os.walk(folder_path):\n", + " for f in filenames:\n", + " fp = os.path.join(dirpath, f)\n", + " total_size += os.path.getsize(fp)\n", + " return total_size / 1024 / 1024 # Size in MB\n", + "\n", + "model_size = get_folder_size(training_args.output_dir)\n", + "print(f\"Model size: {model_size} MB\")\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Now your model is trained and saved! You can download it and use it for generating financial insights or any other relevant tasks in the finance domain. The usage of TensorBoard allows you to deeply understand and visualize the training dynamics and performance of your model in real-time.\n", + "\n", + "Happy FinGPT Training! 🚀" + ], + "metadata": { + "id": "1LCjYKuoCusU" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Part 5: Inference and Benchmarks using FinGPT\n", + "Now that your model is trained, let’s understand how to use it to infer and run benchmarks.\n", + "* Took about 10 compute units\n", + "\n" + ], + "metadata": { + "id": "76g_Qlp8t_Yp" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install transformers==4.30.2 peft==0.4.0\n", + "!pip install sentencepiece\n", + "!pip install accelerate\n", + "!pip install torch\n", + "!pip install peft\n", + "!pip install datasets\n", + "!pip install bitsandbytes" + ], + "metadata": { + "id": "ehjG2bpft_OH", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "outputId": "1806add3-b383-41eb-f6c9-76f982162b1b" + }, + "execution_count": 25, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: transformers==4.30.2 in /usr/local/lib/python3.10/dist-packages (4.30.2)\n", + "Collecting peft==0.4.0\n", + " Downloading peft-0.4.0-py3-none-any.whl (72 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m72.9/72.9 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.30.2) (3.12.4)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.30.2) (0.16.4)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.30.2) (1.23.5)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.30.2) (23.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.30.2) (6.0.1)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.30.2) (2023.6.3)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.30.2) (2.31.0)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.30.2) (0.13.3)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.30.2) (0.3.3)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.30.2) (4.66.1)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft==0.4.0) (5.9.5)\n", + "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft==0.4.0) (2.0.1+cu118)\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from peft==0.4.0) (0.23.0)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (2023.6.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft==0.4.0) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft==0.4.0) (3.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft==0.4.0) (3.1.2)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft==0.4.0) (2.0.0)\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft==0.4.0) (3.27.6)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft==0.4.0) (17.0.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.30.2) (3.3.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.30.2) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.30.2) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.30.2) (2023.7.22)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft==0.4.0) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft==0.4.0) (1.3.0)\n", + "Installing collected packages: peft\n", + " Attempting uninstall: peft\n", + " Found existing installation: peft 0.5.0\n", + " Uninstalling peft-0.5.0:\n", + " Successfully uninstalled peft-0.5.0\n", + "Successfully installed peft-0.4.0\n" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "peft" + ] + } + } + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.23.0)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.23.5)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.2)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.1)\n", + "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)\n", + "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.16.4)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.12.4)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.0.0)\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (3.27.6)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (17.0.2)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (2023.6.0)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (2.31.0)\n", + "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (4.66.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.3.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2023.7.22)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.12.4)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.27.6)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (17.0.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.4.0)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.23.5)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (23.2)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n", + "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.0.1+cu118)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.30.2)\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from peft) (0.23.0)\n", + "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.3.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.12.4)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.0.0)\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft) (3.27.6)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft) (17.0.2)\n", + "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate->peft) (0.16.4)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2023.6.3)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2.31.0)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.13.3)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (4.66.1)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate->peft) (2023.6.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.3.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2023.7.22)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n", + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.14.5)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", + "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n", + "Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.5)\n", + "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.16.4)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", + "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.3.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.12.4)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.5.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n", + "Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.41.1)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### 5.1 Load the model" + ], + "metadata": { + "id": "P91SXTrLS34i" + } + }, + { + "cell_type": "code", + "source": [ + "#clone the FinNLP repository\n", + "!git clone https://github.com/AI4Finance-Foundation/FinNLP.git\n", + "\n", + "import sys\n", + "sys.path.append('/content/FinNLP/')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "y5jyY7S_uEls", + "outputId": "739e49ed-0b31-46dd-96e0-162b4ef8073d" + }, + "execution_count": 26, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'FinNLP'...\n", + "remote: Enumerating objects: 1316, done.\u001b[K\n", + "remote: Counting objects: 100% (375/375), done.\u001b[K\n", + "remote: Compressing objects: 100% (166/166), done.\u001b[K\n", + "remote: Total 1316 (delta 219), reused 303 (delta 175), pack-reused 941\u001b[K\n", + "Receiving objects: 100% (1316/1316), 4.21 MiB | 13.55 MiB/s, done.\n", + "Resolving deltas: 100% (592/592), done.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM\n", + "\n", + "from peft import PeftModel\n", + "import torch\n", + "\n", + "# Load benchmark datasets from FinNLP\n", + "from finnlp.benchmarks.fpb import test_fpb\n", + "from finnlp.benchmarks.fiqa import test_fiqa , add_instructions\n", + "from finnlp.benchmarks.tfns import test_tfns\n", + "from finnlp.benchmarks.nwgi import test_nwgi" + ], + "metadata": { + "id": "zRsmSTFZuEjt" + }, + "execution_count": 27, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade peft" + ], + "metadata": { + "id": "EBqKeUYV9VjF", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 756 + }, + "outputId": "20f6ad0e-9e52-4e8b-dde8-33509ef2f381" + }, + "execution_count": 28, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.4.0)\n", + "Collecting peft\n", + " Using cached peft-0.5.0-py3-none-any.whl (85 kB)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.23.5)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (23.2)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n", + "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.0.1+cu118)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.30.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft) (4.66.1)\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from peft) (0.23.0)\n", + "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.3.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.12.4)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.0.0)\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft) (3.27.6)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.13.0->peft) (17.0.2)\n", + "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate->peft) (0.16.4)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2023.6.3)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2.31.0)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.13.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate->peft) (2023.6.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.3.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2.0.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->peft) (2023.7.22)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n", + "Installing collected packages: peft\n", + " Attempting uninstall: peft\n", + " Found existing installation: peft 0.4.0\n", + " Uninstalling peft-0.4.0:\n", + " Successfully uninstalled peft-0.4.0\n", + "Successfully installed peft-0.5.0\n" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "peft" + ] + } + } + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [ + "# load model from google drive\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')" + ], + "metadata": { + "id": "gRRw9drdA2hv", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d75720b1-35da-4305-cd38-309f64d38eb0" + }, + "execution_count": 29, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Define the path you want to check\n", + "path_to_check = \"/content/drive/My Drive/finetuned_model\"\n", + "\n", + "# Check if the specified path exists\n", + "if os.path.exists(path_to_check):\n", + " print(\"Path exists.\")\n", + "else:\n", + " print(\"Path does not exist.\")\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LxCAhg9QpkyI", + "outputId": "81989b85-4abf-4403-c730-ac6e8ebe3488" + }, + "execution_count": 30, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Path exists.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "## load the chatglm2-6b base model\n", + "base_model = \"THUDM/chatglm2-6b\"\n", + "peft_model = training_args.output_dir\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)\n", + "model = AutoModel.from_pretrained(base_model, trust_remote_code=True, load_in_8bit=True, device_map=\"auto\")\n", + "\n", + "model = PeftModel.from_pretrained(model, peft_model)\n", + "\n", + "model = model.eval()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "424390ff1b0f4c3884e9f0c59fc9a0d7", + "9a09a691e0064c708c39fee52e5a85cd", + "111f553446b6427aa8b752c3bb9bf2a0", + "9275857ca36447b2af2b238aeb64c192", + "616211b4a10348198d070b35dcfa4ec0", + "5355d4c77ea54fe3a6f72e7dba2125d0", + "fa72cf75c46a4a619ce33320573fa9cb", + "3275141d70f64ae4ae7e79efc792869c", + "636fd6c2446b47a9a6168ce3634c6b01", + "3eb06cbb516e4cf28f968ba08a1c7b44", + "d007d3a04df2494b80d1300f1b8e624a" + ] + }, + "id": "bRljPCKC_srt", + "outputId": "b7ab1104-fae1-4f31-9bfc-06cc53b3fbc9" + }, + "execution_count": 33, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/7 [00:00