Commit
·
131da64
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +37 -0
- .gitmodules +15 -0
- Dockerfile +79 -0
- README.md +82 -0
- __builtins__.pyi +7 -0
- configs/config.yaml +451 -0
- configs/config_empty.yaml +8 -0
- configs/experiments/ar.yaml +10 -0
- configs/experiments/elm.yaml +15 -0
- configs/experiments/eval_model.yaml +21 -0
- configs/experiments/eval_text.yaml +26 -0
- configs/experiments/eval_text_only.yaml +30 -0
- configs/experiments/eval_unified.yaml +27 -0
- configs/experiments/fid_cc12m.yaml +22 -0
- configs/experiments/fid_datacomp1b.yaml +22 -0
- configs/experiments/fid_hf.yaml +25 -0
- configs/experiments/jan_cub.yaml +51 -0
- configs/experiments/large_maskdit_exp.yaml +7 -0
- configs/experiments/large_scale_high_res_interleaved_inference.yaml +51 -0
- configs/experiments/large_scale_train.yaml +151 -0
- configs/experiments/large_scale_train_high_res.yaml +39 -0
- configs/experiments/large_scale_train_high_res_inference.yaml +30 -0
- configs/experiments/large_scale_train_high_res_interleaved.yaml +105 -0
- configs/experiments/maskgit.yaml +6 -0
- configs/experiments/master_eval.yaml +49 -0
- configs/experiments/mscoco_fid.yaml +21 -0
- configs/experiments/paired_standalone_fid_eval.yaml +29 -0
- configs/experiments/small_scale_train.yaml +187 -0
- configs/experiments/small_scale_train_caching.yaml +186 -0
- configs/experiments/small_text_only.yaml +28 -0
- configs/experiments/standalone_fid_eval.yaml +18 -0
- configs/experiments/titok.yaml +8 -0
- configs/experiments/titok_sl256.yaml +7 -0
- configs/experiments/txt_only.yaml +21 -0
- configs/experiments/unified.yaml +23 -0
- configs/experiments/vq16.yaml +9 -0
- configs/experiments/vq16_1024.yaml +8 -0
- configs/experiments/vq16_magvit.yaml +9 -0
- configs/experiments/vq16_t2i.yaml +10 -0
- configs/experiments/webdataset.yaml +12 -0
- configs/experiments/zero_shot_eval.yaml +29 -0
- configs/lr_scheduler/constant_warmup.yaml +2 -0
- configs/lr_scheduler/constant_warmup_cosine_decay.yaml +3 -0
- configs/lr_scheduler/cosine_decay_warmup.yaml +7 -0
- configs/lr_scheduler/cosine_with_hard_restarts_schedule_with_warmup.yaml +4 -0
- configs/model/extra_large.yaml +10 -0
- configs/model/large.yaml +14 -0
- configs/model/medium.yaml +12 -0
- configs/model/small-ar.yaml +11 -0
.gitattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
outputs/
|
3 |
+
ckpts/
|
4 |
+
vqgan/vqgan_pretrained/
|
5 |
+
vqgan/vqgan_taming_ckpt/
|
6 |
+
data/
|
7 |
+
models/datasets/.cache/
|
8 |
+
*.json
|
9 |
+
output/
|
10 |
+
tmp*
|
11 |
+
multirun/
|
12 |
+
.nfs*
|
13 |
+
lightning_logs/
|
14 |
+
static/
|
15 |
+
archive/
|
16 |
+
output_profile/
|
17 |
+
logs/
|
18 |
+
.history/
|
19 |
+
.cache/
|
20 |
+
output*/
|
21 |
+
*.out
|
22 |
+
*.parquet
|
23 |
+
wandb/
|
24 |
+
vqgan/
|
25 |
+
*.csv
|
26 |
+
.python-version
|
27 |
+
ft_cache/
|
28 |
+
alias.txt
|
29 |
+
env.sh
|
30 |
+
generated_image.png
|
31 |
+
Untitled-1.ipynb
|
32 |
+
*.log
|
33 |
+
demo/old
|
34 |
+
*.pem
|
35 |
+
.sesskey
|
36 |
+
icons.py
|
37 |
+
generated/
|
.gitmodules
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "third_party/LlamaGen"]
|
2 |
+
path = third_party/LlamaGen
|
3 |
+
url = https://github.com/alexanderswerdlow/LlamaGen.git
|
4 |
+
branch = wip_v1
|
5 |
+
[submodule "third_party/Lumina-mGPT"]
|
6 |
+
path = third_party/Lumina-mGPT
|
7 |
+
url = https://github.com/alexanderswerdlow/Lumina-mGPT.git
|
8 |
+
branch = non_causal
|
9 |
+
[submodule "third_party/Show-o"]
|
10 |
+
path = third_party/Show-o
|
11 |
+
url = https://github.com/showlab/Show-o.git
|
12 |
+
[submodule "third_party/1d-tokenizer"]
|
13 |
+
path = third_party/1d-tokenizer
|
14 |
+
url = https://github.com/bytedance/1d-tokenizer.git
|
15 |
+
branch = main
|
Dockerfile
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base image with CUDA 12.6.3 and cuDNN
|
2 |
+
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04
|
3 |
+
|
4 |
+
# Set environment variables
|
5 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
6 |
+
ENV PYTHONUNBUFFERED=1 \
|
7 |
+
SYSTEM=spaces \
|
8 |
+
AM_I_IN_A_DOCKER_CONTAINER=Yes \
|
9 |
+
PYTHONPATH=/home/appuser/app \
|
10 |
+
HF_HOME=/home/appuser/.cache \
|
11 |
+
TORCH_HOME=/home/appuser/.cache \
|
12 |
+
TMP_DIR=/home/appuser/tmp \
|
13 |
+
TRANSFORMERS_CACHE=/home/appuser/.cache/transformers \
|
14 |
+
NVIDIA_VISIBLE_DEVICES=all \
|
15 |
+
NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
16 |
+
|
17 |
+
# Install system dependencies and set Python 3.10 as default
|
18 |
+
RUN apt-get update && apt-get install --no-install-recommends -y \
|
19 |
+
build-essential \
|
20 |
+
python3.10 \
|
21 |
+
python3.10-distutils \
|
22 |
+
python3-pip \
|
23 |
+
ffmpeg \
|
24 |
+
libsm6 \
|
25 |
+
libxext6 \
|
26 |
+
libgl1 \
|
27 |
+
git \
|
28 |
+
openssh-client \
|
29 |
+
&& ln -sf /usr/bin/python3.10 /usr/bin/python \
|
30 |
+
&& ln -sf /usr/bin/pip3 /usr/bin/pip \
|
31 |
+
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
32 |
+
|
33 |
+
# Install `uv`
|
34 |
+
RUN pip install --upgrade pip \
|
35 |
+
&& pip install uv
|
36 |
+
|
37 |
+
# Create a non-root user
|
38 |
+
RUN useradd -m -u 1000 appuser
|
39 |
+
|
40 |
+
# Set working directory
|
41 |
+
WORKDIR /home/appuser/app
|
42 |
+
|
43 |
+
# Copy dependency files and install dependencies
|
44 |
+
COPY --chown=appuser pyproject.toml uv.lock README.md ./
|
45 |
+
RUN mkdir -p -m 0600 ~/.ssh && ssh-keyscan github.com >> ~/.ssh/known_hosts
|
46 |
+
|
47 |
+
RUN --mount=type=ssh uv sync --no-group dev
|
48 |
+
RUN --mount=type=ssh uv sync --frozen --no-cache \
|
49 |
+
&& chown -R appuser:appuser /home/appuser/app/.venv \
|
50 |
+
&& rm -rf /root/.cache /home/appuser/.cache
|
51 |
+
|
52 |
+
# Ensure non-root user has write access to cache and tmp directories
|
53 |
+
RUN mkdir -p /home/appuser/.cache/transformers /home/appuser/tmp /home/appuser/.cache \
|
54 |
+
&& chown -R appuser:appuser /home/appuser/.cache /home/appuser/tmp/ /home/appuser/app/
|
55 |
+
|
56 |
+
RUN chmod -R 777 /tmp
|
57 |
+
|
58 |
+
# Copy application code
|
59 |
+
COPY --chown=appuser demo demo
|
60 |
+
COPY --chown=appuser unidisc unidisc
|
61 |
+
COPY --chown=appuser models models
|
62 |
+
COPY --chown=appuser configs configs
|
63 |
+
COPY --chown=appuser third_party third_party
|
64 |
+
COPY --chown=appuser ckpts ckpts
|
65 |
+
COPY --chown=appuser ./__* ./
|
66 |
+
COPY --chown=appuser ./*.py ./
|
67 |
+
COPY --chown=appuser ./archive/pytorch_model_fsdp.bin ./
|
68 |
+
|
69 |
+
# Switch to non-root user
|
70 |
+
USER appuser
|
71 |
+
|
72 |
+
# Expose port for Gradio
|
73 |
+
EXPOSE 5003
|
74 |
+
|
75 |
+
# Command to run the application
|
76 |
+
CMD ["bash", "demo/demo.sh"]
|
77 |
+
|
78 |
+
# DOCKER_BUILDKIT=1 docker build --ssh default --network=host -t unidisc .
|
79 |
+
# docker run --network=host -it -p 5003:5003 unidisc
|
README.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<br>
|
3 |
+
<img src="docs/images/banner.webp" width="1000">
|
4 |
+
<h3>Unified Multimodal Discrete Diffusion</h3>
|
5 |
+
|
6 |
+
[Alexander Swerdlow](https://aswerdlow.com/)<sup>1*</sup>
|
7 |
+
[Mihir Prabhudesai](https://mihirp1998.github.io/)<sup>1*</sup>
|
8 |
+
[Siddharth Gandhi](hhttps://www.ssgandhi.com/)<sup>1</sup>
|
9 |
+
[Deepak Pathak](https://www.cs.cmu.edu/~dpathak/)<sup>1</sup>
|
10 |
+
[Katerina Fragkiadaki](https://www.cs.cmu.edu/~katef/)<sup>1</sup>
|
11 |
+
<br>
|
12 |
+
|
13 |
+
<sup>1</sup> Carnegie Mellon University
|
14 |
+
|
15 |
+
[](https://arxiv.org/pdf/0000.00000) [](https://unidisc.github.io/)
|
16 |
+
|
17 |
+
<!-- [](https://huggingface.co/spaces/todo) -->
|
18 |
+
|
19 |
+
</div>
|
20 |
+
|
21 |
+
## Hugging Face models and annotations
|
22 |
+
|
23 |
+
The UniDisc checkpoints are available on [Hugging Face](https://huggingface.co/unidisc):
|
24 |
+
* [unidisc/todo](https://huggingface.co/unidisc/todo)
|
25 |
+
|
26 |
+
## Getting Started
|
27 |
+
|
28 |
+
To install the dependencies, run:
|
29 |
+
```bash
|
30 |
+
git submodule update --init --recursive
|
31 |
+
uv sync --no-group dev
|
32 |
+
uv sync
|
33 |
+
```
|
34 |
+
|
35 |
+
For a more detailed installation guide, please refer to [INSTALL.md](docs/INSTALL.md).
|
36 |
+
|
37 |
+
## Training
|
38 |
+
|
39 |
+
See [TRAIN.md](docs/TRAIN.md) for details.
|
40 |
+
|
41 |
+
## Inference
|
42 |
+
|
43 |
+
<!-- Inference demo for **TODO**.
|
44 |
+
```
|
45 |
+
TODO
|
46 |
+
``` -->
|
47 |
+
<!-- <img src="docs/todo.png" width="1000"> -->
|
48 |
+
|
49 |
+
|
50 |
+
Interactive demo for **TODO**.
|
51 |
+
```
|
52 |
+
python demo/server.py
|
53 |
+
python demo/client_simple_fasthtml.py
|
54 |
+
```
|
55 |
+
|
56 |
+
|
57 |
+
## Training
|
58 |
+
|
59 |
+
See [TRAINING.md](docs/TRAINING.md) for details.
|
60 |
+
|
61 |
+
## Evaluation
|
62 |
+
|
63 |
+
See [EVAL.md](docs/EVAL.md) for details.
|
64 |
+
|
65 |
+
|
66 |
+
### Citation
|
67 |
+
To cite our work, please use the following:
|
68 |
+
```
|
69 |
+
@article{TODO,
|
70 |
+
title={TODO},
|
71 |
+
author={TODO},
|
72 |
+
journal={arXiv preprint arXiv:TODO},
|
73 |
+
year={TODO}
|
74 |
+
}
|
75 |
+
```
|
76 |
+
|
77 |
+
## Credits
|
78 |
+
|
79 |
+
This repository is built on top of the following repositories:
|
80 |
+
|
81 |
+
- [MDLM](https://github.com/kuleshov-group/mdlm)
|
82 |
+
- [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X)
|
__builtins__.pyi
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ipdb import set_trace as st
|
2 |
+
from decoupled_utils import start_timing as start_timing
|
3 |
+
from decoupled_utils import end_timing as end_timing
|
4 |
+
ENABLE_TIMING: bool
|
5 |
+
ENABLE_TIMING_SYNC: bool
|
6 |
+
DEVICE_BACKEND_TYPE: str
|
7 |
+
exists = lambda v: v is not None
|
configs/config.yaml
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _self_
|
3 |
+
- /model: small
|
4 |
+
- /noise: loglinear
|
5 |
+
- /lr_scheduler: constant_warmup
|
6 |
+
- /experiments: []
|
7 |
+
# - override hydra/launcher: submitit_slurm
|
8 |
+
|
9 |
+
slurm: False
|
10 |
+
debug: False
|
11 |
+
mode: train # train / eval
|
12 |
+
diffusion: absorbing_state
|
13 |
+
backbone: dit # dit / dimamba / ar
|
14 |
+
parameterization: subs # subs / d3pm / sedd
|
15 |
+
time_conditioning: False
|
16 |
+
T: 0 # 0 (continuous time) / 1000
|
17 |
+
subs_masking: False
|
18 |
+
seed: 42
|
19 |
+
profile: False
|
20 |
+
# These belong in trainer.* and hydra.launcher.* but are put here for CLI convinience
|
21 |
+
devices: ${device_count:}
|
22 |
+
nodes: 1
|
23 |
+
partition: ${find_partition:}
|
24 |
+
constraint: ${find_constraint:}
|
25 |
+
ckpt: null
|
26 |
+
|
27 |
+
loader:
|
28 |
+
desired_global_batch_size: 512
|
29 |
+
global_batch_size: null
|
30 |
+
eval_global_batch_size: ${.global_batch_size}
|
31 |
+
batch_size: ${div_up:${.desired_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
32 |
+
eval_batch_size: ${div_up:${.desired_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
33 |
+
num_workers: ${eval:"max(len(__import__('os').sched_getaffinity(0)) // 16, 4)"}
|
34 |
+
pin_memory: True
|
35 |
+
persistent_workers: True
|
36 |
+
|
37 |
+
sampling:
|
38 |
+
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
|
39 |
+
steps: 1000
|
40 |
+
max_sampling_steps: 500 # The highest level we use for sampling
|
41 |
+
noise_removal: True
|
42 |
+
num_sample_log: 2
|
43 |
+
semi_ar: False
|
44 |
+
stride_length: 1
|
45 |
+
num_strides: 1
|
46 |
+
|
47 |
+
eval:
|
48 |
+
checkpoint_path: '' # Used to evaluate a checkpoint after training.
|
49 |
+
disable_ema: False
|
50 |
+
compute_generative_perplexity: False
|
51 |
+
perplexity_batch_size: 8
|
52 |
+
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
|
53 |
+
generate_samples: True
|
54 |
+
cfg: null
|
55 |
+
num_masking_viz_batches: 1
|
56 |
+
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
|
57 |
+
test_eval_speed: False
|
58 |
+
standalone_fid: False
|
59 |
+
visualize_data_only: false
|
60 |
+
val_with_train_data: false
|
61 |
+
max_num_fid_batches_per_device: null
|
62 |
+
class_conditional_fid: false
|
63 |
+
compute_entropy: false
|
64 |
+
compute_standalone_mauve: false
|
65 |
+
compute_standalone_entropy: false
|
66 |
+
compute_img_to_txt_mauve_clip: false
|
67 |
+
compute_img_to_txt_mauve_during_unconditional_fid: false
|
68 |
+
mauve_num_samples: 5000
|
69 |
+
mauve_divergence_curve_discretization_size: 25 # default in mauve repo
|
70 |
+
mauve_average_over_seeds: 3
|
71 |
+
mauve_scaling_factor: 5 # default in mauve repo
|
72 |
+
txt_conditional_fid: false
|
73 |
+
unconditional_fid: false
|
74 |
+
fid_mode: inline
|
75 |
+
calculate_clip_score: false
|
76 |
+
clean_fid_use_precomputed_stats: false
|
77 |
+
clean_fid_precomputed_name: null
|
78 |
+
clean_fid_precomputed_split: null
|
79 |
+
clean_fid_precomputed_res: null
|
80 |
+
attention_caching: false
|
81 |
+
set_random_gen_seed: false
|
82 |
+
compute_val_metrics_standalone: false
|
83 |
+
num_val_metrics_standalone_batches_per_device: ${eval:'max(${eval.num_val_metrics_standalone_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
|
84 |
+
num_val_metrics_standalone_samples: -1
|
85 |
+
return_unweighed_sim: false
|
86 |
+
compute_chameleon_perplexity: false
|
87 |
+
global_disable_mauve: false
|
88 |
+
bypass_normal_validation: false
|
89 |
+
auto_enhance: false
|
90 |
+
num_auto_enhance_iter: 2
|
91 |
+
ar_inpainting_min_val: 0.5
|
92 |
+
ar_inpainting_max_val: 1.0
|
93 |
+
ar_inpainting_force_val: null
|
94 |
+
|
95 |
+
optim:
|
96 |
+
weight_decay: 0
|
97 |
+
lr: 3e-4
|
98 |
+
beta1: 0.9
|
99 |
+
beta2: 0.999
|
100 |
+
eps: 1e-8
|
101 |
+
fused: true
|
102 |
+
|
103 |
+
model:
|
104 |
+
use_custom_vae_config: false
|
105 |
+
use_custom_vae_ckpt: null
|
106 |
+
downscale_ratio: null
|
107 |
+
image_vocab_size: null
|
108 |
+
vae_type: null
|
109 |
+
use_attention_mask: false
|
110 |
+
|
111 |
+
cond_use_custom_vae_config: false
|
112 |
+
cond_use_custom_vae_ckpt: null
|
113 |
+
cond_downscale_ratio: null
|
114 |
+
cond_image_vocab_size: null
|
115 |
+
cond_vae_type: null
|
116 |
+
text_model: true
|
117 |
+
|
118 |
+
attn_type: flash
|
119 |
+
force_varlen_attn: false
|
120 |
+
force_cast_bf16: false
|
121 |
+
norm_type: layernorm
|
122 |
+
mup: false
|
123 |
+
qk_norm: false
|
124 |
+
distillation: false
|
125 |
+
force_argmax_valid_indices: false
|
126 |
+
use_flash_attn_3: false
|
127 |
+
use_spda_attn: false # Spelled wrong...
|
128 |
+
rope_2d: false
|
129 |
+
modality_embed: false
|
130 |
+
zero_linear_init: true
|
131 |
+
full_attention: true
|
132 |
+
use_lora: false
|
133 |
+
use_kv_cache: false
|
134 |
+
force_optimized_native_attn: false
|
135 |
+
use_pretrained_img_emb: true
|
136 |
+
use_flex_attention: false
|
137 |
+
add_labels: null
|
138 |
+
flex_attention_txt_masking_prob: null
|
139 |
+
flex_attention_img_masking_prob: null
|
140 |
+
|
141 |
+
trainer:
|
142 |
+
_target_: lightning.Trainer
|
143 |
+
accelerator: cuda
|
144 |
+
num_nodes: ${nodes}
|
145 |
+
devices: ${devices}
|
146 |
+
|
147 |
+
# Given a desired global batch size (e.g., how many batches we see before a optim.step, summed over all nodes/gpus/accum_steps), we find the number of gradient accumulations that gets us closest given our current configuration. We assume that loader.batch_size is the largest that can fit in a single fwd/bwd.
|
148 |
+
accumulate_grad_batches: ${find_grad_accum:${loader.desired_global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
149 |
+
gradient_clip_val: 1.0
|
150 |
+
precision: 'bf16'
|
151 |
+
max_steps: 1_000_000_000
|
152 |
+
|
153 |
+
num_epochs: 1_000_000_000
|
154 |
+
optimizer_cls: adamw
|
155 |
+
set_grads_to_none: true
|
156 |
+
eval_on_start: true
|
157 |
+
eval_decay_steps: false
|
158 |
+
eval_epochs: null
|
159 |
+
ckpt_steps: 100000
|
160 |
+
fsdp: false
|
161 |
+
force_enable_checkpointing: false
|
162 |
+
limit_val_batches: null
|
163 |
+
ckpt_every_n_minutes: 60
|
164 |
+
ckpt_recent_timeout_minutes: 10
|
165 |
+
checkpoint_all_ranks: true
|
166 |
+
force_null_sigma: false
|
167 |
+
|
168 |
+
log_every_n_steps: 10
|
169 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
170 |
+
val_check_interval: 100
|
171 |
+
|
172 |
+
ema: 0.9999
|
173 |
+
antithetic_sampling: True
|
174 |
+
importance_sampling: False
|
175 |
+
sampling_eps: 1e-3
|
176 |
+
change_of_variables: False
|
177 |
+
benchmark: true
|
178 |
+
backward_pass: true
|
179 |
+
forward_pass: true
|
180 |
+
profile_memory: false
|
181 |
+
pytorch_profile: false
|
182 |
+
nvtx_profile: false
|
183 |
+
custom_ddp_bf16: true
|
184 |
+
log_seperate_modal_losses: true
|
185 |
+
use_gradient_checkpointing: false
|
186 |
+
text_loss_weight: null
|
187 |
+
img_loss_weight: null
|
188 |
+
disable_strict_load: false
|
189 |
+
attach_oom_observer_eval: false
|
190 |
+
find_unused_parameters: false
|
191 |
+
restart_on_failure: false
|
192 |
+
skip_early_checkpointing: true
|
193 |
+
log_flops: true
|
194 |
+
sync_timing: false
|
195 |
+
use_custom_ema: false
|
196 |
+
scale_lr_by_batch_size: false
|
197 |
+
tpu_eager: false
|
198 |
+
allow_dynamic_nodes: false
|
199 |
+
force_disable_signal_handler: false
|
200 |
+
tpu_profile: false
|
201 |
+
tpu_cache: false
|
202 |
+
enable_jax_smi: false
|
203 |
+
tpu_compile_debug: false
|
204 |
+
xla_spmd: false
|
205 |
+
log_grad_norm: true
|
206 |
+
tpu_profile_markers: true
|
207 |
+
compile: false
|
208 |
+
disable_all_checkpointing: false
|
209 |
+
tpu_force_mark_step: false
|
210 |
+
ar_shift: false
|
211 |
+
ar_llm_loss: false
|
212 |
+
ar_print_loss: false
|
213 |
+
chameleon_z_loss: null
|
214 |
+
image_mode: discrete # continuous / discrete
|
215 |
+
chameleon_use_ce_loss: false
|
216 |
+
low_precision_loss: false
|
217 |
+
low_precision_params: false
|
218 |
+
scratch: false
|
219 |
+
use_spmd_distributed_checkpointing: null
|
220 |
+
use_simple_spmd_distributed_checkpointing: false
|
221 |
+
load_from_state_dict: null
|
222 |
+
load_from_optimizer_state_dict: null
|
223 |
+
multimodal_batches: false
|
224 |
+
sync_dataloader_timing: false
|
225 |
+
compile_flag_pos_emb: false
|
226 |
+
compile_fullgraph: false
|
227 |
+
compile_mode: max-autotune-no-cudagraphs
|
228 |
+
joint_ar_nar_prob: null
|
229 |
+
joint_ar_nar_prob_warmup_steps: null
|
230 |
+
joint_ar_nar_timestep_warmup_steps: null
|
231 |
+
spmd_mesh: null
|
232 |
+
detect_anomaly: false
|
233 |
+
freeze_chameleon_embeddings: false
|
234 |
+
ckpt_model_only: false
|
235 |
+
use_orig_params: null
|
236 |
+
disable_adjust_num_warmup_steps: false
|
237 |
+
mask_entire_modality: null
|
238 |
+
iterate_dataloader_only: false
|
239 |
+
force_bf16_eval: false
|
240 |
+
disable_all_eval_generation: false
|
241 |
+
debug_xla_sept: false
|
242 |
+
ignore_text_in_unified: false
|
243 |
+
allow_null_sigma: false
|
244 |
+
disable_forward_autocast_during_eval: false
|
245 |
+
viz_images_only: false
|
246 |
+
add_label: false
|
247 |
+
first_token_dropout: null
|
248 |
+
disable_ddp_optimizer: false
|
249 |
+
rand_flip_ar_prob: null
|
250 |
+
rand_ar_modality_dropout: null
|
251 |
+
use_linear_warmup_cosine_annealing: false
|
252 |
+
no_ce_weighting: false
|
253 |
+
interleaved: false
|
254 |
+
interleaved_training_flex_attention: false
|
255 |
+
awr: false
|
256 |
+
ar_inpainting: false
|
257 |
+
|
258 |
+
wandb:
|
259 |
+
entity: grads
|
260 |
+
project: ${eval:'"unidisc-debug" if ${debug} else "unidisc"'}
|
261 |
+
resume: ${eval:'"allow" if ${slurm} else None'}
|
262 |
+
id: null
|
263 |
+
group: null
|
264 |
+
job_type: null
|
265 |
+
name: null
|
266 |
+
tags:
|
267 |
+
- ${data.train}
|
268 |
+
|
269 |
+
checkpointing_root_dir: ${oc.env:UNIDISC_CHECKPOINTING_ROOT_DIR,null}
|
270 |
+
root_output_dir: ${oc.env:UNIDISC_ROOT_OUTPUT_DIR,outputs}
|
271 |
+
python_orig: |
|
272 |
+
accelerate launch \
|
273 |
+
--num_machines $SLURM_NNODES \
|
274 |
+
--num_processes $NUM_PROCESSES \
|
275 |
+
--rdzv_backend c10d \
|
276 |
+
--main_process_ip $MASTER_ADDR \
|
277 |
+
--main_process_port $MASTER_PORT \
|
278 |
+
--machine_rank $SLURM_PROCID \
|
279 |
+
--mixed_precision bf16 \
|
280 |
+
--dynamo_backend no \
|
281 |
+
--enable_cpu_affinity \
|
282 |
+
--max_restarts 0 \
|
283 |
+
|
284 |
+
mem_per_gpu: 40
|
285 |
+
cpus_per_gpu: 8
|
286 |
+
slurm_name: null
|
287 |
+
timeout_min: ${partition_limit:${partition}}
|
288 |
+
hydra:
|
289 |
+
run:
|
290 |
+
dir: ${oc.env:HYDRA_RUN_DIR,${root_output_dir}/outputs/${get_dir_name:}/${oc.env:HYDRA_RUN_DIR_NAME,${now:%Y_%m_%d}/${now:%H_%M_%S}}}
|
291 |
+
sweep:
|
292 |
+
dir: ${oc.env:HYDRA_RUN_DIR,${root_output_dir}/outputs/${get_dir_name:}/${oc.env:HYDRA_RUN_DIR_NAME,${now:%Y_%m_%d}/${now:%H_%M_%S}}}
|
293 |
+
subdir: ${hydra.job.id}
|
294 |
+
job:
|
295 |
+
chdir: true
|
296 |
+
# launcher:
|
297 |
+
# name: ${get_slurm_name:}
|
298 |
+
# # See https://hydra.cc/docs/configure_hydra/workdir/
|
299 |
+
# submitit_folder: ${hydra.sweep.dir}/%j
|
300 |
+
# nodes: ${nodes} # Number of nodes. This value is *per* node
|
301 |
+
# mem_gb: ${eval:'${mem_per_gpu} * ${trainer.devices}'} # 40GB per gpu. This value is *per* node
|
302 |
+
# gpus_per_node: ${trainer.devices}
|
303 |
+
# partition: ${partition}
|
304 |
+
# constraint: ${constraint}
|
305 |
+
# exclude: ${exclude_nodes:}
|
306 |
+
|
307 |
+
# timeout_min: ${timeout_min}
|
308 |
+
# max_num_timeout: 12 # Num requeue exlcuding pre-emptions
|
309 |
+
# comment: aswerdlo
|
310 |
+
# stderr_to_stdout: true
|
311 |
+
|
312 |
+
# # Be careful with changing anything below.
|
313 |
+
# # see: https://github.com/stas00/ml-engineering/tree/master/training/fault-tolerance#approach-b2-choosing-which-process-to-send-the-signal-to
|
314 |
+
# # see: https://github.com/huggingface/accelerate/issues/1918
|
315 |
+
|
316 |
+
# # The accelerate launcher w/1 initial process and then spawn 1 per GPU
|
317 |
+
# tasks_per_node: 1
|
318 |
+
# cpus_per_task: ${eval:'${cpus_per_gpu} * ${trainer.devices}'}
|
319 |
+
# python: |
|
320 |
+
# bash -c "torchrun --nnodes $SLURM_NNODES --nproc_per_node $SLURM_GPUS_PER_NODE --role \$(hostname -s|tr -dc '0-9'): --node_rank \$SLURM_PROCID --max-restarts=2 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
|
321 |
+
|
322 |
+
# # python: "${getpythoncmd:}"
|
323 |
+
# # tasks_per_node: ${devices}
|
324 |
+
# # cpus_per_task: 8
|
325 |
+
# # python: 'python'
|
326 |
+
|
327 |
+
# python_suffix: ' --dummy-arg $SLURM_JOB_ID" &'
|
328 |
+
# signal: 'B:USR2@360'
|
329 |
+
# post_srun_commands:
|
330 |
+
# - ''
|
331 |
+
# - wait
|
332 |
+
|
333 |
+
# srun_args:
|
334 |
+
# - '--jobid $SLURM_JOB_ID'
|
335 |
+
|
336 |
+
# setup:
|
337 |
+
# - |
|
338 |
+
# export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
339 |
+
# export MASTER_PORT=$(( ($SLURM_JOB_ID % 20001) + 30000 ))
|
340 |
+
# export NUM_PROCESSES=$((SLURM_NNODES * SLURM_GPUS_PER_NODE))
|
341 |
+
# export NCCL_DEBUG=INFO
|
342 |
+
# export NCCL_NSOCKS_PERTHREAD=4
|
343 |
+
# export NCCL_SOCKET_NTHREADS=2
|
344 |
+
# export OMP_NUM_THREADS=2
|
345 |
+
# export PYTHONUNBUFFERED=1
|
346 |
+
# export STDOUT_PATH=$(scontrol show job $SLURM_JOB_ID | grep -oP "StdOut=\K[^ ]+")
|
347 |
+
# export LOCAL_JOB_FOLDER=$(dirname $STDOUT_PATH)
|
348 |
+
# export NCCL_TOPO_DUMP_FILE="$LOCAL_JOB_FOLDER/nccl_topo.xml"
|
349 |
+
# if [ -n "$SLURM_RESTART_COUNT" ]; then
|
350 |
+
# export RESTART_COUNT=$SLURM_RESTART_COUNT
|
351 |
+
# else
|
352 |
+
# export RESTART_COUNT=0
|
353 |
+
# fi
|
354 |
+
# export MAIN_LOG_PATH="$LOCAL_JOB_FOLDER/log_$RESTART_COUNT.txt"
|
355 |
+
|
356 |
+
# mkdir -p $LOCAL_JOB_FOLDER
|
357 |
+
# printenv > "$LOCAL_JOB_FOLDER"/env_"$SLURM_LOCALID_$RESTART_COUNT.txt"
|
358 |
+
|
359 |
+
# echo "ibstatus: $(ibstatus)"
|
360 |
+
# echo "ibdev2netdev: $(ibdev2netdev)"
|
361 |
+
# echo "rdma device: $(rdma link)"
|
362 |
+
# echo "environment: $(env | grep NCCL)"
|
363 |
+
# echo "NUM_PROCESSES: $NUM_PROCESSES, SLURM_NNODES: $SLURM_NNODES SLURM_GPUS_PER_NODE: $SLURM_GPUS_PER_NODE"
|
364 |
+
# echo "NODE_ID: $SLURM_NODEID, SLURM_PROCID: $SLURM_PROCID, MASTER_ADDR: $MASTER_ADDR, MASTER_PORT: $MASTER_PORT"
|
365 |
+
# echo "PWD: $PWD, LOCAL_JOB_FOLDER: $LOCAL_JOB_FOLDER, MAIN_LOG_PATH: $MAIN_LOG_PATH"
|
366 |
+
|
367 |
+
# trap 'echo "SIGUSR2 received for $SLURM_JOB_ID"; \
|
368 |
+
# if [ -n "$SLURM_ARRAY_JOB_ID" ]; then echo "SLURM_ARRAY_JOB_ID: $SLURM_ARRAY_JOB_ID"; fi; \
|
369 |
+
# if [ -n "$SLURM_ARRAY_TASK_ID" ]; then echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"; fi; \
|
370 |
+
# # ps auxww | grep $USER; \
|
371 |
+
# pid=$(pgrep -u $USER -f "python.*(accelerate|torchrun|deepspeed|distributed\.run).*dummy-arg $SLURM_JOB_ID"); \
|
372 |
+
# echo "Found parent PIDs: $pid"; \
|
373 |
+
# for p in $pid; do \
|
374 |
+
# echo "Parent PID has cmd: $(ps -p $p -o cmd=)"; \
|
375 |
+
# children=$(pgrep -P $p); \
|
376 |
+
# echo "Children: $children"; \
|
377 |
+
# if [ -n "$children" ]; then \
|
378 |
+
# for child in $children; do \
|
379 |
+
# ppid=$(ps -o ppid= -p $child | tr -d " ")
|
380 |
+
# if [ "$ppid" -eq "$p" ]; then
|
381 |
+
# echo "Killing direct child process: PID $child with cmd: $(ps -p $child -o cmd=)"
|
382 |
+
# kill -USR2 $child &
|
383 |
+
# else
|
384 |
+
# echo "Skipping non-direct child process: PID $child with PPID $ppid"
|
385 |
+
# fi
|
386 |
+
# done; \
|
387 |
+
# echo "Sent kill signals to children of $p"; \
|
388 |
+
# else \
|
389 |
+
# echo "No children found for $p"; \
|
390 |
+
# fi; \
|
391 |
+
# done; \
|
392 |
+
# wait;' SIGUSR2
|
393 |
+
|
394 |
+
checkpointing:
|
395 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
396 |
+
save_dir: ${cwd:}/checkpoints
|
397 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
398 |
+
resume_from_ckpt: true
|
399 |
+
resume_ckpt_path: ${cwd:}/checkpoints
|
400 |
+
initial_resume_ckpt_path: null
|
401 |
+
resume_wandb: true
|
402 |
+
checkpoints_total_limit: 2
|
403 |
+
use_automatic_naming: false
|
404 |
+
|
405 |
+
|
406 |
+
data:
|
407 |
+
cache_dir: ${oc.env:HF_DATASETS_CACHE,/grogu/user/mprabhud/aswerdlo/huggingface/datasets}
|
408 |
+
num_proc: ${eval:"max(len(__import__('os').sched_getaffinity(0)) // 4, 16)"}
|
409 |
+
cond_resolution: null
|
410 |
+
iterable: false
|
411 |
+
force_disable_shuffle: false
|
412 |
+
pin_dataset_to_gpu: false
|
413 |
+
webdataset_iterable: false
|
414 |
+
webdataset_train_data: null
|
415 |
+
webdataset_val_data: null
|
416 |
+
webdataset_train_num_samples: null
|
417 |
+
webdataset_val_num_samples: null
|
418 |
+
webdataset_indexed: false
|
419 |
+
dataset_type: null
|
420 |
+
keep_tensordict_on_disk: false
|
421 |
+
use_token_dataset: false
|
422 |
+
use_custom_tensordict_collate: false
|
423 |
+
use_weighted_tensordict_sampler: false
|
424 |
+
enable_cuda_in_tensordict_collate: true
|
425 |
+
data_dir_train: null
|
426 |
+
data_dir_val: null
|
427 |
+
token_output_dir: null
|
428 |
+
wrap_dataloaders: true
|
429 |
+
force_shuffle_train: false
|
430 |
+
move_tensordict_to_shm: false
|
431 |
+
keep_hf_dataset_in_memory: false
|
432 |
+
use_chameleon: false
|
433 |
+
tokenize_vqvae_in_dataloader: false
|
434 |
+
force_mp_spawn: false
|
435 |
+
force_raw_images_in_multiple_tensordict: false
|
436 |
+
disable_text_modality: false
|
437 |
+
txt_only: false
|
438 |
+
disable_mask_after_eos: false
|
439 |
+
allow_label: false
|
440 |
+
split_dataset: false
|
441 |
+
img_token_shift: ${model.text_vocab_size}
|
442 |
+
zero_shot_eval_dataset: null
|
443 |
+
require_sample_ids: false
|
444 |
+
use_packing_collate: false
|
445 |
+
dynamic_packing_lengths: false
|
446 |
+
remove_txt_img_padding: false
|
447 |
+
add_image_gen_tokens: false
|
448 |
+
use_slow_tokenizer: false
|
449 |
+
add_image_token: false
|
450 |
+
|
451 |
+
dummyarg: null
|
configs/config_empty.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _self_
|
3 |
+
- /model: small
|
4 |
+
- /experiments: []
|
5 |
+
|
6 |
+
# from omegaconf import OmegaConf
|
7 |
+
# with open("config.yaml", "w") as fp:
|
8 |
+
# OmegaConf.save(config=config, f=fp.name)
|
configs/experiments/ar.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
parameterization: ar
|
4 |
+
|
5 |
+
trainer:
|
6 |
+
ar_shift: true
|
7 |
+
|
8 |
+
model:
|
9 |
+
full_attention: false
|
10 |
+
use_flex_attention: false
|
configs/experiments/elm.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
backbone: elm
|
4 |
+
|
5 |
+
data:
|
6 |
+
tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
|
7 |
+
|
8 |
+
model:
|
9 |
+
use_lora: false
|
10 |
+
full_attention: true
|
11 |
+
model_id: apple/OpenELM-270M # apple/OpenELM-1_1B
|
12 |
+
|
13 |
+
trainer:
|
14 |
+
use_gradient_checkpointing: false
|
15 |
+
sd3_compile_config: false
|
configs/experiments/eval_model.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
mode: eval
|
4 |
+
|
5 |
+
loader:
|
6 |
+
batch_size: 16
|
7 |
+
eval_batch_size: 16
|
8 |
+
|
9 |
+
trainer:
|
10 |
+
disable_all_eval_generation: false
|
11 |
+
|
12 |
+
eval:
|
13 |
+
compute_generative_perplexity: true
|
14 |
+
generate_samples: true
|
15 |
+
num_sample_batches: 20
|
16 |
+
log_every_n_fid: 1
|
17 |
+
log_every_n_evals: 1
|
18 |
+
compute_standalone_mauve: true
|
19 |
+
mauve_num_samples: 5000
|
20 |
+
# mauve_divergence_curve_discretization_size: 200 # works well for our repo
|
21 |
+
# mauve_scaling_factor: 2 # works well for our repo
|
configs/experiments/eval_text.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
mode: eval
|
4 |
+
|
5 |
+
sampling:
|
6 |
+
steps: 100
|
7 |
+
max_sampling_steps: 100
|
8 |
+
|
9 |
+
loader:
|
10 |
+
batch_size: 2
|
11 |
+
eval_batch_size: 2
|
12 |
+
|
13 |
+
trainer:
|
14 |
+
fsdp: false
|
15 |
+
|
16 |
+
eval:
|
17 |
+
perplexity_batch_size: 2
|
18 |
+
num_masking_viz_batches: 2
|
19 |
+
log_every_n_evals: 1
|
20 |
+
num_uncond_sample_batches: 2
|
21 |
+
num_sample_batches: 2
|
22 |
+
num_random_masking: 1
|
23 |
+
masking_batch_size: 2
|
24 |
+
cfg: null
|
25 |
+
generate_samples: true
|
26 |
+
compute_generative_perplexity: false
|
configs/experiments/eval_text_only.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
mode: eval
|
4 |
+
debug: true
|
5 |
+
|
6 |
+
sampling:
|
7 |
+
steps: 100
|
8 |
+
max_sampling_steps: 100
|
9 |
+
|
10 |
+
loader:
|
11 |
+
batch_size: 2
|
12 |
+
eval_batch_size: 2
|
13 |
+
|
14 |
+
trainer:
|
15 |
+
fsdp: false
|
16 |
+
|
17 |
+
model:
|
18 |
+
image_model_fid_eval: false
|
19 |
+
|
20 |
+
eval:
|
21 |
+
log_every_n_evals: 1
|
22 |
+
perplexity_batch_size: 2
|
23 |
+
num_uncond_sample_batches: 2
|
24 |
+
num_sample_batches: 2
|
25 |
+
num_masking_viz_batches: -1
|
26 |
+
num_random_masking: -1
|
27 |
+
masking_batch_size: -1
|
28 |
+
cfg: null
|
29 |
+
generate_samples: true
|
30 |
+
compute_generative_perplexity: true
|
configs/experiments/eval_unified.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
mode: eval
|
4 |
+
devices: ${device_count:}
|
5 |
+
|
6 |
+
sampling:
|
7 |
+
steps: 500
|
8 |
+
max_sampling_steps: 1000
|
9 |
+
|
10 |
+
loader:
|
11 |
+
batch_size: 6
|
12 |
+
eval_batch_size: 6
|
13 |
+
|
14 |
+
trainer:
|
15 |
+
fsdp: false
|
16 |
+
disable_all_eval_generation: false
|
17 |
+
|
18 |
+
eval:
|
19 |
+
perplexity_batch_size: 6
|
20 |
+
num_masking_viz_batches: 12
|
21 |
+
log_every_n_evals: 1
|
22 |
+
num_uncond_sample_batches: 5
|
23 |
+
num_sample_batches: 2
|
24 |
+
num_random_masking: 3
|
25 |
+
masking_batch_size: 6
|
26 |
+
cfg: 6.0
|
27 |
+
generate_samples: false
|
configs/experiments/fid_cc12m.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
data:
|
4 |
+
keep_hf_dataset_in_memory: true
|
5 |
+
aggressive_aug: false
|
6 |
+
n_duplicate_train: null
|
7 |
+
n_duplicate_val: null
|
8 |
+
|
9 |
+
tokenize_vqvae_in_dataloader: false
|
10 |
+
enable_cuda_in_tensordict_collate: false
|
11 |
+
force_mp_spawn: false
|
12 |
+
keep_tensordict_on_disk: false
|
13 |
+
move_tensordict_to_shm: false
|
14 |
+
|
15 |
+
fid_dataset: cc12m_tokens_val_256
|
16 |
+
image_data_train: null
|
17 |
+
image_data_val: null
|
18 |
+
data_dir_train: ${data.data_dir_val}
|
19 |
+
data_dir_val:
|
20 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
|
21 |
+
weight: 1
|
22 |
+
name: ${data.fid_dataset}
|
configs/experiments/fid_datacomp1b.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
data:
|
4 |
+
keep_hf_dataset_in_memory: true
|
5 |
+
aggressive_aug: false
|
6 |
+
n_duplicate_train: null
|
7 |
+
n_duplicate_val: null
|
8 |
+
|
9 |
+
tokenize_vqvae_in_dataloader: false
|
10 |
+
enable_cuda_in_tensordict_collate: false
|
11 |
+
force_mp_spawn: false
|
12 |
+
keep_tensordict_on_disk: false
|
13 |
+
move_tensordict_to_shm: false
|
14 |
+
|
15 |
+
fid_dataset: datacomp1b_8_magvit_val
|
16 |
+
image_data_train: null
|
17 |
+
image_data_val: null
|
18 |
+
data_dir_train: ${data.data_dir_val}
|
19 |
+
data_dir_val:
|
20 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
|
21 |
+
weight: -1
|
22 |
+
name: ${data.fid_dataset}
|
configs/experiments/fid_hf.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
data:
|
4 |
+
disable_text_modality: false
|
5 |
+
keep_hf_dataset_in_memory: true
|
6 |
+
aggressive_aug: false
|
7 |
+
n_duplicate_train: null
|
8 |
+
n_duplicate_val: null
|
9 |
+
data_dir_train: []
|
10 |
+
data_dir_val: []
|
11 |
+
fid_dataset: sayakpaul/coco-30-val-2014
|
12 |
+
train: combined_tokens
|
13 |
+
val: {.train}
|
14 |
+
image_data_val:
|
15 |
+
- val: ${data.fid_dataset}
|
16 |
+
weight: -1
|
17 |
+
name: ${.val}
|
18 |
+
tokenize_vqvae_in_dataloader: false
|
19 |
+
raw_images: true
|
20 |
+
image_data_train:
|
21 |
+
- train: ${data.fid_dataset}
|
22 |
+
weight: -1
|
23 |
+
name: ${.train}
|
24 |
+
tokenize_vqvae_in_dataloader: false
|
25 |
+
raw_images: true
|
configs/experiments/jan_cub.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /model: medium
|
5 |
+
- override /lr_scheduler: cosine_with_hard_restarts_schedule_with_warmup
|
6 |
+
|
7 |
+
loader:
|
8 |
+
batch_size: 16
|
9 |
+
eval_batch_size: 16
|
10 |
+
desired_global_batch_size: 128
|
11 |
+
num_workers: 4
|
12 |
+
|
13 |
+
trainer:
|
14 |
+
ckpt_steps: 5000
|
15 |
+
val_check_interval: 100
|
16 |
+
use_legacy_update_batch_fn: true
|
17 |
+
mask_txt_only: true
|
18 |
+
mask_entire_modality: 0.15
|
19 |
+
ema: 0.9999
|
20 |
+
use_custom_ema: true
|
21 |
+
force_enable_checkpointing: true
|
22 |
+
skip_early_checkpointing: false
|
23 |
+
force_after_eos_padding: false
|
24 |
+
|
25 |
+
checkpointing:
|
26 |
+
checkpoints_total_limit: 20
|
27 |
+
|
28 |
+
lr_scheduler:
|
29 |
+
num_warmup_steps: 10000
|
30 |
+
num_training_steps: 400000
|
31 |
+
num_cycles: 80
|
32 |
+
|
33 |
+
data:
|
34 |
+
resolution: 256
|
35 |
+
train: cub2011_custom
|
36 |
+
use_weighted_tensordict_sampler: false
|
37 |
+
|
38 |
+
model:
|
39 |
+
vae_type: titok128
|
40 |
+
txt_length: 18
|
41 |
+
img_length: 128
|
42 |
+
rope_2d: false
|
43 |
+
force_text_vocab_size: 5450
|
44 |
+
text_vocab_size: 5451
|
45 |
+
image_vocab_size: 8192
|
46 |
+
attn_dropout: 0.1
|
47 |
+
|
48 |
+
optim:
|
49 |
+
lr: 1.0e-04
|
50 |
+
weight_decay: 0.2
|
51 |
+
beta2: 0.99
|
configs/experiments/large_maskdit_exp.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- override /model: large_maskdit
|
5 |
+
|
6 |
+
|
7 |
+
backbone: maskdit
|
configs/experiments/large_scale_high_res_interleaved_inference.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
debug: true
|
4 |
+
seed: 163
|
5 |
+
|
6 |
+
loader:
|
7 |
+
eval_batch_size: 1
|
8 |
+
batch_size: 1
|
9 |
+
|
10 |
+
data:
|
11 |
+
move_tensordict_to_shm: false
|
12 |
+
resolution: 1024
|
13 |
+
disable_mask_after_eos: true
|
14 |
+
disable_packing: true
|
15 |
+
data_dir_val:
|
16 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
|
17 |
+
weight: 1.0
|
18 |
+
name: HPDv2_image_reward_512
|
19 |
+
|
20 |
+
model:
|
21 |
+
img_length: 4096
|
22 |
+
txt_length: 1024
|
23 |
+
length: 5120
|
24 |
+
|
25 |
+
trainer:
|
26 |
+
compile: false
|
27 |
+
limit_val_batches: 2
|
28 |
+
fsdp: false
|
29 |
+
force_full_attention_mask: true
|
30 |
+
force_null_sigma: true
|
31 |
+
allow_null_sigma: true
|
32 |
+
|
33 |
+
eval:
|
34 |
+
num_sample_batches: 1
|
35 |
+
num_random_masking: 0
|
36 |
+
num_masking_viz_batches: 0
|
37 |
+
limit_val_batches_manual: 1
|
38 |
+
num_uncond_sample_batches: 10
|
39 |
+
eval_large_batch: 10
|
40 |
+
val_with_train_data: false
|
41 |
+
maskgit_r_temp: 4.5
|
42 |
+
half_uncond: false
|
43 |
+
cfg: 3.0
|
44 |
+
return_interleaved_modalities_split: true
|
45 |
+
static_img_txt_demo: true
|
46 |
+
visualize_sample: true
|
47 |
+
|
48 |
+
sampling:
|
49 |
+
steps: 50
|
50 |
+
max_sampling_steps: 50
|
51 |
+
predictor: "maskgit"
|
configs/experiments/large_scale_train.yaml
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- vq16_t2i
|
5 |
+
- override /model: extra_large
|
6 |
+
|
7 |
+
data:
|
8 |
+
train: combined_tokens
|
9 |
+
valid: ${.train}
|
10 |
+
precache: false
|
11 |
+
streaming: false
|
12 |
+
resolution: 256
|
13 |
+
block_size: 128
|
14 |
+
tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
|
15 |
+
wrap: true
|
16 |
+
iterable: false
|
17 |
+
webdataset_iterable: false
|
18 |
+
webdataset_indexed: false
|
19 |
+
unpaired: false
|
20 |
+
dataset_type: null
|
21 |
+
tokens_flip_collate: false
|
22 |
+
n_val_samples: null
|
23 |
+
n_train_samples: null
|
24 |
+
n_duplicate_train: null
|
25 |
+
n_duplicate_val: null
|
26 |
+
raw_data_dir: null
|
27 |
+
save_train_dataloader: true
|
28 |
+
save_validation_dataloader: true
|
29 |
+
tokenizers_parallelism: false
|
30 |
+
token_data_dir: null
|
31 |
+
force_disable_shuffle: false
|
32 |
+
use_custom_tensordict_collate: true
|
33 |
+
use_weighted_tensordict_sampler: true
|
34 |
+
force_mp_spawn: false
|
35 |
+
enable_cuda_in_tensordict_collate: false
|
36 |
+
use_token_dataset: true
|
37 |
+
keep_tensordict_on_disk: true
|
38 |
+
move_tensordict_to_shm: false
|
39 |
+
add_text_to_weighted_sampler: false
|
40 |
+
data_dir_train:
|
41 |
+
# - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
|
42 |
+
# weight: 15.0
|
43 |
+
# name: hpdv2
|
44 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
|
45 |
+
weight: 1.0
|
46 |
+
name: pixelprose
|
47 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/journeydb_train
|
48 |
+
weight: 10.0
|
49 |
+
name: journeydb_train
|
50 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_0_tokens
|
51 |
+
weight: 1.0
|
52 |
+
name: datacomp0
|
53 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_1_tokens
|
54 |
+
weight: 1.0
|
55 |
+
name: datacomp1
|
56 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_2_tokens
|
57 |
+
weight: 1.0
|
58 |
+
name: datacomp2
|
59 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_3_tokens
|
60 |
+
weight: 1.0
|
61 |
+
name: datacomp3
|
62 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_4_tokens
|
63 |
+
weight: 1.0
|
64 |
+
name: datacomp4
|
65 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_5_tokens
|
66 |
+
weight: 1.0
|
67 |
+
name: datacomp5
|
68 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_6_tokens
|
69 |
+
weight: 1.0
|
70 |
+
name: datacomp6
|
71 |
+
data_dir_val:
|
72 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
|
73 |
+
weight: 1.0
|
74 |
+
name: dummy_1
|
75 |
+
|
76 |
+
model:
|
77 |
+
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
|
78 |
+
txt_length: ${eval:'${data.block_size} if ${.unified_model} else 0'}
|
79 |
+
length: ${eval:'${.txt_length} + ${.img_length}'}
|
80 |
+
unified_model: true
|
81 |
+
image_model: true
|
82 |
+
text_model: true
|
83 |
+
image_model_fid_eval: false
|
84 |
+
force_argmax_valid_indices: true
|
85 |
+
use_pretrained_img_emb: false
|
86 |
+
rope_2d: true
|
87 |
+
modality_embed: true
|
88 |
+
norm_type: rms
|
89 |
+
qk_norm: true
|
90 |
+
sandwich_normalization: true
|
91 |
+
text_vocab_size: 32001
|
92 |
+
|
93 |
+
loader:
|
94 |
+
batch_size: 8
|
95 |
+
eval_batch_size: ${eval:'${.batch_size} // 2'}
|
96 |
+
desired_global_batch_size: 512
|
97 |
+
persistent_workers: true
|
98 |
+
pin_memory: false
|
99 |
+
num_workers: 0
|
100 |
+
num_eval_workers: 0
|
101 |
+
eval:
|
102 |
+
log_every_n_evals: -1
|
103 |
+
log_every_n_fid: -1
|
104 |
+
limit_val_batches_manual: 16
|
105 |
+
generate_samples: true
|
106 |
+
compute_generative_perplexity: false
|
107 |
+
perplexity_batch_size: ${loader.eval_batch_size}
|
108 |
+
cfg: 5.0
|
109 |
+
num_val_metrics_standalone_samples: -1
|
110 |
+
num_val_metrics_standalone_batches_per_device: -1
|
111 |
+
auto_enhance_reward_config:
|
112 |
+
dfn_score: 1.0
|
113 |
+
laion_aesthetic_score: 1.0
|
114 |
+
|
115 |
+
trainer:
|
116 |
+
log_flops: false
|
117 |
+
log_every_n_steps: 10
|
118 |
+
custom_ddp_bf16: true
|
119 |
+
log_seperate_modal_losses: true
|
120 |
+
limit_val_batches: 16
|
121 |
+
softmin_snr: 5
|
122 |
+
text_loss_weight: 1.0
|
123 |
+
img_loss_weight: 0.6
|
124 |
+
use_gradient_checkpointing: false
|
125 |
+
ckpt_steps: 20000
|
126 |
+
ckpt_every_n_minutes: 180
|
127 |
+
ckpt_recent_timeout_minutes: 10
|
128 |
+
use_custom_ema: false
|
129 |
+
ema: 0.0
|
130 |
+
fsdp: true
|
131 |
+
restart_on_failure: true
|
132 |
+
eval_on_start: false
|
133 |
+
val_check_interval: 100000000000
|
134 |
+
scale_lr_by_batch_size: false
|
135 |
+
watch_gradients: false
|
136 |
+
compile: true
|
137 |
+
mask_entire_modality: 0.15
|
138 |
+
compile_flag_pos_emb: true
|
139 |
+
multimodal_batches: true
|
140 |
+
optim:
|
141 |
+
lr: 0.0001
|
142 |
+
sampling:
|
143 |
+
steps: 128
|
144 |
+
num_sample_batches: 2
|
145 |
+
wandb:
|
146 |
+
mode: online
|
147 |
+
checkpointing:
|
148 |
+
checkpoints_total_limit: 10
|
149 |
+
use_automatic_naming: false
|
150 |
+
lr_scheduler:
|
151 |
+
num_warmup_steps: 10000
|
configs/experiments/large_scale_train_high_res.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# @package _global_
|
3 |
+
|
4 |
+
data:
|
5 |
+
resolution: 512
|
6 |
+
data_dir_train:
|
7 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
|
8 |
+
weight: 1
|
9 |
+
name: HPDv2_image_reward_512
|
10 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_512
|
11 |
+
weight: 2
|
12 |
+
name: pick_score_sac_prompts_v1_v2_v3_512
|
13 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_7_512
|
14 |
+
weight: 0.5
|
15 |
+
name: datacomp1b_7_512
|
16 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/text/slimpajama6b
|
17 |
+
weight: 2.5
|
18 |
+
name: slimpajama6b
|
19 |
+
data_dir_val:
|
20 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/gecko_eval_512
|
21 |
+
weight: 1.0
|
22 |
+
name: gecko_eval_512
|
23 |
+
|
24 |
+
trainer:
|
25 |
+
text_loss_weight: 1.0
|
26 |
+
img_loss_weight: 0.5
|
27 |
+
force_full_attention_mask: true
|
28 |
+
mask_entire_modality: 0.1
|
29 |
+
|
30 |
+
loader:
|
31 |
+
pin_memory: false
|
32 |
+
num_workers: 4
|
33 |
+
num_eval_workers: 4
|
34 |
+
|
35 |
+
lr_scheduler:
|
36 |
+
num_warmup_steps: 5000
|
37 |
+
|
38 |
+
model:
|
39 |
+
linear_factor: 2
|
configs/experiments/large_scale_train_high_res_inference.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
data:
|
4 |
+
use_token_dataset: true
|
5 |
+
disable_mask_after_eos: true
|
6 |
+
move_tensordict_to_shm: false
|
7 |
+
|
8 |
+
trainer:
|
9 |
+
compile_flag_pos_emb: true
|
10 |
+
multimodal_batches: true
|
11 |
+
allow_null_sigma: true
|
12 |
+
|
13 |
+
eval:
|
14 |
+
num_sample_batches: 1
|
15 |
+
num_random_masking: 0
|
16 |
+
num_masking_viz_batches: 0
|
17 |
+
limit_val_batches_manual: 1
|
18 |
+
num_uncond_sample_batches: 10
|
19 |
+
eval_large_batch: 10
|
20 |
+
val_with_train_data: false
|
21 |
+
maskgit_r_temp: 4.5
|
22 |
+
half_uncond: false
|
23 |
+
cfg: 3.0
|
24 |
+
static_img_txt_demo: true
|
25 |
+
visualize_sample: true
|
26 |
+
|
27 |
+
sampling:
|
28 |
+
steps: 50
|
29 |
+
max_sampling_steps: 50
|
30 |
+
predictor: "maskgit"
|
configs/experiments/large_scale_train_high_res_interleaved.yaml
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# @package _global_
|
3 |
+
|
4 |
+
data:
|
5 |
+
move_tensordict_to_shm: false
|
6 |
+
enable_cuda_in_tensordict_collate: false
|
7 |
+
force_mp_spawn: false
|
8 |
+
resolution: 512
|
9 |
+
add_text_to_weighted_sampler: false
|
10 |
+
|
11 |
+
add_image_gen_tokens: true
|
12 |
+
use_packing_collate: true
|
13 |
+
dynamic_packing_lengths: true
|
14 |
+
remove_txt_img_padding: true
|
15 |
+
require_sample_ids: true
|
16 |
+
block_size: ${model.length}
|
17 |
+
disable_mask_after_eos: true
|
18 |
+
add_image_token: true
|
19 |
+
use_slow_tokenizer: true
|
20 |
+
force_seed: true
|
21 |
+
|
22 |
+
data_dir_train:
|
23 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/HPDv2_image_reward_v1_v2_v3/train
|
24 |
+
weight: 0.5
|
25 |
+
name: HPDv2_image_reward_v1_v2_v3 # 3593248
|
26 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_512
|
27 |
+
weight: 1.0
|
28 |
+
name: pick_score_sac_prompts_v1_v2_v3_512 # 9330810
|
29 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/pixelprose_tokens
|
30 |
+
weight: 1.0
|
31 |
+
name: pixelprose_tokens # 6627589
|
32 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/cambrian_10m_v5
|
33 |
+
weight: 1.0
|
34 |
+
name: cambrian_10m_v5 # 8215264
|
35 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_7_512
|
36 |
+
weight: 1.0
|
37 |
+
name: datacomp1b_7_512 # 23955209
|
38 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_2_tokens
|
39 |
+
weight: 0.5
|
40 |
+
name: datacomp_1b_datacomp1b_2_tokens # 10161505
|
41 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_4_tokens
|
42 |
+
weight: 0.5
|
43 |
+
name: datacomp_1b_datacomp1b_4_tokens # 27895717
|
44 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/mmc4_fewer_faces_v0
|
45 |
+
weight: 2.0
|
46 |
+
name: mmc4_fewer_faces_v0 # 22605524
|
47 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_matrix/datacomp_1b_datacomp1b_5_tokens
|
48 |
+
weight: 0.5
|
49 |
+
name: datacomp_1b_datacomp1b_5_tokens
|
50 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_0_tokens
|
51 |
+
weight: 0.5
|
52 |
+
name: datacomp_1b_datacomp1b_0_tokens
|
53 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/07_31_2024_grogu/datacomp_1b_datacomp1b_1_tokens
|
54 |
+
weight: 0.5
|
55 |
+
name: datacomp_1b_datacomp1b_1_tokens
|
56 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/cosmopedia_2_v0
|
57 |
+
weight: 1.0
|
58 |
+
name: cosmopedia_v2
|
59 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/babel/fineweb_edu_dedup_v0
|
60 |
+
weight: 1.0
|
61 |
+
name: fineweb_edu_dedup
|
62 |
+
data_dir_val:
|
63 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/matrix/gecko_eval_512
|
64 |
+
weight: 1.0
|
65 |
+
name: gecko_eval_512
|
66 |
+
|
67 |
+
trainer:
|
68 |
+
text_loss_weight: 1.0
|
69 |
+
img_loss_weight: 0.2
|
70 |
+
mask_entire_modality: 0.2
|
71 |
+
|
72 |
+
force_full_attention_mask: false
|
73 |
+
force_full_attention_mask_loss_only: false
|
74 |
+
disable_all_eval_generation: true
|
75 |
+
interleaved: true
|
76 |
+
interleaved_training_flex_attention: true
|
77 |
+
force_convert_to_dict: true
|
78 |
+
val_check_interval: -1
|
79 |
+
use_gradient_checkpointing: true
|
80 |
+
disable_all_checkpointing: false
|
81 |
+
set_max_txt_loss_ratio: true
|
82 |
+
gradient_clip_val: 1.0
|
83 |
+
skip_early_checkpointing: false
|
84 |
+
bypass_load_from_state_dicts_if_resuming: true
|
85 |
+
|
86 |
+
loader:
|
87 |
+
num_workers: 4
|
88 |
+
num_eval_workers: 4
|
89 |
+
|
90 |
+
lr_scheduler:
|
91 |
+
num_warmup_steps: 5000
|
92 |
+
|
93 |
+
model:
|
94 |
+
linear_factor: 2
|
95 |
+
use_flex_attention: true
|
96 |
+
use_spda_attn: true
|
97 |
+
|
98 |
+
length: 1536
|
99 |
+
txt_length: ${.length}
|
100 |
+
img_length: ${.length}
|
101 |
+
|
102 |
+
eval:
|
103 |
+
generate_samples: false
|
104 |
+
disable_visualization: true
|
105 |
+
|
configs/experiments/maskgit.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
model:
|
4 |
+
downscale_ratio: 16
|
5 |
+
image_vocab_size: 1024
|
6 |
+
vae_type: maskgit
|
configs/experiments/master_eval.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
mode: eval
|
4 |
+
|
5 |
+
eval:
|
6 |
+
fid_samples: 4096
|
7 |
+
max_num_fid_batches_per_device: ${eval:'max(${eval.fid_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
|
8 |
+
compute_generative_perplexity: true
|
9 |
+
generate_samples: true
|
10 |
+
log_every_n_fid: 1
|
11 |
+
log_every_n_evals: 1
|
12 |
+
class_conditional_fid: false
|
13 |
+
txt_conditional_fid: true
|
14 |
+
calculate_clip_score: true
|
15 |
+
cfg: 5
|
16 |
+
num_sample_batches: 2
|
17 |
+
compute_standalone_mauve: false
|
18 |
+
mauve_num_samples: -1
|
19 |
+
set_random_gen_seed: true
|
20 |
+
# gen_ppl_eval_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
|
21 |
+
compute_img_to_txt_mauve_clip: true
|
22 |
+
compute_img_to_txt_mauve_during_unconditional_fid: true
|
23 |
+
force_eval_uncond: true
|
24 |
+
ablation_config: true
|
25 |
+
compute_val_metrics_standalone: true
|
26 |
+
num_val_metrics_standalone_samples: 2000
|
27 |
+
|
28 |
+
trainer:
|
29 |
+
disable_all_eval_generation: false
|
30 |
+
force_after_eos_padding: true
|
31 |
+
|
32 |
+
model:
|
33 |
+
image_model_fid_eval: true
|
34 |
+
use_kv_cache: ${is_ar:${parameterization}}
|
35 |
+
|
36 |
+
loader:
|
37 |
+
batch_size: 64
|
38 |
+
eval_batch_size: 64
|
39 |
+
num_workers: 0
|
40 |
+
num_eval_workers: 1
|
41 |
+
|
42 |
+
sampling:
|
43 |
+
steps: ${model.length}
|
44 |
+
max_sampling_steps: ${sampling.steps}
|
45 |
+
sampling_step_frac: null
|
46 |
+
|
47 |
+
|
48 |
+
data:
|
49 |
+
fid_dataset: null
|
configs/experiments/mscoco_fid.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
data:
|
4 |
+
disable_text_modality: false
|
5 |
+
keep_hf_dataset_in_memory: true
|
6 |
+
aggressive_aug: false
|
7 |
+
n_duplicate_train: null
|
8 |
+
n_duplicate_val: null
|
9 |
+
data_dir_train: []
|
10 |
+
data_dir_val: []
|
11 |
+
image_data_train: ${data.image_data_val}
|
12 |
+
image_data_val:
|
13 |
+
- val: sayakpaul/coco-30-val-2014
|
14 |
+
weight: -1
|
15 |
+
name: mscoco_val
|
16 |
+
tokenize_vqvae_in_dataloader: false
|
17 |
+
raw_images: true
|
18 |
+
|
19 |
+
eval:
|
20 |
+
compute_generative_perplexity: true
|
21 |
+
generate_samples: true
|
configs/experiments/paired_standalone_fid_eval.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
mode: eval
|
4 |
+
debug: true
|
5 |
+
|
6 |
+
eval:
|
7 |
+
fid_samples: 4096
|
8 |
+
max_num_fid_batches_per_device: ${eval:'max(${eval.fid_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
|
9 |
+
compute_generative_perplexity: false
|
10 |
+
generate_samples: false
|
11 |
+
log_every_n_fid: 1
|
12 |
+
log_every_n_evals: 1
|
13 |
+
class_conditional_fid: false
|
14 |
+
txt_conditional_fid: true
|
15 |
+
calculate_clip_score: true
|
16 |
+
cfg: 5
|
17 |
+
|
18 |
+
model:
|
19 |
+
image_model_fid_eval: true
|
20 |
+
|
21 |
+
loader:
|
22 |
+
eval_batch_size: 32
|
23 |
+
|
24 |
+
sampling:
|
25 |
+
steps: ${model.length}
|
26 |
+
max_sampling_steps: ${model.length}
|
27 |
+
|
28 |
+
data:
|
29 |
+
keep_hf_dataset_in_memory: false
|
configs/experiments/small_scale_train.yaml
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- vq16_magvit
|
5 |
+
- override /model: small
|
6 |
+
- override /lr_scheduler: constant_warmup_cosine_decay
|
7 |
+
|
8 |
+
model:
|
9 |
+
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
|
10 |
+
txt_length: ${eval:'${data.block_size} if ${.unified_model} else 0'}
|
11 |
+
length: ${eval:'${.txt_length} + ${.img_length}'}
|
12 |
+
image_model: true
|
13 |
+
text_model: true
|
14 |
+
unified_model: true
|
15 |
+
image_model_fid_eval: false
|
16 |
+
force_argmax_valid_indices: true
|
17 |
+
use_pretrained_img_emb: false
|
18 |
+
codebook_embed_dim: 256
|
19 |
+
qk_norm: true
|
20 |
+
norm_type: rms
|
21 |
+
sandwich_normalization: true
|
22 |
+
zero_linear_init: false
|
23 |
+
modality_embed: true
|
24 |
+
rope_2d: false
|
25 |
+
use_spda_attn: true
|
26 |
+
force_optimized_native_attn: true
|
27 |
+
freeze_txt_emb: false
|
28 |
+
add_labels: null
|
29 |
+
txt_dropout: null
|
30 |
+
text_vocab_size: 32001
|
31 |
+
|
32 |
+
data:
|
33 |
+
train: combined_tokens
|
34 |
+
valid: ${.train}
|
35 |
+
n_duplicate_train: null
|
36 |
+
wrap: true
|
37 |
+
streaming: false
|
38 |
+
precache: false
|
39 |
+
tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
|
40 |
+
resolution: 256
|
41 |
+
block_size: 128
|
42 |
+
n_val_samples: null
|
43 |
+
unpaired: false
|
44 |
+
n_duplicate_val: null
|
45 |
+
save_train_dataloader: true
|
46 |
+
save_validation_dataloader: true
|
47 |
+
iterable: false
|
48 |
+
webdataset_iterable: false
|
49 |
+
webdataset_indexed: false
|
50 |
+
dataset_type: null
|
51 |
+
tokens_flip_collate: false
|
52 |
+
n_train_samples: null
|
53 |
+
raw_data_dir: null
|
54 |
+
tokenizers_parallelism: false
|
55 |
+
token_data_dir: null
|
56 |
+
force_disable_shuffle: false
|
57 |
+
keep_tensordict_on_disk: true
|
58 |
+
use_custom_tensordict_collate: true
|
59 |
+
force_mp_spawn: false
|
60 |
+
enable_cuda_in_tensordict_collate: false
|
61 |
+
use_weighted_tensordict_sampler: true
|
62 |
+
fraction_txt_data: 0.0
|
63 |
+
tokenize_vqvae_in_dataloader: false
|
64 |
+
use_token_dataset: true
|
65 |
+
image_dataset: tglcourse/lsun_church_train
|
66 |
+
image_data_train: null
|
67 |
+
image_data_val: null
|
68 |
+
keep_hf_dataset_in_memory: true
|
69 |
+
allow_label: false
|
70 |
+
disable_text_modality: true
|
71 |
+
force_raw_train_images: false
|
72 |
+
aggressive_aug: true
|
73 |
+
allow_aug_vqvae_dataloader: true
|
74 |
+
move_tensordict_to_shm: false
|
75 |
+
data_dir_train:
|
76 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
|
77 |
+
weight: -1
|
78 |
+
name: datacomp1b_8_magvit_train
|
79 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
|
80 |
+
weight: -1
|
81 |
+
name: cc12m_tokens_train_256
|
82 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
|
83 |
+
weight: -1
|
84 |
+
name: HPDv2_image_reward_v1_v2_v3_magvit
|
85 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
|
86 |
+
weight: -1
|
87 |
+
name: pick_score_sac_prompts_v1_v2_v3_magvit
|
88 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
|
89 |
+
weight: -1
|
90 |
+
name: datacomp1b_0_1_6_magvit
|
91 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
|
92 |
+
weight: -1
|
93 |
+
name: laion400m_magvit_part_0
|
94 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
|
95 |
+
weight: -1
|
96 |
+
name: laion400m_magvit_part_1
|
97 |
+
data_dir_val:
|
98 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
|
99 |
+
weight: 1
|
100 |
+
name: datacomp1b_8_magvit_val
|
101 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
|
102 |
+
weight: 1
|
103 |
+
name: cc12m_tokens_val_256
|
104 |
+
|
105 |
+
eval:
|
106 |
+
generate_samples: true
|
107 |
+
compute_generative_perplexity: true
|
108 |
+
log_every_n_evals: 10
|
109 |
+
log_every_n_fid: 20
|
110 |
+
limit_val_batches_manual: 16
|
111 |
+
perplexity_batch_size: ${loader.eval_batch_size}
|
112 |
+
num_masking_viz_batches: -1
|
113 |
+
cfg: null
|
114 |
+
class_conditional_fid: false
|
115 |
+
force_cfg_value: true
|
116 |
+
split_cfg_batches: true
|
117 |
+
max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
|
118 |
+
fid_mode: clean
|
119 |
+
clean_fid_precomputed_name: lsun_church
|
120 |
+
clean_fid_precomputed_split: trainfull
|
121 |
+
clean_fid_precomputed_res: 256
|
122 |
+
|
123 |
+
trainer:
|
124 |
+
log_every_n_steps: 10
|
125 |
+
val_check_interval: 1000
|
126 |
+
custom_ddp_bf16: true
|
127 |
+
scale_lr_by_batch_size: false
|
128 |
+
limit_val_batches: 16
|
129 |
+
use_gradient_checkpointing: false
|
130 |
+
log_seperate_modal_losses: true
|
131 |
+
softmin_snr: 5
|
132 |
+
text_loss_weight: 1.0
|
133 |
+
img_loss_weight: null
|
134 |
+
low_precision_loss: false
|
135 |
+
compile: true
|
136 |
+
multimodal_batches: true
|
137 |
+
compile_fullgraph: false
|
138 |
+
log_grad_norm_every_n_steps: 10
|
139 |
+
mask_entire_modality: 0.1
|
140 |
+
force_shift_image_batches: false
|
141 |
+
ckpt_steps: 10000
|
142 |
+
ckpt_every_n_minutes: -1
|
143 |
+
ignore_text_in_unified: false
|
144 |
+
disable_all_eval_generation: true
|
145 |
+
eval_on_start: false
|
146 |
+
ckpt_model_only: false
|
147 |
+
ema: 0.0
|
148 |
+
use_custom_ema: false
|
149 |
+
log_flops: false
|
150 |
+
disable_distributed_torchmetrics: true
|
151 |
+
restart_on_failure: true
|
152 |
+
force_null_sigma: true
|
153 |
+
allow_null_sigma: true
|
154 |
+
compile_flag_pos_emb: true
|
155 |
+
add_label: false
|
156 |
+
first_token_dropout: null
|
157 |
+
force_shift_raw_image_batches: true
|
158 |
+
txt_dropout: 0.1
|
159 |
+
force_full_attention_mask_loss_only: true
|
160 |
+
|
161 |
+
optim:
|
162 |
+
lr: 0.0003
|
163 |
+
weight_decay: 0.05
|
164 |
+
|
165 |
+
loader:
|
166 |
+
batch_size: 64
|
167 |
+
eval_batch_size: ${loader.batch_size}
|
168 |
+
num_workers: 4
|
169 |
+
desired_global_batch_size: 512
|
170 |
+
persistent_workers: true
|
171 |
+
pin_memory: true
|
172 |
+
num_eval_workers: 1
|
173 |
+
|
174 |
+
sampling:
|
175 |
+
steps: ${model.length}
|
176 |
+
num_sample_batches: 2
|
177 |
+
max_sampling_steps: ${model.length}
|
178 |
+
|
179 |
+
wandb:
|
180 |
+
mode: online
|
181 |
+
|
182 |
+
lr_scheduler:
|
183 |
+
num_warmup_steps: 5000
|
184 |
+
num_training_steps: ${trainer.max_steps}
|
185 |
+
|
186 |
+
checkpointing:
|
187 |
+
checkpoints_total_limit: 10
|
configs/experiments/small_scale_train_caching.yaml
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- /model: small
|
5 |
+
|
6 |
+
model:
|
7 |
+
downscale_ratio: 16
|
8 |
+
image_vocab_size: 8192
|
9 |
+
vae_type: magvit
|
10 |
+
use_custom_vae_ckpt: null
|
11 |
+
custom_vae_name: null
|
12 |
+
img_length: 256
|
13 |
+
txt_length: 128
|
14 |
+
image_model: true
|
15 |
+
text_model: true
|
16 |
+
unified_model: true
|
17 |
+
image_model_fid_eval: false
|
18 |
+
force_argmax_valid_indices: true
|
19 |
+
use_pretrained_img_emb: false
|
20 |
+
codebook_embed_dim: 256
|
21 |
+
qk_norm: true
|
22 |
+
norm_type: rms
|
23 |
+
sandwich_normalization: true
|
24 |
+
zero_linear_init: false
|
25 |
+
modality_embed: true
|
26 |
+
rope_2d: false
|
27 |
+
use_spda_attn: true
|
28 |
+
force_optimized_native_attn: true
|
29 |
+
freeze_txt_emb: false
|
30 |
+
add_labels: null
|
31 |
+
txt_dropout: null
|
32 |
+
text_vocab_size: 32001
|
33 |
+
use_flex_attention: true
|
34 |
+
flex_attention_txt_masking_prob: 0.1
|
35 |
+
flex_attention_img_masking_prob: 0.1
|
36 |
+
linear_factor: 1
|
37 |
+
data:
|
38 |
+
train: combined_tokens
|
39 |
+
valid: ${.train}
|
40 |
+
n_duplicate_train: null
|
41 |
+
wrap: true
|
42 |
+
streaming: false
|
43 |
+
precache: false
|
44 |
+
tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
|
45 |
+
resolution: 256
|
46 |
+
block_size: 128
|
47 |
+
n_val_samples: null
|
48 |
+
unpaired: false
|
49 |
+
n_duplicate_val: null
|
50 |
+
save_train_dataloader: true
|
51 |
+
save_validation_dataloader: true
|
52 |
+
iterable: false
|
53 |
+
webdataset_iterable: false
|
54 |
+
webdataset_indexed: false
|
55 |
+
dataset_type: null
|
56 |
+
tokens_flip_collate: false
|
57 |
+
n_train_samples: null
|
58 |
+
raw_data_dir: null
|
59 |
+
tokenizers_parallelism: false
|
60 |
+
token_data_dir: null
|
61 |
+
force_disable_shuffle: false
|
62 |
+
keep_tensordict_on_disk: true
|
63 |
+
use_custom_tensordict_collate: true
|
64 |
+
force_mp_spawn: false
|
65 |
+
enable_cuda_in_tensordict_collate: false
|
66 |
+
use_weighted_tensordict_sampler: true
|
67 |
+
fraction_txt_data: 0.0
|
68 |
+
data_dir_train:
|
69 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
|
70 |
+
weight: -1
|
71 |
+
name: datacomp1b_8_magvit_train
|
72 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
|
73 |
+
weight: -1
|
74 |
+
name: cc12m_tokens_train_256
|
75 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
|
76 |
+
weight: -1
|
77 |
+
name: HPDv2_image_reward_v1_v2_v3_magvit
|
78 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
|
79 |
+
weight: -1
|
80 |
+
name: pick_score_sac_prompts_v1_v2_v3_magvit
|
81 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
|
82 |
+
weight: -1
|
83 |
+
name: datacomp1b_0_1_6_magvit
|
84 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
|
85 |
+
weight: -1
|
86 |
+
name: laion400m_magvit_part_0
|
87 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
|
88 |
+
weight: -1
|
89 |
+
name: laion400m_magvit_part_1
|
90 |
+
data_dir_val:
|
91 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
|
92 |
+
weight: 1
|
93 |
+
name: datacomp1b_8_magvit_val
|
94 |
+
- dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
|
95 |
+
weight: 1
|
96 |
+
name: cc12m_tokens_val_256
|
97 |
+
tokenize_vqvae_in_dataloader: false
|
98 |
+
val:
|
99 |
+
.train: null
|
100 |
+
use_token_dataset: true
|
101 |
+
image_dataset: tglcourse/lsun_church_train
|
102 |
+
image_data_train: null
|
103 |
+
image_data_val: null
|
104 |
+
keep_hf_dataset_in_memory: true
|
105 |
+
allow_label: false
|
106 |
+
disable_text_modality: true
|
107 |
+
force_raw_train_images: false
|
108 |
+
aggressive_aug: true
|
109 |
+
allow_aug_vqvae_dataloader: true
|
110 |
+
move_tensordict_to_shm: false
|
111 |
+
force_full_attention_mask: false
|
112 |
+
eval:
|
113 |
+
generate_samples: false
|
114 |
+
compute_generative_perplexity: false
|
115 |
+
log_every_n_evals: 10
|
116 |
+
log_every_n_fid: 20
|
117 |
+
limit_val_batches_manual: 16
|
118 |
+
perplexity_batch_size: ${loader.eval_batch_size}
|
119 |
+
num_masking_viz_batches: -1
|
120 |
+
max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
|
121 |
+
cfg: null
|
122 |
+
class_conditional_fid: false
|
123 |
+
force_cfg_value: true
|
124 |
+
split_cfg_batches: true
|
125 |
+
fid_mode: clean
|
126 |
+
clean_fid_precomputed_name: lsun_church
|
127 |
+
clean_fid_precomputed_split: trainfull
|
128 |
+
clean_fid_precomputed_res: 256
|
129 |
+
trainer:
|
130 |
+
log_every_n_steps: 10
|
131 |
+
val_check_interval: 1000
|
132 |
+
custom_ddp_bf16: true
|
133 |
+
scale_lr_by_batch_size: false
|
134 |
+
limit_val_batches: 16
|
135 |
+
use_gradient_checkpointing: false
|
136 |
+
log_seperate_modal_losses: true
|
137 |
+
softmin_snr: 5
|
138 |
+
text_loss_weight: 1.0
|
139 |
+
img_loss_weight: null
|
140 |
+
low_precision_loss: false
|
141 |
+
compile: false
|
142 |
+
multimodal_batches: true
|
143 |
+
compile_fullgraph: false
|
144 |
+
log_grad_norm_every_n_steps: 10
|
145 |
+
mask_entire_modality: 0.1
|
146 |
+
force_shift_image_batches: false
|
147 |
+
ckpt_steps: 10000
|
148 |
+
ckpt_every_n_minutes: -1
|
149 |
+
ignore_text_in_unified: false
|
150 |
+
disable_all_eval_generation: false
|
151 |
+
eval_on_start: false
|
152 |
+
ckpt_model_only: false
|
153 |
+
ema: 0.0
|
154 |
+
use_custom_ema: false
|
155 |
+
log_flops: false
|
156 |
+
disable_distributed_torchmetrics: true
|
157 |
+
restart_on_failure: true
|
158 |
+
force_null_sigma: true
|
159 |
+
allow_null_sigma: true
|
160 |
+
compile_flag_pos_emb: true
|
161 |
+
add_label: false
|
162 |
+
first_token_dropout: null
|
163 |
+
force_shift_raw_image_batches: true
|
164 |
+
txt_dropout: 0.1
|
165 |
+
disable_ddp_optimizer: true
|
166 |
+
optim:
|
167 |
+
lr: 0.0003
|
168 |
+
weight_decay: 0.05
|
169 |
+
loader:
|
170 |
+
batch_size: 64
|
171 |
+
eval_batch_size: ${loader.batch_size}
|
172 |
+
num_workers: 1
|
173 |
+
desired_global_batch_size: 512
|
174 |
+
persistent_workers: true
|
175 |
+
pin_memory: true
|
176 |
+
num_eval_workers: 1
|
177 |
+
sampling:
|
178 |
+
steps: ${model.length}
|
179 |
+
num_sample_batches: 2
|
180 |
+
max_sampling_steps: ${model.length}
|
181 |
+
wandb:
|
182 |
+
mode: online
|
183 |
+
lr_scheduler:
|
184 |
+
num_warmup_steps: 5000
|
185 |
+
checkpointing:
|
186 |
+
checkpoints_total_limit: 4
|
configs/experiments/small_text_only.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- lsun_text8_exp_2
|
5 |
+
- owt_only
|
6 |
+
- override /model: small
|
7 |
+
|
8 |
+
backbone: dit
|
9 |
+
|
10 |
+
loader:
|
11 |
+
batch_size: 64
|
12 |
+
|
13 |
+
trainer:
|
14 |
+
val_check_interval: 10000
|
15 |
+
ckpt_steps: 10000
|
16 |
+
softmin_snr: null
|
17 |
+
|
18 |
+
optim:
|
19 |
+
fused: true
|
20 |
+
weight_decay: 0.03
|
21 |
+
|
22 |
+
sampling:
|
23 |
+
num_sample_batches: 4
|
24 |
+
max_sampling_steps: 256
|
25 |
+
|
26 |
+
model:
|
27 |
+
txt_length: 1024
|
28 |
+
|
configs/experiments/standalone_fid_eval.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
mode: eval
|
4 |
+
debug: true
|
5 |
+
|
6 |
+
eval:
|
7 |
+
max_num_fid_batches_per_device: ${eval:'4096 // (${trainer.devices} * ${loader.eval_batch_size})'}
|
8 |
+
compute_generative_perplexity: false
|
9 |
+
generate_samples: false
|
10 |
+
log_every_n_fid: 1
|
11 |
+
log_every_n_evals: 1
|
12 |
+
|
13 |
+
loader:
|
14 |
+
eval_batch_size: 32
|
15 |
+
|
16 |
+
sampling:
|
17 |
+
steps: 500
|
18 |
+
max_sampling_steps: 500
|
configs/experiments/titok.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
data:
|
4 |
+
resolution: 256
|
5 |
+
downscale_ratio: 16
|
6 |
+
|
7 |
+
model:
|
8 |
+
vae_type: titok
|
configs/experiments/titok_sl256.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
data:
|
4 |
+
resolution: 256
|
5 |
+
|
6 |
+
model:
|
7 |
+
vae_type: titok
|
configs/experiments/txt_only.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
data:
|
4 |
+
streaming: False
|
5 |
+
unpaired: false
|
6 |
+
|
7 |
+
trainer:
|
8 |
+
img_loss_weight: null
|
9 |
+
text_loss_weight: null
|
10 |
+
|
11 |
+
model:
|
12 |
+
use_pretrained_img_emb: false
|
13 |
+
image_model_fid_eval: false
|
14 |
+
unified_model: false
|
15 |
+
image_model: false
|
16 |
+
txt_length: 256
|
17 |
+
img_length: 0
|
18 |
+
|
19 |
+
eval:
|
20 |
+
log_every_n_evals: -1
|
21 |
+
log_every_n_fid: -1
|
configs/experiments/unified.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
data:
|
4 |
+
zero_shot_eval_dataset: "nlphuji/flickr30k"
|
5 |
+
precache: False
|
6 |
+
tokenizers_parallelism: False # parallelism causes some weird error
|
7 |
+
n_val_samples: 2048
|
8 |
+
block_size: 128
|
9 |
+
|
10 |
+
model:
|
11 |
+
unified_model: True
|
12 |
+
text_model: true
|
13 |
+
|
14 |
+
checkpointing:
|
15 |
+
resume_from_ckpt: True
|
16 |
+
load_from_text_model: "ckpts/unidisc-owt/model.safetensors"
|
17 |
+
|
18 |
+
loader:
|
19 |
+
batch_size: 12
|
20 |
+
|
21 |
+
trainer:
|
22 |
+
val_check_interval: 2000
|
23 |
+
log_seperate_modal_losses: true
|
configs/experiments/vq16.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
model:
|
4 |
+
downscale_ratio: 16
|
5 |
+
image_vocab_size: 16384
|
6 |
+
vae_type: VQ-16
|
7 |
+
use_custom_vae_ckpt: null
|
8 |
+
custom_vae_name: null
|
9 |
+
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
|
configs/experiments/vq16_1024.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
model:
|
4 |
+
downscale_ratio: 16
|
5 |
+
image_vocab_size: 1024
|
6 |
+
codebook_embed_dim: 256
|
7 |
+
vae_type: VQ-16
|
8 |
+
use_custom_vae_ckpt: ${oc.env:DIFFUSION_DATA_DIR}/ckpts/2024-07-03-01-10-53_022-VQ-16_0042000.pt
|
configs/experiments/vq16_magvit.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
model:
|
4 |
+
downscale_ratio: 16
|
5 |
+
image_vocab_size: 8192
|
6 |
+
vae_type: magvit
|
7 |
+
use_custom_vae_ckpt: null
|
8 |
+
custom_vae_name: null
|
9 |
+
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
|
configs/experiments/vq16_t2i.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
model:
|
4 |
+
downscale_ratio: 16
|
5 |
+
image_vocab_size: 16384
|
6 |
+
vae_type: VQ-16
|
7 |
+
use_custom_vae_ckpt: ${get_repo_dir:}/ckpts/vq_ds16_t2i.pt
|
8 |
+
custom_vae_name: _t2i
|
9 |
+
codebook_embed_dim: 8
|
10 |
+
img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
|
configs/experiments/webdataset.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
data:
|
4 |
+
train: datacomp1b_indexed
|
5 |
+
valid: ${.train}
|
6 |
+
|
7 |
+
iterable: false
|
8 |
+
webdataset_iterable: false
|
9 |
+
webdataset_indexed: true
|
10 |
+
unpaired: false
|
11 |
+
dataset_type: null
|
12 |
+
tokens_flip_collate: false
|
configs/experiments/zero_shot_eval.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
mode: zero-shot-eval
|
4 |
+
|
5 |
+
data:
|
6 |
+
# train: "nlphuji/flickr30k"
|
7 |
+
train: "facebook/winoground"
|
8 |
+
precache: False
|
9 |
+
tokenizers_parallelism: False # parallelism causes some weird error
|
10 |
+
n_val_samples: 2048
|
11 |
+
block_size: 128
|
12 |
+
disable_text_modality: false
|
13 |
+
|
14 |
+
eval:
|
15 |
+
cfg: 5
|
16 |
+
compute_val_metrics_standalone: false
|
17 |
+
compute_img_to_txt_mauve_clip: false
|
18 |
+
|
19 |
+
loader:
|
20 |
+
batch_size: 16
|
21 |
+
eval_batch_size: 16
|
22 |
+
|
23 |
+
|
24 |
+
model:
|
25 |
+
unified_model: True
|
26 |
+
text_model: true
|
27 |
+
image_model: true
|
28 |
+
vae_type: magvit
|
29 |
+
force_optimized_native_attn: false
|
configs/lr_scheduler/constant_warmup.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
_target_: transformers.get_constant_schedule_with_warmup
|
2 |
+
num_warmup_steps: 2500
|
configs/lr_scheduler/constant_warmup_cosine_decay.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
_target_: transformers.get_cosine_schedule_with_warmup
|
2 |
+
num_warmup_steps: 2500
|
3 |
+
num_training_steps: 1000000
|
configs/lr_scheduler/cosine_decay_warmup.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: utils.CosineDecayWarmupLRScheduler
|
2 |
+
t_in_epochs: False
|
3 |
+
t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
|
4 |
+
warmup_prefix: True
|
5 |
+
warmup_lr_init: 1e-6
|
6 |
+
warmup_t: ${eval:0.1*${trainer.max_steps}}
|
7 |
+
lr_min: 1e-6
|
configs/lr_scheduler/cosine_with_hard_restarts_schedule_with_warmup.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: transformers.get_cosine_with_hard_restarts_schedule_with_warmup
|
2 |
+
num_warmup_steps: 2500
|
3 |
+
num_training_steps: 1000000
|
4 |
+
num_cycles: 1
|
configs/model/extra_large.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: extra_large
|
2 |
+
type: ddit
|
3 |
+
hidden_size: 2048
|
4 |
+
cond_dim: 128
|
5 |
+
length: 1024
|
6 |
+
n_blocks: 24
|
7 |
+
n_heads: 16
|
8 |
+
scale_by_sigma: True
|
9 |
+
dropout: 0.1
|
10 |
+
tie_word_embeddings: False
|
configs/model/large.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: large
|
2 |
+
type: ddit
|
3 |
+
hidden_size: 1280
|
4 |
+
cond_dim: 128
|
5 |
+
length: 1024
|
6 |
+
base_n_blocks: 28
|
7 |
+
# We try to roughly match parameter count
|
8 |
+
n_blocks: ${adjust_n_blocks:}
|
9 |
+
n_heads: 20
|
10 |
+
scale_by_sigma: True
|
11 |
+
dropout: 0.1
|
12 |
+
tie_word_embeddings: False
|
13 |
+
|
14 |
+
# 36 1280 20
|
configs/model/medium.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: medium
|
2 |
+
type: ddit
|
3 |
+
hidden_size: 1024
|
4 |
+
cond_dim: 128
|
5 |
+
length: 1024
|
6 |
+
base_n_blocks: 24
|
7 |
+
# We try to roughly match parameter count
|
8 |
+
n_blocks: ${adjust_n_blocks:}
|
9 |
+
n_heads: 16
|
10 |
+
scale_by_sigma: True
|
11 |
+
dropout: 0.1
|
12 |
+
tie_word_embeddings: False
|
configs/model/small-ar.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: small
|
2 |
+
type: ddit
|
3 |
+
hidden_size: 768
|
4 |
+
cond_dim: 128
|
5 |
+
length: 1024
|
6 |
+
n_blocks: 12
|
7 |
+
n_heads: 12
|
8 |
+
scale_by_sigma: True
|
9 |
+
dropout: 0.1
|
10 |
+
causal: True
|
11 |
+
tie_word_embeddings: False
|