samsaara commited on
Commit
6eabb75
·
verified ·
1 Parent(s): f7008aa

delete notebook file

Browse files
Files changed (1) hide show
  1. datasets.ipynb +0 -1056
datasets.ipynb DELETED
@@ -1,1056 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "0729b762-3b84-474f-b82a-df7622b91ccb",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import torch, html\n",
11
- "from transformers import AutoTokenizer\n",
12
- "from datasets import load_dataset, load_from_disk\n",
13
- "from huggingface_hub import notebook_login\n",
14
- "from dotenv import load_dotenv\n",
15
- "import os"
16
- ]
17
- },
18
- {
19
- "cell_type": "code",
20
- "execution_count": 2,
21
- "id": "92ee5f76-2cd3-4af0-8687-dca782aa38a3",
22
- "metadata": {},
23
- "outputs": [
24
- {
25
- "data": {
26
- "text/plain": [
27
- "True"
28
- ]
29
- },
30
- "execution_count": 2,
31
- "metadata": {},
32
- "output_type": "execute_result"
33
- }
34
- ],
35
- "source": [
36
- "load_dotenv()"
37
- ]
38
- },
39
- {
40
- "cell_type": "code",
41
- "execution_count": 3,
42
- "id": "97d33c57-b03b-4bee-b051-04d707a8d773",
43
- "metadata": {},
44
- "outputs": [],
45
- "source": [
46
- "access_token = os.environ['HF_TOKEN']"
47
- ]
48
- },
49
- {
50
- "cell_type": "code",
51
- "execution_count": 4,
52
- "id": "4358520c-3d8c-42ef-967a-eddeef732ef1",
53
- "metadata": {},
54
- "outputs": [
55
- {
56
- "data": {
57
- "text/plain": [
58
- "'cuda'"
59
- ]
60
- },
61
- "execution_count": 4,
62
- "metadata": {},
63
- "output_type": "execute_result"
64
- }
65
- ],
66
- "source": [
67
- "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
68
- "device"
69
- ]
70
- },
71
- {
72
- "cell_type": "code",
73
- "execution_count": 5,
74
- "id": "1c2ec24f-4c6d-4469-8e85-601a4b0d3e4e",
75
- "metadata": {},
76
- "outputs": [
77
- {
78
- "data": {
79
- "text/plain": [
80
- "DatasetDict({\n",
81
- " train: Dataset({\n",
82
- " features: ['Unnamed: 0', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],\n",
83
- " num_rows: 161297\n",
84
- " })\n",
85
- " test: Dataset({\n",
86
- " features: ['Unnamed: 0', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount'],\n",
87
- " num_rows: 53766\n",
88
- " })\n",
89
- "})"
90
- ]
91
- },
92
- "execution_count": 5,
93
- "metadata": {},
94
- "output_type": "execute_result"
95
- }
96
- ],
97
- "source": [
98
- "dataset = load_dataset('csv', data_files={\n",
99
- " 'train': 'data/drugsComTrain_raw.tsv',\n",
100
- " 'test': 'data/drugsComTest_raw.tsv'\n",
101
- "}, delimiter='\\t', num_proc=8)\n",
102
- "dataset"
103
- ]
104
- },
105
- {
106
- "cell_type": "code",
107
- "execution_count": 6,
108
- "id": "dbb81021-9acc-46b4-87c0-23f0f787fef5",
109
- "metadata": {},
110
- "outputs": [
111
- {
112
- "data": {
113
- "text/plain": [
114
- "{'train': (161297, 7), 'test': (53766, 7)}"
115
- ]
116
- },
117
- "execution_count": 6,
118
- "metadata": {},
119
- "output_type": "execute_result"
120
- }
121
- ],
122
- "source": [
123
- "dataset.shape"
124
- ]
125
- },
126
- {
127
- "cell_type": "code",
128
- "execution_count": 7,
129
- "id": "a983147c-eb04-455f-bf02-0c57c2a549e9",
130
- "metadata": {},
131
- "outputs": [
132
- {
133
- "data": {
134
- "text/plain": [
135
- "{'Unnamed: 0': 206461,\n",
136
- " 'drugName': 'Valsartan',\n",
137
- " 'condition': 'Left Ventricular Dysfunction',\n",
138
- " 'review': '\"It has no side effect, I take it in combination of Bystolic 5 Mg and Fish Oil\"',\n",
139
- " 'rating': 9.0,\n",
140
- " 'date': 'May 20, 2012',\n",
141
- " 'usefulCount': 27}"
142
- ]
143
- },
144
- "execution_count": 7,
145
- "metadata": {},
146
- "output_type": "execute_result"
147
- }
148
- ],
149
- "source": [
150
- "dataset['train'][0]"
151
- ]
152
- },
153
- {
154
- "cell_type": "code",
155
- "execution_count": 8,
156
- "id": "ee2b8ddf-79d7-44d6-80ba-243bc2f04de8",
157
- "metadata": {},
158
- "outputs": [
159
- {
160
- "data": {
161
- "text/plain": [
162
- "DatasetDict({\n",
163
- " train: Dataset({\n",
164
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
165
- " num_rows: 138514\n",
166
- " })\n",
167
- " test: Dataset({\n",
168
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
169
- " num_rows: 46108\n",
170
- " })\n",
171
- "})"
172
- ]
173
- },
174
- "execution_count": 8,
175
- "metadata": {},
176
- "output_type": "execute_result"
177
- }
178
- ],
179
- "source": [
180
- "dataset = (\n",
181
- " dataset\n",
182
- " .filter(lambda x: x['condition'] is not None)\n",
183
- " .rename_column('Unnamed: 0', 'row_id')\n",
184
- " .map(lambda x: {'condition': [row.lower() for row in x['condition']]}, batched=True, num_proc=8, batch_size=3000)\n",
185
- " .map(lambda x: {'review': [html.unescape(row) for row in x['review']]}, batched=True, num_proc=8, batch_size=3000)\n",
186
- " .map(lambda x: {'review_length': [len(row.split()) for row in x['review']]}, batched=True, num_proc=8, batch_size=3000)\n",
187
- " # .filter(lambda x: {'review_length': [row > 30 for row in x['review_length']]}, batched=True, num_proc=8)\n",
188
- " .filter(lambda x: x['review_length'] > 30, num_proc=8, batch_size=3000)\n",
189
- ")\n",
190
- "dataset"
191
- ]
192
- },
193
- {
194
- "cell_type": "markdown",
195
- "id": "e7c4daf2-36c1-4074-91ca-8871a581052d",
196
- "metadata": {},
197
- "source": [
198
- "# Exercises"
199
- ]
200
- },
201
- {
202
- "cell_type": "markdown",
203
- "id": "ea14b998-69f1-40a7-a200-7cc53b0e22fd",
204
- "metadata": {},
205
- "source": [
206
- "## Predict patient condition based on drug review"
207
- ]
208
- },
209
- {
210
- "cell_type": "code",
211
- "execution_count": 41,
212
- "id": "dc6b299b-2d0b-4475-bfff-d0180dd672c1",
213
- "metadata": {},
214
- "outputs": [],
215
- "source": [
216
- "from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification, AutoTokenizer, AutoModel, DataCollatorWithPadding\n",
217
- "from torch.utils.data import DataLoader\n",
218
- "import evaluate, numpy as np\n",
219
- "from huggingface_hub import HfApi"
220
- ]
221
- },
222
- {
223
- "cell_type": "code",
224
- "execution_count": 10,
225
- "id": "77caa284-8307-40a0-8369-621195e5c7e9",
226
- "metadata": {},
227
- "outputs": [],
228
- "source": [
229
- "def clean_condition_column(rows):\n",
230
- " target_text = 'users found this comment helpful'\n",
231
- " return {'condition': ['unknown' if target_text in condition else condition for condition in rows['condition']]}"
232
- ]
233
- },
234
- {
235
- "cell_type": "code",
236
- "execution_count": 11,
237
- "id": "058d4c64-428b-43bb-86c4-ba8f5c1b8a84",
238
- "metadata": {},
239
- "outputs": [
240
- {
241
- "data": {
242
- "text/plain": [
243
- "DatasetDict({\n",
244
- " train: Dataset({\n",
245
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
246
- " num_rows: 138514\n",
247
- " })\n",
248
- " test: Dataset({\n",
249
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
250
- " num_rows: 46108\n",
251
- " })\n",
252
- "})"
253
- ]
254
- },
255
- "execution_count": 11,
256
- "metadata": {},
257
- "output_type": "execute_result"
258
- }
259
- ],
260
- "source": [
261
- "dataset = dataset.map(clean_condition_column, batched=True, batch_size=3000, num_proc=8)\n",
262
- "dataset"
263
- ]
264
- },
265
- {
266
- "cell_type": "code",
267
- "execution_count": 12,
268
- "id": "80dc20fe-cb66-4b0d-99dc-88e84413975b",
269
- "metadata": {},
270
- "outputs": [
271
- {
272
- "data": {
273
- "text/plain": [
274
- "DatasetDict({\n",
275
- " train: Dataset({\n",
276
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
277
- " num_rows: 110811\n",
278
- " })\n",
279
- " validation: Dataset({\n",
280
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
281
- " num_rows: 27703\n",
282
- " })\n",
283
- " test: Dataset({\n",
284
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length'],\n",
285
- " num_rows: 46108\n",
286
- " })\n",
287
- "})"
288
- ]
289
- },
290
- "execution_count": 12,
291
- "metadata": {},
292
- "output_type": "execute_result"
293
- }
294
- ],
295
- "source": [
296
- "clean_data = dataset['train'].train_test_split(test_size=.2, seed=5, writer_batch_size=3000)\n",
297
- "clean_data['validation'] = clean_data.pop('test')\n",
298
- "clean_data['test'] = dataset['test']\n",
299
- "\n",
300
- "clean_data"
301
- ]
302
- },
303
- {
304
- "cell_type": "code",
305
- "execution_count": 13,
306
- "id": "8be33fbb-143f-45b5-9e18-c5662a7e0dad",
307
- "metadata": {},
308
- "outputs": [
309
- {
310
- "data": {
311
- "text/plain": [
312
- "751"
313
- ]
314
- },
315
- "execution_count": 13,
316
- "metadata": {},
317
- "output_type": "execute_result"
318
- }
319
- ],
320
- "source": [
321
- "all_conditions = sorted(set(clean_data['train']['condition']).union(set(clean_data['validation']['condition'])))\n",
322
- "len(all_conditions)"
323
- ]
324
- },
325
- {
326
- "cell_type": "code",
327
- "execution_count": 14,
328
- "id": "912ef7d5-149a-48ed-ac6b-1ff2f3c2556a",
329
- "metadata": {},
330
- "outputs": [],
331
- "source": [
332
- "id2label = dict(enumerate(all_conditions))\n",
333
- "label2id = {v:k for k, v in id2label.items()}"
334
- ]
335
- },
336
- {
337
- "cell_type": "code",
338
- "execution_count": 15,
339
- "id": "aca4a239-3f07-44bf-905e-2743b8f0889d",
340
- "metadata": {},
341
- "outputs": [
342
- {
343
- "data": {
344
- "text/plain": [
345
- "True"
346
- ]
347
- },
348
- "execution_count": 15,
349
- "metadata": {},
350
- "output_type": "execute_result"
351
- }
352
- ],
353
- "source": [
354
- "len(label2id) == len(id2label)"
355
- ]
356
- },
357
- {
358
- "cell_type": "code",
359
- "execution_count": 16,
360
- "id": "024d5faa-88f1-41b7-9f52-8178ad731089",
361
- "metadata": {},
362
- "outputs": [
363
- {
364
- "data": {
365
- "text/plain": [
366
- "DatasetDict({\n",
367
- " train: Dataset({\n",
368
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels'],\n",
369
- " num_rows: 110811\n",
370
- " })\n",
371
- " validation: Dataset({\n",
372
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels'],\n",
373
- " num_rows: 27703\n",
374
- " })\n",
375
- " test: Dataset({\n",
376
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels'],\n",
377
- " num_rows: 46108\n",
378
- " })\n",
379
- "})"
380
- ]
381
- },
382
- "execution_count": 16,
383
- "metadata": {},
384
- "output_type": "execute_result"
385
- }
386
- ],
387
- "source": [
388
- "clean_data = clean_data.map(lambda x: {'labels': [label2id.get(condition, label2id['unknown']) for condition in x['condition']]}, batched=True, batch_size=3000, num_proc=8)\n",
389
- "clean_data"
390
- ]
391
- },
392
- {
393
- "cell_type": "code",
394
- "execution_count": 17,
395
- "id": "2f71cacc-9fb4-4436-b32b-8f172bcc19b1",
396
- "metadata": {},
397
- "outputs": [
398
- {
399
- "name": "stderr",
400
- "output_type": "stream",
401
- "text": [
402
- "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
403
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
404
- ]
405
- }
406
- ],
407
- "source": [
408
- "# checkpoint = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'\n",
409
- "checkpoint = 'distilbert-base-uncased'\n",
410
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=len(id2label)).to(device)\n",
411
- "tokenizer = AutoTokenizer.from_pretrained(checkpoint)"
412
- ]
413
- },
414
- {
415
- "cell_type": "code",
416
- "execution_count": 18,
417
- "id": "e9b2c2bd-52d4-47e0-aaaf-eb76b3bab9fa",
418
- "metadata": {},
419
- "outputs": [],
420
- "source": [
421
- "model.config.id2label = id2label\n",
422
- "model.config.label2id = label2id\n",
423
- "model.num_labels = len(label2id)"
424
- ]
425
- },
426
- {
427
- "cell_type": "code",
428
- "execution_count": 19,
429
- "id": "2d3bb44b-e635-4e7c-b984-6379510b60b3",
430
- "metadata": {},
431
- "outputs": [],
432
- "source": [
433
- "collator = DataCollatorWithPadding(tokenizer)"
434
- ]
435
- },
436
- {
437
- "cell_type": "code",
438
- "execution_count": 20,
439
- "id": "c22a17ab-4a43-45f6-ba99-62cdb94103c5",
440
- "metadata": {},
441
- "outputs": [],
442
- "source": [
443
- "def tokenize_and_split(examples):\n",
444
- " tokens = tokenizer(\n",
445
- " examples[\"review\"],\n",
446
- " truncation=True,\n",
447
- " max_length=512,\n",
448
- " return_overflowing_tokens=True,\n",
449
- " )\n",
450
- " mappings = tokens.pop('overflow_to_sample_mapping')\n",
451
- " for key, values in examples.items():\n",
452
- " tokens[key] = [values[idx] for idx in mappings]\n",
453
- " return tokens"
454
- ]
455
- },
456
- {
457
- "cell_type": "code",
458
- "execution_count": 21,
459
- "id": "5a1b9eb6-87a1-4d7f-855b-f1c9e5ae63c2",
460
- "metadata": {},
461
- "outputs": [
462
- {
463
- "data": {
464
- "text/plain": [
465
- "DatasetDict({\n",
466
- " train: Dataset({\n",
467
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels', 'input_ids', 'attention_mask'],\n",
468
- " num_rows: 110857\n",
469
- " })\n",
470
- " validation: Dataset({\n",
471
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels', 'input_ids', 'attention_mask'],\n",
472
- " num_rows: 27717\n",
473
- " })\n",
474
- " test: Dataset({\n",
475
- " features: ['row_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length', 'labels', 'input_ids', 'attention_mask'],\n",
476
- " num_rows: 46118\n",
477
- " })\n",
478
- "})"
479
- ]
480
- },
481
- "execution_count": 21,
482
- "metadata": {},
483
- "output_type": "execute_result"
484
- }
485
- ],
486
- "source": [
487
- "tokenized_dataset = clean_data.map(tokenize_and_split, batched=True, batch_size=3000, num_proc=8)\n",
488
- "tokenized_dataset"
489
- ]
490
- },
491
- {
492
- "cell_type": "code",
493
- "execution_count": null,
494
- "id": "2729d5c2-499d-41f0-8ddb-27df3cf82475",
495
- "metadata": {},
496
- "outputs": [],
497
- "source": []
498
- },
499
- {
500
- "cell_type": "code",
501
- "execution_count": null,
502
- "id": "3488692d-44ef-4b99-af4c-8fa32d6ed3b2",
503
- "metadata": {},
504
- "outputs": [],
505
- "source": [
506
- "tokenized_dataset.save_to_disk('data/drugs', num_proc=4)"
507
- ]
508
- },
509
- {
510
- "cell_type": "code",
511
- "execution_count": null,
512
- "id": "2bc7b3ea-5f48-4298-b625-d313c4dc1ea3",
513
- "metadata": {},
514
- "outputs": [],
515
- "source": [
516
- "tokenized_dataset = load_from_disk('data/drugs/')\n",
517
- "tokenized_dataset"
518
- ]
519
- },
520
- {
521
- "cell_type": "code",
522
- "execution_count": null,
523
- "id": "001bb28a-9ff1-463f-90af-22dc7f6bce53",
524
- "metadata": {},
525
- "outputs": [],
526
- "source": []
527
- },
528
- {
529
- "cell_type": "code",
530
- "execution_count": 22,
531
- "id": "344a5505-f143-4389-8be3-282219f29d74",
532
- "metadata": {},
533
- "outputs": [
534
- {
535
- "data": {
536
- "text/plain": [
537
- "DatasetDict({\n",
538
- " train: Dataset({\n",
539
- " features: ['input_ids', 'attention_mask', 'labels'],\n",
540
- " num_rows: 110857\n",
541
- " })\n",
542
- " validation: Dataset({\n",
543
- " features: ['input_ids', 'attention_mask', 'labels'],\n",
544
- " num_rows: 27717\n",
545
- " })\n",
546
- " test: Dataset({\n",
547
- " features: ['input_ids', 'attention_mask', 'labels'],\n",
548
- " num_rows: 46118\n",
549
- " })\n",
550
- "})"
551
- ]
552
- },
553
- "execution_count": 22,
554
- "metadata": {},
555
- "output_type": "execute_result"
556
- }
557
- ],
558
- "source": [
559
- "filtered = tokenized_dataset.select_columns(['input_ids', 'attention_mask', 'labels'])\n",
560
- "filtered"
561
- ]
562
- },
563
- {
564
- "cell_type": "code",
565
- "execution_count": 23,
566
- "id": "b31de787-0312-4d67-8b41-ce85732308ea",
567
- "metadata": {},
568
- "outputs": [],
569
- "source": [
570
- "accuracy = evaluate.load('accuracy')"
571
- ]
572
- },
573
- {
574
- "cell_type": "code",
575
- "execution_count": 24,
576
- "id": "f6d0543e-06d5-4930-93f6-8028e4e4ead5",
577
- "metadata": {},
578
- "outputs": [],
579
- "source": [
580
- "def compute_metrics(eval_preds):\n",
581
- " logits, labels = eval_preds\n",
582
- " preds = np.argmax(logits, axis=-1)\n",
583
- " return accuracy.compute(predictions=preds, references=labels)"
584
- ]
585
- },
586
- {
587
- "cell_type": "code",
588
- "execution_count": 25,
589
- "id": "ec5be835-e194-47a4-8c2a-3eb7500645ad",
590
- "metadata": {},
591
- "outputs": [],
592
- "source": [
593
- "lr = 3e-5"
594
- ]
595
- },
596
- {
597
- "cell_type": "code",
598
- "execution_count": 26,
599
- "id": "d04e4bae-8bb0-4e5e-be0f-2ce41db1bbe6",
600
- "metadata": {},
601
- "outputs": [],
602
- "source": [
603
- "train_args = TrainingArguments(\n",
604
- " 'medical_condition_classification', \n",
605
- " overwrite_output_dir=True, \n",
606
- " eval_strategy='steps', eval_steps=2000, \n",
607
- " per_device_train_batch_size=24, \n",
608
- " per_device_eval_batch_size=24, \n",
609
- " fp16=True, num_train_epochs=5,\n",
610
- " learning_rate=lr,\n",
611
- " push_to_hub=True,\n",
612
- " hub_token=access_token\n",
613
- ")"
614
- ]
615
- },
616
- {
617
- "cell_type": "code",
618
- "execution_count": 27,
619
- "id": "e26faf3f-03ab-411d-97a2-c1a3b6e2b425",
620
- "metadata": {},
621
- "outputs": [],
622
- "source": [
623
- "trainer = Trainer(model, train_args, collator, filtered['train'], filtered['validation'], tokenizer, compute_metrics=compute_metrics)"
624
- ]
625
- },
626
- {
627
- "cell_type": "code",
628
- "execution_count": 28,
629
- "id": "52c55095-4761-4353-8222-887cdf309431",
630
- "metadata": {},
631
- "outputs": [
632
- {
633
- "data": {
634
- "text/html": [
635
- "\n",
636
- " <div>\n",
637
- " \n",
638
- " <progress value='23100' max='23100' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
639
- " [23100/23100 1:14:13, Epoch 5/5]\n",
640
- " </div>\n",
641
- " <table border=\"1\" class=\"dataframe\">\n",
642
- " <thead>\n",
643
- " <tr style=\"text-align: left;\">\n",
644
- " <th>Step</th>\n",
645
- " <th>Training Loss</th>\n",
646
- " <th>Validation Loss</th>\n",
647
- " <th>Accuracy</th>\n",
648
- " </tr>\n",
649
- " </thead>\n",
650
- " <tbody>\n",
651
- " <tr>\n",
652
- " <td>2000</td>\n",
653
- " <td>1.862500</td>\n",
654
- " <td>1.719871</td>\n",
655
- " <td>0.639680</td>\n",
656
- " </tr>\n",
657
- " <tr>\n",
658
- " <td>4000</td>\n",
659
- " <td>1.459000</td>\n",
660
- " <td>1.369566</td>\n",
661
- " <td>0.688963</td>\n",
662
- " </tr>\n",
663
- " <tr>\n",
664
- " <td>6000</td>\n",
665
- " <td>1.173700</td>\n",
666
- " <td>1.213141</td>\n",
667
- " <td>0.717249</td>\n",
668
- " </tr>\n",
669
- " <tr>\n",
670
- " <td>8000</td>\n",
671
- " <td>1.042000</td>\n",
672
- " <td>1.101419</td>\n",
673
- " <td>0.732908</td>\n",
674
- " </tr>\n",
675
- " <tr>\n",
676
- " <td>10000</td>\n",
677
- " <td>0.843100</td>\n",
678
- " <td>1.032237</td>\n",
679
- " <td>0.750983</td>\n",
680
- " </tr>\n",
681
- " <tr>\n",
682
- " <td>12000</td>\n",
683
- " <td>0.801200</td>\n",
684
- " <td>0.988939</td>\n",
685
- " <td>0.758668</td>\n",
686
- " </tr>\n",
687
- " <tr>\n",
688
- " <td>14000</td>\n",
689
- " <td>0.731200</td>\n",
690
- " <td>0.949687</td>\n",
691
- " <td>0.772703</td>\n",
692
- " </tr>\n",
693
- " <tr>\n",
694
- " <td>16000</td>\n",
695
- " <td>0.656100</td>\n",
696
- " <td>0.933845</td>\n",
697
- " <td>0.780496</td>\n",
698
- " </tr>\n",
699
- " <tr>\n",
700
- " <td>18000</td>\n",
701
- " <td>0.613200</td>\n",
702
- " <td>0.907262</td>\n",
703
- " <td>0.787531</td>\n",
704
- " </tr>\n",
705
- " <tr>\n",
706
- " <td>20000</td>\n",
707
- " <td>0.519500</td>\n",
708
- " <td>0.901089</td>\n",
709
- " <td>0.792943</td>\n",
710
- " </tr>\n",
711
- " <tr>\n",
712
- " <td>22000</td>\n",
713
- " <td>0.501500</td>\n",
714
- " <td>0.892959</td>\n",
715
- " <td>0.795072</td>\n",
716
- " </tr>\n",
717
- " </tbody>\n",
718
- "</table><p>"
719
- ],
720
- "text/plain": [
721
- "<IPython.core.display.HTML object>"
722
- ]
723
- },
724
- "metadata": {},
725
- "output_type": "display_data"
726
- },
727
- {
728
- "data": {
729
- "text/plain": [
730
- "TrainOutput(global_step=23100, training_loss=1.0162131207949154, metrics={'train_runtime': 4454.3937, 'train_samples_per_second': 124.436, 'train_steps_per_second': 5.186, 'total_flos': 2.958796560013029e+16, 'train_loss': 1.0162131207949154, 'epoch': 5.0})"
731
- ]
732
- },
733
- "execution_count": 28,
734
- "metadata": {},
735
- "output_type": "execute_result"
736
- }
737
- ],
738
- "source": [
739
- "trainer.train()"
740
- ]
741
- },
742
- {
743
- "cell_type": "code",
744
- "execution_count": 29,
745
- "id": "7c8d06d3-ef08-42ca-9dad-651c3a7c45fc",
746
- "metadata": {},
747
- "outputs": [
748
- {
749
- "data": {
750
- "text/html": [],
751
- "text/plain": [
752
- "<IPython.core.display.HTML object>"
753
- ]
754
- },
755
- "metadata": {},
756
- "output_type": "display_data"
757
- }
758
- ],
759
- "source": [
760
- "with torch.no_grad():\n",
761
- " preds = trainer.predict(filtered['test'])"
762
- ]
763
- },
764
- {
765
- "cell_type": "code",
766
- "execution_count": 33,
767
- "id": "cab2f41e-d00f-41cb-a5a6-daf9e713077d",
768
- "metadata": {},
769
- "outputs": [
770
- {
771
- "data": {
772
- "text/plain": [
773
- "{'test_loss': 0.8813542127609253,\n",
774
- " 'test_accuracy': 0.8004249967474739,\n",
775
- " 'test_runtime': 87.98,\n",
776
- " 'test_samples_per_second': 524.188,\n",
777
- " 'test_steps_per_second': 21.846}"
778
- ]
779
- },
780
- "execution_count": 33,
781
- "metadata": {},
782
- "output_type": "execute_result"
783
- }
784
- ],
785
- "source": [
786
- "preds.metrics"
787
- ]
788
- },
789
- {
790
- "cell_type": "code",
791
- "execution_count": 34,
792
- "id": "1e323be4-ac78-498e-99fc-3133b11dc241",
793
- "metadata": {},
794
- "outputs": [
795
- {
796
- "data": {
797
- "application/vnd.jupyter.widget-view+json": {
798
- "model_id": "d45b5b475cce4bb09298a7278ec51c64",
799
- "version_major": 2,
800
- "version_minor": 0
801
- },
802
- "text/plain": [
803
- "model.safetensors: 0%| | 0.00/270M [00:00<?, ?B/s]"
804
- ]
805
- },
806
- "metadata": {},
807
- "output_type": "display_data"
808
- },
809
- {
810
- "data": {
811
- "text/plain": [
812
- "CommitInfo(commit_url='https://huggingface.co/samsaara/medical_condition_classification/commit/73d201a08dd78ce2afea66736b188d271e652052', commit_message='End of training', commit_description='', oid='73d201a08dd78ce2afea66736b188d271e652052', pr_url=None, repo_url=RepoUrl('https://huggingface.co/samsaara/medical_condition_classification', endpoint='https://huggingface.co', repo_type='model', repo_id='samsaara/medical_condition_classification'), pr_revision=None, pr_num=None)"
813
- ]
814
- },
815
- "execution_count": 34,
816
- "metadata": {},
817
- "output_type": "execute_result"
818
- }
819
- ],
820
- "source": [
821
- "trainer.push_to_hub()"
822
- ]
823
- },
824
- {
825
- "cell_type": "code",
826
- "execution_count": 36,
827
- "id": "6e754047-4b6c-4b3c-80a1-8493009ac7ca",
828
- "metadata": {},
829
- "outputs": [
830
- {
831
- "data": {
832
- "application/vnd.jupyter.widget-view+json": {
833
- "model_id": "9f853954de9b4fe5bcf797878413702f",
834
- "version_major": 2,
835
- "version_minor": 0
836
- },
837
- "text/plain": [
838
- "README.md: 0%| | 0.00/2.11k [00:00<?, ?B/s]"
839
- ]
840
- },
841
- "metadata": {},
842
- "output_type": "display_data"
843
- },
844
- {
845
- "data": {
846
- "text/plain": [
847
- "CommitInfo(commit_url='https://huggingface.co/samsaara/medical_condition_classification/commit/9a28b6773707f2b88e7eeb37bc811c642bb524c7', commit_message='tokenizer', commit_description='', oid='9a28b6773707f2b88e7eeb37bc811c642bb524c7', pr_url=None, repo_url=RepoUrl('https://huggingface.co/samsaara/medical_condition_classification', endpoint='https://huggingface.co', repo_type='model', repo_id='samsaara/medical_condition_classification'), pr_revision=None, pr_num=None)"
848
- ]
849
- },
850
- "execution_count": 36,
851
- "metadata": {},
852
- "output_type": "execute_result"
853
- }
854
- ],
855
- "source": [
856
- "tokenizer.push_to_hub('medical_condition_classification', commit_message='tokenizer')"
857
- ]
858
- },
859
- {
860
- "cell_type": "code",
861
- "execution_count": 40,
862
- "id": "2940e50a-328d-4f30-8cb7-30dd047d2f92",
863
- "metadata": {},
864
- "outputs": [
865
- {
866
- "data": {
867
- "application/vnd.jupyter.widget-view+json": {
868
- "model_id": "51c9d91e63664cba896e957841756995",
869
- "version_major": 2,
870
- "version_minor": 0
871
- },
872
- "text/plain": [
873
- "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]"
874
- ]
875
- },
876
- "metadata": {},
877
- "output_type": "display_data"
878
- },
879
- {
880
- "data": {
881
- "application/vnd.jupyter.widget-view+json": {
882
- "model_id": "e3bb0c86842e427eba71257d113c0845",
883
- "version_major": 2,
884
- "version_minor": 0
885
- },
886
- "text/plain": [
887
- "Creating parquet from Arrow format: 0%| | 0/111 [00:00<?, ?ba/s]"
888
- ]
889
- },
890
- "metadata": {},
891
- "output_type": "display_data"
892
- },
893
- {
894
- "data": {
895
- "application/vnd.jupyter.widget-view+json": {
896
- "model_id": "73d3373d5da04164a237ff56dc57fac5",
897
- "version_major": 2,
898
- "version_minor": 0
899
- },
900
- "text/plain": [
901
- "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]"
902
- ]
903
- },
904
- "metadata": {},
905
- "output_type": "display_data"
906
- },
907
- {
908
- "data": {
909
- "application/vnd.jupyter.widget-view+json": {
910
- "model_id": "df6679bdb3c342ffa2aab53922a3d1af",
911
- "version_major": 2,
912
- "version_minor": 0
913
- },
914
- "text/plain": [
915
- "Creating parquet from Arrow format: 0%| | 0/28 [00:00<?, ?ba/s]"
916
- ]
917
- },
918
- "metadata": {},
919
- "output_type": "display_data"
920
- },
921
- {
922
- "data": {
923
- "application/vnd.jupyter.widget-view+json": {
924
- "model_id": "60772b393cdc4dca9191a538b9116169",
925
- "version_major": 2,
926
- "version_minor": 0
927
- },
928
- "text/plain": [
929
- "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]"
930
- ]
931
- },
932
- "metadata": {},
933
- "output_type": "display_data"
934
- },
935
- {
936
- "data": {
937
- "application/vnd.jupyter.widget-view+json": {
938
- "model_id": "9b859655cbe34da184c8e6b6a07d0911",
939
- "version_major": 2,
940
- "version_minor": 0
941
- },
942
- "text/plain": [
943
- "Creating parquet from Arrow format: 0%| | 0/47 [00:00<?, ?ba/s]"
944
- ]
945
- },
946
- "metadata": {},
947
- "output_type": "display_data"
948
- },
949
- {
950
- "data": {
951
- "text/plain": [
952
- "CommitInfo(commit_url='https://huggingface.co/datasets/samsaara/medical_condition_classification/commit/7aea5155fcba521a02ec3e9e8fb4e86d09dc44ba', commit_message='Upload dataset', commit_description='', oid='7aea5155fcba521a02ec3e9e8fb4e86d09dc44ba', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/samsaara/medical_condition_classification', endpoint='https://huggingface.co', repo_type='dataset', repo_id='samsaara/medical_condition_classification'), pr_revision=None, pr_num=None)"
953
- ]
954
- },
955
- "execution_count": 40,
956
- "metadata": {},
957
- "output_type": "execute_result"
958
- }
959
- ],
960
- "source": [
961
- "tokenized_dataset.push_to_hub('medical_condition_classification')"
962
- ]
963
- },
964
- {
965
- "cell_type": "code",
966
- "execution_count": 42,
967
- "id": "4bbf39bb-fea9-4b5b-8d87-7a9835975358",
968
- "metadata": {},
969
- "outputs": [],
970
- "source": [
971
- "api = HfApi(token=access_token)"
972
- ]
973
- },
974
- {
975
- "cell_type": "code",
976
- "execution_count": 46,
977
- "id": "16e0625b-5518-463a-96e2-2d008341b1f1",
978
- "metadata": {},
979
- "outputs": [
980
- {
981
- "data": {
982
- "text/plain": [
983
- "CommitInfo(commit_url='https://huggingface.co/samsaara/medical_condition_classification/commit/98ede0386065880a4cfefcb0ab1c9d7bfc9d081d', commit_message='update README', commit_description='', oid='98ede0386065880a4cfefcb0ab1c9d7bfc9d081d', pr_url=None, repo_url=RepoUrl('https://huggingface.co/samsaara/medical_condition_classification', endpoint='https://huggingface.co', repo_type='model', repo_id='samsaara/medical_condition_classification'), pr_revision=None, pr_num=None)"
984
- ]
985
- },
986
- "execution_count": 46,
987
- "metadata": {},
988
- "output_type": "execute_result"
989
- }
990
- ],
991
- "source": [
992
- "api.upload_file(\n",
993
- " path_or_fileobj='./medical_condition_classification/README.md', \n",
994
- " path_in_repo='README.md',\n",
995
- " repo_id='samsaara/medical_condition_classification', \n",
996
- " commit_message='update README'\n",
997
- ")"
998
- ]
999
- },
1000
- {
1001
- "cell_type": "code",
1002
- "execution_count": 47,
1003
- "id": "4144a617-f0d6-41d3-9a44-90833b8bb1f5",
1004
- "metadata": {},
1005
- "outputs": [
1006
- {
1007
- "data": {
1008
- "text/plain": [
1009
- "CommitInfo(commit_url='https://huggingface.co/samsaara/medical_condition_classification/commit/53b029a816883983962ebe0050977be8ee501d82', commit_message='notebook for training & evaluation', commit_description='', oid='53b029a816883983962ebe0050977be8ee501d82', pr_url=None, repo_url=RepoUrl('https://huggingface.co/samsaara/medical_condition_classification', endpoint='https://huggingface.co', repo_type='model', repo_id='samsaara/medical_condition_classification'), pr_revision=None, pr_num=None)"
1010
- ]
1011
- },
1012
- "execution_count": 47,
1013
- "metadata": {},
1014
- "output_type": "execute_result"
1015
- }
1016
- ],
1017
- "source": [
1018
- "api.upload_file(\n",
1019
- " path_or_fileobj='datasets.ipynb', \n",
1020
- " path_in_repo='datasets.ipynb',\n",
1021
- " repo_id='samsaara/medical_condition_classification', \n",
1022
- " commit_message='notebook for training & evaluation'\n",
1023
- ")"
1024
- ]
1025
- },
1026
- {
1027
- "cell_type": "code",
1028
- "execution_count": null,
1029
- "id": "2de298eb-82f2-482b-acee-36c6a1e630b8",
1030
- "metadata": {},
1031
- "outputs": [],
1032
- "source": []
1033
- }
1034
- ],
1035
- "metadata": {
1036
- "kernelspec": {
1037
- "display_name": "Python 3 (ipykernel)",
1038
- "language": "python",
1039
- "name": "python3"
1040
- },
1041
- "language_info": {
1042
- "codemirror_mode": {
1043
- "name": "ipython",
1044
- "version": 3
1045
- },
1046
- "file_extension": ".py",
1047
- "mimetype": "text/x-python",
1048
- "name": "python",
1049
- "nbconvert_exporter": "python",
1050
- "pygments_lexer": "ipython3",
1051
- "version": "3.11.10"
1052
- }
1053
- },
1054
- "nbformat": 4,
1055
- "nbformat_minor": 5
1056
- }