feat: implement rae autoencoder.#13046
Conversation
|
@bytetriper if you could take a look? |
|
nice works @Ando233 checking |
|
off the bat,
lets sort out these things and then re-look |
|
Agree with @kashif . Also if possible we can bake all the params into config so we can enable .from_pretrained(), which is more elegant and aligns with diffusers usage. I can help convert our released ckpt to hgf format afterwards |
|
@Ando233 we're happy to provide assistance if needed. |
|
@Ando233 the one remaining thing is the use of the |
|
@bytetriper could you kindly try to run the conversion scripts and upload the diffusers style weights to your huggingface hub for the checkpoints you have? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
sayakpaul
left a comment
There was a problem hiding this comment.
Left some comments. Let me know if this makes sense. @bytetriper it would be great if you could also test the diffusers counterparts of RAE and let us know your thoughts.
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
sayakpaul
left a comment
There was a problem hiding this comment.
Left a major comment regarding the presence of the encoder-specific classes now. LMK your thoughts.
| self.model.layernorm.weight = None | ||
| self.model.layernorm.bias = None |
There was a problem hiding this comment.
We're already stripping the layernorms in the conversion. Seems like it's not needed anymore?
| self.model.vision_model.post_layernorm.weight = None | ||
| self.model.vision_model.post_layernorm.bias = None |
|
@bytetriper could you kindly merge the change in the PR on the hub for the weights? |
sayakpaul
left a comment
There was a problem hiding this comment.
Looks really good barring a few nits.
| # RAE-main: base_patches = (encoder_input_size // encoder_patch_size) ** 2 | ||
| if self.encoder_input_size % encoder_patch_size != 0: | ||
| raise ValueError( | ||
| f"encoder_input_size={self.encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}." | ||
| ) |
There was a problem hiding this comment.
Let's error as early as possible?
| model.layernorm.weight = None | ||
| model.layernorm.bias = None |
There was a problem hiding this comment.
Can we add a small comment explaining what this is doing?
| # Decoder patch size is independent from encoder patch size. | ||
| decoder_patch_size = int(patch_size) | ||
| if decoder_patch_size <= 0: | ||
| raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).") |
There was a problem hiding this comment.
Same. Let's error as early as possible.
|
|
||
| @slow | ||
| @pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.") | ||
| class AutoencoderRAEEncoderIntegrationTests(unittest.TestCase): |
There was a problem hiding this comment.
We shouldn't inherit from unittest now that we're using pure pytest.
| gc.collect() | ||
| backend_empty_cache(torch_device) | ||
|
|
||
| def test_autoencoder_rae_from_pretrained_dinov2(self): |
There was a problem hiding this comment.
Should we check against an image and assert a value slice?
|
@dg845 could you give this a review as well? |
@kashif Merged! |
| def _build_encoder(encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int) -> nn.Module: | ||
| """Build a frozen encoder from config (no pretrained download).""" | ||
| num_attention_heads = hidden_size // 64 # all supported encoders use head_dim=64 |
There was a problem hiding this comment.
| def _build_encoder(encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int) -> nn.Module: | |
| """Build a frozen encoder from config (no pretrained download).""" | |
| num_attention_heads = hidden_size // 64 # all supported encoders use head_dim=64 | |
| def _build_encoder( | |
| encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64 | |
| ) -> nn.Module: | |
| """Build a frozen encoder from config (no pretrained download).""" | |
| num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64 |
nit: I think having a head_dim argument here would be useful to make the code more future proof (for example, if a CoolNewRepresentatonEncoder comes out in the future with head_dim != 64, it would be easier to support).
| self.decoder_pos_embed = nn.Parameter( | ||
| torch.zeros(1, num_patches + 1, decoder_hidden_size), requires_grad=False | ||
| ) |
There was a problem hiding this comment.
Is setting requires_grad=False here intended? If decoder_pos_embed is meant to be a fixed 2D sinusoidal positional embedding (as I believe _initialize_weights is doing below), perhaps we could register it as a buffer instead?
| self.gradient_checkpointing = False | ||
|
|
||
| self._initialize_weights(num_patches) | ||
| self.set_trainable_cls_token() |
There was a problem hiding this comment.
Should set_trainable_cls_token receive an argument from __init__ here? If not, maybe it would be more clear to inline its logic, e.g.
self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, self.decoder_hidden_size))| x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c) | ||
| return x | ||
|
|
||
| def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: Optional[Tuple[int, int]] = None): |
There was a problem hiding this comment.
| def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: Optional[Tuple[int, int]] = None): | |
| def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None): |
We have moved to Python 3.9+ style implicit type annotations. Can you update the type annotations here and elsewhere where the old style is used? See #12524 for more details.
| self._initialize_weights(num_patches) | ||
| self.set_trainable_cls_token() | ||
|
|
||
| def set_trainable_cls_token(self, tensor: Optional[torch.Tensor] = None): |
There was a problem hiding this comment.
| def set_trainable_cls_token(self, tensor: Optional[torch.Tensor] = None): | |
| def set_trainable_cls_token(self, tensor: torch.Tensor | None = None): |
See #13046.
| # Slicing support (batch dimension) similar to other diffusers autoencoders | ||
| self.use_slicing = False | ||
|
|
||
| def _noising(self, x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Could the _noising method take a generator argument for reproducibility?
| noise_sigma = self.noise_tau * torch.rand((x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype) | ||
| return x + noise_sigma * torch.randn_like(x) | ||
|
|
||
| def _maybe_resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
| def _maybe_resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor: | |
| def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor: |
nit: I think it would be more clear to drop the _maybe prefixes from the _maybe_* methods as it looks like all of these methods unconditionally perform the action in the name.
What does this PR do?
This PR adds a new representation autoencoder implementation, AutoencoderRAE, to diffusers.
Implements diffusers.models.autoencoders.autoencoder_rae.AutoencoderRAE with a frozen pretrained vision encoder (DINOv2 / SigLIP2 / ViT-MAE) and a ViT-MAE style decoder.
The decoder implementation is aligned with the RAE-main GeneralDecoder parameter structure, enabling loading of existing trained decoder checkpoints (e.g. model.pt) without key mismatches when encoder/decoder settings are consistent.
Adds unit/integration tests under diffusers/tests/models/autoencoders/test_models_autoencoder_rae.py.
Registers exports so users can import directly via from diffusers import AutoencoderRAE.
Fixes #13000
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Usage
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.