samsaara commited on
Commit
39b56ae
·
verified ·
1 Parent(s): d6beaa1

update notebook

Browse files
Files changed (1) hide show
  1. modeling.ipynb +1066 -0
modeling.ipynb ADDED
@@ -0,0 +1,1066 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "0729b762-3b84-474f-b82a-df7622b91ccb",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch, html\n",
11
+ "from transformers import AutoTokenizer\n",
12
+ "from datasets import load_dataset, load_from_disk\n",
13
+ "from huggingface_hub import notebook_login\n",
14
+ "from dotenv import load_dotenv\n",
15
+ "import os"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 3,
21
+ "id": "92ee5f76-2cd3-4af0-8687-dca782aa38a3",
22
+ "metadata": {},
23
+ "outputs": [
24
+ {
25
+ "data": {
26
+ "text/plain": [
27
+ "True"
28
+ ]
29
+ },
30
+ "execution_count": 3,
31
+ "metadata": {},
32
+ "output_type": "execute_result"
33
+ }
34
+ ],
35
+ "source": [
36
+ "load_dotenv()"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 4,
42
+ "id": "97d33c57-b03b-4bee-b051-04d707a8d773",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "access_token = os.environ['HF_TOKEN']"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 4,
52
+ "id": "4358520c-3d8c-42ef-967a-eddeef732ef1",
53
+ "metadata": {},
54
+ "outputs": [
55
+ {
56
+ "data": {
57
+ "text/plain": [
58
+ "'cuda'"
59
+ ]
60
+ },
61
+ "execution_count": 4,
62
+ "metadata": {},
63
+ "output_type": "execute_result"
64
+ }
65
+ ],
66
+ "source": [
67
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
68
+ "device"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 5,
74
+ "id": "1c2ec24f-4c6d-4469-8e85-601a4b0d3e4e",
75
+ "metadata": {},
76
+ "outputs": [
77
+ {
78
+ "data": {
79
+ "text/plain": [
80
+ "DatasetDict({\n",
81
+ " train: Dataset({\n",
82
+ " features: ['Unnamed: 0', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],\n",
83
+ " num_rows: 161297\n",
84
+ " })\n",
85
+ " test: Dataset({\n",
86
+ " features: ['Unnamed: 0', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],\n",
87
+ " num_rows: 53766\n",
88
+ " })\n",
89
+ "})"
90
+ ]
91
+ },
92
+ "execution_count": 5,
93
+ "metadata": {},
94
+ "output_type": "execute_result"
95
+ }
96
+ ],
97
+ "source": [
98
+ "dataset = load_dataset('csv', data_files={\n",
99
+ " 'train': 'data/drugsComTrain_raw.tsv',\n",
100
+ " 'test': 'data/drugsComTest_raw.tsv'\n",
101
+ "}, delimiter='\\t', num_proc=8)\n",
102
+ "dataset"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 6,
108
+ "id": "dbb81021-9acc-46b4-87c0-23f0f787fef5",
109
+ "metadata": {},
110
+ "outputs": [
111
+ {
112
+ "data": {
113
+ "text/plain": [
114
+ "{'train': (161297, 7), 'test': (53766, 7)}"
115
+ ]
116
+ },
117
+ "execution_count": 6,
118
+ "metadata": {},
119
+ "output_type": "execute_result"
120
+ }
121
+ ],
122
+ "source": [
123
+ "dataset.shape"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 7,
129
+ "id": "a983147c-eb04-455f-bf02-0c57c2a549e9",
130
+ "metadata": {},
131
+ "outputs": [
132
+ {
133
+ "data": {
134
+ "text/plain": [
135
+ "{'Unnamed: 0': 206461,\n",
136
+ " 'drugName': 'Valsartan',\n",
137
+ " 'condition': 'Left Ventricular Dysfunction',\n",
138
+ " 'review': '\"It has no side effect, I take it in combination of Bystolic 5 Mg and Fish Oil\"',\n",
139
+ " 'rating': 9.0,\n",
140
+ " 'date': 'May 20, 2012',\n",
141
+ " 'usefulCount': 27}"
142
+ ]
143
+ },
144
+ "execution_count": 7,
145
+ "metadata": {},
146
+ "output_type": "execute_result"
147
+ }
148
+ ],
149
+ "source": [
150
+ "dataset['train'][0]"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 8,
156
+ "id": "ee2b8ddf-79d7-44d6-80ba-243bc2f04de8",
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "data": {
161
+ "text/plain": [
162
+ "DatasetDict({\n",
163
+ " train: Dataset({\n",
164
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
165
+ " num_rows: 138514\n",
166
+ " })\n",
167
+ " test: Dataset({\n",
168
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
169
+ " num_rows: 46108\n",
170
+ " })\n",
171
+ "})"
172
+ ]
173
+ },
174
+ "execution_count": 8,
175
+ "metadata": {},
176
+ "output_type": "execute_result"
177
+ }
178
+ ],
179
+ "source": [
180
+ "dataset = (\n",
181
+ " dataset\n",
182
+ " .filter(lambda x: x['condition'] is not None)\n",
183
+ " .rename_column('Unnamed: 0', 'row_id')\n",
184
+ " .map(lambda x: {'condition': [row.lower() for row in x['condition']]}, batched=True, num_proc=8, batch_size=3000)\n",
185
+ " .map(lambda x: {'review': [html.unescape(row) for row in x['review']]}, batched=True, num_proc=8, batch_size=3000)\n",
186
+ " .map(lambda x: {'review_length': [len(row.split()) for row in x['review']]}, batched=True, num_proc=8, batch_size=3000)\n",
187
+ " # .filter(lambda x: {'review_length': [row > 30 for row in x['review_length']]}, batched=True, num_proc=8)\n",
188
+ " .filter(lambda x: x['review_length'] > 30, num_proc=8, batch_size=3000)\n",
189
+ ")\n",
190
+ "dataset"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "id": "e7c4daf2-36c1-4074-91ca-8871a581052d",
196
+ "metadata": {},
197
+ "source": [
198
+ "# Exercises"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "markdown",
203
+ "id": "ea14b998-69f1-40a7-a200-7cc53b0e22fd",
204
+ "metadata": {},
205
+ "source": [
206
+ "## Predict patient condition based on drug review"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": 6,
212
+ "id": "dc6b299b-2d0b-4475-bfff-d0180dd672c1",
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification, AutoTokenizer, AutoModel, DataCollatorWithPadding\n",
217
+ "from torch.utils.data import DataLoader\n",
218
+ "import evaluate, numpy as np\n",
219
+ "from huggingface_hub import HfApi"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 10,
225
+ "id": "77caa284-8307-40a0-8369-621195e5c7e9",
226
+ "metadata": {},
227
+ "outputs": [],
228
+ "source": [
229
+ "def clean_condition_column(rows):\n",
230
+ " target_text = 'users found this comment helpful'\n",
231
+ " return {'condition': ['unknown' if target_text in condition else condition for condition in rows['condition']]}"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": 11,
237
+ "id": "058d4c64-428b-43bb-86c4-ba8f5c1b8a84",
238
+ "metadata": {},
239
+ "outputs": [
240
+ {
241
+ "data": {
242
+ "text/plain": [
243
+ "DatasetDict({\n",
244
+ " train: Dataset({\n",
245
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
246
+ " num_rows: 138514\n",
247
+ " })\n",
248
+ " test: Dataset({\n",
249
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
250
+ " num_rows: 46108\n",
251
+ " })\n",
252
+ "})"
253
+ ]
254
+ },
255
+ "execution_count": 11,
256
+ "metadata": {},
257
+ "output_type": "execute_result"
258
+ }
259
+ ],
260
+ "source": [
261
+ "dataset = dataset.map(clean_condition_column, batched=True, batch_size=3000, num_proc=8)\n",
262
+ "dataset"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": 12,
268
+ "id": "80dc20fe-cb66-4b0d-99dc-88e84413975b",
269
+ "metadata": {},
270
+ "outputs": [
271
+ {
272
+ "data": {
273
+ "text/plain": [
274
+ "DatasetDict({\n",
275
+ " train: Dataset({\n",
276
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
277
+ " num_rows: 110811\n",
278
+ " })\n",
279
+ " validation: Dataset({\n",
280
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
281
+ " num_rows: 27703\n",
282
+ " })\n",
283
+ " test: Dataset({\n",
284
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
285
+ " num_rows: 46108\n",
286
+ " })\n",
287
+ "})"
288
+ ]
289
+ },
290
+ "execution_count": 12,
291
+ "metadata": {},
292
+ "output_type": "execute_result"
293
+ }
294
+ ],
295
+ "source": [
296
+ "clean_data = dataset['train'].train_test_split(test_size=.2, seed=5, writer_batch_size=3000)\n",
297
+ "clean_data['validation'] = clean_data.pop('test')\n",
298
+ "clean_data['test'] = dataset['test']\n",
299
+ "\n",
300
+ "clean_data"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": 13,
306
+ "id": "8be33fbb-143f-45b5-9e18-c5662a7e0dad",
307
+ "metadata": {},
308
+ "outputs": [
309
+ {
310
+ "data": {
311
+ "text/plain": [
312
+ "751"
313
+ ]
314
+ },
315
+ "execution_count": 13,
316
+ "metadata": {},
317
+ "output_type": "execute_result"
318
+ }
319
+ ],
320
+ "source": [
321
+ "all_conditions = sorted(set(clean_data['train']['condition']).union(set(clean_data['validation']['condition'])))\n",
322
+ "len(all_conditions)"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": 14,
328
+ "id": "912ef7d5-149a-48ed-ac6b-1ff2f3c2556a",
329
+ "metadata": {},
330
+ "outputs": [],
331
+ "source": [
332
+ "id2label = dict(enumerate(all_conditions))\n",
333
+ "label2id = {v:k for k, v in id2label.items()}"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": 15,
339
+ "id": "aca4a239-3f07-44bf-905e-2743b8f0889d",
340
+ "metadata": {},
341
+ "outputs": [
342
+ {
343
+ "data": {
344
+ "text/plain": [
345
+ "True"
346
+ ]
347
+ },
348
+ "execution_count": 15,
349
+ "metadata": {},
350
+ "output_type": "execute_result"
351
+ }
352
+ ],
353
+ "source": [
354
+ "len(label2id) == len(id2label)"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": 16,
360
+ "id": "024d5faa-88f1-41b7-9f52-8178ad731089",
361
+ "metadata": {},
362
+ "outputs": [
363
+ {
364
+ "data": {
365
+ "text/plain": [
366
+ "DatasetDict({\n",
367
+ " train: Dataset({\n",
368
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels'],\n",
369
+ " num_rows: 110811\n",
370
+ " })\n",
371
+ " validation: Dataset({\n",
372
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels'],\n",
373
+ " num_rows: 27703\n",
374
+ " })\n",
375
+ " test: Dataset({\n",
376
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels'],\n",
377
+ " num_rows: 46108\n",
378
+ " })\n",
379
+ "})"
380
+ ]
381
+ },
382
+ "execution_count": 16,
383
+ "metadata": {},
384
+ "output_type": "execute_result"
385
+ }
386
+ ],
387
+ "source": [
388
+ "clean_data = clean_data.map(lambda x: {'labels': [label2id.get(condition, label2id['unknown']) for condition in x['condition']]}, batched=True, batch_size=3000, num_proc=8)\n",
389
+ "clean_data"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": 17,
395
+ "id": "2f71cacc-9fb4-4436-b32b-8f172bcc19b1",
396
+ "metadata": {},
397
+ "outputs": [
398
+ {
399
+ "name": "stderr",
400
+ "output_type": "stream",
401
+ "text": [
402
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
403
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
404
+ ]
405
+ }
406
+ ],
407
+ "source": [
408
+ "# checkpoint = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'\n",
409
+ "checkpoint = 'distilbert-base-uncased'\n",
410
+ "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=len(id2label)).to(device)\n",
411
+ "tokenizer = AutoTokenizer.from_pretrained(checkpoint)"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": 18,
417
+ "id": "e9b2c2bd-52d4-47e0-aaaf-eb76b3bab9fa",
418
+ "metadata": {},
419
+ "outputs": [],
420
+ "source": [
421
+ "model.config.id2label = id2label\n",
422
+ "model.config.label2id = label2id\n",
423
+ "model.num_labels = len(label2id)"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": 19,
429
+ "id": "2d3bb44b-e635-4e7c-b984-6379510b60b3",
430
+ "metadata": {},
431
+ "outputs": [],
432
+ "source": [
433
+ "collator = DataCollatorWithPadding(tokenizer)"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": 20,
439
+ "id": "c22a17ab-4a43-45f6-ba99-62cdb94103c5",
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "def tokenize_and_split(examples):\n",
444
+ " tokens = tokenizer(\n",
445
+ " examples[\"review\"],\n",
446
+ " truncation=True,\n",
447
+ " max_length=512,\n",
448
+ " return_overflowing_tokens=True,\n",
449
+ " )\n",
450
+ " mappings = tokens.pop('overflow_to_sample_mapping')\n",
451
+ " for key, values in examples.items():\n",
452
+ " tokens[key] = [values[idx] for idx in mappings]\n",
453
+ " return tokens"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "code",
458
+ "execution_count": 21,
459
+ "id": "5a1b9eb6-87a1-4d7f-855b-f1c9e5ae63c2",
460
+ "metadata": {},
461
+ "outputs": [
462
+ {
463
+ "data": {
464
+ "text/plain": [
465
+ "DatasetDict({\n",
466
+ " train: Dataset({\n",
467
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels', 'input_ids', 'attention_mask'],\n",
468
+ " num_rows: 110857\n",
469
+ " })\n",
470
+ " validation: Dataset({\n",
471
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels', 'input_ids', 'attention_mask'],\n",
472
+ " num_rows: 27717\n",
473
+ " })\n",
474
+ " test: Dataset({\n",
475
+ " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels', 'input_ids', 'attention_mask'],\n",
476
+ " num_rows: 46118\n",
477
+ " })\n",
478
+ "})"
479
+ ]
480
+ },
481
+ "execution_count": 21,
482
+ "metadata": {},
483
+ "output_type": "execute_result"
484
+ }
485
+ ],
486
+ "source": [
487
+ "tokenized_dataset = clean_data.map(tokenize_and_split, batched=True, batch_size=3000, num_proc=8)\n",
488
+ "tokenized_dataset"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": null,
494
+ "id": "2729d5c2-499d-41f0-8ddb-27df3cf82475",
495
+ "metadata": {},
496
+ "outputs": [],
497
+ "source": []
498
+ },
499
+ {
500
+ "cell_type": "code",
501
+ "execution_count": null,
502
+ "id": "3488692d-44ef-4b99-af4c-8fa32d6ed3b2",
503
+ "metadata": {},
504
+ "outputs": [],
505
+ "source": [
506
+ "tokenized_dataset.save_to_disk('data/drugs', num_proc=4)"
507
+ ]
508
+ },
509
+ {
510
+ "cell_type": "code",
511
+ "execution_count": null,
512
+ "id": "2bc7b3ea-5f48-4298-b625-d313c4dc1ea3",
513
+ "metadata": {},
514
+ "outputs": [],
515
+ "source": [
516
+ "tokenized_dataset = load_from_disk('data/drugs/')\n",
517
+ "tokenized_dataset"
518
+ ]
519
+ },
520
+ {
521
+ "cell_type": "code",
522
+ "execution_count": null,
523
+ "id": "001bb28a-9ff1-463f-90af-22dc7f6bce53",
524
+ "metadata": {},
525
+ "outputs": [],
526
+ "source": []
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": 22,
531
+ "id": "344a5505-f143-4389-8be3-282219f29d74",
532
+ "metadata": {},
533
+ "outputs": [
534
+ {
535
+ "data": {
536
+ "text/plain": [
537
+ "DatasetDict({\n",
538
+ " train: Dataset({\n",
539
+ " features: ['input_ids', 'attention_mask', 'labels'],\n",
540
+ " num_rows: 110857\n",
541
+ " })\n",
542
+ " validation: Dataset({\n",
543
+ " features: ['input_ids', 'attention_mask', 'labels'],\n",
544
+ " num_rows: 27717\n",
545
+ " })\n",
546
+ " test: Dataset({\n",
547
+ " features: ['input_ids', 'attention_mask', 'labels'],\n",
548
+ " num_rows: 46118\n",
549
+ " })\n",
550
+ "})"
551
+ ]
552
+ },
553
+ "execution_count": 22,
554
+ "metadata": {},
555
+ "output_type": "execute_result"
556
+ }
557
+ ],
558
+ "source": [
559
+ "filtered = tokenized_dataset.select_columns(['input_ids', 'attention_mask', 'labels'])\n",
560
+ "filtered"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": 23,
566
+ "id": "b31de787-0312-4d67-8b41-ce85732308ea",
567
+ "metadata": {},
568
+ "outputs": [],
569
+ "source": [
570
+ "accuracy = evaluate.load('accuracy')"
571
+ ]
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "execution_count": 24,
576
+ "id": "f6d0543e-06d5-4930-93f6-8028e4e4ead5",
577
+ "metadata": {},
578
+ "outputs": [],
579
+ "source": [
580
+ "def compute_metrics(eval_preds):\n",
581
+ " logits, labels = eval_preds\n",
582
+ " preds = np.argmax(logits, axis=-1)\n",
583
+ " return accuracy.compute(predictions=preds, references=labels)"
584
+ ]
585
+ },
586
+ {
587
+ "cell_type": "code",
588
+ "execution_count": 25,
589
+ "id": "ec5be835-e194-47a4-8c2a-3eb7500645ad",
590
+ "metadata": {},
591
+ "outputs": [],
592
+ "source": [
593
+ "lr = 3e-5"
594
+ ]
595
+ },
596
+ {
597
+ "cell_type": "code",
598
+ "execution_count": 26,
599
+ "id": "d04e4bae-8bb0-4e5e-be0f-2ce41db1bbe6",
600
+ "metadata": {},
601
+ "outputs": [],
602
+ "source": [
603
+ "train_args = TrainingArguments(\n",
604
+ " 'medical_condition_classification', \n",
605
+ " overwrite_output_dir=True, \n",
606
+ " eval_strategy='steps', eval_steps=2000, \n",
607
+ " per_device_train_batch_size=24, \n",
608
+ " per_device_eval_batch_size=24, \n",
609
+ " fp16=True, num_train_epochs=5,\n",
610
+ " learning_rate=lr,\n",
611
+ " push_to_hub=True,\n",
612
+ " hub_token=access_token\n",
613
+ ")"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": 27,
619
+ "id": "e26faf3f-03ab-411d-97a2-c1a3b6e2b425",
620
+ "metadata": {},
621
+ "outputs": [],
622
+ "source": [
623
+ "trainer = Trainer(model, train_args, collator, filtered['train'], filtered['validation'], tokenizer, compute_metrics=compute_metrics)"
624
+ ]
625
+ },
626
+ {
627
+ "cell_type": "code",
628
+ "execution_count": 28,
629
+ "id": "52c55095-4761-4353-8222-887cdf309431",
630
+ "metadata": {},
631
+ "outputs": [
632
+ {
633
+ "data": {
634
+ "text/html": [
635
+ "\n",
636
+ " <div>\n",
637
+ " \n",
638
+ " <progress value='23100' max='23100' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
639
+ " [23100/23100 1:14:13, Epoch 5/5]\n",
640
+ " </div>\n",
641
+ " <table border=\"1\" class=\"dataframe\">\n",
642
+ " <thead>\n",
643
+ " <tr style=\"text-align: left;\">\n",
644
+ " <th>Step</th>\n",
645
+ " <th>Training Loss</th>\n",
646
+ " <th>Validation Loss</th>\n",
647
+ " <th>Accuracy</th>\n",
648
+ " </tr>\n",
649
+ " </thead>\n",
650
+ " <tbody>\n",
651
+ " <tr>\n",
652
+ " <td>2000</td>\n",
653
+ " <td>1.862500</td>\n",
654
+ " <td>1.719871</td>\n",
655
+ " <td>0.639680</td>\n",
656
+ " </tr>\n",
657
+ " <tr>\n",
658
+ " <td>4000</td>\n",
659
+ " <td>1.459000</td>\n",
660
+ " <td>1.369566</td>\n",
661
+ " <td>0.688963</td>\n",
662
+ " </tr>\n",
663
+ " <tr>\n",
664
+ " <td>6000</td>\n",
665
+ " <td>1.173700</td>\n",
666
+ " <td>1.213141</td>\n",
667
+ " <td>0.717249</td>\n",
668
+ " </tr>\n",
669
+ " <tr>\n",
670
+ " <td>8000</td>\n",
671
+ " <td>1.042000</td>\n",
672
+ " <td>1.101419</td>\n",
673
+ " <td>0.732908</td>\n",
674
+ " </tr>\n",
675
+ " <tr>\n",
676
+ " <td>10000</td>\n",
677
+ " <td>0.843100</td>\n",
678
+ " <td>1.032237</td>\n",
679
+ " <td>0.750983</td>\n",
680
+ " </tr>\n",
681
+ " <tr>\n",
682
+ " <td>12000</td>\n",
683
+ " <td>0.801200</td>\n",
684
+ " <td>0.988939</td>\n",
685
+ " <td>0.758668</td>\n",
686
+ " </tr>\n",
687
+ " <tr>\n",
688
+ " <td>14000</td>\n",
689
+ " <td>0.731200</td>\n",
690
+ " <td>0.949687</td>\n",
691
+ " <td>0.772703</td>\n",
692
+ " </tr>\n",
693
+ " <tr>\n",
694
+ " <td>16000</td>\n",
695
+ " <td>0.656100</td>\n",
696
+ " <td>0.933845</td>\n",
697
+ " <td>0.780496</td>\n",
698
+ " </tr>\n",
699
+ " <tr>\n",
700
+ " <td>18000</td>\n",
701
+ " <td>0.613200</td>\n",
702
+ " <td>0.907262</td>\n",
703
+ " <td>0.787531</td>\n",
704
+ " </tr>\n",
705
+ " <tr>\n",
706
+ " <td>20000</td>\n",
707
+ " <td>0.519500</td>\n",
708
+ " <td>0.901089</td>\n",
709
+ " <td>0.792943</td>\n",
710
+ " </tr>\n",
711
+ " <tr>\n",
712
+ " <td>22000</td>\n",
713
+ " <td>0.501500</td>\n",
714
+ " <td>0.892959</td>\n",
715
+ " <td>0.795072</td>\n",
716
+ " </tr>\n",
717
+ " </tbody>\n",
718
+ "</table><p>"
719
+ ],
720
+ "text/plain": [
721
+ "<IPython.core.display.HTML object>"
722
+ ]
723
+ },
724
+ "metadata": {},
725
+ "output_type": "display_data"
726
+ },
727
+ {
728
+ "data": {
729
+ "text/plain": [
730
+ "TrainOutput(global_step=23100, training_loss=1.0162131207949154, metrics={'train_runtime': 4454.3937, 'train_samples_per_second': 124.436, 'train_steps_per_second': 5.186, 'total_flos': 2.958796560013029e+16, 'train_loss': 1.0162131207949154, 'epoch': 5.0})"
731
+ ]
732
+ },
733
+ "execution_count": 28,
734
+ "metadata": {},
735
+ "output_type": "execute_result"
736
+ }
737
+ ],
738
+ "source": [
739
+ "trainer.train()"
740
+ ]
741
+ },
742
+ {
743
+ "cell_type": "code",
744
+ "execution_count": 29,
745
+ "id": "7c8d06d3-ef08-42ca-9dad-651c3a7c45fc",
746
+ "metadata": {},
747
+ "outputs": [
748
+ {
749
+ "data": {
750
+ "text/html": [],
751
+ "text/plain": [
752
+ "<IPython.core.display.HTML object>"
753
+ ]
754
+ },
755
+ "metadata": {},
756
+ "output_type": "display_data"
757
+ }
758
+ ],
759
+ "source": [
760
+ "with torch.no_grad():\n",
761
+ " preds = trainer.predict(filtered['test'])"
762
+ ]
763
+ },
764
+ {
765
+ "cell_type": "code",
766
+ "execution_count": 33,
767
+ "id": "cab2f41e-d00f-41cb-a5a6-daf9e713077d",
768
+ "metadata": {},
769
+ "outputs": [
770
+ {
771
+ "data": {
772
+ "text/plain": [
773
+ "{'test_loss': 0.8813542127609253,\n",
774
+ " 'test_accuracy': 0.8004249967474739,\n",
775
+ " 'test_runtime': 87.98,\n",
776
+ " 'test_samples_per_second': 524.188,\n",
777
+ " 'test_steps_per_second': 21.846}"
778
+ ]
779
+ },
780
+ "execution_count": 33,
781
+ "metadata": {},
782
+ "output_type": "execute_result"
783
+ }
784
+ ],
785
+ "source": [
786
+ "preds.metrics"
787
+ ]
788
+ },
789
+ {
790
+ "cell_type": "code",
791
+ "execution_count": 34,
792
+ "id": "1e323be4-ac78-498e-99fc-3133b11dc241",
793
+ "metadata": {},
794
+ "outputs": [
795
+ {
796
+ "data": {
797
+ "application/vnd.jupyter.widget-view+json": {
798
+ "model_id": "d45b5b475cce4bb09298a7278ec51c64",
799
+ "version_major": 2,
800
+ "version_minor": 0
801
+ },
802
+ "text/plain": [
803
+ "model.safetensors: 0%| | 0.00/270M [00:00<?, ?B/s]"
804
+ ]
805
+ },
806
+ "metadata": {},
807
+ "output_type": "display_data"
808
+ },
809
+ {
810
+ "data": {
811
+ "text/plain": [
812
+ "CommitInfo(commit_url='https://huggingface.co/samsaara/medical_condition_classification/commit/73d201a08dd78ce2afea66736b188d271e652052', commit_message='End of training', commit_description='', oid='73d201a08dd78ce2afea66736b188d271e652052', pr_url=None, repo_url=RepoUrl('https://huggingface.co/samsaara/medical_condition_classification', endpoint='https://huggingface.co', repo_type='model', repo_id='samsaara/medical_condition_classification'), pr_revision=None, pr_num=None)"
813
+ ]
814
+ },
815
+ "execution_count": 34,
816
+ "metadata": {},
817
+ "output_type": "execute_result"
818
+ }
819
+ ],
820
+ "source": [
821
+ "trainer.push_to_hub()"
822
+ ]
823
+ },
824
+ {
825
+ "cell_type": "code",
826
+ "execution_count": 36,
827
+ "id": "6e754047-4b6c-4b3c-80a1-8493009ac7ca",
828
+ "metadata": {},
829
+ "outputs": [
830
+ {
831
+ "data": {
832
+ "application/vnd.jupyter.widget-view+json": {
833
+ "model_id": "9f853954de9b4fe5bcf797878413702f",
834
+ "version_major": 2,
835
+ "version_minor": 0
836
+ },
837
+ "text/plain": [
838
+ "README.md: 0%| | 0.00/2.11k [00:00<?, ?B/s]"
839
+ ]
840
+ },
841
+ "metadata": {},
842
+ "output_type": "display_data"
843
+ },
844
+ {
845
+ "data": {
846
+ "text/plain": [
847
+ "CommitInfo(commit_url='https://huggingface.co/samsaara/medical_condition_classification/commit/9a28b6773707f2b88e7eeb37bc811c642bb524c7', commit_message='tokenizer', commit_description='', oid='9a28b6773707f2b88e7eeb37bc811c642bb524c7', pr_url=None, repo_url=RepoUrl('https://huggingface.co/samsaara/medical_condition_classification', endpoint='https://huggingface.co', repo_type='model', repo_id='samsaara/medical_condition_classification'), pr_revision=None, pr_num=None)"
848
+ ]
849
+ },
850
+ "execution_count": 36,
851
+ "metadata": {},
852
+ "output_type": "execute_result"
853
+ }
854
+ ],
855
+ "source": [
856
+ "tokenizer.push_to_hub('medical_condition_classification', commit_message='tokenizer')"
857
+ ]
858
+ },
859
+ {
860
+ "cell_type": "code",
861
+ "execution_count": 40,
862
+ "id": "2940e50a-328d-4f30-8cb7-30dd047d2f92",
863
+ "metadata": {},
864
+ "outputs": [
865
+ {
866
+ "data": {
867
+ "application/vnd.jupyter.widget-view+json": {
868
+ "model_id": "51c9d91e63664cba896e957841756995",
869
+ "version_major": 2,
870
+ "version_minor": 0
871
+ },
872
+ "text/plain": [
873
+ "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]"
874
+ ]
875
+ },
876
+ "metadata": {},
877
+ "output_type": "display_data"
878
+ },
879
+ {
880
+ "data": {
881
+ "application/vnd.jupyter.widget-view+json": {
882
+ "model_id": "e3bb0c86842e427eba71257d113c0845",
883
+ "version_major": 2,
884
+ "version_minor": 0
885
+ },
886
+ "text/plain": [
887
+ "Creating parquet from Arrow format: 0%| | 0/111 [00:00<?, ?ba/s]"
888
+ ]
889
+ },
890
+ "metadata": {},
891
+ "output_type": "display_data"
892
+ },
893
+ {
894
+ "data": {
895
+ "application/vnd.jupyter.widget-view+json": {
896
+ "model_id": "73d3373d5da04164a237ff56dc57fac5",
897
+ "version_major": 2,
898
+ "version_minor": 0
899
+ },
900
+ "text/plain": [
901
+ "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]"
902
+ ]
903
+ },
904
+ "metadata": {},
905
+ "output_type": "display_data"
906
+ },
907
+ {
908
+ "data": {
909
+ "application/vnd.jupyter.widget-view+json": {
910
+ "model_id": "df6679bdb3c342ffa2aab53922a3d1af",
911
+ "version_major": 2,
912
+ "version_minor": 0
913
+ },
914
+ "text/plain": [
915
+ "Creating parquet from Arrow format: 0%| | 0/28 [00:00<?, ?ba/s]"
916
+ ]
917
+ },
918
+ "metadata": {},
919
+ "output_type": "display_data"
920
+ },
921
+ {
922
+ "data": {
923
+ "application/vnd.jupyter.widget-view+json": {
924
+ "model_id": "60772b393cdc4dca9191a538b9116169",
925
+ "version_major": 2,
926
+ "version_minor": 0
927
+ },
928
+ "text/plain": [
929
+ "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]"
930
+ ]
931
+ },
932
+ "metadata": {},
933
+ "output_type": "display_data"
934
+ },
935
+ {
936
+ "data": {
937
+ "application/vnd.jupyter.widget-view+json": {
938
+ "model_id": "9b859655cbe34da184c8e6b6a07d0911",
939
+ "version_major": 2,
940
+ "version_minor": 0
941
+ },
942
+ "text/plain": [
943
+ "Creating parquet from Arrow format: 0%| | 0/47 [00:00<?, ?ba/s]"
944
+ ]
945
+ },
946
+ "metadata": {},
947
+ "output_type": "display_data"
948
+ },
949
+ {
950
+ "data": {
951
+ "text/plain": [
952
+ "CommitInfo(commit_url='https://huggingface.co/datasets/samsaara/medical_condition_classification/commit/7aea5155fcba521a02ec3e9e8fb4e86d09dc44ba', commit_message='Upload dataset', commit_description='', oid='7aea5155fcba521a02ec3e9e8fb4e86d09dc44ba', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/samsaara/medical_condition_classification', endpoint='https://huggingface.co', repo_type='dataset', repo_id='samsaara/medical_condition_classification'), pr_revision=None, pr_num=None)"
953
+ ]
954
+ },
955
+ "execution_count": 40,
956
+ "metadata": {},
957
+ "output_type": "execute_result"
958
+ }
959
+ ],
960
+ "source": [
961
+ "tokenized_dataset.push_to_hub('medical_condition_classification')"
962
+ ]
963
+ },
964
+ {
965
+ "cell_type": "code",
966
+ "execution_count": 7,
967
+ "id": "4bbf39bb-fea9-4b5b-8d87-7a9835975358",
968
+ "metadata": {},
969
+ "outputs": [],
970
+ "source": [
971
+ "api = HfApi(token=access_token)"
972
+ ]
973
+ },
974
+ {
975
+ "cell_type": "code",
976
+ "execution_count": 48,
977
+ "id": "16e0625b-5518-463a-96e2-2d008341b1f1",
978
+ "metadata": {},
979
+ "outputs": [
980
+ {
981
+ "data": {
982
+ "text/plain": [
983
+ "CommitInfo(commit_url='https://huggingface.co/samsaara/medical_condition_classification/commit/1067a267d69af46563d3a6b5a36d65030ccaa318', commit_message='update README', commit_description='', oid='1067a267d69af46563d3a6b5a36d65030ccaa318', pr_url=None, repo_url=RepoUrl('https://huggingface.co/samsaara/medical_condition_classification', endpoint='https://huggingface.co', repo_type='model', repo_id='samsaara/medical_condition_classification'), pr_revision=None, pr_num=None)"
984
+ ]
985
+ },
986
+ "execution_count": 48,
987
+ "metadata": {},
988
+ "output_type": "execute_result"
989
+ }
990
+ ],
991
+ "source": [
992
+ "api.upload_file(\n",
993
+ " path_or_fileobj='./medical_condition_classification/README.md', \n",
994
+ " path_in_repo='README.md',\n",
995
+ " repo_id='samsaara/medical_condition_classification', \n",
996
+ " commit_message='update README'\n",
997
+ ")"
998
+ ]
999
+ },
1000
+ {
1001
+ "cell_type": "code",
1002
+ "execution_count": null,
1003
+ "id": "031993dd-33b2-4c31-b199-3fe1c2fb88dc",
1004
+ "metadata": {},
1005
+ "outputs": [],
1006
+ "source": [
1007
+ "api.delete_file('datasets.ipynb', 'samsaara/medical_condition_classification', commit_message='')"
1008
+ ]
1009
+ },
1010
+ {
1011
+ "cell_type": "code",
1012
+ "execution_count": 49,
1013
+ "id": "4144a617-f0d6-41d3-9a44-90833b8bb1f5",
1014
+ "metadata": {},
1015
+ "outputs": [
1016
+ {
1017
+ "data": {
1018
+ "text/plain": [
1019
+ "CommitInfo(commit_url='https://huggingface.co/samsaara/medical_condition_classification/commit/f7008aa4e9f2c5d5fd4f87632cef56c86106a574', commit_message='update notebook', commit_description='', oid='f7008aa4e9f2c5d5fd4f87632cef56c86106a574', pr_url=None, repo_url=RepoUrl('https://huggingface.co/samsaara/medical_condition_classification', endpoint='https://huggingface.co', repo_type='model', repo_id='samsaara/medical_condition_classification'), pr_revision=None, pr_num=None)"
1020
+ ]
1021
+ },
1022
+ "execution_count": 49,
1023
+ "metadata": {},
1024
+ "output_type": "execute_result"
1025
+ }
1026
+ ],
1027
+ "source": [
1028
+ "api.upload_file(\n",
1029
+ " path_or_fileobj='datasets.ipynb', \n",
1030
+ " path_in_repo='datasets.ipynb',\n",
1031
+ " repo_id='samsaara/medical_condition_classification', \n",
1032
+ " commit_message='update notebook'\n",
1033
+ ")"
1034
+ ]
1035
+ },
1036
+ {
1037
+ "cell_type": "code",
1038
+ "execution_count": null,
1039
+ "id": "2de298eb-82f2-482b-acee-36c6a1e630b8",
1040
+ "metadata": {},
1041
+ "outputs": [],
1042
+ "source": []
1043
+ }
1044
+ ],
1045
+ "metadata": {
1046
+ "kernelspec": {
1047
+ "display_name": "Python 3 (ipykernel)",
1048
+ "language": "python",
1049
+ "name": "python3"
1050
+ },
1051
+ "language_info": {
1052
+ "codemirror_mode": {
1053
+ "name": "ipython",
1054
+ "version": 3
1055
+ },
1056
+ "file_extension": ".py",
1057
+ "mimetype": "text/x-python",
1058
+ "name": "python",
1059
+ "nbconvert_exporter": "python",
1060
+ "pygments_lexer": "ipython3",
1061
+ "version": "3.11.10"
1062
+ }
1063
+ },
1064
+ "nbformat": 4,
1065
+ "nbformat_minor": 5
1066
+ }