File size: 5,109 Bytes
596f1c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "b3266a7b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"from geneformer import MTLClassifier"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e12ac9f",
"metadata": {},
"outputs": [],
"source": [
"# Define paths\n",
"pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
"train_path = \"/path/to/train/data.dataset\"\n",
"val_path = \"/path/to/val/data.dataset\"\n",
"test_path = \"/path/to/test/data.dataset\"\n",
"results_dir = \"/path/to/results/directory\"\n",
"model_save_path = \"/path/to/model/save/path\"\n",
"tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
"\n",
"# Define tasks and hyperparameters\n",
"# task_columns should be a list of column names from your dataset\n",
"# Each column represents a specific classification task (e.g. cell type, disease state)\n",
"task_columns = [\"cell_type\", \"disease_state\"] # Example task columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9bd7562",
"metadata": {},
"outputs": [],
"source": [
"# Check GPU environment\n",
"num_gpus = torch.cuda.device_count()\n",
"use_distributed = num_gpus > 1\n",
"print(f\"Number of GPUs detected: {num_gpus}\")\n",
"print(f\"Using distributed training: {use_distributed}\")\n",
"\n",
"# Set environment variables for distributed training when multiple GPUs are available\n",
"if use_distributed:\n",
" os.environ[\"MASTER_ADDR\"] = \"localhost\" # hostname\n",
" os.environ[\"MASTER_PORT\"] = \"12355\" # Choose an available port\n",
" print(\"Distributed environment variables set.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6ff3618",
"metadata": {},
"outputs": [],
"source": [
"#Define Hyperparameters for Optimization\n",
"hyperparameters = {\n",
" \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
" \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
" \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
" \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
" \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
" \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0},\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f665c5a7",
"metadata": {},
"outputs": [],
"source": [
"mc = MTLClassifier(\n",
" task_columns=task_columns, # Our defined classification tasks\n",
" study_name=\"MTLClassifier_distributed\",\n",
" pretrained_path=pretrained_path,\n",
" train_path=train_path,\n",
" val_path=val_path,\n",
" test_path=test_path,\n",
" model_save_path=model_save_path,\n",
" results_dir=results_dir,\n",
" tensorboard_log_dir=tensorboard_log_dir,\n",
" hyperparameters=hyperparameters,\n",
" # Distributed training parameters\n",
" distributed_training=use_distributed, # Enable distributed training if multiple GPUs available\n",
" master_addr=\"localhost\" if use_distributed else None,\n",
" master_port=\"12355\" if use_distributed else None,\n",
" # Other training parameters\n",
" n_trials=15, # Number of trials for hyperparameter optimization\n",
" epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
" batch_size=8, # Adjust based on available GPU memory\n",
" gradient_accumulation_steps=4, # Accumulate gradients over multiple steps\n",
" gradient_clipping=True, # Enable gradient clipping for stability\n",
" max_grad_norm=1.0, # Set maximum gradient norm\n",
" seed=42\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f69f7b6a",
"metadata": {},
"outputs": [],
"source": [
"# Run Hyperparameter Optimization with Distributed Training\n",
"if __name__ == \"__main__\":\n",
" # This guard is required for distributed training to prevent\n",
" # infinite subprocess spawning when using torch.multiprocessing\n",
" mc.run_optuna_study()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3affd5dd",
"metadata": {},
"outputs": [],
"source": [
"# Evaluate the Model on Test Data\n",
"if __name__ == \"__main__\":\n",
" mc.load_and_evaluate_test_model()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "bio",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|