From fdec9652d7d41ecf93a2016cfba0f5ea7fe1972d Mon Sep 17 00:00:00 2001 From: haarisr <122410226+haarisr@users.noreply.github.com> Date: Mon, 8 Jul 2024 11:45:49 -0700 Subject: [PATCH] Add missing parameter mode (#300) --- .gitignore | 3 +++ diffdrr/renderers.py | 8 ++++++-- notebooks/api/01_renderers.ipynb | 4 ++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 01ff62ca1..92dfba872 100644 --- a/.gitignore +++ b/.gitignore @@ -96,3 +96,6 @@ target/ # Mypy cache .mypy_cache/ + +# Jupyter notebooks +**/.last_checked diff --git a/diffdrr/renderers.py b/diffdrr/renderers.py index d26ade481..4d1b691fc 100644 --- a/diffdrr/renderers.py +++ b/diffdrr/renderers.py @@ -73,7 +73,9 @@ def forward( # https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614 B, D, _ = img.shape C = int(mask.max().item() + 1) - channels = _get_voxel(mask, xyzs, align_corners=align_corners).long() + channels = _get_voxel( + mask, xyzs, self.mode, align_corners=align_corners + ).long() img = ( torch.zeros(B, C, D) .to(img) @@ -203,7 +205,9 @@ def forward( else: B, D, _ = img.shape C = int(mask.max().item() + 1) - channels = _get_voxel(mask, xyzs, align_corners=align_corners).long() + channels = _get_voxel( + mask, xyzs, self.mode, align_corners=align_corners + ).long() img = ( torch.zeros(B, C, D) .to(img) diff --git a/notebooks/api/01_renderers.ipynb b/notebooks/api/01_renderers.ipynb index af20b3e47..1947303f9 100644 --- a/notebooks/api/01_renderers.ipynb +++ b/notebooks/api/01_renderers.ipynb @@ -175,7 +175,7 @@ " # https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614\n", " B, D, _ = img.shape\n", " C = int(mask.max().item() + 1)\n", - " channels = _get_voxel(mask, xyzs, align_corners=align_corners).long()\n", + " channels = _get_voxel(mask, xyzs, self.mode, align_corners=align_corners).long()\n", " img = (\n", " torch.zeros(B, C, D)\n", " .to(img)\n", @@ -334,7 +334,7 @@ " else:\n", " B, D, _ = img.shape\n", " C = int(mask.max().item() + 1)\n", - " channels = _get_voxel(mask, xyzs, align_corners=align_corners).long()\n", + " channels = _get_voxel(mask, xyzs, self.mode, align_corners=align_corners).long()\n", " img = (\n", " torch.zeros(B, C, D)\n", " .to(img)\n",