Skip to content

Commit 9bda058

Browse files
authored
Implement alpha filtering in Trilinear (#281)
* Implement separate alpha filtering function * Update tutorials
1 parent e26bb37 commit 9bda058

File tree

7 files changed

+154
-91
lines changed

7 files changed

+154
-91
lines changed

diffdrr/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@
158158
'diffdrr/renderers.py'),
159159
'diffdrr.renderers.Trilinear.dims': ('api/renderers.html#trilinear.dims', 'diffdrr/renderers.py'),
160160
'diffdrr.renderers.Trilinear.forward': ('api/renderers.html#trilinear.forward', 'diffdrr/renderers.py'),
161+
'diffdrr.renderers._filter_intersections_outside_volume': ( 'api/renderers.html#_filter_intersections_outside_volume',
162+
'diffdrr/renderers.py'),
161163
'diffdrr.renderers._get_alpha_minmax': ('api/renderers.html#_get_alpha_minmax', 'diffdrr/renderers.py'),
162164
'diffdrr.renderers._get_alphas': ('api/renderers.html#_get_alphas', 'diffdrr/renderers.py'),
163165
'diffdrr.renderers._get_voxel': ('api/renderers.html#_get_voxel', 'diffdrr/renderers.py'),

diffdrr/renderers.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -114,24 +114,31 @@ def _get_alphas(
114114

115115
# Sort the intersections
116116
alphas = torch.sort(alphas, dim=-1).values
117-
118-
# Remove interesections that are outside of the volume for all rays
119117
if filter_intersections_outside_volume:
120-
alphamin, alphamax = _get_alpha_minmax(
121-
source, target, origin, spacing, dims, eps
118+
alphas = _filter_intersections_outside_volume(
119+
alphas, source, target, origin, spacing, dims, eps
122120
)
123-
good_idxs = torch.logical_and(alphamin <= alphas, alphas <= alphamax)
124-
alphas = alphas[..., good_idxs.any(dim=[0, 1])]
125121

126122
return alphas
127123

128124

125+
def _filter_intersections_outside_volume(
126+
alphas, source, target, origin, spacing, dims, eps
127+
):
128+
"""Remove interesections that are outside of the volume for all rays."""
129+
alphamin, alphamax = _get_alpha_minmax(source, target, origin, spacing, dims, eps)
130+
good_idxs = torch.logical_and(alphamin <= alphas, alphas <= alphamax)
131+
alphas = alphas[..., good_idxs.any(dim=[0, 1])]
132+
return alphas
133+
134+
129135
def _get_alpha_minmax(source, target, origin, spacing, dims, eps):
136+
"""Calculate the first and last intersections of each ray with the volume."""
130137
sdd = target - source + eps
131138

132-
planes = torch.zeros(3).to(source)
139+
planes = torch.zeros(3).to(source) - 0.5
133140
alpha0 = (planes * spacing + origin - source) / sdd
134-
planes = (dims - 1).to(source)
141+
planes = dims.to(source) - 0.5
135142
alpha1 = (planes * spacing + origin - source) / sdd
136143
alphas = torch.stack([alpha0, alpha1]).to(source)
137144

@@ -177,13 +184,15 @@ def __init__(
177184
self,
178185
near=0.0,
179186
far=1.0,
180-
mode="bilinear",
181-
eps=1e-8,
187+
mode: str = "bilinear", # Interpolation mode for grid_sample
188+
filter_intersections_outside_volume: bool = True, # Use alphamin/max to filter the intersections
189+
eps: float = 1e-8, # Small constant to avoid div by zero errors
182190
):
183191
super().__init__()
184192
self.near = near
185193
self.far = far
186194
self.mode = mode
195+
self.filter_intersections_outside_volume = filter_intersections_outside_volume
187196
self.eps = eps
188197

189198
def dims(self, volume):
@@ -200,16 +209,18 @@ def forward(
200209
align_corners=True,
201210
mask=None,
202211
):
203-
# Get the raylength and reshape sources
204-
raylength = (source - target + self.eps).norm(dim=-1).unsqueeze(1)
205-
206-
# Sample points along the rays and rescale to [-1, 1]
207-
alphas = torch.linspace(self.near, self.far, n_points).to(volume)
208-
alphas = alphas[None, None, :]
209-
210-
# Render the DRR
212+
# Sample points along the rays
211213
dims = self.dims(volume)
214+
alphas = torch.linspace(self.near, self.far, n_points)[None, None].to(volume)
215+
if self.filter_intersections_outside_volume:
216+
alphas = _filter_intersections_outside_volume(
217+
alphas, source, target, origin, spacing, dims, self.eps
218+
)
219+
220+
# Get the XYZ coordinate of each alpha, normalized for grid_sample
212221
xyzs = _get_xyzs(alphas, source, target, origin, spacing, dims, self.eps)
222+
223+
# Sample the volume with trilinear interpolation
213224
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)
214225

215226
# Handle optional masking
@@ -227,6 +238,7 @@ def forward(
227238
.scatter_add_(1, channels.transpose(-1, -2), img.transpose(-1, -2))
228239
)
229240

230-
# Multiply by raylength
241+
# Multiply by raylength and return the drr
242+
raylength = (target - source + self.eps).norm(dim=-1).unsqueeze(1)
231243
img *= raylength / n_points
232244
return img

notebooks/api/01_renderers.ipynb

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -211,22 +211,31 @@
211211
"\n",
212212
" # Sort the intersections\n",
213213
" alphas = torch.sort(alphas, dim=-1).values\n",
214-
"\n",
215-
" # Remove interesections that are outside of the volume for all rays\n",
216214
" if filter_intersections_outside_volume:\n",
217-
" alphamin, alphamax = _get_alpha_minmax(source, target, origin, spacing, dims, eps)\n",
218-
" good_idxs = torch.logical_and(alphamin <= alphas, alphas <= alphamax)\n",
219-
" alphas = alphas[..., good_idxs.any(dim=[0, 1])]\n",
215+
" alphas = _filter_intersections_outside_volume(\n",
216+
" alphas, source, target, origin, spacing, dims, eps\n",
217+
" )\n",
220218
" \n",
221219
" return alphas\n",
222220
"\n",
223221
"\n",
222+
"def _filter_intersections_outside_volume(\n",
223+
" alphas, source, target, origin, spacing, dims, eps\n",
224+
"):\n",
225+
" \"\"\"Remove interesections that are outside of the volume for all rays.\"\"\"\n",
226+
" alphamin, alphamax = _get_alpha_minmax(source, target, origin, spacing, dims, eps)\n",
227+
" good_idxs = torch.logical_and(alphamin <= alphas, alphas <= alphamax)\n",
228+
" alphas = alphas[..., good_idxs.any(dim=[0, 1])]\n",
229+
" return alphas\n",
230+
"\n",
231+
"\n",
224232
"def _get_alpha_minmax(source, target, origin, spacing, dims, eps):\n",
233+
" \"\"\"Calculate the first and last intersections of each ray with the volume.\"\"\"\n",
225234
" sdd = target - source + eps\n",
226235
" \n",
227-
" planes = torch.zeros(3).to(source)\n",
236+
" planes = torch.zeros(3).to(source) - 0.5\n",
228237
" alpha0 = (planes * spacing + origin - source) / sdd\n",
229-
" planes = (dims - 1).to(source)\n",
238+
" planes = dims.to(source) - 0.5\n",
230239
" alpha1 = (planes * spacing + origin - source) / sdd\n",
231240
" alphas = torch.stack([alpha0, alpha1]).to(source)\n",
232241
"\n",
@@ -294,13 +303,15 @@
294303
" self,\n",
295304
" near=0.0,\n",
296305
" far=1.0,\n",
297-
" mode=\"bilinear\",\n",
298-
" eps=1e-8,\n",
306+
" mode: str = \"bilinear\", # Interpolation mode for grid_sample\n",
307+
" filter_intersections_outside_volume: bool = True, # Use alphamin/max to filter the intersections\n",
308+
" eps: float = 1e-8, # Small constant to avoid div by zero errors\n",
299309
" ):\n",
300310
" super().__init__()\n",
301311
" self.near = near\n",
302312
" self.far = far\n",
303313
" self.mode = mode\n",
314+
" self.filter_intersections_outside_volume = filter_intersections_outside_volume\n",
304315
" self.eps = eps\n",
305316
"\n",
306317
" def dims(self, volume):\n",
@@ -317,16 +328,18 @@
317328
" align_corners=True,\n",
318329
" mask=None,\n",
319330
" ):\n",
320-
" # Get the raylength and reshape sources\n",
321-
" raylength = (source - target + self.eps).norm(dim=-1).unsqueeze(1)\n",
322-
"\n",
323-
" # Sample points along the rays and rescale to [-1, 1]\n",
324-
" alphas = torch.linspace(self.near, self.far, n_points).to(volume)\n",
325-
" alphas = alphas[None, None, :]\n",
326-
"\n",
327-
" # Render the DRR\n",
331+
" # Sample points along the rays\n",
328332
" dims = self.dims(volume)\n",
333+
" alphas = torch.linspace(self.near, self.far, n_points)[None, None].to(volume)\n",
334+
" if self.filter_intersections_outside_volume:\n",
335+
" alphas = _filter_intersections_outside_volume(\n",
336+
" alphas, source, target, origin, spacing, dims, self.eps\n",
337+
" )\n",
338+
"\n",
339+
" # Get the XYZ coordinate of each alpha, normalized for grid_sample\n",
329340
" xyzs = _get_xyzs(alphas, source, target, origin, spacing, dims, self.eps)\n",
341+
"\n",
342+
" # Sample the volume with trilinear interpolation\n",
330343
" img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n",
331344
"\n",
332345
" # Handle optional masking\n",
@@ -344,7 +357,8 @@
344357
" .scatter_add_(1, channels.transpose(-1, -2), img.transpose(-1, -2))\n",
345358
" )\n",
346359
"\n",
347-
" # Multiply by raylength\n",
360+
" # Multiply by raylength and return the drr\n",
361+
" raylength = (target - source + self.eps).norm(dim=-1).unsqueeze(1)\n",
348362
" img *= raylength / n_points\n",
349363
" return img"
350364
]

notebooks/tutorials/introduction.ipynb

Lines changed: 9 additions & 9 deletions
Large diffs are not rendered by default.

notebooks/tutorials/reconstruction.ipynb

Lines changed: 21 additions & 21 deletions
Large diffs are not rendered by default.

notebooks/tutorials/registration.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@
895895
"id": "ebf4a6da-a6d2-421a-bb36-c8157716f776",
896896
"metadata": {},
897897
"source": [
898-
"L-BFGS with line search converges so quickly that a GIF with ~30 FPS is imperceptable. Here's the same GIF but at 1 FPS."
898+
"L-BFGS with line search converges so quickly that a GIF with ~30 FPS is imperceptible. Here's the same GIF but at 1 FPS."
899899
]
900900
},
901901
{

notebooks/tutorials/trilinear.ipynb

Lines changed: 58 additions & 23 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)