diff --git a/.gitignore b/.gitignore index 73d56d3..8e5bfd8 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,6 @@ venv/ # written by setuptools_scm **/_version.py + +# Other test files +Logs/ diff --git a/README.md b/README.md index 1219a7e..f22db06 100644 --- a/README.md +++ b/README.md @@ -11,29 +11,51 @@ Registration to a BrainGlobe atlas using Elastix ---------------------------------- -This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. - - - +A [napari] plugin for registering images to a BrainGlobe atlas. + +![brainglobe-registration](./imgs/brainglobe_registration_main.png) + +## Usage + +1. Open `napari`. +2. Install the plugin with `pip install git+https://github.com/brainglobe/brainglobe-registration.git`. +3. Open the widget by selecting `Plugins > BrainGlobe Registration` in the napari menu bar near the +top left of the window. +![brainglobe-registration-plugin](./imgs/brainglobe_registration_plugin_window.png) +The `BrainGlobe Registration` plugin will appear on the right hand side of the napari window. +4. Open the image you want to register in napari (a sample 2D image can be found by selecting `File > Open Sample > Sample Brain Slice`). +5. Select the atlas you want to register to from the dropdown menu. +![brainglobe-registration-atlas-selection](./imgs/brainglobe_registration_atlas_selection.png) +The atlas will appear in the napari viewer. Select the approximate `Z` slice of the atlas that you want to register to, +using the slider at the bottom of the napari viewer. +![brainglobe-registration-atlas-selection](./imgs/brainglobe_registration_atlas_selection_2.png) +6. Adjust the sample image to roughly match the atlas image. +You can do this by adjusting X and Y translation as well as rotating around the centre of the image. +You can overlay the two images by toggling `Grid` mode in the napari viewer (Ctrl+G). +You can then adjust the color map and opacity of the atlas image to make manual alignment easier. +![brainglobe-registration-overlay](./imgs/brainglobe_registration_overlay.png) +The sample image can be reset to its original position and orientation by clicking `Reset Image` in the `BrainGlobe Registration` plugin window. +7. Select the transformations you want to use from the dropdown menu. Set the transformation type to empty to remove a step. +Select from one of the three provided default parameter sets (elastix, ARA, or IBL). Customise the parameters further in the +`Parameters` tab. +8. Click `Run` to register the image. The registered image will appear in the napari viewer. +![brainglobe-registration-registered](./imgs/brainglobe_registration_registered.png) +![brainglobe-registration-registered](./imgs/brainglobe_registration_registered_stacked.png) ## Installation -You can install `brainglobe-registration` via [pip]: +We strongly recommend to use a virtual environment manager (like `conda` or `venv`). The installation instructions below +will not specify the Qt backend for napari, and you will therefore need to install that separately. Please see the +[`napari` installation instructions](https://napari.org/stable/tutorials/fundamentals/installation.html) for further advice on this. - pip install brainglobe-registration +[WIP] You can install `brainglobe-registration` via [pip]: + pip install brainglobe-registration To install latest development version : pip install git+https://github.com/brainglobe/brainglobe-registration.git - ## Contributing Contributions are very welcome. Tests can be run with [tox], please ensure @@ -48,6 +70,19 @@ Distributed under the terms of the [BSD-3] license, If you encounter any problems, please [file an issue] along with a detailed description. + +## Acknowledgements + +This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. + + + [napari]: https://github.com/napari/napari [Cookiecutter]: https://github.com/audreyr/cookiecutter [@napari]: https://github.com/napari diff --git a/imgs/brainglobe_registration_atlas_selection.png b/imgs/brainglobe_registration_atlas_selection.png new file mode 100644 index 0000000..f23ec30 Binary files /dev/null and b/imgs/brainglobe_registration_atlas_selection.png differ diff --git a/imgs/brainglobe_registration_atlas_selection_2.png b/imgs/brainglobe_registration_atlas_selection_2.png new file mode 100644 index 0000000..07cf11f Binary files /dev/null and b/imgs/brainglobe_registration_atlas_selection_2.png differ diff --git a/imgs/brainglobe_registration_main.png b/imgs/brainglobe_registration_main.png new file mode 100644 index 0000000..beae32c Binary files /dev/null and b/imgs/brainglobe_registration_main.png differ diff --git a/imgs/brainglobe_registration_overlay.png b/imgs/brainglobe_registration_overlay.png new file mode 100644 index 0000000..25f44fe Binary files /dev/null and b/imgs/brainglobe_registration_overlay.png differ diff --git a/imgs/brainglobe_registration_plugin_window.png b/imgs/brainglobe_registration_plugin_window.png new file mode 100644 index 0000000..4ff1f3b Binary files /dev/null and b/imgs/brainglobe_registration_plugin_window.png differ diff --git a/imgs/brainglobe_registration_registered.png b/imgs/brainglobe_registration_registered.png new file mode 100644 index 0000000..d4acc29 Binary files /dev/null and b/imgs/brainglobe_registration_registered.png differ diff --git a/imgs/brainglobe_registration_registered_stacked.png b/imgs/brainglobe_registration_registered_stacked.png new file mode 100644 index 0000000..b5bf813 Binary files /dev/null and b/imgs/brainglobe_registration_registered_stacked.png differ diff --git a/pyproject.toml b/pyproject.toml index 3b88d90..d335581 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [tool.setuptools_scm] -write_to = "src/bg_elastix/_version.py" +write_to = "src/brainglobe_registration/_version.py" [tool.black] line-length = 79 @@ -12,3 +12,9 @@ line-length = 79 [tool.isort] profile = "black" line_length = 79 + +[tool.pytest.ini_options] +testpaths = "src/tests" +markers = [ + "slow: mark test as slow" +] diff --git a/setup.cfg b/setup.cfg index e7a52cb..824757a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,6 +36,7 @@ install_requires = qtpy itk-elastix bg-atlasapi + pytransform3d python_requires = >=3.8 include_package_data = True diff --git a/src/brainglobe_registration/_tests/test_widget.py b/src/brainglobe_registration/_tests/test_widget.py deleted file mode 100644 index 46cc872..0000000 --- a/src/brainglobe_registration/_tests/test_widget.py +++ /dev/null @@ -1,21 +0,0 @@ -import numpy as np - -from bg_elastix.elastix import register - -# make_napari_viewer is a pytest fixture that returns a napari viewer object -# capsys is a pytest fixture that captures stdout and stderr output streams - - -def test_example_magic_widget(make_napari_viewer, capsys): - viewer = make_napari_viewer() - layer = viewer.add_image(np.random.random((100, 100))) - - # this time, our widget will be a MagicFactory or FunctionGui instance - my_widget = register() - - # if we "call" this object, it'll execute our function - my_widget(viewer.layers[0]) - - # read captured output and check that it's as we expected - captured = capsys.readouterr() - assert captured.out == f"you have selected {layer}\n" diff --git a/src/brainglobe_registration/_widget.py b/src/brainglobe_registration/_widget.py deleted file mode 100644 index b410bcc..0000000 --- a/src/brainglobe_registration/_widget.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import TYPE_CHECKING - -from magicgui import magic_factory - -from bg_elastix.elastix.register import run_registration - -if TYPE_CHECKING: - import napari - - -@magic_factory -def register( - viewer: "napari.Viewer", - image: "napari.layers.Image", - atlas_image: "napari.layers.Image", - rigid=True, - affine=True, - bspline=True, - affine_iterations="2048", - log=False, -): - result, parameters = run_registration( - atlas_image.data, - image.data, - rigid=rigid, - affine=affine, - bspline=bspline, - affine_iterations=affine_iterations, - log=log, - ) - viewer.add_image(result, name="Registered Image") diff --git a/src/brainglobe_registration/elastix/register.py b/src/brainglobe_registration/elastix/register.py index a271a8e..c7b1a8c 100644 --- a/src/brainglobe_registration/elastix/register.py +++ b/src/brainglobe_registration/elastix/register.py @@ -1,33 +1,67 @@ import itk import numpy as np +from bg_atlasapi import BrainGlobeAtlas +from typing import List + + +def get_atlas_by_name(atlas_name: str) -> BrainGlobeAtlas: + """ + Get a BrainGlobeAtlas object by its name. + + Parameters + ---------- + atlas_name : str + The name of the atlas. + + Returns + ------- + BrainGlobeAtlas + The BrainGlobeAtlas object. + """ + atlas = BrainGlobeAtlas(atlas_name) + + return atlas def run_registration( - fixed_image, + atlas_image, moving_image, - rigid=True, - affine=True, - bspline=True, - affine_iterations="2048", - log=False, -): + annotation_image, + parameter_lists: List[tuple[str, dict]] = None, +) -> tuple[np.ndarray, itk.ParameterObject, np.ndarray]: + """ + Run the registration process on the given images. + + Parameters + ---------- + atlas_image : np.ndarray + The atlas image. + moving_image : np.ndarray + The moving image. + annotation_image : np.ndarray + The annotation image. + parameter_lists : List[tuple[str, dict]], optional + The list of parameter lists, by default None + + Returns + ------- + np.ndarray + The result image. + itk.ParameterObject + The result transform parameters. + """ # convert to ITK, view only - fixed_image = itk.GetImageViewFromArray(fixed_image).astype(itk.F) + atlas_image = itk.GetImageViewFromArray(atlas_image).astype(itk.F) moving_image = itk.GetImageViewFromArray(moving_image).astype(itk.F) # This syntax needed for 3D images elastix_object = itk.ElastixRegistrationMethod.New( - fixed_image, moving_image + moving_image, atlas_image ) - parameter_object = setup_parameter_object( - rigid=rigid, - affine=affine, - bspline=bspline, - affine_iterations=affine_iterations, - ) + parameter_object = setup_parameter_object(parameter_lists=parameter_lists) + elastix_object.SetParameterObject(parameter_object) - elastix_object.SetLogToConsole(log) # update filter object elastix_object.UpdateLargestPossibleRegion() @@ -35,32 +69,52 @@ def run_registration( # get results result_image = elastix_object.GetOutput() result_transform_parameters = elastix_object.GetTransformParameterObject() - return np.asarray(result_image), result_transform_parameters + temp_interp_order = result_transform_parameters.GetParameter( + 0, "FinalBSplineInterpolationOrder" + ) + result_transform_parameters.SetParameter( + "FinalBSplineInterpolationOrder", "0" + ) + annotation_image_transformix = itk.transformix_filter( + annotation_image.astype(np.float32, copy=False), + result_transform_parameters, + ) + + result_transform_parameters.SetParameter( + "FinalBSplineInterpolationOrder", temp_interp_order + ) + + return ( + np.asarray(result_image), + result_transform_parameters, + np.asarray(annotation_image_transformix), + ) -def setup_parameter_object( - rigid=True, - affine=True, - bspline=True, - affine_iterations="2048", -): + +def setup_parameter_object(parameter_lists: List[tuple[str, dict]] = None): + """ + Set up the parameter object for the registration process. + + Parameters + ---------- + parameter_lists : List[tuple[str, dict]], optional + The list of parameter lists, by default None + + Returns + ------- + itk.ParameterObject + The parameter object.# + """ parameter_object = itk.ParameterObject.New() - if rigid: - parameter_map_rigid = parameter_object.GetDefaultParameterMap("rigid") - parameter_object.AddParameterMap(parameter_map_rigid) - - if affine: - parameter_map_affine = parameter_object.GetDefaultParameterMap( - "affine" - ) - parameter_map_affine["MaximumNumberOfIterations"] = [affine_iterations] - parameter_object.AddParameterMap(parameter_map_affine) - - if bspline: - parameter_map_bspline = parameter_object.GetDefaultParameterMap( - "bspline" - ) - parameter_object.AddParameterMap(parameter_map_bspline) + for transform_type, parameter_dict in parameter_lists: + parameter_map = parameter_object.GetDefaultParameterMap(transform_type) + parameter_map.clear() + + for k, v in parameter_dict.items(): + parameter_map[k] = v + + parameter_object.AddParameterMap(parameter_map) return parameter_object diff --git a/src/brainglobe_registration/napari.yaml b/src/brainglobe_registration/napari.yaml index e2cd2d1..a5c3aaa 100644 --- a/src/brainglobe_registration/napari.yaml +++ b/src/brainglobe_registration/napari.yaml @@ -1,10 +1,14 @@ name: brainglobe-registration -display_name: BrainGlobe Elastix Registration +display_name: BrainGlobe Registration contributions: commands: - - id: brainglobe-registration.register - python_name: brainglobe_registration._widget:register - title: BrainGlobe Elastix Registration + - id: brainglobe-registration.make_registration_widget + python_name: brainglobe_registration.registration_widget:RegistrationWidget + title: BrainGlobe Registration + sample_data: + - key: example + display_name: Sample Brain Slice + uri: src/brainglobe_registration/resources/sample_hipp.tif widgets: - - command: brainglobe-registration.register - display_name: BrainGlobe Elastix Registration + - command: brainglobe-registration.make_registration_widget + display_name: BrainGlobe Registration diff --git a/src/brainglobe_registration/parameters/ara_tools/affine.txt b/src/brainglobe_registration/parameters/ara_tools/affine.txt new file mode 100644 index 0000000..cabe450 --- /dev/null +++ b/src/brainglobe_registration/parameters/ara_tools/affine.txt @@ -0,0 +1,66 @@ +//Affine Transformation - updated May 2012 + +// Description: affine, MI, ASGD + +//ImageTypes +(FixedInternalImagePixelType "float") +(FixedImageDimension 3) +(MovingInternalImagePixelType "float") +(MovingImageDimension 3) + +//Components +(Registration "MultiResolutionRegistration") +(FixedImagePyramid "FixedSmoothingImagePyramid") +(MovingImagePyramid "MovingSmoothingImagePyramid") +(Interpolator "BSplineInterpolator") +(Metric "AdvancedMattesMutualInformation") +(Optimizer "AdaptiveStochasticGradientDescent") +(ResampleInterpolator "FinalBSplineInterpolator") +(Resampler "DefaultResampler") +(Transform "AffineTransform") + +(ErodeMask "false" ) + +(NumberOfResolutions 4) + +(HowToCombineTransforms "Compose") +(AutomaticTransformInitialization "true") +(AutomaticScalesEstimation "true") + +(WriteTransformParametersEachIteration "false") +(WriteResultImage "true") +(ResultImageFormat "tiff") +(CompressResultImage "false") +(WriteResultImageAfterEachResolution "false") +(ShowExactMetricValue "false") + +//Maximum number of iterations in each resolution level: +(MaximumNumberOfIterations 500 ) + +//Number of grey level bins in each resolution level: +(NumberOfHistogramBins 32 ) +(FixedLimitRangeRatio 0.0) +(MovingLimitRangeRatio 0.0) +(FixedKernelBSplineOrder 3) +(MovingKernelBSplineOrder 3) + +//Number of spatial samples used to compute the mutual information in each resolution level: +(ImageSampler "RandomCoordinate") +(FixedImageBSplineInterpolationOrder 3) +(UseRandomSampleRegion "false") +(NumberOfSpatialSamples 4000 ) +(NewSamplesEveryIteration "true") +(CheckNumberOfSamples "true") +(MaximumNumberOfSamplingAttempts 10) + +//Order of B-Spline interpolation used in each resolution level: +(BSplineInterpolationOrder 3) + +//Order of B-Spline interpolation used for applying the final deformation: +(FinalBSplineInterpolationOrder 3) + +//Default pixel value for pixels that come from outside the picture: +(DefaultPixelValue 0) + +//SP: Param_A in each resolution level. a_k = a/(A+k+1)^alpha +(SP_A 20.0 ) diff --git a/src/brainglobe_registration/parameters/ara_tools/bspline.txt b/src/brainglobe_registration/parameters/ara_tools/bspline.txt new file mode 100644 index 0000000..e87a225 --- /dev/null +++ b/src/brainglobe_registration/parameters/ara_tools/bspline.txt @@ -0,0 +1,74 @@ +//Bspline Transformation - updated May 2012 + +//ImageTypes +(FixedInternalImagePixelType "float") +(FixedImageDimension 3) +(MovingInternalImagePixelType "float") +(MovingImageDimension 3) + +//Components +(Registration "MultiResolutionRegistration") +(FixedImagePyramid "FixedSmoothingImagePyramid") +(MovingImagePyramid "MovingSmoothingImagePyramid") +(Interpolator "BSplineInterpolator") +(Metric "AdvancedMattesMutualInformation") +(Optimizer "StandardGradientDescent") +(ResampleInterpolator "FinalBSplineInterpolator") +(Resampler "DefaultResampler") +(Transform "BSplineTransform") + +(ErodeMask "false" ) + +(NumberOfResolutions 3) +(FinalGridSpacingInVoxels 25.000000 25.000000 25.000000) + +(HowToCombineTransforms "Compose") + +(WriteTransformParametersEachIteration "false") +(ResultImageFormat "tiff") +(WriteResultImage "true") +(CompressResultImage "false") +(WriteResultImageAfterEachResolution "false") +(ShowExactMetricValue "false") +(WriteDiffusionFiles "true") + +// Option supported in elastix 4.1: +(UseFastAndLowMemoryVersion "true") + +//Maximum number of iterations in each resolution level: +(MaximumNumberOfIterations 5000 ) + +//Number of grey level bins in each resolution level: +(NumberOfHistogramBins 32 ) +(FixedLimitRangeRatio 0.0) +(MovingLimitRangeRatio 0.0) +(FixedKernelBSplineOrder 3) +(MovingKernelBSplineOrder 3) + +//Number of spatial samples used to compute the mutual information in each resolution level: +(ImageSampler "RandomCoordinate") +(FixedImageBSplineInterpolationOrder 1 ) +(UseRandomSampleRegion "true") +(SampleRegionSize 50.0 50.0 50.0) +(NumberOfSpatialSamples 10000 ) +(NewSamplesEveryIteration "true") +(CheckNumberOfSamples "true") +(MaximumNumberOfSamplingAttempts 10) + +//Order of B-Spline interpolation used in each resolution level: +(BSplineInterpolationOrder 3) + +//Order of B-Spline interpolation used for applying the final deformation: +(FinalBSplineInterpolationOrder 3) + +//Default pixel value for pixels that come from outside the picture: +(DefaultPixelValue 0) + +//SP: Param_a in each resolution level. a_k = a/(A+k+1)^alpha +(SP_a 10000.0 ) + +//SP: Param_A in each resolution level. a_k = a/(A+k+1)^alpha +(SP_A 100.0 ) + +//SP: Param_alpha in each resolution level. a_k = a/(A+k+1)^alpha +(SP_alpha 0.6 ) diff --git a/src/brainglobe_registration/parameters/brainregister_IBL/affine.txt b/src/brainglobe_registration/parameters/brainregister_IBL/affine.txt new file mode 100644 index 0000000..1e81449 --- /dev/null +++ b/src/brainglobe_registration/parameters/brainregister_IBL/affine.txt @@ -0,0 +1,237 @@ +// ********** Affine Transformation ********** +// -------------------------------------------------------------------------------- +// Optimised Affine transform for Mouse Brain serial section 2-photon datasets +// +// Steven J. West, SWC, UCL, UK 2020 + + +// ********** ImageTypes ********** +// -------------------------------------------------------------------------------- + +(FixedInternalImagePixelType "float") // automatically converted to this type +(FixedImageDimension 3) + +(MovingInternalImagePixelType "float") // automatically converted to this type +(MovingImageDimension 3) + +(UseDirectionCosines "true") +// Setting it to false means that you choose to ignore important information +// from the image, which relates voxel coordinates to world coordinates +// Ignoring it may easily lead to left/right swaps for example, which could +// screw up a (medical) analysis + + +// ********** Registration ********** +// -------------------------------------------------------------------------------- + +(Registration "MultiResolutionRegistration") +// the default + + +// ********** Pyramid ********** +// -------------------------------------------------------------------------------- + +(FixedImagePyramid "FixedSmoothingImagePyramid") +// Applies gaussian smoothing and no down-sampling + +(MovingImagePyramid "MovingSmoothingImagePyramid") +// Applies gaussian smoothing and no down-sampling + +(NumberOfResolutions 4) +// 4 levels + +(ImagePyramidSchedule 8 8 8 4 4 4 2 2 2 1 1 1) +// sigma: 8/2 XYZ, 4/2 XYZ, 2/2 XYZ, 1/2 XYZ + +(ErodeMask "false" ) +// information from inside any mask will flow into the ROI due to the +// smoothing step + + +// ********** Metric ********** +// -------------------------------------------------------------------------------- + +(Metric "AdvancedMattesMutualInformation") +// Cost Function Metric +// quantifies the "amount of information" (in units of shannons, commonly called +// bits) obtained about one random variable through observing the other random +// variable +// only a relation between the probability distributions of the intensities of +// the fixed and moving image is assumed +// often a good choice for image registration + +(NumberOfHistogramBins 32 ) +(NumberOfFixedHistogramBins 32 ) +(NumberOfMovingHistogramBins 32 ) +// The size of the histogram. Must be given for each resolution, or for all +// resolutions at once + +(FixedKernelBSplineOrder 3) +(MovingKernelBSplineOrder 3) +// The B-spline order of the Parzen window, used to estimate the joint histogram + +(FixedLimitRangeRatio 0.0) +(MovingLimitRangeRatio 0.0) +// The relative extension of the intensity range of the fixed image. +// 0.0 - turned off + + +(ShowExactMetricValue "false" "false" "false" "false") +// get exact metric on final resolution +// computes the exact metric value (computed on all voxels rather than on the +// set of spatial samples) and shows it each iteration +// Must be given for each resolution +// This is very SLOW for large images + + +(UseMultiThreadingForMetrics "true") +// Whether to compute metric with multi-threading + +(UseFastAndLowMemoryVersion "true") +// select between two methods for computing mutual information metric +// false : computes the derivatives of the joint histogram to each transformation +// parameter +// true : computes the mutual information via another route + +(UseJacobianPreconditioning "false") +// whether to implement the preconditioning technique by Nicholas Tustison: +// "Directly Manipulated Freeform Deformations" + +(FiniteDifferenceDerivative "false") +// Experimental feature, do not use. + +(ASGDParameterEstimationMethod "Original") +// ASGD parameter estimation method used in this optimizer + + +// ********** ImageSampler ********** +// -------------------------------------------------------------------------------- + +(ImageSampler "RandomCoordinate") + +(NumberOfSpatialSamples 4000 ) +// Number of spatial samples used to compute +// the mutual information in each resolution level + +(NewSamplesEveryIteration "true" "true" "true" "true") +// whether to select a new set of spatial samples in every iteration + +(UseRandomSampleRegion "false") +// whether to randomly select a subregion of the image in each iteration + +(CheckNumberOfSamples "true") +// whether to check if at least a certain fraction (default 1/4) of the samples map +// inside the moving image. + +(MaximumNumberOfSamplingAttempts 10 10 10 10) +// maximum number of sampling attempts + + +// ********** Interpolator and Resampler ********** +// -------------------------------------------------------------------------------- + +(Interpolator "BSplineInterpolator") +// The interpolator to use during registration process +// BSpline : Evaluates the Values of non-voxel Positions in the Moving Image +// Basis Function for Splines - set of Piecewise Polynomial Lines + +(BSplineInterpolationOrder 3) +// Order of B-Spline interpolation used in each resolution level +// 0 Nearest Neighbour, 1 Linear interpolation, +// 2+ non-linear curves with increasing degrees of freedom/power + + +// Order of B-Spline interpolation used when interpolating the fixed + // image - if using MultiInputRandomCoordinate sampler +(FixedImageBSplineInterpolationOrder 3) + +//Default pixel value for pixels that come from outside the picture: +(DefaultPixelValue 0) + +(Resampler "DefaultResampler") +// Either DefaultResampler or OpenCLResampler + +(ResampleInterpolator "FinalBSplineInterpolator") +// The interpolator to use to generate the resulting deformed moving image +// BSpline : Evaluates the Values of non-voxel Positions in the Moving Image +// Basis Function for Splines - set of Piecewise Polynomial Lines + +(FinalBSplineInterpolationOrder 3) +// Order of B-Spline interpolation used for applying the final deformation +// 0 Nearest Neighbour, 1 Linear interpolation, +// 2+ non-linear curves with increasing degrees of freedom/power + + +// ********** Transform ********** +// -------------------------------------------------------------------------------- + +(Transform "AffineTransform") +// translate, rotate, scale, shear + +(AutomaticScalesEstimation "true") +// if "true" the Scales parameter is ignored and the scales are determined +// automatically. + +(AutomaticTransformInitialization "true") +// whether the initial translation between images should be estimated as the +// distance between their centers. + +(AutomaticTransformInitializationMethod "GeometricalCenter") +// how to initialize this transform + +(HowToCombineTransforms "Compose") +// Always use Compose for combining transforms + + +// ********** Optimizer ********** +// -------------------------------------------------------------------------------- + +(Optimizer "AdaptiveStochasticGradientDescent") +// take the search direction as the negative gradient of the cost function +// Adaptive version: requires less parameters to be set and tends to be +// more robust. + +(MaximumNumberOfIterations 500 500 500 500) +// Maximum number of iterations in each resolution level + +(SP_A 20.0) +// SP: Param_A in each resolution level. a_k = a/(A+k+1)^alpha + +(SigmoidInitialTime 0.0) +// initial time input for the sigmoid +// When increased, the optimization starts with smaller steps +// If set to 0.0, the method starts with with the largest step allowed + +(MaxBandCovSize 192) +(NumberOfBandStructureSamples 10) +(UseAdaptiveStepSizes "true") +(AutomaticParameterEstimation "true") +(UseConstantStep "false") +(MaximumStepLengthRatio 1) +(NumberOfGradientMeasurements 0) +(NumberOfJacobianMeasurements 1000) +(NumberOfSamplesForExactGradient 100000) +(SigmoidScaleFactor 0.1) + + +// ********** Output ********** +// -------------------------------------------------------------------------------- + +(WriteResultImage "true") +// Whether to write the final deformed image when elastix has optimised the +// transformation. + +(ResultImageFormat "tiff") // commented out as not writing any images +// What image format to write the image as +// can use: "tiff" "dcm" "nrrd" "png" + +// (CompressResultImage "false") +// whether lossless compression of the written image is performed + + +(WriteTransformParametersEachIteration "false") +// whether to save a transform parameter file to disk in every iteration + +(WriteResultImageAfterEachResolution "false" "false" "false" "false") +// whether the intermediate result image is resampled and written after +// each resolution diff --git a/src/brainglobe_registration/parameters/brainregister_IBL/bspline.txt b/src/brainglobe_registration/parameters/brainregister_IBL/bspline.txt new file mode 100644 index 0000000..578e24d --- /dev/null +++ b/src/brainglobe_registration/parameters/brainregister_IBL/bspline.txt @@ -0,0 +1,223 @@ +// ********** B-Spline Transformation ********** +// -------------------------------------------------------------------------------- +// Optimised B-Spline transform for Mouse Brain serial section 2-photon datasets +// +// Steven J. West, SWC, UCL, UK 2020 + + +// ********** ImageTypes ********** +// -------------------------------------------------------------------------------- + +(FixedInternalImagePixelType "float") // automatically converted to this type +(FixedImageDimension 3) + +(MovingInternalImagePixelType "float") // automatically converted to this type +(MovingImageDimension 3) + +(UseDirectionCosines "true") +// Setting it to false means that you choose to ignore important information +// from the image, which relates voxel coordinates to world coordinates +// Ignoring it may easily lead to left/right swaps for example, which could +// screw up a (medical) analysis + + +// ********** Registration ********** +// -------------------------------------------------------------------------------- + +(Registration "MultiResolutionRegistration") +// the default + + +// ********** Pyramid ********** +// -------------------------------------------------------------------------------- + +(FixedImagePyramid "FixedSmoothingImagePyramid") +// Applies gaussian smoothing and no down-sampling + +(MovingImagePyramid "MovingSmoothingImagePyramid") +// Applies gaussian smoothing and no down-sampling + +(NumberOfResolutions 3) +// 3 levels + +(ImagePyramidSchedule 4 4 4 2 2 2 1 1 1) +// sigma: 4/2 XYZ, 2/2 XYZ, 1/2 XYZ + +(ErodeMask "false" ) +// information from inside any mask will flow into the ROI due to the +// smoothing step + + +// ********** Metric ********** +// -------------------------------------------------------------------------------- + +(Metric "AdvancedMattesMutualInformation") +// Cost Function Metric +// quantifies the "amount of information" (in units of shannons, commonly called +// bits) obtained about one random variable through observing the other random +// variable +// only a relation between the probability distributions of the intensities of +// the fixed and moving image is assumed +// often a good choice for image registration + +(NumberOfFixedHistogramBins 32 ) +(NumberOfMovingHistogramBins 32 ) +// The size of the histogram. Must be given for each resolution, or for all +// resolutions at once + + +(FixedKernelBSplineOrder 3) +(MovingKernelBSplineOrder 3) +// The B-spline order of the Parzen window, used to estimate the joint histogram + +(FixedLimitRangeRatio 0.0) +(MovingLimitRangeRatio 0.0) +// The relative extension of the intensity range of the fixed image. +// 0.0 - turned off + +(ShowExactMetricValue "false" "false" "false") +// get exact metric on final resolution +// computes the exact metric value (computed on all voxels rather than on the +// set of spatial samples) and shows it each iteration +// Must be given for each resolution +// This is very SLOW for large images + +(UseMultiThreadingForMetrics "true") +// Whether to compute metric with multi-threading + +(UseFastAndLowMemoryVersion "true") +// select between two methods for computing mutual information metric +// false : computes the derivatives of the joint histogram to each transformation +// parameter +// true : computes the mutual information via another route + +(UseJacobianPreconditioning "false") +// whether to implement the preconditioning technique by Nicholas Tustison: +// "Directly Manipulated Freeform Deformations" + +(FiniteDifferenceDerivative "false") +// Experimental feature, do not use. + + +// ********** ImageSampler ********** +// -------------------------------------------------------------------------------- + +(ImageSampler "RandomCoordinate") + +(NumberOfSpatialSamples 10000 ) +// Number of spatial samples used to compute +// the mutual information in each resolution level + +(NewSamplesEveryIteration "true" "true" "true" "true") +// whether to select a new set of spatial samples in every iteration + +(UseRandomSampleRegion "false") +// whether to randomly select a subregion of the image in each iteration + +(CheckNumberOfSamples "true") +// whether to check if at least a certain fraction (default 1/4) of the samples map +// inside the moving image. + +(MaximumNumberOfSamplingAttempts 10 10 10) +// maximum number of sampling attempts + + +// ********** Interpolator and Resampler ********** +// -------------------------------------------------------------------------------- + +(Interpolator "BSplineInterpolator") +// The interpolator to use during registration process +// BSpline : Evaluates the Values of non-voxel Positions in the Moving Image +// Basis Function for Splines - set of Piecewise Polynomial Lines + +(BSplineInterpolationOrder 3) +// Order of B-Spline interpolation used in each resolution level +// 0 Nearest Neighbour, 1 Linear interpolation, +// 2+ non-linear curves with increasing degrees of freedom/power + + +// Order of B-Spline interpolation used when interpolating the fixed + // image - if using MultiInputRandomCoordinate sampler +// (FixedImageBSplineInterpolationOrder 3) + +//Default pixel value for pixels that come from outside the picture: +(DefaultPixelValue 0) + +(Resampler "DefaultResampler") +// Either DefaultResampler or OpenCLResampler + +(ResampleInterpolator "FinalBSplineInterpolator") +// The interpolator to use to generate the resulting deformed moving image +// BSpline : Evaluates the Values of non-voxel Positions in the Moving Image +// Basis Function for Splines - set of Piecewise Polynomial Lines + +(FinalBSplineInterpolationOrder 3) +// Order of B-Spline interpolation used for applying the final deformation +// 0 Nearest Neighbour, 1 Linear interpolation, +// 2+ non-linear curves with increasing degrees of freedom/power + + +// ********** Transform ********** +// -------------------------------------------------------------------------------- + +(Transform "BSplineTransform") +// Set of control points are defined on a regular grid, overlayed on the +// fixed image +// Control Point Grid is spaced according to n-dimensional vector +// Number of Control Points in each dimension is approx. the image length/spacing, +// plus extra points at each end +// Pixels are LOCALLY Transformed by the B-splines at surrounding control points +// This models local transformations, and is fast to compute +// B-spline coefficients pk are the B-Spline PARAMETERS +// Number of coefficients is the number of control points x number of dimensions +// Coefficients are ordered by coefficient index first +// (p1x, p2x..., p1y, p2y..., p1z, p2z) + +(FinalGridSpacingInVoxels 25.000000 25.000000 25.000000) +// grid spacing of the B-spline transform for each dimension +// spacing is in "voxel size units" + +(HowToCombineTransforms "Compose") +// Always use Compose for combining transforms + + +// ********** Optimizer ********** +// -------------------------------------------------------------------------------- + +(Optimizer "StandardGradientDescent") +// take the search direction as the negative gradient of the cost function + +(MaximumNumberOfIterations 5000 5000 5000) +// Maximum number of iterations in each resolution level + +(SP_a 10000.0 ) +// Param_a in each resolution level. a_k = a/(A+k+1)^alpha + +(SP_A 100.0 ) +// Param_A in each resolution level. a_k = a/(A+k+1)^alpha + +(SP_alpha 0.6 ) +// Param_alpha in each resolution level. a_k = a/(A+k+1)^alpha + + +// ********** Output ********** +// -------------------------------------------------------------------------------- + +(WriteResultImage "true") +// Whether to write the final deformed image when elastix has optimised the +// transformation. + +(ResultImageFormat "tiff") // commented out as not writing any images +// What image format to write the image as +// can use: "tiff" "dcm" "nrrd" "png" + +// (CompressResultImage "false") +// whether lossless compression of the written image is performed + + +(WriteTransformParametersEachIteration "false") +// whether to save a transform parameter file to disk in every iteration + +(WriteResultImageAfterEachResolution "false" "false" "false" "false") +// whether the intermediate result image is resampled and written after +// each resolution diff --git a/src/brainglobe_registration/parameters/elastix_default/affine.txt b/src/brainglobe_registration/parameters/elastix_default/affine.txt new file mode 100644 index 0000000..77e2e87 --- /dev/null +++ b/src/brainglobe_registration/parameters/elastix_default/affine.txt @@ -0,0 +1,24 @@ +(AutomaticParameterEstimation "true") +(AutomaticScalesEstimation "true") +(CheckNumberOfSamples "true") +(DefaultPixelValue 0) +(FinalBSplineInterpolationOrder 3) +(FixedImagePyramid "FixedSmoothingImagePyramid") +(ImageSampler "RandomCoordinate") +(Interpolator "LinearInterpolator") +(MaximumNumberOfIterations 256) +(MaximumNumberOfSamplingAttempts 8) +(Metric "AdvancedMattesMutualInformation") +(MovingImagePyramid "MovingSmoothingImagePyramid") +(NewSamplesEveryIteration "true") +(NumberOfResolutions 4) +(NumberOfSamplesForExactGradient 4096) +(NumberOfSpatialSamples 2048) +(Optimizer "AdaptiveStochasticGradientDescent") +(Registration "MultiResolutionRegistration") +(ResampleInterpolator "FinalBSplineInterpolator") +(Resampler "DefaultResampler") +(ResultImageFormat "nii") +(Transform "AffineTransform") +(WriteIterationInfo "false") +(WriteResultImage "true") diff --git a/src/brainglobe_registration/parameters/elastix_default/bspline.txt b/src/brainglobe_registration/parameters/elastix_default/bspline.txt new file mode 100644 index 0000000..c14fa00 --- /dev/null +++ b/src/brainglobe_registration/parameters/elastix_default/bspline.txt @@ -0,0 +1,27 @@ +(AutomaticParameterEstimation "true") +(CheckNumberOfSamples "true") +(DefaultPixelValue 0) +(FinalBSplineInterpolationOrder 3) +(FinalGridSpacingInPhysicalUnits 10.000000) +(FixedImagePyramid "FixedSmoothingImagePyramid") +(GridSpacingSchedule 2.803221 1.988100 1.410000 1.000000) +(ImageSampler "RandomCoordinate") +(Interpolator "LinearInterpolator") +(MaximumNumberOfIterations 256) +(MaximumNumberOfSamplingAttempts 8) +(Metric "AdvancedMattesMutualInformation" "TransformBendingEnergyPenalty") +(Metric0Weight 1.0) +(Metric1Weight 1.0) +(MovingImagePyramid "MovingSmoothingImagePyramid") +(NewSamplesEveryIteration "true") +(NumberOfResolutions 4) +(NumberOfSamplesForExactGradient 4096) +(NumberOfSpatialSamples 2048) +(Optimizer "AdaptiveStochasticGradientDescent") +(Registration "MultiMetricMultiResolutionRegistration") +(ResampleInterpolator "FinalBSplineInterpolator") +(Resampler "DefaultResampler") +(ResultImageFormat "nii") +(Transform "BSplineTransform") +(WriteIterationInfo "false") +(WriteResultImage "true") diff --git a/src/brainglobe_registration/parameters/elastix_default/rigid.txt b/src/brainglobe_registration/parameters/elastix_default/rigid.txt new file mode 100644 index 0000000..80ed2c9 --- /dev/null +++ b/src/brainglobe_registration/parameters/elastix_default/rigid.txt @@ -0,0 +1,24 @@ +(AutomaticParameterEstimation "true") +(AutomaticScalesEstimation "true") +(CheckNumberOfSamples "true") +(DefaultPixelValue 0) +(FinalBSplineInterpolationOrder 3) +(FixedImagePyramid "FixedSmoothingImagePyramid") +(ImageSampler "RandomCoordinate") +(Interpolator "LinearInterpolator") +(MaximumNumberOfIterations 256) +(MaximumNumberOfSamplingAttempts 8) +(Metric "AdvancedMattesMutualInformation") +(MovingImagePyramid "MovingSmoothingImagePyramid") +(NewSamplesEveryIteration "true") +(NumberOfResolutions 4) +(NumberOfSamplesForExactGradient 4096) +(NumberOfSpatialSamples 2048) +(Optimizer "AdaptiveStochasticGradientDescent") +(Registration "MultiResolutionRegistration") +(ResampleInterpolator "FinalBSplineInterpolator") +(Resampler "DefaultResampler") +(ResultImageFormat "nii") +(Transform "EulerTransform") +(WriteIterationInfo "false") +(WriteResultImage "true") diff --git a/src/brainglobe_registration/registration_widget.py b/src/brainglobe_registration/registration_widget.py new file mode 100644 index 0000000..793c723 --- /dev/null +++ b/src/brainglobe_registration/registration_widget.py @@ -0,0 +1,307 @@ +""" +A napari widget to view atlases. + +Atlases that are exposed by the Brainglobe atlas API are +shown in a table view using the Qt model/view framework +[Qt Model/View framework](https://doc.qt.io/qt-6/model-view-programming.html) + +Users can download and add the atlas images/structures as layers to the viewer. +""" + +from pathlib import Path + +import numpy as np + +from brainglobe_registration.elastix.register import run_registration +from brainglobe_registration.widgets.select_images_view import SelectImagesView +from brainglobe_registration.widgets.adjust_moving_image_view import ( + AdjustMovingImageView, +) +from brainglobe_registration.widgets.parameter_list_view import ( + RegistrationParameterListView, +) +from brainglobe_registration.widgets.transform_select_view import ( + TransformSelectView, +) +from brainglobe_registration.utils.utils import ( + adjust_napari_image_layer, + open_parameter_file, + find_layer_index, + get_image_layer_names, +) + +from bg_atlasapi import BrainGlobeAtlas +from bg_atlasapi.list_atlases import get_downloaded_atlases +from napari.viewer import Viewer +from qtpy.QtCore import Qt +from qtpy.QtWidgets import ( + QGroupBox, + QVBoxLayout, + QWidget, + QTabWidget, + QPushButton, +) +from skimage.segmentation import find_boundaries + +from brainglobe_registration.utils.brainglobe_logo import header_widget + + +class RegistrationWidget(QWidget): + def __init__(self, napari_viewer: Viewer): + super().__init__() + + self._viewer = napari_viewer + self._atlas: BrainGlobeAtlas = None + self._moving_image = None + + self.transform_params = {"rigid": {}, "affine": {}, "bspline": {}} + self.transform_selections = [] + + for transform_type in self.transform_params: + file_path = ( + Path(__file__).parent.resolve() + / "parameters" + / "elastix_default" + / f"{transform_type}.txt" + ) + + if file_path.exists(): + self.transform_params[transform_type] = open_parameter_file( + file_path + ) + self.transform_selections.append( + (transform_type, self.transform_params[transform_type]) + ) + + # Hacky way of having an empty first option for the dropdown + self._available_atlases = ["------"] + get_downloaded_atlases() + self._sample_images = get_image_layer_names(self._viewer) + + if len(self._sample_images) > 0: + self._moving_image = self._viewer.layers[0] + else: + self._moving_image = None + + self.setLayout(QVBoxLayout()) + self.layout().addWidget(header_widget()) + + self.main_tabs = QTabWidget(parent=self) + self.main_tabs.setTabPosition(QTabWidget.West) + + self.settings_tab = QGroupBox() + self.settings_tab.setLayout(QVBoxLayout()) + self.parameters_tab = QTabWidget() + + self.get_atlas_widget = SelectImagesView( + available_atlases=self._available_atlases, + sample_image_names=self._sample_images, + parent=self, + ) + self.get_atlas_widget.atlas_index_change.connect( + self._on_atlas_dropdown_index_changed + ) + self.get_atlas_widget.moving_image_index_change.connect( + self._on_sample_dropdown_index_changed + ) + self.get_atlas_widget.sample_image_popup_about_to_show.connect( + self._on_sample_popup_about_to_show + ) + + self.adjust_moving_image_widget = AdjustMovingImageView(parent=self) + self.adjust_moving_image_widget.adjust_image_signal.connect( + self._on_adjust_moving_image + ) + + self.adjust_moving_image_widget.reset_image_signal.connect( + self._on_adjust_moving_image_reset_button_click + ) + + self.transform_select_view = TransformSelectView() + self.transform_select_view.transform_type_added_signal.connect( + self._on_transform_type_added + ) + self.transform_select_view.transform_type_removed_signal.connect( + self._on_transform_type_removed + ) + self.transform_select_view.file_option_changed_signal.connect( + self._on_default_file_selection_change + ) + + self.run_button = QPushButton("Run") + self.run_button.clicked.connect(self._on_run_button_click) + self.run_button.setEnabled(False) + + self.settings_tab.layout().addWidget(self.get_atlas_widget) + self.settings_tab.layout().addWidget(self.adjust_moving_image_widget) + self.settings_tab.layout().addWidget(self.transform_select_view) + self.settings_tab.layout().addWidget(self.run_button) + self.settings_tab.layout().setAlignment(Qt.AlignTop) + + self.parameter_setting_tabs_lists = [] + + for transform_type in self.transform_params: + new_tab = RegistrationParameterListView( + param_dict=self.transform_params[transform_type], + transform_type=transform_type, + ) + + self.parameters_tab.addTab(new_tab, transform_type) + self.parameter_setting_tabs_lists.append(new_tab) + + self.main_tabs.addTab(self.settings_tab, "Settings") + self.main_tabs.addTab(self.parameters_tab, "Parameters") + + self.layout().addWidget(self.main_tabs) + + def _on_atlas_dropdown_index_changed(self, index): + # Hacky way of having an empty first dropdown + if index == 0: + if self._atlas: + curr_atlas_layer_index = find_layer_index( + self._viewer, self._atlas.atlas_name + ) + + self._viewer.layers.pop(curr_atlas_layer_index) + self._atlas = None + self.run_button.setEnabled(False) + self._viewer.grid.enabled = False + + return + + atlas_name = self._available_atlases[index] + atlas = BrainGlobeAtlas(atlas_name) + + if self._atlas: + curr_atlas_layer_index = find_layer_index( + self._viewer, self._atlas.atlas_name + ) + + self._viewer.layers.pop(curr_atlas_layer_index) + else: + self.run_button.setEnabled(True) + + self._viewer.add_image( + atlas.reference, + name=atlas_name, + colormap="gray", + blending="translucent", + ) + + self._atlas = BrainGlobeAtlas(atlas_name=atlas_name) + self._viewer.grid.enabled = True + + def _on_sample_dropdown_index_changed(self, index): + viewer_index = find_layer_index( + self._viewer, self._sample_images[index] + ) + self._moving_image = self._viewer.layers[viewer_index] + + def _on_adjust_moving_image(self, x: int, y: int, rotate: float): + adjust_napari_image_layer(self._moving_image, x, y, rotate) + + def _on_adjust_moving_image_reset_button_click(self): + adjust_napari_image_layer(self._moving_image, 0, 0, 0) + + def _on_run_button_click(self): + current_atlas_slice = self._viewer.dims.current_step[0] + + result, parameters, registered_annotation_image = run_registration( + self._atlas.reference[current_atlas_slice, :, :], + self._moving_image.data, + self._atlas.annotation[current_atlas_slice, :, :], + self.transform_selections, + ) + + boundaries = find_boundaries( + registered_annotation_image, mode="inner" + ).astype(np.int8, copy=False) + + self._viewer.add_image(result, name="Registered Image", visible=False) + + atlas_layer_index = find_layer_index( + self._viewer, self._atlas.atlas_name + ) + self._viewer.layers[atlas_layer_index].visible = False + + self._viewer.add_labels( + registered_annotation_image.astype(np.uint32, copy=False), + name="Registered Annotations", + visible=False, + ) + self._viewer.add_image( + boundaries, + name="Registered Boundaries", + visible=True, + blending="additive", + opacity=0.8, + ) + + self._viewer.grid.enabled = False + + def _on_transform_type_added( + self, transform_type: str, transform_order: int + ) -> None: + if transform_order > len(self.transform_selections): + raise IndexError( + f"Transform added out of order index: {transform_order}" + f" is greater than length: {len(self.transform_selections)}" + ) + elif len(self.parameter_setting_tabs_lists) == transform_order: + self.transform_selections.append( + (transform_type, self.transform_params[transform_type].copy()) + ) + new_tab = RegistrationParameterListView( + param_dict=self.transform_selections[transform_order][1], + transform_type=transform_type, + ) + self.parameters_tab.addTab(new_tab, transform_type) + self.parameter_setting_tabs_lists.append(new_tab) + + else: + self.transform_selections[transform_order] = ( + transform_type, + self.transform_params[transform_type], + ) + self.parameters_tab.setTabText(transform_order, transform_type) + self.parameter_setting_tabs_lists[transform_order].set_data( + self.transform_params[transform_type].copy() + ) + + def _on_transform_type_removed(self, transform_order: int) -> None: + if transform_order >= len(self.transform_selections): + raise IndexError("Transform removed out of order") + else: + self.transform_selections.pop(transform_order) + self.parameters_tab.removeTab(transform_order) + self.parameter_setting_tabs_lists.pop(transform_order) + + def _on_default_file_selection_change( + self, default_file_type: str, index: int + ) -> None: + if index >= len(self.transform_selections): + raise IndexError("Transform file selection out of order") + + transform_type = self.transform_selections[index][0] + file_path = ( + Path(__file__).parent.resolve() + / "parameters" + / default_file_type + / f"{transform_type}.txt" + ) + + if not file_path.exists(): + file_path = ( + Path(__file__).parent.resolve() + / "parameters" + / "elastix_default" + / f"{transform_type}.txt" + ) + + param_dict = open_parameter_file(file_path) + + self.transform_selections[index] = (transform_type, param_dict) + self.parameter_setting_tabs_lists[index].set_data(param_dict) + + def _on_sample_popup_about_to_show(self): + self._sample_images = get_image_layer_names(self._viewer) + self.get_atlas_widget.update_sample_image_names(self._sample_images) diff --git a/src/brainglobe_registration/resources/brainglobe.png b/src/brainglobe_registration/resources/brainglobe.png new file mode 100644 index 0000000..427bdab Binary files /dev/null and b/src/brainglobe_registration/resources/brainglobe.png differ diff --git a/src/brainglobe_registration/resources/sample_hipp.tif b/src/brainglobe_registration/resources/sample_hipp.tif new file mode 100644 index 0000000..d30ff69 Binary files /dev/null and b/src/brainglobe_registration/resources/sample_hipp.tif differ diff --git a/src/brainglobe_registration/_tests/__init__.py b/src/brainglobe_registration/utils/__init__.py similarity index 100% rename from src/brainglobe_registration/_tests/__init__.py rename to src/brainglobe_registration/utils/__init__.py diff --git a/src/brainglobe_registration/utils/brainglobe_logo.py b/src/brainglobe_registration/utils/brainglobe_logo.py new file mode 100644 index 0000000..66740f1 --- /dev/null +++ b/src/brainglobe_registration/utils/brainglobe_logo.py @@ -0,0 +1,51 @@ +""" +Can this be imorted from brainrender-napari? We can also move this to a bgutils maybe? +""" + + +from importlib.resources import files + +from qtpy.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QWidget, QVBoxLayout + +brainglobe_logo = files("brainglobe_registration").joinpath( + "resources/brainglobe.png" +) + +_logo_html = f""" +

