{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "0729b762-3b84-474f-b82a-df7622b91ccb", "metadata": {}, "outputs": [], "source": [ "import torch, html\n", "from transformers import AutoTokenizer\n", "from datasets import load_dataset, load_from_disk\n", "from huggingface_hub import notebook_login\n", "from dotenv import load_dotenv\n", "import os" ] }, { "cell_type": "code", "execution_count": 3, "id": "92ee5f76-2cd3-4af0-8687-dca782aa38a3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "load_dotenv()" ] }, { "cell_type": "code", "execution_count": 4, "id": "97d33c57-b03b-4bee-b051-04d707a8d773", "metadata": {}, "outputs": [], "source": [ "access_token = os.environ['HF_TOKEN']" ] }, { "cell_type": "code", "execution_count": 4, "id": "4358520c-3d8c-42ef-967a-eddeef732ef1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'cuda'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "device" ] }, { "cell_type": "code", "execution_count": 5, "id": "1c2ec24f-4c6d-4469-8e85-601a4b0d3e4e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['Unnamed: 0', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],\n", " num_rows: 161297\n", " })\n", " test: Dataset({\n", " features: ['Unnamed: 0', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],\n", " num_rows: 53766\n", " })\n", "})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = load_dataset('csv', data_files={\n", " 'train': 'data/drugsComTrain_raw.tsv',\n", " 'test': 'data/drugsComTest_raw.tsv'\n", "}, delimiter='\\t', num_proc=8)\n", "dataset" ] }, { "cell_type": "code", "execution_count": 6, "id": "dbb81021-9acc-46b4-87c0-23f0f787fef5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': (161297, 7), 'test': (53766, 7)}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.shape" ] }, { "cell_type": "code", "execution_count": 7, "id": "a983147c-eb04-455f-bf02-0c57c2a549e9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Unnamed: 0': 206461,\n", " 'drugName': 'Valsartan',\n", " 'condition': 'Left Ventricular Dysfunction',\n", " 'review': '\"It has no side effect, I take it in combination of Bystolic 5 Mg and Fish Oil\"',\n", " 'rating': 9.0,\n", " 'date': 'May 20, 2012',\n", " 'usefulCount': 27}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset['train'][0]" ] }, { "cell_type": "code", "execution_count": 8, "id": "ee2b8ddf-79d7-44d6-80ba-243bc2f04de8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n", " num_rows: 138514\n", " })\n", " test: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n", " num_rows: 46108\n", " })\n", "})" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = (\n", " dataset\n", " .filter(lambda x: x['condition'] is not None)\n", " .rename_column('Unnamed: 0', 'row_id')\n", " .map(lambda x: {'condition': [row.lower() for row in x['condition']]}, batched=True, num_proc=8, batch_size=3000)\n", " .map(lambda x: {'review': [html.unescape(row) for row in x['review']]}, batched=True, num_proc=8, batch_size=3000)\n", " .map(lambda x: {'review_length': [len(row.split()) for row in x['review']]}, batched=True, num_proc=8, batch_size=3000)\n", " # .filter(lambda x: {'review_length': [row > 30 for row in x['review_length']]}, batched=True, num_proc=8)\n", " .filter(lambda x: x['review_length'] > 30, num_proc=8, batch_size=3000)\n", ")\n", "dataset" ] }, { "cell_type": "markdown", "id": "e7c4daf2-36c1-4074-91ca-8871a581052d", "metadata": {}, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "id": "ea14b998-69f1-40a7-a200-7cc53b0e22fd", "metadata": {}, "source": [ "## Predict patient condition based on drug review" ] }, { "cell_type": "code", "execution_count": 6, "id": "dc6b299b-2d0b-4475-bfff-d0180dd672c1", "metadata": {}, "outputs": [], "source": [ "from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification, AutoTokenizer, AutoModel, DataCollatorWithPadding\n", "from torch.utils.data import DataLoader\n", "import evaluate, numpy as np\n", "from huggingface_hub import HfApi" ] }, { "cell_type": "code", "execution_count": 10, "id": "77caa284-8307-40a0-8369-621195e5c7e9", "metadata": {}, "outputs": [], "source": [ "def clean_condition_column(rows):\n", " target_text = 'users found this comment helpful'\n", " return {'condition': ['unknown' if target_text in condition else condition for condition in rows['condition']]}" ] }, { "cell_type": "code", "execution_count": 11, "id": "058d4c64-428b-43bb-86c4-ba8f5c1b8a84", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n", " num_rows: 138514\n", " })\n", " test: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n", " num_rows: 46108\n", " })\n", "})" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = dataset.map(clean_condition_column, batched=True, batch_size=3000, num_proc=8)\n", "dataset" ] }, { "cell_type": "code", "execution_count": 12, "id": "80dc20fe-cb66-4b0d-99dc-88e84413975b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n", " num_rows: 110811\n", " })\n", " validation: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n", " num_rows: 27703\n", " })\n", " test: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n", " num_rows: 46108\n", " })\n", "})" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clean_data = dataset['train'].train_test_split(test_size=.2, seed=5, writer_batch_size=3000)\n", "clean_data['validation'] = clean_data.pop('test')\n", "clean_data['test'] = dataset['test']\n", "\n", "clean_data" ] }, { "cell_type": "code", "execution_count": 13, "id": "8be33fbb-143f-45b5-9e18-c5662a7e0dad", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "751" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_conditions = sorted(set(clean_data['train']['condition']).union(set(clean_data['validation']['condition'])))\n", "len(all_conditions)" ] }, { "cell_type": "code", "execution_count": 14, "id": "912ef7d5-149a-48ed-ac6b-1ff2f3c2556a", "metadata": {}, "outputs": [], "source": [ "id2label = dict(enumerate(all_conditions))\n", "label2id = {v:k for k, v in id2label.items()}" ] }, { "cell_type": "code", "execution_count": 15, "id": "aca4a239-3f07-44bf-905e-2743b8f0889d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(label2id) == len(id2label)" ] }, { "cell_type": "code", "execution_count": 16, "id": "024d5faa-88f1-41b7-9f52-8178ad731089", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels'],\n", " num_rows: 110811\n", " })\n", " validation: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels'],\n", " num_rows: 27703\n", " })\n", " test: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels'],\n", " num_rows: 46108\n", " })\n", "})" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clean_data = clean_data.map(lambda x: {'labels': [label2id.get(condition, label2id['unknown']) for condition in x['condition']]}, batched=True, batch_size=3000, num_proc=8)\n", "clean_data" ] }, { "cell_type": "code", "execution_count": 17, "id": "2f71cacc-9fb4-4436-b32b-8f172bcc19b1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "# checkpoint = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'\n", "checkpoint = 'distilbert-base-uncased'\n", "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=len(id2label)).to(device)\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)" ] }, { "cell_type": "code", "execution_count": 18, "id": "e9b2c2bd-52d4-47e0-aaaf-eb76b3bab9fa", "metadata": {}, "outputs": [], "source": [ "model.config.id2label = id2label\n", "model.config.label2id = label2id\n", "model.num_labels = len(label2id)" ] }, { "cell_type": "code", "execution_count": 19, "id": "2d3bb44b-e635-4e7c-b984-6379510b60b3", "metadata": {}, "outputs": [], "source": [ "collator = DataCollatorWithPadding(tokenizer)" ] }, { "cell_type": "code", "execution_count": 20, "id": "c22a17ab-4a43-45f6-ba99-62cdb94103c5", "metadata": {}, "outputs": [], "source": [ "def tokenize_and_split(examples):\n", " tokens = tokenizer(\n", " examples[\"review\"],\n", " truncation=True,\n", " max_length=512,\n", " return_overflowing_tokens=True,\n", " )\n", " mappings = tokens.pop('overflow_to_sample_mapping')\n", " for key, values in examples.items():\n", " tokens[key] = [values[idx] for idx in mappings]\n", " return tokens" ] }, { "cell_type": "code", "execution_count": 21, "id": "5a1b9eb6-87a1-4d7f-855b-f1c9e5ae63c2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels', 'input_ids', 'attention_mask'],\n", " num_rows: 110857\n", " })\n", " validation: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels', 'input_ids', 'attention_mask'],\n", " num_rows: 27717\n", " })\n", " test: Dataset({\n", " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels', 'input_ids', 'attention_mask'],\n", " num_rows: 46118\n", " })\n", "})" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized_dataset = clean_data.map(tokenize_and_split, batched=True, batch_size=3000, num_proc=8)\n", "tokenized_dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "2729d5c2-499d-41f0-8ddb-27df3cf82475", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "3488692d-44ef-4b99-af4c-8fa32d6ed3b2", "metadata": {}, "outputs": [], "source": [ "tokenized_dataset.save_to_disk('data/drugs', num_proc=4)" ] }, { "cell_type": "code", "execution_count": null, "id": "2bc7b3ea-5f48-4298-b625-d313c4dc1ea3", "metadata": {}, "outputs": [], "source": [ "tokenized_dataset = load_from_disk('data/drugs/')\n", "tokenized_dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "001bb28a-9ff1-463f-90af-22dc7f6bce53", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 22, "id": "344a5505-f143-4389-8be3-282219f29d74", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['input_ids', 'attention_mask', 'labels'],\n", " num_rows: 110857\n", " })\n", " validation: Dataset({\n", " features: ['input_ids', 'attention_mask', 'labels'],\n", " num_rows: 27717\n", " })\n", " test: Dataset({\n", " features: ['input_ids', 'attention_mask', 'labels'],\n", " num_rows: 46118\n", " })\n", "})" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "filtered = tokenized_dataset.select_columns(['input_ids', 'attention_mask', 'labels'])\n", "filtered" ] }, { "cell_type": "code", "execution_count": 23, "id": "b31de787-0312-4d67-8b41-ce85732308ea", "metadata": {}, "outputs": [], "source": [ "accuracy = evaluate.load('accuracy')" ] }, { "cell_type": "code", "execution_count": 24, "id": "f6d0543e-06d5-4930-93f6-8028e4e4ead5", "metadata": {}, "outputs": [], "source": [ "def compute_metrics(eval_preds):\n", " logits, labels = eval_preds\n", " preds = np.argmax(logits, axis=-1)\n", " return accuracy.compute(predictions=preds, references=labels)" ] }, { "cell_type": "code", "execution_count": 25, "id": "ec5be835-e194-47a4-8c2a-3eb7500645ad", "metadata": {}, "outputs": [], "source": [ "lr = 3e-5" ] }, { "cell_type": "code", "execution_count": 26, "id": "d04e4bae-8bb0-4e5e-be0f-2ce41db1bbe6", "metadata": {}, "outputs": [], "source": [ "train_args = TrainingArguments(\n", " 'medical_condition_classification', \n", " overwrite_output_dir=True, \n", " eval_strategy='steps', eval_steps=2000, \n", " per_device_train_batch_size=24, \n", " per_device_eval_batch_size=24, \n", " fp16=True, num_train_epochs=5,\n", " learning_rate=lr,\n", " push_to_hub=True,\n", " hub_token=access_token\n", ")" ] }, { "cell_type": "code", "execution_count": 27, "id": "e26faf3f-03ab-411d-97a2-c1a3b6e2b425", "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(model, train_args, collator, filtered['train'], filtered['validation'], tokenizer, compute_metrics=compute_metrics)" ] }, { "cell_type": "code", "execution_count": 28, "id": "52c55095-4761-4353-8222-887cdf309431", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [23100/23100 1:14:13, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation LossAccuracy
20001.8625001.7198710.639680
40001.4590001.3695660.688963
60001.1737001.2131410.717249
80001.0420001.1014190.732908
100000.8431001.0322370.750983
120000.8012000.9889390.758668
140000.7312000.9496870.772703
160000.6561000.9338450.780496
180000.6132000.9072620.787531
200000.5195000.9010890.792943
220000.5015000.8929590.795072

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=23100, training_loss=1.0162131207949154, metrics={'train_runtime': 4454.3937, 'train_samples_per_second': 124.436, 'train_steps_per_second': 5.186, 'total_flos': 2.958796560013029e+16, 'train_loss': 1.0162131207949154, 'epoch': 5.0})" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": 29, "id": "7c8d06d3-ef08-42ca-9dad-651c3a7c45fc", "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "with torch.no_grad():\n", " preds = trainer.predict(filtered['test'])" ] }, { "cell_type": "code", "execution_count": 33, "id": "cab2f41e-d00f-41cb-a5a6-daf9e713077d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'test_loss': 0.8813542127609253,\n", " 'test_accuracy': 0.8004249967474739,\n", " 'test_runtime': 87.98,\n", " 'test_samples_per_second': 524.188,\n", " 'test_steps_per_second': 21.846}" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds.metrics" ] }, { "cell_type": "code", "execution_count": 34, "id": "1e323be4-ac78-498e-99fc-3133b11dc241", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d45b5b475cce4bb09298a7278ec51c64", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/270M [00:00