diff --git a/.github/workflows/gpu-test-action.yml b/.github/workflows/gpu-test-action.yml index 316532c..0590526 100644 --- a/.github/workflows/gpu-test-action.yml +++ b/.github/workflows/gpu-test-action.yml @@ -3,7 +3,7 @@ name: gpu-tests on: pull_request: push: - branches: main + branches: [dev, main] jobs: test-linux: diff --git a/README.md b/README.md index eb9ce29..b25fe76 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,18 @@ The result is a deep-learning-based registration model that works well across da Please (currently) cite as: ``` -@misc{tian2024unigradicon, - title={uniGradICON: A Foundation Model for Medical Image Registration}, - author={Lin Tian and Hastings Greer and Roland Kwitt and Francois-Xavier Vialard and Raul San Jose Estepar and Sylvain Bouix and Richard Rushmore and Marc Niethammer}, - year={2024}, - eprint={2403.05780}, - archivePrefix={arXiv}, - primaryClass={cs.CV} +@article{tian2024unigradicon, + title={uniGradICON: A Foundation Model for Medical Image Registration}, + author={Tian, Lin and Greer, Hastings and Kwitt, Roland and Vialard, Francois-Xavier and Estepar, Raul San Jose and Bouix, Sylvain and Rushmore, Richard and Niethammer, Marc}, + journal={arXiv preprint arXiv:2403.05780}, + year={2024} +} + +@article{demir2024multigradicon, + title={multiGradICON: A Foundation Model for Multimodal Medical Image Registration}, + author={Demir, Basar and Tian, Lin and Greer, Thomas Hastings and Kwitt, Roland and Vialard, Francois-Xavier and Estepar, Raul San Jose and Bouix, Sylvain and Rushmore, Richard Jarrett and Ebrahim, Ebrahim and Niethammer, Marc}, + journal={arXiv preprint arXiv:2408.00221}, + year={2024} } ``` @@ -204,12 +209,25 @@ unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=Reg ``` -To register without instance optimization +To register without instance optimization (IO) ``` unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --io_iterations None ``` -To warp +To use a different similarity measure in the IO. We currently support three similarity measures +- LNCC: lncc +- Squared LNCC: lncc2 +- MIND SSC: mind +``` +unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --io_iterations 50 --io_sim lncc2 +``` + +To load specific model weight in the inference. We currently support uniGradICON and multiGradICON. +``` +unigradicon-register --fixed=RegLib_C01_2.nrrd --fixed_modality=mri --moving=RegLib_C01_1.nrrd --moving_modality=mri --transform_out=trans.hdf5 --warped_moving_out=warped_C01_1.nrrd --model multigradicon +``` + +To warp an image ``` unigradicon-warp --fixed [fixed_image_file_name] --moving [moving_image_file_name] --transform trans.hdf5 --warped_moving_out warped.nii.gz --linear ``` @@ -218,6 +236,7 @@ To warp a label map ``` unigradicon-warp --fixed [fixed_image_file_name] --moving [moving_image_segmentation_file_name] --transform trans.hdf5 --warped_moving_out warped_seg.nii.gz --nearest_neighbor ``` + We also provide a [colab](https://colab.research.google.com/drive/1JuFL113WN3FHCoXG-4fiBTWIyYpwGyGy?usp=sharing) demo. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..994a988 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +icon_registration>=1.1.5 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index a0575b5..f50b557 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +1,8 @@ [metadata] name = unigradicon -version = 1.0.2 +version = 1.0.3 author = Lin Tian -author_email = +author_email = lintian@cs.unc.edu description = a foundation model for medical image registration long_description = file: README.md long_description_content_type = text/markdown @@ -21,7 +21,7 @@ packages = find: python_requires = >=3.7 install_requires = - icon_registration>=1.1.4 + icon_registration>=1.1.5 [options.packages.find] where = src diff --git a/src/unigradicon/__init__.py b/src/unigradicon/__init__.py index 4125e5e..d284ee5 100644 --- a/src/unigradicon/__init__.py +++ b/src/unigradicon/__init__.py @@ -159,13 +159,40 @@ def make_network(input_shape, include_last_step=False, lmbda=1.5, loss_fn=icon.L net.assign_identity_map(input_shape) return net +def make_sim(similarity): + if similarity == "lncc": + return icon.LNCC(sigma=5) + elif similarity == "lncc2": + return icon. SquaredLNCC(sigma=5) + elif similarity == "mind": + return icon.MINDSSC(radius=2, dilation=2) + else: + raise ValueError(f"Similarity measure {similarity} not recognized. Choose from [lncc, lncc2, mind].") + +def get_multigradicon(loss_fn=icon.LNCC(sigma=5)): + net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn) + from os.path import exists + weights_location = "network_weights/multigradicon1.0/Step_2_final.trch" + if not exists(weights_location): + print("Downloading pretrained multigradicon model") + import urllib.request + import os + download_path = "https://github.com/uncbiag/uniGradICON/releases/download/multigradicon_weights/Step_2_final.trch" + os.makedirs("network_weights/multigradicon1.0/", exist_ok=True) + urllib.request.urlretrieve(download_path, weights_location) + print(f"Loading weights from {weights_location}") + trained_weights = torch.load(weights_location, map_location=torch.device("cpu")) + net.regis_net.load_state_dict(trained_weights) + net.to(config.device) + net.eval() + return net -def get_unigradicon(): - net = make_network(input_shape, include_last_step=True) +def get_unigradicon(loss_fn=icon.LNCC(sigma=5)): + net = make_network(input_shape, include_last_step=True, loss_fn=loss_fn) from os.path import exists weights_location = "network_weights/unigradicon1.0/Step_2_final.trch" if not exists(weights_location): - print("Downloading pretrained model") + print("Downloading pretrained unigradicon model") import urllib.request import os download_path = "https://github.com/uncbiag/uniGradICON/releases/download/unigradicon_weights/Step_2_final.trch" @@ -177,6 +204,14 @@ def get_unigradicon(): net.eval() return net +def get_model_from_model_zoo(model_name="unigradicon", loss_fn=icon.LNCC(sigma=5)): + if model_name == "unigradicon": + return get_unigradicon(loss_fn) + elif model_name == "multigradicon": + return get_multigradicon(loss_fn) + else: + raise ValueError(f"Model {model_name} not recognized. Choose from [unigradicon, multigradicon].") + def quantile(arr: torch.Tensor, q): arr = arr.flatten() l = len(arr) @@ -241,10 +276,14 @@ def main(): default=None, type=str, help="The path to save the warped image.") parser.add_argument("--io_iterations", required=False, default="50", help="The number of IO iterations. Default is 50. Set to 'None' to disable IO.") + parser.add_argument("--io_sim", required=False, + default="lncc", help="The similarity measure used in IO. Default is LNCC. Choose from [lncc, lncc2, mind].") + parser.add_argument("--model", required=False, + default="unigradicon", help="The model to load. Default is unigradicon. Choose from [unigradicon, multigradicon].") args = parser.parse_args() - net = get_unigradicon() + net = get_model_from_model_zoo(args.model, make_sim(args.io_sim)) fixed = itk.imread(args.fixed) moving = itk.imread(args.moving) diff --git a/tests/test_command_arguments.py b/tests/test_command_arguments.py new file mode 100644 index 0000000..069839d --- /dev/null +++ b/tests/test_command_arguments.py @@ -0,0 +1,110 @@ +import itk +import numpy as np +import unittest +import icon_registration.test_utils + +import subprocess +import os +import torch + + +class TestCommandInterface(unittest.TestCase): + def __init__(self, methodName: str = "runTest") -> None: + super().__init__(methodName) + icon_registration.test_utils.download_test_data() + self.test_data_dir = icon_registration.test_utils.TEST_DATA_DIR + self.test_temp_dir = f"{self.test_data_dir}/temp" + os.makedirs(self.test_temp_dir, exist_ok=True) + self.device = torch.cuda.current_device() + + def test_register_unigradicon_inference(self): + subprocess.run([ + "unigradicon-register", + "--fixed", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz", + "--fixed_modality", "ct", + "--fixed_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz", + "--moving", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz", + "--moving_modality", "ct", + "--moving_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz", + "--transform_out", f"{self.test_temp_dir}/transform.hdf5", + "--io_iterations", "None" + ]) + + # load transform + phi_AB = itk.transformread(f"{self.test_temp_dir}/transform.hdf5")[0] + + assert isinstance(phi_AB, itk.CompositeTransform) + + insp_points = icon_registration.test_utils.read_copd_pointset( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_300_iBH_xyz_r1.txt" + ) + ) + exp_points = icon_registration.test_utils.read_copd_pointset( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_300_eBH_xyz_r1.txt" + ) + ) + + dists = [] + for i in range(len(insp_points)): + px, py = ( + insp_points[i], + np.array(phi_AB.TransformPoint(tuple(exp_points[i]))), + ) + dists.append(np.sqrt(np.sum((px - py) ** 2))) + print(np.mean(dists)) + self.assertLess(np.mean(dists), 2.1) + + # remove temp file + os.remove(f"{self.test_temp_dir}/transform.hdf5") + + def test_register_multigradicon_inference(self): + subprocess.run([ + "unigradicon-register", + "--fixed", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_img.nii.gz", + "--fixed_modality", "ct", + "--fixed_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_EXP_STD_COPD_label.nii.gz", + "--moving", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_img.nii.gz", + "--moving_modality", "ct", + "--moving_segmentation", f"{self.test_data_dir}/lung_test_data/copd1_highres_INSP_STD_COPD_label.nii.gz", + "--transform_out", f"{self.test_temp_dir}/transform.hdf5", + "--io_iterations", "None", + "--model", "multigradicon" + ]) + + # load transform + phi_AB = itk.transformread(f"{self.test_temp_dir}/transform.hdf5")[0] + + assert isinstance(phi_AB, itk.CompositeTransform) + + insp_points = icon_registration.test_utils.read_copd_pointset( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_300_iBH_xyz_r1.txt" + ) + ) + exp_points = icon_registration.test_utils.read_copd_pointset( + str( + icon_registration.test_utils.TEST_DATA_DIR + / "lung_test_data/copd1_300_eBH_xyz_r1.txt" + ) + ) + + dists = [] + for i in range(len(insp_points)): + px, py = ( + insp_points[i], + np.array(phi_AB.TransformPoint(tuple(exp_points[i]))), + ) + dists.append(np.sqrt(np.sum((px - py) ** 2))) + print(np.mean(dists)) + self.assertLess(np.mean(dists), 3.8) + + # remove temp file + os.remove(f"{self.test_temp_dir}/transform.hdf5") + + + diff --git a/tests/test_requirements_sync.py b/tests/test_requirements_sync.py new file mode 100644 index 0000000..f56b77a --- /dev/null +++ b/tests/test_requirements_sync.py @@ -0,0 +1,19 @@ +import unittest + + +class TestImports(unittest.TestCase): + + def test_requirements_match_cfg(self): + from inspect import getsourcefile + import os.path as path, sys + import configparser + + current_dir = path.dirname(path.abspath(getsourcefile(lambda: 0))) + parent_dir = current_dir[: current_dir.rfind(path.sep)] + + with open(parent_dir + "/requirements.txt") as f: + requirements_txt = "\n" + f.read() + requirements_cfg = configparser.ConfigParser() + requirements_cfg.read(parent_dir + "/setup.cfg") + requirements_cfg = requirements_cfg["options"]["install_requires"] + self.assertEqual(requirements_txt, requirements_cfg)