+ +<\h1> +""" + + +def _docs_links_widget(parent: QWidget = None): + _docs_links_html = """ +

+

Website

+

Source

+

+ """ # noqa: E501 + docs_links_widget = QLabel(_docs_links_html, parent=parent) + docs_links_widget.setOpenExternalLinks(True) + return docs_links_widget + + +def _logo_widget(parent: QWidget = None): + return QLabel(_logo_html, parent=None) + + +def header_widget(parent: QWidget = None): + box = QGroupBox(parent) + box.setFlat(True) + box.setLayout(QVBoxLayout()) + box.layout().addWidget(QLabel("

brainglobe-registration

")) + subbox = QGroupBox(parent) + subbox.setFlat(True) + subbox.setLayout(QHBoxLayout()) + subbox.layout().setSpacing(0) + subbox.layout().setContentsMargins(0, 0, 0, 0) + subbox.setStyleSheet("border: 0;") + subbox.layout().addWidget(_logo_widget(parent=box)) + subbox.layout().addWidget(_docs_links_widget(parent=box)) + box.layout().addWidget(subbox) + return box diff --git a/src/brainglobe_registration/utils/utils.py b/src/brainglobe_registration/utils/utils.py new file mode 100644 index 0000000..d680c93 --- /dev/null +++ b/src/brainglobe_registration/utils/utils.py @@ -0,0 +1,97 @@ +import napari +from pytransform3d.rotations import active_matrix_from_angle +import numpy as np +from pathlib import Path +from typing import List + + +def adjust_napari_image_layer( + image_layer: napari.layers.Image, x: int, y: int, rotate: float +): + """ + Adjusts the napari image layer by the given x, y, and rotation values. + + This function takes in a napari image layer and modifies its translate + and affine properties based on the provided x, y, and rotation values. + The rotation is performed around the center of the image layer. + + Rotation around origin code adapted from: + https://forum.image.sc/t/napari-3d-rotation-center-change-and-scaling/66347/5 + + Parameters + ---------- + image_layer : napari.layers.Layer + The napari image layer to be adjusted. + x : int + The x-coordinate for the translation. + y : int + The y-coordinate for the translation. + rotate : float + The angle of rotation in degrees. + + Returns + ------- + None + """ + image_layer.translate = (y, x) + + rotation_matrix = active_matrix_from_angle(2, np.deg2rad(rotate)) + translate_matrix = np.eye(3) + origin = np.asarray(image_layer.data.shape) // 2 + np.asarray([y, x]) + translate_matrix[:2, -1] = origin + transform_matrix = ( + translate_matrix @ rotation_matrix @ np.linalg.inv(translate_matrix) + ) + image_layer.affine = transform_matrix + + +def open_parameter_file(file_path: Path) -> dict: + """ + Opens the parameter file and returns the parameter dictionary. + + This function reads a parameter file and extracts the parameters into + a dictionary. The parameter file is expected to have lines in the format + "(key value1 value2 ...)". Any line not starting with "(" is ignored. + The values are stripped of any trailing ")" and leading or trailing quotes. + + Parameters + ---------- + file_path : Path + The path to the parameter file. + + Returns + ------- + dict + A dictionary containing the parameters from the file. + """ + with open(file_path, "r") as f: + param_dict = {} + for line in f.readlines(): + if line[0] == "(": + split_line = line[1:-1].split() + cleaned_params = [] + for i, entry in enumerate(split_line[1:]): + if entry == ")" or entry[0] == "/": + break + + cleaned_params.append(entry.strip('" )')) + + param_dict[split_line[0]] = cleaned_params + + return param_dict + + +def find_layer_index(viewer: napari.Viewer, layer_name: str) -> int: + """Finds the index of a layer in the napari viewer.""" + for idx, layer in enumerate(viewer.layers): + if layer.name == layer_name: + return idx + + return -1 + + +def get_image_layer_names(viewer: napari.Viewer) -> List[str]: + """ + Returns a list of the names of the napari image layers in the viewer. + """ + return [layer.name for layer in viewer.layers] diff --git a/src/brainglobe_registration/widgets/__init__.py b/src/brainglobe_registration/widgets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/brainglobe_registration/widgets/adjust_moving_image_view.py b/src/brainglobe_registration/widgets/adjust_moving_image_view.py new file mode 100644 index 0000000..c81f32f --- /dev/null +++ b/src/brainglobe_registration/widgets/adjust_moving_image_view.py @@ -0,0 +1,104 @@ +from qtpy.QtCore import Signal + +from qtpy.QtWidgets import ( + QFormLayout, + QSpinBox, + QDoubleSpinBox, + QPushButton, + QWidget, + QLabel, +) + + +class AdjustMovingImageView(QWidget): + """ + A QWidget subclass that provides controls for adjusting the moving image. + + This widget provides controls for adjusting the x and y offsets and rotation + of the moving image. It emits signals when the image is adjusted or reset. + + Attributes + ---------- + adjust_image_signal : Signal + Emitted when the image is adjusted. The signal includes the x and y offsets + and rotation as parameters. + reset_image_signal : Signal + Emitted when the image is reset. + + Methods + ------- + _on_adjust_image(): + Emits the adjust_image_signal with the current x and y offsets and rotation. + _on_reset_image_button_click(): + Resets the x and y offsets and rotation to 0 and emits the reset_image_signal. + """ + + adjust_image_signal = Signal(int, int, float) + reset_image_signal = Signal() + + def __init__(self, parent=None): + """ + Initialize the widget. + + Parameters + ---------- + parent : QWidget, optional + The parent widget, by default None + """ + super().__init__(parent=parent) + + self.setLayout(QFormLayout()) + + offset_range = 2000 + rotation_range = 360 + + self.adjust_moving_image_x = QSpinBox() + self.adjust_moving_image_x.setRange(-offset_range, offset_range) + self.adjust_moving_image_x.valueChanged.connect(self._on_adjust_image) + + self.adjust_moving_image_y = QSpinBox() + self.adjust_moving_image_y.setRange(-offset_range, offset_range) + self.adjust_moving_image_y.valueChanged.connect(self._on_adjust_image) + + self.adjust_moving_image_rotate = QDoubleSpinBox() + self.adjust_moving_image_rotate.setRange( + -rotation_range, rotation_range + ) + self.adjust_moving_image_rotate.setSingleStep(0.5) + self.adjust_moving_image_rotate.valueChanged.connect( + self._on_adjust_image + ) + + self.adjust_moving_image_reset_button = QPushButton(parent=self) + self.adjust_moving_image_reset_button.setText("Reset Image") + self.adjust_moving_image_reset_button.clicked.connect( + self._on_reset_image_button_click + ) + + self.layout().addRow(QLabel("Adjust the moving image: ")) + self.layout().addRow("X offset:", self.adjust_moving_image_x) + self.layout().addRow("Y offset:", self.adjust_moving_image_y) + self.layout().addRow( + "Rotation (degrees):", self.adjust_moving_image_rotate + ) + self.layout().addRow(self.adjust_moving_image_reset_button) + + def _on_adjust_image(self): + """ + Emit the adjust_image_signal with the current x and y offsets and rotation. + """ + self.adjust_image_signal.emit( + self.adjust_moving_image_x.value(), + self.adjust_moving_image_y.value(), + self.adjust_moving_image_rotate.value(), + ) + + def _on_reset_image_button_click(self): + """ + Reset the x and y offsets and rotation to 0 and emit the reset_image_signal. + """ + self.adjust_moving_image_x.setValue(0) + self.adjust_moving_image_y.setValue(0) + self.adjust_moving_image_rotate.setValue(0) + + self.reset_image_signal.emit() diff --git a/src/brainglobe_registration/widgets/parameter_list_view.py b/src/brainglobe_registration/widgets/parameter_list_view.py new file mode 100644 index 0000000..58fc063 --- /dev/null +++ b/src/brainglobe_registration/widgets/parameter_list_view.py @@ -0,0 +1,96 @@ +from qtpy.QtWidgets import ( + QTableWidget, + QTableWidgetItem, +) + + +class RegistrationParameterListView(QTableWidget): + """ + A QTableWidget subclass that displays and manages registration parameters. + + This widget displays a table of registration parameters and their values. + The parameters can be edited directly in the table. When a parameter value is + changed, the parameter dictionary is updated. + + Attributes + ---------- + param_dict : dict + The dictionary of registration parameters. + transform_type : str + The transform type. + + Methods + ------- + set_data(param_dict): + Sets the data in the table from the parameter dictionary. + _on_cell_change(row, column): + Updates the parameter dictionary when a cell in the table is changed. + """ + + def __init__(self, param_dict: dict, transform_type: str, parent=None): + """ + Initialize the widget. + + Parameters + ---------- + param_dict : dict + The dictionary of registration parameters. + transform_type : str + The type of transform being used. + parent : QWidget, optional + The parent widget, by default None + """ + super().__init__(parent) + self.param_dict = {} + self.transform_type = transform_type + self.setColumnCount(2) + + self.set_data(param_dict) + self.setHorizontalHeaderItem(0, QTableWidgetItem("Parameter")) + self.setHorizontalHeaderItem(1, QTableWidgetItem("Values")) + + self.cellChanged.connect(self._on_cell_change) + + def set_data(self, param_dict): + """ + Sets the data in the table from the parameter dictionary. + + Parameters + ---------- + param_dict : dict + The dictionary of registration parameters. + """ + self.clear() + self.setRowCount(len(param_dict) + 1) + for i, k in enumerate(param_dict): + new_param = QTableWidgetItem(k) + new_value = QTableWidgetItem(", ".join(param_dict[k])) + + self.setItem(i, 0, new_param) + self.setItem(i, 1, new_value) + + self.resizeColumnsToContents() + self.resizeRowsToContents() + self.param_dict = param_dict + + def _on_cell_change(self, row, column): + """ + Updates the parameter dictionary when a cell in the table is changed. + + Parameters + ---------- + row : int + The row of the changed cell. + column : int + The column of the changed cell. + """ + if column == 1 and self.item(row, 0): + parameter = self.item(row, 0).text() + value = self.item(row, 1).text().split(", ") + self.param_dict[parameter] = value + + if row == self.rowCount() - 1: + self.setRowCount(self.rowCount() + 1) + # TODO - add a way to remove rows if they are empty removing + # them from the param dictionary (might have to save the parameter when + # it is selected) diff --git a/src/brainglobe_registration/widgets/select_images_view.py b/src/brainglobe_registration/widgets/select_images_view.py new file mode 100644 index 0000000..dfa2ab8 --- /dev/null +++ b/src/brainglobe_registration/widgets/select_images_view.py @@ -0,0 +1,121 @@ +from qtpy.QtCore import Signal +from typing import List +from qtpy.QtWidgets import ( + QWidget, + QVBoxLayout, + QLabel, + QComboBox, +) + + +class SampleImageComboBox(QComboBox): + popout_about_to_show = Signal() + + def __init__(self, parent=None): + super().__init__(parent=parent) + + def showPopup(self): + self.popout_about_to_show.emit() + super().showPopup() + + +class SelectImagesView(QWidget): + """ + A QWidget subclass that provides a dropdown menu for selecting the image and atlas + for registration. + + This widget provides two dropdown menus for selecting the atlas and the sample + to be used for registration. It emits signals when the selected atlas or sample + image changes. + + Attributes + ---------- + atlas_index_change : Signal + Emitted when the selected atlas changes. The signal includes the index of the + selected atlas. + moving_image_index_change : Signal + Emitted when the selected sample image changes. The signal includes the index + of the selected image. + + Methods + ------- + _on_atlas_index_change(): + Emits the atlas_index_change signal with the index of the selected atlas. + _on_moving_image_index_change(): + Emits the moving_image_index_change signal with the index of the selected image. + """ + + atlas_index_change = Signal(int) + moving_image_index_change = Signal(int) + sample_image_popup_about_to_show = Signal() + + def __init__( + self, + available_atlases: List[str], + sample_image_names: List[str], + parent: QWidget = None, + ): + """ + Initialize the widget. + + Parameters + ---------- + available_atlases : List[str] + The list of available atlases. + sample_image_names : List[str] + The list of available sample images. + parent : QWidget, optional + The parent widget, by default None + """ + super().__init__(parent) + + self.setLayout(QVBoxLayout()) + self.available_atlas_dropdown_label = QLabel("Select Atlas:") + self.available_atlas_dropdown = QComboBox(parent=self) + self.available_atlas_dropdown.addItems(available_atlases) + + self.available_sample_dropdown_label = QLabel("Select sample:") + self.available_sample_dropdown = SampleImageComboBox(parent=self) + self.available_sample_dropdown.addItems(sample_image_names) + + self.available_atlas_dropdown.currentIndexChanged.connect( + self._on_atlas_index_change + ) + self.available_sample_dropdown.currentIndexChanged.connect( + self._on_moving_image_index_change + ) + self.available_sample_dropdown.popout_about_to_show.connect( + self._on_sample_popup_about_to_show + ) + + self.layout().addWidget(self.available_atlas_dropdown_label) + self.layout().addWidget(self.available_atlas_dropdown) + self.layout().addWidget(self.available_sample_dropdown_label) + self.layout().addWidget(self.available_sample_dropdown) + + def update_sample_image_names(self, sample_image_names: List[str]): + self.available_sample_dropdown.clear() + self.available_sample_dropdown.addItems(sample_image_names) + + def _on_atlas_index_change(self): + """ + Emit the atlas_index_change signal with the index of the selected atlas. + + If the selected index is invalid, emits -1. + """ + self.atlas_index_change.emit( + self.available_atlas_dropdown.currentIndex() + ) + + def _on_moving_image_index_change(self): + """ + Emit the moving_image_index_change signal with the index of the selected image. + + If the selected index is invalid, emits -1. + """ + self.moving_image_index_change.emit( + self.available_sample_dropdown.currentIndex() + ) + + def _on_sample_popup_about_to_show(self): + self.sample_image_popup_about_to_show.emit() diff --git a/src/brainglobe_registration/widgets/transform_select_view.py b/src/brainglobe_registration/widgets/transform_select_view.py new file mode 100644 index 0000000..65d6b3c --- /dev/null +++ b/src/brainglobe_registration/widgets/transform_select_view.py @@ -0,0 +1,238 @@ +from qtpy.QtCore import QSignalMapper, Signal +from qtpy.QtWidgets import ( + QComboBox, + QTableWidget, +) + + +class TransformSelectView(QTableWidget): + """ + A QTableWidget subclass that provides a user interface for selecting transform + types and associated files. + + This widget displays a table of available transform types and associated default + parameter files. The user can select a transform type and an associated file from + dropdown menus. The widget emits signals when a transform type is added or removed, + or when a file option is changed. + + Attributes + ---------- + transform_type_added_signal : Signal + Emitted when a transform type is added. The signal includes the name of the + transform type and its index. + transform_type_removed_signal : Signal + Emitted when a transform type is removed. The signal includes the index of the + removed transform type. + file_option_changed_signal : Signal + Emitted when a file option is changed. The signal includes the name of the file + and its index. + + Methods + ------- + _on_transform_type_change(index): + Handles the event when a transform type is changed. Emits the + transform_type_added_signal or transform_type_removed_signal as appropriate. + _on_file_change(index): + Handles the event when the default file option is changed. + Emits the file_option_changed_signal. + """ + + transform_type_added_signal = Signal(str, int) + transform_type_removed_signal = Signal(int) + file_option_changed_signal = Signal(str, int) + + def __init__(self, parent=None): + """ + Initialize the widget. + + Parameters + ---------- + parent : QWidget, optional + The parent widget, by default None + """ + super().__init__(parent=parent) + + # Define the available transform types and file options + self.file_options = [ + "elastix_default", + "ara_tools", + "brainregister_IBL", + ] + self.transform_type_options = ["", "rigid", "affine", "bspline"] + + # Create signal mappers for the transform type and file option dropdown menus + self.transform_type_signaller = QSignalMapper(self) + self.transform_type_signaller.mapped[int].connect( + self._on_transform_type_change + ) + + self.file_signaller = QSignalMapper(self) + self.file_signaller.mapped[int].connect(self._on_file_change) + + # Initialize lists to hold the dropdown menus + self.transform_type_selections = [] + self.file_selections = [] + + # Set up the table + self.setColumnCount(2) + self.setRowCount(len(self.transform_type_options)) + self.setHorizontalHeaderLabels(["Transform Type", "Default File"]) + + # Add dropdown menus to the table for each transform type option + for i in range(len(self.transform_type_options) - 1): + # Create and configure the transform type dropdown menu + self.transform_type_selections.append(QComboBox()) + self.transform_type_selections[i].addItems( + self.transform_type_options + ) + self.transform_type_selections[i].setCurrentIndex(i + 1) + self.transform_type_selections[i].currentIndexChanged.connect( + self.transform_type_signaller.map + ) + + # Create and configure the file option dropdown menu + self.file_selections.append(QComboBox()) + self.file_selections[i].addItems(self.file_options) + self.file_selections[i].setCurrentIndex(0) + self.file_selections[i].currentIndexChanged.connect( + self.file_signaller.map + ) + + # Add the dropdown menus to the signal mappers + self.transform_type_signaller.setMapping( + self.transform_type_selections[i], i + ) + self.file_signaller.setMapping(self.file_selections[i], i) + + # Add the dropdown menus to the table + self.setCellWidget(i, 0, self.transform_type_selections[i]) + self.setCellWidget(i, 1, self.file_selections[i]) + + # Add an extra row to the table for adding new transform types + self.transform_type_selections.append(QComboBox()) + self.transform_type_selections[-1].addItems( + self.transform_type_options + ) + self.transform_type_selections[-1].currentIndexChanged.connect( + self.transform_type_signaller.map + ) + + self.transform_type_signaller.setMapping( + self.transform_type_selections[-1], + len(self.transform_type_options) - 1, + ) + + self.file_selections.append(QComboBox()) + self.file_selections[-1].addItems(self.file_options) + self.file_selections[-2].setEnabled(True) + self.file_selections[-1].setEnabled(False) + self.file_selections[-1].currentIndexChanged.connect( + self.file_signaller.map + ) + + self.file_signaller.setMapping( + self.file_selections[-1], len(self.transform_type_options) - 1 + ) + + self.setCellWidget( + len(self.transform_type_options) - 1, + 0, + self.transform_type_selections[-1], + ) + self.setCellWidget( + len(self.transform_type_options) - 1, 1, self.file_selections[-1] + ) + self.file_selections[-1].setEnabled(False) + self.resizeRowsToContents() + self.resizeColumnsToContents() + + def _on_transform_type_change(self, index): + """ + Handle the event when a transform type is changed. + + If a new transform type is selected, emits the transform_type_added_signal + and adds a new row to the table. If the transform type is set to "", removes + the row from the table and emits the transform_type_removed_signal. + + Parameters + ---------- + index : int + The index of the changed transform type. + """ + if self.transform_type_selections[index].currentIndex() != 0: + self.transform_type_added_signal.emit( + self.transform_type_selections[index].currentText(), index + ) + + self.file_selections[index].setCurrentIndex(0) + + if index >= len(self.transform_type_selections) - 1: + curr_length = self.rowCount() + self.setRowCount(self.rowCount() + 1) + + self.transform_type_selections.append(QComboBox()) + self.transform_type_selections[-1].addItems( + self.transform_type_options + ) + self.transform_type_selections[-1].currentIndexChanged.connect( + self.transform_type_signaller.map + ) + self.transform_type_signaller.setMapping( + self.transform_type_selections[-1], curr_length + ) + + self.file_selections.append(QComboBox()) + self.file_selections[-1].addItems(self.file_options) + self.file_selections[-2].setEnabled(True) + self.file_selections[-1].setEnabled(False) + self.file_selections[-1].currentIndexChanged.connect( + self.file_signaller.map + ) + self.file_signaller.setMapping( + self.file_selections[-1], curr_length + ) + + self.setCellWidget( + curr_length, 0, self.transform_type_selections[curr_length] + ) + self.setCellWidget( + curr_length, 1, self.file_selections[curr_length] + ) + + else: + self.transform_type_signaller.removeMappings( + self.transform_type_selections[index] + ) + self.transform_type_selections.pop(index) + + self.file_signaller.removeMappings(self.file_selections[index]) + self.file_selections.pop(index) + + # Update mappings + for i in range(index, len(self.transform_type_selections)): + self.transform_type_signaller.removeMappings( + self.transform_type_selections[i] + ) + self.transform_type_signaller.setMapping( + self.transform_type_selections[i], i + ) + self.file_signaller.removeMappings(self.file_selections[i]) + self.file_signaller.setMapping(self.file_selections[i], i) + + self.removeRow(index) + self.transform_type_removed_signal.emit(index) + + def _on_file_change(self, index): + """ + Handle the event when a file option is changed. + + Emits the file_option_changed_signal with the name of the file and its index. + + Parameters + ---------- + index : int + The index of the changed file option. + """ + self.file_option_changed_signal.emit( + self.file_selections[index].currentText(), index + ) diff --git a/src/tests/__init__.py b/src/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/conftest.py b/src/tests/conftest.py new file mode 100644 index 0000000..d6a99c3 --- /dev/null +++ b/src/tests/conftest.py @@ -0,0 +1,102 @@ +import os +from pathlib import Path + +import pytest +from bg_atlasapi import BrainGlobeAtlas, config as bg_config +from PIL import Image +import napari +import numpy as np + + +@pytest.fixture() +def make_napari_viewer_with_images(make_napari_viewer, pytestconfig): + viewer: napari.Viewer = make_napari_viewer() + + root_path = pytestconfig.rootpath + atlas_image = Image.open( + root_path / "src/tests/test_images/Atlas_Hipp.tif" + ) + moving_image = Image.open( + root_path / "src/tests/test_images/sample_hipp.tif" + ) + + viewer.add_image(np.asarray(moving_image), name="moving_image") + viewer.add_image(np.asarray(atlas_image), name="atlas_image") + + return viewer + + +def pytest_addoption(parser): + parser.addoption( + "--runslow", + action="store_true", + dest="slow", + default=False, + help="enable runslow decorated tests", + ) + + +def pytest_configure(config): + if not config.option.slow: + setattr(config.option, "markexpr", "not slow") + + +@pytest.fixture(autouse=True) +def mock_brainglobe_user_folders(monkeypatch): + """Ensures user config and data is mocked during all local testing. + + User config and data need mocking to avoid interfering with user data. + Mocking is achieved by turning user data folders used in tests into + subfolders of a new ~/.brainglobe-tests folder instead of ~/. + + It is not sufficient to mock the home path in the tests, as this + will leave later imports in other modules unaffected. + + GH actions workflow will test with default user folders. + """ + if not os.getenv("GITHUB_ACTIONS"): + home_path = Path.home() # actual home path + mock_home_path = home_path / ".brainglobe-tests" + if not mock_home_path.exists(): + mock_home_path.mkdir() + + def mock_home(): + return mock_home_path + + monkeypatch.setattr(Path, "home", mock_home) + + # also mock global variables of config.py + monkeypatch.setattr( + bg_config, "DEFAULT_PATH", mock_home_path / ".brainglobe" + ) + monkeypatch.setattr( + bg_config, "CONFIG_DIR", mock_home_path / ".config" / "brainglobe" + ) + monkeypatch.setattr( + bg_config, + "CONFIG_PATH", + bg_config.CONFIG_DIR / bg_config.CONFIG_FILENAME, + ) + mock_default_dirs = { + "default_dirs": { + "brainglobe_dir": mock_home_path / ".brainglobe", + "interm_download_dir": mock_home_path / ".brainglobe", + } + } + monkeypatch.setattr(bg_config, "TEMPLATE_CONF_DICT", mock_default_dirs) + + +@pytest.fixture(autouse=True) +def setup_preexisting_local_atlases(): + """Automatically setup all tests to have three downloaded atlases + in the test user data.""" + preexisting_atlases = [ + ("example_mouse_100um", "v1.2"), + ("allen_mouse_100um", "v1.2"), + ("osten_mouse_100um", "v1.1"), + ] + for atlas_name, version in preexisting_atlases: + if not Path.exists( + Path.home() / f".brainglobe/{atlas_name}_{version}" + ): + _ = BrainGlobeAtlas(atlas_name) diff --git a/src/tests/test_adjust_moving_image_view.py b/src/tests/test_adjust_moving_image_view.py new file mode 100644 index 0000000..dbdb4a5 --- /dev/null +++ b/src/tests/test_adjust_moving_image_view.py @@ -0,0 +1,96 @@ +import pytest + +from brainglobe_registration.widgets.adjust_moving_image_view import ( + AdjustMovingImageView, +) + +max_translate = 2000 +max_rotate = 360 + + +@pytest.fixture(scope="class") +def adjust_moving_image_view() -> AdjustMovingImageView: + adjust_moving_image_view = AdjustMovingImageView() + yield adjust_moving_image_view + + +@pytest.mark.parametrize( + "x_value, expected", + [ + (-100, -100), + (100, 100), + (max_translate + 1, max_translate), + (-1 * (max_translate + 1), -1 * max_translate), + ], +) +def test_x_position_changed( + qtbot, adjust_moving_image_view, x_value, expected +): + qtbot.addWidget(adjust_moving_image_view) + + with qtbot.waitSignal( + adjust_moving_image_view.adjust_image_signal, timeout=1000 + ) as blocker: + adjust_moving_image_view.adjust_moving_image_x.setValue(x_value) + + assert blocker.args == [expected, 0, 0] + + +@pytest.mark.parametrize( + "y_value, expected", + [ + (-100, -100), + (100, 100), + (max_translate + 1, max_translate), + (-1 * (max_translate + 1), -1 * max_translate), + ], +) +def test_y_position_changed( + qtbot, adjust_moving_image_view, y_value, expected +): + qtbot.addWidget(adjust_moving_image_view) + + with qtbot.waitSignal( + adjust_moving_image_view.adjust_image_signal, timeout=1000 + ) as blocker: + adjust_moving_image_view.adjust_moving_image_y.setValue(y_value) + + assert blocker.args == [0, expected, 0] + + +@pytest.mark.parametrize( + "rotate_value, expected", + [ + (-100, -100), + (100, 100), + (10.5, 10.5), + (max_rotate + 1, max_rotate), + (-1 * (max_rotate + 1), -1 * max_rotate), + ], +) +def test_rotation_position_changed( + qtbot, adjust_moving_image_view, rotate_value, expected +): + qtbot.addWidget(adjust_moving_image_view) + + with qtbot.waitSignal( + adjust_moving_image_view.adjust_image_signal, timeout=1000 + ) as blocker: + adjust_moving_image_view.adjust_moving_image_rotate.setValue( + rotate_value + ) + + assert blocker.args == [0, 0, expected] + + +def test_reset_image_button_click(qtbot, adjust_moving_image_view): + qtbot.addWidget(adjust_moving_image_view) + + with qtbot.waitSignal( + adjust_moving_image_view.reset_image_signal, timeout=1000 + ): + adjust_moving_image_view.adjust_moving_image_reset_button.click() + + assert adjust_moving_image_view.adjust_moving_image_x.value() == 0 + assert adjust_moving_image_view.adjust_moving_image_y.value() == 0 + assert adjust_moving_image_view.adjust_moving_image_rotate.value() == 0 diff --git a/src/tests/test_images/Atlas_Hipp.tif b/src/tests/test_images/Atlas_Hipp.tif new file mode 100644 index 0000000..9a6c44e Binary files /dev/null and b/src/tests/test_images/Atlas_Hipp.tif differ diff --git a/src/tests/test_images/Atlas_Str.tif b/src/tests/test_images/Atlas_Str.tif new file mode 100644 index 0000000..8b62e31 Binary files /dev/null and b/src/tests/test_images/Atlas_Str.tif differ diff --git a/src/tests/test_images/sample_hipp.tif b/src/tests/test_images/sample_hipp.tif new file mode 100644 index 0000000..d30ff69 Binary files /dev/null and b/src/tests/test_images/sample_hipp.tif differ diff --git a/src/tests/test_parameter_list_view.py b/src/tests/test_parameter_list_view.py new file mode 100644 index 0000000..f59df95 --- /dev/null +++ b/src/tests/test_parameter_list_view.py @@ -0,0 +1,96 @@ +import pytest + +from qtpy.QtWidgets import QTableWidgetItem +from brainglobe_registration.widgets.parameter_list_view import ( + RegistrationParameterListView, +) + +param_dict = { + "AutomaticScalesEstimation": ["true"], + "AutomaticTransformInitialization": ["true"], + "BSplineInterpolationOrder": ["1"], + "CheckNumberOfSamples": ["true"], + "Transform": ["BSplineTransform"], +} +transform_type = "bspline" + + +@pytest.fixture(scope="class") +def parameter_list_view() -> RegistrationParameterListView: + parameter_list_view = RegistrationParameterListView( + param_dict=param_dict, transform_type=transform_type + ) + return parameter_list_view + + +def test_parameter_list_view(parameter_list_view, qtbot): + qtbot.addWidget(parameter_list_view) + + assert parameter_list_view.rowCount() == len(param_dict) + 1 + assert parameter_list_view.columnCount() == 2 + + assert parameter_list_view.horizontalHeaderItem(0).text() == "Parameter" + assert parameter_list_view.horizontalHeaderItem(1).text() == "Values" + + for i, k in enumerate(param_dict): + assert parameter_list_view.item(i, 0).text() == k + assert parameter_list_view.item(i, 1).text() == ", ".join( + param_dict[k] + ) + + +def test_parameter_list_view_cell_change(parameter_list_view, qtbot): + qtbot.addWidget(parameter_list_view) + + with qtbot.waitSignal( + parameter_list_view.cellChanged, timeout=1000 + ) as blocker: + parameter_list_view.item(0, 1).setText("false") + + assert blocker.args == [0, 1] + assert parameter_list_view.param_dict["AutomaticScalesEstimation"] == [ + "false" + ] + + +def test_parameter_list_view_cell_change_last_row(parameter_list_view, qtbot): + qtbot.addWidget(parameter_list_view) + + curr_row_count = parameter_list_view.rowCount() + last_row_index = len(param_dict) + + parameter_list_view.setItem( + last_row_index, 0, QTableWidgetItem("TestParameter") + ) + parameter_list_view.setItem(last_row_index, 1, QTableWidgetItem("true")) + + assert parameter_list_view.param_dict["TestParameter"] == ["true"] + assert parameter_list_view.rowCount() == curr_row_count + 1 + + +def test_parameter_list_view_cell_change_last_row_no_param( + parameter_list_view, qtbot +): + qtbot.addWidget(parameter_list_view) + + curr_row_count = parameter_list_view.rowCount() + last_row_index = len(param_dict) + + parameter_list_view.setItem(last_row_index, 1, QTableWidgetItem("true")) + + assert parameter_list_view.rowCount() == curr_row_count + + +def test_parameter_list_view_cell_change_last_row_no_value( + parameter_list_view, qtbot +): + qtbot.addWidget(parameter_list_view) + + curr_row_count = parameter_list_view.rowCount() + last_row_index = len(param_dict) + + parameter_list_view.setItem( + last_row_index, 0, QTableWidgetItem("TestParameter") + ) + + assert parameter_list_view.rowCount() == curr_row_count diff --git a/src/tests/test_register.py b/src/tests/test_register.py new file mode 100644 index 0000000..5ca3298 --- /dev/null +++ b/src/tests/test_register.py @@ -0,0 +1,77 @@ +import pytest +from PIL import Image + +from brainglobe_registration.elastix.register import ( + setup_parameter_object, + run_registration, +) + + +@pytest.fixture +def sample_atlas_slice(): + return Image.open("test_images/Atlas_Hipp.tif") + + +@pytest.fixture +def sample_moving_image(): + return Image.open("test_images/sample_hipp.tif") + + +@pytest.mark.slow +def test_run_registration(sample_atlas_slice, sample_moving_image): + result_image, transform_parameters = run_registration( + sample_atlas_slice, + sample_moving_image, + ) + assert result_image is not None + assert transform_parameters is not None + + +def test_setup_parameter_object_empty_list(): + parameter_list = [] + + param_obj = setup_parameter_object(parameter_list) + + assert param_obj.GetNumberOfParameterMaps() == 0 + + +@pytest.mark.parametrize( + "parameter_list, expected", + [ + ( + [("rigid", {"Transform": ["EulerTransform"]})], + [("EulerTransform",)], + ), + ( + [("affine", {"Transform": ["AffineTransform"]})], + [("AffineTransform",)], + ), + ( + [("bspline", {"Transform": ["BSplineTransform"]})], + [("BSplineTransform",)], + ), + ( + [ + ("rigid", {"Transform": ["EulerTransform"]}), + ("affine", {"Transform": ["AffineTransform"]}), + ("bspline", {"Transform": ["BSplineTransform"]}), + ], + [("EulerTransform",), ("AffineTransform",), ("BSplineTransform",)], + ), + ( + [ + ("rigid", {"Transform": ["EulerTransform"]}), + ("rigid", {"Transform": ["EulerTransform"]}), + ("rigid", {"Transform": ["EulerTransform"]}), + ], + [("EulerTransform",), ("EulerTransform",), ("EulerTransform",)], + ), + ], +) +def test_setup_parameter_object_one_transform(parameter_list, expected): + param_obj = setup_parameter_object(parameter_list) + + assert param_obj.GetNumberOfParameterMaps() == len(expected) + + for index, transform_type in enumerate(expected): + assert param_obj.GetParameterMap(index)["Transform"] == transform_type diff --git a/src/tests/test_registration_widget.py b/src/tests/test_registration_widget.py new file mode 100644 index 0000000..aa91594 --- /dev/null +++ b/src/tests/test_registration_widget.py @@ -0,0 +1,67 @@ +import pytest + +from brainglobe_registration.registration_widget import RegistrationWidget + + +@pytest.fixture() +def registration_widget(make_napari_viewer_with_images): + viewer = make_napari_viewer_with_images + + widget = RegistrationWidget(viewer) + + return widget + + +def test_registration_widget(make_napari_viewer_with_images): + widget = RegistrationWidget(make_napari_viewer_with_images) + + assert widget is not None + + +def test_atlas_dropdown_index_changed_with_valid_index( + make_napari_viewer_with_images, registration_widget +): + registration_widget._on_atlas_dropdown_index_changed(2) + + assert ( + registration_widget._atlas.atlas_name + == registration_widget._available_atlases[2] + ) + assert registration_widget.run_button.isEnabled() + assert registration_widget._viewer.grid.enabled + + +def test_atlas_dropdown_index_changed_with_zero_index( + make_napari_viewer_with_images, registration_widget +): + registration_widget._on_atlas_dropdown_index_changed(0) + + assert registration_widget._atlas is None + assert not registration_widget.run_button.isEnabled() + assert not registration_widget._viewer.grid.enabled + + +def test_atlas_dropdown_index_changed_with_existing_atlas( + make_napari_viewer_with_images, registration_widget +): + registration_widget._on_atlas_dropdown_index_changed(2) + + registration_widget._on_atlas_dropdown_index_changed(1) + + assert ( + registration_widget._atlas.atlas_name + == registration_widget._available_atlases[1] + ) + assert registration_widget.run_button.isEnabled() + assert registration_widget._viewer.grid.enabled + + +def test_sample_dropdown_index_changed_with_valid_index( + make_napari_viewer_with_images, registration_widget +): + registration_widget._on_sample_dropdown_index_changed(1) + + assert ( + registration_widget._moving_image.name + == registration_widget._sample_images[1] + ) diff --git a/src/tests/test_select_images_view.py b/src/tests/test_select_images_view.py new file mode 100644 index 0000000..eb87168 --- /dev/null +++ b/src/tests/test_select_images_view.py @@ -0,0 +1,146 @@ +import pytest + +from brainglobe_registration.widgets.select_images_view import SelectImagesView + + +available_atlases = [ + "------", + "allen_mouse_100um", + "allen_mouse_25um", + "example_mouse_100um", +] +sample_image_names = ["image1", "image2", "image3"] + + +@pytest.fixture(scope="class") +def select_images_view() -> SelectImagesView: + select_images_view = SelectImagesView( + available_atlases=available_atlases, + sample_image_names=sample_image_names, + ) + yield select_images_view + + +def test_select_images_view(select_images_view, qtbot): + qtbot.addWidget(select_images_view) + + assert select_images_view.available_atlas_dropdown.count() == len( + available_atlases + ) + assert select_images_view.available_sample_dropdown.count() == len( + sample_image_names + ) + assert ( + select_images_view.available_atlas_dropdown.currentText() + == available_atlases[0] + ) + assert ( + select_images_view.available_sample_dropdown.currentText() + == sample_image_names[0] + ) + assert ( + select_images_view.available_atlas_dropdown_label.text() + == "Select Atlas:" + ) + assert ( + select_images_view.available_sample_dropdown_label.text() + == "Select sample:" + ) + + +@pytest.mark.parametrize( + "atlas_index, expected", + [ + (1, 1), + (2, 2), + (len(available_atlases), -1), + ], +) +def test_select_images_view_atlas_index_change_once( + select_images_view, qtbot, atlas_index, expected +): + qtbot.addWidget(select_images_view) + + with qtbot.waitSignal( + select_images_view.atlas_index_change, timeout=1000 + ) as blocker: + select_images_view.available_atlas_dropdown.setCurrentIndex( + atlas_index + ) + assert blocker.args == [expected] + + +@pytest.mark.parametrize( + "atlas_indexes", + [ + ([1, 2, 0]), + ([1, 2, 1]), + ], +) +def test_select_images_view_atlas_index_change_multi( + select_images_view, qtbot, atlas_indexes +): + qtbot.addWidget(select_images_view) + + expected = -1 + + def check_index_valid(signal_index): + return signal_index == expected + + with qtbot.waitSignals( + [select_images_view.atlas_index_change] * 3, + check_params_cbs=[check_index_valid] * 3, + timeout=1000, + ): + for index in atlas_indexes: + expected = index + select_images_view.available_atlas_dropdown.setCurrentIndex(index) + + +@pytest.mark.parametrize( + "image_index, expected", + [ + (1, 1), + (2, 2), + (len(sample_image_names), -1), + ], +) +def test_select_images_view_moving_image_index_change_once( + select_images_view, qtbot, image_index, expected +): + qtbot.addWidget(select_images_view) + + with qtbot.waitSignal( + select_images_view.moving_image_index_change, timeout=1000 + ) as blocker: + select_images_view.available_sample_dropdown.setCurrentIndex( + image_index + ) + assert blocker.args == [expected] + + +@pytest.mark.parametrize( + "image_indexes", + [ + ([1, 2, 0]), + ([1, 2, 1]), + ], +) +def test_select_images_view_moving_image_index_change_multi( + select_images_view, qtbot, image_indexes +): + qtbot.addWidget(select_images_view) + + expected = -1 + + def check_index_valid(signal_index): + return signal_index == expected + + with qtbot.waitSignals( + [select_images_view.atlas_index_change] * 3, + check_params_cbs=[check_index_valid] * 3, + timeout=1000, + ): + for index in image_indexes: + expected = index + select_images_view.available_atlas_dropdown.setCurrentIndex(index) diff --git a/src/tests/test_transform_select_view.py b/src/tests/test_transform_select_view.py new file mode 100644 index 0000000..7416612 --- /dev/null +++ b/src/tests/test_transform_select_view.py @@ -0,0 +1,168 @@ +import pytest + +from brainglobe_registration.widgets.transform_select_view import ( + TransformSelectView, +) + + +@pytest.fixture(scope="class") +def transform_select_view() -> TransformSelectView: + transform_select_view = TransformSelectView() + yield transform_select_view + + +def test_transform_select_view(transform_select_view, qtbot): + qtbot.addWidget(transform_select_view) + + assert ( + transform_select_view.horizontalHeaderItem(0).text() + == "Transform Type" + ) + assert ( + transform_select_view.horizontalHeaderItem(1).text() == "Default File" + ) + + assert transform_select_view.rowCount() == len( + transform_select_view.transform_type_options + ) + assert transform_select_view.columnCount() == 2 + + for i in range(len(transform_select_view.transform_type_options) - 1): + assert ( + transform_select_view.cellWidget(i, 0).currentText() + == transform_select_view.transform_type_options[i + 1] + ) + assert ( + transform_select_view.cellWidget(i, 1).currentText() + == transform_select_view.file_options[0] + ) + + last_row = len(transform_select_view.transform_type_options) - 1 + + assert ( + transform_select_view.cellWidget(last_row, 0).currentText() + == transform_select_view.transform_type_options[0] + ) + assert ( + transform_select_view.cellWidget(last_row, 1).currentText() + == transform_select_view.file_options[0] + ) + assert not transform_select_view.cellWidget(last_row, 1).isEnabled() + + +def test_transform_type_added_signal(transform_select_view, qtbot): + qtbot.addWidget(transform_select_view) + + last_index = len(transform_select_view.transform_type_options) - 1 + row_count = transform_select_view.rowCount() + + with qtbot.waitSignal( + transform_select_view.transform_type_added_signal, timeout=1000 + ) as blocker: + transform_select_view.cellWidget(0, 0).setCurrentIndex(last_index) + + assert blocker.args == [ + transform_select_view.transform_type_options[last_index], + 0, + ] + assert transform_select_view.rowCount() == row_count + + +def test_file_option_changed_signal(transform_select_view, qtbot): + qtbot.addWidget(transform_select_view) + + last_index = len(transform_select_view.file_options) - 1 + row_count = transform_select_view.rowCount() + + with qtbot.waitSignal( + transform_select_view.file_option_changed_signal, timeout=1000 + ) as blocker: + transform_select_view.cellWidget(0, 1).setCurrentIndex(last_index) + + assert blocker.args == [transform_select_view.file_options[last_index], 0] + assert transform_select_view.rowCount() == row_count + + +def test_transform_type_added_signal_last_row(transform_select_view, qtbot): + qtbot.addWidget(transform_select_view) + + last_index = len(transform_select_view.transform_type_options) - 1 + row_count = transform_select_view.rowCount() + + with qtbot.waitSignal( + transform_select_view.transform_type_added_signal, timeout=1000 + ) as blocker: + transform_select_view.cellWidget(row_count - 1, 0).setCurrentIndex( + last_index + ) + + assert blocker.args == [ + transform_select_view.transform_type_options[last_index], + row_count - 1, + ] + assert transform_select_view.rowCount() == row_count + 1 + assert ( + transform_select_view.cellWidget(row_count, 0).currentText() + == transform_select_view.transform_type_options[0] + ) + assert ( + transform_select_view.cellWidget(row_count, 1).currentText() + == transform_select_view.file_options[0] + ) + assert transform_select_view.cellWidget(row_count - 1, 1).isEnabled() + assert not transform_select_view.cellWidget(row_count, 1).isEnabled() + + with qtbot.waitSignal( + transform_select_view.transform_type_added_signal, timeout=1000 + ) as blocker: + transform_select_view.cellWidget(row_count, 0).setCurrentIndex( + last_index + ) + + assert blocker.args == [ + transform_select_view.transform_type_options[last_index], + row_count, + ] + + +def test_transform_type_removed(transform_select_view, qtbot): + qtbot.addWidget(transform_select_view) + + row_count = transform_select_view.rowCount() + + with qtbot.waitSignal( + transform_select_view.transform_type_removed_signal, timeout=1000 + ) as blocker: + transform_select_view.cellWidget(0, 0).setCurrentIndex(0) + + assert blocker.args == [0] + assert transform_select_view.rowCount() == row_count - 1 + + with qtbot.waitSignal( + transform_select_view.transform_type_removed_signal, timeout=1000 + ) as blocker: + transform_select_view.cellWidget(0, 0).setCurrentIndex(0) + + assert blocker.args == [0] + assert transform_select_view.rowCount() == row_count - 2 + + +def test_file_option_default_on_transform_change(transform_select_view, qtbot): + """ + When the transform type is changed, the file option should be set to the default + """ + qtbot.addWidget(transform_select_view) + file_option_index = 2 + + transform_select_view.cellWidget(0, 1).setCurrentIndex(file_option_index) + assert ( + transform_select_view.cellWidget(0, 1).currentText() + == transform_select_view.file_options[file_option_index] + ) + + transform_select_view.cellWidget(0, 0).setCurrentIndex(2) + + assert ( + transform_select_view.cellWidget(0, 1).currentText() + == transform_select_view.file_options[0] + ) diff --git a/src/tests/test_utils.py b/src/tests/test_utils.py new file mode 100644 index 0000000..affb84a --- /dev/null +++ b/src/tests/test_utils.py @@ -0,0 +1,93 @@ +from unittest.mock import Mock +import pytest +from pytransform3d.rotations import active_matrix_from_angle +import numpy as np +from pathlib import Path +from brainglobe_registration.utils.utils import ( + adjust_napari_image_layer, + open_parameter_file, + get_image_layer_names, + find_layer_index, +) + + +def adjust_napari_image_layer_no_translation_no_rotation(): + image_layer = Mock() + image_layer.data.shape = (100, 100) + adjust_napari_image_layer(image_layer, 0, 0, 0) + assert image_layer.translate == (0, 0) + assert np.array_equal(image_layer.affine, np.eye(3)) + + +def adjust_napari_image_layer_with_translation_no_rotation(): + image_layer = Mock() + image_layer.data.shape = (100, 100) + adjust_napari_image_layer(image_layer, 10, 20, 0) + assert image_layer.translate == (20, 10) + assert np.array_equal(image_layer.affine, np.eye(3)) + + +def adjust_napari_image_layer_no_translation_with_rotation(): + image_layer = Mock() + image_layer.data.shape = (100, 100) + adjust_napari_image_layer(image_layer, 0, 0, 45) + rotation_matrix = active_matrix_from_angle(2, np.deg2rad(45)) + translate_matrix = np.eye(3) + origin = np.asarray(image_layer.data.shape) // 2 + translate_matrix[:2, -1] = origin + expected_transform_matrix = ( + translate_matrix @ rotation_matrix @ np.linalg.inv(translate_matrix) + ) + assert np.array_equal(image_layer.affine, expected_transform_matrix) + + +def open_parameter_file_with_valid_content(): + file_path = Path("test_file.txt") + with open(file_path, "w") as f: + f.write("(key1 value1 value2)\n(key2 value3 value4)") + result = open_parameter_file(file_path) + assert result == { + "key1": ["value1", "value2"], + "key2": ["value3", "value4"], + } + file_path.unlink() + + +def open_parameter_file_with_invalid_content(): + file_path = Path("test_file.txt") + with open(file_path, "w") as f: + f.write("invalid content") + result = open_parameter_file(file_path) + assert result == {} + file_path.unlink() + + +def open_parameter_file_with_empty_content(): + file_path = Path("test_file.txt") + with open(file_path, "w") as f: + f.write("") + result = open_parameter_file(file_path) + assert result == {} + file_path.unlink() + + +@pytest.mark.parametrize( + "name, index", + [ + ("moving_image", 0), + ("atlas_image", 1), + ], +) +def test_find_layer_index(make_napari_viewer_with_images, name, index): + viewer = make_napari_viewer_with_images + + assert find_layer_index(viewer, name) == index + + +def test_get_image_layer_names(make_napari_viewer_with_images): + viewer = make_napari_viewer_with_images + + layer_names = get_image_layer_names(viewer) + + assert len(layer_names) == 2 + assert layer_names == ["moving_image", "atlas_image"] diff --git a/tox.ini b/tox.ini index 78d715c..948804c 100644 --- a/tox.ini +++ b/tox.ini @@ -29,4 +29,4 @@ passenv = PYVISTA_OFF_SCREEN extras = testing -commands = pytest -v --color=yes --cov=bg_elastix --cov-report=xml +commands = pytest -v --color=yes --cov=brainglobe_registration --cov-report=xml