diff --git a/diffdrr/renderers.py b/diffdrr/renderers.py index e26cda210..32259c825 100644 --- a/diffdrr/renderers.py +++ b/diffdrr/renderers.py @@ -132,10 +132,6 @@ def dims(self, volume): def forward( self, volume, spacing, source, target, n_points=100, align_corners=True ): - # Reorder array to match torch conventions - volume = volume.permute(2, 1, 0) - spacing = spacing[[2, 1, 0]] - # Get the raylength and reshape sources raylength = (source - target + self.eps).norm(dim=-1) source = source[:, None, :, None, :] @@ -147,6 +143,9 @@ def forward( rays = source + alphas * (target - source) rays = 2 * rays / (spacing * self.dims(volume)) - 1 + # Reorder array to match torch conventions + volume = volume.permute(2, 1, 0) + # Render the DRR batch_size = len(rays) vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1) diff --git a/notebooks/api/01_renderers.ipynb b/notebooks/api/01_renderers.ipynb index 07672e03a..f305b43f4 100644 --- a/notebooks/api/01_renderers.ipynb +++ b/notebooks/api/01_renderers.ipynb @@ -257,10 +257,6 @@ " return torch.tensor(volume.shape).to(volume) + 1\n", "\n", " def forward(self, volume, spacing, source, target, n_points=100, align_corners=True): \n", - " # Reorder array to match torch conventions\n", - " volume = volume.permute(2, 1, 0)\n", - " spacing = spacing[[2, 1, 0]]\n", - "\n", " # Get the raylength and reshape sources\n", " raylength = (source - target + self.eps).norm(dim=-1)\n", " source = source[:, None, :, None, :]\n", @@ -272,6 +268,9 @@ " rays = source + alphas * (target - source)\n", " rays = 2 * rays / (spacing * self.dims(volume)) - 1\n", "\n", + " # Reorder array to match torch conventions\n", + " volume = volume.permute(2, 1, 0)\n", + " \n", " # Render the DRR\n", " batch_size = len(rays)\n", " vol = volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1)\n",