Christina Theodoris commited on
Commit
596f1c2
·
1 Parent(s): f3ff19d

update dist multitask example name

Browse files
examples/distributed_multitask_cell_classification.ipynb ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "b3266a7b",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import torch\n",
12
+ "from geneformer import MTLClassifier"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "3e12ac9f",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "# Define paths\n",
23
+ "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
24
+ "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
25
+ "train_path = \"/path/to/train/data.dataset\"\n",
26
+ "val_path = \"/path/to/val/data.dataset\"\n",
27
+ "test_path = \"/path/to/test/data.dataset\"\n",
28
+ "results_dir = \"/path/to/results/directory\"\n",
29
+ "model_save_path = \"/path/to/model/save/path\"\n",
30
+ "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
31
+ "\n",
32
+ "# Define tasks and hyperparameters\n",
33
+ "# task_columns should be a list of column names from your dataset\n",
34
+ "# Each column represents a specific classification task (e.g. cell type, disease state)\n",
35
+ "task_columns = [\"cell_type\", \"disease_state\"] # Example task columns"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "c9bd7562",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "# Check GPU environment\n",
46
+ "num_gpus = torch.cuda.device_count()\n",
47
+ "use_distributed = num_gpus > 1\n",
48
+ "print(f\"Number of GPUs detected: {num_gpus}\")\n",
49
+ "print(f\"Using distributed training: {use_distributed}\")\n",
50
+ "\n",
51
+ "# Set environment variables for distributed training when multiple GPUs are available\n",
52
+ "if use_distributed:\n",
53
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\" # hostname\n",
54
+ " os.environ[\"MASTER_PORT\"] = \"12355\" # Choose an available port\n",
55
+ " print(\"Distributed environment variables set.\")"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "id": "b6ff3618",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "#Define Hyperparameters for Optimization\n",
66
+ "hyperparameters = {\n",
67
+ " \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
68
+ " \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
69
+ " \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
70
+ " \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
71
+ " \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
72
+ " \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0},\n",
73
+ "}"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "f665c5a7",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "mc = MTLClassifier(\n",
84
+ " task_columns=task_columns, # Our defined classification tasks\n",
85
+ " study_name=\"MTLClassifier_distributed\",\n",
86
+ " pretrained_path=pretrained_path,\n",
87
+ " train_path=train_path,\n",
88
+ " val_path=val_path,\n",
89
+ " test_path=test_path,\n",
90
+ " model_save_path=model_save_path,\n",
91
+ " results_dir=results_dir,\n",
92
+ " tensorboard_log_dir=tensorboard_log_dir,\n",
93
+ " hyperparameters=hyperparameters,\n",
94
+ " # Distributed training parameters\n",
95
+ " distributed_training=use_distributed, # Enable distributed training if multiple GPUs available\n",
96
+ " master_addr=\"localhost\" if use_distributed else None,\n",
97
+ " master_port=\"12355\" if use_distributed else None,\n",
98
+ " # Other training parameters\n",
99
+ " n_trials=15, # Number of trials for hyperparameter optimization\n",
100
+ " epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
101
+ " batch_size=8, # Adjust based on available GPU memory\n",
102
+ " gradient_accumulation_steps=4, # Accumulate gradients over multiple steps\n",
103
+ " gradient_clipping=True, # Enable gradient clipping for stability\n",
104
+ " max_grad_norm=1.0, # Set maximum gradient norm\n",
105
+ " seed=42\n",
106
+ ")"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "id": "f69f7b6a",
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "# Run Hyperparameter Optimization with Distributed Training\n",
117
+ "if __name__ == \"__main__\":\n",
118
+ " # This guard is required for distributed training to prevent\n",
119
+ " # infinite subprocess spawning when using torch.multiprocessing\n",
120
+ " mc.run_optuna_study()"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "id": "3affd5dd",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "# Evaluate the Model on Test Data\n",
131
+ "if __name__ == \"__main__\":\n",
132
+ " mc.load_and_evaluate_test_model()"
133
+ ]
134
+ }
135
+ ],
136
+ "metadata": {
137
+ "kernelspec": {
138
+ "display_name": "bio",
139
+ "language": "python",
140
+ "name": "python3"
141
+ },
142
+ "language_info": {
143
+ "name": "python",
144
+ "version": "3.12.8"
145
+ }
146
+ },
147
+ "nbformat": 4,
148
+ "nbformat_minor": 5
149
+ }