Skip to content

Commit

Permalink
WIP 3d atlas crop
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorTatarnikov committed May 31, 2024
1 parent 449b369 commit 55ca5fc
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 36 deletions.
45 changes: 37 additions & 8 deletions brainglobe_registration/elastix/register.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List
from pathlib import Path
from typing import List, Tuple

import itk
import numpy as np
import numpy.typing as npt
from brainglobe_atlasapi import BrainGlobeAtlas


Expand All @@ -25,24 +27,33 @@ def get_atlas_by_name(atlas_name: str) -> BrainGlobeAtlas:


def run_registration(
atlas_image,
moving_image,
annotation_image,
atlas_image: npt.NDArray,
moving_image: npt.NDArray,
annotation_image: npt.NDArray,
atlas_voxel_size: Tuple[float, ...],
moving_voxel_size: Tuple[float, ...],
parameter_lists: List[tuple[str, dict]],
) -> tuple[np.ndarray, itk.ParameterObject, np.ndarray]:
output_directory: Path,
) -> Tuple[np.ndarray, itk.ParameterObject, np.ndarray]:
"""
Run the registration process on the given images.
Parameters
----------
atlas_image : np.ndarray
atlas_image : npt.NDArray
The atlas image.
moving_image : np.ndarray
moving_image : npt.NDArray
The moving image.
annotation_image : np.ndarray
atlas_voxel_size : Tuple[float, ...]
The voxel size of the atlas image in um.
moving_voxel_size : Tuple[float, ...]
The voxel size of the moving image in um.
annotation_image : npt.NDArray
The annotation image.
parameter_lists : List[tuple[str, dict]], optional
The list of parameter lists, by default None
output_directory: Path
The output directory for the registration process.
Returns
-------
Expand All @@ -54,6 +65,13 @@ def run_registration(
# convert to ITK, view only
atlas_image = itk.GetImageViewFromArray(atlas_image).astype(itk.F)
moving_image = itk.GetImageViewFromArray(moving_image).astype(itk.F)
annotation_image = itk.GetImageViewFromArray(annotation_image).astype(
itk.F
)

atlas_image.SetSpacing(atlas_voxel_size)
annotation_image.SetSpacing(atlas_voxel_size)
moving_image.SetSpacing(moving_voxel_size)

# This syntax needed for 3D images
elastix_object = itk.ElastixRegistrationMethod.New(
Expand All @@ -63,6 +81,7 @@ def run_registration(
parameter_object = setup_parameter_object(parameter_lists=parameter_lists)

elastix_object.SetParameterObject(parameter_object)
elastix_object.SetOutputDirectory(str(output_directory))

# update filter object
elastix_object.UpdateLargestPossibleRegion()
Expand All @@ -82,6 +101,16 @@ def run_registration(
result_transform_parameters,
)

# Load Transformix Object
transformix_object = itk.TransformixFilter.New(annotation_image)
transformix_object.SetTransformParameterObject(result_transform_parameters)

# Update object (required)
transformix_object.UpdateLargestPossibleRegion()

# Results of Transformation
annotation_image_transformix = transformix_object.GetOutput()

result_transform_parameters.SetParameter(
"FinalBSplineInterpolationOrder", temp_interp_order
)
Expand Down
85 changes: 66 additions & 19 deletions brainglobe_registration/registration_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from napari.qt.threading import thread_worker
from napari.utils.notifications import show_error
from napari.viewer import Viewer
from ome_zarr.dask_utils import resize
from pytransform3d.rotations import active_matrix_from_angle
from qtpy.QtWidgets import (
QPushButton,
Expand Down Expand Up @@ -293,16 +294,21 @@ def _on_sample_dropdown_index_changed(self, index):
self._viewer, self._sample_images[index]
)
self._moving_image = self._viewer.layers[viewer_index]
self._moving_image_data_backup = self._moving_image.data.copy()
self._moving_image_data_backup = da.asarray(self._moving_image.data)

def _on_adjust_moving_image(self, x: int, y: int, rotate: float):
if len(self._moving_image.data.shape) == 3:
self.adjust_moving_image_widget.set_moving_image_to3d()
else:
self.adjust_moving_image_widget.set_moving_image_to2d()

def _on_adjust_moving_image(self, x: int, y: int, z: int, rotate: float):
if not self._moving_image:
show_error(
"No moving image selected. "
"Please select a moving image before adjusting"
)
return
adjust_napari_image_layer(self._moving_image, x, y, rotate)
adjust_napari_image_layer(self._moving_image, x, y, z, rotate)

def _on_adjust_moving_image_reset_button_click(self):
if not self._moving_image:
Expand Down Expand Up @@ -371,14 +377,35 @@ def _on_crop_atlas_z_signal(self, start: int, end: int):
)

def _on_run_button_click(self):
if not (self._atlas and self._moving_image):
show_error(
"Sample image or atlas not selected. "
"Please select a sample image and atlas before running"
)
return

if len(self._moving_image.data.shape) == 3:
reference_data = self._atlas_data_layer.data
annotation_data = self._atlas_annotations_layer.data.compute()
else:
current_atlas_slice = self._viewer.dims.current_step[0]
reference_data = self._atlas_data_layer.data[
current_atlas_slice, :, :
]
annotation_data = self._atlas_annotations_layer.data[
current_atlas_slice, :, :
]

current_atlas_slice = self._viewer.dims.current_step[0]
output_path = Path.home() / "NIU-dev" / "elastix_output"

result, parameters, registered_annotation_image = run_registration(
self._atlas_data_layer.data[current_atlas_slice, :, :],
self._moving_image.data,
self._atlas_annotations_layer.data[current_atlas_slice, :, :],
self.transform_selections,
atlas_image=reference_data,
moving_image=self._moving_image.data,
atlas_voxel_size=self._atlas.resolution,
moving_voxel_size=(25, 25, 25),
annotation_image=annotation_data,
parameter_lists=self.transform_selections,
output_directory=output_path,
)

boundaries = find_boundaries(
Expand Down Expand Up @@ -475,7 +502,7 @@ def _on_sample_popup_about_to_show(self):
self._sample_images = get_image_layer_names(self._viewer)
self.get_atlas_widget.update_sample_image_names(self._sample_images)

def _on_scale_moving_image(self, x: float, y: float):
def _on_scale_moving_image(self, x: float, y: float, z: float = 1.0):
"""
Scale the moving image to have resolution equal to the atlas.
Expand All @@ -485,11 +512,13 @@ def _on_scale_moving_image(self, x: float, y: float):
Moving image x pixel size (> 0.0).
y : float
Moving image y pixel size (> 0.0).
z : float
Moving image z pixel size (> 0.0).
Will show an error if the pixel sizes are less than or equal to 0.
Will show an error if the moving image or atlas is not selected.
"""
if x <= 0 or y <= 0:
if x <= 0 or y <= 0 or z <= 0:
show_error("Pixel sizes must be greater than 0")
return

Expand All @@ -501,18 +530,34 @@ def _on_scale_moving_image(self, x: float, y: float):
return

if self._moving_image_data_backup is None:
self._moving_image_data_backup = self._moving_image.data.copy()
self._moving_image_data_backup = da.asarray(
self._moving_image.data
)

x_factor = x / self._atlas.resolution[0]
y_factor = y / self._atlas.resolution[1]

self._moving_image.data = rescale(
self._moving_image_data_backup,
(y_factor, x_factor),
mode="constant",
preserve_range=True,
anti_aliasing=True,
)
# z_factor = z / self._atlas.resolution[2]

if len(self._moving_image.data.shape) == 3:
self._moving_image.data = resize(
self._moving_image_data_backup,
(
self._atlas.reference.shape[0],
self._atlas.reference.shape[2],
self._atlas.reference.shape[1],
),
mode="constant",
preserve_range=True,
anti_aliasing=True,
).compute()
else:
self._moving_image.data = rescale(
self._moving_image_data_backup,
(y_factor, x_factor),
mode="constant",
preserve_range=True,
anti_aliasing=True,
)

def _on_adjust_atlas_rotation(self, pitch: float, yaw: float, roll: float):
if not (
Expand Down Expand Up @@ -580,6 +625,8 @@ def _on_adjust_atlas_rotation(self, pitch: float, yaw: float, roll: float):
worker = self.compute_atlas_rotation(self._atlas_data_layer.data)
worker.returned.connect(self.set_atlas_layer_data)
worker.start()
self._atlas_data_layer.experimental_clipping_planes = None
self._atlas_annotations_layer.experimental_clipping_planes = None

@thread_worker
def compute_atlas_rotation(self, dask_array: da.Array):
Expand Down
30 changes: 23 additions & 7 deletions brainglobe_registration/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@


def adjust_napari_image_layer(
image_layer: napari.layers.Image, x: int, y: int, rotate: float
image_layer: napari.layers.Image,
x: int,
y: int,
z: int = 0,
rotate: float = 0,
):
"""
Adjusts the napari image layer by the given x, y, and rotation values.
Expand All @@ -28,19 +32,31 @@ def adjust_napari_image_layer(
The x-coordinate for the translation.
y : int
The y-coordinate for the translation.
rotate : float
z : int, optional
The z-coordinate for the translation.
rotate : float, optional
The angle of rotation in degrees.
Returns
--------
None
"""
image_layer.translate = (y, x)
num_dimensions = len(image_layer.data.shape)
if num_dimensions == 3:
image_layer.translate = (z, y, x)
translation = np.asarray([z, y, x])
else:
image_layer.translate = (y, x)
translation = np.asarray([y, x])

rotation_matrix = np.eye(num_dimensions + 1)
rotation_matrix[:num_dimensions, :num_dimensions] = (
active_matrix_from_angle(0, np.deg2rad(rotate))
)

rotation_matrix = active_matrix_from_angle(2, np.deg2rad(rotate))
translate_matrix = np.eye(3)
origin = np.asarray(image_layer.data.shape) // 2 + np.asarray([y, x])
translate_matrix[:2, -1] = origin
translate_matrix = np.eye(num_dimensions + 1)
origin = np.asarray(image_layer.data.shape) // 2 + translation
translate_matrix[:num_dimensions, -1] = origin
transform_matrix = (
translate_matrix @ rotation_matrix @ np.linalg.inv(translate_matrix)
)
Expand Down
32 changes: 30 additions & 2 deletions brainglobe_registration/widgets/adjust_moving_image_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class AdjustMovingImageView(QWidget):
Resets the pitch, yaw, and roll to 0 and emits the atlas_reset_signal.
"""

adjust_image_signal = Signal(int, int, float)
scale_image_signal = Signal(float, float)
adjust_image_signal = Signal(int, int, int, float)
scale_image_signal = Signal(float, float, float)
atlas_rotation_signal = Signal(float, float, float)
reset_atlas_signal = Signal()
reset_image_signal = Signal()
Expand All @@ -67,9 +67,16 @@ def __init__(self, parent=None):
self.adjust_moving_image_pixel_size_x = QDoubleSpinBox(parent=self)
self.adjust_moving_image_pixel_size_x.setDecimals(2)
self.adjust_moving_image_pixel_size_x.setRange(0.01, 100.00)

self.adjust_moving_image_pixel_size_y = QDoubleSpinBox(parent=self)
self.adjust_moving_image_pixel_size_y.setDecimals(2)
self.adjust_moving_image_pixel_size_y.setRange(0.01, 100.00)

self.adjust_moving_image_pixel_size_z = QDoubleSpinBox(parent=self)
self.adjust_moving_image_pixel_size_z.setDecimals(2)
self.adjust_moving_image_pixel_size_z.setRange(0.01, 100.00)
self.adjust_moving_image_pixel_size_z.setEnabled(False)

self.scale_moving_image_button = QPushButton()
self.scale_moving_image_button.setText("Scale Image")
self.scale_moving_image_button.clicked.connect(
Expand Down Expand Up @@ -105,6 +112,11 @@ def __init__(self, parent=None):
self.adjust_moving_image_y.setRange(-offset_range, offset_range)
self.adjust_moving_image_y.valueChanged.connect(self._on_adjust_image)

self.adjust_moving_image_z = QSpinBox(parent=self)
self.adjust_moving_image_z.setRange(-offset_range, offset_range)
self.adjust_moving_image_z.valueChanged.connect(self._on_adjust_image)
self.adjust_moving_image_z.setEnabled(False)

self.adjust_moving_image_rotate = QDoubleSpinBox(parent=self)
self.adjust_moving_image_rotate.setRange(
-rotation_range, rotation_range
Expand All @@ -129,6 +141,10 @@ def __init__(self, parent=None):
"Sample image Y pixel size (\u03BCm / pixel):",
self.adjust_moving_image_pixel_size_y,
)
self.layout().addRow(
"Sample image Z pixel size (\u03BCm / pixel):",
self.adjust_moving_image_pixel_size_z,
)
self.layout().addRow(self.scale_moving_image_button)

self.layout().addRow(QLabel("Adjust the atlas pitch and yaw: "))
Expand All @@ -141,6 +157,7 @@ def __init__(self, parent=None):
self.layout().addRow(QLabel("Adjust the moving image position: "))
self.layout().addRow("X offset:", self.adjust_moving_image_x)
self.layout().addRow("Y offset:", self.adjust_moving_image_y)
self.layout().addRow("Z offset:", self.adjust_moving_image_z)
self.layout().addRow(
"Rotation (degrees):", self.adjust_moving_image_rotate
)
Expand All @@ -154,6 +171,7 @@ def _on_adjust_image(self):
self.adjust_image_signal.emit(
self.adjust_moving_image_x.value(),
self.adjust_moving_image_y.value(),
self.adjust_moving_image_z.value(),
self.adjust_moving_image_rotate.value(),
)

Expand All @@ -164,6 +182,7 @@ def _on_reset_image_button_click(self):
"""
self.adjust_moving_image_x.setValue(0)
self.adjust_moving_image_y.setValue(0)
self.adjust_moving_image_z.setValue(0)
self.adjust_moving_image_rotate.setValue(0)

self.reset_image_signal.emit()
Expand All @@ -175,6 +194,7 @@ def _on_scale_image_button_click(self):
self.scale_image_signal.emit(
self.adjust_moving_image_pixel_size_x.value(),
self.adjust_moving_image_pixel_size_y.value(),
self.adjust_moving_image_pixel_size_z.value(),
)

def _on_adjust_atlas_rotation(self):
Expand All @@ -196,3 +216,11 @@ def _on_atlas_reset(self):
self.adjust_atlas_roll.setValue(0)

self.reset_atlas_signal.emit()

def set_moving_image_to3d(self):
self.adjust_moving_image_z.setEnabled(True)
self.adjust_moving_image_pixel_size_z.setEnabled(True)

def set_moving_image_to2d(self):
self.adjust_moving_image_z.setEnabled(False)
self.adjust_moving_image_pixel_size_z.setEnabled(False)

0 comments on commit 55ca5fc

Please sign in to comment.