Skip to content

Commit 98c1109

Browse files
committed
fix: Rotation applied in surface_image_stencil() based on image orientation
1 parent 1eab935 commit 98c1109

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

src/deepali/utils/vtk/image.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
vtkImageData,
99
vtkImageStencilData,
1010
vtkImageStencilToImage,
11-
vtkMatrixToLinearTransform,
1211
vtkPolyData,
1312
vtkPolyDataToImageStencil,
13+
vtkTransform,
1414
vtkTransformPolyDataFilter,
1515
)
1616

@@ -42,29 +42,31 @@ def surface_mesh_grid(*mesh: vtkPolyData, resolution: Optional[float] = None) ->
4242

4343

4444
def surface_image_stencil(mesh: vtkPolyData, grid: Grid) -> vtkImageStencilData:
45-
r"""Convert vtkPolyData surface mesh to image stencil."""
46-
max_index = [n - 1 for n in grid.size().tolist()]
47-
48-
rot = np.eye(4, dtype=np.float)
49-
rot[:3, :3] = np.array(grid.direction).reshape(3, 3)
50-
rot = numpy_to_vtk_matrix4x4(rot)
51-
52-
transform = vtkMatrixToLinearTransform()
53-
transform.SetInput(rot)
54-
45+
r"""Convert vtkPolyData surface mesh to image stencil."""
46+
# Create the transform
47+
transform = vtkTransform()
48+
transform.Translate(grid.center().tolist())
49+
transform.Concatenate(numpy_to_vtk_matrix4x4(grid.direction().numpy().T)) # type: ignore
50+
transform.Translate(grid.center().neg().tolist())
51+
52+
# Apply the transform to the polydata
5553
transformer = vtkTransformPolyDataFilter()
5654
transformer.SetInputData(mesh)
5755
transformer.SetTransform(transform)
5856

59-
converter = vtkPolyDataToImageStencil()
60-
converter.SetInputConnection(transformer.GetOutputPort())
61-
converter.SetOutputOrigin(grid.origin().tolist())
62-
converter.SetOutputSpacing(grid.spacing().tolist())
63-
converter.SetOutputWholeExtent([0, max_index[0], 0, max_index[1], 0, max_index[2]])
64-
converter.Update()
65-
57+
# Convert the transformed polydata to an image stencil
58+
stencil_grid = Grid(size=grid.size(), spacing=grid.spacing(), center=grid.center())
59+
stencil_extent = [0, grid.size(0) - 1, 0, grid.size(1) - 1, 0, grid.size(2) - 1]
60+
polydata_to_stencil = vtkPolyDataToImageStencil()
61+
polydata_to_stencil.SetInputConnection(transformer.GetOutputPort())
62+
polydata_to_stencil.SetOutputOrigin(stencil_grid.origin().tolist())
63+
polydata_to_stencil.SetOutputSpacing(stencil_grid.spacing().tolist())
64+
polydata_to_stencil.SetOutputWholeExtent(stencil_extent)
65+
polydata_to_stencil.Update()
66+
67+
# Get the output stencil
6668
stencil = vtkImageStencilData()
67-
stencil.DeepCopy(converter.GetOutput())
69+
stencil.DeepCopy(polydata_to_stencil.GetOutput())
6870
return stencil
6971

7072

0 commit comments

Comments
 (0)