Skip to content

Commit

Permalink
Add a data object to wrap a CT scan
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Mar 18, 2024
1 parent 8e38f7d commit 089631f
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 201 deletions.
14 changes: 6 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,20 @@ from diffdrr.data import load_example_ct
from diffdrr.visualization import plot_drr

# Read in the volume and get its origin and spacing in world coordinates
volume, origin, spacing = load_example_ct()
subject = load_example_ct()

# Initialize the DRR module for generating synthetic X-rays
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
drr = DRR(
volume, # The CT volume as a numpy array
origin, # Location of the voxel [0, 0, 0] in world coordinates
spacing, # Voxel dimensions of the CT
sdr=510.0, # Source-to-detector radius (half of the source-to-detector distance)
subject, # An object storing the CT volume, origin, and voxel spacing
sdd=1020.0, # Source-to-detector radius (half of the source-to-detector distance)
height=200, # Height of the DRR (if width is not seperately provided, the generated image is square)
delx=2.0, # Pixel spacing (in mm)
).to(device)

# Set the camera pose with rotations (yaw, pitch, roll) and translations (x, y, z)
rotations = torch.tensor([[torch.pi / 2, 0.0, -torch.pi / 2]], device=device)
translations = torch.tensor([[350.0, 325.0, -175.0]], device=device)
rotations = torch.tensor([[0.0, torch.pi / 2, torch.pi / 2]], device=device)
translations = torch.tensor([[-10.0, 850.0, -175.0]], device=device)

# 📸 Also note that DiffDRR can take many representations of SO(3) 📸
# For example, quaternions, rotation matrix, axis-angle, etc...
Expand All @@ -72,7 +70,7 @@ plt.show()

On a single NVIDIA RTX 2080 Ti GPU, producing such an image takes

35.8 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
34.9 ms ± 22.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

