diff --git a/diffdrr/renderers.py b/diffdrr/renderers.py index 4d1b691fc..6eb87bff2 100644 --- a/diffdrr/renderers.py +++ b/diffdrr/renderers.py @@ -182,13 +182,16 @@ def forward( n_points=500, align_corners=False, mask=None, + alphamin=None, + alphamax=None, ): dims = self.dims(volume) # Sample points along the rays and rescale to [-1, 1] - alphamin, alphamax = _get_alpha_minmax(source, target, dims, self.eps) - alphamin = alphamin.min() - alphamax = alphamax.max() + if alphamin is None or alphamax is None: + alphamin, alphamax = _get_alpha_minmax(source, target, dims, self.eps) + alphamin = alphamin.min() + alphamax = alphamax.max() alphas = torch.linspace(0, 1, n_points)[None, None].to(volume) alphas = alphas * (alphamax - alphamin) + alphamin diff --git a/notebooks/api/01_renderers.ipynb b/notebooks/api/01_renderers.ipynb index 1947303f9..dab23f566 100644 --- a/notebooks/api/01_renderers.ipynb +++ b/notebooks/api/01_renderers.ipynb @@ -175,7 +175,9 @@ " # https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614\n", " B, D, _ = img.shape\n", " C = int(mask.max().item() + 1)\n", - " channels = _get_voxel(mask, xyzs, self.mode, align_corners=align_corners).long()\n", + " channels = _get_voxel(\n", + " mask, xyzs, self.mode, align_corners=align_corners\n", + " ).long()\n", " img = (\n", " torch.zeros(B, C, D)\n", " .to(img)\n", @@ -311,13 +313,16 @@ " n_points=500,\n", " align_corners=False,\n", " mask=None,\n", + " alphamin=None,\n", + " alphamax=None,\n", " ):\n", " dims = self.dims(volume)\n", "\n", " # Sample points along the rays and rescale to [-1, 1]\n", - " alphamin, alphamax = _get_alpha_minmax(source, target, dims, self.eps)\n", - " alphamin = alphamin.min()\n", - " alphamax = alphamax.max()\n", + " if alphamin is None or alphamax is None:\n", + " alphamin, alphamax = _get_alpha_minmax(source, target, dims, self.eps)\n", + " alphamin = alphamin.min()\n", + " alphamax = alphamax.max()\n", " alphas = torch.linspace(0, 1, n_points)[None, None].to(volume)\n", " alphas = alphas * (alphamax - alphamin) + alphamin\n", "\n", @@ -334,7 +339,9 @@ " else:\n", " B, D, _ = img.shape\n", " C = int(mask.max().item() + 1)\n", - " channels = _get_voxel(mask, xyzs, self.mode, align_corners=align_corners).long()\n", + " channels = _get_voxel(\n", + " mask, xyzs, self.mode, align_corners=align_corners\n", + " ).long()\n", " img = (\n", " torch.zeros(B, C, D)\n", " .to(img)\n",