Skip to content

Commit 11b1f0a

Browse files
authored
Flip height and width in the detector (#357)
* Flip height and width; fix reorientations * Fix orientations * Rename f to sdd * Handle errors in volume masking * Fix perspective projection * Fix x/y conventions * Bump version
1 parent e41c365 commit 11b1f0a

File tree

10 files changed

+126
-85
lines changed

10 files changed

+126
-85
lines changed

diffdrr/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.4.5"
1+
__version__ = "0.4.6"

diffdrr/data.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -85,35 +85,37 @@ def read(
8585
# Frame-of-reference change
8686
if orientation == "AP":
8787
# Rotates the C-arm about the x-axis by 90 degrees
88-
# Rotates the C-arm about the z-axis by -90 degrees
8988
reorient = torch.tensor(
9089
[
91-
[0.0, 1.0, 0.0, 0.0],
92-
[0.0, 0.0, -1.0, 0.0],
93-
[-1.0, 0.0, 0.0, 0.0],
94-
[0.0, 0.0, 0.0, 1.0],
95-
]
90+
[1, 0, 0, 0],
91+
[0, 0, -1, 0],
92+
[0, 1, 0, 0],
93+
[0, 0, 0, 1],
94+
],
95+
dtype=torch.float32,
9696
)
9797
elif orientation == "PA":
9898
# Rotates the C-arm about the x-axis by 90 degrees
99-
# Rotates the C-arm about the z-axis by 90 degrees
99+
# Reverses the direction of the y-axis
100100
reorient = torch.tensor(
101101
[
102-
[0.0, 1.0, 0.0, 0.0],
103-
[0.0, 0.0, 1.0, 0.0],
104-
[-1.0, 0.0, 0.0, 0.0],
105-
[0.0, 0.0, 0.0, 1.0],
106-
]
102+
[1, 0, 0, 0],
103+
[0, 0, 1, 0],
104+
[0, 1, 0, 0],
105+
[0, 0, 0, 1],
106+
],
107+
dtype=torch.float32,
107108
)
108109
elif orientation is None:
109110
# Identity transform
110111
reorient = torch.tensor(
111112
[
112-
[1.0, 0.0, 0.0, 0.0],
113-
[0.0, 1.0, 0.0, 0.0],
114-
[0.0, 0.0, 1.0, 0.0],
115-
[0.0, 0.0, 0.0, 1.0],
116-
]
113+
[1, 0, 0, 0],
114+
[0, 1, 0, 0],
115+
[0, 0, 1, 0],
116+
[0, 0, 0, 1],
117+
],
118+
dtype=torch.float32,
117119
)
118120
else:
119121
raise ValueError(f"Unrecognized orientation {orientation}")
@@ -122,6 +124,7 @@ def read(
122124
subject = Subject(
123125
volume=volume,
124126
mask=mask,
127+
orientation=orientation,
125128
reorient=reorient,
126129
density=density,
127130
fiducials=fiducials,
@@ -161,9 +164,13 @@ def read(
161164
dim=0,
162165
)
163166

164-
subject.volume.data = subject.volume.data * mask
165-
subject.mask.data = subject.mask.data * mask
166-
subject.density.data = subject.density.data * mask
167+
# Mask all volumes, unless error, then just mask the density
168+
try:
169+
subject.volume.data = subject.volume.data * mask
170+
subject.mask.data = subject.mask.data * mask
171+
subject.density.data = subject.density.data * mask
172+
except:
173+
subject.density.data = subject.density.data * mask
167174

168175
return subject
169176

diffdrr/detector.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def __init__(
5151
"_calibration",
5252
torch.tensor(
5353
[
54-
[dely, 0, 0, -y0],
55-
[0, delx, 0, -x0],
54+
[delx, 0, 0, x0],
55+
[0, dely, 0, y0],
5656
[0, 0, sdd, 0],
5757
[0, 0, 0, 1],
5858
]
@@ -65,19 +65,19 @@ def sdd(self):
6565

6666
@property
6767
def delx(self):
68-
return self._calibration[1, 1].item()
68+
return self._calibration[0, 0].item()
6969

7070
@property
7171
def dely(self):
72-
return self._calibration[0, 0].item()
72+
return self._calibration[1, 1].item()
7373

7474
@property
7575
def x0(self):
76-
return -self._calibration[1, -1].item()
76+
return -self._calibration[0, -1].item()
7777

7878
@property
7979
def y0(self):
80-
return -self._calibration[0, -1].item()
80+
return -self._calibration[1, -1].item()
8181

8282
@property
8383
def reorient(self):
@@ -107,7 +107,7 @@ def _initialize_carm(self: Detector):
107107
center = torch.tensor([[0.0, 0.0, 1.0]], device=device)
108108

109109
# Use the standard basis for the detector plane
110-
basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)
110+
basis = torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], device=device)
111111

112112
# Construct the detector plane with different offsets for even or odd heights
113113
# These ensure that the detector plane is centered around (0, 0, 1)
@@ -117,8 +117,12 @@ def _initialize_carm(self: Detector):
117117
# Construct equally spaced points along the basis vectors
118118
t = torch.arange(-self.height // 2, self.height // 2, device=device) + h_off
119119
s = torch.arange(-self.width // 2, self.width // 2, device=device) + w_off
120-
if self.reverse_x_axis:
120+
121+
t = -t
122+
s = -s
123+
if not self.reverse_x_axis:
121124
s = -s
125+
122126
coefs = torch.cartesian_prod(t, s).reshape(-1, 2)
123127
target = torch.einsum("cd,nc->nd", basis, coefs)
124128
target += center

diffdrr/drr.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ def set_intrinsics_(
230230
width if width is not None else self.detector.width,
231231
delx if delx is not None else self.detector.delx,
232232
dely if dely is not None else self.detector.dely,
233-
x0 if x0 is not None else self.detector.x0,
234-
y0 if y0 is not None else self.detector.y0,
233+
x0 if x0 is not None else -self.detector.x0,
234+
y0 if y0 is not None else -self.detector.y0,
235235
self.subject.reorient,
236236
n_subsample if n_subsample is not None else self.detector.n_subsample,
237237
reverse_x_axis if reverse_x_axis is not None else self.detector.reverse_x_axis,
@@ -256,14 +256,21 @@ def perspective_projection(
256256
pts: torch.Tensor,
257257
):
258258
"""Project points in world coordinates (3D) onto the pixel plane (2D)."""
259+
# Poses in DiffDRR are world2camera, but perspective transforms use camera2world, so invert
259260
extrinsic = (self.detector.reorient.compose(pose)).inverse()
260261
x = extrinsic(pts)
262+
263+
# Project onto the detector plane
261264
x = torch.einsum("ij, bnj -> bni", self.detector.intrinsic, x)
262265
z = x[..., -1].unsqueeze(-1).clone()
263266
x = x / z
267+
268+
# Move origin to upper-left corner
269+
x[..., 1] = self.detector.height - x[..., 1]
264270
if self.detector.reverse_x_axis:
265-
x[..., 1] = self.detector.width - x[..., 1]
266-
return x[..., :2].flip(-1)
271+
x[..., 0] = self.detector.width - x[..., 0]
272+
273+
return x[..., :2]
267274

268275
# %% ../notebooks/api/00_drr.ipynb 14
269276
from torch.nn.functional import pad
@@ -276,7 +283,7 @@ def inverse_projection(
276283
pts: torch.Tensor,
277284
):
278285
"""Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D)."""
279-
pts = pts.flip(-1)
286+
# pts = pts.flip(-1)
280287
if self.detector.reverse_x_axis:
281288
pts[..., 1] = self.detector.width - pts[..., 1]
282289
x = self.detector.sdd * torch.einsum(

diffdrr/utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def resample(
5454

5555
# %% ../notebooks/api/07_utils.ipynb 6
5656
from kornia.geometry.camera.pinhole import PinholeCamera as KorniaPinholeCamera
57+
from torchio import Subject
5758

5859
from diffdrr.detector import Detector
5960

@@ -66,9 +67,11 @@ def __init__(
6667
height: torch.Tensor,
6768
width: torch.Tensor,
6869
detector: Detector,
70+
subject: Subject,
6971
):
7072
super().__init__(intrinsics, extrinsics, height, width)
71-
self.f = detector.sdd
73+
multiplier = -1 if subject.orientation == "PA" else 1
74+
self.sdd = multiplier * detector.sdd
7275
self.delx = detector.delx
7376
self.dely = detector.dely
7477
self.x0 = detector.x0
@@ -94,9 +97,9 @@ def pose(self):
9497

9598
from kornia.geometry.calibration import solve_pnp_dlt
9699

100+
from .detector import make_intrinsic_matrix
97101
from .drr import DRR
98102
from .pose import RigidTransform
99-
from .detector import make_intrinsic_matrix
100103

101104

102105
def get_pinhole_camera(
@@ -107,14 +110,12 @@ def get_pinhole_camera(
107110
pose = deepcopy(pose).to(device="cpu", dtype=dtype)
108111

109112
# Make the intrinsic matrix (in pixels)
110-
fx = drr.detector.sdd / drr.detector.delx
111-
fy = drr.detector.sdd / drr.detector.dely
113+
multiplier = -1 if drr.subject.orientation == "PA" else 1
114+
fx = multiplier * drr.detector.sdd / drr.detector.delx
115+
fy = multiplier * drr.detector.sdd / drr.detector.dely
112116
u0 = drr.detector.x0 / drr.detector.delx + drr.detector.width / 2
113117
v0 = drr.detector.y0 / drr.detector.dely + drr.detector.height / 2
114-
intrinsics = torch.eye(4)[None]
115-
intrinsics[0, :3, :3] = make_intrinsic_matrix(drr.detector)
116-
117-
torch.tensor(
118+
intrinsics = torch.tensor(
118119
[
119120
[
120121
[fx, 0.0, u0, 0.0],
@@ -156,6 +157,7 @@ def get_pinhole_camera(
156157
torch.tensor([drr.detector.height]),
157158
torch.tensor([drr.detector.width]),
158159
drr.detector,
160+
drr.subject,
159161
)
160162

161163
return camera

notebooks/api/00_drr.ipynb

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@
357357
" width if width is not None else self.detector.width,\n",
358358
" delx if delx is not None else self.detector.delx,\n",
359359
" dely if dely is not None else self.detector.dely,\n",
360-
" x0 if x0 is not None else self.detector.x0,\n",
361-
" y0 if y0 is not None else self.detector.y0,\n",
360+
" x0 if x0 is not None else -self.detector.x0,\n",
361+
" y0 if y0 is not None else -self.detector.y0,\n",
362362
" self.subject.reorient,\n",
363363
" n_subsample if n_subsample is not None else self.detector.n_subsample,\n",
364364
" reverse_x_axis if reverse_x_axis is not None else self.detector.reverse_x_axis,\n",
@@ -399,14 +399,21 @@
399399
" pts: torch.Tensor,\n",
400400
"):\n",
401401
" \"\"\"Project points in world coordinates (3D) onto the pixel plane (2D).\"\"\"\n",
402+
" # Poses in DiffDRR are world2camera, but perspective transforms use camera2world, so invert\n",
402403
" extrinsic = (self.detector.reorient.compose(pose)).inverse()\n",
403404
" x = extrinsic(pts)\n",
405+
"\n",
406+
" # Project onto the detector plane\n",
404407
" x = torch.einsum(\"ij, bnj -> bni\", self.detector.intrinsic, x)\n",
405408
" z = x[..., -1].unsqueeze(-1).clone()\n",
406409
" x = x / z\n",
410+
"\n",
411+
" # Move origin to upper-left corner\n",
412+
" x[..., 1] = self.detector.height - x[..., 1]\n",
407413
" if self.detector.reverse_x_axis:\n",
408-
" x[..., 1] = self.detector.width - x[..., 1]\n",
409-
" return x[..., :2].flip(-1)"
414+
" x[..., 0] = self.detector.width - x[..., 0]\n",
415+
" \n",
416+
" return x[..., :2]"
410417
]
411418
},
412419
{
@@ -427,7 +434,7 @@
427434
" pts: torch.Tensor,\n",
428435
"):\n",
429436
" \"\"\"Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D).\"\"\"\n",
430-
" pts = pts.flip(-1)\n",
437+
" # pts = pts.flip(-1)\n",
431438
" if self.detector.reverse_x_axis:\n",
432439
" pts[..., 1] = self.detector.width - pts[..., 1]\n",
433440
" x = self.detector.sdd * torch.einsum(\n",

notebooks/api/02_detector.ipynb

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@
106106
" \"_calibration\",\n",
107107
" torch.tensor(\n",
108108
" [\n",
109-
" [dely, 0, 0, -y0],\n",
110-
" [0, delx, 0, -x0],\n",
109+
" [delx, 0, 0, x0],\n",
110+
" [0, dely, 0, y0],\n",
111111
" [0, 0, sdd, 0],\n",
112112
" [0, 0, 0, 1],\n",
113113
" ]\n",
@@ -120,19 +120,19 @@
120120
"\n",
121121
" @property\n",
122122
" def delx(self):\n",
123-
" return self._calibration[1, 1].item()\n",
123+
" return self._calibration[0, 0].item()\n",
124124
"\n",
125125
" @property\n",
126126
" def dely(self):\n",
127-
" return self._calibration[0, 0].item()\n",
127+
" return self._calibration[1, 1].item()\n",
128128
"\n",
129129
" @property\n",
130130
" def x0(self):\n",
131-
" return -self._calibration[1, -1].item()\n",
131+
" return -self._calibration[0, -1].item()\n",
132132
"\n",
133133
" @property\n",
134134
" def y0(self):\n",
135-
" return -self._calibration[0, -1].item()\n",
135+
" return -self._calibration[1, -1].item()\n",
136136
"\n",
137137
" @property\n",
138138
" def reorient(self):\n",
@@ -170,7 +170,7 @@
170170
" center = torch.tensor([[0.0, 0.0, 1.0]], device=device)\n",
171171
"\n",
172172
" # Use the standard basis for the detector plane\n",
173-
" basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)\n",
173+
" basis = torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], device=device)\n",
174174
"\n",
175175
" # Construct the detector plane with different offsets for even or odd heights\n",
176176
" # These ensure that the detector plane is centered around (0, 0, 1)\n",
@@ -180,8 +180,12 @@
180180
" # Construct equally spaced points along the basis vectors\n",
181181
" t = torch.arange(-self.height // 2, self.height // 2, device=device) + h_off\n",
182182
" s = torch.arange(-self.width // 2, self.width // 2, device=device) + w_off\n",
183-
" if self.reverse_x_axis:\n",
183+
"\n",
184+
" t = -t\n",
185+
" s = -s\n",
186+
" if not self.reverse_x_axis:\n",
184187
" s = -s\n",
188+
"\n",
185189
" coefs = torch.cartesian_prod(t, s).reshape(-1, 2)\n",
186190
" target = torch.einsum(\"cd,nc->nd\", basis, coefs)\n",
187191
" target += center\n",

0 commit comments

Comments
 (0)