The full example is available at
[`introduction.ipynb`](https://vivekg.dev/DiffDRR/tutorials/introduction.html).
Expand Down
15 changes: 7 additions & 8 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,24 @@
'doc_host': 'https://vivekg.dev',
'git_url': 'https://github.com/eigenvivek/DiffDRR',
'lib_path': 'diffdrr'},
'syms': { 'diffdrr.data': { 'diffdrr.data.load_example_ct': ('api/data.html#load_example_ct', 'diffdrr/data.py'),
'diffdrr.data.read_nifti': ('api/data.html#read_nifti', 'diffdrr/data.py')},
'syms': { 'diffdrr.data': { 'diffdrr.data.Subject': ('api/data.html#subject', 'diffdrr/data.py'),
'diffdrr.data.Subject.__init__': ('api/data.html#subject.__init__', 'diffdrr/data.py'),
'diffdrr.data.Subject.from_dicom': ('api/data.html#subject.from_dicom', 'diffdrr/data.py'),
'diffdrr.data.Subject.from_nifti': ('api/data.html#subject.from_nifti', 'diffdrr/data.py'),
'diffdrr.data.Subject.parse_density': ('api/data.html#subject.parse_density', 'diffdrr/data.py'),
'diffdrr.data.load_example_ct': ('api/data.html#load_example_ct', 'diffdrr/data.py')},
'diffdrr.detector': { 'diffdrr.detector.Detector': ('api/detector.html#detector', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.__init__': ('api/detector.html#detector.__init__', 'diffdrr/detector.py'),
'diffdrr.detector.Detector._initialize_carm': ( 'api/detector.html#detector._initialize_carm',
'diffdrr/detector.py'),
'diffdrr.detector.Detector.flip_xz': ('api/detector.html#detector.flip_xz', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.forward': ('api/detector.html#detector.forward', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.intrinsic': ('api/detector.html#detector.intrinsic', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.translate': ('api/detector.html#detector.translate', 'diffdrr/detector.py'),
'diffdrr.detector.diffdrr_to_deepdrr': ('api/detector.html#diffdrr_to_deepdrr', 'diffdrr/detector.py')},
'diffdrr.detector.Detector.intrinsic': ('api/detector.html#detector.intrinsic', 'diffdrr/detector.py')},
'diffdrr.drr': { 'diffdrr.drr.DRR': ('api/drr.html#drr', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.__init__': ('api/drr.html#drr.__init__', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.forward': ('api/drr.html#drr.forward', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.inverse_projection': ('api/drr.html#drr.inverse_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.perspective_projection': ('api/drr.html#drr.perspective_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.reshape_transform': ('api/drr.html#drr.reshape_transform', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.set_bone_attenuation_multiplier': ( 'api/drr.html#drr.set_bone_attenuation_multiplier',
'diffdrr/drr.py'),
'diffdrr.drr.DRR.set_intrinsics': ('api/drr.html#drr.set_intrinsics', 'diffdrr/drr.py'),
'diffdrr.drr.Registration': ('api/drr.html#registration', 'diffdrr/drr.py'),
'diffdrr.drr.Registration.__init__': ('api/drr.html#registration.__init__', 'diffdrr/drr.py'),
Expand Down
86 changes: 59 additions & 27 deletions diffdrr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,69 @@

import nibabel as nib
import numpy as np
import torch

# %% auto 0
__all__ = ['read_nifti', 'load_example_ct']
__all__ = ['load_example_ct', 'Subject']

# %% ../notebooks/api/03_data.ipynb 4
def read_nifti(filename: Path | str, return_affine=False):
"""Read a NIFTI and return the volume, affine matrix, and voxel spacings."""
img = nib.load(filename)
affine = img.affine
volume = img.get_fdata().astype(np.float32)
spacing = img.header.get_zooms()

# If affine matrix has negative spacing, flip axis
for axis in range(volume.ndim):
if affine[axis, axis] < 0:
volume = np.flip(volume, axis)
volume = np.copy(volume)

# Get the origin in world coordinates from the affine matrix, correcting for negative spacings
corners = np.array([[0, 0, 0, 1], [*volume.shape, 1]])
origin = np.einsum("ij, nj -> ni", affine, corners).min(axis=0)[:3]
origin = tuple(origin.astype(np.float32))

if return_affine:
return volume, origin, spacing, affine
else:
return volume, origin, spacing

# %% ../notebooks/api/03_data.ipynb 5
def load_example_ct():
def load_example_ct(bone_attenuation_multiplier=1.0):
"""Load an example chest CT for demonstration purposes."""
datadir = Path(__file__).resolve().parent / "data"
filename = datadir / "cxr.nii"
return read_nifti(filename)
return Subject.from_nifti(filename, bone_attenuation_multiplier)

# %% ../notebooks/api/03_data.ipynb 5
class Subject:
def __init__(
self,
volume,
affine,
origin,
spacing,
bone_attenuation_multiplier,
):
self.volume = torch.from_numpy(volume)
self.affine = torch.from_numpy(affine)
self.origin = torch.tensor(origin)
self.spacing = torch.tensor(spacing)
self.density = self.parse_density(self.volume, bone_attenuation_multiplier)
self.bone_attenuation_multiplier = bone_attenuation_multiplier

@staticmethod
def parse_density(volume, bone_attenuation_multiplier):
volume[torch.where(350 < volume)] *= bone_attenuation_multiplier
density = torch.max(
torch.min(
0.001029 * volume + 1.03,
0.0005886 * volume + 1.03,
),
torch.zeros_like(volume),
)
return density

@staticmethod
def from_nifti(filename: Path | str, bone_attenuation_multiplier=1.0):
# Read the NIFTI volume
img = nib.load(filename)
affine = img.affine
volume = img.get_fdata().astype(np.float32)
spacing = img.header.get_zooms()

# If affine matrix has negative spacing, flip axis
for axis in range(volume.ndim):
if affine[axis, axis] < 0:
volume = np.flip(volume, axis)
volume = np.copy(volume)

# Get the origin in world coordinates from the affine matrix, correcting for negative spacings
corners = np.array([[0, 0, 0, 1], [*volume.shape, 1]])
origin = np.einsum("ij, nj -> ni", affine, corners).min(axis=0)[:3]
origin = tuple(origin.astype(np.float32))
return Subject(volume, affine, origin, spacing, bone_attenuation_multiplier)

@staticmethod
def from_dicom(filename: Path | str, bone_attenuation_multiplier=1.0):
raise NotImplementedError(
"First use dcm2niix to convert your DICOM: https://github.com/rordenlab/dcm2niix"
)
61 changes: 17 additions & 44 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
__all__ = ['DRR', 'Registration']

# %% ../notebooks/api/00_drr.ipynb 7
from .data import Subject


class DRR(nn.Module):
"""PyTorch module that computes differentiable digitally reconstructed radiographs."""

def __init__(
self,
volume: np.ndarray, # CT volume
origin: tuple, # Origin of the CT volume in world coordinates
spacing: tuple, # Dimensions in the CT volume in world coordinates
subject: Subject, # Data wrapper for the CT volume
sdd: float, # Source-to-detector distance (i.e., the C-arm's focal length)
height: int, # Height of the rendered DRR
delx: float, # X-axis pixel size
Expand All @@ -34,7 +35,6 @@ def __init__(
reshape: bool = True, # Return DRR with shape (b, 1, h, w)
reverse_x_axis: bool = True, # If pose includes reflection (i.e., E(3), not SE(3)), reverse x-axis
patch_size: int | None = None, # Render patches of the DRR in series
bone_attenuation_multiplier: float = 1.0, # Contrast ratio of bone to soft tissue
mask: np.ndarray = None, # Segmentation mask the same size as the CT volume
renderer: str = "siddon", # Rendering backend, either "siddon" or "trilinear"
**renderer_kwargs, # Kwargs for the renderer
Expand All @@ -60,21 +60,12 @@ def __init__(
)

# Initialize the volume
self.register_buffer("volume", torch.tensor(volume))
self.register_buffer("origin", torch.tensor(origin))
self.register_buffer("spacing", torch.tensor(spacing))
self.register_buffer("density", subject.density)
self.register_buffer("spacing", subject.spacing)
self.register_buffer("origin", subject.origin)
self.register_buffer("affine", subject.affine)
if mask is not None:
self.register_buffer("mask", torch.tensor(mask).to(torch.int16))
self.reshape = reshape
self.patch_size = patch_size
if self.patch_size is not None:
self.n_patches = (height * width) // (self.patch_size**2)

# Parameters for segmenting the CT volume and reweighting voxels
self.air = torch.where(self.volume <= -800)
self.soft_tissue = torch.where((-800 < self.volume) & (self.volume <= 350))
self.bone = torch.where(350 < self.volume)
self.bone_attenuation_multiplier = bone_attenuation_multiplier

# Initialize the renderer
if renderer == "siddon":
Expand All @@ -83,6 +74,10 @@ def __init__(
self.renderer = Trilinear(**renderer_kwargs)
else:
raise ValueError(f"renderer must be 'siddon', not {renderer}")
self.reshape = reshape
self.patch_size = patch_size
if self.patch_size is not None:
self.n_patches = (height * width) // (self.patch_size**2)

def reshape_transform(self, img, batch_size):
if self.reshape:
Expand All @@ -93,11 +88,7 @@ def reshape_transform(self, img, batch_size):
return img

# %% ../notebooks/api/00_drr.ipynb 8
def reshape_subsampled_drr(
img: torch.Tensor,
detector: Detector,
batch_size: int,
):
def reshape_subsampled_drr(img: torch.Tensor, detector: Detector, batch_size: int):
n_points = detector.height * detector.width
drr = torch.zeros(batch_size, n_points).to(img)
drr[:, detector.subsamples[-1]] = img
Expand All @@ -114,17 +105,10 @@ def forward(
*args, # Some batched representation of SE(3)
parameterization: str = None, # Specifies the representation of the rotation
convention: str = None, # If parameterization is Euler angles, specify convention
bone_attenuation_multiplier: float = None, # Contrast ratio of bone to soft tissue
labels: list = None, # Labels from the mask of structures to render
**kwargs, # Passed to the renderer
):
"""Generate DRR with rotational and translational parameters."""
# Initialize a density map from the volume
if not hasattr(self, "density"):
self.set_bone_attenuation_multiplier(self.bone_attenuation_multiplier)
if bone_attenuation_multiplier is not None:
self.set_bone_attenuation_multiplier(bone_attenuation_multiplier)

# Initialize the camera pose
if parameterization is None:
pose = args[0]
Expand All @@ -141,7 +125,7 @@ def forward(
else:
density = self.density

# Render the drr
# Render the DRR
if self.patch_size is not None:
n_points = target.shape[1] // self.n_patches
img = []
Expand All @@ -160,17 +144,6 @@ def forward(

# %% ../notebooks/api/00_drr.ipynb 11
@patch
def set_bone_attenuation_multiplier(self: DRR, bone_attenuation_multiplier: float):
self.density = torch.empty_like(self.volume)
self.density[self.air] = self.volume[self.soft_tissue].min()
self.density[self.soft_tissue] = self.volume[self.soft_tissue]
self.density[self.bone] = self.volume[self.bone] * bone_attenuation_multiplier
self.density -= self.density.min()
self.density /= self.density.max()
self.bone_attenuation_multiplier = bone_attenuation_multiplier

# %% ../notebooks/api/00_drr.ipynb 12
@patch
def set_intrinsics(
self: DRR,
sdd: float = None,
Expand All @@ -191,7 +164,7 @@ def set_intrinsics(
reverse_x_axis=self.detector.reverse_x_axis,
).to(self.volume)

# %% ../notebooks/api/00_drr.ipynb 13
# %% ../notebooks/api/00_drr.ipynb 12
from .pose import RigidTransform


Expand All @@ -208,7 +181,7 @@ def perspective_projection(
x = x / z
return x[..., :2]

# %% ../notebooks/api/00_drr.ipynb 14
# %% ../notebooks/api/00_drr.ipynb 13
from torch.nn.functional import pad


Expand All @@ -230,7 +203,7 @@ def inverse_projection(
)
return extrinsic(x)

# %% ../notebooks/api/00_drr.ipynb 16
# %% ../notebooks/api/00_drr.ipynb 15
class Registration(nn.Module):
"""Perform automatic 2D-to-3D registration using differentiable rendering."""

Expand Down
Loading

0 comments on commit 089631f

Please sign in to comment.