🆕 For examples running DiffDRR
on real data, please check out our latest work, DiffPose
.
Auto-differentiable DRR synthesis and optimization in PyTorch
DiffDRR
is a PyTorch-based digitally reconstructed radiograph (DRR) generator that provides
- Auto-differentiable DRR syntheisis
- GPU-accelerated rendering
- A pure Python implementation
Most importantly, DiffDRR
implements DRR synthesis as a PyTorch module, making it interoperable in deep learning pipelines.
Below is a comparison of DiffDRR
to a real X-ray (X-rays and CTs from the DeepFluoro dataset):
To install DiffDRR
from PyPI:
pip install diffdrr
DiffDRR
also requires PyTorch3D
, which gives us the ability to use multiple parameterizations of SO(3) when constructing camera poses! For most users,
conda install pytorch3d -c pytorch3d
should work perfectly well. Otherwise, see PyTorch3D's installation guide.
The following minimal example specifies the geometry of the projectional radiograph imaging system and traces rays through a CT volume:
import matplotlib.pyplot as plt
import torch
from diffdrr.drr import DRR
from diffdrr.data import load_example_ct
from diffdrr.visualization import plot_drr
# Read in the volume and get the isocenter
volume, spacing = load_example_ct()
bx, by, bz = torch.tensor(volume.shape) * torch.tensor(spacing) / 2
# Initialize the DRR module for generating synthetic X-rays
device = "cuda" if torch.cuda.is_available() else "cpu"
drr = DRR(
volume, # The CT volume as a numpy array
spacing, # Voxel dimensions of the CT
sdr=300.0, # Source-to-detector radius (half of the source-to-detector distance)
height=200, # Height of the DRR (if width is not seperately provided, the generated image is square)
delx=4.0, # Pixel spacing (in mm)
).to(device)
# Set the camera pose with rotation (yaw, pitch, roll) and translation (x, y, z)
rotation = torch.tensor([[torch.pi, 0.0, torch.pi / 2]], device=device)
translation = torch.tensor([[bx, by, bz]], device=device)
# 📸 Also note that DiffDRR can take many representations of SO(3) 📸
# For example, quaternions, rotation matrix, axis-angle, etc...
img = drr(rotation, translation, parameterization="euler_angles", convention="ZYX")
plot_drr(img, ticks=False)
plt.show()
On a single NVIDIA RTX 2080 Ti GPU, producing such an image takes
33.3 ms ± 6.78 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
The full example is available at
introduction.ipynb
.
We demonstrate the utility of our auto-differentiable DRR generator by solving the 2D/3D registration problem with gradient-based optimization. Here, we generate two DRRs:
- A fixed DRR from a set of ground truth parameters
- A moving DRR from randomly initialized parameters
To solve the registration problem, we use gradient descent to maximize an image loss similarity metric between the two DRRs. This produces optimization runs like this:
The full example is available at
optimizers.ipynb
.
DiffDRR
source code, docs, and CI are all built using
nbdev
. To get set up with nbdev
, install
the following
mamba install jupyterlab nbdev -c fastai -c conda-forge
nbdev_install_quarto # To build docs
nbdev_install_hooks # Make notebooks git-friendly
Running nbdev_help
will give you the full list of options. The most
important ones are
nbdev_preview # Render docs locally and inspect in browser
nbdev_clean # NECESSARY BEFORE PUSHING
nbdev_test # tests notebooks
nbdev_export # builds package and builds docs
For more details, follow this in-depth tutorial.
DiffDRR
reformulates Siddon’s method,1 the
canonical algorithm for calculating the radiologic path of an X-ray
through a volume, as a series of vectorized tensor operations. This
version of the algorithm is easily implemented in tensor algebra
libraries like PyTorch to achieve a fast auto-differentiable DRR
generator.
If you find DiffDRR
useful in your work, please cite our
paper (or the freely
accessible arXiv version):
@inproceedings{gopalakrishnanDiffDRR2022,
author = {Gopalakrishnan, Vivek and Golland, Polina},
title = {Fast Auto-Differentiable Digitally Reconstructed Radiographs for Solving Inverse Problems in Intraoperative Imaging},
year = {2022},
booktitle = {Clinical Image-based Procedures: 11th International Workshop, CLIP 2022, Held in Conjunction with MICCAI 2022, Singapore, Proceedings},
series = {Lecture Notes in Computer Science},
publisher = {Springer},
doi = {https://doi.org/10.1007/978-3-031-23179-7_1},
}