File size: 5,109 Bytes
596f1c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3266a7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from geneformer import MTLClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e12ac9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define paths\n",
    "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
    "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
    "train_path = \"/path/to/train/data.dataset\"\n",
    "val_path = \"/path/to/val/data.dataset\"\n",
    "test_path = \"/path/to/test/data.dataset\"\n",
    "results_dir = \"/path/to/results/directory\"\n",
    "model_save_path = \"/path/to/model/save/path\"\n",
    "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
    "\n",
    "# Define tasks and hyperparameters\n",
    "# task_columns should be a list of column names from your dataset\n",
    "# Each column represents a specific classification task (e.g. cell type, disease state)\n",
    "task_columns = [\"cell_type\", \"disease_state\"]  # Example task columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9bd7562",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check GPU environment\n",
    "num_gpus = torch.cuda.device_count()\n",
    "use_distributed = num_gpus > 1\n",
    "print(f\"Number of GPUs detected: {num_gpus}\")\n",
    "print(f\"Using distributed training: {use_distributed}\")\n",
    "\n",
    "# Set environment variables for distributed training when multiple GPUs are available\n",
    "if use_distributed:\n",
    "    os.environ[\"MASTER_ADDR\"] = \"localhost\"  # hostname\n",
    "    os.environ[\"MASTER_PORT\"] = \"12355\"      # Choose an available port\n",
    "    print(\"Distributed environment variables set.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6ff3618",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Define Hyperparameters for Optimization\n",
    "hyperparameters = {\n",
    "    \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
    "    \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
    "    \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
    "    \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
    "    \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
    "    \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0},\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f665c5a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "mc = MTLClassifier(\n",
    "    task_columns=task_columns,  # Our defined classification tasks\n",
    "    study_name=\"MTLClassifier_distributed\",\n",
    "    pretrained_path=pretrained_path,\n",
    "    train_path=train_path,\n",
    "    val_path=val_path,\n",
    "    test_path=test_path,\n",
    "    model_save_path=model_save_path,\n",
    "    results_dir=results_dir,\n",
    "    tensorboard_log_dir=tensorboard_log_dir,\n",
    "    hyperparameters=hyperparameters,\n",
    "    # Distributed training parameters\n",
    "    distributed_training=use_distributed,  # Enable distributed training if multiple GPUs available\n",
    "    master_addr=\"localhost\" if use_distributed else None,\n",
    "    master_port=\"12355\" if use_distributed else None,\n",
    "    # Other training parameters\n",
    "    n_trials=15,  # Number of trials for hyperparameter optimization\n",
    "    epochs=1,     # Number of training epochs (1 suggested to prevent overfitting)\n",
    "    batch_size=8, # Adjust based on available GPU memory\n",
    "    gradient_accumulation_steps=4,  # Accumulate gradients over multiple steps\n",
    "    gradient_clipping=True,         # Enable gradient clipping for stability\n",
    "    max_grad_norm=1.0,              # Set maximum gradient norm\n",
    "    seed=42\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f69f7b6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run Hyperparameter Optimization with Distributed Training\n",
    "if __name__ == \"__main__\":\n",
    "    # This guard is required for distributed training to prevent\n",
    "    # infinite subprocess spawning when using torch.multiprocessing\n",
    "    mc.run_optuna_study()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3affd5dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the Model on Test Data\n",
    "if __name__ == \"__main__\":\n",
    "    mc.load_and_evaluate_test_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bio",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}