{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine tuning bert base uncased for paraphrasing identification task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/huggingface/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Map: 100%|██████████| 408/408 [00:00<00:00, 8416.87 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
    "checkpoint = \"bert-base-uncased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
    "\n",
    "\n",
    "def tokenize_function(example):\n",
    "    return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n",
    "\n",
    "\n",
    "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_args.num_train_epochs = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TrainingArguments(\n",
       "_n_gpu=1,\n",
       "adafactor=False,\n",
       "adam_beta1=0.9,\n",
       "adam_beta2=0.999,\n",
       "adam_epsilon=1e-08,\n",
       "auto_find_batch_size=False,\n",
       "bf16=False,\n",
       "bf16_full_eval=False,\n",
       "data_seed=None,\n",
       "dataloader_drop_last=False,\n",
       "dataloader_num_workers=0,\n",
       "dataloader_pin_memory=True,\n",
       "ddp_bucket_cap_mb=None,\n",
       "ddp_find_unused_parameters=None,\n",
       "ddp_timeout=1800,\n",
       "debug=[],\n",
       "deepspeed=None,\n",
       "disable_tqdm=False,\n",
       "do_eval=False,\n",
       "do_predict=False,\n",
       "do_train=False,\n",
       "eval_accumulation_steps=None,\n",
       "eval_delay=0,\n",
       "eval_steps=None,\n",
       "evaluation_strategy=no,\n",
       "fp16=False,\n",
       "fp16_backend=auto,\n",
       "fp16_full_eval=False,\n",
       "fp16_opt_level=O1,\n",
       "fsdp=[],\n",
       "fsdp_config={'fsdp_min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},\n",
       "fsdp_min_num_params=0,\n",
       "fsdp_transformer_layer_cls_to_wrap=None,\n",
       "full_determinism=False,\n",
       "gradient_accumulation_steps=1,\n",
       "gradient_checkpointing=False,\n",
       "greater_is_better=None,\n",
       "group_by_length=False,\n",
       "half_precision_backend=auto,\n",
       "hub_model_id=None,\n",
       "hub_private_repo=False,\n",
       "hub_strategy=every_save,\n",
       "hub_token=<HUB_TOKEN>,\n",
       "ignore_data_skip=False,\n",
       "include_inputs_for_metrics=False,\n",
       "jit_mode_eval=False,\n",
       "label_names=None,\n",
       "label_smoothing_factor=0.0,\n",
       "learning_rate=5e-05,\n",
       "length_column_name=length,\n",
       "load_best_model_at_end=False,\n",
       "local_rank=-1,\n",
       "log_level=passive,\n",
       "log_level_replica=warning,\n",
       "log_on_each_node=True,\n",
       "logging_dir=test-trainer/runs/Jul21_12-03-08_602d65b93b25,\n",
       "logging_first_step=False,\n",
       "logging_nan_inf_filter=True,\n",
       "logging_steps=500,\n",
       "logging_strategy=steps,\n",
       "lr_scheduler_type=linear,\n",
       "max_grad_norm=1.0,\n",
       "max_steps=-1,\n",
       "metric_for_best_model=None,\n",
       "mp_parameters=,\n",
       "no_cuda=False,\n",
       "num_train_epochs=2,\n",
       "optim=adamw_hf,\n",
       "optim_args=None,\n",
       "output_dir=test-trainer,\n",
       "overwrite_output_dir=False,\n",
       "past_index=-1,\n",
       "per_device_eval_batch_size=8,\n",
       "per_device_train_batch_size=8,\n",
       "prediction_loss_only=False,\n",
       "push_to_hub=False,\n",
       "push_to_hub_model_id=None,\n",
       "push_to_hub_organization=None,\n",
       "push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\n",
       "ray_scope=last,\n",
       "remove_unused_columns=True,\n",
       "report_to=[],\n",
       "resume_from_checkpoint=None,\n",
       "run_name=test-trainer,\n",
       "save_on_each_node=False,\n",
       "save_steps=500,\n",
       "save_strategy=steps,\n",
       "save_total_limit=None,\n",
       "seed=42,\n",
       "sharded_ddp=[],\n",
       "skip_memory_metrics=True,\n",
       "tf32=None,\n",
       "torch_compile=False,\n",
       "torch_compile_backend=None,\n",
       "torch_compile_mode=None,\n",
       "torchdynamo=None,\n",
       "tpu_metrics_debug=False,\n",
       "tpu_num_cores=None,\n",
       "use_ipex=False,\n",
       "use_legacy_prediction_loop=False,\n",
       "use_mps_device=False,\n",
       "warmup_ratio=0.0,\n",
       "warmup_steps=0,\n",
       "weight_decay=0.0,\n",
       "xpu_backend=None,\n",
       ")"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import TrainingArguments\n",
    "training_args = TrainingArguments(\"test-trainer\")\n",
    "training_args.num_train_epochs = 2\n",
    "training_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Linear(in_features=768, out_features=2, bias=True)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer\n",
    "trainer = Trainer(\n",
    "    model,\n",
    "    training_args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/huggingface/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='918' max='918' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [918/918 01:11, Epoch 2/2]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.323900</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=918, training_loss=0.26239446669102756, metrics={'train_runtime': 72.0933, 'train_samples_per_second': 101.757, 'train_steps_per_second': 12.733, 'total_flos': 270693998197680.0, 'train_loss': 0.26239446669102756, 'epoch': 2.0})"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Plain training\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6' max='51' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 6/51 00:00 < 00:00, 50.51 it/s]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(408, 2) (408,)\n"
     ]
    }
   ],
   "source": [
    "predictions = trainer.predict(tokenized_datasets[\"validation\"])\n",
    "print(predictions.predictions.shape, predictions.label_ids.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "preds = np.argmax(predictions.predictions, axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.8553921568627451, 'f1': 0.8963093145869947}"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import evaluate\n",
    "metric = evaluate.load(\"glue\", \"mrpc\")\n",
    "metric.compute(predictions=preds, references=predictions.label_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(eval_preds):\n",
    "    metric = evaluate.load(\"glue\", \"mrpc\")\n",
    "    logits, labels = eval_preds\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\", num_train_epochs = 2)\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
    "\n",
    "trainer = Trainer(\n",
    "    model,\n",
    "    training_args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/huggingface/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='918' max='918' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [918/918 01:21, Epoch 2/2]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.418620</td>\n",
       "      <td>0.830882</td>\n",
       "      <td>0.883249</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.498100</td>\n",
       "      <td>0.485925</td>\n",
       "      <td>0.860294</td>\n",
       "      <td>0.903226</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=918, training_loss=0.39945665579735584, metrics={'train_runtime': 82.0502, 'train_samples_per_second': 89.409, 'train_steps_per_second': 11.188, 'total_flos': 270693998197680.0, 'train_loss': 0.39945665579735584, 'epoch': 2.0})"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finetuning on GLUE-SST-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Map: 100%|██████████| 1821/1821 [00:00<00:00, 9257.22 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "raw_dataset = load_dataset(\"glue\", \"sst2\")\n",
    "\n",
    "from transformers import AutoTokenizer\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "checkpoint = \"bert-base-uncased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
    "\n",
    "def tokenize_function(sequence):\n",
    "    return tokenizer(sequence[\"sentence\"], padding = True, truncation = True, return_tensors=\"pt\")\n",
    "\n",
    "tokenized_dataset = raw_dataset.map(tokenize_function, batched = True)\n",
    "\n",
    "from transformers import DataCollatorWithPadding\n",
    "dc = DataCollatorWithPadding(tokenizer = tokenizer, padding = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(eval_preds):\n",
    "    metric = evaluate.load(\"glue\", \"sst2\")\n",
    "    logits, labels = eval_preds\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\", num_train_epochs = 2,\n",
    "                                  per_device_eval_batch_size = 32, per_device_train_batch_size = 64)\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
    "\n",
    "trainer = Trainer(\n",
    "    model,\n",
    "    training_args,\n",
    "    train_dataset=tokenized_dataset[\"train\"],\n",
    "    eval_dataset=tokenized_dataset[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/huggingface/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='2106' max='2106' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [2106/2106 11:38, Epoch 2/2]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.166400</td>\n",
       "      <td>0.236536</td>\n",
       "      <td>0.918578</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.089700</td>\n",
       "      <td>0.238962</td>\n",
       "      <td>0.925459</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=2106, training_loss=0.1493011535289507, metrics={'train_runtime': 698.8876, 'train_samples_per_second': 192.732, 'train_steps_per_second': 3.013, 'total_flos': 4556217062352120.0, 'train_loss': 0.1493011535289507, 'epoch': 2.0})"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}