@@ -25,16 +25,14 @@ def __init__(
25
25
sdr : float , # Source-to-detector radius for the C-arm (half of the source-to-detector distance)
26
26
height : int , # Height of the rendered DRR
27
27
delx : float , # X-axis pixel size
28
- width : int
29
- | None = None , # Width of the rendered DRR (if not provided, set to `height`)
28
+ width : int | None = None , # Width of the rendered DRR (default to `height`)
30
29
dely : float | None = None , # Y-axis pixel size (if not provided, set to `delx`)
31
30
x0 : float = 0.0 , # Principal point X-offset
32
31
y0 : float = 0.0 , # Principal point Y-offset
33
32
p_subsample : float | None = None , # Proportion of pixels to randomly subsample
34
33
reshape : bool = True , # Return DRR with shape (b, 1, h, w)
35
34
reverse_x_axis : bool = False , # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
36
- patch_size : int
37
- | None = None , # If the entire DRR can't fit in memory, render patches of the DRR in series
35
+ patch_size : int | None = None , # Render patches of the DRR in series
38
36
bone_attenuation_multiplier : float = 1.0 , # Contrast ratio of bone to soft tissue
39
37
):
40
38
super ().__init__ ()
@@ -92,18 +90,16 @@ def reshape_subsampled_drr(
92
90
return drr
93
91
94
92
# %% ../notebooks/api/00_drr.ipynb 10
95
- from . detector import make_xrays
96
- from .utils import Transform3d
93
+ # from diffdrr.se3 import RigidTransform, convert
94
+ from .pose import convert
97
95
98
96
99
97
@patch
100
98
def forward (
101
99
self : DRR ,
102
- rotation : torch .Tensor ,
103
- translation : torch .Tensor ,
104
- parameterization : str ,
105
- convention : str = None ,
106
- pose : Transform3d = None , # If you have a preformed pose, can pass it directly
100
+ * args , # Some batched representation of SE(3)
101
+ parameterization : str = None , # Specifies the representation of the rotation
102
+ convention : str = None , # If parameterization is Euler angles, specify convention
107
103
bone_attenuation_multiplier : float = None , # Contrast ratio of bone to soft tissue
108
104
):
109
105
"""Generate DRR with rotational and translational parameters."""
@@ -112,18 +108,11 @@ def forward(
112
108
if bone_attenuation_multiplier is not None :
113
109
self .set_bone_attenuation_multiplier (bone_attenuation_multiplier )
114
110
115
- if pose is None :
116
- assert len (rotation ) == len (translation )
117
- batch_size = len (rotation )
118
- source , target = self .detector (
119
- rotation = rotation ,
120
- translation = translation ,
121
- parameterization = parameterization ,
122
- convention = convention ,
123
- )
111
+ if parameterization is None :
112
+ pose = args [0 ]
124
113
else :
125
- batch_size = len ( pose )
126
- source , target = make_xrays ( pose , self .detector . source , self . detector . target )
114
+ pose = convert ( * args , parameterization = parameterization , convention = convention )
115
+ source , target = self .detector ( pose )
127
116
128
117
if self .patch_size is not None :
129
118
n_points = target .shape [1 ] // self .n_patches
@@ -135,7 +124,7 @@ def forward(
135
124
img = torch .cat (img , dim = 1 )
136
125
else :
137
126
img = siddon_raycast (source , target , self .density , self .spacing )
138
- return self .reshape_transform (img , batch_size = batch_size )
127
+ return self .reshape_transform (img , batch_size = len ( pose ) )
139
128
140
129
# %% ../notebooks/api/00_drr.ipynb 11
141
130
@patch
@@ -171,9 +160,6 @@ def set_intrinsics(
171
160
).to (self .volume )
172
161
173
162
# %% ../notebooks/api/00_drr.ipynb 14
174
- from .utils import convert
175
-
176
-
177
163
class Registration (nn .Module ):
178
164
"""Perform automatic 2D-to-3D registration using differentiable rendering."""
179
165
@@ -183,38 +169,25 @@ def __init__(
183
169
rotation : torch .Tensor ,
184
170
translation : torch .Tensor ,
185
171
parameterization : str ,
186
- input_convention : str = None ,
187
- output_convention : str = "ZYX" ,
172
+ convention : str = None ,
188
173
):
189
174
super ().__init__ ()
190
175
self .drr = drr
191
176
self .rotation = nn .Parameter (rotation )
192
177
self .translation = nn .Parameter (translation )
193
178
self .parameterization = parameterization
194
- self .input_convention = input_convention
195
- self .output_convention = output_convention
179
+ self .convention = convention
196
180
197
181
def forward (self ):
198
182
return self .drr (
199
183
self .rotation ,
200
184
self .translation ,
201
- self .parameterization ,
202
- self .input_convention ,
185
+ parameterization = self .parameterization ,
186
+ convention = self .convention ,
203
187
)
204
188
205
189
def get_rotation (self ):
206
- return (
207
- convert (
208
- self .rotation ,
209
- input_parameterization = self .parameterization ,
210
- output_parameterization = "euler_angles" ,
211
- input_convention = self .input_convention ,
212
- output_convention = self .output_convention ,
213
- )
214
- .clone ()
215
- .detach ()
216
- .cpu ()
217
- )
190
+ return self .rotation .clone ().detach ().cpu ()
218
191
219
192
def get_translation (self ):
220
193
return self .translation .clone ().detach ().cpu ()
0 commit comments