unidisc / docs /TRAIN.md
aswerdlow's picture
Initial commit
131da64

Small Scale Training

UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1 accelerate launch --main_process_port=$RANDOM main.py +experiments='[small_scale_train,ar]' loader.batch_size=8 wandb.name='10_11_ar' trainer.val_check_interval=100 debug=true

Caching

UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1 accelerate launch --main_process_port=$RANDOM main.py +experiments='[small_scale_caching_train]' loader.batch_size=8 wandb.name='10_11_ar' trainer.val_check_interval=100 debug=true

Large Scale Training

To train the large scale experiments, we recommend to set the following environment variables:

unset CUDA_VISIBLE_DEVICES; unset CUDA_LAUNCH_BLOCKING; unset NCCL_SOCKET_IFNAME; unset NCCL_NSOCKS_PERTHREAD; unset NCCL_SOCKET_NTHREADS; unset OMP_NUM_THREADS; unset NCCL_P2P_DISABLE; unset NCCL_P2P_LEVEL

export NCCL_DEBUG=INFO
export PYTHONUNBUFFERED=1
export UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1
export UNIDISC_ROOT_OUTPUT_DIR="outputs"

We have several stages of training. You will have to load from the previous stage's checkpoint, using either trainer.load_from_state_dict or by setting HYDRA_RUN_DIR_NAME which will automatically load the latest checkpoint/optimizer/sampler states.

1st stage:

accelerate launch --main_process_port=$RANDOM main.py +experiments='[large_scale_train]' debug=true loader.batch_size=8

2nd stage:

accelerate launch --main_process_port=$RANDOM main.py +experiments='[large_scale_train,large_scale_train_high_res]' debug=true loader.batch_size=4

3rd stage:

accelerate launch --main_process_port=$RANDOM main.py +experiments='[large_scale_train,large_scale_train_high_res_interleaved]' debug=true loader.batch_size=1

SLURM Training

To train on SLURM, you have two options:

  1. sbatch directly from the script: sbatch scripts/train_large_scale_slurm.sh. Please edit the script to set the correct number of nodes, gpus per node, and other parameters.

  2. Use sbatch through hydra_submitit_launcher. This is recommended for smaller experiments, partially hyperparameter sweeps, although it has been setup to work for multi-node training.

First, uncomment everything under hydra.launcher in configs/config.yaml. Next, uncomment this at the top of configs/config.yaml:

- override hydra/launcher: submitit_slurm

Next, run the following (change to uv pip install if you are using uv):

pip install 'git+ssh://[email protected]/alexanderswerdlow/hydra.git@working_ci#egg=hydra-core'
pip install 'git+ssh://[email protected]/alexanderswerdlow/hydra.git@working_ci#egg=hydra-submitit-launcher&subdirectory=plugins/hydra_submitit_launcher'

To use hydra_submitit_launcher, append the following to any command:

devices=8 nodes=1 partition=general --multirun

You may modify devices/nodes/partition based on your set. See hydra.launcher in configs/config.yaml to set additional SLURM parameters.

Both of the above methods use 1 task per node. Some SLURM clusters prefer to use 1 task per GPU. Please see this amazing guide for more details: https://github.com/stas00/ml-engineering. In short, be careful in making this change as there are many subtle differences (e.g., passing signals between processes, how checkpointing/requeing/error handling works, etc.)

TPU Training

TODO: Add more documentation for TPU training. Our codebase is setup to use TPUs through torch_xla, taking advantage of SPMD.

Misc TPU Notes:

  • SPMD has a very confusing setup for SPMD and pretends that each node is a single device. See decoupled_utils.py for some of the logic used to handle this. Moreover, getting the proper rank can only be done after spawning SPMD, so we need to handle this as a lot of code needs the device rank on import.

Attention Kernels

The codebase currently supports the following:

  • PyTorch SDPA (Including CUDNN Flash, Regular Flash, etc.)
  • PyTorch FlexAttention (For interleaved/caching training)
  • Flash Attention 2/3 (With varying kernels, e.g., packed/non-packed/varlen depending on the use case)
  • TorchXLA SPMD FlashAttention (For TPU training)

Generally, we use PyTorch SDPA (preferrably the CUDDN Kernel which you can force with UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1) and FlexAttention for all interleaved/caching training, setting trainer.compile=true to improve MFU. We found this to be similar in speed to Flash Attention 2, which at the time of development did not have good compile support.

Config Notes

  • data.enable_cuda_in_tensordict_collate requires loader.num_workers = 0
  • Enable data.move_tensordict_to_shm to speed up dataloading (keeping data.keep_tensordict_on_disk = true), assuming you have enough system memory. Separately, you can disable data.keep_tensordict_on_disk, but this will load the entire tensordict into each dataloader worker process (e.g., given 2 GPUs and 4 workers, this will load 8 tensordicts into system memory) which is not possible on larger datasets. Optionally, you can set +data.shm_path='/path/to/data' to use a custom path, e.g., to use a scratch disk instead of system memory.
  • You may need to set NCCL_IB_DISABLE=1 or NCCL_P2P_DISABLE=1 depending on the system configuration. Setting NCCL_P2P_LEVEL=NVL is recommended if the system has NVLink.
  • To use data.enable_cuda_in_tensordict_collate=true, you must also set data.force_mp_spawn=false and loader.num_workers>0.
  • Resuming from checkpoint can be done in multiple ways. Double check the log output to verify the correct checkpoint is being loaded. The tl;dr is the following: If you use hydra_submitit_launcher or set HYDRA_RUN_DIR_NAME, it will automatically load the latest checkpoint/optimizer/sampler states.
  • To resume from weights only set: trainer.load_from_state_dict="/path/to/weights.bin"