A deep learning project for Fundus Autofluorescence (FAF) image generation, inversion and editing using Scalable Interpolant Transformers (SiT). This repository enables conditional generation of synthetic FAF images based on genetic mutations, patient age, and eye laterality, with support for real-to-latent inversion and semantic image editing.
This project implements a flow-based generative model (SiT) for synthesizing and editing FAF images conditioned on:
- Gene: The genetic mutation associated with inherited retinal diseases (36 gene classes)
- Laterality: Left (L) or Right (R) eye
- Age: Patient age (normalized 0-100)
The model supports:
- Conditional Generation: Generate novel FAF images given specific conditions
- Image Inversion: Map real FAF images back to the latent space
- Semantic Editing: Modify conditions (gene, age, laterality) to generate edited versions of real images
- Multi-conditional generation (gene, laterality, age)
- ODE-based image inversion for real images
- Semantic editing via latent space manipulation
- Comprehensive evaluation suite (conditioning accuracy, FID, TSTR)
- Distributed training with PyTorch DDP
- Weights & Biases integration for experiment tracking
- Operating System: Linux
- GPU: NVIDIA GPU with CUDA support (2 x RTX A6000s used)
- CUDA: Version 11.x or higher
- Conda: Miniconda or Anaconda for environment management
- Python: 3.10+
faf_flow_edit/
├── data/ # Datasets and metadata
│ ├── class_mapping.json # Gene-to-index mapping
│ ├── images_cleaned/ # Cleaned full-resolution images
│ ├── images_256_cleaned/ # Resized 256x256 images
│ ├── synthetic_10kSamples/ # Generated synthetic datasets
│ └── eval_10kSamples_256res/ # Evaluation datasets
├── environments/ # Conda environment files
│ ├── env_sit.yml
│ └── env_stylegan2ada.yml
├── evaluation/ # Evaluation results
│ ├── conditioning_results_*/ # Conditioning evaluation outputs
│ └── TSTR_reports/ # TSTR classification reports
├── scripts/
│ ├── evaluation/ # Evaluation scripts
│ │ ├── evaluate_conditioning.py
│ │ └── run_tstr.py
│ └── prepare_datasets/ # Data preparation scripts
│ ├── clean_dataset.py
│ ├── resize_dataset.py
│ └── generate_synthetic_dataset.py
├── SiT/ # Core SiT model
│ ├── train.py # Training script
│ ├── sample.py # Sampling script
│ ├── invert.py # Inversion and editing
│ ├── dataset.py # Custom dataset class
│ ├── models.py # Model architectures
│ ├── transport/ # Flow matching transport
│ ├── weights/ # Model checkpoints
│ ├── inversions/ # Inversion outputs
│ └── edits/ # Editing outputs
└── stylegan2-ada-pytorch/ # StyleGAN2-ADA baseline
git clone <repository-url>
cd faf_flow_editThe SiT environment includes all necessary dependencies:
conda env create -f environments/env_sit.yml
conda activate SiT_flowThis installs:
- PyTorch with CUDA support
- Diffusers (for VAE)
- torchvision, numpy, pandas
- pytorch-fid (for FID computation)
- wandb (for experiment tracking)
If you plan to compare with StyleGAN2-ADA:
conda env create -f environments/env_stylegan2ada.yml
conda activate stylegan2adaConfidential Medical Data: This project uses clinical Fundus Autofluorescence (FAF) images from patients with inherited retinal diseases provided by Moorfields Eye Hospital. The data is confidential and not publicly available.
Dataset Statistics:
- Total Samples: ~34,000 FAF images after cleaning
- Gene Classes: 36 unique genetic mutations
- Image Resolution: Original variable resolution, standardized to 256×256 for training
- Attributes per Image:
file_name: Image filenamegene: Genetic mutation (e.g., ABCA4, USH2A, CHM)laterality: Eye side (L/R)age: Patient age at imaging
Class Mapping (data/class_mapping.json):
{"ABCA4": 0, "BBS1": 1, "BEST1": 2, ..., "USH2A": 35}The cleaning script filters images to include only samples with valid gene labels from the class mapping:
python scripts/prepare_datasets/clean_dataset.py \
--input_csv <path/to/original_metadata.csv> \
--class_mapping_json data/class_mapping.json \
--output_csv data/images_cleaned/metadata_cleaned.csv \
--source_images_dir <path/to/original_images> \
--output_images_dir data/images_cleanedWhat it does:
- Filters metadata CSV to keep only genes present in
class_mapping.json - Copies corresponding images to the cleaned directory
- Reports statistics on filtered and missing files
SiT requires images at a fixed resolution (256×256). Resize the cleaned dataset:
python scripts/prepare_datasets/resize_dataset.py \
--src data/images_cleaned \
--dest data/images_256_cleaned \
--size 256Features:
- Multi-threaded processing for speed
- High-quality LANCZOS resampling
- Automatic handling of various image formats (PNG, JPG, JPEG)
After resizing, create the metadata file for the resized images:
python scripts/prepare_datasets/create_metadata_256.py \
--input_csv data/images_cleaned/metadata_cleaned.csv \
--output_csv data/images_256_cleaned/metadata_cleaned_256.csvTrain the SiT model using Distributed Data Parallel (DDP):
cd SiT
torchrun --nproc_per_node=<NUM_GPUS> train.py \
--data-path ../data/images_256_cleaned/metadata_cleaned_256.csv \
--img-dir ../data/images_256_cleaned \
--mapping-file ../data/class_mapping.json \
--model SiT-XL/2 \
--image-size 256 \
--epochs 200 \
--global-batch-size 16 \
--results-dir results \
--ckpt-every 50000 \
--sample-every 10000 \
--log-every 100 \
--cfg-scale 4.0 \
--wandbKey Arguments:
| Argument | Description | Default |
|---|---|---|
--data-path |
Path to metadata CSV | Required |
--img-dir |
Directory containing images | Required |
--mapping-file |
Gene-to-index JSON mapping | Required |
--model |
Model architecture | SiT-XL/2 |
--image-size |
Training resolution | 16 |
--epochs |
Number of training epochs | 200 |
--global-batch-size |
Total batch size across GPUs | 256 |
--cfg-scale |
Classifier-free guidance scale | 4.0 (default) |
--wandb |
Enable W&B logging | Flag |
--ckpt |
Resume from checkpoint | Optional |
Resume Training:
torchrun --nproc_per_node=<NUM_GPUS> train.py \
--data-path ../data/images_256_cleaned/metadata_cleaned_256.csv \
--img-dir ../data/images_256_cleaned \
--mapping-file ../data/class_mapping.json \
--ckpt results/<experiment>/checkpoints/<step>.pt \
...The trained SiT-XL/2 model can be found in the huggingface repo for this project.
Generate individual samples with specific conditions:
cd SiT
python sample.py \
--ckpt weights/<checkpoint>.pt \
--mapping-file ../data/class_mapping.json \
--gene ABCA4 \
--laterality L \
--age 45 \
--num-samples 4 \
--cfg-scale 4.0 \
--output-dir samplesGenerate thousands of samples matching the demographic distribution of real data:
python scripts/prepare_datasets/generate_synthetic_dataset.py \
--ckpt SiT/weights/<checkpoint>.pt \
--data-path data/images_256_cleaned/metadata_cleaned_256.csv \
--mapping-file data/class_mapping.json \
--output-dir data/synthetic_10kSamples \
--num-samples 10000 \
--batch-size 32 \
--cfg-scale 4.0 \
--seed 42This script:
- Samples demographic distributions (gene, age, laterality) from real metadata
- Generates synthetic images matching those distributions
- Creates a manifest CSV for evaluation
Invert a real FAF image to obtain its latent noise representation:
cd SiT
python invert.py invert \
--ckpt weights/<checkpoint>.pt \
--input-image <path/to/real_image.png> \
--gene ABCA4 \
--laterality L \
--age 55 \
--mapping-file ../data/class_mapping.json \
--output-dir inversions \
--verifyArguments:
| Argument | Description |
|---|---|
--ckpt |
Path to trained model checkpoint |
--input-image |
Path to real FAF image |
--gene |
Gene label of the input image |
--laterality |
Eye laterality (L/R) |
--age |
Patient age |
--verify |
Reconstruct image to verify inversion quality |
--output-dir |
Directory to save outputs |
Outputs:
inverted_noise.pt: Latent noise tensor with metadatareconstruction.png: Reconstructed image (if--verifyflag used)original.png: Copy of input image for comparison
ODE Solver Options:
--sampling-method dopri5 # ODE solver: dopri5, euler, heun
--num-steps 50 # Number of ODE steps
--atol 1e-6 # Absolute tolerance
--rtol 1e-3 # Relative toleranceEdit an inverted image by changing its conditioning attributes:
cd SiT
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/inverted_noise.pt \
--target-gene USH2A \
--target-age 70 \
--target-laterality R \
--mapping-file ../data/class_mapping.json \
--output-dir editsArguments:
| Argument | Description |
|---|---|
--noise-file |
Path to inverted noise (.pt file) |
--target-gene |
New gene (or None to keep original) |
--target-laterality |
New laterality (or None to keep) |
--target-age |
New age (or None to keep) |
Below are real experiments demonstrating the inversion and editing pipeline on a sample FAF image.
Invert a real ABCA4 patient image (Right eye, Age 32) to latent space:
cd SiT
python invert.py invert \
--ckpt weights/<checkpoint>.pt \
--input-image ../data/eval_10kSamples_256res/real_10kSamples_256res/00000018.pat_00448798.sdb_AF_B-0_0.png \
--gene ABCA4 \
--laterality R \
--age 32 \
--output-dir inversions/00000018_pat_00448798_sdb \
--verifyTerminal Output
Loading checkpoint: weights/<checkpoint>.pt
Loaded config from checkpoint: Model=SiT-XL/2, Size=256
Condition: Gene=ABCA4(0), Eye=R(1), Age=32(0.32)
Loading EMA weights...
Encoded to latent shape: torch.Size([1, 4, 32, 32])
Inverting (Data -> Noise)...
Inverted noise shape: torch.Size([1, 4, 32, 32])
Inverted noise stats: mean=0.0040, std=0.9685
SUCCESS: Inverted noise saved to inversions/00000018_pat_00448798_sdb/inverted_noise.pt
Verifying (Noise -> Data reconstruction)...
SUCCESS: Reconstruction saved to inversions/00000018_pat_00448798_sdb/reconstruction.png
SUCCESS: Original saved to inversions/00000018_pat_00448798_sdb/original.png
Latent reconstruction MSE: 0.000006
| Original (ABCA4, R, 32) | Reconstructed (ABCA4, R, 32) |
|---|---|
![]() |
![]() |
The low reconstruction MSE (0.000006) confirms high-quality inversion.
Age Editing — Increase or decrease patient age:
# Increase age: 32 → 65
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-age 65 \
--output-dir edits/00000018_pat_00448798_sdb
# Decrease age: 32 → 20
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-age 20 \
--output-dir edits/00000018_pat_00448798_sdb| Reconstructed (Age 32) | Older (Age 65) | Younger (Age 20) |
|---|---|---|
![]() |
![]() |
![]() |
Laterality Editing — Flip from Right to Left eye:
# Flip laterality: R → L
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-laterality L \
--output-dir edits/00000018_pat_00448798_sdb| Reconstructed (R) | Edited (L) |
|---|---|
![]() |
![]() |
Gene Editing — Transform between genetic phenotypes:
# Change gene: ABCA4 → USH2A
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-gene USH2A \
--output-dir edits/00000018_pat_00448798_sdb
# Change gene: ABCA4 → OPA1
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-gene OPA1 \
--output-dir edits/00000018_pat_00448798_sdb| Reconstructed (ABCA4) | Edited (USH2A) | Edited (OPA1) |
|---|---|---|
![]() |
![]() |
![]() |
Gene + Age:
# USH2A + Age 65
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-gene USH2A \
--target-age 65 \
--output-dir edits/00000018_pat_00448798_sdb
# USH2A + Age 20
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-gene USH2A \
--target-age 20 \
--output-dir edits/00000018_pat_00448798_sdb| Reconstructed (ABCA4, 32) | USH2A + Age 65 | USH2A + Age 20 |
|---|---|---|
![]() |
![]() |
![]() |
Gene + Laterality:
# USH2A + Left Eye
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-gene USH2A \
--target-laterality L \
--output-dir edits/00000018_pat_00448798_sdb
# OPA1 + Left Eye
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-gene OPA1 \
--target-laterality L \
--output-dir edits/00000018_pat_00448798_sdb| Reconstructed (ABCA4, R) | USH2A + L | OPA1 + L |
|---|---|---|
![]() |
![]() |
![]() |
Age + Laterality:
# Left Eye + Age 65
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-laterality L \
--target-age 65 \
--output-dir edits/00000018_pat_00448798_sdb| Reconstructed (R, 32) | Edited (L, 65) |
|---|---|
![]() |
![]() |
# USH2A + Left Eye + Age 75
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-gene USH2A \
--target-laterality L \
--target-age 75 \
--output-dir edits/00000018_pat_00448798_sdb
# OPA1 + Right Eye + Age 40
python invert.py edit \
--ckpt weights/<checkpoint>.pt \
--noise-file inversions/00000018_pat_00448798_sdb/inverted_noise.pt \
--target-gene OPA1 \
--target-laterality R \
--target-age 40 \
--output-dir edits/00000018_pat_00448798_sdb| Reconstructed (ABCA4, R, 32) | USH2A, L, 75 | OPA1, R, 40 |
|---|---|---|
![]() |
![]() |
![]() |
All models used for evaluation can be found in the huggingface repo for this project.
Evaluate how well synthetic images match their conditioning labels using trained judge classifiers:
python scripts/evaluation/evaluate_conditioning.py \
--real-csv data/images_256_cleaned/metadata_cleaned_256.csv \
--real-img-dir data/images_256_cleaned \
--synth-csv data/synthetic_10kSamples/synthetic_manifest.csv \
--synth-img-dir data/synthetic_10kSamples \
--output-dir evaluation/conditioning_results_10k \
--train-samples 5000 \
--batch-size 64 \
--epochs 5 \
--save-modelsWhat it evaluates:
- Laterality Accuracy: Classification accuracy for L/R prediction
- Age Correlation: R and MAE for age regression
Example Results (10k samples):
LATERALITY
------------------------------
Overall Accuracy: 95.15%
Left Eye Accuracy: 93.74%
Right Eye Accuracy: 96.52%
AGE
------------------------------
Correlation (R): 0.8488
R-squared: 0.2662
Mean Absolute Error: 14.84 years
Evaluation-only mode (using pre-trained judges):
python scripts/evaluation/evaluate_conditioning.py \
--synth-csv data/synthetic_10kSamples/synthetic_manifest.csv \
--synth-img-dir data/synthetic_10kSamples \
--output-dir evaluation/conditioning_results_10k \
--eval-onlyCompute Fréchet Inception Distance (FID) to measure image quality and diversity.
Using pytorch-fid:
# Ensure both directories contain images at the same resolution
python -m pytorch_fid \
data/eval_10kSamples_256res/real_10kSamples_256res \
data/eval_10kSamples_256res/synthetic_10kSamples_SiT_256resCompare SiT vs StyleGAN2-ADA:
# SiT FID
python -m pytorch_fid \
data/eval_10kSamples_256res/real_10kSamples_256res \
data/eval_10kSamples_256res/synthetic_10kSamples_SiT_256res
# StyleGAN2-ADA FID
python -m pytorch_fid \
data/eval_10kSamples_256res/real_10kSamples_256res \
data/eval_10kSamples_256res/synthetic_10kSamples_stylegan2ada_256resTrain on Synthetic, Test on Real (TSTR) evaluation measures how well a classifier trained on synthetic data performs on real data:
# Evaluate SiT synthetic data
python scripts/evaluation/run_tstr.py \
--experiment_name TSTR_SiT \
--train_csv data/synthetic_10kSamples/synthetic_manifest.csv \
--train_img_dir data/eval_10kSamples_256res/synthetic_10kSamples_SiT_256res \
--train_mode synthetic \
--test_csv data/images_256_cleaned/metadata_cleaned_256.csv \
--test_img_dir data/images_256_cleaned \
--test_mode real \
--mapping_json data/class_mapping.json \
--outdir evaluation/TSTR_reports \
--batch_size 32 \
--epochs 10
# Evaluate StyleGAN2-ADA synthetic data
python scripts/evaluation/run_tstr.py \
--experiment_name TSTR_SG2ADA \
--train_csv data/synthetic_10kSamples_stylegan2ada/stylegan_manifest.csv \
--train_img_dir data/eval_10kSamples_256res/synthetic_10kSamples_stylegan2ada_256res \
--train_mode synthetic \
--test_csv data/images_256_cleaned/metadata_cleaned_256.csv \
--test_img_dir data/images_256_cleaned \
--test_mode real \
--mapping_json data/class_mapping.json \
--outdir evaluation/TSTR_reports \
--batch_size 32 \
--epochs 10
# Real data upper bound (Train on Real, Test on Real)
python scripts/evaluation/run_tstr.py \
--experiment_name TSTR_Real \
--train_csv data/eval_10kSamples_256res/real_10kSamples_train.csv \
--train_img_dir data/eval_10kSamples_256res/real_10kSamples_256res \
--train_mode real \
--test_csv data/images_256_cleaned/metadata_cleaned_256.csv \
--test_img_dir data/images_256_cleaned \
--test_mode real \
--mapping_json data/class_mapping.json \
--outdir evaluation/TSTR_reports \
--batch_size 32 \
--epochs 10Example Results (10k samples):
| Model | Test Accuracy |
|---|---|
| Real (Upper Bound) | 78.92% |
| SiT | 67.85% |
| StyleGAN2-ADA | 46.33% |
If you use this code or methodology in your research, please cite:
@misc{sit_faf_generate_edit,
author = {Amit John},
title = {SiT FAF Generation and Editing},
year = {2025},
url = {https://github.com/johnamit/sit-faf-generate-edit}
}SiT (Scalable Interpolant Transformers):
@article{ma2024sit,
title={SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers},
author={Ma, Nanye and Goldstein, Mark and Albergo, Michael S and Boffi, Nicholas M and Vanden-Eijnden, Eric and Xie, Saining},
journal={arXiv preprint arXiv:2401.08740},
year={2024}
}This project is licensed under the terms specified in [TO ADD LATER].
Note: The medical imaging data used in this project is confidential and not included in this repository. Please ensure you have appropriate permissions and ethics approval before working with medical imaging data.















