This repository implements β-VAE and λβ-VAE models to evaluate disentanglement metrics on the dSprites, Shapes3D, and MPI3D datasets. It includes both linear and nonlinear implementations. The β-VAE baselines are trained first, followed by continued training with an additional λ term in the loss function to assess its impact on reconstruction quality and disentanglement performance.
- Models: Linear and convolutional encoder-decoder VAE architectures.
- Loss Functions: β-VAE (reconstruction + β × KL divergence) and λβ-VAE (reconstruction + β × KL divergence + λ × L2 loss).
-
Datasets: dSprites, Shapes3D, and MPI3D (included in the
datafolder). -
Metrics: Negative Log-Likelihood (NLL), Mutual Information Gap (MIG), Separated Attribute Predictability (SAP), and
$I_m$ score. - Visualizations: Image reconstruction grids, latent traversal GIFs, mutual information heatmaps, boxplots and heatmaps for all metrics, and interactive Plotly heatmaps.
- Reproducibility: Multi-seed experiments with fixed random seeds.
- Download and unzip the repository archive from GitHub.
- Ensure the
datafolder contains the datasets. - Install dependencies using Python 3.8+:
pip install -r requirements.txt
Run the experiments in a linear sequence: start with linear models, then proceed to nonlinear datasets, and finally generate interactive visualizations. This ensures compatibility and allows for comparison between linear and nonlinear results.
To execute everything step by step:
python linear_betavae.py
python main_beta.py
python main_lambda.py
python interactive_heatmap_visualizer.py
The linear experiments use two methods (fixed-point iteration and AdamW optimization) within the script. Run this before nonlinear experiments:
python linear_betavae.py
- Results are saved to
./linear_betavae_results/. - Pre-computed results: Download
linear_betavae_results.zipattached in the repository main.
Train and evaluate nonlinear β-VAE models:
python main_beta.py
- Results are saved to
./betavae_results/<dataset>/. - Includes trained models, metrics, visualizations, and summary files.
- Aggregated boxplots:
./betavae_results/metric_boxplots_across_datasets.png.
Continue training from β-VAE checkpoints (run main_beta.py first):
python main_lambda.py
- Results are saved to
./lambda_betavae_results/<dataset>/. - Includes updated models, metrics, visualizations, and summary files.
- Aggregated heatmaps:
./lambda_betavae_results/metric_heatmaps_across_datasets.png.
After running nonlinear experiments, load metrics from .npz files, calculate percentage changes vs. baseline for NLL, SAP, MIG, and
python interactive_heatmap_visualizer.py
- Results are saved to
./lambda_betavae_results/including "percent_heatmaps_across_datasets.png" and HTML files (e.g.,interactive_tchebycheff_heatmap_<dataset>.html). - Use the slider to adjust weights (0-100% reconstruction weight) and find the best β-λ pair for different priorities.
- Pre-computed results: Download
interactive_heatmap_results.zipattached in the repository main.
Customize hyperparameters (e.g., β values, λ values, number of seeds, training steps) in config.py.
We do not redistribute datasets. Please download them from the official sources below and place the files in the data folder so loaders work correctly.
dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz: dSprites (shapes with variations in shape, scale, orientation, position).3dshapes.h5: Shapes3D (3D shapes with variations in hue, shape, scale, orientation).mpi3d_real.npz: MPI3D (realistic 3D objects with variations in color, shape, size, camera position).- Source: https://github.com/rr-learning/disentanglement_dataset Expected layout:
data/
├─ dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz
├─ 3dshapes.h5
└─ mpi3d_real.npz
-
Metrics: Reconstruction loss (NLL), MIG, SAP,
$I_m$ . -
Visualizations:
- Original and reconstructed image grids.
- Latent traversal animations (GIFs).
- Mutual information heatmaps.
- Boxplots and heatmaps for all metrics.
- Interactive weighted score heatmaps.
- Summary statistics (mean ± std across seeds) in text files per dataset.
Pre-trained models, metrics, and visualizations from training runs are available for download. These can be used to reproduce or analyze results without re-training.
-
$\beta=[1,4,8,16,32], \lambda=[0,2,4,6,8], m=10$ : -
$\beta=[1,4,8,16,32,64,128,256], \lambda=[0,4,8,16,32], m=15$ :
config.py: Configuration settings.datasets.py: Dataset loading.models.py: VAE model definitions.losses.py: Loss functions and training utilities.metrics.py: Disentanglement metric calculations.utils.py: Utility functions for seeding and data handling.visualizations.py: Functions for generating plots and images.main_beta.py: Script for nonlinear β-VAE training and evaluation.main_lambda.py: Script for nonlinear λβ-VAE continuation and evaluation.linear_betavae.py: Main script for linear β-VAE and λβ-VAE experiments.interactive_heatmap_visualizer.py: Script for generating static percentage change heatmaps and interactive Plotly heatmaps.requirements.txt: List of dependencies.
- Datasets and metrics inspired by the Disentanglement Library: https://github.com/google-research/disentanglement_lib.