Skip to content

Commit

Permalink
Move dims to device for torch.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Feb 5, 2024
1 parent e671d45 commit f93aafc
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion diffdrr/siddon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def siddon_raycast(
eps: float = 1e-8,
):
"""An auto-differentiable implementation of the raycasting algorithm known as Siddon's method."""
dims = torch.tensor(volume.shape) + 1
dims = torch.tensor(volume.shape).to(source) + 1
alphas, maxidx = _get_alphas(source, target, spacing, dims, eps)
alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2
voxels = _get_voxel(alphamid, source, target, volume, spacing, dims, maxidx, eps)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/api/01_siddon.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
" eps: float=1e-8,\n",
"):\n",
" \"\"\"An auto-differentiable implementation of the raycasting algorithm known as Siddon's method.\"\"\"\n",
" dims = torch.tensor(volume.shape) + 1\n",
" dims = torch.tensor(volume.shape).to(source) + 1\n",
" alphas, maxidx = _get_alphas(source, target, spacing, dims, eps)\n",
" alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2\n",
" voxels = _get_voxel(alphamid, source, target, volume, spacing, dims, maxidx, eps)\n",
Expand Down

0 comments on commit f93aafc

Please sign in to comment.