File size: 5,751 Bytes
131da64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
## Small Scale Training
UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1 accelerate launch --main_process_port=$RANDOM main.py +experiments='[small_scale_train,ar]' loader.batch_size=8 wandb.name='10_11_ar' trainer.val_check_interval=100 debug=true

### Caching
UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1 accelerate launch --main_process_port=$RANDOM main.py +experiments='[small_scale_caching_train]' loader.batch_size=8 wandb.name='10_11_ar' trainer.val_check_interval=100 debug=true

## Large Scale Training

To train the large scale experiments, we recommend to set the following environment variables:
```
unset CUDA_VISIBLE_DEVICES; unset CUDA_LAUNCH_BLOCKING; unset NCCL_SOCKET_IFNAME; unset NCCL_NSOCKS_PERTHREAD; unset NCCL_SOCKET_NTHREADS; unset OMP_NUM_THREADS; unset NCCL_P2P_DISABLE; unset NCCL_P2P_LEVEL

export NCCL_DEBUG=INFO
export PYTHONUNBUFFERED=1
export UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1
export UNIDISC_ROOT_OUTPUT_DIR="outputs"
```

We have several stages of training. You will have to load from the previous stage's checkpoint, using either `trainer.load_from_state_dict` or by setting `HYDRA_RUN_DIR_NAME` which will automatically load the latest checkpoint/optimizer/sampler states.

1st stage:
```
accelerate launch --main_process_port=$RANDOM main.py +experiments='[large_scale_train]' debug=true loader.batch_size=8
```

2nd stage:
```
accelerate launch --main_process_port=$RANDOM main.py +experiments='[large_scale_train,large_scale_train_high_res]' debug=true loader.batch_size=4
```

3rd stage:
```
accelerate launch --main_process_port=$RANDOM main.py +experiments='[large_scale_train,large_scale_train_high_res_interleaved]' debug=true loader.batch_size=1
```

## SLURM Training
To train on SLURM, you have two options:

1. sbatch directly from the script: `sbatch scripts/train_large_scale_slurm.sh`. Please edit the script to set the correct number of nodes, gpus per node, and other parameters.

2. Use sbatch through hydra_submitit_launcher. This is recommended for smaller experiments, partially hyperparameter sweeps, although it has been setup to work for multi-node training.

First, uncomment everything under `hydra.launcher` in `configs/config.yaml`. Next, uncomment this at the top of `configs/config.yaml`:
```
- override hydra/launcher: submitit_slurm
```

Next, run the following (change to `uv pip install` if you are using uv):

```
pip install 'git+ssh://[email protected]/alexanderswerdlow/hydra.git@working_ci#egg=hydra-core'
pip install 'git+ssh://[email protected]/alexanderswerdlow/hydra.git@working_ci#egg=hydra-submitit-launcher&subdirectory=plugins/hydra_submitit_launcher'
```

To use hydra_submitit_launcher, append the following to any command:

`devices=8 nodes=1 partition=general --multirun`

You may modify `devices`/`nodes`/`partition` based on your set. See `hydra.launcher` in `configs/config.yaml` to set additional SLURM parameters.


Both of the above methods use 1 task per node. Some SLURM clusters prefer to use 1 task per GPU. Please see this amazing guide for more details: `https://github.com/stas00/ml-engineering`. In short, be careful in making this change as there are many subtle differences (e.g., passing signals between processes, how checkpointing/requeing/error handling works, etc.)

## TPU Training

TODO: Add more documentation for TPU training. Our codebase is setup to use TPUs through `torch_xla`, taking advantage of SPMD.

Misc TPU Notes:
- SPMD has a very confusing setup for SPMD and pretends that each node is a single device. See `decoupled_utils.py` for some of the logic used to handle this. Moreover, getting the proper rank can only be done after spawning SPMD, so we need to handle this as a lot of code needs the device rank on import.

## Attention Kernels
The codebase currently supports the following:

- PyTorch SDPA (Including CUDNN Flash, Regular Flash, etc.)
- PyTorch FlexAttention (For interleaved/caching training)
- Flash Attention 2/3 (With varying kernels, e.g., packed/non-packed/varlen depending on the use case)
- TorchXLA SPMD FlashAttention (For TPU training)

Generally, we use PyTorch SDPA (preferrably the CUDDN Kernel which you can force with `UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=1`) and FlexAttention for all interleaved/caching training, setting `trainer.compile=true` to improve MFU. We found this to be similar in speed to Flash Attention 2, which at the time of development did not have good compile support.

# Config Notes
- `data.enable_cuda_in_tensordict_collate` requires `loader.num_workers = 0`
- Enable `data.move_tensordict_to_shm` to speed up dataloading (keeping `data.keep_tensordict_on_disk = true`), assuming you have enough system memory. Separately, you can disable `data.keep_tensordict_on_disk`, but this will load the entire tensordict into each dataloader worker process (e.g., given 2 GPUs and 4 workers, this will load 8 tensordicts into system memory) which is not possible on larger datasets. Optionally, you can set `+data.shm_path='/path/to/data'` to use a custom path, e.g., to use a scratch disk instead of system memory.
- You may need to set `NCCL_IB_DISABLE=1` or `NCCL_P2P_DISABLE=1` depending on the system configuration. Setting `NCCL_P2P_LEVEL=NVL` is recommended if the system has NVLink.
- To use `data.enable_cuda_in_tensordict_collate=true`, you must also set `data.force_mp_spawn=false` and `loader.num_workers>0`.
- Resuming from checkpoint can be done in multiple ways. Double check the log output to verify the correct checkpoint is being loaded. The tl;dr is the following: If you use `hydra_submitit_launcher` or set `HYDRA_RUN_DIR_NAME`, it will automatically load the latest checkpoint/optimizer/sampler states.
- To resume from weights only set: `trainer.load_from_state_dict="/path/to/weights.bin"`