This repository contains the code to reproduce the results from the paper Disentangling representations of retinal images with generative models.
We present a novel population model for retinal fundus images that effectively disentangles patient attributes from camera effects with a a disentanglement loss based on distance correlation. The resulting models enable controllable and highly realistic fundus image generation.
Set up a python environment with a python version 3.9
. Then, download the repository,
activate the environment and install all other dependencies with
cd disentangling-retinal-images
pip install --editable .
This installs the code in src
as an editable package and all the dependencies in
requirements.txt.
- configs: Configuration files for all experiments.
- scripts: Bash scripts for model training, testing and evaluation.
- src: Main source code to run the experiments.
- dataset: Pytorch EyePACS dataset.
- generative_model: Pytorch lightning stylegan module.
- evaluation: Model evaluation with kNN classifiers, image quality metrics (fid-score), and swapped subspace classification.
- train.py: Model training script.
- test.py: Model testing script.
- predict.py: Image embedding prediction script (model inference).
The EyePACS dataset can be accessed upon request: https://www.eyepacs.com/. Our dataset parser can be checked out in dataset/eyepacs_parsing.
For our EyePACS pytorch dataset you will need a factorized metadata (with a categorical columns mapping) and a diretory to your dataset splits. Therefore, we also share our scripts to factorize and to split the dataset. In dataset/eyepacs_parsing we share our categorical columns mapping as a reference.
Moreover, before factorizing and splitting the dataset, we pre-processed the retinal fundus images with: https://github.com/berenslab/fundus_circle_cropping.
For model training, run the following command
python src/train.py -c ./configs/configs_train/test.yaml
Here we run the model with a test training configuration file. All model configuration files for reproducing the models of the paper can be found here.
To test the model run the script
python src/test.py -d path/to/experiment/folder -c ./configs/configs_test/file.yaml
To predict the learned image embeddings for all data set splits (train, val, test), execute the bash script
sh scripts/predict_embeddings.sh path/to/experiment/folder configs/configs_predict/default.yaml
with the arguments $1: path to model experiment folder
,
$2: configuration file for predict.py
. Hint: you need to adjust PROJECT_DIR
and
python_path
.
Evaluate the kNN classifier performance with the predicted embeddings for EyePACS:
sh scripts/knn_eval_embeddings.sh path/to/experiment/folder 4 12 16
with the arguments $1: path to model experiment folder
,$2-$end: subspace dimensions
.
Here, we chose the subspace dimensions of [age, camera, identity] = [4, 12, 16]
.
Hint: you need to adjust PROJECT_DIR
and python_path
.
Compute image quality metrics (fid, kid):
python src/evaluation/eval_image_quality.py -d path/to/experiment/folder -c ./configs/configs_image_quality/default.yaml
Train subspace classifiers on age subspaces:
python src/evaluation/swapped_subspaces/train_age_classification.py -d path/to/experiment/folder -c configs/configs_swapped_subspaces/train_age_classification.yaml
Test trained classification model on swapped age subspaces:
python src/evaluation/swapped_subspaces/test_age_classification.py -d path/to/classification/model -c configs/configs_swapped_subspaces/test_age_classification.yaml
The stylegan model interface is dataset-agnostic.
Therefore, if you want to train our model on a different dataset, start replacing our
EyePACS dataset with your dataset and return an identically
structured dictionary in the __getitem__
function.
The model weight of our trained generative models from the paper can be found on zenodo.
We used a stylegan2-ada pytorch lightning implementation as a starting point for our experiments: https://github.com/nihalsid/stylegan2-ada-lightning. From this repository we extended the gan architecture with gan inversion and independent subspace learning (subspace classifiers and distance correlation loss).
If you find our code or paper useful, please consider citing this work.
@misc{mueller2024disentangling,
title = {Disentangling representations of retinal images with generative models},
author = {M\"uller, Sarah and Koch, Lisa M. and Lensch, Hendrik, P. A. and Berens, Philipp},
year = {2024},
eprint = {2402.19186},
archivePrefix = {arXiv},
}