From f93aafc648638a379ca88db4fbb4908d0ca81117 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Mon, 5 Feb 2024 15:12:04 -0500 Subject: [PATCH] Move dims to device for torch.compile --- diffdrr/siddon.py | 2 +- notebooks/api/01_siddon.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diffdrr/siddon.py b/diffdrr/siddon.py index bc25d0aaa..34f6b0bf9 100644 --- a/diffdrr/siddon.py +++ b/diffdrr/siddon.py @@ -15,7 +15,7 @@ def siddon_raycast( eps: float = 1e-8, ): """An auto-differentiable implementation of the raycasting algorithm known as Siddon's method.""" - dims = torch.tensor(volume.shape) + 1 + dims = torch.tensor(volume.shape).to(source) + 1 alphas, maxidx = _get_alphas(source, target, spacing, dims, eps) alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2 voxels = _get_voxel(alphamid, source, target, volume, spacing, dims, maxidx, eps) diff --git a/notebooks/api/01_siddon.ipynb b/notebooks/api/01_siddon.ipynb index e5a273b94..3424d3fa9 100644 --- a/notebooks/api/01_siddon.ipynb +++ b/notebooks/api/01_siddon.ipynb @@ -91,7 +91,7 @@ " eps: float=1e-8,\n", "):\n", " \"\"\"An auto-differentiable implementation of the raycasting algorithm known as Siddon's method.\"\"\"\n", - " dims = torch.tensor(volume.shape) + 1\n", + " dims = torch.tensor(volume.shape).to(source) + 1\n", " alphas, maxidx = _get_alphas(source, target, spacing, dims, eps)\n", " alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2\n", " voxels = _get_voxel(alphamid, source, target, volume, spacing, dims, maxidx, eps)\n",