{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "b3266a7b", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "from geneformer import MTLClassifier" ] }, { "cell_type": "code", "execution_count": null, "id": "3e12ac9f", "metadata": {}, "outputs": [], "source": [ "# Define paths\n", "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n", "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n", "train_path = \"/path/to/train/data.dataset\"\n", "val_path = \"/path/to/val/data.dataset\"\n", "test_path = \"/path/to/test/data.dataset\"\n", "results_dir = \"/path/to/results/directory\"\n", "model_save_path = \"/path/to/model/save/path\"\n", "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n", "\n", "# Define tasks and hyperparameters\n", "# task_columns should be a list of column names from your dataset\n", "# Each column represents a specific classification task (e.g. cell type, disease state)\n", "task_columns = [\"cell_type\", \"disease_state\"] # Example task columns" ] }, { "cell_type": "code", "execution_count": null, "id": "c9bd7562", "metadata": {}, "outputs": [], "source": [ "# Check GPU environment\n", "num_gpus = torch.cuda.device_count()\n", "use_distributed = num_gpus > 1\n", "print(f\"Number of GPUs detected: {num_gpus}\")\n", "print(f\"Using distributed training: {use_distributed}\")\n", "\n", "# Set environment variables for distributed training when multiple GPUs are available\n", "if use_distributed:\n", " os.environ[\"MASTER_ADDR\"] = \"localhost\" # hostname\n", " os.environ[\"MASTER_PORT\"] = \"12355\" # Choose an available port\n", " print(\"Distributed environment variables set.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "b6ff3618", "metadata": {}, "outputs": [], "source": [ "#Define Hyperparameters for Optimization\n", "hyperparameters = {\n", " \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n", " \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n", " \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n", " \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n", " \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n", " \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0},\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "f665c5a7", "metadata": {}, "outputs": [], "source": [ "mc = MTLClassifier(\n", " task_columns=task_columns, # Our defined classification tasks\n", " study_name=\"MTLClassifier_distributed\",\n", " pretrained_path=pretrained_path,\n", " train_path=train_path,\n", " val_path=val_path,\n", " test_path=test_path,\n", " model_save_path=model_save_path,\n", " results_dir=results_dir,\n", " tensorboard_log_dir=tensorboard_log_dir,\n", " hyperparameters=hyperparameters,\n", " # Distributed training parameters\n", " distributed_training=use_distributed, # Enable distributed training if multiple GPUs available\n", " master_addr=\"localhost\" if use_distributed else None,\n", " master_port=\"12355\" if use_distributed else None,\n", " # Other training parameters\n", " n_trials=15, # Number of trials for hyperparameter optimization\n", " epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n", " batch_size=8, # Adjust based on available GPU memory\n", " gradient_accumulation_steps=4, # Accumulate gradients over multiple steps\n", " gradient_clipping=True, # Enable gradient clipping for stability\n", " max_grad_norm=1.0, # Set maximum gradient norm\n", " seed=42\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "f69f7b6a", "metadata": {}, "outputs": [], "source": [ "# Run Hyperparameter Optimization with Distributed Training\n", "if __name__ == \"__main__\":\n", " # This guard is required for distributed training to prevent\n", " # infinite subprocess spawning when using torch.multiprocessing\n", " mc.run_optuna_study()" ] }, { "cell_type": "code", "execution_count": null, "id": "3affd5dd", "metadata": {}, "outputs": [], "source": [ "# Evaluate the Model on Test Data\n", "if __name__ == \"__main__\":\n", " mc.load_and_evaluate_test_model()" ] } ], "metadata": { "kernelspec": { "display_name": "bio", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 5 }