Skip to content

Commit edf1007

Browse files
authored
Merge pull request #178 from eigenvivek/intrinsics
Support autodifferentiability for intrinsic parameters
2 parents cb20a7f + 970bab3 commit edf1007

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

diffdrr/detector.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,29 @@ def __init__(
5757
@patch
5858
def _initialize_carm(self: Detector):
5959
"""Initialize the default position for the source and detector plane."""
60-
# Initialize the source on the x-axis
61-
source = torch.tensor([[self.sdr, 0.0, 0.0]])
60+
try:
61+
device = self.sdr.device
62+
except AttributeError:
63+
device = torch.device("cpu")
6264

63-
# Initialize the center of the detector plane on the negative x-axis
64-
center = torch.tensor([[-self.sdr, 0.0, 0.0]])
65+
# Initialize the source on the x-axis and the center of the detector plane on the negative x-axis
66+
source = torch.tensor([[1.0, 0.0, 0.0]], device=device) * self.sdr
67+
center = torch.tensor([[-1.0, 0.0, 0.0]], device=device) * self.sdr
6568

6669
# Use the standard basis for the detector plane
67-
basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
70+
basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device)
6871

6972
# Construct the detector plane with different offsets for even or odd heights
7073
h_off = 1.0 if self.height % 2 else 0.5
7174
w_off = 1.0 if self.width % 2 else 0.5
7275

7376
# Construct equally spaced points along the basis vectors
74-
t = (torch.arange(-self.height // 2, self.height // 2) + h_off) * self.delx
75-
s = (torch.arange(-self.width // 2, self.width // 2) + w_off) * self.dely
77+
t = (
78+
torch.arange(-self.height // 2, self.height // 2, device=device) + h_off
79+
) * self.delx
80+
s = (
81+
torch.arange(-self.width // 2, self.width // 2, device=device) + w_off
82+
) * self.dely
7683
if self.reverse_x_axis:
7784
s = -s
7885
coefs = torch.cartesian_prod(t, s).reshape(-1, 2)

notebooks/api/02_detector.ipynb

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,25 @@
128128
"@patch\n",
129129
"def _initialize_carm(self: Detector):\n",
130130
" \"\"\"Initialize the default position for the source and detector plane.\"\"\"\n",
131-
" # Initialize the source on the x-axis\n",
132-
" source = torch.tensor([[self.sdr, 0.0, 0.0]])\n",
133-
"\n",
134-
" # Initialize the center of the detector plane on the negative x-axis\n",
135-
" center = torch.tensor([[-self.sdr, 0.0, 0.0]])\n",
131+
" try:\n",
132+
" device = self.sdr.device\n",
133+
" except AttributeError:\n",
134+
" device = torch.device(\"cpu\")\n",
135+
" \n",
136+
" # Initialize the source on the x-axis and the center of the detector plane on the negative x-axis\n",
137+
" source = torch.tensor([[1.0, 0.0, 0.0]], device=device) * self.sdr\n",
138+
" center = torch.tensor([[-1.0, 0.0, 0.0]], device=device) * self.sdr\n",
136139
"\n",
137140
" # Use the standard basis for the detector plane\n",
138-
" basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])\n",
141+
" basis = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device)\n",
139142
"\n",
140143
" # Construct the detector plane with different offsets for even or odd heights\n",
141144
" h_off = 1.0 if self.height % 2 else 0.5\n",
142145
" w_off = 1.0 if self.width % 2 else 0.5\n",
143146
"\n",
144147
" # Construct equally spaced points along the basis vectors\n",
145-
" t = (torch.arange(-self.height // 2, self.height // 2) + h_off) * self.delx\n",
146-
" s = (torch.arange(-self.width // 2, self.width // 2) + w_off) * self.dely\n",
148+
" t = (torch.arange(-self.height // 2, self.height // 2, device=device) + h_off) * self.delx\n",
149+
" s = (torch.arange(-self.width // 2, self.width // 2, device=device) + w_off) * self.dely\n",
147150
" if self.reverse_x_axis:\n",
148151
" s = -s\n",
149152
" coefs = torch.cartesian_prod(t, s).reshape(-1, 2)\n",

0 commit comments

Comments
 (0)