Skip to content

feat: implement rae autoencoder.#13046

Open
Ando233 wants to merge 48 commits intohuggingface:mainfrom
Ando233:rae
Open

feat: implement rae autoencoder.#13046
Ando233 wants to merge 48 commits intohuggingface:mainfrom
Ando233:rae

Conversation

@Ando233
Copy link

@Ando233 Ando233 commented Jan 28, 2026

What does this PR do?

This PR adds a new representation autoencoder implementation, AutoencoderRAE, to diffusers.
Implements diffusers.models.autoencoders.autoencoder_rae.AutoencoderRAE with a frozen pretrained vision encoder (DINOv2 / SigLIP2 / ViT-MAE) and a ViT-MAE style decoder.
The decoder implementation is aligned with the RAE-main GeneralDecoder parameter structure, enabling loading of existing trained decoder checkpoints (e.g. model.pt) without key mismatches when encoder/decoder settings are consistent.
Adds unit/integration tests under diffusers/tests/models/autoencoders/test_models_autoencoder_rae.py.
Registers exports so users can import directly via from diffusers import AutoencoderRAE.

Fixes #13000

Before submitting

Usage

ae = AutoencoderRAE(
    encoder_cls="dinov2",
    encoder_name_or_path=encoder_path,
    image_size=image_size,
    encoder_input_size=image_size,
    patch_size=patch_size,
    num_patches=num_patches,
    decoder_hidden_size=1152,
    decoder_num_hidden_layers=28,
    decoder_num_attention_heads=16,
    decoder_intermediate_size=4096,
).to(device)
ae.eval()

state = torch.load(args.decoder_ckpt, map_location="cpu")
ae.decoder.load_state_dict(state, strict=False)

with torch.no_grad():
    recon = ae(x).sample

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul sayakpaul requested a review from kashif January 30, 2026 11:31
@sayakpaul
Copy link
Member

@bytetriper if you could take a look?

@kashif
Copy link
Contributor

kashif commented Jan 30, 2026

nice works @Ando233 checking

@kashif
Copy link
Contributor

kashif commented Jan 30, 2026

off the bat,

  • let's have a nice convention for the output datatype classes, have a look at the other autoencoder for the convention in difusers
  • some of the tests might need to be marked as slow and some paths are hard-coded

lets sort out these things and then re-look

@bytetriper
Copy link

Agree with @kashif . Also if possible we can bake all the params into config so we can enable .from_pretrained(), which is more elegant and aligns with diffusers usage. I can help convert our released ckpt to hgf format afterwards

@sayakpaul
Copy link
Member

@Ando233 we're happy to provide assistance if needed.

@kashif
Copy link
Contributor

kashif commented Feb 15, 2026

@Ando233 the one remaining thing is the use of the use_encoder_loss and perhaps an example real-world training script

@kashif
Copy link
Contributor

kashif commented Feb 15, 2026

@bytetriper could you kindly try to run the conversion scripts and upload the diffusers style weights to your huggingface hub for the checkpoints you have?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments. Let me know if this makes sense. @bytetriper it would be great if you could also test the diffusers counterparts of RAE and let us know your thoughts.

@sayakpaul sayakpaul requested a review from stevhliu February 28, 2026 16:49
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a major comment regarding the presence of the encoder-specific classes now. LMK your thoughts.

Comment on lines +66 to +67
self.model.layernorm.weight = None
self.model.layernorm.bias = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're already stripping the layernorms in the conversion. Seems like it's not needed anymore?

Comment on lines +102 to +103
self.model.vision_model.post_layernorm.weight = None
self.model.vision_model.post_layernorm.bias = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for these?

@kashif
Copy link
Contributor

kashif commented Mar 3, 2026

@bytetriper could you kindly merge the change in the PR on the hub for the weights?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really good barring a few nits.

Comment on lines +528 to +532
# RAE-main: base_patches = (encoder_input_size // encoder_patch_size) ** 2
if self.encoder_input_size % encoder_patch_size != 0:
raise ValueError(
f"encoder_input_size={self.encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's error as early as possible?

Comment on lines +96 to +97
model.layernorm.weight = None
model.layernorm.bias = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a small comment explaining what this is doing?

Comment on lines +535 to +538
# Decoder patch size is independent from encoder patch size.
decoder_patch_size = int(patch_size)
if decoder_patch_size <= 0:
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. Let's error as early as possible.


@slow
@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.")
class AutoencoderRAEEncoderIntegrationTests(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't inherit from unittest now that we're using pure pytest.

gc.collect()
backend_empty_cache(torch_device)

def test_autoencoder_rae_from_pretrained_dinov2(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check against an image and assert a value slice?

@sayakpaul
Copy link
Member

@dg845 could you give this a review as well?

@bytetriper
Copy link

@bytetriper could you kindly merge the change in the PR on the hub for the weights?

@kashif Merged!

Comment on lines +83 to +85
def _build_encoder(encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int) -> nn.Module:
"""Build a frozen encoder from config (no pretrained download)."""
num_attention_heads = hidden_size // 64 # all supported encoders use head_dim=64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _build_encoder(encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int) -> nn.Module:
"""Build a frozen encoder from config (no pretrained download)."""
num_attention_heads = hidden_size // 64 # all supported encoders use head_dim=64
def _build_encoder(
encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64
) -> nn.Module:
"""Build a frozen encoder from config (no pretrained download)."""
num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64

nit: I think having a head_dim argument here would be useful to make the code more future proof (for example, if a CoolNewRepresentatonEncoder comes out in the future with head_dim != 64, it would be easier to support).

Comment on lines +259 to +261
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_hidden_size), requires_grad=False
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is setting requires_grad=False here intended? If decoder_pos_embed is meant to be a fixed 2D sinusoidal positional embedding (as I believe _initialize_weights is doing below), perhaps we could register it as a buffer instead?

self.gradient_checkpointing = False

self._initialize_weights(num_patches)
self.set_trainable_cls_token()
Copy link
Collaborator

@dg845 dg845 Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should set_trainable_cls_token receive an argument from __init__ here? If not, maybe it would be more clear to inline its logic, e.g.

    self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, self.decoder_hidden_size))

x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c)
return x

def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: Optional[Tuple[int, int]] = None):
Copy link
Collaborator

@dg845 dg845 Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: Optional[Tuple[int, int]] = None):
def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None):

We have moved to Python 3.9+ style implicit type annotations. Can you update the type annotations here and elsewhere where the old style is used? See #12524 for more details.

self._initialize_weights(num_patches)
self.set_trainable_cls_token()

def set_trainable_cls_token(self, tensor: Optional[torch.Tensor] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def set_trainable_cls_token(self, tensor: Optional[torch.Tensor] = None):
def set_trainable_cls_token(self, tensor: torch.Tensor | None = None):

See #13046.

# Slicing support (batch dimension) similar to other diffusers autoencoders
self.use_slicing = False

def _noising(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the _noising method take a generator argument for reproducibility?

noise_sigma = self.noise_tau * torch.rand((x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype)
return x + noise_sigma * torch.randn_like(x)

def _maybe_resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _maybe_resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:
def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:

nit: I think it would be more clear to drop the _maybe prefixes from the _maybe_* methods as it looks like all of these methods unconditionally perform the action in the name.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RAE support

7 participants