Improve model card: Add metadata, link to project page and Github repository
Browse filesThis PR improves the model card by adding the missing `pipeline_tag` and `library_name` to the metadata. It also incorporates the link to the project page.
README.md
CHANGED
@@ -1,3 +1,290 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
pipeline_tag: image-to-image
|
4 |
+
library_name: diffusers
|
5 |
+
---
|
6 |
+
|
7 |
+
<h1 align="center"> REPA-E: Unlocking VAE for End-to-End Tuning of Latent Diffusion Transformers </h1>
|
8 |
+
|
9 |
+
<p align="center">
|
10 |
+
<a href="https://scholar.google.com.au/citations?user=GQzvqS4AAAAJ" target="_blank">Xingjian Leng</a><sup>1*</sup>   <b>·</b>  
|
11 |
+
<a href="https://1jsingh.github.io/" target="_blank">Jaskirat Singh</a><sup>1*</sup>   <b>·</b>  
|
12 |
+
<a href="https://hou-yz.github.io/" target="_blank">Yunzhong Hou</a><sup>1</sup>   <b>·</b>  
|
13 |
+
<a href="https://people.csiro.au/X/Z/Zhenchang-Xing/" target="_blank">Zhenchang Xing</a><sup>2</sup>  <b>·</b>  
|
14 |
+
<a href="https://www.sainingxie.com/" target="_blank">Saining Xie</a><sup>3</sup>  <b>·</b>  
|
15 |
+
<a href="https://zheng-lab-anu.github.io/" target="_blank">Liang Zheng</a><sup>1</sup> 
|
16 |
+
</p>
|
17 |
+
|
18 |
+
<p align="center">
|
19 |
+
<sup>1</sup> Australian National University   <sup>2</sup>Data61-CSIRO   <sup>3</sup>New York University   <br>
|
20 |
+
<sub><sup>*</sup>Project Leads  </sub>
|
21 |
+
</p>
|
22 |
+
|
23 |
+
<p align="center">
|
24 |
+
<a href="https://End2End-Diffusion.github.io">π Project Page</a>  
|
25 |
+
<a href="https://huggingface.co/REPA-E">π€ Models</a>  
|
26 |
+
<a href="https://huggingface.co/papers/2504.10483">π Paper</a>  
|
27 |
+
<br><br>
|
28 |
+
<a href="https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?p=repa-e-unlocking-vae-for-end-to-end-tuning-of/image-generation-on-imagenet-256x256"><img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/repa-e-unlocking-vae-for-end-to-end-tuning-of/image-generation-on-imagenet-256x256" alt="PWC"></a>
|
29 |
+
</p>
|
30 |
+
|
31 |
+

|
32 |
+
|
33 |
+
## Overview
|
34 |
+
We address a fundamental question: ***Can latent diffusion models and their VAE tokenizer be trained end-to-end?*** While training both components jointly with standard diffusion loss is observed to be ineffective β often degrading final performance β we show that this limitation can be overcome using a simple representation-alignment (REPA) loss. Our proposed method, **REPA-E**, enables stable and effective joint training of both the VAE and the diffusion model.
|
35 |
+
|
36 |
+

