aswerdlow commited on
Commit
131da64
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +37 -0
  3. .gitmodules +15 -0
  4. Dockerfile +79 -0
  5. README.md +82 -0
  6. __builtins__.pyi +7 -0
  7. configs/config.yaml +451 -0
  8. configs/config_empty.yaml +8 -0
  9. configs/experiments/ar.yaml +10 -0
  10. configs/experiments/elm.yaml +15 -0
  11. configs/experiments/eval_model.yaml +21 -0
  12. configs/experiments/eval_text.yaml +26 -0
  13. configs/experiments/eval_text_only.yaml +30 -0
  14. configs/experiments/eval_unified.yaml +27 -0
  15. configs/experiments/fid_cc12m.yaml +22 -0
  16. configs/experiments/fid_datacomp1b.yaml +22 -0
  17. configs/experiments/fid_hf.yaml +25 -0
  18. configs/experiments/jan_cub.yaml +51 -0
  19. configs/experiments/large_maskdit_exp.yaml +7 -0
  20. configs/experiments/large_scale_high_res_interleaved_inference.yaml +51 -0
  21. configs/experiments/large_scale_train.yaml +151 -0
  22. configs/experiments/large_scale_train_high_res.yaml +39 -0
  23. configs/experiments/large_scale_train_high_res_inference.yaml +30 -0
  24. configs/experiments/large_scale_train_high_res_interleaved.yaml +105 -0
  25. configs/experiments/maskgit.yaml +6 -0
  26. configs/experiments/master_eval.yaml +49 -0
  27. configs/experiments/mscoco_fid.yaml +21 -0
  28. configs/experiments/paired_standalone_fid_eval.yaml +29 -0
  29. configs/experiments/small_scale_train.yaml +187 -0
  30. configs/experiments/small_scale_train_caching.yaml +186 -0
  31. configs/experiments/small_text_only.yaml +28 -0
  32. configs/experiments/standalone_fid_eval.yaml +18 -0
  33. configs/experiments/titok.yaml +8 -0
  34. configs/experiments/titok_sl256.yaml +7 -0
  35. configs/experiments/txt_only.yaml +21 -0
  36. configs/experiments/unified.yaml +23 -0
  37. configs/experiments/vq16.yaml +9 -0
  38. configs/experiments/vq16_1024.yaml +8 -0
  39. configs/experiments/vq16_magvit.yaml +9 -0
  40. configs/experiments/vq16_t2i.yaml +10 -0
  41. configs/experiments/webdataset.yaml +12 -0
  42. configs/experiments/zero_shot_eval.yaml +29 -0
  43. configs/lr_scheduler/constant_warmup.yaml +2 -0
  44. configs/lr_scheduler/constant_warmup_cosine_decay.yaml +3 -0
  45. configs/lr_scheduler/cosine_decay_warmup.yaml +7 -0
  46. configs/lr_scheduler/cosine_with_hard_restarts_schedule_with_warmup.yaml +4 -0
  47. configs/model/extra_large.yaml +10 -0
  48. configs/model/large.yaml +14 -0
  49. configs/model/medium.yaml +12 -0
  50. configs/model/small-ar.yaml +11 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
