Skip to content

Commit 2ab187f

Browse files
authored
Merge pull request #189 from eigenvivek/fix-trilinear
Hotfix volume permute
2 parents 93c9d53 + edbc25e commit 2ab187f

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

diffdrr/renderers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,6 @@ def dims(self, volume):
132132
def forward(
133133
self, volume, spacing, source, target, n_points=100, align_corners=True
134134
):
135-
# Reorder array to match torch conventions
136-
volume = volume.permute(2, 1, 0)
137-
spacing = spacing[[2, 1, 0]]
138-
139135
# Get the raylength and reshape sources
140136
raylength = (source - target + self.eps).norm(dim=-1)
141137
source = source[:, None, :, None, :]
@@ -147,6 +143,9 @@ def forward(
147143
rays = source + alphas * (target - source)
148144
rays = 2 * rays / (spacing * self.dims(volume)) - 1
149145

146+
# Reorder array to match torch conventions
147+
volume = volume.permute(2, 1, 0)
148+
150149
# Render the DRR
151150
batch_size = len(rays)
152151
vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)

notebooks/api/01_renderers.ipynb

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,6 @@
257257
" return torch.tensor(volume.shape).to(volume) + 1\n",
258258
"\n",
259259
" def forward(self, volume, spacing, source, target, n_points=100, align_corners=True): \n",
260-
" # Reorder array to match torch conventions\n",
261-
" volume = volume.permute(2, 1, 0)\n",
262-
" spacing = spacing[[2, 1, 0]]\n",
263-
"\n",
264260
" # Get the raylength and reshape sources\n",
265261
" raylength = (source - target + self.eps).norm(dim=-1)\n",
266262
" source = source[:, None, :, None, :]\n",
@@ -272,6 +268,9 @@
272268
" rays = source + alphas * (target - source)\n",
273269
" rays = 2 * rays / (spacing * self.dims(volume)) - 1\n",
274270
"\n",
271+
" # Reorder array to match torch conventions\n",
272+
" volume = volume.permute(2, 1, 0)\n",
273+
" \n",
275274
" # Render the DRR\n",
276275
" batch_size = len(rays)\n",
277276
" vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)\n",

0 commit comments

Comments
 (0)