In [None]:
import os
import torch
from geneformer import MTLClassifier

In [None]:
# Define paths
pretrained_path = "/path/to/pretrained/Geneformer/model" 
# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)
train_path = "/path/to/train/data.dataset"
val_path = "/path/to/val/data.dataset"
test_path = "/path/to/test/data.dataset"
results_dir = "/path/to/results/directory"
model_save_path = "/path/to/model/save/path"
tensorboard_log_dir = "/path/to/tensorboard/log/dir"

# Define tasks and hyperparameters
# task_columns should be a list of column names from your dataset
# Each column represents a specific classification task (e.g. cell type, disease state)
task_columns = ["cell_type", "disease_state"] # Example task columns

In [None]:
# Check GPU environment
num_gpus = torch.cuda.device_count()
use_distributed = num_gpus > 1
print(f"Number of GPUs detected: {num_gpus}")
print(f"Using distributed training: {use_distributed}")

# Set environment variables for distributed training when multiple GPUs are available
if use_distributed:
 os.environ["MASTER_ADDR"] = "localhost" # hostname
 os.environ["MASTER_PORT"] = "12355" # Choose an available port
 print("Distributed environment variables set.")

In [None]:
#Define Hyperparameters for Optimization
hyperparameters = {
 "learning_rate": {"type": "float", "low": 1e-5, "high": 1e-3, "log": True},
 "warmup_ratio": {"type": "float", "low": 0.005, "high": 0.01},
 "weight_decay": {"type": "float", "low": 0.01, "high": 0.1},
 "dropout_rate": {"type": "float", "low": 0.0, "high": 0.7},
 "lr_scheduler_type": {"type": "categorical", "choices": ["cosine"]},
 "task_weights": {"type": "float", "low": 0.1, "high": 2.0},
}

In [None]:
mc = MTLClassifier(
 task_columns=task_columns, # Our defined classification tasks
 study_name="MTLClassifier_distributed",
 pretrained_path=pretrained_path,
 train_path=train_path,
 val_path=val_path,
 test_path=test_path,
 model_save_path=model_save_path,
 results_dir=results_dir,
 tensorboard_log_dir=tensorboard_log_dir,
 hyperparameters=hyperparameters,
 # Distributed training parameters
 distributed_training=use_distributed, # Enable distributed training if multiple GPUs available
 master_addr="localhost" if use_distributed else None,
 master_port="12355" if use_distributed else None,
 # Other training parameters
 n_trials=15, # Number of trials for hyperparameter optimization
 epochs=1, # Number of training epochs (1 suggested to prevent overfitting)
 batch_size=8, # Adjust based on available GPU memory
 gradient_accumulation_steps=4, # Accumulate gradients over multiple steps
 gradient_clipping=True, # Enable gradient clipping for stability
 max_grad_norm=1.0, # Set maximum gradient norm
 seed=42
)

In [None]:
# Run Hyperparameter Optimization with Distributed Training
if __name__ == "__main__":
 # This guard is required for distributed training to prevent
 # infinite subprocess spawning when using torch.multiprocessing
 mc.run_optuna_study()

In [None]:
# Evaluate the Model on Test Data
if __name__ == "__main__":
 mc.load_and_evaluate_test_model()