|
37 |
+
|
38 |
+
**REPA-E** significantly accelerates training β achieving over **17Γ** speedup compared to REPA and **45Γ** over the vanilla training recipe. Interestingly, end-to-end tuning also improves the VAE itself: the resulting **E2E-VAE** provides better latent structure and serves as a **drop-in replacement** for existing VAEs (e.g., SD-VAE), improving convergence and generation quality across diverse LDM architectures. Our method achieves state-of-the-art FID scores on ImageNet 256Γ256: **1.26** with CFG and **1.83** without CFG.
|
39 |
+
|
40 |
+
## News and Updates
|
41 |
+
**[2025-04-15]** Initial Release with pre-trained models and codebase.
|
42 |
+
|
43 |
+
## Getting Started
|
44 |
+
### 1. Environment Setup
|
45 |
+
To set up our environment, please run:
|
46 |
+
|
47 |
+
```bash
|
48 |
+
git clone https://github.com/REPA-E/REPA-E.git
|
49 |
+
cd REPA-E
|
50 |
+
conda env create -f environment.yml -y
|
51 |
+
conda activate repa-e
|
52 |
+
```
|
53 |
+
|
54 |
+
### 2. Prepare the training data
|
55 |
+
Download and extract the training split of the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index) dataset. Once it's ready, run the following command to preprocess the dataset:
|
56 |
+
|
57 |
+
```bash
|
58 |
+
python preprocessing.py --imagenet-path /PATH/TO/IMAGENET_TRAIN
|
59 |
+
```
|
60 |
+
|
61 |
+
Replace `/PATH/TO/IMAGENET_TRAIN` with the actual path to the extracted training images.
|
62 |
+
|
63 |
+
### 3. Train the REPA-E model
|
64 |
+
|
65 |
+
To train the REPA-E model, you first need to download the following pre-trained VAE checkpoints:
|
66 |
+
- [π€ **SD-VAE (f8d4)**](https://huggingface.co/REPA-E/sdvae): Derived from the [Stability AI SD-VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse), originally trained on [Open Images](https://storage.googleapis.com/openimages/web/index.html) and fine-tuned on a subset of [LAION-2B](https://laion.ai/blog/laion-5b/).
|
67 |
+
- [π€ **IN-VAE (f16d32)**](https://huggingface.co/REPA-E/invae): Trained from scratch on [ImageNet-1K](https://www.image-net.org/) using the [latent-diffusion](https://github.com/CompVis/latent-diffusion) codebase with our custom architecture.
|
68 |
+
- [π€ **VA-VAE (f16d32)**](https://huggingface.co/REPA-E/vavae): Taken from [LightningDiT](https://github.com/hustvl/LightningDiT), this VAE is a visual tokenizer aligned with vision foundation models during reconstruction training. It is also trained on [ImageNet-1K](https://www.image-net.org/) for high-quality tokenization in high-dimensional latent spaces.
|
69 |
+
|
70 |
+
Recommended directory structure:
|
71 |
+
```
|
72 |
+
pretrained/
|
73 |
+
βββ invae/
|
74 |
+
βββ sdvae/
|
75 |
+
βββ vavae/
|
76 |
+
```
|
77 |
+
|
78 |
+
Once you've downloaded the VAE checkpoint, you can launch REPA-E training with:
|
79 |
+
```bash
|
80 |
+
accelerate launch train_repae.py \
|
81 |
+
--max-train-steps=400000 \
|
82 |
+
--report-to="wandb" \
|
83 |
+
--allow-tf32 \
|
84 |
+
--mixed-precision="fp16" \
|
85 |
+
--seed=0 \
|
86 |
+
--data-dir="data" \
|
87 |
+
--output-dir="exps" \
|
88 |
+
--batch-size=256 \
|
89 |
+
--path-type="linear" \
|
90 |
+
--prediction="v" \
|
91 |
+
--weighting="uniform" \
|
92 |
+
--model="SiT-XL/2" \
|
93 |
+
--checkpointing-steps=50000 \
|
94 |
+
--loss-cfg-path="configs/l1_lpips_kl_gan.yaml" \
|
95 |
+
--vae="f8d4" \
|
96 |
+
--vae-ckpt="pretrained/sdvae/sdvae-f8d4.pt" \
|
97 |
+
--disc-pretrained-ckpt="pretrained/sdvae/sdvae-f8d4-discriminator-ckpt.pt" \
|
98 |
+
--enc-type="dinov2-vit-b" \
|
99 |
+
--proj-coeff=0.5 \
|
100 |
+
--encoder-depth=8 \
|
101 |
+
--vae-align-proj-coeff=1.5 \
|
102 |
+
--bn-momentum=0.1 \
|
103 |
+
--exp-name="sit-xl-dinov2-b-enc8-repae-sdvae-0.5-1.5-400k"
|
104 |
+
```
|
105 |
+
<details>
|
106 |
+
<summary>Click to expand for configuration options</summary>
|
107 |
+
|
108 |
+
Then this script will automatically create the folder in `exps` to save logs and checkpoints. You can adjust the following options:
|
109 |
+
|
110 |
+
- `--output-dir`: Directory to save checkpoints and logs
|
111 |
+
- `--exp-name`: Experiment name (a subfolder will be created under `output-dir`)
|
112 |
+
- `--vae`: Choose between `[f8d4, f16d32]`
|
113 |
+
- `--vae-ckpt`: Path to a provided or custom VAE checkpoint
|
114 |
+
- `--disc-pretrained-ckpt`: Path to a provided or custom VAE discriminator checkpoint
|
115 |
+
- `--models`: Choose from `[SiT-B/2, SiT-L/2, SiT-XL/2, SiT-B/1, SiT-L/1, SiT-XL/1]`. The number indicates the patch size. Select a model compatible with your VAE architecture.
|
116 |
+
- `--enc-type`: `[dinov2-vit-b, dinov2-vit-l, dinov2-vit-g, dinov1-vit-b, mocov3-vit-b, mocov3-vit-l, clip-vit-L, jepa-vit-h, mae-vit-l]`
|
117 |
+
- `--encoder-depth`: Any integer from 1 up to the full depth of the selected encoder
|
118 |
+
- `--proj-coeff`: REPA-E projection coefficient for SiT alignment (float > 0)
|
119 |
+
- `--vae-align-proj-coeff`: REPA-E projection coefficient for VAE alignment (float > 0)
|
120 |
+
- `--bn-momentum`: Batchnorm layer momentum (float)
|
121 |
+
|
122 |
+
</details>
|
123 |
+
|
124 |
+
### 4. Use REPA-E Tuned VAE (E2E-VAE) for Accelerated Training and Better Generation
|
125 |
+
This section shows how to use the REPA-E fine-tuned VAE (E2E-VAE) in latent diffusion training. E2E-VAE acts as a drop-in replacement for the original VAE, enabling significantly accelerated generation performance. You can either download a pre-trained VAE or extract it from a REPA-E checkpoint.
|
126 |
+
|
127 |
+
**Step 1**: Obtain the fine-tuned VAE from REPA-E checkpoints:
|
128 |
+
- **Option 1**: Download pre-trained REPA-E VAEs directly from Hugging Face:
|
129 |
+
- [π€ **E2E-SDVAE**](https://huggingface.co/REPA-E/e2e-sdvae)
|
130 |
+
- [π€ **E2E-INVAE**](https://huggingface.co/REPA-E/e2e-invae)
|
131 |
+
- [π€ **E2E-VAVAE**](https://huggingface.co/REPA-E/e2e-vavae)
|
132 |
+
|
133 |
+
Recommended directory structure:
|
134 |
+
```
|
135 |
+
pretrained/
|
136 |
+
βββ e2e-sdvae/
|
137 |
+
βββ e2e-invae/
|
138 |
+
βββ e2e-vavae/
|
139 |
+
```
|
140 |
+
- **Option 2**: Extract the VAE from a full REPA-E checkpoint manually:
|
141 |
+
```bash
|
142 |
+
python save_vae_weights.py \
|
143 |
+
--repae-ckpt pretrained/sit-repae-vavae/checkpoints/0400000.pt \
|
144 |
+
--vae-name e2e-vavae \
|
145 |
+
--save-dir exps
|
146 |
+
```
|
147 |
+
|
148 |
+
**Step 2**: Cache latents to enable fast training:
|
149 |
+
```bash
|
150 |
+
accelerate launch --num-machines=1 --num-processes=8 cache_latents.py \
|
151 |
+
--vae-arch="f16d32" \
|
152 |
+
--vae-ckpt-path="pretrained/e2e-vavae/e2e-vavae-400k.pt" \
|
153 |
+
--vae-latents-name="e2e-vavae" \
|
154 |
+
--pproc-batch-size=128
|
155 |
+
```
|
156 |
+
|
157 |
+
**Step 3**: Train the SiT generation model using the cached latents:
|
158 |
+
|
159 |
+
```bash
|
160 |
+
accelerate launch train_ldm_only.py \
|
161 |
+
--max-train-steps=4000000 \
|
162 |
+
--report-to="wandb" \
|
163 |
+
--allow-tf32 \
|
164 |
+
--mixed-precision="fp16" \
|
165 |
+
--seed=0 \
|
166 |
+
--data-dir="data" \
|
167 |
+
--batch-size=256 \
|
168 |
+
--path-type="linear" \
|
169 |
+
--prediction="v" \
|
170 |
+
--weighting="uniform" \
|
171 |
+
--model="SiT-XL/1" \
|
172 |
+
--checkpointing-steps=50000 \
|
173 |
+
--vae="f16d32" \
|
174 |
+
--vae-ckpt="pretrained/e2e-vavae/e2e-vavae-400k.pt" \
|
175 |
+
--vae-latents-name="e2e-vavae" \
|
176 |
+
--learning-rate=1e-4 \
|
177 |
+
--enc-type="dinov2-vit-b" \
|
178 |
+
--proj-coeff=0.5 \
|
179 |
+
--encoder-depth=8 \
|
180 |
+
--output-dir="exps" \
|
181 |
+
--exp-name="sit-xl-1-dinov2-b-enc8-ldm-only-e2e-vavae-0.5-4m"
|
182 |
+
```
|
183 |
+
|
184 |
+
For details on the available training options and argument descriptions, refer to [Section 3](#3-train-the-repa-e-model).
|
185 |
+
|
186 |
+
### 5. Generate samples and run evaluation
|
187 |
+
You can generate samples and save them as `.npz` files using the following script. Simply set the `--exp-path` and `--train-steps` corresponding to your trained model (REPA-E or Traditional LDM Training).\
|
188 |
+
|
189 |
+
```bash
|
190 |
+
torchrun --nnodes=1 --nproc_per_node=8 generate.py \
|
191 |
+
--num-fid-samples 50000 \
|
192 |
+
--path-type linear \
|
193 |
+
--mode sde \
|
194 |
+
--num-steps 250 \
|
195 |
+
--cfg-scale 1.0 \
|
196 |
+
--guidance-high 1.0 \
|
197 |
+
--guidance-low 0.0 \
|
198 |
+
--exp-path pretrained/sit-repae-sdvae \
|
199 |
+
--train-steps 400000
|
200 |
+
```
|
201 |
+
|
202 |
+
```bash
|
203 |
+
torchrun --nnodes=1 --nproc_per_node=8 generate.py \
|
204 |
+
--num-fid-samples 50000 \
|
205 |
+
--path-type linear \
|
206 |
+
--mode sde \
|
207 |
+
--num-steps 250 \
|
208 |
+
--cfg-scale 1.0 \
|
209 |
+
--guidance-high 1.0 \
|
210 |
+
--guidance-low 0.0 \
|
211 |
+
--exp-path pretrained/sit-ldm-e2e-vavae \
|
212 |
+
--train-steps 4000000
|
213 |
+
```
|
214 |
+
|
215 |
+
<details>
|
216 |
+
<summary>Click to expand for sampling options</summary>
|
217 |
+
|
218 |
+
You can adjust the following options for sampling:
|
219 |
+
- `--path-type linear`: Noise schedule type, choose from `[linear, cosine]`
|
220 |
+
- `--mode`: Sampling mode, `[ode, sde]`
|
221 |
+
- `--num-steps`: Number of denoising steps
|
222 |
+
- `--cfg-scale`: Guidance scale (float β₯ 1), setting it to 1 disables classifier-free guidance (CFG)
|
223 |
+
- `--guidance-high`: Upper guidance interval (float in [0, 1])
|
224 |
+
- `--guidance-low`: Lower guidance interval (float in [0, 1], must be < `--guidance-high`)\
|
225 |
+
- `--exp-path`: Path to the experiment directory
|
226 |
+
- `--train-steps`: Training step of the checkpoint to evaluate
|
227 |
+
|
228 |
+
</details>
|
229 |
+
|
230 |
+
### Quantitative Results
|
231 |
+
Tables below report generation performance using gFID on 50k samples, with and without classifier-free guidance (CFG). We compare models trained end-to-end with **REPA-E** and models using a frozen REPA-E fine-tuned VAE (**E2E-VAE**). Lower is better. All linked checkpoints below are hosted on our [π€ Hugging Face Hub](https://huggingface.co/REPA-E). To reproduce these results, download the respective checkpoints to the `pretrained` folder and run the evaluation script as detailed in [Section 5](#5-generate-samples-and-run-evaluation).\
|
232 |
+
|
233 |
+
#### A. End-to-End Training (REPA-E)
|
234 |
+
| Tokenizer | Generation Model | Epochs | gFID-50k β | gFID-50k (CFG) β |
|
235 |
+
|:---------|:----------------|:-----:|:----:|:---:|
|
236 |
+
| [**SD-VAE<sup>*</sup>**](https://huggingface.co/REPA-E/sdvae) | [**SiT-XL/2**](https://huggingface.co/REPA-E/sit-repae-sdvae) | 80 | 4.07 | 1.67<sup>a</sup> |
|
237 |
+
| [**IN-VAE<sup>*</sup>**](https://huggingface.co/REPA-E/invae) | [**SiT-XL/1**](https://huggingface.co/REPA-E/sit-repae-invae) | 80 | 4.09 | 1.61<sup>b</sup> |
|
238 |
+
| [**VA-VAE<sup>*</sup>**](https://huggingface.co/REPA-E/vavae) | [**SiT-XL/1**](https://huggingface.co/REPA-E/sit-repae-vavae) | 80 | 4.05 | 1.73<sup>c</sup> |
|
239 |
+
|
240 |
+
\* The "Tokenizer" column refers to the initial VAE used for joint REPA-E training. The final (jointly optimized) VAE is bundled within the generation model checkpoint.
|
241 |
+
|
242 |
+
<details>
|
243 |
+
<summary>Click to expand for CFG parameters</summary>
|
244 |
+
<ul>
|
245 |
+
<li><strong>a</strong>: <code>--cfg-scale=2.2</code>, <code>--guidance-low=0.0</code>, <code>--guidance-high=0.65</code></li>
|
246 |
+
<li><strong>b</strong>: <code>--cfg-scale=1.8</code>, <code>--guidance-low=0.0</code>, <code>--guidance-high=0.825</code></li>
|
247 |
+
<li><strong>c</strong>: <code>--cfg-scale=1.9</code>, <code>--guidance-low=0.0</code>, <code>--guidance-high=0.825</code></li>
|
248 |
+
</ul>
|
249 |
+
</details>
|
250 |
+
|
251 |
+
---
|
252 |
+
|
253 |
+
#### B. Traditional Latent Diffusion Model Training (Frozen VAE)
|
254 |
+
| Tokenizer | Generation Model | Method | Epochs | gFID-50k β | gFID-50k (CFG) β |
|
255 |
+
|:------|:---------|:----------------|:-----:|:----:|:---:|
|
256 |
+
| SD-VAE | SiT-XL/2 | SiT | 1400 | 8.30 | 2.06 |
|
257 |
+
| SD-VAE | SiT-XL/2 | REPA | 800 | 5.90 | 1.42 |
|
258 |
+
| VA-VAE | LightningDiT-XL/1 | LightningDiT | 800 | 2.17 | 1.36 |
|
259 |
+
| [**E2E-VAVAE (Ours)**](https://huggingface.co/REPA-E/e2e-vavae) | [**SiT-XL/1**](https://huggingface.co/REPA-E/sit-ldm-e2e-vavae) | REPA | 800 | **1.83** | **1.26**<sup>β </sup> |
|
260 |
+
|
261 |
+
In this setup, the VAE is kept frozen, and only the generator is trained. Models using our E2E-VAE (fine-tuned via REPA-E) consistently outperform baselines like SD-VAE and VA-VAE, achieving state-of-the-art performance when incorporating the REPA alignment objective.
|
262 |
+
|
263 |
+
<details>
|
264 |
+
<summary>Click to expand for CFG parameters</summary>
|
265 |
+
<ul>
|
266 |
+
<li><strong>β </strong>: <code>--cfg-scale=2.5</code>, <code>--guidance-low=0.0</code>, <code>--guidance-high=0.75</code></li>
|
267 |
+
</ul>
|
268 |
+
</details>
|
269 |
+
|
270 |
+
## Acknowledgement
|
271 |
+
This codebase builds upon several excellent open-source projects, including:
|
272 |
+
- [1d-tokenizer](https://github.com/bytedance/1d-tokenizer)
|
273 |
+
- [edm2](https://github.com/NVlabs/edm2)
|
274 |
+
- [LightningDiT](https://github.com/hustvl/LightningDiT)
|
275 |
+
- [REPA](https://github.com/sihyun-yu/REPA)
|
276 |
+
- [Taming-Transformers](https://github.com/CompVis/taming-transformers)
|
277 |
+
|
278 |
+
We sincerely thank the authors for making their work publicly available.
|
279 |
+
|
280 |
+
## BibTeX
|
281 |
+
If you find our work useful, please consider citing:
|
282 |
+
|
283 |
+
```bibtex
|
284 |
+
@article{leng2025repae,
|
285 |
+
title={REPA-E: Unlocking VAE for End-to-End Tuning with Latent Diffusion Transformers},
|
286 |
+
author={Xingjian Leng and Jaskirat Singh and Yunzhong Hou and Zhenchang Xing and Saining Xie and Liang Zheng},
|
287 |
+
year={2025},
|
288 |
+
journal={arXiv preprint arXiv:2504.10483},
|
289 |
+
}
|
290 |
+
```
|