Skip to content

Commit

Permalink
Merge pull request #187 from eigenvivek/trilinear-renderer
Browse files Browse the repository at this point in the history
Trilinear renderer
  • Loading branch information
eigenvivek authored Feb 9, 2024
2 parents f707cc0 + 5a0d035 commit b0243d4
Show file tree
Hide file tree
Showing 6 changed files with 709 additions and 68 deletions.
19 changes: 14 additions & 5 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,20 @@
'diffdrr.pose.so3_relative_angle': ('api/pose.html#so3_relative_angle', 'diffdrr/pose.py'),
'diffdrr.pose.so3_rotation_angle': ('api/pose.html#so3_rotation_angle', 'diffdrr/pose.py'),
'diffdrr.pose.standardize_quaternion': ('api/pose.html#standardize_quaternion', 'diffdrr/pose.py')},
'diffdrr.siddon': { 'diffdrr.siddon._get_alpha_minmax': ('api/siddon.html#_get_alpha_minmax', 'diffdrr/siddon.py'),
'diffdrr.siddon._get_alphas': ('api/siddon.html#_get_alphas', 'diffdrr/siddon.py'),
'diffdrr.siddon._get_index': ('api/siddon.html#_get_index', 'diffdrr/siddon.py'),
'diffdrr.siddon._get_voxel': ('api/siddon.html#_get_voxel', 'diffdrr/siddon.py'),
'diffdrr.siddon.siddon_raycast': ('api/siddon.html#siddon_raycast', 'diffdrr/siddon.py')},
'diffdrr.renderers': { 'diffdrr.renderers.Siddon': ('api/renderers.html#siddon', 'diffdrr/renderers.py'),
'diffdrr.renderers.Siddon.__init__': ('api/renderers.html#siddon.__init__', 'diffdrr/renderers.py'),
'diffdrr.renderers.Siddon.dims': ('api/renderers.html#siddon.dims', 'diffdrr/renderers.py'),
'diffdrr.renderers.Siddon.forward': ('api/renderers.html#siddon.forward', 'diffdrr/renderers.py'),
'diffdrr.renderers.Siddon.maxidx': ('api/renderers.html#siddon.maxidx', 'diffdrr/renderers.py'),
'diffdrr.renderers.Trilinear': ('api/renderers.html#trilinear', 'diffdrr/renderers.py'),
'diffdrr.renderers.Trilinear.__init__': ( 'api/renderers.html#trilinear.__init__',
'diffdrr/renderers.py'),
'diffdrr.renderers.Trilinear.dims': ('api/renderers.html#trilinear.dims', 'diffdrr/renderers.py'),
'diffdrr.renderers.Trilinear.forward': ('api/renderers.html#trilinear.forward', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_alpha_minmax': ('api/renderers.html#_get_alpha_minmax', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_alphas': ('api/renderers.html#_get_alphas', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_index': ('api/renderers.html#_get_index', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_voxel': ('api/renderers.html#_get_voxel', 'diffdrr/renderers.py')},
'diffdrr.utils': { 'diffdrr.utils.get_focal_length': ('api/utils.html#get_focal_length', 'diffdrr/utils.py'),
'diffdrr.utils.get_principal_point': ('api/utils.html#get_principal_point', 'diffdrr/utils.py'),
'diffdrr.utils.parse_intrinsic_matrix': ('api/utils.html#parse_intrinsic_matrix', 'diffdrr/utils.py')},
Expand Down
17 changes: 14 additions & 3 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastcore.basics import patch

from .detector import Detector
from .siddon import siddon_raycast
from .renderers import Siddon, Trilinear

# %% auto 0
__all__ = ['DRR', 'Registration']
Expand All @@ -34,6 +34,8 @@ def __init__(
reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
patch_size: int | None = None, # Render patches of the DRR in series
bone_attenuation_multiplier: float = 1.0, # Contrast ratio of bone to soft tissue
renderer: str = "siddon", # Rendering backend, either "siddon" or "trilinear"
**renderer_kwargs, # Kwargs for the renderer
):
super().__init__()

Expand Down Expand Up @@ -69,6 +71,14 @@ def __init__(
self.bone = torch.where(350 < self.volume)
self.bone_attenuation_multiplier = bone_attenuation_multiplier

# Initialize the renderer
if renderer == "siddon":
self.renderer = Siddon(**renderer_kwargs)
elif renderer == "trilinear":
self.renderer = Trilinear(**renderer_kwargs)
else:
raise ValueError(f"renderer must be 'siddon', not {renderer}")

def reshape_transform(self, img, batch_size):
if self.reshape:
if self.detector.n_subsample is None:
Expand Down Expand Up @@ -101,6 +111,7 @@ def forward(
parameterization: str = None, # Specifies the representation of the rotation
convention: str = None, # If parameterization is Euler angles, specify convention
bone_attenuation_multiplier: float = None, # Contrast ratio of bone to soft tissue
**kwargs, # Passed to the renderer
):
"""Generate DRR with rotational and translational parameters."""
if not hasattr(self, "density"):
Expand All @@ -119,11 +130,11 @@ def forward(
img = []
for idx in range(self.n_patches):
t = target[:, idx * n_points : (idx + 1) * n_points]
partial = siddon_raycast(source, t, self.density, self.spacing)
partial = self.renderer(self.density, self.spacing, source, t, **kwargs)
img.append(partial)
img = torch.cat(img, dim=1)
else:
img = siddon_raycast(source, target, self.density, self.spacing)
img = self.renderer(self.density, self.spacing, source, target, **kwargs)
return self.reshape_transform(img, batch_size=len(pose))

# %% ../notebooks/api/00_drr.ipynb 11
Expand Down
112 changes: 83 additions & 29 deletions diffdrr/siddon.py → diffdrr/renderers.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,44 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/01_siddon.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/01_renderers.ipynb.

# %% auto 0
__all__ = ['siddon_raycast']
__all__ = ['Siddon', 'Trilinear']

# %% ../notebooks/api/01_siddon.ipynb 3
# %% ../notebooks/api/01_renderers.ipynb 3
import torch

# %% ../notebooks/api/01_siddon.ipynb 6
def siddon_raycast(
source: torch.Tensor,
target: torch.Tensor,
volume: torch.Tensor,
spacing: torch.Tensor,
eps: float = 1e-8,
):
"""An auto-differentiable implementation of the raycasting algorithm known as Siddon's method."""
maxidx = volume.numel() - 1
dims = torch.tensor(volume.shape).to(source) + 1
alphas = _get_alphas(source, target, spacing, dims, eps)
alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2
voxels = _get_voxel(alphamid, source, target, volume, spacing, dims, maxidx, eps)

# Step length for alphas out of range will be nan
# These nans cancel out voxels convereted to 0 index
step_length = torch.diff(alphas, dim=-1)
weighted_voxels = voxels * step_length

drr = torch.nansum(weighted_voxels, dim=-1)
raylength = (target - source + eps).norm(dim=-1)
drr *= raylength
return drr

# %% ../notebooks/api/01_siddon.ipynb 8
# %% ../notebooks/api/01_renderers.ipynb 6
class Siddon(torch.nn.Module):
def __init__(self, eps=1e-8):
super().__init__()
self.eps = eps

def dims(self, volume):
return torch.tensor(volume.shape).to(volume) + 1

def maxidx(self, volume):
return volume.numel() - 1

def forward(self, volume, spacing, source, target):
dims = self.dims(volume)
maxidx = self.maxidx(volume)

alphas = _get_alphas(source, target, spacing, dims, self.eps)
alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2
voxels = _get_voxel(
alphamid, source, target, volume, spacing, dims, maxidx, self.eps
)

# Step length for alphas out of range will be nan
# These nans cancel out voxels convereted to 0 index
step_length = torch.diff(alphas, dim=-1)
weighted_voxels = voxels * step_length

drr = torch.nansum(weighted_voxels, dim=-1)
raylength = (target - source + self.eps).norm(dim=-1)
drr *= raylength
return drr

# %% ../notebooks/api/01_renderers.ipynb 8
def _get_alphas(source, target, spacing, dims, eps):
# Get the CT sizing and spacing parameters
dx, dy, dz = spacing
Expand Down Expand Up @@ -100,3 +107,50 @@ def _get_index(alpha, source, target, spacing, dims, maxidx, eps):
idxs[idxs < 0] = 0
idxs[idxs > maxidx] = maxidx
return idxs

# %% ../notebooks/api/01_renderers.ipynb 10
from torch.nn.functional import grid_sample


class Trilinear(torch.nn.Module):
def __init__(
self,
near=0.0,
far=1.0,
eps=1e-8,
mode="bilinear",
):
super().__init__()
self.near = near
self.far = far
self.eps = eps
self.mode = mode

def dims(self, volume):
return torch.tensor(volume.shape).to(volume) + 1

def forward(
self, volume, spacing, source, target, n_points=100, align_corners=True
):
# Reorder array to match torch conventions
volume = volume.permute(2, 1, 0)
spacing = spacing[[2, 1, 0]]

# Get the raylength and reshape sources
raylength = (source - target + self.eps).norm(dim=-1)
source = source[:, None, :, None, :]
target = target[:, None, :, None, :]

# Sample points along the rays and rescale to [-1, 1]
alphas = torch.linspace(self.near, self.far, n_points).to(volume)
alphas = alphas[None, None, None, :, None]
rays = source + alphas * (target - source)
rays = 2 * rays / (spacing * self.dims(volume)) - 1

# Render the DRR
batch_size = len(rays)
vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)
drr = grid_sample(vol, rays, mode=self.mode, align_corners=align_corners)
drr = drr[:, 0, 0].sum(dim=-1)
drr *= raylength
return drr
17 changes: 14 additions & 3 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"from fastcore.basics import patch\n",
"\n",
"from diffdrr.detector import Detector\n",
"from diffdrr.siddon import siddon_raycast"
"from diffdrr.renderers import Siddon, Trilinear"
]
},
{
Expand Down Expand Up @@ -129,6 +129,8 @@
" reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis\n",
" patch_size: int | None = None, # Render patches of the DRR in series\n",
" bone_attenuation_multiplier: float = 1.0, # Contrast ratio of bone to soft tissue\n",
" renderer: str = \"siddon\", # Rendering backend, either \"siddon\" or \"trilinear\"\n",
" **renderer_kwargs, # Kwargs for the renderer\n",
" ):\n",
" super().__init__()\n",
"\n",
Expand Down Expand Up @@ -164,6 +166,14 @@
" self.bone = torch.where(350 < self.volume)\n",
" self.bone_attenuation_multiplier = bone_attenuation_multiplier\n",
"\n",
" # Initialize the renderer\n",
" if renderer == \"siddon\":\n",
" self.renderer = Siddon(**renderer_kwargs)\n",
" elif renderer == \"trilinear\":\n",
" self.renderer = Trilinear(**renderer_kwargs)\n",
" else:\n",
" raise ValueError(f\"renderer must be 'siddon', not {renderer}\")\n",
"\n",
" def reshape_transform(self, img, batch_size):\n",
" if self.reshape:\n",
" if self.detector.n_subsample is None:\n",
Expand Down Expand Up @@ -220,6 +230,7 @@
" parameterization: str = None, # Specifies the representation of the rotation\n",
" convention: str = None, # If parameterization is Euler angles, specify convention\n",
" bone_attenuation_multiplier: float = None, # Contrast ratio of bone to soft tissue\n",
" **kwargs, # Passed to the renderer\n",
"):\n",
" \"\"\"Generate DRR with rotational and translational parameters.\"\"\"\n",
" if not hasattr(self, \"density\"):\n",
Expand All @@ -238,11 +249,11 @@
" img = []\n",
" for idx in range(self.n_patches):\n",
" t = target[:, idx * n_points : (idx + 1) * n_points]\n",
" partial = siddon_raycast(source, t, self.density, self.spacing)\n",
" partial = self.renderer(self.density, self.spacing, source, t, **kwargs)\n",
" img.append(partial)\n",
" img = torch.cat(img, dim=1)\n",
" else:\n",
" img = siddon_raycast(source, target, self.density, self.spacing)\n",
" img = self.renderer(self.density, self.spacing, source, target, **kwargs)\n",
" return self.reshape_transform(img, batch_size=len(pose))"
]
},
Expand Down
Loading

0 comments on commit b0243d4

Please sign in to comment.