+ *.webp filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ outputs/
3
+ ckpts/
4
+ vqgan/vqgan_pretrained/
5
+ vqgan/vqgan_taming_ckpt/
6
+ data/
7
+ models/datasets/.cache/
8
+ *.json
9
+ output/
10
+ tmp*
11
+ multirun/
12
+ .nfs*
13
+ lightning_logs/
14
+ static/
15
+ archive/
16
+ output_profile/
17
+ logs/
18
+ .history/
19
+ .cache/
20
+ output*/
21
+ *.out
22
+ *.parquet
23
+ wandb/
24
+ vqgan/
25
+ *.csv
26
+ .python-version
27
+ ft_cache/
28
+ alias.txt
29
+ env.sh
30
+ generated_image.png
31
+ Untitled-1.ipynb
32
+ *.log
33
+ demo/old
34
+ *.pem
35
+ .sesskey
36
+ icons.py
37
+ generated/
.gitmodules ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "third_party/LlamaGen"]
2
+ path = third_party/LlamaGen
3
+ url = https://github.com/alexanderswerdlow/LlamaGen.git
4
+ branch = wip_v1
5
+ [submodule "third_party/Lumina-mGPT"]
6
+ path = third_party/Lumina-mGPT
7
+ url = https://github.com/alexanderswerdlow/Lumina-mGPT.git
8
+ branch = non_causal
9
+ [submodule "third_party/Show-o"]
10
+ path = third_party/Show-o
11
+ url = https://github.com/showlab/Show-o.git
12
+ [submodule "third_party/1d-tokenizer"]
13
+ path = third_party/1d-tokenizer
14
+ url = https://github.com/bytedance/1d-tokenizer.git
15
+ branch = main
Dockerfile ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image with CUDA 12.6.3 and cuDNN
2
+ FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04
3
+
4
+ # Set environment variables
5
+ ARG DEBIAN_FRONTEND=noninteractive
6
+ ENV PYTHONUNBUFFERED=1 \
7
+ SYSTEM=spaces \
8
+ AM_I_IN_A_DOCKER_CONTAINER=Yes \
9
+ PYTHONPATH=/home/appuser/app \
10
+ HF_HOME=/home/appuser/.cache \
11
+ TORCH_HOME=/home/appuser/.cache \
12
+ TMP_DIR=/home/appuser/tmp \
13
+ TRANSFORMERS_CACHE=/home/appuser/.cache/transformers \
14
+ NVIDIA_VISIBLE_DEVICES=all \
15
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility
16
+
17
+ # Install system dependencies and set Python 3.10 as default
18
+ RUN apt-get update && apt-get install --no-install-recommends -y \
19
+ build-essential \
20
+ python3.10 \
21
+ python3.10-distutils \
22
+ python3-pip \
23
+ ffmpeg \
24
+ libsm6 \
25
+ libxext6 \
26
+ libgl1 \
27
+ git \
28
+ openssh-client \
29
+ && ln -sf /usr/bin/python3.10 /usr/bin/python \
30
+ && ln -sf /usr/bin/pip3 /usr/bin/pip \
31
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
32
+
33
+ # Install `uv`
34
+ RUN pip install --upgrade pip \
35
+ && pip install uv
36
+
37
+ # Create a non-root user
38
+ RUN useradd -m -u 1000 appuser
39
+
40
+ # Set working directory
41
+ WORKDIR /home/appuser/app
42
+
43
+ # Copy dependency files and install dependencies
44
+ COPY --chown=appuser pyproject.toml uv.lock README.md ./
45
+ RUN mkdir -p -m 0600 ~/.ssh && ssh-keyscan github.com >> ~/.ssh/known_hosts
46
+
47
+ RUN --mount=type=ssh uv sync --no-group dev
48
+ RUN --mount=type=ssh uv sync --frozen --no-cache \
49
+ && chown -R appuser:appuser /home/appuser/app/.venv \
50
+ && rm -rf /root/.cache /home/appuser/.cache
51
+
52
+ # Ensure non-root user has write access to cache and tmp directories
53
+ RUN mkdir -p /home/appuser/.cache/transformers /home/appuser/tmp /home/appuser/.cache \
54
+ && chown -R appuser:appuser /home/appuser/.cache /home/appuser/tmp/ /home/appuser/app/
55
+
56
+ RUN chmod -R 777 /tmp
57
+
58
+ # Copy application code
59
+ COPY --chown=appuser demo demo
60
+ COPY --chown=appuser unidisc unidisc
61
+ COPY --chown=appuser models models
62
+ COPY --chown=appuser configs configs
63
+ COPY --chown=appuser third_party third_party
64
+ COPY --chown=appuser ckpts ckpts
65
+ COPY --chown=appuser ./__* ./
66
+ COPY --chown=appuser ./*.py ./
67
+ COPY --chown=appuser ./archive/pytorch_model_fsdp.bin ./
68
+
69
+ # Switch to non-root user
70
+ USER appuser
71
+
72
+ # Expose port for Gradio
73
+ EXPOSE 5003
74
+
75
+ # Command to run the application
76
+ CMD ["bash", "demo/demo.sh"]
77
+
78
+ # DOCKER_BUILDKIT=1 docker build --ssh default --network=host -t unidisc .
79
+ # docker run --network=host -it -p 5003:5003 unidisc
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <br>
3
+ <img src="docs/images/banner.webp" width="1000">
4
+ <h3>Unified Multimodal Discrete Diffusion</h3>
5
+
6
+ [Alexander Swerdlow](https://aswerdlow.com/)<sup>1&#42;</sup>&nbsp;
7
+ [Mihir Prabhudesai](https://mihirp1998.github.io/)<sup>1&#42;</sup>&nbsp;
8
+ [Siddharth Gandhi](hhttps://www.ssgandhi.com/)<sup>1</sup>&nbsp;
9
+ [Deepak Pathak](https://www.cs.cmu.edu/~dpathak/)<sup>1</sup>&nbsp;
10
+ [Katerina Fragkiadaki](https://www.cs.cmu.edu/~katef/)<sup>1</sup>&nbsp;
11
+ <br>
12
+
13
+ <sup>1</sup> Carnegie Mellon University&nbsp;
14
+
15
+ [![ArXiv](https://img.shields.io/badge/ArXiv-<0000.00000>-<COLOR>.svg)](https://arxiv.org/pdf/0000.00000) [![Webpage](https://img.shields.io/badge/Webpage-UniDisc-<COLOR>.svg)](https://unidisc.github.io/)
16
+
17
+ <!-- [![Demo](https://img.shields.io/badge/Demo-Custom-<COLOR>.svg)](https://huggingface.co/spaces/todo) -->
18
+
19
+ </div>
20
+
21
+ ## Hugging Face models and annotations
22
+
23
+ The UniDisc checkpoints are available on [Hugging Face](https://huggingface.co/unidisc):
24
+ * [unidisc/todo](https://huggingface.co/unidisc/todo)
25
+
26
+ ## Getting Started
27
+
28
+ To install the dependencies, run:
29
+ ```bash
30
+ git submodule update --init --recursive
31
+ uv sync --no-group dev
32
+ uv sync
33
+ ```
34
+
35
+ For a more detailed installation guide, please refer to [INSTALL.md](docs/INSTALL.md).
36
+
37
+ ## Training
38
+
39
+ See [TRAIN.md](docs/TRAIN.md) for details.
40
+
41
+ ## Inference
42
+
43
+ <!-- Inference demo for **TODO**.
44
+ ```
45
+ TODO
46
+ ``` -->
47
+ <!-- <img src="docs/todo.png" width="1000"> -->
48
+
49
+
50
+ Interactive demo for **TODO**.
51
+ ```
52
+ python demo/server.py
53
+ python demo/client_simple_fasthtml.py
54
+ ```
55
+
56
+
57
+ ## Training
58
+
59
+ See [TRAINING.md](docs/TRAINING.md) for details.
60
+
61
+ ## Evaluation
62
+
63
+ See [EVAL.md](docs/EVAL.md) for details.
64
+
65
+
66
+ ### Citation
67
+ To cite our work, please use the following:
68
+ ```
69
+ @article{TODO,
70
+ title={TODO},
71
+ author={TODO},
72
+ journal={arXiv preprint arXiv:TODO},
73
+ year={TODO}
74
+ }
75
+ ```
76
+
77
+ ## Credits
78
+
79
+ This repository is built on top of the following repositories:
80
+
81
+ - [MDLM](https://github.com/kuleshov-group/mdlm)
82
+ - [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X)
__builtins__.pyi ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from ipdb import set_trace as st
2
+ from decoupled_utils import start_timing as start_timing
3
+ from decoupled_utils import end_timing as end_timing
4
+ ENABLE_TIMING: bool
5
+ ENABLE_TIMING_SYNC: bool
6
+ DEVICE_BACKEND_TYPE: str
7
+ exists = lambda v: v is not None
configs/config.yaml ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - /model: small
4
+ - /noise: loglinear
5
+ - /lr_scheduler: constant_warmup
6
+ - /experiments: []
7
+ # - override hydra/launcher: submitit_slurm
8
+
9
+ slurm: False
10
+ debug: False
11
+ mode: train # train / eval
12
+ diffusion: absorbing_state
13
+ backbone: dit # dit / dimamba / ar
14
+ parameterization: subs # subs / d3pm / sedd
15
+ time_conditioning: False
16
+ T: 0 # 0 (continuous time) / 1000
17
+ subs_masking: False
18
+ seed: 42
19
+ profile: False
20
+ # These belong in trainer.* and hydra.launcher.* but are put here for CLI convinience
21
+ devices: ${device_count:}
22
+ nodes: 1
23
+ partition: ${find_partition:}
24
+ constraint: ${find_constraint:}
25
+ ckpt: null
26
+
27
+ loader:
28
+ desired_global_batch_size: 512
29
+ global_batch_size: null
30
+ eval_global_batch_size: ${.global_batch_size}
31
+ batch_size: ${div_up:${.desired_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
32
+ eval_batch_size: ${div_up:${.desired_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
33
+ num_workers: ${eval:"max(len(__import__('os').sched_getaffinity(0)) // 16, 4)"}
34
+ pin_memory: True
35
+ persistent_workers: True
36
+
37
+ sampling:
38
+ predictor: ddpm_cache # analytic, ddpm, ddpm_cache
39
+ steps: 1000
40
+ max_sampling_steps: 500 # The highest level we use for sampling
41
+ noise_removal: True
42
+ num_sample_log: 2
43
+ semi_ar: False
44
+ stride_length: 1
45
+ num_strides: 1
46
+
47
+ eval:
48
+ checkpoint_path: '' # Used to evaluate a checkpoint after training.
49
+ disable_ema: False
50
+ compute_generative_perplexity: False
51
+ perplexity_batch_size: 8
52
+ gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
53
+ generate_samples: True
54
+ cfg: null
55
+ num_masking_viz_batches: 1
56
+ num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
57
+ test_eval_speed: False
58
+ standalone_fid: False
59
+ visualize_data_only: false
60
+ val_with_train_data: false
61
+ max_num_fid_batches_per_device: null
62
+ class_conditional_fid: false
63
+ compute_entropy: false
64
+ compute_standalone_mauve: false
65
+ compute_standalone_entropy: false
66
+ compute_img_to_txt_mauve_clip: false
67
+ compute_img_to_txt_mauve_during_unconditional_fid: false
68
+ mauve_num_samples: 5000
69
+ mauve_divergence_curve_discretization_size: 25 # default in mauve repo
70
+ mauve_average_over_seeds: 3
71
+ mauve_scaling_factor: 5 # default in mauve repo
72
+ txt_conditional_fid: false
73
+ unconditional_fid: false
74
+ fid_mode: inline
75
+ calculate_clip_score: false
76
+ clean_fid_use_precomputed_stats: false
77
+ clean_fid_precomputed_name: null
78
+ clean_fid_precomputed_split: null
79
+ clean_fid_precomputed_res: null
80
+ attention_caching: false
81
+ set_random_gen_seed: false
82
+ compute_val_metrics_standalone: false
83
+ num_val_metrics_standalone_batches_per_device: ${eval:'max(${eval.num_val_metrics_standalone_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
84
+ num_val_metrics_standalone_samples: -1
85
+ return_unweighed_sim: false
86
+ compute_chameleon_perplexity: false
87
+ global_disable_mauve: false
88
+ bypass_normal_validation: false
89
+ auto_enhance: false
90
+ num_auto_enhance_iter: 2
91
+ ar_inpainting_min_val: 0.5
92
+ ar_inpainting_max_val: 1.0
93
+ ar_inpainting_force_val: null
94
+
95
+ optim:
96
+ weight_decay: 0
97
+ lr: 3e-4
98
+ beta1: 0.9
99
+ beta2: 0.999
100
+ eps: 1e-8
101
+ fused: true
102
+
103
+ model:
104
+ use_custom_vae_config: false
105
+ use_custom_vae_ckpt: null
106
+ downscale_ratio: null
107
+ image_vocab_size: null
108
+ vae_type: null
109
+ use_attention_mask: false
110
+
111
+ cond_use_custom_vae_config: false
112
+ cond_use_custom_vae_ckpt: null
113
+ cond_downscale_ratio: null
114
+ cond_image_vocab_size: null
115
+ cond_vae_type: null
116
+ text_model: true
117
+
118
+ attn_type: flash
119
+ force_varlen_attn: false
120
+ force_cast_bf16: false
121
+ norm_type: layernorm
122
+ mup: false
123
+ qk_norm: false
124
+ distillation: false
125
+ force_argmax_valid_indices: false
126
+ use_flash_attn_3: false
127
+ use_spda_attn: false # Spelled wrong...
128
+ rope_2d: false
129
+ modality_embed: false
130
+ zero_linear_init: true
131
+ full_attention: true
132
+ use_lora: false
133
+ use_kv_cache: false
134
+ force_optimized_native_attn: false
135
+ use_pretrained_img_emb: true
136
+ use_flex_attention: false
137
+ add_labels: null
138
+ flex_attention_txt_masking_prob: null
139
+ flex_attention_img_masking_prob: null
140
+
141
+ trainer:
142
+ _target_: lightning.Trainer
143
+ accelerator: cuda
144
+ num_nodes: ${nodes}
145
+ devices: ${devices}
146
+
147
+ # Given a desired global batch size (e.g., how many batches we see before a optim.step, summed over all nodes/gpus/accum_steps), we find the number of gradient accumulations that gets us closest given our current configuration. We assume that loader.batch_size is the largest that can fit in a single fwd/bwd.
148
+ accumulate_grad_batches: ${find_grad_accum:${loader.desired_global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
149
+ gradient_clip_val: 1.0
150
+ precision: 'bf16'
151
+ max_steps: 1_000_000_000
152
+
153
+ num_epochs: 1_000_000_000
154
+ optimizer_cls: adamw
155
+ set_grads_to_none: true
156
+ eval_on_start: true
157
+ eval_decay_steps: false
158
+ eval_epochs: null
159
+ ckpt_steps: 100000
160
+ fsdp: false
161
+ force_enable_checkpointing: false
162
+ limit_val_batches: null
163
+ ckpt_every_n_minutes: 60
164
+ ckpt_recent_timeout_minutes: 10
165
+ checkpoint_all_ranks: true
166
+ force_null_sigma: false
167
+
168
+ log_every_n_steps: 10
169
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
170
+ val_check_interval: 100
171
+
172
+ ema: 0.9999
173
+ antithetic_sampling: True
174
+ importance_sampling: False
175
+ sampling_eps: 1e-3
176
+ change_of_variables: False
177
+ benchmark: true
178
+ backward_pass: true
179
+ forward_pass: true
180
+ profile_memory: false
181
+ pytorch_profile: false
182
+ nvtx_profile: false
183
+ custom_ddp_bf16: true
184
+ log_seperate_modal_losses: true
185
+ use_gradient_checkpointing: false
186
+ text_loss_weight: null
187
+ img_loss_weight: null
188
+ disable_strict_load: false
189
+ attach_oom_observer_eval: false
190
+ find_unused_parameters: false
191
+ restart_on_failure: false
192
+ skip_early_checkpointing: true
193
+ log_flops: true
194
+ sync_timing: false
195
+ use_custom_ema: false
196
+ scale_lr_by_batch_size: false
197
+ tpu_eager: false
198
+ allow_dynamic_nodes: false
199
+ force_disable_signal_handler: false
200
+ tpu_profile: false
201
+ tpu_cache: false
202
+ enable_jax_smi: false
203
+ tpu_compile_debug: false
204
+ xla_spmd: false
205
+ log_grad_norm: true
206
+ tpu_profile_markers: true
207
+ compile: false
208
+ disable_all_checkpointing: false
209
+ tpu_force_mark_step: false
210
+ ar_shift: false
211
+ ar_llm_loss: false
212
+ ar_print_loss: false
213
+ chameleon_z_loss: null
214
+ image_mode: discrete # continuous / discrete
215
+ chameleon_use_ce_loss: false
216
+ low_precision_loss: false
217
+ low_precision_params: false
218
+ scratch: false
219
+ use_spmd_distributed_checkpointing: null
220
+ use_simple_spmd_distributed_checkpointing: false
221
+ load_from_state_dict: null
222
+ load_from_optimizer_state_dict: null
223
+ multimodal_batches: false
224
+ sync_dataloader_timing: false
225
+ compile_flag_pos_emb: false
226
+ compile_fullgraph: false
227
+ compile_mode: max-autotune-no-cudagraphs
228
+ joint_ar_nar_prob: null
229
+ joint_ar_nar_prob_warmup_steps: null
230
+ joint_ar_nar_timestep_warmup_steps: null
231
+ spmd_mesh: null
232
+ detect_anomaly: false
233
+ freeze_chameleon_embeddings: false
234
+ ckpt_model_only: false
235
+ use_orig_params: null
236
+ disable_adjust_num_warmup_steps: false
237
+ mask_entire_modality: null
238
+ iterate_dataloader_only: false
239
+ force_bf16_eval: false
240
+ disable_all_eval_generation: false
241
+ debug_xla_sept: false
242
+ ignore_text_in_unified: false
243
+ allow_null_sigma: false
244
+ disable_forward_autocast_during_eval: false
245
+ viz_images_only: false
246
+ add_label: false
247
+ first_token_dropout: null
248
+ disable_ddp_optimizer: false
249
+ rand_flip_ar_prob: null
250
+ rand_ar_modality_dropout: null
251
+ use_linear_warmup_cosine_annealing: false
252
+ no_ce_weighting: false
253
+ interleaved: false
254
+ interleaved_training_flex_attention: false
255
+ awr: false
256
+ ar_inpainting: false
257
+
258
+ wandb:
259
+ entity: grads
260
+ project: ${eval:'"unidisc-debug" if ${debug} else "unidisc"'}
261
+ resume: ${eval:'"allow" if ${slurm} else None'}
262
+ id: null
263
+ group: null
264
+ job_type: null
265
+ name: null
266
+ tags:
267
+ - ${data.train}
268
+
269
+ checkpointing_root_dir: ${oc.env:UNIDISC_CHECKPOINTING_ROOT_DIR,null}
270
+ root_output_dir: ${oc.env:UNIDISC_ROOT_OUTPUT_DIR,outputs}
271
+ python_orig: |
272
+ accelerate launch \
273
+ --num_machines $SLURM_NNODES \
274
+ --num_processes $NUM_PROCESSES \
275
+ --rdzv_backend c10d \
276
+ --main_process_ip $MASTER_ADDR \
277
+ --main_process_port $MASTER_PORT \
278
+ --machine_rank $SLURM_PROCID \
279
+ --mixed_precision bf16 \
280
+ --dynamo_backend no \
281
+ --enable_cpu_affinity \
282
+ --max_restarts 0 \
283
+
284
+ mem_per_gpu: 40
285
+ cpus_per_gpu: 8
286
+ slurm_name: null
287
+ timeout_min: ${partition_limit:${partition}}
288
+ hydra:
289
+ run:
290
+ dir: ${oc.env:HYDRA_RUN_DIR,${root_output_dir}/outputs/${get_dir_name:}/${oc.env:HYDRA_RUN_DIR_NAME,${now:%Y_%m_%d}/${now:%H_%M_%S}}}
291
+ sweep:
292
+ dir: ${oc.env:HYDRA_RUN_DIR,${root_output_dir}/outputs/${get_dir_name:}/${oc.env:HYDRA_RUN_DIR_NAME,${now:%Y_%m_%d}/${now:%H_%M_%S}}}
293
+ subdir: ${hydra.job.id}
294
+ job:
295
+ chdir: true
296
+ # launcher:
297
+ # name: ${get_slurm_name:}
298
+ # # See https://hydra.cc/docs/configure_hydra/workdir/
299
+ # submitit_folder: ${hydra.sweep.dir}/%j
300
+ # nodes: ${nodes} # Number of nodes. This value is *per* node
301
+ # mem_gb: ${eval:'${mem_per_gpu} * ${trainer.devices}'} # 40GB per gpu. This value is *per* node
302
+ # gpus_per_node: ${trainer.devices}
303
+ # partition: ${partition}
304
+ # constraint: ${constraint}
305
+ # exclude: ${exclude_nodes:}
306
+
307
+ # timeout_min: ${timeout_min}
308
+ # max_num_timeout: 12 # Num requeue exlcuding pre-emptions
309
+ # comment: aswerdlo
310
+ # stderr_to_stdout: true
311
+
312
+ # # Be careful with changing anything below.
313
+ # # see: https://github.com/stas00/ml-engineering/tree/master/training/fault-tolerance#approach-b2-choosing-which-process-to-send-the-signal-to
314
+ # # see: https://github.com/huggingface/accelerate/issues/1918
315
+
316
+ # # The accelerate launcher w/1 initial process and then spawn 1 per GPU
317
+ # tasks_per_node: 1
318
+ # cpus_per_task: ${eval:'${cpus_per_gpu} * ${trainer.devices}'}
319
+ # python: |
320
+ # bash -c "torchrun --nnodes $SLURM_NNODES --nproc_per_node $SLURM_GPUS_PER_NODE --role \$(hostname -s|tr -dc '0-9'): --node_rank \$SLURM_PROCID --max-restarts=2 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
321
+
322
+ # # python: "${getpythoncmd:}"
323
+ # # tasks_per_node: ${devices}
324
+ # # cpus_per_task: 8
325
+ # # python: 'python'
326
+
327
+ # python_suffix: ' --dummy-arg $SLURM_JOB_ID" &'
328
+ # signal: 'B:USR2@360'
329
+ # post_srun_commands:
330
+ # - ''
331
+ # - wait
332
+
333
+ # srun_args:
334
+ # - '--jobid $SLURM_JOB_ID'
335
+
336
+ # setup:
337
+ # - |
338
+ # export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
339
+ # export MASTER_PORT=$(( ($SLURM_JOB_ID % 20001) + 30000 ))
340
+ # export NUM_PROCESSES=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
341
+ # export NCCL_DEBUG=INFO
342
+ # export NCCL_NSOCKS_PERTHREAD=4
343
+ # export NCCL_SOCKET_NTHREADS=2
344
+ # export OMP_NUM_THREADS=2
345
+ # export PYTHONUNBUFFERED=1
346
+ # export STDOUT_PATH=$(scontrol show job $SLURM_JOB_ID | grep -oP "StdOut=\K[^ ]+")
347
+ # export LOCAL_JOB_FOLDER=$(dirname $STDOUT_PATH)
348
+ # export NCCL_TOPO_DUMP_FILE="$LOCAL_JOB_FOLDER/nccl_topo.xml"
349
+ # if [ -n "$SLURM_RESTART_COUNT" ]; then
350
+ # export RESTART_COUNT=$SLURM_RESTART_COUNT
351
+ # else
352
+ # export RESTART_COUNT=0
353
+ # fi
354
+ # export MAIN_LOG_PATH="$LOCAL_JOB_FOLDER/log_$RESTART_COUNT.txt"
355
+
356
+ # mkdir -p $LOCAL_JOB_FOLDER
357
+ # printenv > "$LOCAL_JOB_FOLDER"/env_"$SLURM_LOCALID_$RESTART_COUNT.txt"
358
+
359
+ # echo "ibstatus: $(ibstatus)"
360
+ # echo "ibdev2netdev: $(ibdev2netdev)"
361
+ # echo "rdma device: $(rdma link)"
362
+ # echo "environment: $(env | grep NCCL)"
363
+ # echo "NUM_PROCESSES: $NUM_PROCESSES, SLURM_NNODES: $SLURM_NNODES SLURM_GPUS_PER_NODE: $SLURM_GPUS_PER_NODE"
364
+ # echo "NODE_ID: $SLURM_NODEID, SLURM_PROCID: $SLURM_PROCID, MASTER_ADDR: $MASTER_ADDR, MASTER_PORT: $MASTER_PORT"
365
+ # echo "PWD: $PWD, LOCAL_JOB_FOLDER: $LOCAL_JOB_FOLDER, MAIN_LOG_PATH: $MAIN_LOG_PATH"
366
+
367
+ # trap 'echo "SIGUSR2 received for $SLURM_JOB_ID"; \
368
+ # if [ -n "$SLURM_ARRAY_JOB_ID" ]; then echo "SLURM_ARRAY_JOB_ID: $SLURM_ARRAY_JOB_ID"; fi; \
369
+ # if [ -n "$SLURM_ARRAY_TASK_ID" ]; then echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"; fi; \
370
+ # # ps auxww | grep $USER; \
371
+ # pid=$(pgrep -u $USER -f "python.*(accelerate|torchrun|deepspeed|distributed\.run).*dummy-arg $SLURM_JOB_ID"); \
372
+ # echo "Found parent PIDs: $pid"; \
373
+ # for p in $pid; do \
374
+ # echo "Parent PID has cmd: $(ps -p $p -o cmd=)"; \
375
+ # children=$(pgrep -P $p); \
376
+ # echo "Children: $children"; \
377
+ # if [ -n "$children" ]; then \
378
+ # for child in $children; do \
379
+ # ppid=$(ps -o ppid= -p $child | tr -d " ")
380
+ # if [ "$ppid" -eq "$p" ]; then
381
+ # echo "Killing direct child process: PID $child with cmd: $(ps -p $child -o cmd=)"
382
+ # kill -USR2 $child &
383
+ # else
384
+ # echo "Skipping non-direct child process: PID $child with PPID $ppid"
385
+ # fi
386
+ # done; \
387
+ # echo "Sent kill signals to children of $p"; \
388
+ # else \
389
+ # echo "No children found for $p"; \
390
+ # fi; \
391
+ # done; \
392
+ # wait;' SIGUSR2
393
+
394
+ checkpointing:
395
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
396
+ save_dir: ${cwd:}/checkpoints
397
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
398
+ resume_from_ckpt: true
399
+ resume_ckpt_path: ${cwd:}/checkpoints
400
+ initial_resume_ckpt_path: null
401
+ resume_wandb: true
402
+ checkpoints_total_limit: 2
403
+ use_automatic_naming: false
404
+
405
+
406
+ data:
407
+ cache_dir: ${oc.env:HF_DATASETS_CACHE,/grogu/user/mprabhud/aswerdlo/huggingface/datasets}
408
+ num_proc: ${eval:"max(len(__import__('os').sched_getaffinity(0)) // 4, 16)"}
409
+ cond_resolution: null
410
+ iterable: false
411
+ force_disable_shuffle: false
412
+ pin_dataset_to_gpu: false
413
+ webdataset_iterable: false
414
+ webdataset_train_data: null
415
+ webdataset_val_data: null
416
+ webdataset_train_num_samples: null
417
+ webdataset_val_num_samples: null
418
+ webdataset_indexed: false
419
+ dataset_type: null
420
+ keep_tensordict_on_disk: false
421
+ use_token_dataset: false
422
+ use_custom_tensordict_collate: false
423
+ use_weighted_tensordict_sampler: false
424
+ enable_cuda_in_tensordict_collate: true
425
+ data_dir_train: null
426
+ data_dir_val: null
427
+ token_output_dir: null
428
+ wrap_dataloaders: true
429
+ force_shuffle_train: false
430
+ move_tensordict_to_shm: false
431
+ keep_hf_dataset_in_memory: false
432
+ use_chameleon: false
433
+ tokenize_vqvae_in_dataloader: false
434
+ force_mp_spawn: false
435
+ force_raw_images_in_multiple_tensordict: false
436
+ disable_text_modality: false
437
+ txt_only: false
438
+ disable_mask_after_eos: false
439
+ allow_label: false
440
+ split_dataset: false
441
+ img_token_shift: ${model.text_vocab_size}
442
+ zero_shot_eval_dataset: null
443
+ require_sample_ids: false
444
+ use_packing_collate: false
445
+ dynamic_packing_lengths: false
446
+ remove_txt_img_padding: false
447
+ add_image_gen_tokens: false
448
+ use_slow_tokenizer: false
449
+ add_image_token: false
450
+
451
+ dummyarg: null
configs/config_empty.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - /model: small
4
+ - /experiments: []
5
+
6
+ # from omegaconf import OmegaConf
7
+ # with open("config.yaml", "w") as fp:
8
+ # OmegaConf.save(config=config, f=fp.name)
configs/experiments/ar.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ parameterization: ar
4
+
5
+ trainer:
6
+ ar_shift: true
7
+
8
+ model:
9
+ full_attention: false
10
+ use_flex_attention: false
configs/experiments/elm.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ backbone: elm
4
+
5
+ data:
6
+ tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
7
+
8
+ model:
9
+ use_lora: false
10
+ full_attention: true
11
+ model_id: apple/OpenELM-270M # apple/OpenELM-1_1B
12
+
13
+ trainer:
14
+ use_gradient_checkpointing: false
15
+ sd3_compile_config: false
configs/experiments/eval_model.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ mode: eval
4
+
5
+ loader:
6
+ batch_size: 16
7
+ eval_batch_size: 16
8
+
9
+ trainer:
10
+ disable_all_eval_generation: false
11
+
12
+ eval:
13
+ compute_generative_perplexity: true
14
+ generate_samples: true
15
+ num_sample_batches: 20
16
+ log_every_n_fid: 1
17
+ log_every_n_evals: 1
18
+ compute_standalone_mauve: true
19
+ mauve_num_samples: 5000
20
+ # mauve_divergence_curve_discretization_size: 200 # works well for our repo
21
+ # mauve_scaling_factor: 2 # works well for our repo
configs/experiments/eval_text.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ mode: eval
4
+
5
+ sampling:
6
+ steps: 100
7
+ max_sampling_steps: 100
8
+
9
+ loader:
10
+ batch_size: 2
11
+ eval_batch_size: 2
12
+
13
+ trainer:
14
+ fsdp: false
15
+
16
+ eval:
17
+ perplexity_batch_size: 2
18
+ num_masking_viz_batches: 2
19
+ log_every_n_evals: 1
20
+ num_uncond_sample_batches: 2
21
+ num_sample_batches: 2
22
+ num_random_masking: 1
23
+ masking_batch_size: 2
24
+ cfg: null
25
+ generate_samples: true
26
+ compute_generative_perplexity: false
configs/experiments/eval_text_only.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ mode: eval
4
+ debug: true
5
+
6
+ sampling:
7
+ steps: 100
8
+ max_sampling_steps: 100
9
+
10
+ loader:
11
+ batch_size: 2
12
+ eval_batch_size: 2
13
+
14
+ trainer:
15
+ fsdp: false
16
+
17
+ model:
18
+ image_model_fid_eval: false
19
+
20
+ eval:
21
+ log_every_n_evals: 1
22
+ perplexity_batch_size: 2
23
+ num_uncond_sample_batches: 2
24
+ num_sample_batches: 2
25
+ num_masking_viz_batches: -1
26
+ num_random_masking: -1
27
+ masking_batch_size: -1
28
+ cfg: null
29
+ generate_samples: true
30
+ compute_generative_perplexity: true
configs/experiments/eval_unified.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ mode: eval
4
+ devices: ${device_count:}
5
+
6
+ sampling:
7
+ steps: 500
8
+ max_sampling_steps: 1000
9
+
10
+ loader:
11
+ batch_size: 6
12
+ eval_batch_size: 6
13
+
14
+ trainer:
15
+ fsdp: false
16
+ disable_all_eval_generation: false
17
+
18
+ eval:
19
+ perplexity_batch_size: 6
20
+ num_masking_viz_batches: 12
21
+ log_every_n_evals: 1
22
+ num_uncond_sample_batches: 5
23
+ num_sample_batches: 2
24
+ num_random_masking: 3
25
+ masking_batch_size: 6
26
+ cfg: 6.0
27
+ generate_samples: false
configs/experiments/fid_cc12m.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ data:
4
+ keep_hf_dataset_in_memory: true
5
+ aggressive_aug: false
6
+ n_duplicate_train: null
7
+ n_duplicate_val: null
8
+
9
+ tokenize_vqvae_in_dataloader: false
10
+ enable_cuda_in_tensordict_collate: false
11
+ force_mp_spawn: false
12
+ keep_tensordict_on_disk: false
13
+ move_tensordict_to_shm: false
14
+
15
+ fid_dataset: cc12m_tokens_val_256
16
+ image_data_train: null
17
+ image_data_val: null
18
+ data_dir_train: ${data.data_dir_val}
19
+ data_dir_val:
20
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
21
+ weight: 1
22
+ name: ${data.fid_dataset}
configs/experiments/fid_datacomp1b.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ data:
4
+ keep_hf_dataset_in_memory: true
5
+ aggressive_aug: false
6
+ n_duplicate_train: null
7
+ n_duplicate_val: null
8
+
9
+ tokenize_vqvae_in_dataloader: false
10
+ enable_cuda_in_tensordict_collate: false
11
+ force_mp_spawn: false
12
+ keep_tensordict_on_disk: false
13
+ move_tensordict_to_shm: false
14
+
15
+ fid_dataset: datacomp1b_8_magvit_val
16
+ image_data_train: null
17
+ image_data_val: null
18
+ data_dir_train: ${data.data_dir_val}
19
+ data_dir_val:
20
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
21
+ weight: -1
22
+ name: ${data.fid_dataset}
configs/experiments/fid_hf.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ data:
4
+ disable_text_modality: false
5
+ keep_hf_dataset_in_memory: true
6
+ aggressive_aug: false
7
+ n_duplicate_train: null
8
+ n_duplicate_val: null
9
+ data_dir_train: []
10
+ data_dir_val: []
11
+ fid_dataset: sayakpaul/coco-30-val-2014
12
+ train: combined_tokens
13
+ val: {.train}
14
+ image_data_val:
15
+ - val: ${data.fid_dataset}
16
+ weight: -1
17
+ name: ${.val}
18
+ tokenize_vqvae_in_dataloader: false
19
+ raw_images: true
20
+ image_data_train:
21
+ - train: ${data.fid_dataset}
22
+ weight: -1
23
+ name: ${.train}
24
+ tokenize_vqvae_in_dataloader: false
25
+ raw_images: true
configs/experiments/jan_cub.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: medium
5
+ - override /lr_scheduler: cosine_with_hard_restarts_schedule_with_warmup
6
+
7
+ loader:
8
+ batch_size: 16
9
+ eval_batch_size: 16
10
+ desired_global_batch_size: 128
11
+ num_workers: 4
12
+
13
+ trainer:
14
+ ckpt_steps: 5000
15
+ val_check_interval: 100
16
+ use_legacy_update_batch_fn: true
17
+ mask_txt_only: true
18
+ mask_entire_modality: 0.15
19
+ ema: 0.9999
20
+ use_custom_ema: true
21
+ force_enable_checkpointing: true
22
+ skip_early_checkpointing: false
23
+ force_after_eos_padding: false
24
+
25
+ checkpointing:
26
+ checkpoints_total_limit: 20
27
+
28
+ lr_scheduler:
29
+ num_warmup_steps: 10000
30
+ num_training_steps: 400000
31
+ num_cycles: 80
32
+
33
+ data:
34
+ resolution: 256
35
+ train: cub2011_custom
36
+ use_weighted_tensordict_sampler: false
37
+
38
+ model:
39
+ vae_type: titok128
40
+ txt_length: 18
41
+ img_length: 128
42
+ rope_2d: false
43
+ force_text_vocab_size: 5450
44
+ text_vocab_size: 5451
45
+ image_vocab_size: 8192
46
+ attn_dropout: 0.1
47
+
48
+ optim:
49
+ lr: 1.0e-04
50
+ weight_decay: 0.2
51
+ beta2: 0.99
configs/experiments/large_maskdit_exp.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: large_maskdit
5
+
6
+
7
+ backbone: maskdit
configs/experiments/large_scale_high_res_interleaved_inference.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ debug: true
4
+ seed: 163
5
+
6
+ loader:
7
+ eval_batch_size: 1
8
+ batch_size: 1
9
+
10
+ data:
11
+ move_tensordict_to_shm: false
12
+ resolution: 1024
13
+ disable_mask_after_eos: true
14
+ disable_packing: true
15
+ data_dir_val:
16
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
17
+ weight: 1.0
18
+ name: HPDv2_image_reward_512
19
+
20
+ model:
21
+ img_length: 4096
22
+ txt_length: 1024
23
+ length: 5120
24
+
25
+ trainer:
26
+ compile: false
27
+ limit_val_batches: 2
28
+ fsdp: false
29
+ force_full_attention_mask: true
30
+ force_null_sigma: true
31
+ allow_null_sigma: true
32
+
33
+ eval:
34
+ num_sample_batches: 1
35
+ num_random_masking: 0
36
+ num_masking_viz_batches: 0
37
+ limit_val_batches_manual: 1
38
+ num_uncond_sample_batches: 10
39
+ eval_large_batch: 10
40
+ val_with_train_data: false
41
+ maskgit_r_temp: 4.5
42
+ half_uncond: false
43
+ cfg: 3.0
44
+ return_interleaved_modalities_split: true
45
+ static_img_txt_demo: true
46
+ visualize_sample: true
47
+
48
+ sampling:
49
+ steps: 50
50
+ max_sampling_steps: 50
51
+ predictor: "maskgit"
configs/experiments/large_scale_train.yaml ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - vq16_t2i
5
+ - override /model: extra_large
6
+
7
+ data:
8
+ train: combined_tokens
9
+ valid: ${.train}
10
+ precache: false
11
+ streaming: false
12
+ resolution: 256
13
+ block_size: 128
14
+ tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
15
+ wrap: true
16
+ iterable: false
17
+ webdataset_iterable: false
18
+ webdataset_indexed: false
19
+ unpaired: false
20
+ dataset_type: null
21
+ tokens_flip_collate: false
22
+ n_val_samples: null
23
+ n_train_samples: null
24
+ n_duplicate_train: null
25
+ n_duplicate_val: null
26
+ raw_data_dir: null
27
+ save_train_dataloader: true
28
+ save_validation_dataloader: true
29
+ tokenizers_parallelism: false
30
+ token_data_dir: null
31
+ force_disable_shuffle: false
32
+ use_custom_tensordict_collate: true
33
+ use_weighted_tensordict_sampler: true
34
+ force_mp_spawn: false
35
+ enable_cuda_in_tensordict_collate: false
36
+ use_token_dataset: true
37
+ keep_tensordict_on_disk: true
38
+ move_tensordict_to_shm: false
39
+ add_text_to_weighted_sampler: false
40
+ data_dir_train:
41
+ # - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
42
+ # weight: 15.0
43
+ # name: hpdv2
44
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
45
+ weight: 1.0
46
+ name: pixelprose
47
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/journeydb_train
48
+ weight: 10.0
49
+ name: journeydb_train
50
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_0_tokens
51
+ weight: 1.0
52
+ name: datacomp0
53
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_1_tokens
54
+ weight: 1.0
55
+ name: datacomp1
56
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_2_tokens
57
+ weight: 1.0
58
+ name: datacomp2
59
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_3_tokens
60
+ weight: 1.0
61
+ name: datacomp3
62
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_4_tokens
63
+ weight: 1.0
64
+ name: datacomp4
65
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_5_tokens
66
+ weight: 1.0
67
+ name: datacomp5
68
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_6_tokens
69
+ weight: 1.0
70
+ name: datacomp6
71
+ data_dir_val:
72
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
73
+ weight: 1.0
74
+ name: dummy_1
75
+
76
+ model:
77
+ img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
78
+ txt_length: ${eval:'${data.block_size} if ${.unified_model} else 0'}
79
+ length: ${eval:'${.txt_length} + ${.img_length}'}
80
+ unified_model: true
81
+ image_model: true
82
+ text_model: true
83
+ image_model_fid_eval: false
84
+ force_argmax_valid_indices: true
85
+ use_pretrained_img_emb: false
86
+ rope_2d: true
87
+ modality_embed: true
88
+ norm_type: rms
89
+ qk_norm: true
90
+ sandwich_normalization: true
91
+ text_vocab_size: 32001
92
+
93
+ loader:
94
+ batch_size: 8
95
+ eval_batch_size: ${eval:'${.batch_size} // 2'}
96
+ desired_global_batch_size: 512
97
+ persistent_workers: true
98
+ pin_memory: false
99
+ num_workers: 0
100
+ num_eval_workers: 0
101
+ eval:
102
+ log_every_n_evals: -1
103
+ log_every_n_fid: -1
104
+ limit_val_batches_manual: 16
105
+ generate_samples: true
106
+ compute_generative_perplexity: false
107
+ perplexity_batch_size: ${loader.eval_batch_size}
108
+ cfg: 5.0
109
+ num_val_metrics_standalone_samples: -1
110
+ num_val_metrics_standalone_batches_per_device: -1
111
+ auto_enhance_reward_config:
112
+ dfn_score: 1.0
113
+ laion_aesthetic_score: 1.0
114
+
115
+ trainer:
116
+ log_flops: false
117
+ log_every_n_steps: 10
118
+ custom_ddp_bf16: true
119
+ log_seperate_modal_losses: true
120
+ limit_val_batches: 16
121
+ softmin_snr: 5
122
+ text_loss_weight: 1.0
123
+ img_loss_weight: 0.6
124
+ use_gradient_checkpointing: false
125
+ ckpt_steps: 20000
126
+ ckpt_every_n_minutes: 180
127
+ ckpt_recent_timeout_minutes: 10
128
+ use_custom_ema: false
129
+ ema: 0.0
130
+ fsdp: true
131
+ restart_on_failure: true
132
+ eval_on_start: false
133
+ val_check_interval: 100000000000
134
+ scale_lr_by_batch_size: false
135
+ watch_gradients: false
136
+ compile: true
137
+ mask_entire_modality: 0.15
138
+ compile_flag_pos_emb: true
139
+ multimodal_batches: true
140
+ optim:
141
+ lr: 0.0001
142
+ sampling:
143
+ steps: 128
144
+ num_sample_batches: 2
145
+ wandb:
146
+ mode: online
147
+ checkpointing:
148
+ checkpoints_total_limit: 10
149
+ use_automatic_naming: false
150
+ lr_scheduler:
151
+ num_warmup_steps: 10000
configs/experiments/large_scale_train_high_res.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # @package _global_
3
+
4
+ data:
5
+ resolution: 512
6
+ data_dir_train:
7
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
8
+ weight: 1
9
+ name: HPDv2_image_reward_512
10
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_512
11
+ weight: 2
12
+ name: pick_score_sac_prompts_v1_v2_v3_512
13
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_7_512
14
+ weight: 0.5
15
+ name: datacomp1b_7_512
16
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/text/slimpajama6b
17
+ weight: 2.5
18
+ name: slimpajama6b
19
+ data_dir_val:
20
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/gecko_eval_512
21
+ weight: 1.0
22
+ name: gecko_eval_512
23
+
24
+ trainer:
25
+ text_loss_weight: 1.0
26
+ img_loss_weight: 0.5
27
+ force_full_attention_mask: true
28
+ mask_entire_modality: 0.1
29
+
30
+ loader:
31
+ pin_memory: false
32
+ num_workers: 4
33
+ num_eval_workers: 4
34
+
35
+ lr_scheduler:
36
+ num_warmup_steps: 5000
37
+
38
+ model:
39
+ linear_factor: 2
configs/experiments/large_scale_train_high_res_inference.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ data:
4
+ use_token_dataset: true
5
+ disable_mask_after_eos: true
6
+ move_tensordict_to_shm: false
7
+
8
+ trainer:
9
+ compile_flag_pos_emb: true
10
+ multimodal_batches: true
11
+ allow_null_sigma: true
12
+
13
+ eval:
14
+ num_sample_batches: 1
15
+ num_random_masking: 0
16
+ num_masking_viz_batches: 0
17
+ limit_val_batches_manual: 1
18
+ num_uncond_sample_batches: 10
19
+ eval_large_batch: 10
20
+ val_with_train_data: false
21
+ maskgit_r_temp: 4.5
22
+ half_uncond: false
23
+ cfg: 3.0
24
+ static_img_txt_demo: true
25
+ visualize_sample: true
26
+
27
+ sampling:
28
+ steps: 50
29
+ max_sampling_steps: 50
30
+ predictor: "maskgit"
configs/experiments/large_scale_train_high_res_interleaved.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # @package _global_
3
+
4
+ data:
5
+ move_tensordict_to_shm: false
6
+ enable_cuda_in_tensordict_collate: false
7
+ force_mp_spawn: false
8
+ resolution: 512
9
+ add_text_to_weighted_sampler: false
10
+
11
+ add_image_gen_tokens: true
12
+ use_packing_collate: true
13
+ dynamic_packing_lengths: true
14
+ remove_txt_img_padding: true
15
+ require_sample_ids: true
16
+ block_size: ${model.length}
17
+ disable_mask_after_eos: true
18
+ add_image_token: true
19
+ use_slow_tokenizer: true
20
+ force_seed: true
21
+
22
+ data_dir_train:
23
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
24
+ weight: 0.5
25
+ name: HPDv2_image_reward_v1_v2_v3 # 3593248
26
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_512
27
+ weight: 1.0
28
+ name: pick_score_sac_prompts_v1_v2_v3_512 # 9330810
29
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
30
+ weight: 1.0
31
+ name: pixelprose_tokens # 6627589
32
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/cambrian_10m_v5
33
+ weight: 1.0
34
+ name: cambrian_10m_v5 # 8215264
35
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_7_512
36
+ weight: 1.0
37
+ name: datacomp1b_7_512 # 23955209
38
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_2_tokens
39
+ weight: 0.5
40
+ name: datacomp_1b_datacomp1b_2_tokens # 10161505
41
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_4_tokens
42
+ weight: 0.5
43
+ name: datacomp_1b_datacomp1b_4_tokens # 27895717
44
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/mmc4_fewer_faces_v0
45
+ weight: 2.0
46
+ name: mmc4_fewer_faces_v0 # 22605524
47
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_5_tokens
48
+ weight: 0.5
49
+ name: datacomp_1b_datacomp1b_5_tokens
50
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_0_tokens
51
+ weight: 0.5
52
+ name: datacomp_1b_datacomp1b_0_tokens
53
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_1_tokens
54
+ weight: 0.5
55
+ name: datacomp_1b_datacomp1b_1_tokens
56
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/cosmopedia_2_v0
57
+ weight: 1.0
58
+ name: cosmopedia_v2
59
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/fineweb_edu_dedup_v0
60
+ weight: 1.0
61
+ name: fineweb_edu_dedup
62
+ data_dir_val:
63
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/gecko_eval_512
64
+ weight: 1.0
65
+ name: gecko_eval_512
66
+
67
+ trainer:
68
+ text_loss_weight: 1.0
69
+ img_loss_weight: 0.2
70
+ mask_entire_modality: 0.2
71
+
72
+ force_full_attention_mask: false
73
+ force_full_attention_mask_loss_only: false
74
+ disable_all_eval_generation: true
75
+ interleaved: true
76
+ interleaved_training_flex_attention: true
77
+ force_convert_to_dict: true
78
+ val_check_interval: -1
79
+ use_gradient_checkpointing: true
80
+ disable_all_checkpointing: false
81
+ set_max_txt_loss_ratio: true
82
+ gradient_clip_val: 1.0
83
+ skip_early_checkpointing: false
84
+ bypass_load_from_state_dicts_if_resuming: true
85
+
86
+ loader:
87
+ num_workers: 4
88
+ num_eval_workers: 4
89
+
90
+ lr_scheduler:
91
+ num_warmup_steps: 5000
92
+
93
+ model:
94
+ linear_factor: 2
95
+ use_flex_attention: true
96
+ use_spda_attn: true
97
+
98
+ length: 1536
99
+ txt_length: ${.length}
100
+ img_length: ${.length}
101
+
102
+ eval:
103
+ generate_samples: false
104
+ disable_visualization: true
105
+
configs/experiments/maskgit.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ model:
4
+ downscale_ratio: 16
5
+ image_vocab_size: 1024
6
+ vae_type: maskgit
configs/experiments/master_eval.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ mode: eval
4
+
5
+ eval:
6
+ fid_samples: 4096
7
+ max_num_fid_batches_per_device: ${eval:'max(${eval.fid_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
8
+ compute_generative_perplexity: true
9
+ generate_samples: true
10
+ log_every_n_fid: 1
11
+ log_every_n_evals: 1
12
+ class_conditional_fid: false
13
+ txt_conditional_fid: true
14
+ calculate_clip_score: true
15
+ cfg: 5
16
+ num_sample_batches: 2
17
+ compute_standalone_mauve: false
18
+ mauve_num_samples: -1
19
+ set_random_gen_seed: true
20
+ # gen_ppl_eval_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
21
+ compute_img_to_txt_mauve_clip: true
22
+ compute_img_to_txt_mauve_during_unconditional_fid: true
23
+ force_eval_uncond: true
24
+ ablation_config: true
25
+ compute_val_metrics_standalone: true
26
+ num_val_metrics_standalone_samples: 2000
27
+
28
+ trainer:
29
+ disable_all_eval_generation: false
30
+ force_after_eos_padding: true
31
+
32
+ model:
33
+ image_model_fid_eval: true
34
+ use_kv_cache: ${is_ar:${parameterization}}
35
+
36
+ loader:
37
+ batch_size: 64
38
+ eval_batch_size: 64
39
+ num_workers: 0
40
+ num_eval_workers: 1
41
+
42
+ sampling:
43
+ steps: ${model.length}
44
+ max_sampling_steps: ${sampling.steps}
45
+ sampling_step_frac: null
46
+
47
+
48
+ data:
49
+ fid_dataset: null
configs/experiments/mscoco_fid.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ data:
4
+ disable_text_modality: false
5
+ keep_hf_dataset_in_memory: true
6
+ aggressive_aug: false
7
+ n_duplicate_train: null
8
+ n_duplicate_val: null
9
+ data_dir_train: []
10
+ data_dir_val: []
11
+ image_data_train: ${data.image_data_val}
12
+ image_data_val:
13
+ - val: sayakpaul/coco-30-val-2014
14
+ weight: -1
15
+ name: mscoco_val
16
+ tokenize_vqvae_in_dataloader: false
17
+ raw_images: true
18
+
19
+ eval:
20
+ compute_generative_perplexity: true
21
+ generate_samples: true
configs/experiments/paired_standalone_fid_eval.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ mode: eval
4
+ debug: true
5
+
6
+ eval:
7
+ fid_samples: 4096
8
+ max_num_fid_batches_per_device: ${eval:'max(${eval.fid_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
9
+ compute_generative_perplexity: false
10
+ generate_samples: false
11
+ log_every_n_fid: 1
12
+ log_every_n_evals: 1
13
+ class_conditional_fid: false
14
+ txt_conditional_fid: true
15
+ calculate_clip_score: true
16
+ cfg: 5
17
+
18
+ model:
19
+ image_model_fid_eval: true
20
+
21
+ loader:
22
+ eval_batch_size: 32
23
+
24
+ sampling:
25
+ steps: ${model.length}
26
+ max_sampling_steps: ${model.length}
27
+
28
+ data:
29
+ keep_hf_dataset_in_memory: false
configs/experiments/small_scale_train.yaml ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - vq16_magvit
5
+ - override /model: small
6
+ - override /lr_scheduler: constant_warmup_cosine_decay
7
+
8
+ model:
9
+ img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
10
+ txt_length: ${eval:'${data.block_size} if ${.unified_model} else 0'}
11
+ length: ${eval:'${.txt_length} + ${.img_length}'}
12
+ image_model: true
13
+ text_model: true
14
+ unified_model: true
15
+ image_model_fid_eval: false
16
+ force_argmax_valid_indices: true
17
+ use_pretrained_img_emb: false
18
+ codebook_embed_dim: 256
19
+ qk_norm: true
20
+ norm_type: rms
21
+ sandwich_normalization: true
22
+ zero_linear_init: false
23
+ modality_embed: true
24
+ rope_2d: false
25
+ use_spda_attn: true
26
+ force_optimized_native_attn: true
27
+ freeze_txt_emb: false
28
+ add_labels: null
29
+ txt_dropout: null
30
+ text_vocab_size: 32001
31
+
32
+ data:
33
+ train: combined_tokens
34
+ valid: ${.train}
35
+ n_duplicate_train: null
36
+ wrap: true
37
+ streaming: false
38
+ precache: false
39
+ tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
40
+ resolution: 256
41
+ block_size: 128
42
+ n_val_samples: null
43
+ unpaired: false
44
+ n_duplicate_val: null
45
+ save_train_dataloader: true
46
+ save_validation_dataloader: true
47
+ iterable: false
48
+ webdataset_iterable: false
49
+ webdataset_indexed: false
50
+ dataset_type: null
51
+ tokens_flip_collate: false
52
+ n_train_samples: null
53
+ raw_data_dir: null
54
+ tokenizers_parallelism: false
55
+ token_data_dir: null
56
+ force_disable_shuffle: false
57
+ keep_tensordict_on_disk: true
58
+ use_custom_tensordict_collate: true
59
+ force_mp_spawn: false
60
+ enable_cuda_in_tensordict_collate: false
61
+ use_weighted_tensordict_sampler: true
62
+ fraction_txt_data: 0.0
63
+ tokenize_vqvae_in_dataloader: false
64
+ use_token_dataset: true
65
+ image_dataset: tglcourse/lsun_church_train
66
+ image_data_train: null
67
+ image_data_val: null
68
+ keep_hf_dataset_in_memory: true
69
+ allow_label: false
70
+ disable_text_modality: true
71
+ force_raw_train_images: false
72
+ aggressive_aug: true
73
+ allow_aug_vqvae_dataloader: true
74
+ move_tensordict_to_shm: false
75
+ data_dir_train:
76
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
77
+ weight: -1
78
+ name: datacomp1b_8_magvit_train
79
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
80
+ weight: -1
81
+ name: cc12m_tokens_train_256
82
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
83
+ weight: -1
84
+ name: HPDv2_image_reward_v1_v2_v3_magvit
85
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
86
+ weight: -1
87
+ name: pick_score_sac_prompts_v1_v2_v3_magvit
88
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
89
+ weight: -1
90
+ name: datacomp1b_0_1_6_magvit
91
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
92
+ weight: -1
93
+ name: laion400m_magvit_part_0
94
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
95
+ weight: -1
96
+ name: laion400m_magvit_part_1
97
+ data_dir_val:
98
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
99
+ weight: 1
100
+ name: datacomp1b_8_magvit_val
101
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
102
+ weight: 1
103
+ name: cc12m_tokens_val_256
104
+
105
+ eval:
106
+ generate_samples: true
107
+ compute_generative_perplexity: true
108
+ log_every_n_evals: 10
109
+ log_every_n_fid: 20
110
+ limit_val_batches_manual: 16
111
+ perplexity_batch_size: ${loader.eval_batch_size}
112
+ num_masking_viz_batches: -1
113
+ cfg: null
114
+ class_conditional_fid: false
115
+ force_cfg_value: true
116
+ split_cfg_batches: true
117
+ max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
118
+ fid_mode: clean
119
+ clean_fid_precomputed_name: lsun_church
120
+ clean_fid_precomputed_split: trainfull
121
+ clean_fid_precomputed_res: 256
122
+
123
+ trainer:
124
+ log_every_n_steps: 10
125
+ val_check_interval: 1000
126
+ custom_ddp_bf16: true
127
+ scale_lr_by_batch_size: false
128
+ limit_val_batches: 16
129
+ use_gradient_checkpointing: false
130
+ log_seperate_modal_losses: true
131
+ softmin_snr: 5
132
+ text_loss_weight: 1.0
133
+ img_loss_weight: null
134
+ low_precision_loss: false
135
+ compile: true
136
+ multimodal_batches: true
137
+ compile_fullgraph: false
138
+ log_grad_norm_every_n_steps: 10
139
+ mask_entire_modality: 0.1
140
+ force_shift_image_batches: false
141
+ ckpt_steps: 10000
142
+ ckpt_every_n_minutes: -1
143
+ ignore_text_in_unified: false
144
+ disable_all_eval_generation: true
145
+ eval_on_start: false
146
+ ckpt_model_only: false
147
+ ema: 0.0
148
+ use_custom_ema: false
149
+ log_flops: false
150
+ disable_distributed_torchmetrics: true
151
+ restart_on_failure: true
152
+ force_null_sigma: true
153
+ allow_null_sigma: true
154
+ compile_flag_pos_emb: true
155
+ add_label: false
156
+ first_token_dropout: null
157
+ force_shift_raw_image_batches: true
158
+ txt_dropout: 0.1
159
+ force_full_attention_mask_loss_only: true
160
+
161
+ optim:
162
+ lr: 0.0003
163
+ weight_decay: 0.05
164
+
165
+ loader:
166
+ batch_size: 64
167
+ eval_batch_size: ${loader.batch_size}
168
+ num_workers: 4
169
+ desired_global_batch_size: 512
170
+ persistent_workers: true
171
+ pin_memory: true
172
+ num_eval_workers: 1
173
+
174
+ sampling:
175
+ steps: ${model.length}
176
+ num_sample_batches: 2
177
+ max_sampling_steps: ${model.length}
178
+
179
+ wandb:
180
+ mode: online
181
+
182
+ lr_scheduler:
183
+ num_warmup_steps: 5000
184
+ num_training_steps: ${trainer.max_steps}
185
+
186
+ checkpointing:
187
+ checkpoints_total_limit: 10
configs/experiments/small_scale_train_caching.yaml ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - /model: small
5
+
6
+ model:
7
+ downscale_ratio: 16
8
+ image_vocab_size: 8192
9
+ vae_type: magvit
10
+ use_custom_vae_ckpt: null
11
+ custom_vae_name: null
12
+ img_length: 256
13
+ txt_length: 128
14
+ image_model: true
15
+ text_model: true
16
+ unified_model: true
17
+ image_model_fid_eval: false
18
+ force_argmax_valid_indices: true
19
+ use_pretrained_img_emb: false
20
+ codebook_embed_dim: 256
21
+ qk_norm: true
22
+ norm_type: rms
23
+ sandwich_normalization: true
24
+ zero_linear_init: false
25
+ modality_embed: true
26
+ rope_2d: false
27
+ use_spda_attn: true
28
+ force_optimized_native_attn: true
29
+ freeze_txt_emb: false
30
+ add_labels: null
31
+ txt_dropout: null
32
+ text_vocab_size: 32001
33
+ use_flex_attention: true
34
+ flex_attention_txt_masking_prob: 0.1
35
+ flex_attention_img_masking_prob: 0.1
36
+ linear_factor: 1
37
+ data:
38
+ train: combined_tokens
39
+ valid: ${.train}
40
+ n_duplicate_train: null
41
+ wrap: true
42
+ streaming: false
43
+ precache: false
44
+ tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
45
+ resolution: 256
46
+ block_size: 128
47
+ n_val_samples: null
48
+ unpaired: false
49
+ n_duplicate_val: null
50
+ save_train_dataloader: true
51
+ save_validation_dataloader: true
52
+ iterable: false
53
+ webdataset_iterable: false
54
+ webdataset_indexed: false
55
+ dataset_type: null
56
+ tokens_flip_collate: false
57
+ n_train_samples: null
58
+ raw_data_dir: null
59
+ tokenizers_parallelism: false
60
+ token_data_dir: null
61
+ force_disable_shuffle: false
62
+ keep_tensordict_on_disk: true
63
+ use_custom_tensordict_collate: true
64
+ force_mp_spawn: false
65
+ enable_cuda_in_tensordict_collate: false
66
+ use_weighted_tensordict_sampler: true
67
+ fraction_txt_data: 0.0
68
+ data_dir_train:
69
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
70
+ weight: -1
71
+ name: datacomp1b_8_magvit_train
72
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
73
+ weight: -1
74
+ name: cc12m_tokens_train_256
75
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
76
+ weight: -1
77
+ name: HPDv2_image_reward_v1_v2_v3_magvit
78
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
79
+ weight: -1
80
+ name: pick_score_sac_prompts_v1_v2_v3_magvit
81
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
82
+ weight: -1
83
+ name: datacomp1b_0_1_6_magvit
84
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
85
+ weight: -1
86
+ name: laion400m_magvit_part_0
87
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
88
+ weight: -1
89
+ name: laion400m_magvit_part_1
90
+ data_dir_val:
91
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
92
+ weight: 1
93
+ name: datacomp1b_8_magvit_val
94
+ - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
95
+ weight: 1
96
+ name: cc12m_tokens_val_256
97
+ tokenize_vqvae_in_dataloader: false
98
+ val:
99
+ .train: null
100
+ use_token_dataset: true
101
+ image_dataset: tglcourse/lsun_church_train
102
+ image_data_train: null
103
+ image_data_val: null
104
+ keep_hf_dataset_in_memory: true
105
+ allow_label: false
106
+ disable_text_modality: true
107
+ force_raw_train_images: false
108
+ aggressive_aug: true
109
+ allow_aug_vqvae_dataloader: true
110
+ move_tensordict_to_shm: false
111
+ force_full_attention_mask: false
112
+ eval:
113
+ generate_samples: false
114
+ compute_generative_perplexity: false
115
+ log_every_n_evals: 10
116
+ log_every_n_fid: 20
117
+ limit_val_batches_manual: 16
118
+ perplexity_batch_size: ${loader.eval_batch_size}
119
+ num_masking_viz_batches: -1
120
+ max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
121
+ cfg: null
122
+ class_conditional_fid: false
123
+ force_cfg_value: true
124
+ split_cfg_batches: true
125
+ fid_mode: clean
126
+ clean_fid_precomputed_name: lsun_church
127
+ clean_fid_precomputed_split: trainfull
128
+ clean_fid_precomputed_res: 256
129
+ trainer:
130
+ log_every_n_steps: 10
131
+ val_check_interval: 1000
132
+ custom_ddp_bf16: true
133
+ scale_lr_by_batch_size: false
134
+ limit_val_batches: 16
135
+ use_gradient_checkpointing: false
136
+ log_seperate_modal_losses: true
137
+ softmin_snr: 5
138
+ text_loss_weight: 1.0
139
+ img_loss_weight: null
140
+ low_precision_loss: false
141
+ compile: false
142
+ multimodal_batches: true
143
+ compile_fullgraph: false
144
+ log_grad_norm_every_n_steps: 10
145
+ mask_entire_modality: 0.1
146
+ force_shift_image_batches: false
147
+ ckpt_steps: 10000
148
+ ckpt_every_n_minutes: -1
149
+ ignore_text_in_unified: false
150
+ disable_all_eval_generation: false
151
+ eval_on_start: false
152
+ ckpt_model_only: false
153
+ ema: 0.0
154
+ use_custom_ema: false
155
+ log_flops: false
156
+ disable_distributed_torchmetrics: true
157
+ restart_on_failure: true
158
+ force_null_sigma: true
159
+ allow_null_sigma: true
160
+ compile_flag_pos_emb: true
161
+ add_label: false
162
+ first_token_dropout: null
163
+ force_shift_raw_image_batches: true
164
+ txt_dropout: 0.1
165
+ disable_ddp_optimizer: true
166
+ optim:
167
+ lr: 0.0003
168
+ weight_decay: 0.05
169
+ loader:
170
+ batch_size: 64
171
+ eval_batch_size: ${loader.batch_size}
172
+ num_workers: 1
173
+ desired_global_batch_size: 512
174
+ persistent_workers: true
175
+ pin_memory: true
176
+ num_eval_workers: 1
177
+ sampling:
178
+ steps: ${model.length}
179
+ num_sample_batches: 2
180
+ max_sampling_steps: ${model.length}
181
+ wandb:
182
+ mode: online
183
+ lr_scheduler:
184
+ num_warmup_steps: 5000
185
+ checkpointing:
186
+ checkpoints_total_limit: 4
configs/experiments/small_text_only.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - lsun_text8_exp_2
5
+ - owt_only
6
+ - override /model: small
7
+
8
+ backbone: dit
9
+
10
+ loader:
11
+ batch_size: 64
12
+
13
+ trainer:
14
+ val_check_interval: 10000
15
+ ckpt_steps: 10000
16
+ softmin_snr: null
17
+
18
+ optim:
19
+ fused: true
20
+ weight_decay: 0.03
21
+
22
+ sampling:
23
+ num_sample_batches: 4
24
+ max_sampling_steps: 256
25
+
26
+ model:
27
+ txt_length: 1024
28
+
configs/experiments/standalone_fid_eval.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ mode: eval
4
+ debug: true
5
+
6
+ eval:
7
+ max_num_fid_batches_per_device: ${eval:'4096 // (${trainer.devices} * ${loader.eval_batch_size})'}
8
+ compute_generative_perplexity: false
9
+ generate_samples: false
10
+ log_every_n_fid: 1
11
+ log_every_n_evals: 1
12
+
13
+ loader:
14
+ eval_batch_size: 32
15
+
16
+ sampling:
17
+ steps: 500
18
+ max_sampling_steps: 500
configs/experiments/titok.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ data:
4
+ resolution: 256
5
+ downscale_ratio: 16
6
+
7
+ model:
8
+ vae_type: titok
configs/experiments/titok_sl256.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ data:
4
+ resolution: 256
5
+
6
+ model:
7
+ vae_type: titok
configs/experiments/txt_only.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ data:
4
+ streaming: False
5
+ unpaired: false
6
+
7
+ trainer:
8
+ img_loss_weight: null
9
+ text_loss_weight: null
10
+
11
+ model:
12
+ use_pretrained_img_emb: false
13
+ image_model_fid_eval: false
14
+ unified_model: false
15
+ image_model: false
16
+ txt_length: 256
17
+ img_length: 0
18
+
19
+ eval:
20
+ log_every_n_evals: -1
21
+ log_every_n_fid: -1
configs/experiments/unified.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ data:
4
+ zero_shot_eval_dataset: "nlphuji/flickr30k"
5
+ precache: False
6
+ tokenizers_parallelism: False # parallelism causes some weird error
7
+ n_val_samples: 2048
8
+ block_size: 128
9
+
10
+ model:
11
+ unified_model: True
12
+ text_model: true
13
+
14
+ checkpointing:
15
+ resume_from_ckpt: True
16
+ load_from_text_model: "ckpts/unidisc-owt/model.safetensors"
17
+
18
+ loader:
19
+ batch_size: 12
20
+
21
+ trainer:
22
+ val_check_interval: 2000
23
+ log_seperate_modal_losses: true
configs/experiments/vq16.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ model:
4
+ downscale_ratio: 16
5
+ image_vocab_size: 16384
6
+ vae_type: VQ-16
7
+ use_custom_vae_ckpt: null
8
+ custom_vae_name: null
9
+ img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
configs/experiments/vq16_1024.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ model:
4
+ downscale_ratio: 16
5
+ image_vocab_size: 1024
6
+ codebook_embed_dim: 256
7
+ vae_type: VQ-16
8
+ use_custom_vae_ckpt: ${oc.env:DIFFUSION_DATA_DIR}/ckpts/2024-07-03-01-10-53_022-VQ-16_0042000.pt
configs/experiments/vq16_magvit.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ model:
4
+ downscale_ratio: 16
5
+ image_vocab_size: 8192
6
+ vae_type: magvit
7
+ use_custom_vae_ckpt: null
8
+ custom_vae_name: null
9
+ img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
configs/experiments/vq16_t2i.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ model:
4
+ downscale_ratio: 16
5
+ image_vocab_size: 16384
6
+ vae_type: VQ-16
7
+ use_custom_vae_ckpt: ${get_repo_dir:}/ckpts/vq_ds16_t2i.pt
8
+ custom_vae_name: _t2i
9
+ codebook_embed_dim: 8
10
+ img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
configs/experiments/webdataset.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ data:
4
+ train: datacomp1b_indexed
5
+ valid: ${.train}
6
+
7
+ iterable: false
8
+ webdataset_iterable: false
9
+ webdataset_indexed: true
10
+ unpaired: false
11
+ dataset_type: null
12
+ tokens_flip_collate: false
configs/experiments/zero_shot_eval.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ mode: zero-shot-eval
4
+
5
+ data:
6
+ # train: "nlphuji/flickr30k"
7
+ train: "facebook/winoground"
8
+ precache: False
9
+ tokenizers_parallelism: False # parallelism causes some weird error
10
+ n_val_samples: 2048
11
+ block_size: 128
12
+ disable_text_modality: false
13
+
14
+ eval:
15
+ cfg: 5
16
+ compute_val_metrics_standalone: false
17
+ compute_img_to_txt_mauve_clip: false
18
+
19
+ loader:
20
+ batch_size: 16
21
+ eval_batch_size: 16
22
+
23
+
24
+ model:
25
+ unified_model: True
26
+ text_model: true
27
+ image_model: true
28
+ vae_type: magvit
29
+ force_optimized_native_attn: false
configs/lr_scheduler/constant_warmup.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: transformers.get_constant_schedule_with_warmup
2
+ num_warmup_steps: 2500
configs/lr_scheduler/constant_warmup_cosine_decay.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ _target_: transformers.get_cosine_schedule_with_warmup
2
+ num_warmup_steps: 2500
3
+ num_training_steps: 1000000
configs/lr_scheduler/cosine_decay_warmup.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: utils.CosineDecayWarmupLRScheduler
2
+ t_in_epochs: False
3
+ t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
4
+ warmup_prefix: True
5
+ warmup_lr_init: 1e-6
6
+ warmup_t: ${eval:0.1*${trainer.max_steps}}
7
+ lr_min: 1e-6
configs/lr_scheduler/cosine_with_hard_restarts_schedule_with_warmup.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: transformers.get_cosine_with_hard_restarts_schedule_with_warmup
2
+ num_warmup_steps: 2500
3
+ num_training_steps: 1000000
4
+ num_cycles: 1
configs/model/extra_large.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: extra_large
2
+ type: ddit
3
+ hidden_size: 2048
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 24
7
+ n_heads: 16
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ tie_word_embeddings: False
configs/model/large.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: large
2
+ type: ddit
3
+ hidden_size: 1280
4
+ cond_dim: 128
5
+ length: 1024
6
+ base_n_blocks: 28
7
+ # We try to roughly match parameter count
8
+ n_blocks: ${adjust_n_blocks:}
9
+ n_heads: 20
10
+ scale_by_sigma: True
11
+ dropout: 0.1
12
+ tie_word_embeddings: False
13
+
14
+ # 36 1280 20
configs/model/medium.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: medium
2
+ type: ddit
3
+ hidden_size: 1024
4
+ cond_dim: 128
5
+ length: 1024
6
+ base_n_blocks: 24
7
+ # We try to roughly match parameter count
8
+ n_blocks: ${adjust_n_blocks:}
9
+ n_heads: 16
10
+ scale_by_sigma: True
11
+ dropout: 0.1
12
+ tie_word_embeddings: False
configs/model/small-ar.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: small
2
+ type: ddit
3
+ hidden_size: 768
4
+ cond_dim: 128
5
+ length: 1024
6
+ n_blocks: 12
7
+ n_heads: 12
8
+ scale_by_sigma: True
9
+ dropout: 0.1
10
+ causal: True
11
+ tie_word_embeddings: False