Skip to content

Commit 39b0daa

Browse files
authored
Implement cone-beam max-intensity projection (#331)
1 parent e3f411e commit 39b0daa

File tree

5 files changed

+87
-9
lines changed

5 files changed

+87
-9
lines changed

diffdrr/_modidx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@
164164
'diffdrr.renderers._get_alpha_minmax': ('api/renderers.html#_get_alpha_minmax', 'diffdrr/renderers.py'),
165165
'diffdrr.renderers._get_alphas': ('api/renderers.html#_get_alphas', 'diffdrr/renderers.py'),
166166
'diffdrr.renderers._get_voxel': ('api/renderers.html#_get_voxel', 'diffdrr/renderers.py'),
167-
'diffdrr.renderers._get_xyzs': ('api/renderers.html#_get_xyzs', 'diffdrr/renderers.py')},
167+
'diffdrr.renderers._get_xyzs': ('api/renderers.html#_get_xyzs', 'diffdrr/renderers.py'),
168+
'diffdrr.renderers.reduce': ('api/renderers.html#reduce', 'diffdrr/renderers.py')},
168169
'diffdrr.utils': { 'diffdrr.utils.get_focal_length': ('api/utils.html#get_focal_length', 'diffdrr/utils.py'),
169170
'diffdrr.utils.get_principal_point': ('api/utils.html#get_principal_point', 'diffdrr/utils.py'),
170171
'diffdrr.utils.make_intrinsic_matrix': ('api/utils.html#make_intrinsic_matrix', 'diffdrr/utils.py'),

diffdrr/renderers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ def __init__(
1616
mode: str = "nearest", # Interpolation mode for grid_sample
1717
stop_gradients_through_grid_sample: bool = False, # Apply torch.no_grad when calling grid_sample
1818
filter_intersections_outside_volume: bool = True, # Use alphamin/max to filter the intersections
19+
reducefn: str = "sum", # Function for combining samples along each ray
1920
eps: float = 1e-8, # Small constant to avoid div by zero errors
2021
):
2122
super().__init__()
2223
self.mode = mode
2324
self.stop_gradients_through_grid_sample = stop_gradients_through_grid_sample
2425
self.filter_intersections_outside_volume = filter_intersections_outside_volume
26+
self.reducefn = reducefn
2527
self.eps = eps
2628

2729
def dims(self, volume):
@@ -69,7 +71,7 @@ def forward(
6971

7072
# Handle optional masking
7173
if mask is None:
72-
img = img.sum(dim=-1)
74+
img = reduce(img, self.reducefn)
7375
img = img.unsqueeze(1)
7476
else:
7577
# Thanks to @Ivan for the clutch assist w/ pytorch tensor ops
@@ -162,17 +164,28 @@ def _get_voxel(volume, xyzs, img, mode, align_corners):
162164
img = voxels
163165
return img
164166

165-
# %% ../notebooks/api/01_renderers.ipynb 10
167+
# %% ../notebooks/api/01_renderers.ipynb 9
168+
def reduce(img, reducefn):
169+
if reducefn == "sum":
170+
return img.sum(dim=-1)
171+
elif reducefn == "max":
172+
return img.max(dim=-1).values
173+
else:
174+
raise ValueError(f"Only supports reducefn 'sum' or 'max', not {reducefn}")
175+
176+
# %% ../notebooks/api/01_renderers.ipynb 11
166177
class Trilinear(torch.nn.Module):
167178
"""Differentiable X-ray renderer implemented with trilinear interpolation."""
168179

169180
def __init__(
170181
self,
171182
mode: str = "bilinear", # Interpolation mode for grid_sample
183+
reducefn: str = "sum", # Function for combining samples along each ray
172184
eps: float = 1e-8, # Small constant to avoid div by zero errors
173185
):
174186
super().__init__()
175187
self.mode = mode
188+
self.reducefn = reducefn
176189
self.eps = eps
177190

178191
def dims(self, volume):
@@ -213,7 +226,8 @@ def forward(
213226

214227
# Handle optional masking
215228
if mask is None:
216-
img = img.sum(dim=-1).unsqueeze(1)
229+
img = reduce(img, self.reducefn)
230+
img = img.unsqueeze(1)
217231
else:
218232
B, D, _ = img.shape
219233
C = int(mask.max().item() + 1)

notebooks/api/01_renderers.ipynb

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,14 @@
118118
" mode: str = \"nearest\", # Interpolation mode for grid_sample\n",
119119
" stop_gradients_through_grid_sample: bool = False, # Apply torch.no_grad when calling grid_sample\n",
120120
" filter_intersections_outside_volume: bool = True, # Use alphamin/max to filter the intersections\n",
121+
" reducefn: str = \"sum\", # Function for combining samples along each ray\n",
121122
" eps: float = 1e-8, # Small constant to avoid div by zero errors\n",
122123
" ):\n",
123124
" super().__init__()\n",
124125
" self.mode = mode\n",
125126
" self.stop_gradients_through_grid_sample = stop_gradients_through_grid_sample\n",
126127
" self.filter_intersections_outside_volume = filter_intersections_outside_volume\n",
128+
" self.reducefn = reducefn\n",
127129
" self.eps = eps\n",
128130
"\n",
129131
" def dims(self, volume):\n",
@@ -169,7 +171,7 @@
169171
"\n",
170172
" # Handle optional masking\n",
171173
" if mask is None:\n",
172-
" img = img.sum(dim=-1)\n",
174+
" img = reduce(img, self.reducefn)\n",
173175
" img = img.unsqueeze(1)\n",
174176
" else:\n",
175177
" # Thanks to @Ivan for the clutch assist w/ pytorch tensor ops\n",
@@ -270,6 +272,22 @@
270272
" return img"
271273
]
272274
},
275+
{
276+
"cell_type": "code",
277+
"execution_count": null,
278+
"metadata": {},
279+
"outputs": [],
280+
"source": [
281+
"#| exporti\n",
282+
"def reduce(img, reducefn):\n",
283+
" if reducefn == \"sum\":\n",
284+
" return img.sum(dim=-1)\n",
285+
" elif reducefn == \"max\":\n",
286+
" return img.max(dim=-1).values\n",
287+
" else:\n",
288+
" raise ValueError(f\"Only supports reducefn 'sum' or 'max', not {reducefn}\")"
289+
]
290+
},
273291
{
274292
"cell_type": "markdown",
275293
"metadata": {},
@@ -298,10 +316,12 @@
298316
" def __init__(\n",
299317
" self,\n",
300318
" mode: str = \"bilinear\", # Interpolation mode for grid_sample\n",
319+
" reducefn: str = \"sum\", # Function for combining samples along each ray\n",
301320
" eps: float = 1e-8, # Small constant to avoid div by zero errors\n",
302321
" ):\n",
303322
" super().__init__()\n",
304323
" self.mode = mode\n",
324+
" self.reducefn = reducefn\n",
305325
" self.eps = eps\n",
306326
"\n",
307327
" def dims(self, volume):\n",
@@ -342,7 +362,8 @@
342362
"\n",
343363
" # Handle optional masking\n",
344364
" if mask is None:\n",
345-
" img = img.sum(dim=-1).unsqueeze(1)\n",
365+
" img = reduce(img, self.reducefn)\n",
366+
" img = img.unsqueeze(1)\n",
346367
" else:\n",
347368
" B, D, _ = img.shape\n",
348369
" C = int(mask.max().item() + 1)\n",

notebooks/index.ipynb

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

notebooks/tutorials/introduction.ipynb

Lines changed: 43 additions & 1 deletion
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)