Skip to content

Commit

Permalink
Hotfix volume permute
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Feb 9, 2024
1 parent 93c9d53 commit edbc25e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
7 changes: 3 additions & 4 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,6 @@ def dims(self, volume):
def forward(
self, volume, spacing, source, target, n_points=100, align_corners=True
):
# Reorder array to match torch conventions
volume = volume.permute(2, 1, 0)
spacing = spacing[[2, 1, 0]]

# Get the raylength and reshape sources
raylength = (source - target + self.eps).norm(dim=-1)
source = source[:, None, :, None, :]
Expand All @@ -147,6 +143,9 @@ def forward(
rays = source + alphas * (target - source)
rays = 2 * rays / (spacing * self.dims(volume)) - 1

# Reorder array to match torch conventions
volume = volume.permute(2, 1, 0)

# Render the DRR
batch_size = len(rays)
vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)
Expand Down
7 changes: 3 additions & 4 deletions notebooks/api/01_renderers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,6 @@
" return torch.tensor(volume.shape).to(volume) + 1\n",
"\n",
" def forward(self, volume, spacing, source, target, n_points=100, align_corners=True): \n",
" # Reorder array to match torch conventions\n",
" volume = volume.permute(2, 1, 0)\n",
" spacing = spacing[[2, 1, 0]]\n",
"\n",
" # Get the raylength and reshape sources\n",
" raylength = (source - target + self.eps).norm(dim=-1)\n",
" source = source[:, None, :, None, :]\n",
Expand All @@ -272,6 +268,9 @@
" rays = source + alphas * (target - source)\n",
" rays = 2 * rays / (spacing * self.dims(volume)) - 1\n",
"\n",
" # Reorder array to match torch conventions\n",
" volume = volume.permute(2, 1, 0)\n",
" \n",
" # Render the DRR\n",
" batch_size = len(rays)\n",
" vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)\n",
Expand Down

0 comments on commit edbc25e

Please sign in to comment.