-
Notifications
You must be signed in to change notification settings - Fork 6.8k
feat: implement rae autoencoder. #13046
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
382aad0
f82cecc
a3926d7
3ecf89d
0850c8c
24acab0
25bc9e3
f06ea7a
d7cb124
0d59b22
202b14f
7cbbf27
e6d4499
6a9bde6
9522e68
906d79a
d3cbd5a
96520c4
fc52959
a4fc9f6
d06b501
d8b2983
c68b812
61885f3
28a02eb
b297868
7debd07
b3ffd63
dca5923
c71cb44
5c85781
d965cab
663b580
6a78767
1b4a43f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| <!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
| --> | ||
|
|
||
| # AutoencoderRAE | ||
|
|
||
| The Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx. | ||
|
|
||
| RAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation). | ||
|
|
||
| The following RAE models are released and supported in Diffusers: | ||
|
|
||
| | Model | Encoder | Latent shape (224px input) | | ||
| |:------|:--------|:---------------------------| | ||
| | [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 | | ||
| | [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 | | ||
| | [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 | | ||
| | [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 | | ||
| | [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 | | ||
| | [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 | | ||
|
|
||
| ## Loading a pretrained model | ||
|
|
||
| ```python | ||
| from diffusers import AutoencoderRAE | ||
|
|
||
| model = AutoencoderRAE.from_pretrained( | ||
| "nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08" | ||
| ).to("cuda").eval() | ||
| ``` | ||
|
|
||
| ## Encoding and decoding a real image | ||
|
|
||
| ```python | ||
| import torch | ||
| from diffusers import AutoencoderRAE | ||
| from PIL import Image | ||
| from torchvision.transforms.functional import to_tensor, to_pil_image | ||
|
|
||
| model = AutoencoderRAE.from_pretrained( | ||
| "nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08" | ||
| ).to("cuda").eval() | ||
|
|
||
| image = Image.open("cat.png").convert("RGB").resize((224, 224)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use an example snippet that just works? In that case, we should load the image directly from a public URL and then use it further. We can leverage |
||
| x = to_tensor(image).unsqueeze(0).to("cuda") # (1, 3, 224, 224), values in [0, 1] | ||
|
|
||
| with torch.no_grad(): | ||
| latents = model.encode(x).latent # (1, 768, 16, 16) | ||
| recon = model.decode(latents).sample # (1, 3, 256, 256) | ||
|
|
||
| recon_image = to_pil_image(recon[0].clamp(0, 1).cpu()) | ||
| recon_image.save("recon.png") | ||
| ``` | ||
|
|
||
| ## Latent normalization | ||
|
|
||
| Some pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively. | ||
|
|
||
| ```python | ||
| model = AutoencoderRAE.from_pretrained( | ||
| "nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08" | ||
| ).to("cuda").eval() | ||
|
|
||
| # Latent normalization is handled automatically inside encode/decode | ||
| # when the checkpoint config includes latents_mean/latents_std. | ||
|
Comment on lines
+72
to
+73
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Beautiful. |
||
| with torch.no_grad(): | ||
| latents = model.encode(x).latent # normalized latents | ||
| recon = model.decode(latents).sample | ||
| ``` | ||
|
|
||
| ## AutoencoderRAE | ||
|
|
||
| [[autodoc]] AutoencoderRAE | ||
| - encode | ||
| - decode | ||
| - all | ||
|
|
||
| ## DecoderOutput | ||
|
|
||
| [[autodoc]] models.autoencoders.vae.DecoderOutput | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| # Training AutoencoderRAE | ||
|
|
||
| This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen. | ||
|
|
||
| It follows the same high-level training recipe as the official RAE stage-1 setup: | ||
| - frozen encoder | ||
| - train decoder | ||
| - pixel reconstruction loss | ||
| - optional encoder feature consistency loss | ||
|
|
||
| ## Quickstart | ||
|
|
||
| ### Resume or finetune from pretrained weights | ||
|
|
||
| ```bash | ||
| accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \ | ||
| --pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \ | ||
| --train_data_dir /path/to/imagenet_like_folder \ | ||
| --output_dir /tmp/autoencoder-rae \ | ||
| --resolution 256 \ | ||
| --train_batch_size 8 \ | ||
| --learning_rate 1e-4 \ | ||
| --num_train_epochs 10 \ | ||
| --report_to wandb \ | ||
| --reconstruction_loss_type l1 \ | ||
| --use_encoder_loss \ | ||
| --encoder_loss_weight 0.1 | ||
| ``` | ||
|
|
||
| ### Train from scratch with a pretrained encoder | ||
| The following command launches RAE training with "facebook/dinov2-with-registers-base" as the base. | ||
|
|
||
kashif marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ```bash | ||
| accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \ | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| --train_data_dir /path/to/imagenet_like_folder \ | ||
| --output_dir /tmp/autoencoder-rae \ | ||
| --resolution 256 \ | ||
| --encoder_type dinov2 \ | ||
| --encoder_name_or_path facebook/dinov2-with-registers-base \ | ||
| --encoder_input_size 224 \ | ||
| --patch_size 16 \ | ||
| --image_size 256 \ | ||
| --decoder_hidden_size 1152 \ | ||
| --decoder_num_hidden_layers 28 \ | ||
| --decoder_num_attention_heads 16 \ | ||
| --decoder_intermediate_size 4096 \ | ||
| --train_batch_size 8 \ | ||
| --learning_rate 1e-4 \ | ||
| --num_train_epochs 10 \ | ||
| --report_to wandb \ | ||
| --reconstruction_loss_type l1 \ | ||
| --use_encoder_loss \ | ||
| --encoder_loss_weight 0.1 | ||
| ``` | ||
|
|
||
| Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`. | ||
|
|
||
| Dataset format is expected to be `ImageFolder`-compatible: | ||
|
|
||
| ```text | ||
| train_data_dir/ | ||
| class_a/ | ||
| img_0001.jpg | ||
| class_b/ | ||
| img_0002.jpg | ||
| ``` | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stevhliu could you check out the docs?