diff --git a/diffdrr/_modidx.py b/diffdrr/_modidx.py index a4c667afa..9b34a313e 100644 --- a/diffdrr/_modidx.py +++ b/diffdrr/_modidx.py @@ -29,6 +29,7 @@ 'diffdrr.drr.DRR.affine_inverse': ('api/drr.html#drr.affine_inverse', 'diffdrr/drr.py'), 'diffdrr.drr.DRR.forward': ('api/drr.html#drr.forward', 'diffdrr/drr.py'), 'diffdrr.drr.DRR.inverse_projection': ('api/drr.html#drr.inverse_projection', 'diffdrr/drr.py'), + 'diffdrr.drr.DRR.n_patches': ('api/drr.html#drr.n_patches', 'diffdrr/drr.py'), 'diffdrr.drr.DRR.perspective_projection': ('api/drr.html#drr.perspective_projection', 'diffdrr/drr.py'), 'diffdrr.drr.DRR.render': ('api/drr.html#drr.render', 'diffdrr/drr.py'), 'diffdrr.drr.DRR.rescale_detector_': ('api/drr.html#drr.rescale_detector_', 'diffdrr/drr.py'), diff --git a/diffdrr/drr.py b/diffdrr/drr.py index 8cc29eb9d..c02778d83 100644 --- a/diffdrr/drr.py +++ b/diffdrr/drr.py @@ -98,8 +98,6 @@ def __init__( ) self.reshape = reshape self.patch_size = patch_size - if self.patch_size is not None: - self.n_patches = (height * width) // (self.patch_size**2) def reshape_transform(self, img, batch_size): if self.reshape: @@ -122,6 +120,10 @@ def affine(self): def affine_inverse(self): return RigidTransform(self._affine_inverse) + @property + def n_patches(self): + return (self.detector.height * self.detector.width) // (self.patch_size**2) + # %% ../notebooks/api/00_drr.ipynb 8 def reshape_subsampled_drr(img: torch.Tensor, detector: Detector, batch_size: int): n_points = detector.height * detector.width diff --git a/notebooks/api/00_drr.ipynb b/notebooks/api/00_drr.ipynb index 7273d1b35..95da3da1e 100644 --- a/notebooks/api/00_drr.ipynb +++ b/notebooks/api/00_drr.ipynb @@ -193,8 +193,6 @@ " )\n", " self.reshape = reshape\n", " self.patch_size = patch_size\n", - " if self.patch_size is not None:\n", - " self.n_patches = (height * width) // (self.patch_size**2)\n", "\n", " def reshape_transform(self, img, batch_size):\n", " if self.reshape:\n", @@ -215,7 +213,11 @@ "\n", " @property\n", " def affine_inverse(self):\n", - " return RigidTransform(self._affine_inverse)" + " return RigidTransform(self._affine_inverse)\n", + "\n", + " @property\n", + " def n_patches(self):\n", + " return (self.detector.height * self.detector.width) // (self.patch_size**2)" ] }, {