diff --git a/.dockerignore b/.dockerignore index 32855116..0a127dc2 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,7 +2,7 @@ **/.github checkpoints Singularity -images +doc venv Tutorial **/*.md diff --git a/.github/ISSUE_TEMPLATE/questions-help-support.md b/.github/ISSUE_TEMPLATE/questions-help-support.md index 1cb25a75..1068b9b2 100644 --- a/.github/ISSUE_TEMPLATE/questions-help-support.md +++ b/.github/ISSUE_TEMPLATE/questions-help-support.md @@ -7,28 +7,24 @@ assignees: '' --- -## Question/Support Request - -... +**IMPORTANT**: Please make sure to fill out the information about your environment (see below). This is often critical information we need to help you. -## Screenshots - -... +## Question/Support Request +A clear and concise description of a question you may have or a problem for which you would like to request support. - +## Screenshots / Log files +Please provide error messages (can be a screenshot), stack traces, log files (specifically `$SUBJECTS_DIR/$SUBJECT_ID/scripts/deep-seg.log` and `$SUBJECTS_DIR/$SUBJECT_ID/scripts/recon-surf.log`) and any snippets useful in describing your problem here. ## Environment - - FastSurfer Version: ... - - FreeSurfer Version: ... - - OS: ... - - CPU: ... - - GPU: ... - + - FastSurfer Version: please run `run_fastsurfer.sh --version all` and copy/attach the resulting output + - Installation type: official docker/custom docker/singularity/native + - FreeSurfer Version: 7.4.1/7.3.2 + - OS: Windows/Linux/macOS + - GPU: none/RTX 2080/... - + +... ### Execution - - - -Run Command: +Include the command you used to run FastSurfer that cause the problem, e.g. +`./run_fastsurfer.sh --sid test --sd /path/to/dir --t1 /path/to/file.nii`. diff --git a/.github/workflows/QUICKTEST.md b/.github/workflows/QUICKTEST.md new file mode 100644 index 00000000..b4db637d --- /dev/null +++ b/.github/workflows/QUICKTEST.md @@ -0,0 +1,58 @@ +# FastSurfer Singularity GitHub Actions Workflow + +This GitHub Actions workflow is designed to automate the integration testing of new code into the FastSurfer repository using Singularity containers. The workflow is triggered whenever new code is pushed to the repository. + +The workflow runs on a self-hosted runner labelled 'ci-gpu' to ensure security. + +## Jobs + +The workflow consists of several jobs that are executed in sequence: + +### Checkout + +This job checks out the repository using the `actions/checkout@v2` action. + +### Prepare Job + +This job sets up the necessary environments for the workflow. It depends on the successful completion of the `checkout` job. The environments set up in this job include: + +- Python 3.10, using the `actions/setup-python@v3` action. +- Go, using the `actions/setup-go@v5` action with version `1.13.1`. +- Singularity, using the `eWaterCycle/setup-singularity@v7` action with version `3.8.3`. + +### Build Singularity Image + +This job builds a Docker image and converts it to a Singularity image. It depends on the successful completion of the `prepare-job`. The Docker image is built using a Python script `Docker/build.py` with the `--device cuda --tag fastsurfer_gpu:cuda` flags. The Docker image is then converted to a Singularity image. + +### Run FastSurfer + +This job runs FastSurfer on sample MRI data using the Singularity image built in the previous job. It depends on the successful completion of the `build-singularity-image` job. The Singularity container is executed with the `--nv`, `--no-home`, and `--bind` flags to enable GPU access, prevent home directory mounting, and bind the necessary directories respectively. The `FASTSURFER_HOME` environment variable is set to `/fastsurfer-dev` inside the container. + +### Test File Existence + +This job tests for the existence of certain files after running FastSurfer. It depends on the successful completion of the `run-fastsurfer` job. The test is performed using a Python script `test/test_file_existence.py`. + +### Test Error Messages + +This job tests for errors in log files after running FastSurfer. It runs on a self-hosted runner labeled `ci-gpu` and depends on the successful completion of both the `run-fastsurfer` and `test-file-existence` jobs. The test is performed using a Python script `test/test_error_messages.py`. + +## Usage + +To use this workflow, you need to have a self-hosted runner labeled `ci-gpu` set up on your machine. You also need to update the environment variables of the runner, by going to `/home/your_runner/.env` file and adding the following environment variables with the actual paths you want to use. + + +### Environment variables +`RUNNER_FS_MRI_DATA`: Path to MRI Data + +`RUNNER_FS_OUTPUT`: Path to Output directory + +`RUNNER_FS_LICENSE`: Path to License directory + +`RUNNER_SINGULARITY_IMGS`: Path to where Singularity images should be stored + +`RUNNER_FS_OUTPUT_FILES`: Path to output files to be tested + +`RUNNER_FS_OUTPUT_LOGS`: Path to output log files to check for errors + + +Once everything is set up, you can trigger the workflow manually from the GitHub Actions tab in your repository, as well as by pushing code to the repository. diff --git a/.github/workflows/code-style.yml b/.github/workflows/code-style.yml new file mode 100644 index 00000000..85e4ec0f --- /dev/null +++ b/.github/workflows/code-style.yml @@ -0,0 +1,40 @@ +name: code-style +concurrency: + group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} + cancel-in-progress: true +on: +# pull_request: +# push: +# branches: [dev] + workflow_dispatch: + +jobs: + style: + timeout-minutes: 10 + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Setup Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + architecture: 'x64' + cache: 'pip' # caching pip dependencies + - name: Install dependencies + run: | + python -m pip install --progress-bar off --upgrade pip setuptools wheel + python -m pip install --progress-bar off .[style] + - name: Run Ruff + run: ruff check . + - name: Run codespell + uses: codespell-project/actions-codespell@master + with: + check_filenames: true + check_hidden: true + skip: './.git,./build,./.mypy_cache,./.pytest_cache' + ignore_words_file: ./.codespellignore + - name: Run pydocstyle + run: pydocstyle . + - name: Run bibclean + run: bibclean-check doc/references.bib diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml new file mode 100644 index 00000000..3336a7fc --- /dev/null +++ b/.github/workflows/doc.yml @@ -0,0 +1,68 @@ +name: doc +concurrency: + group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} + cancel-in-progress: true +on: + pull_request: + push: + branches: [dev, stable] + workflow_dispatch: + +jobs: + build: + timeout-minutes: 10 + runs-on: ubuntu-latest + defaults: + run: + shell: bash + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + path: src + - name: Setup Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: '3.10' + architecture: 'x64' + cache: 'pip' # caching pip dependencies + - name: Install package + run: | + python -m pip install --progress-bar off --upgrade pip setuptools wheel + python -m pip install --progress-bar off src/.[doc] + - name: Build doc + run: PYTHONPATH=$PYTHONPATH:src TZ=UTC sphinx-build src/doc doc-build -W --keep-going + - name: Upload documentation + uses: actions/upload-artifact@v4 + with: + name: doc + path: | + doc-build + !doc-build/.doctrees + + deploy: + # only on push to dev or stable + if: ${{ github.event_name == 'push' && contains(fromJSON('["dev", "stable"]'), github.ref_name) }} + needs: build + timeout-minutes: 10 + runs-on: ubuntu-latest + permissions: + contents: write + defaults: + run: + shell: bash + steps: + - name: Download documentation + uses: actions/download-artifact@v4 + with: + name: doc + path: doc + - name: Deploy {dev,stable} documentation + uses: peaceiris/actions-gh-pages@v4 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: doc + # destination_dir: github.ref_name will be dev or stable + destination_dir: ${{ github.ref_name }} + user_name: 'github-actions[bot]' + user_email: 'github-actions[bot]@users.noreply.github.com' diff --git a/.github/workflows/quicktest.yaml b/.github/workflows/quicktest.yaml new file mode 100644 index 00000000..b6d827f3 --- /dev/null +++ b/.github/workflows/quicktest.yaml @@ -0,0 +1,90 @@ +name: FastSurfer Singularity + +on: + workflow_dispatch: + +jobs: + # Checkout repo + checkout: + runs-on: ci-gpu + steps: + - uses: actions/checkout@v2 + + # Prepare job: Set up Python, Go, Singularity + prepare-job: + runs-on: ci-gpu + needs: checkout + steps: + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '^1.13.1' # The Go version to download (if necessary) and use. + - name: Set up Singularity + uses: eWaterCycle/setup-singularity@v7 + with: + singularity-version: 3.8.3 + + # Build Docker Image and convert it to Singularity + build-singularity-image: + runs-on: ci-gpu + needs: prepare-job + steps: + - name: Build Docker Image and convert to Singularity + run: | + cd $RUNNER_SINGULARITY_IMGS + FILE="fastsurfer-gpu.sif" + if [ ! -f "$FILE" ]; then + # If the file does not exist, build the file + echo "SIF File does not exist. Building file." + PYTHONPATH=$PYTHONPATH + cd $PYTHONPATH + python3 Docker/build.py --device cuda --tag fastsurfer_gpu:cuda + cd $RUNNER_SINGULARITY_IMGS + singularity build --force fastsurfer-gpu.sif docker-daemon://fastsurfer_gpu:cuda + else + echo "File already exists" + cd $PYTHONPATH + fi + + # Run FastSurfer on MRI data + run-fastsurfer: + runs-on: ci-gpu + needs: build-singularity-image + steps: + - name: Run FastSurfer + run: | + singularity exec --nv \ + --no-home \ + --bind $GITHUB_WORKSPACE:/fastsurfer-dev \ + --env FASTSURFER_HOME=/fastsurfer-dev \ + -B $RUNNER_FS_MRI_DATA:/data \ + -B $RUNNER_FS_OUTPUT:/output \ + -B $RUNNER_FS_LICENSE:/fs_license \ + $RUNNER_SINGULARITY_IMGS/fastsurfer-gpu.sif \ + /fastsurfer/run_fastsurfer.sh \ + --fs_license /fs_license/.license \ + --t1 /data/subjectx/orig.mgz \ + --sid subjectX --sd /output \ + --parallel --3T + + # Test file existence + test-file-existence: + runs-on: ci-gpu + needs: run-fastsurfer + steps: + - name: Test File Existence + run: | + python3 test/quick_test/test_file_existence.py $RUNNER_FS_OUTPUT_FILES + + # Test for errors in log files + test-error-messages: + runs-on: ci-gpu + needs: [run-fastsurfer, test-file-existence] + steps: + - name: Test Log Files For Error Messages + run: | + python3 test/quick_test/test_errors.py $RUNNER_FS_OUTPUT_LOGS diff --git a/.gitignore b/.gitignore index 6288671e..a41e3f3d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,9 @@ /BUILD.info -/.idea/** \ No newline at end of file +/.idea/** +/rough_work/** + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + diff --git a/CerebNet/README.md b/CerebNet/README.md new file mode 100644 index 00000000..88ed7f98 --- /dev/null +++ b/CerebNet/README.md @@ -0,0 +1,5 @@ +# CerebNet +Deep learning based tool for segmentation of cerebellar sub-regions. + +The training and evaluation scripts of CerebNet are currently not part of the FastSurfer repository and are only available as incompatible stubs from the authors on request via email. +The interface to realistic deformations can be found in :py:`CerebNet.apply_warp`. diff --git a/CerebNet/__init__.py b/CerebNet/__init__.py index e69de29b..67c70352 100644 --- a/CerebNet/__init__.py +++ b/CerebNet/__init__.py @@ -0,0 +1,10 @@ +__all__ = [ + "apply_warp", + "config", + "datasets", + "data_loader", + "inference", + "models", + "run_prediction", + "utils", +] \ No newline at end of file diff --git a/CerebNet/apply_warp.py b/CerebNet/apply_warp.py index 5518b363..fb5fbb85 100644 --- a/CerebNet/apply_warp.py +++ b/CerebNet/apply_warp.py @@ -1,3 +1,4 @@ +import argparse # Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn # @@ -14,20 +15,52 @@ # limitations under the License. # IMPORTS -from os.path import join import numpy as np import nibabel as nib +from os.path import join from CerebNet.datasets import utils def save_nii_image(img_data, save_path, header, affine): + """ + Save an image data array as a NIfTI file. + + Parameters + ---------- + img_data : ndarray + The image data to be saved. + save_path : str + The path (including file name) where the image will be saved. + header : nibabel.Nifti1Header + The header information for the NIfTI file. + affine : ndarray + The affine matrix for the NIfTI file. + """ + img_out = nib.Nifti1Image(img_data, header=header, affine=affine) print(f"Saving {save_path}") nib.save(img_out, save_path) -def store_warped_data(img_path, lbl_path, warp_path, result_path, patch_size): +def main(img_path, lbl_path, warp_path, result_path, patch_size): + + """ + Load, warp, crop, and save both an image and its corresponding label based on a given warp field. + + Parameters + ---------- + img_path : str + Path to the T1-weighted MRI image to be warped. + lbl_path : str + Path to the label image corresponding to the T1 image, to be warped similarly. + warp_path : str + Path to the warp field file used to warp the images. + result_path : str + Directory path where the warped and cropped images will be saved. + patch_size : tuple of int + The dimensions (height, width, depth) cropped images after warping. + """ img, img_file = utils.load_reorient_rescale_image(img_path) @@ -58,8 +91,7 @@ def store_warped_data(img_path, lbl_path, warp_path, result_path, patch_size): affine=lbl_file.affine) -if __name__ == '__main__': - import argparse +def make_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--img_path", help="path to T1 image", @@ -75,11 +107,17 @@ def store_warped_data(img_path, lbl_path, warp_path, result_path, patch_size): help="Warp field file", default='1Warp.nii.gz', type=str) + return parser + +if __name__ == '__main__': + parser = make_parser() args = parser.parse_args() - warp_path = join(args.result_path, args.warp_filename) - store_warped_data(args.img_path, - args.lbl_path, - warp_path=warp_path, - result_path=args.result_path, - patch_size=(128, 128, 128)) + warp_path = str(join(args.result_path, args.warp_filename)) + main( + args.img_path, + args.lbl_path, + warp_path=warp_path, + result_path=args.result_path, + patch_size=(128, 128, 128), + ) diff --git a/CerebNet/config/__init__.py b/CerebNet/config/__init__.py index b3d73b6a..926ad07f 100644 --- a/CerebNet/config/__init__.py +++ b/CerebNet/config/__init__.py @@ -15,3 +15,9 @@ # IMPORTS from CerebNet.config.cerebnet import get_cfg_cerebnet from CerebNet.config.dataset import get_cfg_dataset +__all__ = [ + "cerebnet", + "dataset", + "get_cfg_cerebnet", + "get_cfg_dataset", +] \ No newline at end of file diff --git a/CerebNet/config/checkpoint_paths.yaml b/CerebNet/config/checkpoint_paths.yaml new file mode 100644 index 00000000..05db43bf --- /dev/null +++ b/CerebNet/config/checkpoint_paths.yaml @@ -0,0 +1,8 @@ +url: +- "https://zenodo.org/records/10390742/files" +- "https://b2share.fz-juelich.de/api/files/c6cf7bc6-2ae5-4d0e-814d-2a3cf0e1a8c5" + +checkpoint: + axial: "checkpoints/CerebNet_axial_v1.0.0.pkl" + coronal: "checkpoints/CerebNet_coronal_v1.0.0.pkl" + sagittal: "checkpoints/CerebNet_sagittal_v1.0.0.pkl" \ No newline at end of file diff --git a/CerebNet/data_loader/__init__.py b/CerebNet/data_loader/__init__.py index e69de29b..5d73dc8c 100644 --- a/CerebNet/data_loader/__init__.py +++ b/CerebNet/data_loader/__init__.py @@ -0,0 +1,6 @@ +__all__ = [ + "augmentation", + "data_utils", + "dataset", + "loader", +] \ No newline at end of file diff --git a/CerebNet/data_loader/augmentation.py b/CerebNet/data_loader/augmentation.py index 579d8d45..c67b2fde 100644 --- a/CerebNet/data_loader/augmentation.py +++ b/CerebNet/data_loader/augmentation.py @@ -28,9 +28,8 @@ from CerebNet.data_loader.data_utils import FLIPPED_LABELS -## + # Transformations for training -## class ToTensor(object): """ Convert ndarrays in sample to Tensors. @@ -71,10 +70,18 @@ class RandomAffine(object): """ Apply a random affine transformation to images, label and weight - the transformation includes translation, rotation and scaling + the transformation includes translation, rotation and scaling. """ def __init__(self, cfg): + """ + Create a Affine augmentation operation with Random Parameter initialization. + + Parameters + ---------- + cfg : yacs.config.CfgNode + Parameters degree, img_size, scale, translate, prob are filled from its attributes AUGMENTATION and MODEL. + """ self.degree = cfg.AUGMENTATION.DEGREE self.img_size = [cfg.MODEL.HEIGHT, cfg.MODEL.WIDTH] self.scale = cfg.AUGMENTATION.SCALE @@ -85,32 +92,7 @@ def __init__(self, cfg): def _get_random_affine(self): """ Random inverse affine matrix composed of rotation matrix (of each axis)and translation. - - Parameters - ---------- - degrees : sequence or float or int, - Range of degrees to select from. - If degrees is a number instead of sequence like (min, max), the range of degrees - will be (-degrees, +degrees). - translate : tuple, float - if translate=(a,b), the value for translation is uniformly sampled - in the range - -column_size * a < dx < column_size * a - -row_size * b < dy < row_size * b - If translate is a number then a=b=c. - The value should be between 0 and 1. - img_size : tuple - img_size = (column_size, row_size) - scale: tuple, range of min and max scaling factor - seed : int - random seed - - Returns - ------- - transform_mat : 3x3 matrix - Random affine transformation """ - if isinstance(self.degree, numbers.Number): if self.degree < 0: raise ValueError("If degrees is a single number, it must be positive.") @@ -179,9 +161,8 @@ def __call__(self, sample): class RandomFlip(object): """ - Random horizontal flipping + Random horizontal flipping. """ - def __init__(self, cfg): self.prob = cfg.AUGMENTATION.PROB self.axis = cfg.AUGMENTATION.FLIP_AXIS @@ -202,7 +183,8 @@ def __call__(self, sample): class RandomBiasField: - r"""Add random MRI bias field artifact. + """ + Add random MRI bias field artifact. Based on https://github.com/fepegar/torchio @@ -210,21 +192,30 @@ class RandomBiasField: `Sudre et al., 2017, Longitudinal segmentation of age-related white matter hyperintensities `_. - - Args: - coefficients: Magnitude :math:`n` of polynomial coefficients. - If a tuple :math:`(a, b)` is specified, then - :math:`n \sim \mathcal{U}(a, b)`. - order: Order of the basis polynomial functions. - p: Probability that this transform will be applied. - seed: """ - def __init__( self, cfg, seed: Optional[int] = None, ): + """ + Initialize the RandomBiasField object with configuration and optional seed. + + Parameters + ---------- + cfg : yacs.config.CfgNode + Node to get Config from (should include: + AUGMENTATION.BIAS_FIELD_COEFFICIENTS + Magnitude :math:`n` of polynomial coefficients. + If a tuple :math:`(a, b)` is specified, then + :math:`n \sim \mathcal{U}(a, b)`. + AUGMENTATION.BIAS_FIELD_ORDER + Order of the basis polynomial functions. + AUGMENTATION.PROB + Probability that this transform will be applied. + seed : int, optional + Seed. + """ coefficients = cfg.AUGMENTATION.BIAS_FIELD_COEFFICIENTS if isinstance(coefficients, float): coefficients = (-coefficients, coefficients) @@ -302,9 +293,8 @@ class RandomLabelsToImage(object): using the dataset intensity priors. based on Billot et al.: A Learning Strategy for Contrast-agnostic MRI Segmentation - and Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast. + and Partial Volume Segmentation of Brain MRI Scans of any Resolution and Contrast. """ - def __init__(self, mean, std, cfg, blur_factor=0.3): self.means = mean self.stds = std @@ -335,22 +325,38 @@ def __call__(self, sample): def sample_intensity_stats_from_image( image, segmentation, labels_list, classes_list=None, keep_strictly_positive=True ): - """This function takes an image and corresponding segmentation as inputs. It estimates the mean and std intensity - for all specified label values. Labels can share the same statistics by being regrouped into K classes. - :param image: image from which to evaluate mean intensity and std deviation. - :param segmentation: segmentation of the input image. Must have the same size as image. - :param labels_list: list of labels for which to evaluate mean and std intensity. - Can be a sequence, a 1d numpy array, or the path to a 1d numpy array. - :param classes_list: (optional) enables to regroup structures into classes of similar intensity statistics. - Intenstites associated to regrouped labels will thus contribute to the same Gaussian during statistics estimation. - Can be a sequence, a 1d numpy array, or the path to a 1d numpy array. - It should have the same length as labels_list, and contain values between 0 and K-1, where K is the total number of - classes. Default is all labels have different classes (K=len(labels_list)). - :param keep_strictly_positive: (optional) whether to only keep strictly positive intensity values when - computing stats. This doesn't apply to the first label in label_list (or class if class_list is provided), for - which we keep positive and zero values, as we consider it to be the background label. - :return: a numpy array of size (2, K), the first row being the mean intensity for each structure, - and the second being the median absolute deviation (robust estimation of std). + """ + This function takes an image and corresponding segmentation as inputs. + + It estimates the mean and std intensity for all specified label values. + Labels can share the same statistics by being regrouped into K classes. + + Parameters + ---------- + image : array_like + Image from which to evaluate mean intensity and std deviation. + segmentation : array_like + Segmentation of the input image. Must have the same size as image. + labels_list : array_like + List of labels for which to evaluate mean and std intensity. + Can be a sequence, a 1d numpy array, or the path to a 1d numpy array. + classes_list : array_like, optional + Enables grouping structures into classes of similar intensity statistics. + The intensities associated with regrouped labels will contribute to the same + Gaussian during statistics estimation. Can be a sequence, a 1D numpy array, + or the path to a 1D numpy array. It should have the same length as `labels_list`, + and contain values between 0 and K-1, where K is the total number of classes. + By default, each label is considered its own class (K=len(labels_list)). + keep_strictly_positive : optional + Whether to only keep strictly positive intensity values when computing stats. + This doesn't apply to the first label in label_list (or class if class_list is provided), for + which we keep positive and zero values, as we consider it to be the background label. + + Returns + ------- + numpy.ndarray + A numpy array of size (2, K), the first row being the mean intensity for each structure, + and the second being the median absolute deviation (robust estimation of std). """ # reformat labels and classes if classes_list is not None: diff --git a/CerebNet/data_loader/data_utils.py b/CerebNet/data_loader/data_utils.py index f6e418c0..f8cf8915 100644 --- a/CerebNet/data_loader/data_utils.py +++ b/CerebNet/data_loader/data_utils.py @@ -15,13 +15,14 @@ # IMPORTS -from typing import Literal, TypeVar +from typing import TypeVar import numpy as np import torch from numpy import typing as npt -Plane = Literal['axial', 'coronal', 'sagittal'] +from FastSurferCNN.utils import Plane + AT = TypeVar('AT', np.ndarray, torch.Tensor) # CLASSES for final evaluation @@ -145,18 +146,26 @@ 12, 13, 12, 14, 15, 14, 16, 16 - ]) -} + ])} # Transformation for mapping def transform_axial(vol, coronal2axial=True): """ - Function to transform volume into Axial axis and back - :param np.ndarray vol: image volume to transform - :param bool coronal2axial: transform from coronal to axial = True (default), - transform from axial to coronal = False - :return: + Function to transform volume into Axial axis and back. + + Parameters + ---------- + vol : np.ndarray + Image volume to transform. + coronal2axial : bool, default = True + If True (default), transforms from coronal to axial. + If False, transforms from axial to coronal. + + Returns + ------- + np.ndarray + Transformed image volume. """ if coronal2axial: return np.moveaxis(vol, [0, 1, 2, 3], [0, 2, 3, 1]) @@ -166,11 +175,20 @@ def transform_axial(vol, coronal2axial=True): def transform_sagittal(vol, coronal2sagittal=True): """ - Function to transform volume into Sagittal axis and back - :param np.ndarray vol: image volume to transform - :param bool coronal2sagittal: transform from coronal to sagittal = True (default), - transform from sagittal to coronal = False - :return: + Transform a volume into the Sagittal axis and back. + + Parameters + ---------- + vol : np.ndarray + The image volume to transform. + coronal2sagittal : bool, default = True + If True (default), transforms from coronal to sagittal. + If False, transforms from sagittal to coronal. + + Returns + ------- + np.ndarray + The transformed image volume. """ if coronal2sagittal: return np.moveaxis(vol, [0, 1, 2, 3], [0, 3, 2, 1]) @@ -180,11 +198,20 @@ def transform_sagittal(vol, coronal2sagittal=True): def transform_coronal(vol, axial2coronal=True): """ - Function to transform volume into coronal axis and back - :param np.ndarray vol: image volume to transform - :param bool axial2coronal: transform from axial to coronal = True (default), - transform from coronal to axial = False - :return: + Transform a volume into the coronal axis and back. + + Parameters + ---------- + vol : np.ndarray + The image volume to transform. + axial2coronal : bool, default=True + If True (default), transforms from axial to coronal. + If False, transforms from coronal to axial. + + Returns + ------- + np.ndarray + The transformed image volume. """ if axial2coronal: if len(vol.shape) == 4: @@ -200,11 +227,20 @@ def transform_coronal(vol, axial2coronal=True): def transform_axial2sagittal(vol, axial2sagittal=True): """ - Function to transform volume into Sagittal axis and back - :param np.ndarray vol: image volume to transform - :param bool coronal2sagittal: transform from coronal to sagittal = True (default), - transform from sagittal to coronal = False - :return: + Transform a volume into the Sagittal axis and back. + + Parameters + ---------- + vol : np.ndarray + The image volume to transform. + axial2sagittal : bool, default=True + If True (default), transforms from axial to sagittal. + If False, transforms from sagittal to axial. + + Returns + ------- + np.ndarray + The transformed image volume. """ if axial2sagittal: if len(vol.shape) == 4: @@ -238,12 +274,18 @@ def get_plane_transform(plane, primary_slice_dir='coronal'): def filter_blank_slices_thick(data_dict, img_key="img", lbl_key="label", threshold=10): """ - Function to filter blank slices from the volume using the label volume - :param dict data_dict: dictionary containing all volumes need to be filtered - :param img_key - :param lbl_key - :param threshold - :return: + Function to filter blank slices from the volume using the label volume. + + Parameters + ---------- + data_dict : dict + A dictionary containing all volumes that need to be filtered. + img_key : str, default="img" + Name of the key with the image. + lbl_key : str, default="label" + Name of the key with the target label. + threshold : int, default=10 + Threshold for number of voxels so this slice is included (or filtered). """ # Get indices of all slices with more than threshold labels/pixels selected_slices = (np.sum(data_dict[lbl_key], axis=(1, 2)) > threshold) @@ -253,11 +295,16 @@ def filter_blank_slices_thick(data_dict, img_key="img", lbl_key="label", thresho def create_weight_mask2d(label_map, class_wise_weights, max_edge_weight=5): """ - Function to create weighted mask - with median frequency balancing and edge-weighting - :param label_map: - :param class_wise_weights: - :param max_edge_weight: - :return: + Function to create weighted mask - with median frequency balancing and edge-weighting. + + Parameters + ---------- + label_map : np.ndarray + A 2D array representing the label map. + class_wise_weights : np.ndarray + A 1D array where each element is the weight corresponding to a class in the label map. + max_edge_weight : float, default=5 + The maximum weight to be applied at the edges in the label map to emphasize boundaries. """ (h, w) = label_map.shape weights_mask = np.reshape(class_wise_weights[label_map.ravel()], (h, w)) @@ -273,13 +320,19 @@ def create_weight_mask2d(label_map, class_wise_weights, max_edge_weight=5): def map_sag2label(lbl_data, label_type='cereb_subseg'): """ - Mapping right ids to left and relabeling - Args: - lbl_data: - label_type: - - Returns: - + Mapping right ids to left and relabeling. + + Parameters + ---------- + lbl_data : np.ndarray + An array of label data. + label_type : str, default="cereb_subseg" + A string identifier for the type of labels to map to. + + Returns + ------- + np.ndarray + The remapped label array with continuous labels. """ for r_lbl, l_lbl in sag_right2left.items(): lbl_data[lbl_data == r_lbl] = l_lbl @@ -296,19 +349,40 @@ def map_sag2label(lbl_data, label_type='cereb_subseg'): def map_prediction_sagittal2full(prediction_sag, lbl_type): """ - Function to remap the prediction on the sagittal network to full label space used by coronal and axial networks - :param prediction_sag: sagittal prediction (labels) - :param lbl_type: type of label - :return: Remapped prediction + Function to remap the prediction on the sagittal network to + full label space used by coronal and axial networks. + + Parameters + ---------- + prediction_sag : np.ndarray + Sagittal prediction (labels). + lbl_type : str + Type of label. + + Returns + ------- + np.ndarray + Remapped prediction. """ - idx_list = SAG2FULL_MAP[lbl_type] prediction_full = prediction_sag[:, idx_list, :, :] return prediction_full def get_aseg_cereb_mask(aseg_map: npt.NDArray[int]) -> npt.NDArray[bool]: - """Get a boolean mask of the cerebellum from a segmentation image.""" + """ + Get a boolean mask of the cerebellum from a segmentation image. + + Parameters + ---------- + aseg_map : np.ndarray + A segmentation image. + + Returns + ------- + np.ndarray + A boolean mask of the cerebellum. + """ wm_cereb_mask = np.logical_or(aseg_map == 46, aseg_map == 7) gm_cereb_mask = np.logical_or(aseg_map == 47, aseg_map == 8) return np.logical_or(wm_cereb_mask, gm_cereb_mask) @@ -333,16 +407,23 @@ def get_binary_map(lbl_map, class_names): def slice_lia2ras(plane: Plane, data: AT, /, thick_slices: bool = False) -> AT: - """Maps the data from LIA to RAS orientation. - - Args: - plane: the slicing direction (usually moved into batch dimension) - data: the data array of shape [plane, Channels, H, W] - thick_slices: whether the channels are thick slices and should also be flipped (default: False). - - Returns: - data reoriented from LIA to RAS of [plane, Channels, H, W] (plane: 'sagittal' or 'coronal') or - [plane, Channels, W, H] (plane: 'axial'). + """ + Maps the data from LIA to RAS orientation. + + Parameters + ---------- + plane : Plane + The slicing direction (usually moved into batch dimension). + data : np.ndarray + The data array of shape [plane, Channels, H, W]. + thick_slices : bool, default = False + Whether the channels are thick slices and should also be flipped. + + Returns + ------- + np.ndarray + Data reoriented from LIA to RAS of [plane, Channels, H, W] (plane: 'sagittal' or 'coronal') or + [plane, Channels, W, H] (plane: 'axial'). """ if isinstance(data, np.ndarray): flip, swapaxes = np.flip, np.swapaxes @@ -362,16 +443,23 @@ def slice_lia2ras(plane: Plane, data: AT, /, thick_slices: bool = False) -> AT: def slice_ras2lia(plane: Plane, data: AT, /, thick_slices: bool = False) -> AT: - """Maps the data from RAS to LIA orientation. - - Args: - plane: the slicing direction (usually moved into batch dimension) - data: the data array of shape [plane, Channels, H, W] - thick_slices: whether the channels are thick slices and should also be flipped (default: False). - - Returns: - data reoriented from RAS to LIA of [plane, Channels, H, W] (plane: 'sagittal' or 'coronal') or - [plane, Channels, W, H] (plane: 'axial'). + """ + Maps the data from RAS to LIA orientation. + + Parameters + ---------- + plane : Plane + The slicing direction (usually moved into batch dimension). + data : np.ndarray + The data array of shape [plane, Channels, H, W]. + thick_slices : bool, default=False + Whether the channels are thick slices and should also be flipped. + + Returns + ------- + np.ndarray + Data reoriented from RAS to LIA of [plane, Channels, H, W] (plane: 'sagittal' or 'coronal') or + [plane, Channels, W, H] (plane: 'axial'). The dtype of the array is the same as data. """ if isinstance(data, np.ndarray): flip, swapaxes = np.flip, np.swapaxes diff --git a/CerebNet/data_loader/dataset.py b/CerebNet/data_loader/dataset.py index 8fff9984..f83a61af 100644 --- a/CerebNet/data_loader/dataset.py +++ b/CerebNet/data_loader/dataset.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import time # IMPORTS -from typing import Sequence, Tuple, Literal, get_args as _get_args, TypeVar, Dict +from typing import Tuple, Literal, TypeVar, Dict from numbers import Number import nibabel as nib @@ -25,8 +24,7 @@ from torch.utils.data.dataset import Dataset from torchvision.transforms import Compose -from CerebNet.data_loader.data_utils import Plane -from FastSurferCNN.utils import logging +from FastSurferCNN.utils import logging, Plane from FastSurferCNN.data_loader.data_utils import ( get_thick_slices, transform_axial, @@ -42,7 +40,6 @@ LocalizerROI = Dict[ROIKeys, Tuple[int, ...]] NT = TypeVar("NT", bound=Number) -PLANES = _get_args(Plane) logger = logging.get_logger(__name__) diff --git a/CerebNet/datasets/__init__.py b/CerebNet/datasets/__init__.py index e69de29b..191894fe 100644 --- a/CerebNet/datasets/__init__.py +++ b/CerebNet/datasets/__init__.py @@ -0,0 +1,6 @@ +__all__ = [ + "generate_hdf5", + "load_data", + "utils", + "wm_merge_clean", +] \ No newline at end of file diff --git a/CerebNet/datasets/load_data.py b/CerebNet/datasets/load_data.py index a0329ab3..500362fb 100644 --- a/CerebNet/datasets/load_data.py +++ b/CerebNet/datasets/load_data.py @@ -24,6 +24,9 @@ class SubjectLoader: + """ + Subject loader class. + """ def __init__(self, cfg, aux_subjects_files=None): self.cfg = cfg self.patch_size = cfg.PATCH_SIZE @@ -37,22 +40,52 @@ def _process_segm_volumes( self, seg_map, label_map_func, - plane_transform=None, - ): + plane_transform=None,): """ - - :param seg_map: - :param plane_transform: - :param label_map_func: - :return: + Process segmentation volumes. + + Parameters + ---------- + seg_map : np.ndarray + The segmentation map to be processed. + label_map_func : function + A function to map labels in the segmentation map. + plane_transform : function, optional + A function to transform the segmentation map in plane. Defaults to None. + + Returns + ------- + np.ndarray + The processed segmentation map. """ - mapped_seg = label_map_func(seg_map) if plane_transform is not None: mapped_seg = plane_transform(mapped_seg) return mapped_seg def _load_volumes(self, subject_path, store_talairach=False): + """ + Loads the original image and cerebellum sub-segmentation from the given subject path. + Also loads the Talairach coordinates if store_talairach is set to True. + + Parameters + ---------- + subject_path : str + The path to the subject's data directory. + store_talairach : bool, default=False + If True, the method will attempt to load the Talairach coordinates. Defaults to False. + + Returns + ------- + orig : np.ndarray + The original image. + cereb_subseg : np.ndarray + The cerebellum sub-segmentation loaded from the subject's data directory. + img_meta_data : dict + Dictionary containing the affine transformation and header from cereb_subseg file. + If store_talairach is True and Talairach coordinates file exists, also contains the + Talairach coordinates. + """ orig_path = join(subject_path, self.cfg.IMAGE_NAME) subseg_path = join(subject_path, self.cfg.CEREB_SUBSEG_NAME) @@ -163,11 +196,21 @@ def _load_auxiliary_data(self, aux_subjects_path): def load_subject(self, current_subject, store_talairach=False, load_aux_data=False): """ - Loads and process the subject and return data in a dictionary - :param current_subject: subject ID - :param load_aux_data: to load auxiliary data or not - :return: - dictionary of processed data + Loads and processes the subject and returns data in a dictionary. + + Parameters + ---------- + current_subject : str + Subject ID. + store_talairach : bool, optional + Whether to store Talairach coordinates. Defaults to False. + load_aux_data : bool, optional + Whether to load auxiliary data. Defaults to False. + + Returns + ------- + dict + Dictionary of processed data. """ in_data = {} subject_path = join(self.cfg.DATA_DIR, current_subject) diff --git a/CerebNet/datasets/utils.py b/CerebNet/datasets/utils.py index ea01254b..b5a93e05 100644 --- a/CerebNet/datasets/utils.py +++ b/CerebNet/datasets/utils.py @@ -14,15 +14,16 @@ # IMPORTS -from typing import Tuple, Union, Sequence, Optional, TypeVar +from typing import Tuple, Union, Sequence, Optional, TypeVar, TypedDict, Iterable, Type +from pathlib import Path import nibabel as nib import numpy as np +from numpy import typing as npt import torch from FastSurferCNN.data_loader.conform import getscale, scalecrop -# class names for network training and validation/testing CLASS_NAMES = { "Background": 0, "Left_I_IV": 1, @@ -54,11 +55,38 @@ "Right_Corpus_Medullare": 38, } +# class names for network training and validation/testing subseg_labels = {"cereb_subseg": np.array(list(CLASS_NAMES.values()))} AT = TypeVar("AT", np.ndarray, torch.Tensor) +class LTADict(TypedDict): + type: int + nxforms: int + mean: list[float] + sigma: float + lta: npt.NDArray[float] + src_valid: int + src_filename: str + src_volume: list[int] + src_voxelsize: list[float] + src_xras: list[float] + src_yras: list[float] + src_zras: list[float] + src_cras: list[float] + dst_valid: int + dst_filename: str + dst_volume: list[int] + dst_voxelsize: list[float] + dst_xras: list[float] + dst_yras: list[float] + dst_zras: list[float] + dst_cras: list[float] + src: npt.NDArray[float] + dst: npt.NDArray[float] + + def define_size(mov_dim, ref_dim): new_dim = np.zeros(len(mov_dim), dtype=int) borders = np.zeros((len(mov_dim), 2), dtype=int) @@ -167,7 +195,7 @@ def bounding_volume_offset( if isinstance(img, np.ndarray): from FastSurferCNN.data_loader.data_utils import bbox_3d - bbox = bbox_3d(img != 0) + bbox = bbox_3d(np.not_equal(img, 0)) bbox = bbox[::2] + bbox[1::2] else: bbox = img @@ -325,237 +353,78 @@ def apply_warp_field(dform_field, img, interpol_order=3): return deformed_img -def readLTA(file): +def read_lta(file: Path | str) -> LTADict: + """Read the LTA info.""" import re + from functools import partial import numpy as np + parameter_pattern = re.compile("^\s*([^=]+)\s*=\s*([^#]*)\s*(#.*)") + vol_info_pattern = re.compile("^(.*) volume info$") + shape_pattern = re.compile("^(\s*\d+)+$") + matrix_pattern = re.compile("^(-?\d+\.\S+\s+)+$") + + _Type = TypeVar("_Type", bound=Type) + + def _vector(_a: str, dtype: Type[_Type] = float, count: int = -1) -> list[_Type]: + return np.fromstring(_a, dtype=dtype, count=count, sep=" ").tolist() + + parameters = { + "type": int, + "nxforms": int, + "mean": partial(_vector, dtype=float, count=3), + "sigma": float, + "subject": str, + "fscale": float, + } + vol_info_par = { + "valid": int, + "filename": str, + "volume": partial(_vector, dtype=int, count=3), + "voxelsize": partial(_vector, dtype=float, count=3), + **{f"{c}ras": partial(_vector, dtype=float) for c in "xyzc"} + } with open(file, "r") as f: - lta = f.readlines() - d = dict() - i = 0 - while i < len(lta): - if re.match("type", lta[i]) is not None: - d["type"] = int( - re.sub("=", "", re.sub("[a-z]+", "", re.sub("#.*", "", lta[i]))).strip() - ) - i += 1 - elif re.match("nxforms", lta[i]) is not None: - d["nxforms"] = int( - re.sub("=", "", re.sub("[a-z]+", "", re.sub("#.*", "", lta[i]))).strip() - ) - i += 1 - elif re.match("mean", lta[i]) is not None: - d["mean"] = [ - float(x) - for x in re.split( - " +", - re.sub( - "=", "", re.sub("[a-z]+", "", re.sub("#.*", "", lta[i])) - ).strip(), - ) - ] - i += 1 - elif re.match("sigma", lta[i]) is not None: - d["sigma"] = float( - re.sub("=", "", re.sub("[a-z]+", "", re.sub("#.*", "", lta[i]))).strip() - ) - i += 1 - elif ( - re.match( - "-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+", lta[i] - ) - is not None - ): - d["lta"] = np.array( - [ - [ - float(x) - for x in re.split( - " +", - re.match( - "-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+", - lta[i], - ).string.strip(), - ) - ], - [ - float(x) - for x in re.split( - " +", - re.match( - "-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+", - lta[i + 1], - ).string.strip(), - ) - ], - [ - float(x) - for x in re.split( - " +", - re.match( - "-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+", - lta[i + 2], - ).string.strip(), - ) - ], - [ - float(x) - for x in re.split( - " +", - re.match( - "-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+-*[0-9]\.\S+\W+", - lta[i + 3], - ).string.strip(), - ) - ], - ] - ) - i += 4 - elif re.match("src volume info", lta[i]) is not None: - while i < len(lta) and re.match("dst volume info", lta[i]) is None: - if re.match("valid", lta[i]) is not None: - d["src_valid"] = int( - re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - elif re.match("filename", lta[i]) is not None: - d["src_filename"] = re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - elif re.match("volume", lta[i]) is not None: - d["src_volume"] = [ - int(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - elif re.match("voxelsize", lta[i]) is not None: - d["src_voxelsize"] = [ - float(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - elif re.match("xras", lta[i]) is not None: - d["src_xras"] = [ - float(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - elif re.match("yras", lta[i]) is not None: - d["src_yras"] = [ - float(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - elif re.match("zras", lta[i]) is not None: - d["src_zras"] = [ - float(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - elif re.match("cras", lta[i]) is not None: - d["src_cras"] = [ - float(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - i += 1 - elif re.match("dst volume info", lta[i]) is not None: - while i < len(lta) and re.match("src volume info", lta[i]) is None: - if re.match("valid", lta[i]) is not None: - d["dst_valid"] = int( - re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - elif re.match("filename", lta[i]) is not None: - d["dst_filename"] = re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - elif re.match("volume", lta[i]) is not None: - d["dst_volume"] = [ - int(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - elif re.match("voxelsize", lta[i]) is not None: - d["dst_voxelsize"] = [ - float(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - elif re.match("xras", lta[i]) is not None: - d["dst_xras"] = [ - float(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - elif re.match("yras", lta[i]) is not None: - d["dst_yras"] = [ - float(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - elif re.match("zras", lta[i]) is not None: - d["dst_zras"] = [ - float(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - elif re.match("cras", lta[i]) is not None: - d["dst_cras"] = [ - float(x) - for x in re.split( - " +", re.sub(".*=", "", re.sub("#.*", "", lta[i])).strip() - ) - ] - i += 1 - else: - i += 1 - # create full transformation matrices - d["src"] = np.concatenate( - ( - np.concatenate( - ( - np.c_[d["src_xras"]], - np.c_[d["src_yras"]], - np.c_[d["src_zras"]], - np.c_[d["src_cras"]], - ), - axis=1, - ), - np.array([0.0, 0.0, 0.0, 1.0], ndmin=2), - ), - axis=0, - ) - d["dst"] = np.concatenate( - ( - np.concatenate( - ( - np.c_[d["dst_xras"]], - np.c_[d["dst_yras"]], - np.c_[d["dst_zras"]], - np.c_[d["dst_cras"]], - ), - axis=1, - ), - np.array([0.0, 0.0, 0.0, 1.0], ndmin=2), - ), - axis=0, - ) - # return - return d + lines = f.readlines() + + items = [] + shape_lines = [] + matrix_lines = [] + section = "" + for i, line in enumerate(lines): + if line.strip() == "": + continue + if hits := parameter_pattern.match(line): + name = hits.group(1) + if section and name in vol_info_par: + items.append((f"{section}_{name}", vol_info_par[name](hits.group(2)))) + elif name in parameters: + section = "" + items.append((name, parameters[name](hits.group(2)))) + else: + raise NotImplementedError(f"Unrecognized type string in lta-file " + f"{file}:{i+1}: '{name}'") + elif hits := vol_info_pattern.match(line): + section = hits.group(1) + # not a parameter line + elif shape_pattern.search(line): + shape_lines.append(np.fromstring(line, dtype=int, count=-1, sep=" ")) + elif matrix_pattern.search(line): + matrix_lines.append(np.fromstring(line, dtype=float, count=-1, sep=" ")) + + shape_lines = list(map(tuple, shape_lines)) + lta = dict(items) + if lta["nxforms"] != len(shape_lines): + raise IOError("Inconsistent lta format: nxforms inconsistent with shapes.") + if len(shape_lines) > 1 and np.any(np.not_equal([shape_lines[0]], shape_lines[1:])): + raise IOError(f"Inconsistent lta format: shapes inconsistent {shape_lines}") + lta_matrix = np.asarray(matrix_lines).reshape((-1,) + shape_lines[0].shape) + lta["lta"] = lta_matrix + return lta def load_talairach_coordinates(tala_path, img_shape, vox2ras): - tala_lta = readLTA(tala_path) + tala_lta = read_lta(tala_path) # create image grid p x, y, z = np.meshgrid( np.arange(img_shape[0]), @@ -567,7 +436,7 @@ def load_talairach_coordinates(tala_path, img_shape, vox2ras): p1 = np.concatenate((p, np.ones((p.shape[0], 1))), axis=1) assert tala_lta["type"] == 1, "talairach not in ras2ras" # ras2ras - m = np.matmul(tala_lta["lta"], vox2ras) + m = np.matmul(tala_lta["lta"][0, 0], vox2ras) tala_coordinates = np.matmul(m, p1.transpose()).transpose() tala_coordinates = tala_coordinates[:, :-1] @@ -587,7 +456,25 @@ def normalize_array(arr): def _crop_transform_make_indices(image_shape, offsets, target_shape): - """Create the indexing tuple. Returned pad tuples are for the last N dimensions.""" + """ + Create the indexing tuple and return padding tuples for the last N dimensions. + + Parameters + ---------- + image_shape : np.ndarray + The shape of the image from which a region is to be cropped. + offsets : Sequence[int] + Exact location within the image from which the cropping should start. + target_shape : Sequence[int], optional + The desired shape of the cropped region. + + Returns + ------- + paddings: list of 2-tuples of paddings or None + A list of per-axis tuples of the padding to apply to the slice to get the target_shape. + indices : tuple of indices + A tuple of per-axis indices to index in the data to get the target_shape. + """ if len(offsets) != len(target_shape): raise ValueError( f"offsets {offsets} and target shape {target_shape} must be same length." @@ -611,7 +498,21 @@ def _crop_transform_make_indices(image_shape, offsets, target_shape): def _crop_transform_pad_fn(image, pad_tuples, pad): - """Generate a parameterized pad function.""" + """ + Generate a parameterized pad function. + + Parameters + ---------- + image : np.ndarray, torch.Tensor + Input image. + pad_tuples : List[Tuple[int, int]] + List of padding tuples for each axis. + + Returns + ------- + partial + A partial function to pad the image. + """ if all(p1 == 0 and p2 == 0 for p1, p2 in pad_tuples): return None @@ -642,35 +543,63 @@ def _crop_transform_pad_fn(image, pad_tuples, pad): def crop_transform( - image: AT, offsets=None, target_shape=None, out: Optional[AT] = None, pad=0 -): + image: AT, + offsets: Optional[Sequence[int]] = None, + target_shape: Optional[Sequence[int]] = None, + out: Optional[AT] = None, + pad: int = 0, +) -> AT: """ - Perform a crop transform of the last N dimensions on the image data. Cropping does not interpolate the image, but - "just removes" border pixels/voxels. Negative offsets lead to padding. - - Args: - image: image of size [..., D_1, D_2, ..., D_N], where D_1, D_2, ..., D_N are the N image dimensions. - offsets: offset of the cropped region for the last N dimensions (default: center crop with less crop/pad - towards index 0). - target_shape: if defined, target_shape specifies the target shape of the "cropped region", else the crop - will be centered cropping offset[dim] voxels on each side (then the shape is derived by subtracting 2x - the dimension-specific offset). target_shape should have the same number of elements as offsets. - May be implicitly defined by out. - out: Array to store the cropped image in (optional), can be a view on image for memory-efficiency. - pad: padding strategy to use when padding is required (default: zero-pad). - - Notes: - Either offsets, target_shape or out must be defined. - - Raises: - ValueError: If neither offsets nor target_shape nor out are defined. - ValueError: If out is not target_shape. - TypeError: If the type of image is not an np.ndarray or a torch.Tensor. - RuntimeError: If the dimensionality of image, out, offset or target_shape is invalid or inconsistent. - - Returns: - The image (stack) cropped in the last N dimensions by offsets to the shape target_shape, or if target_shape is - not given image.shape[i+2] - 2*offset[i]. + Perform a crop transform of the last N dimensions on the image data. + Cropping does not interpolate the image, but "just removes" border pixels/voxels. + Negative offsets lead to padding. + + Parameters + ---------- + image : np.ndarray, torch.Tensor + Image of size [..., D_1, D_2, ..., D_N], where D_1, D_2, ..., D_N are the N + image dimensions. + offsets : Sequence[int], optional + Offset of the cropped region for the last N dimensions (default: center crop + with less crop/pad towards index 0). + target_shape : Sequence[int], optional + If defined, target_shape specifies the target shape of the "cropped region", + else the crop will be centered cropping offset[dim] voxels on each side (then + the shape is derived by subtracting 2x the dimension-specific offset). + target_shape should have the same number of elements as offsets. + May be implicitly defined by out. + out : np.ndarray, torch.Tensor, optional + Array to store the cropped image in (optional), can be a view on image for + memory-efficiency. + pad : int, str, default=0/zero-pad + Padding strategy to use when padding is required, if int, pad with that value. + + Returns + ------- + out : np.ndarray, torch.Tensor + The image (stack) cropped in the last N dimensions by offsets to the shape + target_shape, or if target_shape is not given image.shape[i+2] - 2*offset[i]. + + Raises + ------ + ValueError + If neither offsets nor target_shape nor out are defined. + ValueError + If out is not target_shape. + TypeError + If the type of image is not an np.ndarray or a torch.Tensor. + RuntimeError + If the dimensionality of image, out, offset or target_shape is invalid or + inconsistent. + + See Also + -------- + numpy.pad + For additional information refer to numpy.pad function. + + Notes + ----- + Either offsets, target_shape or out must be defined. """ if target_shape is None and out is not None: target_shape = out.shape @@ -688,11 +617,11 @@ def crop_transform( _target_shape = image.shape[:-len_off] + tuple( i - 2 * o for i, o in zip(image.shape[-len_off:], offsets) ) + elif len_off != len(target_shape): + raise ValueError( + "Incompatible offset and target_shape dimensionality (at least once)." + ) else: - if len_off != len(target_shape): - raise ValueError( - "Incompatible offset and target_shape dimensionality (at least once)." - ) _target_shape = tuple( i if t == -1 else t for i, t in zip(image.shape[-len_off:], target_shape) diff --git a/CerebNet/datasets/wm_merge_clean.py b/CerebNet/datasets/wm_merge_clean.py index 89ce34dd..4024e78a 100644 --- a/CerebNet/datasets/wm_merge_clean.py +++ b/CerebNet/datasets/wm_merge_clean.py @@ -30,6 +30,9 @@ def locating_unknowns(gm_binary, wm_mask): + """ + Find labels with missing labels, i.e. find holes. + """ selem = ndimage.generate_binary_structure(3, 3) wm_binary = np.array(wm_mask, dtype=np.bool) # gm_binary = (segmap != 0) ^ wm_binary @@ -41,8 +44,7 @@ def locating_unknowns(gm_binary, wm_mask): def drop_disconnected_component( - img_data: npt.NDArray[NT], classes: Iterable[NT] -) -> npt.NDArray[NT]: + img_data: npt.NDArray[NT], classes: Iterable[NT]) -> npt.NDArray[NT]: """ Dropping the smaller disconnected component of each label. """ @@ -64,7 +66,9 @@ def drop_disconnected_component( def filling_unknown_labels(segmap, unknown_mask, candidate_lbls): - + """ + For each unknown voxel in unknown_mask, find and fill it with a candidate. + """ h, w, d = segmap.shape blur_vals = np.ndarray((h, w, d, 0), dtype=np.float) for lbl in candidate_lbls: @@ -80,6 +84,9 @@ def filling_unknown_labels(segmap, unknown_mask, candidate_lbls): def cereb_subseg_lateral_mask(cereb_subseg): + """ + Create mask for left and right cerebellar gray matter. + """ left_gm_idxs = np.array([1, 3, 5, 8, 11, 14, 17, 20, 23, 26]) right_gm_idxs = np.array([2, 4, 7, 10, 13, 16, 19, 22, 25, 28]) @@ -95,6 +102,9 @@ def cereb_subseg_lateral_mask(cereb_subseg): def sphere(radius): + """ + Create a spherical binary mask. + """ shape = (2 * radius + 1,) * 3 struct = np.zeros(shape) x, y, z = np.indices(shape) @@ -260,6 +270,10 @@ def add_cereb_wm(cereb_subseg, aseg, manual_cereb): def correct_cereb_brainstem(cereb_subseg, brainstem, manual_cereb): + """ + Correct brainstem or cereb_subseg according to the + other (select which to correct by manual_cereb). + """ if manual_cereb: print("Correcting brainstem according to cerebellum dzne_manual subseg.") # mapping the overlapping part to dzne_manual labels @@ -272,6 +286,9 @@ def correct_cereb_brainstem(cereb_subseg, brainstem, manual_cereb): def save_mgh_image(img_data, save_path, header, affine): + """ + Save data as mgh image. + """ mgh_out = nib.MGHImage(img_data, header=header, affine=affine) print(f"Saving {save_path}") nib.save(mgh_out, save_path) diff --git a/CerebNet/inference.py b/CerebNet/inference.py index b470d630..d597cd95 100644 --- a/CerebNet/inference.py +++ b/CerebNet/inference.py @@ -14,9 +14,8 @@ # IMPORTS import time -from os import makedirs -from os.path import join, dirname, isfile -from typing import Dict, List, Tuple, Optional +from pathlib import Path +from typing import Dict, List, Tuple, Optional, TYPE_CHECKING from concurrent.futures import Future, ThreadPoolExecutor import nibabel as nib @@ -25,50 +24,68 @@ from torch.utils.data import DataLoader from tqdm import tqdm -from FastSurferCNN.utils import logging +from FastSurferCNN.utils import logging, Plane, PLANES from FastSurferCNN.utils.threads import get_num_threads -from FastSurferCNN.utils.mapper import JsonColorLookupTable, TSVLookupTable +from FastSurferCNN.utils.mapper import JsonColorLookupTable, TSVLookupTable, Mapper from FastSurferCNN.utils.common import ( find_device, SubjectList, SubjectDirectory, - NoParallelExecutor, + SerialExecutor, ) from CerebNet.data_loader.augmentation import ToTensorTest -from CerebNet.data_loader.dataset import SubjectDataset, Plane, PLANES +from CerebNet.data_loader.dataset import SubjectDataset from CerebNet.datasets.utils import crop_transform from CerebNet.models.networks import build_model from CerebNet.utils import checkpoint as cp +if TYPE_CHECKING: + import yacs.config + logger = logging.get_logger(__name__) class Inference: + """ + Manages inference operations, including batch processing, data loading, and model + predictions for neuroimaging data. + """ + + cerebnet_labels: Mapper[str, int] + cereb_name2fs_id: Mapper[str, int] + freesurfer_name2id: Mapper[str, int] + def __init__( self, - cfg: "yacs.ConfigNode", + cfg: "yacs.config.CfgNode", threads: int = -1, async_io: bool = False, device: str = "auto", viewagg_device: str = "auto", ): """ - Create the inference object to manage inferencing, batch processing, data loading, etc. - - Args: - cfg: yaml configuration to populate default values for parameters - threads: number of threads to use, -1 is max (all), which is also the default. - async_io: whether io is run asynchronously (default: False) - device: device to perform inference on (default: auto) - viewagg_device: device to aggregate views on (default: auto) + Create the inference object to manage inferencing, batch processing, data + loading, etc. + + Parameters + ---------- + cfg : yacs.config.CfgNode + Yaml configuration to populate default values for parameters. + threads : int, optional + Number of threads to use, -1 is max (all), which is also the default. + async_io : bool, default=False + Whether io is run asynchronously. + device : str, default="auto" + Device to perform inference on. + viewagg_device : str, default="auto" + Device to aggregate views on. """ self.pool = None self._threads = None self.threads = threads - torch.set_num_threads(get_num_threads() if self._threads is None else self._threads) - self.pool = ( - ThreadPoolExecutor(self._threads) if async_io else NoParallelExecutor() - ) + _threads = get_num_threads() if self._threads is None else self._threads + torch.set_num_threads(_threads) + self.pool = ThreadPoolExecutor(self._threads) if async_io else SerialExecutor() self.cfg = cfg self._async_io = async_io @@ -81,58 +98,52 @@ def __init__( _viewagg_device = torch.device("cpu") else: _viewagg_device = find_device( - viewagg_device, flag_name="viewagg_device", min_memory=2 * (2**30) + viewagg_device, + flag_name="viewagg_device", + min_memory=2 * (2**30), ) self.batch_size = cfg.TEST.BATCH_SIZE - cerebnet_labels_file = join( - cp.FASTSURFER_ROOT, "CerebNet", "config", "CerebNet_ColorLUT.tsv" - ) - _cerebnet_mapper = self.pool.submit( - TSVLookupTable, cerebnet_labels_file, header=True - ) + _models = self._load_model(cfg) + self.device = _device + self.viewagg_device = _viewagg_device - self.freesurfer_color_lut_file = join( - cp.FASTSURFER_ROOT, "FastSurferCNN", "config", "FreeSurferColorLUT.txt" - ) - fs_color_map = self.pool.submit( - TSVLookupTable, self.freesurfer_color_lut_file, header=False - ) - cerebnet2sagittal_lut = join( - cp.FASTSURFER_ROOT, "CerebNet", "config", "CerebNet2Sagittal.json" - ) - cereb2cereb_sagittal = self.pool.submit( - JsonColorLookupTable, cerebnet2sagittal_lut - ) + def prep_lut( + file: Path, *args, **kwargs, + ) -> Future[TSVLookupTable | JsonColorLookupTable]: + _cls = TSVLookupTable + cls = {".json": JsonColorLookupTable, ".txt": _cls, ".tsv": _cls} + return self.pool.submit(cls[file.suffix], file, *args, **kwargs) - cerebnet2freesurfer_lut = join( - cp.FASTSURFER_ROOT, "CerebNet", "config", "CerebNet2FreeSurfer.json" - ) - cereb2freesurfer = self.pool.submit( - JsonColorLookupTable, cerebnet2freesurfer_lut - ) + def lut_path(module: str, file: str) -> Path: + return cp.FASTSURFER_ROOT / module / "config" / file - _models = self._load_model(cfg) + cerebnet_labels_file = lut_path("CerebNet", "CerebNet_ColorLUT.tsv") + _cerebnet_mapper = prep_lut(cerebnet_labels_file, header=True) - self.device = _device - self.viewagg_device = _viewagg_device + self.freesurfer_lut_file = lut_path("FastSurferCNN", "FreeSurferColorLUT.txt") + fs_color_map = prep_lut(self.freesurfer_lut_file, header=False) - self.cerebnet_labels = _cerebnet_mapper.result().labelname2id() + cerebnet2sagittal_lut = lut_path("CerebNet", "CerebNet2Sagittal.json") + sagittal_cereb2cereb_mapper = prep_lut(cerebnet2sagittal_lut) + cerebnet2freesurfer_lut = lut_path("CerebNet", "CerebNet2FreeSurfer.json") + cereb2freesurfer_mapper = prep_lut(cerebnet2freesurfer_lut) + + self.cerebnet_labels = _cerebnet_mapper.result().labelname2id() self.freesurfer_name2id = fs_color_map.result().labelname2id() - self.cereb_name2freesurfer_id = ( - cereb2freesurfer.result().labelname2id().chain(self.freesurfer_name2id) + cereb_name2fs_name: Mapper[str, str] = ( + cereb2freesurfer_mapper.result().labelname2id() ) - - # the id in cereb2freesurfer is also a labelname, i.e. cereb2freesurfer is a map of Labelname2Labelname - self.cereb2fs = self.cerebnet_labels.__reversed__().chain( - self.cereb_name2freesurfer_id + cerebsag_name2cereb_name: Mapper[str, str] = ( + sagittal_cereb2cereb_mapper.result().labelname2id() ) - self.cereb2cereb_sagittal = self.cerebnet_labels.__reversed__().chain( - cereb2cereb_sagittal.result().labelname2id() - ) + cereb_id2name = self.cerebnet_labels.__reversed__() + self.cereb_name2fs_id = cereb_name2fs_name.chain(self.freesurfer_name2id) + self.cereb_id2fs_id = cereb_id2name.chain(self.cereb_name2fs_id) + self.cerebsag_id2cereb_name = cereb_id2name.chain(cerebsag_name2cereb_name) self.models = {k: m.to(self.device) for k, m in _models.items()} @property @@ -151,16 +162,17 @@ def __del__(self): def _load_model(self, cfg) -> Dict[Plane, torch.nn.Module]: """Loads the three models per plane.""" - def __load_model(cfg: "yacs.ConfigNode", plane: Plane) -> torch.nn.Module: + def __load_model(cfg: "yacs.config.CfgNode", plane: Plane) -> torch.nn.Module: params = {k.lower(): v for k, v in dict(cfg.MODEL).items()} params["plane"] = plane if plane == "sagittal": if params["num_classes"] != params["num_classes_sag"]: params["num_classes"] = params["num_classes_sag"] - checkpoint_path = cfg.TEST[f"{plane.upper()}_CHECKPOINT_PATH"] + checkpoint_path = Path(cfg.TEST[f"{plane.upper()}_CHECKPOINT_PATH"]) model = build_model(params) - if not isfile(checkpoint_path): - # if the checkpoint path is not a file, but a folder search in there for the newest checkpoint + if not checkpoint_path.is_file(): + # if the checkpoint path is not a file, but a folder search in there for + # the newest checkpoint checkpoint_path = cp.get_checkpoint_path(checkpoint_path).pop() cp.load_from_checkpoint(checkpoint_path, model) model.eval() @@ -188,7 +200,8 @@ def _predict_single_subject( from CerebNet.data_loader.data_utils import slice_lia2ras, slice_ras2lia for img in img_loader: - # CerebNet is trained on RAS+ conventions, so we need to map between lia (FastSurfer) and RAS+ + # CerebNet is trained on RAS+ conventions, so we need to map between + # lia (FastSurfer) and RAS+ # map LIA 2 RAS img = slice_lia2ras(plane, img, thick_slices=True) batch = img.to(self.device) @@ -208,13 +221,17 @@ def _predict_single_subject( def _post_process_preds( self, preds: Dict[Plane, List[torch.Tensor]] ) -> Dict[Plane, torch.Tensor]: - """Permutes axes, so it has consistent sagittal, coronal, axial, channels format. Also maps - classes of sagittal predictions into the global label space + """ + Permutes axes, so it has consistent sagittal, coronal, axial, channels format. + Also maps classes of sagittal predictions into the global label space. - Args: - preds: predicted logits. + Parameters + ---------- + preds: + predicted logits. - Returns: + Returns + ------- dictionary of permuted logits. """ axis_permutation = { @@ -229,7 +246,7 @@ def _post_process_preds( def _convert(plane: Plane) -> torch.Tensor: pred = torch.cat(preds[plane], dim=0) if plane == "sagittal": - pred = self.cereb2cereb_sagittal.map_probs(pred, axis=1, reverse=True) + pred = self.cerebsag_id2cereb_name.map_probs(pred, axis=1, reverse=True) return pred.permute(axis_permutation[plane]) return {plane: _convert(plane) for plane in preds.keys()} @@ -264,7 +281,7 @@ def _get_ids_startswith(_label_map: Dict[int, str], prefix: str) -> List[int]: if name.startswith(prefix) and not name.endswith("Medullare") ] - freesurfer_id2cereb_name = self.cereb_name2freesurfer_id.__reversed__() + freesurfer_id2cereb_name = self.cereb_name2fs_id.__reversed__() freesurfer_id2name = self.freesurfer_name2id.__reversed__() label_map = dict(freesurfer_id2cereb_name) meta_labels = { @@ -283,6 +300,7 @@ def _get_ids_startswith(_label_map: Dict[int, str], prefix: str) -> List[int]: table = pv_calc( seg_data, norm_data, + norm_data, list(filter(lambda l: l != 0, label_map.keys())), vox_vol=vox_vol, threads=self.threads, @@ -309,18 +327,27 @@ def _get_ids_startswith(_label_map: Dict[int, str], prefix: str) -> List[int]: return dataframe def _save_cerebnet_seg( - self, cerebnet_seg: np.ndarray, filename: str, orig: nib.analyze.SpatialImage + self, + cerebnet_seg: np.ndarray, + filename: str | Path, + orig: nib.analyze.SpatialImage ) -> "Future[None]": """ Saving the segmentations asynchronously. - Args: - cerebnet_seg: segmentation data - filename: path and file name to the saved file - bounding_box: bounding box from the full image to fill with the segmentation - orig: file container (with header and affine) used to populate header and affine of the segmentation - - Returns: + Parameters + ---------- + cerebnet_seg : np.ndarray + Segmentation data. + filename : Path, str + Path and file name to the saved file. + orig : nib.analyze.SpatialImage + File container (with header and affine) used to populate header and affine + of the segmentation. + + Returns + ------- + Future[None] A Future to determine when the file was saved. Result is None. """ from FastSurferCNN.data_loader.data_utils import save_image @@ -334,8 +361,11 @@ def _save_cerebnet_seg( def _get_subject_dataset( self, subject: SubjectDirectory - ) -> Tuple[Optional[np.ndarray], Optional[str], SubjectDataset]: - """Load and prepare input files asynchronously, then locate the cerebellum and provide a localized patch.""" + ) -> Tuple[Optional[np.ndarray], Optional[Path], SubjectDataset]: + """ + Load and prepare input files asynchronously, then locate the cerebellum and + provide a localized patch. + """ from FastSurferCNN.data_loader.data_utils import load_image, load_maybe_conform @@ -345,8 +375,9 @@ def _get_subject_dataset( from FastSurferCNN.utils.parser_defaults import ALL_FLAGS raise ValueError( - f"Cannot resolve the intended filename {subject.get_attribute('cereb_statsfile')} " - f"for the cereb_statsfile, maybe specify an absolute path via " + f"Cannot resolve the intended filename " + f"{subject.get_attribute('cereb_statsfile')} for the " + f"cereb_statsfile, maybe specify an absolute path via " f"{ALL_FLAGS['cereb_statsfile'](dict)['flag']}." ) if not subject.has_attribute( @@ -355,9 +386,11 @@ def _get_subject_dataset( from FastSurferCNN.utils.parser_defaults import ALL_FLAGS raise ValueError( - f"Cannot resolve the file name {subject.get_attribute('norm_name')} for the " - f"bias field corrected image, maybe specify an absolute path via " - f"{ALL_FLAGS['norm_name'](dict)['flag']} or the file does not exist." + f"Cannot resolve the file name " + f"{subject.get_attribute('norm_name')} for the bias field " + f"corrected image, maybe specify an absolute path via " + f"{ALL_FLAGS['norm_name'](dict)['flag']} or the file does not " + f"exist." ) norm_file = subject.filename_by_attribute("norm_name") @@ -369,8 +402,8 @@ def _get_subject_dataset( # localization if not subject.fileexists_by_attribute("asegdkt_segfile"): raise RuntimeError( - f"The aseg.DKT-segmentation file '{subject.asegdkt_segfile}' did not exist, " - "please run FastSurferVINN first." + f"The aseg.DKT-segmentation file '{subject.asegdkt_segfile}' did not " + f"exist, please run FastSurferVINN first." ) seg = self.pool.submit( load_image, subject.filename_by_attribute("asegdkt_segfile") @@ -397,7 +430,7 @@ def _get_subject_dataset( norm_file, _, norm_data = norm.result() return norm_data, norm_file, subject_dataset - def run(self, subject_directories: SubjectList): + def run(self, subject_dirs: SubjectList): logger.info(time.strftime("%y-%m-%d_%H:%M:%S")) from tqdm.contrib.logging import logging_redirect_tqdm @@ -408,19 +441,17 @@ def run(self, subject_directories: SubjectList): from FastSurferCNN.utils.common import pipeline as iterate else: from FastSurferCNN.utils.common import iterate - iter_subjects = iterate( - self.pool, self._get_subject_dataset, subject_directories - ) + iter_subjects = iterate(self.pool, self._get_subject_dataset, subject_dirs) futures = [] for idx, (subject, (norm, norm_file, subject_dataset)) in tqdm( - enumerate(iter_subjects), total=len(subject_directories), desc="Subject" + enumerate(iter_subjects), total=len(subject_dirs), desc="Subject", ): try: # predict CerebNet, returns logits preds = self._predict_single_subject(subject_dataset) # create the folder for the output file, if it does not exist _mkdir = self.pool.submit( - makedirs, dirname(subject.segfile), exist_ok=True + subject.segfile.parent.mkdir, exist_ok=True, parents=True, ) # postprocess logits (move axes, map sagittal to all classes) @@ -428,8 +459,9 @@ def run(self, subject_directories: SubjectList): # view aggregation in logit space and find max label cerebnet_seg = self._view_aggregation(preds_per_plane) - # map predictions into FreeSurfer Label space & move segmentation to cpu - cerebnet_seg = self.cereb2fs.map(cerebnet_seg).cpu() + # map predictions into FreeSurfer Label space & move segmentation to + # cpu + cerebnet_seg = self.cereb_id2fs_id.map(cerebnet_seg).cpu() pred_time = time.time() # uncrop the segmentation @@ -440,9 +472,8 @@ def run(self, subject_directories: SubjectList): target_shape=bounding_box["source_shape"], ).numpy() - _ = ( - _mkdir.result() - ) # this is None, but synchronizes the creation of the directory + # this is None, but synchronizes the creation of the directory + _ = _mkdir.result() futures.append( self._save_cerebnet_seg( full_cereb_seg, @@ -452,14 +483,15 @@ def run(self, subject_directories: SubjectList): ) if subject.has_attribute("cereb_statsfile"): - # vox_vol = np.prod(norm.header.get_zooms()).item() # CerebNet always has vox_vol 1 + # vox_vol = np.prod(norm.header.get_zooms()).item() + # CerebNet always has vox_vol 1 if norm is None: raise RuntimeError("norm not loaded as expected!") df = self._calc_segstats(full_cereb_seg, norm, vox_vol=1.0) from FastSurferCNN.segstats import write_statsfile - # in batch processing, we are finished with this subject and the output of this data can be - # outsourced to a different process + # in batch processing, we are finished with this subject and the + # output of this data can be outsourced to a different process futures.append( self.pool.submit( write_statsfile, @@ -468,13 +500,19 @@ def run(self, subject_directories: SubjectList): vox_vol=1.0, segfile=subject.segfile, normfile=norm_file, - lut=self.freesurfer_color_lut_file, + lut=self.freesurfer_lut_file, + volume_precision="3", + exclude=[0], + pvfile=norm_file, + report_empty=True, + extra_header=[], ) ) logger.info( - f"Subject {idx + 1}/{len(subject_directories)} with id '{subject.id}' " - f"processed in {pred_time - start_time :.2f} sec." + f"Subject {idx + 1}/{len(subject_dirs)} with id " + f"'{subject.id}' processed in {pred_time - start_time :.2f} " + f"sec." ) except Exception as e: logger.exception(e) diff --git a/CerebNet/models/__init__.py b/CerebNet/models/__init__.py index e69de29b..8b00d4c4 100644 --- a/CerebNet/models/__init__.py +++ b/CerebNet/models/__init__.py @@ -0,0 +1,4 @@ +__all__ = [ + "networks", + "sub_module", +] \ No newline at end of file diff --git a/CerebNet/models/networks.py b/CerebNet/models/networks.py index d98ddb94..fb8af95f 100644 --- a/CerebNet/models/networks.py +++ b/CerebNet/models/networks.py @@ -20,7 +20,6 @@ import torch.nn as nn from FastSurferCNN.utils import logging - from CerebNet.models import sub_module as sm @@ -40,6 +39,9 @@ class FastSurferCNN(nn.Module): """ def __init__(self, params): + """ + Create the FastSurferCNN model. + """ super(FastSurferCNN, self).__init__() # Parameters for the Descending Arm @@ -112,6 +114,9 @@ def forward(self, x): def build_model(params: Mapping) -> torch.nn.Module: + """ + Build the model based on the params Mapping. + """ params = {k.lower(): v for k, v in dict(params).items()} assert ( params["model_name"] in _MODELS.keys() diff --git a/CerebNet/models/sub_module.py b/CerebNet/models/sub_module.py index cd86bfeb..3833fae2 100644 --- a/CerebNet/models/sub_module.py +++ b/CerebNet/models/sub_module.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -21,31 +21,37 @@ # Building Blocks class CompetitiveDenseBlock(nn.Module): """ - Function to define a competitive dense block comprising of 3 convolutional layers, with BN/ReLU - - Inputs: - -- Params - params = {'num_channels': 1, - 'num_filters': 64, - 'kernel_h': 5, - 'kernel_w': 5, - 'stride_conv': 1, - 'pool': 2, - 'stride_pool': 2, - 'num_classes': 44 - 'kernel_c':1 - 'input':True - } + Define a competitive dense block. + + A dense block consists of 3 convolutional layers, with BN/ReLU. + + Parameters + ---------- + params = {'num_channels' : 1, + 'num_filters' : 64, + 'kernel_h' : 5, + 'kernel_w' : 5, + 'stride_conv' : 1, + 'pool' : 2, + 'stride_pool' : 2, + 'num_classes' : 44, + 'kernel_c' : 1, + 'input' : True + }. """ def __init__(self, params, outblock=False, discriminator_block=False): """ - Constructor to initialize the Competitive Dense Block - :param dict params: dictionary with parameters specifying block architecture - :param bool outblock: Flag indicating if last block (before classifier block) is set up. - Default: False - :param bool discriminator_block: Flag indicating if the block is discriminator block or not - :return None: + Constructor to initialize the Competitive Dense Block. + + Parameters + ---------- + params : dict + Dictionary with parameters specifying block architecture. + outblock : bool, default=False + Flag indicating if last block (before classifier block) is set up. + discriminator_block : bool, default=False + Flag indicating if the block is discriminator block or not. """ super(CompetitiveDenseBlock, self).__init__() @@ -99,10 +105,17 @@ def forward(self, x): """ CompetitiveDenseBlock's computational Graph {in (Conv - BN from prev. block) -> PReLU} -> {Conv -> BN -> Maxout -> PReLU} x 2 -> {Conv -> BN} -> out - end with batch-normed output to allow maxout across skip-connections + end with batch-normed output to allow maxout across skip-connections. - :param tensor x: input tensor (image or feature map) - :return tensor out: output tensor (processed feature map) + Parameters + ---------- + x : tensor + Input tensor (image or feature map). + + Returns + ------- + tensor + Output tensor (processed feature map). """ # Activation from pooled input x0 = self.prelu(x) @@ -138,27 +151,31 @@ def forward(self, x): class CompetitiveDenseBlockInput(nn.Module): """ - Function to define a competitive dense block comprising of 3 convolutional layers, with BN/ReLU for input - - Inputs: - -- Params - params = {'num_channels': 1, - 'num_filters': 64, - 'kernel_h': 5, - 'kernel_w': 5, - 'stride_conv': 1, - 'pool': 2, - 'stride_pool': 2, - 'num_classes': 44 - 'kernel_c':1 - 'input':True - } + Function to define a competitive dense block comprising of + 3 convolutional layers, with BN/ReLU for input. + + Parameters + ---------- + params = {'num_channels' : 1, + 'num_filters' : 64, + 'kernel_h' : 5, + 'kernel_w' : 5, + 'stride_conv' : 1, + 'pool' : 2, + 'stride_pool' : 2, + 'num_classes' : 44, + 'kernel_c' : 1, + 'input' : True + }. """ - def __init__(self, params): """ - Constructor to initialize the Competitive Dense Block - :param dict params: dictionary with parameters specifying block architecture + Constructor to initialize the Competitive Dense Block. + + Parameters + ---------- + params : dict + Dictionary with parameters specifying block architecture. """ super(CompetitiveDenseBlockInput, self).__init__() @@ -208,10 +225,17 @@ def __init__(self, params): def forward(self, x): """ CompetitiveDenseBlockInput's computational Graph - in -> BN -> {Conv -> BN -> PReLU} -> {Conv -> BN -> Maxout -> PReLU} -> {Conv -> BN} -> out + in -> BN -> {Conv -> BN -> PReLU} -> {Conv -> BN -> Maxout -> PReLU} -> {Conv -> BN} -> out. - :param tensor x: input tensor (image or feature map) - :return tensor out: output tensor (processed feature map) + Parameters + ---------- + x : tensor + Input tensor (image or feature map). + + Returns + ------- + tensor + Output tensor (processed feature map). """ # Input batch normalization x0_bn = self.bn0(x) @@ -240,13 +264,17 @@ def forward(self, x): class CompetitiveEncoderBlock(CompetitiveDenseBlock): """ - Encoder Block = CompetitiveDenseBlock + Max Pooling + Encoder Block = CompetitiveDenseBlock + Max Pooling. """ def __init__(self, params): """ - Encoder Block initialization - :param dict params: parameters like number of channels, stride etc. + Encoder Block initialization. + + Parameters + ---------- + params : dict + Parameters like number of channels, stride etc. """ super(CompetitiveEncoderBlock, self).__init__(params) self.maxpool = nn.MaxPool2d( @@ -257,12 +285,23 @@ def __init__(self, params): def forward(self, x): """ - CComputational graph for Encoder Block: + CComputational graph for Encoder Block : * CompetitiveDenseBlock * Max Pooling (+ retain indices) - :param tensor x: feature map from previous block - :return: original feature map, maxpooled feature map, maxpool indices + Parameters + ---------- + x : tensor + Feature map from previous block. + + Returns + ------- + out_encoder : Tensor + Original feature map. + out_block : Tensor + Maxpooled feature map. + indicies : Tensor + Maxpool indices. """ out_block = super(CompetitiveEncoderBlock, self).forward( x @@ -275,13 +314,17 @@ def forward(self, x): class CompetitiveEncoderBlockInput(CompetitiveDenseBlockInput): """ - Encoder Block = CompetitiveDenseBlockInput + Max Pooling + Encoder Block = CompetitiveDenseBlockInput + Max Pooling. """ def __init__(self, params): """ - Encoder Block initialization - :param dict params: parameters like number of channels, stride etc. + Encoder Block initialization. + + Parameters + ---------- + params : dict + Parameters like number of channels, stride etc. """ super(CompetitiveEncoderBlockInput, self).__init__( params @@ -298,8 +341,19 @@ def forward(self, x): * CompetitiveDenseBlockInput * Max Pooling (+ retain indices) - :param tensor x: feature map from previous block - :return: original feature map, maxpooled feature map, maxpool indices + Parameters + ---------- + x : tensor + Feature map from previous block. + + Returns + ------- + Tensor + The original feature map as received by the block. + Tensor + The maxpooled feature map after applying max pooling to the original feature map. + Tensor + The indices of the maxpool operation. """ out_block = super(CompetitiveEncoderBlockInput, self).forward( x @@ -317,10 +371,15 @@ class CompetitiveDecoderBlock(CompetitiveDenseBlock): def __init__(self, params, outblock=False): """ - Decoder Block initialization - :param dict params: parameters like number of channels, stride etc. - :param bool outblock: Flag, indicating if last block of network before classifier - is created. Default: False + Decoder Block initialization. + + Parameters + ---------- + params : dict + Parameters like number of channels, stride etc. + outblock : bool, default=False + Flag, indicating if last block of network before classifier + is created. """ super(CompetitiveDecoderBlock, self).__init__(params, outblock=outblock) self.unpool = nn.MaxUnpool2d( @@ -334,10 +393,19 @@ def forward(self, x, out_block, indices): * Maxout combination of unpooled map + skip connection * Forwarding toward CompetitiveDenseBlock - :param tensor x: input feature map from lower block (gets unpooled and maxed with out_block) - :param tensor out_block: skip connection feature map from the corresponding Encoder - :param tensor indices: indices for unpooling from the corresponding Encoder (maxpool op) - :return: processed feature maps + Parameters + ---------- + x : tensor + Input feature map from lower block (gets unpooled and maxed with out_block). + out_block : tensor + Skip connection feature map from the corresponding Encoder. + indices : tensor + Indices for unpooling from the corresponding Encoder (maxpool op). + + Returns + ------- + out_block + Processed feature maps. """ unpool = self.unpool(x, indices) unpool = torch.unsqueeze(unpool, 4) @@ -352,7 +420,7 @@ def forward(self, x, out_block, indices): class ClassifierBlock(nn.Module): """ - Classification Block + Classification Block. """ def __init__(self, params): diff --git a/CerebNet/run_prediction.py b/CerebNet/run_prediction.py index 5eae1776..410adcd5 100644 --- a/CerebNet/run_prediction.py +++ b/CerebNet/run_prediction.py @@ -1,4 +1,4 @@ -import os +#!/bin/python # Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn # @@ -17,23 +17,36 @@ # IMPORTS import sys import argparse +from pathlib import Path -from FastSurferCNN.utils import logging, parser_defaults - -from CerebNet.utils.load_config import get_config -from CerebNet.inference import Inference -from FastSurferCNN.utils.checkpoint import get_checkpoints +from FastSurferCNN.utils import logging, parser_defaults, Plane, PLANES +from FastSurferCNN.utils.checkpoint import ( + get_checkpoints, + load_checkpoint_config_defaults, +) from FastSurferCNN.utils.common import assert_no_root, SubjectList +from CerebNet.inference import Inference +from CerebNet.utils.checkpoint import YAML_DEFAULT as CHECKPOINT_PATHS_FILE +from CerebNet.utils.load_config import get_config logger = logging.get_logger(__name__) -DEFAULT_CEREBELLUM_STATSFILE = "stats/cerebellum.CerebNet.stats" +DEFAULT_CEREBELLUM_STATSFILE = Path("stats/cerebellum.CerebNet.stats") def setup_options(): + """ + Configure and return an argument parser for the segmentation script. + + Returns + ------- + argparse.ArgumentParser + The configured argument parser. + """ # Training settings parser = argparse.ArgumentParser(description="Segmentation") - # 1. Directory information (where to read from, where to write from and to incl. search-tag) + # 1. Directory information (where to read from, where to write from and to incl. + # search-tag) parser = parser_defaults.add_arguments( parser, ["in_dir", "tag", "csv_file", "sd", "sid", "remove_suffix"] ) @@ -45,18 +58,22 @@ def setup_options(): parser.add_argument( "--cereb_segfile", dest="cereb_segfile", - default="mri/cerebellum.CerebNet.nii.gz", - help="Name under which segmentation will be saved. Default: mri/cerebellum.CerebNet.nii.gz.", + default=Path("mri/cerebellum.CerebNet.nii.gz"), + type=Path, + help="Name under which segmentation will be saved. " + "Default: mri/cerebellum.CerebNet.nii.gz.", ) # 3. Options for additional files and parameters parser.add_argument( "--cereb_statsfile", dest="cereb_statsfile", + type=Path, default=None, - help=f"Name under which the statsfield for the cerebellum will be saved. Default: None, do not " - f'calculate stats file. This option supports the special option "default", which saves the ' - f"stats file at {DEFAULT_CEREBELLUM_STATSFILE} in the subject directory.", + help=f"Name under which the statsfield for the cerebellum will be saved. " + f"Default: None, do not calculate stats file. This option supports the " + f"special option 'default', which saves the stats file at " + f"{DEFAULT_CEREBELLUM_STATSFILE} in the subject directory.", ) parser = parser_defaults.add_arguments(parser, ["seg_log"]) @@ -67,20 +84,14 @@ def setup_options(): ["device", "viewagg_device", "threads", "batch_size", "async_io", "allow_root"], ) - from CerebNet.utils.checkpoint import CEREBNET_COR, CEREBNET_AXI, CEREBNET_SAG - + files: dict[Plane, str | Path] = {k: "default" for k in PLANES} parser_defaults.add_plane_flags( advanced, "checkpoint", - {"coronal": CEREBNET_COR, "axial": CEREBNET_AXI, "sagittal": CEREBNET_SAG}, + files, + CHECKPOINT_PATHS_FILE, ) - parser.add_argument( - "--cfg", - dest="cfg_file", - help="Path to the config file", - type=str, - ) parser.add_argument( "opts", help="See CerebNet/config/cerebnet.py for additional options", @@ -93,39 +104,65 @@ def setup_options(): return parser -def main(args): +def main(args: argparse.Namespace) -> int | str: + """ + Main function to run the inference based on the given command line arguments. + This implementation is inspired by methods described in CerebNet for cerebellum + sub-segmentation. + + Parameters + ---------- + args : argparse.Namespace + Command line arguments parsed by `argparse.ArgumentParser`. + + Returns + ------- + int + Returns 0 upon successful execution to indicate success. + str + A message indicating the failure reason in case of an exception. + + References + ---------- + Faber J, Kuegler D, Bahrami E, et al. CerebNet: A fast and reliable deep-learning + pipeline for detailed cerebellum sub-segmentation. NeuroImage 264 (2022), 119703. + https://doi.org/10.1016/j.neuroimage.2022.119703 + """ cfg = get_config(args) cfg.TEST.ENABLE = True cfg.TRAIN.ENABLE = False # Warning if run as root user - args.allow_root or assert_no_root() + getattr(args, "allow_root", False) or assert_no_root() # Set up logging from FastSurferCNN.utils.logging import setup_logging - from CerebNet.utils.checkpoint import URL as CEREBNET_URL - setup_logging(args.log_name) + setup_logging(getattr(args, "log_name")) subjects_kwargs = {} cereb_statsfile = getattr(args, "cereb_statsfile", None) - if cereb_statsfile == "default": - args.cereb_statsfile = DEFAULT_CEREBELLUM_STATSFILE + if cereb_statsfile is None or str(cereb_statsfile) == "default": + cereb_statsfile = DEFAULT_CEREBELLUM_STATSFILE + args.cereb_statsfile = cereb_statsfile if cereb_statsfile is not None: subjects_kwargs["cereb_statsfile"] = "cereb_statsfile" if not hasattr(args, "norm_name"): return ( - f"Execution failed because `--cereb_statsfile {cereb_statsfile}` requires `--norm_name ` " - f"to be passed!" + f"Execution failed because `--cereb_statsfile {cereb_statsfile}` " + f"requires `--norm_name ` to be passed!" ) subjects_kwargs["norm_name"] = "norm_name" logger.info("Checking or downloading default checkpoints ...") - get_checkpoints(args.ckpt_ax, args.ckpt_cor, args.ckpt_sag, url=CEREBNET_URL) + + urls = load_checkpoint_config_defaults("url", filename=CHECKPOINT_PATHS_FILE) + + get_checkpoints(args.ckpt_ax, args.ckpt_cor, args.ckpt_sag, urls=urls) # Check input and output options and get all subjects of interest subjects = SubjectList( - args, asegdkt_segfile="pred_name", segfile="cereb_segfile", **subjects_kwargs + args, asegdkt_segfile="pred_name", segfile="cereb_segfile", **subjects_kwargs, ) try: @@ -143,4 +180,6 @@ def main(args): if __name__ == "__main__": parser = setup_options() - sys.exit(main(parser.parse_args())) + args = parser.parse_args() + + sys.exit(main(args)) diff --git a/CerebNet/utils/__init__.py b/CerebNet/utils/__init__.py index e69de29b..203a3416 100644 --- a/CerebNet/utils/__init__.py +++ b/CerebNet/utils/__init__.py @@ -0,0 +1,8 @@ +__all__ = [ + "checkpoint", + "load_config", + "lr_scheduler", + "meters", + "metrics", + "misc", +] \ No newline at end of file diff --git a/CerebNet/utils/checkpoint.py b/CerebNet/utils/checkpoint.py index fec664bc..e8cdf1ca 100644 --- a/CerebNet/utils/checkpoint.py +++ b/CerebNet/utils/checkpoint.py @@ -14,11 +14,13 @@ # IMPORTS -import os +from typing import TYPE_CHECKING +if TYPE_CHECKING: + import yacs +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT from FastSurferCNN.utils import logging from FastSurferCNN.utils.checkpoint import ( - FASTSURFER_ROOT, load_from_checkpoint, create_checkpoint_dir, get_checkpoint, @@ -26,23 +28,22 @@ save_checkpoint, ) -logger = logging.get_logger(__name__) +# DEFAULTS +YAML_DEFAULT = FASTSURFER_ROOT / "CerebNet/config/checkpoint_paths.yaml" -# Defaults -URL = "https://b2share.fz-juelich.de/api/files/c6cf7bc6-2ae5-4d0e-814d-2a3cf0e1a8c5" -CEREBNET_AXI = os.path.join(FASTSURFER_ROOT, "checkpoints/CerebNet_axial_v1.0.0.pkl") -CEREBNET_COR = os.path.join(FASTSURFER_ROOT, "checkpoints/CerebNet_coronal_v1.0.0.pkl") -CEREBNET_SAG = os.path.join(FASTSURFER_ROOT, "checkpoints/CerebNet_sagittal_v1.0.0.pkl") +logger = logging.get_logger(__name__) -def is_checkpoint_epoch(cfg, cur_epoch): +def is_checkpoint_epoch(cfg: "yacs.config.CfgNode", cur_epoch: int) -> bool: """ - Check if checkpoint need to be saved - Check if the - :param cfg: - :param cur_epoch: - :return: + Check if checkpoint need to be saved. + + Parameters + ---------- + cfg : yacs.config.CfgNode + The config node. + cur_epoch : int + The current epoch number to check if this is the last epoch. """ final_epoch = (cur_epoch + 1) == cfg.TRAIN.NUM_EPOCHS - is_checkpoint = (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD or final_epoch - return is_checkpoint + return (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD or final_epoch diff --git a/CerebNet/utils/load_config.py b/CerebNet/utils/load_config.py index 29a5d18d..ac54f0ba 100644 --- a/CerebNet/utils/load_config.py +++ b/CerebNet/utils/load_config.py @@ -14,28 +14,22 @@ # IMPORTS -import argparse -import os.path -import sys -from os.path import join, split, splitext +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import yacs.config -from CerebNet.utils.checkpoint import CEREBNET_AXI, CEREBNET_SAG, CEREBNET_COR from CerebNet.config import get_cfg_cerebnet +from FastSurferCNN.utils import PLANES -def get_config(args) -> "yacs.CfgNode": +def get_config(args) -> "yacs.config.CfgNode": """ Given the arguments, load and initialize the config_files. - """ # Setup cfg. cfg = get_cfg_cerebnet() # Load config from cfg. - if getattr(args, "cfg_file") is not None: - if os.path.exists(args.cfg_file): - cfg.merge_from_file(args.cfg_file) - else: - raise RuntimeError(f"The config file {args.cfg_file} does not exist.") # Load config from command line, overwrite config from opts. if args.opts is not None: cfg.merge_from_list(args.opts) @@ -43,34 +37,14 @@ def get_config(args) -> "yacs.CfgNode": if hasattr(args, "rng_seed"): cfg.RNG_SEED = args.rng_seed if hasattr(args, "out_dir"): - cfg.LOG_DIR = args.out_dir - - if getattr(args, "cfg_file") is not None: - # derive some paths relative to the config file - cfg_file_name = splitext(split(args.cfg_file)[1])[0] - - if cfg.TEST.ENABLE: - cfg_file_name_first = "_".join(cfg_file_name.split("_")) - cfg.TEST.RESULTS_DIR = join(cfg.TEST.RESULTS_DIR, cfg_file_name_first) - - cfg.LOG_DIR = join(cfg.LOG_DIR, cfg_file_name) + cfg.LOG_DIR = str(args.out_dir) - # populate default paths for the checkpoints - default_paths = [ - ("ax_ckpt", CEREBNET_AXI), - ("sag_ckpt", CEREBNET_SAG), - ("cor_ckpt", CEREBNET_COR), - ] path_ax, path_sag, path_cor = [ - getattr(args, name, default_path) for name, default_path in default_paths + getattr(args, name) for name in ["ckpt_ax", "ckpt_sag", "ckpt_cor"] ] - for plane, path in [ - ("axial", path_ax), - ("sagittal", path_sag), - ("coronal", path_cor), - ]: - setattr(cfg.TEST, f"{plane.upper()}_CHECKPOINT_PATH", path) + for plane, path in zip(PLANES, (path_ax, path_cor, path_sag)): + setattr(cfg.TEST, f"{plane.upper()}_CHECKPOINT_PATH", str(path)) # overwrite the batch size if it is passed as a parameter batch_size = getattr(args, "batch_size", None) @@ -78,24 +52,3 @@ def get_config(args) -> "yacs.CfgNode": cfg.TEST.BATCH_SIZE = batch_size return cfg - - -def setup_options(): - parser = argparse.ArgumentParser(description="Segmentation") - parser.add_argument( - "--cfg", - dest="cfg_file", - help="Path to the config file", - default="config_files/CerebNet.yaml", - type=str, - ) - parser.add_argument( - "opts", - help="See CerebNet/config/cerebnet.py for all options", - default=None, - nargs=argparse.REMAINDER, - ) - - if len(sys.argv) == 1: - parser.print_help() - return parser.parse_args() diff --git a/CerebNet/utils/lr_scheduler.py b/CerebNet/utils/lr_scheduler.py index f4253a26..5e3eab3d 100644 --- a/CerebNet/utils/lr_scheduler.py +++ b/CerebNet/utils/lr_scheduler.py @@ -29,8 +29,9 @@ class ReduceLROnPlateauWithRestarts(ReduceLROnPlateau): - """Extends the ReduceLROnPlateau class with the restart ability.""" - + """ + Extends the ReduceLROnPlateau class with the restart ability. + """ def __init__(self, optimizer, *args, T_0=10, Tmult=1, lr_restart=None, **kwargs): """ Create a ReduceLROnPlateauWithRestarts learning rate scheduler. @@ -71,7 +72,21 @@ def __init__(self, optimizer, *args, T_0=10, Tmult=1, lr_restart=None, **kwargs) self.lr_restart = 1 def step(self, metrics, epoch=None): - + """ + Perfroms an optimization step. + + Parameters + ---------- + metrics : float + The value of matrics= used to determine learning rate adjustments. + epoch : int, default=None + Number of epochs. + + Notes + ----- + For details, refer to the PyTorch documentation for `ReduceLROnPlateau` at: + https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html + """ self.Tcur += 1 super(ReduceLROnPlateauWithRestarts, self).step(metrics, epoch) if self.Tcur >= self.T_i: @@ -79,6 +94,9 @@ def step(self, metrics, epoch=None): self._last_lr = [group["lr"] for group in self.optimizer.param_groups] def _reset_lr(self): + """ + Internal method to reset the learning rate. + """ self.Tcur = 0 self.T_i *= self.Tmult self.i += 1 @@ -108,6 +126,9 @@ def _reset_lr(self): # https://detectron2.readthedocs.io/_modules/detectron2/solver/lr_scheduler.html class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): + """ + Learning Rate scheduler that combines a cosine schedule with a warmup phase. + """ def __init__( self, optimizer: torch.optim.Optimizer, @@ -124,6 +145,9 @@ def __init__( super().__init__(optimizer, last_epoch) def get_lr(self) -> List[float]: + """ + Get the learning rates at the current epoch. + """ warmup_factor = _get_warmup_factor_at_iter( self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor ) @@ -175,6 +199,9 @@ def _get_warmup_factor_at_iter( class CosineLR: + """ + Learning rate scheduler that follows a Cosine trajectory. + """ def __init__(self, base_lr, eta_min, max_epoch): self.base_lr = base_lr self.max_epoch = max_epoch @@ -182,8 +209,12 @@ def __init__(self, base_lr, eta_min, max_epoch): def lr_func_cosine(self, cur_epoch): """ + Get the learning rate following a cosine pattern for the epoch `cur_epoch`. - cur_epoch (float): the number of epoch of the current training stage. + Parameters + ---------- + cur_epoch : int + The number of epoch of the current training stage. """ return self.eta_min + ( (self.base_lr - self.eta_min) @@ -194,9 +225,13 @@ def lr_func_cosine(self, cur_epoch): def set_lr(self, optimizer, epoch): """ Sets the optimizer lr to the specified value. - Args: - optimizer (optim): the optimizer using to optimize the current network. - new_lr (float): the new learning rate to set. + + Parameters + ---------- + optimizer : torch.optim.Optimizer + The optimizer using to optimize the current network. + epoch : int + The epoch for which to update the learning rate. """ new_lr = self.get_epoch_lr(epoch) for param_group in optimizer.param_groups: @@ -205,13 +240,20 @@ def set_lr(self, optimizer, epoch): def get_epoch_lr(self, cur_epoch): """ Retrieves the lr for the given epoch (as specified by the lr policy). - Args: - cur_epoch (float): the number of epoch of the current training stage. + + Parameters + ---------- + cur_epoch : int + The number of epoch of the current training stage. """ return self.lr_func_cosine(cur_epoch) class CosineAnnealingWarmRestartsDecay(scheduler.CosineAnnealingWarmRestarts): + """ + Learning rate scheduler that combines a Cosine annealing with warm restarts pattern, but also adds a + decay factor for where the learning rate restarts at. + """ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): super(CosineAnnealingWarmRestartsDecay, self).__init__( optimizer, T_0, T_mult=T_mult, eta_min=eta_min, last_epoch=last_epoch @@ -219,6 +261,10 @@ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): pass def decay_base_lr(self, curr_iter, n_epochs, n_iter): + """ + Learning rate scheduler that combines a Cosine annealing with warm restarts pattern, + but also adds a decay factor for where the learning rate restarts at. + """ if self.T_cur + 1 == self.T_i: annealed_lrs = [] for base_lr in self.base_lrs: @@ -232,6 +278,9 @@ def decay_base_lr(self, curr_iter, n_epochs, n_iter): def get_lr_scheduler(optimizer, cfg): + """ + Build a learning rate scheduler object from the config data in cfg. + """ scheduler_type = cfg.OPTIMIZER.LR_SCHEDULER if scheduler_type == "step_lr": return scheduler.StepLR( diff --git a/CerebNet/utils/meters.py b/CerebNet/utils/meters.py index 11ac59c8..67d7ebdd 100644 --- a/CerebNet/utils/meters.py +++ b/CerebNet/utils/meters.py @@ -30,8 +30,18 @@ class TestMeter: + """ + TestMeter class. + """ def __init__(self, classname_to_ids): + """ + Constructor function. + Parameters + ---------- + classname_to_ids : dict + Dictionary containing class names and their corresponding ids. + """ # class_id: class_name self.classname_to_ids = classname_to_ids self.measure_func = lambda pred, gt: { @@ -40,6 +50,22 @@ def __init__(self, classname_to_ids): } def _compute_hd(self, pred_bin, gt_bin): + """ + Compute the Hausdorff Distance (HD) between the predicted binary segmentation map + and the ground truth binary segmentation map. + + Parameters + ---------- + pred_bin : np.array + Predicted binary segmentation map. + gt_bin : np.array + Ground truth binary segmentation map. + + Returns + ------- + hd_dict : dict + Dictionary containing the maximum HD and 95th percentile HD. + """ hd_dict = {} if np.count_nonzero(pred_bin) == 0: hd_dict["HD_Max"] = np.nan @@ -52,10 +78,40 @@ def _compute_hd(self, pred_bin, gt_bin): return hd_dict def _get_binray_map(self, lbl_map, class_names): + """ + Generate binary map based on the label map and class names. + + Parameters + ---------- + lbl_map : np.array + Label map where each pixel/voxel is assigned a class label. + class_names : list + List of class names to be considered in the binary map. + + Returns + ------- + bin_map : np.array + Binary map where True represents class and False represents its absence. + """ bin_map = np.logical_or.reduce(list(map(lambda l: lbl_map == l, class_names))) return bin_map def metrics_per_class(self, pred, gt): + """ + Compute metrics for each class in the predicted and ground truth segmentation maps. + + Parameters + ---------- + pred : np.array + Predicted segmentation map. + gt : np.array + Ground truth segmentation map. + + Returns + ------- + metrics : dict + Dict containing metrics for each class. + """ metrics = {"Label": [], "Dice": [], "HD95": [], "HD_Max": [], "VS": []} for lbl_name, lbl_id in self.classname_to_ids.items(): # ignoring background @@ -88,6 +144,9 @@ def metrics_per_class(self, pred, gt): class Meter: + """ + Meter class. + """ def __init__( self, cfg, @@ -99,6 +158,28 @@ def __init__( device=None, writer=None, ): + """ + Constructor function. + + Parameters + ---------- + cfg : object + Configuration object containing all the configuration parameters. + mode : str + Mode of operation ("Train" or "Val"). + global_step : int + The global step count. + total_iter : int, optional + Total number of iterations. + total_epoch : int, optional + Total number of epochs. + class_names : list, optional + List of class names. + device : str, optional + Device to be used for computation. + writer : object, optional + Writer object for tensorboard. + """ self._cfg = cfg self.mode = mode.capitalize() self.confusion_mat = self.mode == "Val" @@ -115,10 +196,25 @@ def __init__( self.multi_gpu = cfg.NUM_GPUS > 1 def reset(self): + """ + Reset function. + """ self.batch_losses = {} self.dice_score.reset() def update_stats(self, pred, labels, loss_dict=None): + """ + Update stats. + + Parameters + ---------- + pred : torch.Tensor + Predicted labels. + labels : torch.Tensor + Ground truth labels. + loss_dict : dict, optional + Dictionary containing loss values. + """ self.dice_score.update((pred, labels)) if loss_dict is None: return @@ -126,15 +222,37 @@ def update_stats(self, pred, labels, loss_dict=None): self.batch_losses.setdefault(name, []).append(loss.item()) def write_summary(self, loss_dict): + """ + Write summary. + + Parameters + ---------- + loss_dict : dict + Dictionary containing loss values. + """ if self.writer is None: return for name, loss in loss_dict.items(): self.writer.add_scalar(f"{self.mode}/{name}", loss.item(), self.global_iter) self.global_iter += 1 - def prediction_visualize( - self, cur_iter, cur_epoch, img_batch, label_batch, pred_batch - ): + def prediction_visualize(self, cur_iter, cur_epoch, img_batch, label_batch, pred_batch): + """ + Visualize prediction results for current iteration and epoch. + + Parameters + ---------- + cur_iter : int + Current iteration number. + cur_epoch : int + Current epoch number. + img_batch : torch.Tensor + Input image batch. + label_batch : torch.Tensor + Ground truth label batch. + pred_batch : torch.Tensor + Predicted label batch. + """ if self.writer is None: return if cur_iter == 1: @@ -146,6 +264,16 @@ def prediction_visualize( plt.close("all") def log_iter(self, cur_iter, cur_epoch): + """ + Log training or validation progress at each iteration. + + Parameters + ---------- + cur_iter : int + The current iteration number. + cur_epoch : int + The current epoch number. + """ if (cur_iter + 1) % self._cfg.TRAIN.LOG_INTERVAL == 0: out_losses = {} for name, loss in self.batch_losses.items(): @@ -167,11 +295,36 @@ def log_iter(self, cur_iter, cur_epoch): ) def log_lr(self, lr, step=None): + """ + Log learning rate at each step. + + Parameters + ---------- + lr : list + Learning rate at the current step. Expected to be a list where the first + element is the learning rate. + step : int, optional + Current step number. If not provided, the global iteration + number is used. + """ if step is None: step = self.global_iter self.writer.add_scalar("Train/lr", lr[0], step) def log_epoch(self, cur_epoch): + """ + Log mean Dice score and confusion matrix at the end of each epoch. + + Parameters + ---------- + cur_epoch : int + Current epoch number. + + Returns + ------- + dice_score : float + The mean Dice score for the non-background classes. + """ dice_score_per_class, confusion_mat = self.dice_score.compute(per_class=True) dice_score = dice_score_per_class[1:].mean() if self.writer is None: diff --git a/CerebNet/utils/metrics.py b/CerebNet/utils/metrics.py index 1ef62d65..b76d1d20 100644 --- a/CerebNet/utils/metrics.py +++ b/CerebNet/utils/metrics.py @@ -30,7 +30,8 @@ class DiceScore: """ - Accumulating the component of the dice coefficient i.e. the union and intersection + Accumulating the component of the dice coefficient i.e. the union and intersection. + Args: op (callable): a callable to update accumulator. Method's signature is `(accumulator, output)`. For example, to compute arithmetic mean value, `op = lambda a, x: a + x`. @@ -43,7 +44,6 @@ class DiceScore: if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is initialized and available, device is set to `cuda`. """ - def __init__( self, num_classes, @@ -67,18 +67,40 @@ def __init__( self.intersection = torch.zeros(self.n_classes, self.n_classes) def reset(self): + """ + Reset the state of the object. + """ self.union = torch.zeros(self.n_classes, self.n_classes) self.intersection = torch.zeros(self.n_classes, self.n_classes) def _check_output_type(self, output): + """ + Check the type of the output and raise an error if it doesn't match expectations. + + Parameters: + ----------- + output : tuple + The output to be checked, expected to be a tuple. + """ if not (isinstance(output, tuple)): raise TypeError( - "Output should a tuple consist of of torch.Tensors, but given {}".format( + "Output should be a tuple consisting of torch.Tensors, but given {}".format( type(output) ) ) def _update_union_intersection(self, batch_output, labels_batch): + """ + Update the union and intersection matrices based on batch predictions and labels. + + Parameters: + ----------- + batch_output : torch.Tensor + Batch predictions from the model. + + labels_batch : np.ndarray or torch.Tensor + Batch labels from the dataset. + """ # self.union.to(batch_output.device) # self.intersection.to(batch_output.device) for i, c1 in enumerate(self.class_ids): @@ -91,6 +113,14 @@ def _update_union_intersection(self, batch_output, labels_batch): self.union[i, j] = self.union[i, j] + torch.sum(gt) + torch.sum(pred) def update(self, output): + """ + Update the internal state based on the output. + + Parameters + ---------- + output : tuple of torch.Tensor + Tuple of predictions and labels. + """ self._check_output_type(output) if self.out_transform is not None: @@ -111,6 +141,9 @@ def update(self, output): self._update_union_intersection(y_pred, y) def compute(self, per_class=False, class_idxs=None): + """ + Compute the Dice score. + """ dice_cm_mat = self._dice_confusion_matrix(class_idxs) dice_score_per_class = dice_cm_mat.diagonal() dice_score = dice_score_per_class.mean() @@ -120,6 +153,9 @@ def compute(self, per_class=False, class_idxs=None): return dice_score, dice_cm_mat def _dice_confusion_matrix(self, class_idxs): + """ + Compute the Dice score confusion matrix. + """ dice_intersection = self.intersection.cpu().numpy() dice_union = self.union.cpu().numpy() if class_idxs is not None: @@ -151,9 +187,9 @@ def volume_similarity(pred, gt): # https://github.com/amanbasu/3d-prostate-segmentation/blob/master/metric_eval.py def hd(result, reference, voxelspacing=None, connectivity=1): """ - Hausdorff Distance. Computes the (symmetric) Hausdorff Distance (HD) between the binary objects in two images. It is defined as the maximum surface distance between the objects. + Parameters ---------- result : array_like @@ -172,16 +208,23 @@ def hd(result, reference, voxelspacing=None, connectivity=1): of the binary objects. This value is passed to `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. Note that the connectivity influences the result in the case of the Hausdorff distance. + Returns ------- hd : float - The symmetric Hausdorff Distance between the object(s) in ```result``` and the - object(s) in ```reference```. The distance unit is the same as for the spacing of + The symmetric Hausdorff Distance between the object(s) in `result` and the + object(s) in `reference`. The distance unit is the same as for the spacing of elements along each dimension, which is usually given in mm. - See also + hd50 : float + The 50th percentile of the Hausdorff Distance. + hd95 : float + The 95th percentile of the Hausdorff Distance. + + See Also -------- - :func:`assd` - :func:`asd` + assd : Average Symmetric Surface Distance, computes the average symmetric surface distance. + asd : Average Surface Distance, computes the average surface distance. + Notes ----- This is a real metric. The binary images can therefore be supplied in any order. @@ -196,13 +239,15 @@ def hd(result, reference, voxelspacing=None, connectivity=1): def hd95(result, reference, voxelspacing=None, connectivity=1): """ - 95th percentile of the Hausdorff Distance. + Computes the 95th percentile of the Hausdorff Distance. + Computes the 95th percentile of the (symmetric) Hausdorff Distance (HD) between the binary objects in two images. Compared to the Hausdorff Distance, this metric is slightly more stable to small outliers and is commonly used in Biomedical Segmentation challenges. + Parameters ---------- - result : array_like + result : Any Input data containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. reference : array_like @@ -218,15 +263,18 @@ def hd95(result, reference, voxelspacing=None, connectivity=1): of the binary objects. This value is passed to `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. Note that the connectivity influences the result in the case of the Hausdorff distance. + Returns ------- - hd : float - The symmetric Hausdorff Distance between the object(s) in ```result``` and the + hd95 : float + The 95th percentile of the symmetric Hausdorff Distance between the object(s) in ```result``` and the object(s) in ```reference```. The distance unit is the same as for the spacing of elements along each dimension, which is usually given in mm. - See also + + See Also -------- - :func:`hd` + hd : Computes the symmetric Hausdorff Distance. + Notes ----- This is a real metric. The binary images can therefore be supplied in any order. diff --git a/CerebNet/utils/misc.py b/CerebNet/utils/misc.py index 527c7ce4..6e7ed790 100644 --- a/CerebNet/utils/misc.py +++ b/CerebNet/utils/misc.py @@ -79,6 +79,29 @@ def plot_confusion_matrix( figsize=(20, 20), file_save_name=None, ): + """ + This function prints and plots the confusion matrix. + + Parameters + ---------- + cm : np.ndarray + Confusion matrix. + classes : list + List of classes. + title : str, default="Confusion matrix" + Title of the confusion matrix. + cmap : plt.cm, default=matplotlib.pyplot.cm.Blues + Color map. + figsize : tuple, default=(20, 20) + Figure size. + file_save_name : str, optional + File save name. + + Returns + ------- + fig : plt.Figure + Figure object. + """ n_classes = len(classes) fig, ax = plt.subplots(figsize=figsize) @@ -175,9 +198,12 @@ def get_selected_class_ids(num_classes, ignored_classes=None): def set_summary_path(cfg): """ - Set last experiment number(EXPR_NUM) and updates the summary path accordingly - :param cfg: - :return: + Set last experiment number(EXPR_NUM) and updates the summary path accordingly. + + Parameters + ---------- + cfg : yacs.config.CfgNode + Configuration node. """ summary_path = check_path(os.path.join(cfg.LOG_DIR, "summary")) cfg.EXPR_NUM = str(find_latest_experiment(os.path.join(cfg.LOG_DIR, "summary")) + 1) @@ -188,7 +214,7 @@ def set_summary_path(cfg): def load_classwise_weights(cfg): """ - Loading class-wise median frequency weights + Loading class-wise median frequency weights. """ dataset_dir = os.path.dirname(cfg.DATA.PATH_HDF5_TRAIN) weight_path = glob.glob(os.path.join(dataset_dir, "*.npy")) @@ -202,9 +228,12 @@ def load_classwise_weights(cfg): def update_results_dir(cfg): """ - It will update the results path by finding the last experiment number - :param cfg: - :return: + It will update the results path by finding the last experiment number. + + Parameters + ---------- + cfg : yacs.config.CfgNode + Configuration node. """ cfg.EXPR_NUM = str(find_latest_experiment(cfg.TEST.RESULTS_DIR) + 1) cfg.TEST.RESULTS_DIR = check_path( @@ -215,11 +244,11 @@ def update_results_dir(cfg): def update_split_path(cfg): """ Updating path with respect to the split number - Args: - cfg: - - Returns: - + + Parameters + ---------- + cfg : yacs.config.CfgNode + Configuration node. """ from os.path import split, join @@ -240,7 +269,7 @@ def update_split_path(cfg): def visualize_batch(img, label, idx): """ - For deubg + For deubg :param batch_dict: :return: """ diff --git a/Docker/Dockerfile b/Docker/Dockerfile index 2a1d8e10..8e4f854e 100644 --- a/Docker/Dockerfile +++ b/Docker/Dockerfile @@ -23,14 +23,15 @@ # Image to use to install freesurfer binaries from, the freesurfer binaries # should be located in /opt/freesurfer in the image. # - default: build_freesurfer -# - CONDA_BUILD_IAMGE: +# - CONDA_BUILD_IMAGE: # Image to use to install the python environment from, the python environment # should be in /venv/ in the image. # - default: build_cuda -# - CONDA_FILE: -# Which conda minifile to download to install conda -# from https://repo.continuum.io/miniconda/${CONDA_FILE} -# - default: Miniconda3-py38_4.11.0-Linux-x86_64.sh +# - MAMBA_VERSION: +# Which miniforge file to download to install mamba +# from https://github.com/conda-forge/miniforge/releases/download/ +# ${FORGE_VERSION}/Miniforge3-${FORGE_VERSION}-Linux-x86_64.sh +# - default: Miniforge3-23.11.0-0-Linux-x86_64.sh # DOCUMENTATION FOR TARGETS (use '--target '): # To select which imaged will be tagged with '-t' @@ -47,6 +48,12 @@ ARG FREESURFER_BUILD_IMAGE=build_freesurfer ARG CONDA_BUILD_IMAGE=build_conda ARG RUNTIME_BASE_IMAGE=ubuntu:22.04 ARG BUILD_BASE_IMAGE=ubuntu:22.04 +# BUILDKIT_SBOM:SCAN_CONTEXT enables buildkit to provide and scan build images +# this is active by default to provide proper SBOM manifests, however, it may also +# include parts that are not part of the distributed image (specifically build image +# parts installed in the build image, but not transfered to the runtime image such as +# git, wget, the miniconda installer, etc.) +ARG BUILDKIT_SBOM_SCAN_CONTEXT=true ## Start with ubuntu base to build the conda base stage FROM $BUILD_BASE_IMAGE AS build_base @@ -65,15 +72,16 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* ARG PYTHON_VERSION=3.10 -ARG CONDA_FILE=Miniconda3-py310_23.9.0-0-Linux-x86_64.sh +ARG FORGE_VERSION=24.3.0-0 # Install conda -RUN wget --no-check-certificate -qO ~/miniconda.sh https://repo.continuum.io/miniconda/${CONDA_FILE} && \ - chmod +x ~/miniconda.sh && \ - ~/miniconda.sh -b -p /opt/conda && \ - rm ~/miniconda.sh +RUN wget --no-check-certificate -qO ~/miniforge.sh \ + https://github.com/conda-forge/miniforge/releases/download/${FORGE_VERSION}/Miniforge3-${FORGE_VERSION}-Linux-x86_64.sh && \ + chmod +x ~/miniforge.sh && \ + ~/miniforge.sh -b -p /opt/miniforge && \ + rm ~/miniforge.sh -ENV PATH /opt/conda/bin:$PATH +ENV PATH=/opt/miniforge/bin:$PATH # create a stage for the common components used across different DEVICE settings FROM build_base AS build_common @@ -85,8 +93,9 @@ COPY ./env/fastsurfer.yml ./Docker/conda_pack.sh ./Docker/install_env.py /instal # Install conda for gpu ARG DEBUG=false -RUN python /install/install_env.py -m base -i /install/fastsurfer.yml -o /install/base-env.yml && \ - conda env create -f "/install/base-env.yml" | tee /install/env-create.log ; \ +RUN python /install/install_env.py -m base -i /install/fastsurfer.yml \ + -o /install/base-env.yml && \ + mamba env create -f "/install/base-env.yml" | tee /install/env-create.log ; \ if [ "${DEBUG}" != "true" ]; then \ rm /install/base-env.yml ; \ fi @@ -96,12 +105,14 @@ FROM build_common AS build_conda ARG DEBUG=false ARG DEVICE=cu118 # install additional packages for cuda/rocm/cpu -RUN python /install/install_env.py -m ${DEVICE} -i /install/fastsurfer.yml -o /install/${DEVICE}-env.yml && \ - conda env update -n "fastsurfer" -f "/install/${DEVICE}-env.yml" | tee /install/env-update.log && \ +RUN python /install/install_env.py -m ${DEVICE} -i /install/fastsurfer.yml \ + -o /install/${DEVICE}-env.yml && \ + mamba env update -n "fastsurfer" -f "/install/${DEVICE}-env.yml" \ + | tee /install/env-update.log && \ /install/conda_pack.sh "fastsurfer" && \ echo "DEBUG=$DEBUG\nDEVICE=$DEVICE\n" > /install/build_conda.args ; \ if [ "${DEBUG}" != "true" ]; then \ - conda env remove -n "fastsurfer" && \ + mamba env remove -n "fastsurfer" && \ rm -R /install ; \ fi @@ -112,8 +123,10 @@ FROM build_base AS build_freesurfer COPY ./Docker/install_fs_pruned.sh /install/ SHELL ["/bin/bash", "--login", "-c"] +ARG FREESURFER_URL=default + # install freesurfer and point to new python location -RUN /install/install_fs_pruned.sh /opt --upx && \ +RUN /install/install_fs_pruned.sh /opt --upx --url $FREESURFER_URL && \ rm /opt/freesurfer/bin/fspython && \ rm -R /install && \ ln -s /venv/bin/python3 /opt/freesurfer/bin/fspython @@ -122,7 +135,8 @@ ln -s /venv/bin/python3 /opt/freesurfer/bin/fspython # ======================================================= # Here, we create references to the requested build image # ======================================================= -# This is needed because COPY --from= does not accept variables as part of the image/stage name +# This is needed because COPY --from= does not accept variables as part of +# the image/stage name # selected_freesurfer_build_image -> $FREESURFER_BUILD_IMAGE FROM $FREESURFER_BUILD_IMAGE AS selected_freesurfer_build_image # selected_conda_build_image -> $CONDA_BUILD_IMAGE @@ -148,6 +162,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Add FreeSurfer and python Environment variables +# DO_NOT_SEARCH_FS_LICENSE_IN_FREESURFER_HOME=true deactivates the search for FS_LICENSE in FREESURFER_HOME ENV OS=Linux \ FS_OVERRIDE=0 \ FIX_VERTEX_AREA="" \ @@ -157,7 +172,8 @@ ENV OS=Linux \ PYTHONUNBUFFERED=0 \ MPLCONFIGDIR=/tmp \ PATH=/venv/bin:/opt/freesurfer/bin:$PATH \ - MPLCONFIGDIR=/tmp/matplotlib-config + MPLCONFIGDIR=/tmp/matplotlib-config \ + DO_NOT_SEARCH_FS_LICENSE_IN_FREESURFER_HOME="true" # create matplotlib config dir; make sure we use bash and activate conda env # (in case someone starts this interactively) @@ -167,31 +183,50 @@ SHELL ["/bin/bash", "--login", "-c"] # Copy fastsurfer venv and pruned freesurfer from build images -# Note, since COPY does not support variables in the --from parameter, so we point to a reference here, and the +# Note, since COPY does not support variables in the --from parameter, so we point to a +# reference here, and the # seletced__build_image is a only a reference to $_BUILD_IMAGE COPY --from=selected_freesurfer_build_image /opt/freesurfer /opt/freesurfer COPY --from=selected_conda_build_image /venv /venv -COPY /Docker/python-s /venv/bin/ + +# Fix for cuda11.8+cudnn8.7 bug+warning: https://github.com/pytorch/pytorch/issues/97041 +RUN if [[ "$DEVICE" == "cu118" ]] ; then cd /venv/python3.10/site-packages/torch/lib && ln -s libnvrtc-*.so.11.2 libnvrtc.so ; fi # Copy fastsurfer over from the build context and add PYTHONPATH COPY . /fastsurfer/ -ENV PYTHONPATH=/fastsurfer:/opt/freesurfer/python/packages:$PYTHONPATH \ +ENV PYTHONPATH=/fastsurfer:/opt/freesurfer/python/packages \ FASTSURFER_HOME=/fastsurfer -# Download all remote network checkpoints already, compile all FastSurfer scripts into bytecode and update -# the build file with checkpoints md5sums and pip packages. +# Download all remote network checkpoints already, compile all FastSurfer scripts into +# bytecode and update the build file with checkpoints md5sums and pip packages. RUN cd /fastsurfer ; python3 FastSurferCNN/download_checkpoints.py --all && \ python3 -m compileall * && \ - python3 FastSurferCNN/version.py --sections +git+checkpoints+pip --build_cache BUILD.info -o fullBUILD.info && \ + python3 FastSurferCNN/version.py --sections +git+checkpoints+pip \ + --build_cache BUILD.info -o fullBUILD.info && \ mv fullBUILD.info BUILD.info +# TODO: SBOM info of FastSurfer and FreeSurfer are missing, it is unclear how to add +# those at the moment, as the buildscanner syft does not find simple package.json +# or pyproject.toml files right now. The easiest option seems to be to "setup" +# fastsurfer and freesurfer via pip install. +#ENV BUILDKIT_SCAN_SOURCE_EXTRAS="/fastsurfer" +#ARG BUILDKIT_SCAN_SOURCE_EXTRAS="/fastsurfer" +#RUN < /fastsurfer/package.json +#{ +# "name": "fastsurfer", +# "version": "$(python3 FastSurferCNN/version.py)", +# "author": "David Kügler " +#} +#EOF + # Set FastSurfer workdir and entrypoint # the script entrypoint ensures that our conda env is active +USER nonroot WORKDIR "/fastsurfer" ENTRYPOINT ["/fastsurfer/Docker/entrypoint.sh","/fastsurfer/run_fastsurfer.sh"] CMD ["--help"] -FROM runtime as runtime_cuda +FROM runtime AS runtime_cuda # to support AWS docker images, see Issue #352 # https://sarus.readthedocs.io/en/stable/user/custom-cuda-images.html#controlling-the-nvidia-container-runtime diff --git a/Docker/README.md b/Docker/README.md index 426decb9..d5d12af8 100644 --- a/Docker/README.md +++ b/Docker/README.md @@ -1,12 +1,23 @@ -# Pull FastSurfer from DockerHub +# FastSurfer Docker Support +## Pull FastSurfer from DockerHub -We provide a number of prebuild docker images on [Docker Hub](https://hub.docker.com/r/deepmi/fastsurfer/tags). In order to get the latest cuda image (for nVidia GPUs) you simply need to execute the following command: +We provide pre-built Docker images with support for nVidia GPU-acceleration and for CPU-only use on [Docker Hub](https://hub.docker.com/r/deepmi/fastsurfer/tags). +In order to quickly get the latest Docker image, simply execute: ```bash docker pull deepmi/fastsurfer ``` -You can get a different one by simply adding the corresponding tag at the end of "deepmi/fastsurfer" as in "deepmi/fastsurfer:gpu-v#.#.#", where the # should be replaced with the version. +This will download the newest, official FastSurfer image with support for nVidia GPUs. + +Image are named and tagged as follows: `deepmi/fastsurfer:-`, where `` is `gpu` for support of nVidia GPUs and `cpu` without hardware acceleration (the latter is smaller and thus faster to download). +Similarly, `` can be a version string (`latest` or `v#.#.#`, where `#` are digits, for example `v2.2.2`), for example: + +```bash +docker pull deepmi/fastsurfer:cpu-v2.2.2 +``` + +### Running the official Docker Image After pulling the image, you can start a FastSurfer container and process a T1-weighted image (both segmentation and surface reconstruction) with the following command: ```bash @@ -20,15 +31,18 @@ docker run --gpus all -v /home/user/my_mri_data:/data \ --parallel ``` -##### Docker Flags +#### Docker Flags * `--gpus`: This flag is used to access GPU resources. With it, you can also specify how many GPUs to use. In the example above, _all_ will use all available GPUS. To use a single one (e.g. GPU 0), set `--gpus device=0`. To use multiple specific ones (e.g. GPU 0, 1 and 3), set `--gpus "device=0,1,3"`. * `-v`: This commands mount your data, output and directory with the FreeSurfer license file into the docker container. Inside the container these are visible under the name following the colon (in this case /data, /output, and /fs_license). * `--rm`: The flag takes care of removing the container once the analysis finished. * `-d`: This is optional. You can add this flag to run in detached mode (no screen output and you return to shell) -* `--user $(id -u):$(id -g)`: This part automatically runs the container with your group- (id -g) and user-id (id -u). All generated files will then belong to the specified user. Without the flag, the docker container will be run as root which is strongly discouraged. +* `--user $(id -u):$(id -g)`: Run the container with your account (your user-id and group-id), which are determined by `$(id -u)` and `$(id -g)`, respectively. Running the docker container as root `-u 0:0` is strongly discouraged. + +#### Advanced Docker Flags +* `--group-add `: If additional user groups are required to access files, additional groups may be added via `--group-add [,...]` or `--group-add $(id -G )`. -##### FastSurfer Flags -* The `--fs_license` points to your FreeSurfer license which needs to be available on your computer in the my_fs_license_dir that was mapped above. +#### FastSurfer Flags +* The `--fs_license` points to your FreeSurfer license which needs to be available on your computer in the `my_fs_license_dir` that was mapped above. * The `--t1` points to the t1-weighted MRI image to analyse (full path, with mounted name inside docker: /home/user/my_mri_data => /data) * The `--sid` is the subject ID name (output folder name) * The `--sd` points to the output directory (its mounted name inside docker: /home/user/my_fastsurfer_analysis => /output) @@ -47,7 +61,7 @@ All other available flags are identical to the ones explained on the main page [ How? Docker does not mount the home directory by default, so unless you manually set the `HOME` environment variable, all should be fine. -# FastSurfer Docker Image Creation +## FastSurfer Docker Image Creation Within this directory, we currently provide a build script and Dockerfile to create multiple Docker images for users (usually developers) who wish to create their own Docker images for 3 platforms: @@ -67,9 +81,12 @@ Also note, in order to run our Docker containers on a Mac, users need to increas The build script `build.py` supports additional args, targets and options, see `python Docker/build.py --help`. Note, that the build script's main function is to select parameters for build args, but also create the FastSurfer-root/BUILD.info file, which will be used by FastSurfer to document the version (including git hash of the docker container). This BUILD.info file must exist for the docker build to be successful. -In general, if you specify `--dry_run` the command will not be executed but sent to stdout, so you can run `python build.py --device cuda --dry_run | bash` as well. Note, that build.py uses some dependencies from FastSurfer, so you will need to set the PYTHONPATH environment variable to the FastSurfer root (include of `FastSurferCNN` must be possible) and we only support Python 3.10 (Python 3.8+ still seems to work, but may stop working at any time). +In general, if you specify `--dry_run` the command will not be executed but sent to stdout, so you can run `python build.py --device cuda --dry_run | bash` as well. Note, that build.py uses some dependencies from FastSurfer, so you will need to set the PYTHONPATH environment variable to the FastSurfer root (include of `FastSurferCNN` must be possible) and we only support Python 3.10. -By default, the build script will tag your image as "fastsurfer:{version_tag}[-{device}]", where {version_tag} is {version-identifer from pyproject.toml}_{current git-hash} and {device} is the value to --device (and omitted for cuda), but a custom tag can be specified by `--tag {tag_name}`. +By default, the build script will tag your image as `"fastsurfer:[{device}-]{version_tag}"`, where `{version_tag}` is `{version-identifer from pyproject.toml}_{current git-hash}` and `{device}` is the value to `--device` (omitted for `cuda`), but a custom tag can be specified by `--tag {tag_name}`. + +#### BuildKit +Note, we recommend using BuildKit to build docker images (e.g. `DOCKER_BUILDKIT=1` -- the build.py script already always adds this). To install BuildKit, run `wget -qO ~/.docker/cli-plugins/docker-buildx https://github.com/docker/buildx/releases/download//buildx-.`, for example `wget -qO ~/.docker/cli-plugins/docker-buildx https://github.com/docker/buildx/releases/download/v0.12.1/buildx-v0.12.1.linux-amd64`. See also https://github.com/docker/buildx#manual-download. ### Example 1: Build GPU FastSurfer Image @@ -122,7 +139,7 @@ As you can see, only the tag of the image is changed from gpu to cpu and the sta Here we build an experimental image to test performance when running on AMD GPUs. Note that you need a supported OS and Kernel version and supported GPU for the RocM to work correctly. You need to install the Kernel drivers into your host machine kernel (amdgpu-install --usecase=dkms) for the amd docker to work. For this follow: -https://docs.amd.com/en/latest/deploy/linux/quick_start.html +https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html#rocm-install-quick, https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/amdgpu-install.html#amdgpu-install-dkms and https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html ```bash PYTHONPATH= @@ -132,9 +149,8 @@ python build.py --device rocm --tag my_fastsurfer:rocm and run segmentation only: ```bash -docker run --rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ - --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host \ - --shm-size 8G \ +docker run --rm --security-opt seccomp=unconfined \ + --device=/dev/kfd --device=/dev/dri --group-add video \ -v /home/user/my_mri_data:/data \ -v /home/user/my_fastsurfer_analysis:/output \ my_fastsurfer:rocm \ @@ -142,12 +158,13 @@ docker run --rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --sid subjectX --sd /output ``` -Note, we tested on an AMD Radeon Pro W6600, which is [not officially supported](https://docs.amd.com/en/latest/release/gpu_os_support.html), but setting `HSA_OVERRIDE_GFX_VERSION=10.3.0` [inside docker did the trick](https://en.opensuse.org/AMD_OpenCL#ROCm_-_Running_on_unsupported_hardware): +In conflict with the official ROCm documentation (above), we also needed to add the group render `--group-add render` (in addition to `--group-add video`). + +Note, we tested on an AMD Radeon Pro W6600, which is [not officially supported](https://docs.amd.com/en/latest/release/gpu_os_support.html), but setting `HSA_OVERRIDE_GFX_VERSION=10.3.0` [inside docker did the trick](https://en.opensuse.org/SDB:AMD_GPGPU#Using_CUDA_code_with_ZLUDA_and_ROCm): ```bash -docker run --rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ - --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host \ - --shm-size 8G \ +docker run --rm --security-opt seccomp=unconfined \ + --device=/dev/kfd --device=/dev/dri --group-add video --group-add render \ -v /home/user/my_mri_data:/data \ -v /home/user/my_fastsurfer_analysis:/output \ -e HSA_OVERRIDE_GFX_VERSION=10.3.0 \ @@ -156,3 +173,42 @@ docker run --rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ --sid subjectX --sd /output ``` +## Build docker image with attestation and provenance + +To build a docker image with attestation and provenance, i.e. Software Bill Of Materials (SBOM) information, several requirements have to be met: + +1. The image must be built with version v0.11+ of BuildKit (we recommend you [install BuildKit](#buildkit) independent of attestation). +2. You must configure a docker-container builder in buildx (`docker buildx create --use --bootstrap --name fastsurfer-bctx --driver docker-container`). Here, you can add additional configuration options such as safe registries to the builder configuration (add `--config /etc/buildkitd.toml`). + ```toml + root = "/path/to/data/for/buildkit" + [worker.containerd] + gckeepstorage=9000 + [[worker.containerd.gcpolicy]] + keepBytes = 512000000 + keepDuration = 172800 + filters = [ "type==source.local", "type==exec.cachemount", "type==source.git.checkout"] + [[worker.containerd.gcpolicy]] + all = true + keepBytes = 1024000000 + # settings to push to a "local", registry with self-signed certificates + # see for example https://tech.paulcz.net/2016/01/secure-docker-with-tls/ https://github.com/paulczar/omgwtfssl + [registry."host:5000"] + ca=["/path/to/registry/ssl/ca.pem"] + [[registry."landau.dzne.ds:5000".keypair]] + key="/path/to/registry/ssl/key.pem" + cert="/path/to/registry/ssl/cert.pem" + ``` +3. Attestation files are not supported by the standard docker image storage driver. Therefore, images cannot be tested locally. + There are two solutions to this limitation. + 1. Directly push to the registry: + Add `--action push` to the build script (the default is `--action load`, which loads the created image into the current docker context, and for the image name, also add the registry name. For example `... python Docker/build.py ... --attest --action push --tag docker.io//fastsurfer:latest`. + 2. [Install the containerd image storage driver](https://docs.docker.com/storage/containerd/#enable-containerd-image-store-on-docker-engine), which supports attestation: To implement this on Linux, make sure your docker daemon config file `/etc/docker/daemon.json` includes + ```json + { + "features": { + "containerd-snapshotter": true + } + } + ``` + Also note, that the image storage location with containerd is not defined by the docker config file `/etc/docker/daemon.json`, but by the containerd config `/etc/containerd/config.toml`, which will likely not exist. You can [create a default config](https://github.com/containerd/containerd/blob/main/docs/getting-started.md#customizing-containerd) file with `containerd config default > /etc/containerd/config.toml`, in this config file edit the `"root"`-entry (default value is `/var/lib/containerd`). +4. Finally, you can now build the FastSurfer image with `python Docker/build.py ... --attest`. This will add the additional flags to the docker build command. diff --git a/Docker/build.py b/Docker/build.py index 4304e6ab..6467c7bf 100755 --- a/Docker/build.py +++ b/Docker/build.py @@ -20,7 +20,6 @@ import argparse import os import subprocess -from itertools import chain from pathlib import Path from typing import Tuple, Literal, Sequence, Optional, Dict, get_args, cast, List, Callable, Union import logging @@ -31,10 +30,25 @@ Target = Literal['runtime', 'build_common', 'build_conda', 'build_freesurfer', 'build_base', 'runtime_cuda'] CacheType = Literal["inline", "registry", "local", "gha", "s3", "azblob"] -AllDeviceType = Literal["cpu", "cuda", "cu116", "cu117", "cu118", "rocm", "rocm5.1.1", - "rocm5.4.2"] -DeviceType = Literal["cpu", "cu116", "cu117", "cu118", "rocm5.1.1", "rocm5.4.2"] - +AllDeviceType = Literal["cpu", "cuda", "cu118", "cu121", "cu124", "rocm", "rocm6.1"] +DeviceType = Literal["cpu", "cu118", "cu121", "cu124", "rocm6.1"] + +CREATE_BUILDER = "Create builder with 'docker buildx create --name fastsurfer'." +CONTAINERD_MESSAGE = ( + "Attestation requires OCI images, which are not supported by the default docker " + "storage driver (your current storage driver?). Use containerd storage: " + "https://docs.docker.com/storage/containerd/\n" +) +_WGET_BUILDX_FMT = ( + "wget -qO ~/.docker/cli-plugins/docker-buildx https://github.com/docker" + "/buildx/releases/download/{0:s}/buildx-{0:s}.{1:s}" +) +INSTALL_BUILDX = ( + f"Install buildx with '{_WGET_BUILDX_FMT.format('', '')}', " + f"e.g. '{_WGET_BUILDX_FMT.format('v0.12.1', 'linux-amd64')}.\nYou may need to " + f"'chmod +x ~/.docker/cli-plugins/docker-buildx'\nSee also " + f"https://github.com/docker/buildx#manual-download." +) __import_cache = {} @@ -44,10 +58,11 @@ class DEFAULTS: # and rocm versions, if pytorch comes with new versions. # torch 1.12.0 comes compiled with cu113, cu116, rocm5.0 and rocm5.1.1 # torch 2.0.1 comes compiled with cu117, cu118, and rocm5.4.2 + # torch 2.4 comes compiled with cu118, cu121, cu124 and rocm6.1 MapDeviceType: Dict[AllDeviceType, DeviceType] = dict( ((d, d) for d in get_args(DeviceType)), - rocm="rocm5.1.1", - cuda="cu117", + rocm="rocm6.1", + cuda="cu124", ) BUILD_BASE_IMAGE = "ubuntu:22.04" RUNTIME_BASE_IMAGE = "ubuntu:22.04" @@ -55,39 +70,15 @@ class DEFAULTS: CONDA_BUILD_IMAGE = "build_conda" -def _import_calls(fasturfer_home: Path, token: str = "Popen") -> Callable: - # import call and call_async without importing FastSurferCNN fully - if token not in __import_cache: - def __import(file: Path, name: str, *tokens: str, **rename_tokens: str): - from importlib.util import spec_from_file_location, module_from_spec - spec = spec_from_file_location(name, file) - module = module_from_spec(spec) - spec.loader.exec_module(module) - for tok, name in chain(zip(tokens, tokens), rename_tokens.items()): - __import_cache[tok] = getattr(module, name) - - if token in ("Popen", "PyPopen"): - # import Popen and PyPopen from FastSurferCNN.utils.run_tools - __import(fasturfer_home / "FastSurferCNN/utils/run_tools.py", - "run_tools", "Popen", "PyPopen") - elif token in ("version", "parse_build_file"): - # import main as version from FastSurferCNN.version - __import(fasturfer_home / "FastSurferCNN/version.py", - "version", "parse_build_file", version="main") - - if token in __import_cache: - return __import_cache[token] - else: - raise ImportError(f"Invalid token {token}") - - def docker_image(arg) -> str: - """Returns a str with the image. + """ + Returns a str with the image. Raises ====== ArgumentTypeError - if it is not a valid docker image.""" + if it is not a valid docker image. + """ from re import match # regex from https://stackoverflow.com/questions/39671641/regex-to-parse-docker-tag pattern = r"^(?:(?=[^:\/]{1,253})(?!-)[a-zA-Z0-9-]{1,63}(? str: if match(pattern, arg): return arg else: - raise argparse.ArgumentTypeError(f"The image '{arg}' does not look like a " - f"valid image name.") + raise argparse.ArgumentTypeError( + f"The image '{arg}' does not look like a valid image name." + ) def target(arg) -> Target: @@ -112,7 +104,8 @@ def target(arg) -> Target: return cast(Target, arg) else: raise argparse.ArgumentTypeError( - f"target must be one of {', '.join(get_args(Target))}, but was {arg}.") + f"target must be one of {', '.join(get_args(Target))}, but was {arg}." + ) class CacheSpec: @@ -192,19 +185,20 @@ def make_parser() -> argparse.ArgumentParser: parser.add_argument( "--device", - choices=["cpu", "cuda", "cu117", "cu1118", "rocm", "rocm5.4.2"], + choices=["cpu", "cuda", "cu118", "cu121", "cu124", "rocm", "rocm6.1"], required=True, help="""selection of internal build stages to build for a specific platform.
- - cuda: defaults to cu118, cuda 11.8
+ - cuda: defaults to cu124, cuda 12.4
- cpu: only cpu support
- - rocm: defaults to rocm5.4.2 (experimental)""", + - rocm: defaults to rocm6.1 (experimental)""", ) parser.add_argument( "--tag", type=docker_image, dest="image_tag", metavar="image[:tag]", - help="""tag build stage/target as [:]""") + help="""tag build stage/target as [:]""", + ) parser.add_argument( "--target", default="runtime", @@ -214,32 +208,77 @@ def make_parser() -> argparse.ArgumentParser: help=f"""target to build (from list of targets below, defaults to runtime):
- build_conda: "finished" conda build image
- build_freesurfer: "finished" freesurfer build image
- - runtime: final fastsurfer runtime image""") - parser.add_argument( - "--rm", - action="store_true", - help="disables caching, i.e. removes all intermediate images.") + - runtime: final fastsurfer runtime image""", + ) + cache_kwargs = {} + if "FASTSURFER_BUILD_CACHE" in os.environ: + try: + cache_kwargs = { + "default": CacheSpec(os.environ["FASTSURFER_BUILD_CACHE"]) + } + except ValueError as e: + logger.warning( + f"ERROR while parsing the environment variable 'FASTSURFER_BUILD_CACHE' " + f"{os.environ['FASTSURFER_BUILD_CACHE']} (ignoring this environment " + f"variable): {e.args[0]}" + ) parser.add_argument( "--cache", type=CacheSpec, - help="""cache as defined in https://docs.docker.com/build/cache/backends/ - (using --cache-to syntax, parameters are automatically filtered for use - in --cache-to and --cache-from), e.g.: - --cache type=registry,ref=server/fastbuild,mode=max.""") + help=f"""cache as defined in https://docs.docker.com/build/cache/backends/ + (using --cache-to syntax, parameters are automatically filtered for use + in --cache-to and --cache-from), e.g.: + --cache type=registry,ref=server/fastbuild,mode=max. + Will default to the environment variable FASTSURFER_BUILD_CACHE: + {cache_kwargs.get('default', 'N/A')}""", + metavar="type={inline,local,...}[,=[,...]]", + **cache_kwargs, + ) parser.add_argument( "--dry_run", "--print", action="store_true", help="Instead of starting processes, write the commands to stdout, so they can " - "be dry_run with 'build.py ... --dry_run | bash'.") + "be dry_run with 'build.py ... --dry_run | bash'.", + ) parser.add_argument( "--tag_dev", action="store_true", - help="Also tag the resulting image as 'fastsurfer:dev'." + help="Also tag the resulting image as 'fastsurfer:dev'.", + ) + # --save_image does not work as expected right now, it cannot be imported via + # docker load, but must be transferred to a registry... + # parser.add_argument( + # "--save_image", + # dest="image_path", + # default=None, + # help="Export the image to a tarball.", + # ) + parser.add_argument( + "--singularity", + type=Path, + default=None, + help="Specify a singularity file name to build a singularity image into.", ) expert = parser.add_argument_group('Expert options') + parser.add_argument( + "--attest", + action="store_true", + help="add sbom and provenance attestation (requires docker-container buildkit " + "builder created with 'docker buildx create')", + ) + parser.add_argument( + "--action", + choices=("load", "push"), + default="load", + help="Which action to perform after building the image (if a docker-container " + "is detected): " + "'load' loads the image into the current docker context (default), " + "'push' pushes the image to the registry (needs --tag /" + ":)", + ) expert.add_argument( "--freesurfer_build_image", type=docker_image, @@ -281,56 +320,285 @@ def red(skk): return "\033[91m {}\033[00m" .format(skk) +def get_builder( + Popen, + builder_type: str, + require_builder_type: bool = False, +) -> tuple[bool, str]: + """Get the builder to build the fastsurfer image.""" + from subprocess import PIPE + from re import compile + + buildx_binfo = Popen(["docker", "buildx", "ls"], stdout=PIPE, stderr=PIPE).finish() + header, *lines = buildx_binfo.out_str("utf-8").strip().split("\n") + header_pattern = compile("\\S+\\s*") + fields = {} + alternative_builder = "use_default" + pos = 0 + while pos < len(header) and (match := header_pattern.search(header, pos)): + start, pos = match.span() + fields[match.group().strip()] = slice(start, pos) + builders = {line[fields["NAME/NODE"]].strip(): line[fields["DRIVER/ENDPOINT"]].strip() + for line in lines if not line.startswith(" ")} + default_builders = [name for name in builders.keys() if name.endswith("*")] + if len(default_builders) != 1: + raise RuntimeError("Could not find default builder of buildx") + default_builder = default_builders[0][:-1].strip() + builders[default_builder] = builders[default_builders[0]] + del builders[default_builders[0]] + builder_is_correct_type = builders[default_builder] == builder_type + default_builder_is_correct_type = builder_is_correct_type + if not builder_is_correct_type: + # if the default builder is a docker builder (which may not support features) + # see if there is an alternative builder named "fastsurfer*" + for builder in builders.keys(): + if builder.startswith("fastsurfer") and builders[builder] == builder_type: + # set the default_builder to this (prefered) builder + alternative_builder = builder + break + # update is_correct_type + if alternative_builder != "use_default": + builder_is_correct_type = builders[alternative_builder] == builder_type + if not builder_is_correct_type and require_builder_type: + # did not find an appropriate builder, but is required!! + raise RuntimeError( + "Could not find an appropriate builder from the current builder " + "(see docker buildx use) or builders named fastsurfer* (searching for " + f"a builder of type {builder_type}, docker builders may not be supported " + f"with the selected export settings. {CREATE_BUILDER}" + ) + return default_builder_is_correct_type, alternative_builder + + def docker_build_image( image_name: str, dockerfile: Path, working_directory: Optional[Path] = None, - context: Path = ".", + context: Path | str = ".", dry_run: bool = False, - **kwargs): - logger.info("Building. This starts with sending the build context to the docker daemon, which may take a while...") + attestation: bool = False, + action: Literal["load", "push"] = "load", + image_path: Path | str | None = None, + **kwargs) -> None: + """ + Build a docker image. + + Parameters + ---------- + image_name : str + Name / target tag of the image. + dockerfile : Path, str + Path to the Dockerfile. + working_directory : Path, str, optional + Path o the working directory to perform the build operation (default: inherit). + context : Path, str, optional + Base path to the context folder to build the docker image from (default: '.'). + dry_run : bool, optional + Whether to actually trigger the build, or just print the command to the console + (default: False => actually build). + cache_to : str, optional + Forces usage of buildx over build, use docker build caching as in the --cache-to + argument to docker buildx build. + attestation : bool, default=False + Whether to create sbom and provenance attestation + action : "load", "push", default="load" + The operation to perform after the image is built (only if a docker-container + builder is detected). + image_path : Path, str, optional + A path to save the image to (experimental; currently cannot be imported into a + legacy docker storage driver). + + Additional kwargs add additional build flags to the build command in the following + manner: "_" is replaced by "-" in the keyword name and each sequence entry is passed + with its own flag, e.g. `docker_build_image(..., build_arg=["TEST=1", "VAL=2"])` is + translated to `docker [buildx] build ... --build-arg TEST=1 --build-arg VAL=2`. + """ + from itertools import chain, repeat + from shutil import which + from subprocess import PIPE + + from FastSurferCNN.utils.run_tools import Popen + + logger.info("Building. This starts with sending the build context to the docker " + "daemon, which may take a while...") extra_env = {"DOCKER_BUILDKIT": "1"} - from itertools import chain - def to_pair(key, values): - _values = values if isinstance(values, Sequence) and not isinstance(values, (str, bytes)) else [values] - key_dashed = key.replace("_", "-") - return list(chain(*[[f"--{key_dashed}"] + ([] if val is None else [val]) for val in _values])) + docker_cmd = which("docker") + if docker_cmd is None: + raise FileNotFoundError("Could not locate the docker executable") - # needs buildx - buildx = "cache_to" in kwargs - args = ["buildx", "build"] if buildx else ["build"] + if action not in ("load", "push"): + raise ValueError(f"Invalid Value for 'action' {action}, must be load or push.") - if buildx: - Popen = _import_calls(working_directory) # from fastsurfer dir - buildx_test = Popen(["docker", "buildx", "version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE).finish() - if "'buildx' is not a docker command" in buildx_test.err_str('utf-8').strip(): + def to_pair(key, values): + if isinstance(values, Sequence) and isinstance(values, (str, bytes)): + values = [values] + key_dashed = key.replace("_", "-") + # concatenate the --key_dashed value pairs + return list(chain(*zip(repeat(f"--{key_dashed}"), values))) + + kw = {"stdout": PIPE, "stderr": PIPE} + _buildx = Popen([docker_cmd, "buildx", "version"], **kw) + _storage = Popen([docker_cmd, "info", "-f", "{{.DriverStatus}}"], **kw) + has_buildx = "'buildx' is not a docker command" not in _buildx.finish().err_str() + has_storage = "io.containerd.snapshotter" in _storage.finish().out_str() + + def is_inline_cache(cache_kw): + inline_cache = "type=inline" + all_inline_cache = (None, "", inline_cache) + return kwargs.get(cache_kw, inline_cache) not in all_inline_cache + + # require buildx for sbom and provenance and cache != inline + require_container = (attestation or + any(is_inline_cache(f"cache_{c}") for c in ("to", "from"))) + import_after_args = [] + if dest := image_path or "": + logger.warning("Images exported with image_path cannot be imported into legacy " + "storage drivers. This feature is currently experimental. Also " + "note, that exporting to a file is incompatible with the load " + f"and push actions. Deactivating {action}-action!") + dest = f",dest={dest}" + action = "export" + if not has_buildx: + # only standard build environment arguments available + if require_container: + # not supported with builder != docker-container + raise RuntimeError( + "Using --cache_{from,to} or attestation requires docker buildx and a " + f"docker-container builder.\n{INSTALL_BUILDX}\n{CREATE_BUILDER}" + ) + if action != "load": raise RuntimeError( - "Using --cache requires docker buildx, install with 'wget -qO ~/" - ".docker/cli-plugins/docker-buildx https://github.com/docker/buildx/" - "releases/download//buildx-.'\n" - "e.g. 'wget -qO ~/.docker/cli-plugins/docker-buildx " - "https://github.com/docker/buildx/releases/download/v0.11.2/" - "buildx-v0.11.2.linux-amd64'\n" - "You may need to 'chmod +x ~/.docker/cli-plugins/docker-buildx'\n" - "See also https://github.com/docker/buildx#manual-download") - - params = [to_pair(*a) for a in kwargs.items()] - - args += ["-t", image_name, "-f", str(dockerfile)] + list(chain(*params)) + [str(context)] + "The legacy docker builder does not support pushing or exporting the " + "image." + ) + args = ["build"] + kwargs_to_exclude = [f"cache_{c}" for c in ("to", "from")] + else: + # buildx argument construction + args = ["buildx", "build"] + # raises RuntimeError, if a docker-container builder is required, but not found + default_builder_is_container, alternative_builder = get_builder( + Popen, + "docker-container", + require_container, + ) + if has_storage and action == "load": + image_type = f"docker" + elif action == "push": + # with containerd storage driver or pushing to registry + image_type = f"image" + # both support attestation no problem + elif action == "export": + experimental = ". No image will be imported. This features is experimental." + if attestation: + warn_msg = (f"{CONTAINERD_MESSAGE}The build script will save the image " + f"to {image_path} (which will contain the attestation " + f"manifest files){experimental}") + else: + warn_msg = (f"The build script will save the image to {image_path}" + f"{experimental}") + logger.warning(warn_msg) + image_type = f"oci{dest}" + if dry_run: + print(f"mkdir -p {Path(image_path).parent} && ", sep="") + else: + Path(image_path).parent.mkdir(exist_ok=True) + # importing after (bock docker image import as well as docker image load + # are not supported for images exported by buildkit. + # import_after_args = ["image", "import", image_path, image_name] + elif attestation: + # also implicitly action == load + raise RuntimeError(CONTAINERD_MESSAGE) + # Future Alternative: save the image to preserve the manifest files to file + else: + # no attestation, docker builder supports this format + image_type = f"docker" + + args.extend(["--output", f"type={image_type},name={image_name}"]) + if not bool(import_after_args): + args.append(f"--{action}") + if attestation: + args.extend([ + "--attest", "type=sbom", + "--attest", "type=provenance", + ]) + if not default_builder_is_container: + args.extend(["--builder", alternative_builder]) + + kwargs_to_exclude = [] + + params = [to_pair(k, v) for k, v in kwargs.items() if k not in kwargs_to_exclude] + # arguments for standard build and buildx + args.extend([ + "-t", image_name, + "-f", str(dockerfile), + ]) + args.extend(chain(*params)) + args.append(str(context)) + if dry_run: extra_environment = [f"{k}={v}" for k, v in extra_env.items()] - print(" ".join(extra_environment + ["docker"] + args)) + print(" ".join(extra_environment + [docker_cmd] + args), sep="") + if import_after_args: + print(" && " + " ".join([docker_cmd] + import_after_args), sep="") else: - from shutil import which - docker_cmd = which("docker") - if docker_cmd is None: - raise FileNotFoundError("Could not locate the docker executable") - Popen = _import_calls(working_directory) # from fastsurfer dir env = dict(os.environ) env.update(extra_env) + + def forward_output_to_logger(process): + for msg in process: + if msg.out: + logger.info("stdout: " + msg.out.decode("utf-8")) + if msg.err: + logger.info("stderr: " + red(msg.err.decode("utf-8"))) + with Popen([docker_cmd] + args + ["--progress=plain"], cwd=working_directory, env=env, stdout=subprocess.PIPE) as proc: + forward_output_to_logger(proc) + if import_after_args: + with Popen([docker_cmd] + import_after_args, + cwd=working_directory, env=env, stdout=subprocess.PIPE) as proc: + forward_output_to_logger(proc) + + +def singularity_build_image( + image_name: str, + singularity_image: Path, + working_directory: Optional[Path] = None, + dry_run: bool = False, +): + """ + Build the singularity image from the docker image. + + Parameters + ---------- + image_name : str + The name of the docker image to build the singularity image from. + singularity_image : Path + The path and file of the singularity image to build. + working_directory : Path, str, optional + Path o the working directory to perform the build operation (default: inherit). + dry_run : bool, default=False + Whether to build from python or to print the command to stdout. + """ + from shutil import which + + # Create the folder for the singularity image + singularity_image.parent.mkdir(exist_ok=True) + args = [ + which("singularity"), + "build", + "--force", + str(singularity_image), + f"docker-daemon://{image_name}", + ] + if dry_run: + print(" ".join([" &&"] + args), sep="") + else: + from FastSurferCNN.utils.run_tools import Popen + with Popen(args, + cwd=working_directory, stdout=subprocess.PIPE) as proc: for msg in proc: if msg.out: logger.info("stdout: " + msg.out.decode("utf-8")) @@ -341,21 +609,15 @@ def to_pair(key, values): def main( device: DeviceType, cache: Optional[CacheSpec] = None, - rm: bool = False, target: Target = "runtime", debug: bool = False, image_tag: Optional[str] = None, dry_run: bool = False, tag_dev: bool = True, - **keywords - ): - this_script = Path(__file__) - if not this_script.is_absolute(): - this_script = Path.cwd() / __file__ - fastsurfer_home = this_script.parent.parent - version = _import_calls(fastsurfer_home, "version") - parse_build_file = _import_calls(fastsurfer_home, "parse_build_file") - + fastsurfer_home: Optional[Path] = None, + **keywords, + ) -> int | str: + from FastSurferCNN.version import has_git, main as version kwargs: Dict[str, Union[str, List[str]]] = {} if cache is not None: if not isinstance(cache, CacheSpec): @@ -363,13 +625,15 @@ def main( logger.info(f"cache: {cache}") kwargs["cache_from"] = cache.format_cache_from() kwargs["cache_to"] = cache.format_cache_from() - elif rm is True: - kwargs["no-cache"] = None + + fastsurfer_home = Path(fastsurfer_home) if fastsurfer_home else default_home() if target not in get_args(Target): raise ValueError(f"Invalid target: {target}") if device not in get_args(AllDeviceType): raise ValueError(f"Invalid device: {device}") + if keywords.get("action", "load") == "push": + kwargs["action"] = "push" # special case to add extra environment variables to better support AWS and ROCm if device.startswith("cu") and target == "runtime": target = "runtime_cuda" @@ -377,33 +641,59 @@ def main( kwargs["build_arg"] = [f"DEVICE={DEFAULTS.MapDeviceType.get(device, 'cpu')}"] if debug: kwargs["build_arg"].append(f"DEBUG=true") - for key in ["build_base_image", "runtime_base_image", "freesurfer_build_image", - "conda_build_image"]: + build_arg_list = [ + "build_base_image", + "runtime_base_image", + "freesurfer_build_image", + "conda_build_image", + ] + for key in build_arg_list: upper_key = key.upper() value = keywords.get(key) or getattr(DEFAULTS, upper_key) kwargs["build_arg"].append(f"{upper_key}={value}") # kwargs["build_arg"] = " ".join(kwargs["build_arg"]) build_filename = fastsurfer_home / "BUILD.info" + if has_git(): + version_sections = "+git" + else: + # try creating the build file without git info + version_sections = "" + logger.warning( + "Failed to create the git_status section in the BUILD.info file. " + "The resulting build file will not have valid git information, so " + "the version command of FastSurfer in the image will not complete." + ) + with open(build_filename, "w") as build_file, \ open(fastsurfer_home / "pyproject.toml") as project_file: - ret_version = version("+git", project_file=project_file, file=build_file) - if ret_version != 0: - return f"Creating the version file failed with message: {ret_version}" + ret_version = version( + version_sections, + project_file=project_file, + file=build_file, + build_cache=False, + ) + if ret_version != 0: + return f"Creating the version file failed with message: {ret_version}" with open(build_filename, "r") as build_file: + from FastSurferCNN.version import parse_build_file build_info = parse_build_file(build_file) version_tag = build_info["version_tag"] - image_suffix = "" + image_prefix = "" if device != "cuda": - image_suffix = f"-{device}" + image_prefix = f"{device}-" # image_tag is None or "" if not bool(image_tag): - image_tag = f"fastsurfer:{version_tag}{image_suffix}".replace("+", "_") + image_tag = f"fastsurfer:{image_prefix}{version_tag}".replace("+", "_") + logger.info(f"No image name/tag provided, auto-generated tag: {image_tag}") + attestation = bool(keywords.get("attest")) if tag_dev: - kwargs["tag"] = f"fastsurfer:dev{image_suffix}" + kwargs["tag"] = f"fastsurfer:dev{image_prefix}" + if keywords.get("image_path", False): + kwargs["image_path"] = keywords["image_path"] if not dry_run: logger.info("Version info added to the docker image:") @@ -417,26 +707,45 @@ def main( working_directory=fastsurfer_home, context=fastsurfer_home, dry_run=dry_run, - **kwargs + attestation=attestation, + **kwargs, ) + if singularity := keywords.get("singularity", None): + singularity_build_image( + image_tag, + Path(singularity), + dry_run=dry_run, + ) + print("") except RuntimeError as e: return e.args[0] return 0 +def default_home() -> Path: + """ + Find the fastsurfer path. + + Returns + ------- + Path + The FASTSURFER_HOME-path. + """ + if "FASTSURFER_HOME" in os.environ: + return Path(os.environ["FASTSURFER_HOME"]) + else: + return Path(__file__).parent.parent + + if __name__ == "__main__": import sys logging.basicConfig(stream=sys.stdout) arguments = make_parser().parse_args() # make sure the code can run without FastSurfer being in PYTHONPATH - if "FASTSURFER_HOME" in os.environ: - fastsurfer_home = os.environ["FASTSURFER_HOME"] - else: - fastsurfer_home = str(Path(__file__).parent.parent) - - if fastsurfer_home not in sys.path: - sys.path.append(fastsurfer_home) + fastsurfer_home = default_home() + if str(fastsurfer_home) not in sys.path: + sys.path.append(str(fastsurfer_home)) logger.setLevel(logging.WARN if arguments.dry_run else logging.INFO) - sys.exit(main(**vars(arguments))) + sys.exit(main(**vars(arguments), fastsurfer_home=fastsurfer_home)) diff --git a/Docker/conda_pack.sh b/Docker/conda_pack.sh index a4c3aa9a..9ac4cbe5 100755 --- a/Docker/conda_pack.sh +++ b/Docker/conda_pack.sh @@ -8,7 +8,7 @@ set -e # Install conda-pack -conda install -c conda-forge conda-pack +mamba install -c conda-forge conda-pack # Use conda-pack to create a standalone environment in /venv conda-pack -n "$1" -o /tmp/env.tar mkdir /venv @@ -16,4 +16,4 @@ cd /venv tar xf /tmp/env.tar rm /tmp/env.tar # Finally, when venv in a new location, fix up paths -/venv/bin/conda-unpack \ No newline at end of file +/venv/bin/conda-unpack diff --git a/Docker/install_env.py b/Docker/install_env.py index 503f6435..3d03d38d 100644 --- a/Docker/install_env.py +++ b/Docker/install_env.py @@ -12,14 +12,14 @@ arg_pattern = re.compile('^(\\s*-\\s*)(--[a-zA-Z0-9\\-]+)(\\s+\\S+)?(\\s*(#.*)?)$') -package_pattern = re.compile('^(\\s*-\\s*)([a-zA-Z0-9\\-]+|pip:)(\\s*[<=>~]{1,2}\\s*\\S+)?(\\s*(#.*)?\\s*)$') +package_pattern = re.compile('^(\\s*-\\s*)([a-zA-Z0-9\\.\\_\\-]+|pip:)(\\s*[<=>~]{1,2}\\s*\\S+)?(\\s*(#.*)?\\s*)$') dependencies_pattern = re.compile('^\\s*dependencies:\\s*$') def mode(arg: str) -> str: if arg in ["base", "cpu"] or \ re.match("^cu\\d+$", arg) or \ - re.match("^rocm\\d+\\.\\d+(\\.\\d+)?$"): + re.match("^rocm\\d+\\.\\d+(\\.\\d+)?$", arg): return arg else: raise argparse.ArgumentTypeError(f"The mode was '{arg}', but should be " diff --git a/Docker/install_fs_pruned.sh b/Docker/install_fs_pruned.sh index c8bd5ca5..e6d4e1a9 100755 --- a/Docker/install_fs_pruned.sh +++ b/Docker/install_fs_pruned.sh @@ -15,21 +15,41 @@ fslink="https://surfer.nmr.mgh.harvard.edu/pub/dist/freesurfer/7.4.1/freesurfer-linux-ubuntu22_amd64-7.4.1.tar.gz" -if [ "$#" -lt 1 ]; then +if [[ "$#" -lt 1 ]]; then echo - echo "Usage: install_fs_prunded install_dir <--upx>" + echo "Usage: install_fs_prunded install_dir [--upx] [--url freesurfer_download_url]" echo echo "--upx is optional, if passed, fs/bin will be packed" - echo + echo "--url is optional, if passed, freesurfer will be downloaded from it instead of $fslink" + echo exit 2 fi - where=/opt -if [ "$#" -ge 1 ]; then +if [[ "$#" -ge 1 ]]; then where=$1 + shift fi +upx="false" +while [[ "$#" -ge 1 ]]; do + lowercase=$(echo "$1" | tr '[:upper:]' '[:lower:]') + case $lowercase in + --upx) + upx="true" + shift + ;; + --url) + if [[ "$2" != "default" ]]; then fslink=$2; fi + shift + shift + ;; + *) + echo "Invalid argument $1" + exit 1 + ;; + esac +done fss=$where/fs-tmp fsd=$where/freesurfer echo @@ -41,6 +61,42 @@ echo "$fslink" echo +function run_parallel () +{ + # param 1 num_parallel_processes + # param 2 command (printf string) + # param 3 how many entries to consume from $@ per "run" + # param ... parameters to format, ie. we are executing $(printf $command $@...) + i=0 + pids=() + num_parallel_processes=$1 + command=$2 + num=$3 + shift + shift + shift + args=("$@") + j=0 + while [[ "$j" -lt "${#args}" ]] + do + cmd=$(printf "$command" "${args[@]:$j:$num}") + j=$((j + num)) + $cmd & + pids=("${pids[@]}" "$!") + i=$((i + 1)) + if [[ "$i" -ge "$num_parallel_processes" ]] + then + wait "${pids[0]}" + pids=("${pids[@]:1}") + fi + done + for pid in "${pids[@]}" + do + wait "$pid" + done +} + + # get Freesurfer and upack (some of it) echo "Downloading FS and unpacking portions ..." wget --no-check-certificate -qO- $fslink | tar zxv --no-same-owner -C $where \ @@ -48,7 +104,6 @@ wget --no-check-certificate -qO- $fslink | tar zxv --no-same-owner -C $where \ --exclude='freesurfer/average/Buckner_JNeurophysiol11_MNI152' \ --exclude='freesurfer/average/Choi_JNeurophysiol12_MNI152' \ --exclude='freesurfer/average/mult-comp-cor' \ - --exclude='freesurfer/average/mult-comp-cor' \ --exclude='freesurfer/average/samseg' \ --exclude='freesurfer/average/Yeo_Brainmap_MNI152' \ --exclude='freesurfer/average/Yeo_JNeurophysiol11_MNI152' \ @@ -178,6 +233,7 @@ copy_files=" bin/mri_concat bin/mri_concatenate_lta bin/mri_convert + bin/mri_coreg bin/mri_diff bin/mri_edit_wm_with_aseg bin/mri_fill @@ -200,6 +256,7 @@ copy_files=" bin/mri_surf2volseg bin/mri_tessellate bin/mri_vol2surf + bin/mri_vol2vol bin/mris_anatomical_stats bin/mris_autodet_gwstats bin/mris_ca_label @@ -213,7 +270,6 @@ copy_files=" bin/mris_extract_main_component bin/mris_fix_topology bin/mris_inflate - bin/mris_inflate bin/mris_info bin/mris_jacobian bin/mris_label2annot @@ -338,6 +394,14 @@ do cp -r $fss/$file $fsd/$file done +# pack if desired with upx (do this before adding all the links +if [[ "$upx" == "true" ]] ; then + echo "finding executables in $fsd/bin/..." + exe=$(find $fsd/bin -exec file {} \; | grep ELF | cut -d: -f1) + echo "packing $fsd/bin/ executables (this can take a while) ..." + run_parallel 8 "upx -9 %s %s %s %s" 4 $exe +fi + # Modify fsbindings Python package to allow calling scripts like asegstats2table directly: echo "from . import legacy" > "$fsd/python/packages/fsbindings/__init__.py" @@ -372,7 +436,6 @@ link_files=" bin/mri_stats2seg bin/mri_surf2vol bin/mri_surfcluster - bin/mri_vol2vol bin/mri_voldiff bin/mri_watershed bin/mris_divide_parcellation @@ -404,7 +467,7 @@ do done # use our python (not really needed in recon-all anyway) -p3=`which python3` +p3=$(which python3) if [ "$p3" == "" ]; then echo "No python3 found, please install first!" echo @@ -413,13 +476,4 @@ fi ln -s $p3 $fsd/bin/fspython #cleanup -rm -rf $fss - -# pack if desired with upx -if [ "$#" -ge 2 ]; then - if [ "${2^^}" == "--UPX" ] ; then - echo "packing $fsd/bin/ executables (this can take a while) ..." - exe=`find $fsd/bin -exec file {} \; | grep ELF | cut -d: -f1` - upx -9 $exe - fi -fi +rm -rf $fss \ No newline at end of file diff --git a/Docker/python-s b/Docker/python-s deleted file mode 100755 index cabd6f4b..00000000 --- a/Docker/python-s +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -inputargs=("$@") -python3.10 -s ${inputargs[@]} \ No newline at end of file diff --git a/FastSurferCNN/README.md b/FastSurferCNN/README.md index b9455d84..268bff2e 100644 --- a/FastSurferCNN/README.md +++ b/FastSurferCNN/README.md @@ -1,15 +1,17 @@ # Overview This directory contains all information needed to run inference with the readily trained FastSurferVINN or train it from scratch. FastSurferCNN is capable of whole brain segmentation into 95 classes in under 1 minute, mimicking FreeSurfer's anatomical segmentation and cortical parcellation (DKTatlas). The network architecture incorporates local and global competition via competitive dense blocks and competitive skip pathways, as well as multi-slice information aggregation that specifically tailor network performance towards accurate segmentation of both cortical and sub-cortical structures. -![](/images/detailed_network.png) +![](../doc/images/detailed_network.png) The network was trained with conformed images (UCHAR, 1-0.7 mm voxels and standard slice orientation). These specifications are checked in the run_prediction.py script and the image is automatically conformed if it does not comply. + # 1. Inference + The *FastSurferCNN* directory contains all the source code and modules needed to run the scripts. A list of python libraries used within the code can be found in __requirements.txt__. The main script is called __run_prediction.py__ within which certain options can be selected and set via the command line: -#### General +## General * `--in_dir`: Path to the input volume directory (e.g /your/path/to/ADNI/fs60) or * `--csv_file`: Path to csv-file listing input volume directories * `--t1`: name of the T1-weighted MRI_volume (like mri_volume.mgz, __default: orig.mgz__) @@ -20,16 +22,16 @@ The *FastSurferCNN* directory contains all the source code and modules needed to * `--strip`: strip suffix from path definition of input file to yield correct subject name. (Optional, if full path is defined for `--t1`) * `--lut`: FreeSurfer-style Color Lookup Table with labels to use in final prediction. Default: ./config/FastSurfer_ColorLUT.tsv * `--seg`: Name of intermediate DL-based segmentation file (similar to aparc+aseg). -* `--cfg_cor`: Path to the coronal config file -* `--cfg_sag`: Path to the axial config file -* `--cfg_ax`: Path to the sagittal config file -#### Checkpoints +## Checkpoints and configs * `--ckpt_sag`: path to sagittal network checkpoint * `--ckpt_cor`: path to coronal network checkpoint * `--ckpt_ax`: path to axial network checkpoint +* `--cfg_cor`: Path to the coronal config file +* `--cfg_sag`: Path to the axial config file +* `--cfg_ax`: Path to the sagittal config file -#### Optional commands +## Optional commands * `--clean`: clean up segmentation after running it (optional) * `--device `:Device for processing (_auto_, _cpu_, _cuda_, _cuda:_), where cuda means Nvidia GPU; you can select which one e.g. "cuda:1". Default: "auto", check GPU and then CPU * `--viewagg_device `: Define where the view aggregation should be run on. @@ -41,7 +43,7 @@ The *FastSurferCNN* directory contains all the source code and modules needed to * `--batch_size`: Batch size for inference. Default=1 -### Example Command Evaluation Single Subject +## Example Command: Evaluation Single Subject To run the network on MRI-volumes of subjectX in ./data (specified by `--t1` flag; e.g. ./data/subjectX/t1-weighted.nii.gz), change into the *FastSurferCNN* directory and run the following commands: ``` @@ -53,14 +55,14 @@ python3 run_prediction.py --t1 ../data/subjectX/t1-weighted.nii.gz \ The output will be stored in: -- ../output/subjectX/mri/aparc.DKTatlas+aseg.deep.mgz (large segmentation) -- ../output/subjectX/mri/mask.mgz (brain mask) -- ../output/subjectX/mri/aseg_noCC.mgz (reduced segmentation) +- `../output/subjectX/mri/aparc.DKTatlas+aseg.deep.mgz` (large segmentation) +- `../output/subjectX/mri/mask.mgz` (brain mask) +- `../output/subjectX/mri/aseg_noCC.mgz` (reduced segmentation) Here the logfile "temp_Competitive.log" will include the logfiles of all subjects. If left out, the logs will be written to stdout -### Example Command Evaluation whole directory +## Example Command: Evaluation whole directory To run the network on all subjects MRI-volumes in ./data, change into the *FastSurferCNN* directory and run the following command: ``` @@ -71,18 +73,19 @@ python3 run_prediction.py --in_dir ../data \ The output will be stored in: -- ../output/subjectX/mri/aparc.DKTatlas+aseg.deep.mgz (large segmentation) -- ../output/subjectX/mri/mask.mgz (brain mask) -- ../output/subjectX/mri/aseg_noCC.mgz (reduced segmentation) -- and the log in ../output/temp_Competitive.log - +- `../output/subjectX/mri/aparc.DKTatlas+aseg.deep.mgz` (large segmentation) +- `../output/subjectX/mri/mask.mgz` (brain mask) +- `../output/subjectX/mri/aseg_noCC.mgz` (reduced segmentation) +- and the log in `../output/temp_Competitive.log` + # 2. Hdf5-Trainingset Generation + The *FastSurferCNN* directory contains all the source code and modules needed to create a hdf5-file from given MRI volumes. Here, we use the orig.mgz output from freesurfer as the input image and the aparc.DKTatlas+aseg.mgz as the ground truth. The mapping functions are set-up accordingly as well and need to be changed if you use a different segmentation as ground truth. A list of python libraries used within the code can be found in __requirements.txt__. The main script is called __generate_hdf5.py__ within which certain options can be selected and set via the command line: -#### General +### General * `--hdf5_name`: Path and name of the to-be-created hdf5-file. Default: ../data/hdf5_set/Multires_coronal.hdf5 * `--data_dir`: Directory with images to load. Default: /data * `--pattern`: Pattern to match only certain files in the directory @@ -99,12 +102,12 @@ A list of python libraries used within the code can be found in __requirements.t The actual filename and segmentation ground truth name is specified via `--image_name` and `--gt_name` (e.g. the actual file could be sth. like /dataset/D1/subject1/mri_volume.mgz and /dataset/D1/subject1/segmentation.mgz) -#### Image Names +## Image Names * `--image_name`: Default name of original images. FreeSurfer orig.mgz is default (mri/orig.mgz) * `--gt_name`: Default name for ground truth segmentations. Default: mri/aparc.DKTatlas+aseg.mgz. * `--gt_nocc`: Segmentation without corpus callosum (used to mask this segmentation in ground truth). For a normal FreeSurfer input, use mri/aseg.auto_noCCseg.mgz. -#### Image specific options +## Image specific options * `--plane`: Which anatomical plane to use for slicing (axial, coronal or sagittal) * `--thickness`: Number of pre- and succeeding slices (we use 3 --> total of 7 slices is fed to the network; default: 3) * `--combi`: Suffixes of labels names to combine. Default: Left- and Right- @@ -117,7 +120,7 @@ The actual filename and segmentation ground truth name is specified via `--image * `--sizes`: Resolutions of images in the dataset. Default: 256 * `--edge_w`: Weight for edges in weight mask. Default=5 -#### Example Command Axial (Single Resolution) +## Example Command: Axial (Single Resolution) ``` python3 generate_hdf5.py \ --hdf5_name ../data/training_set_axial.hdf5 \ @@ -131,10 +134,9 @@ python3 generate_hdf5.py \ --edge_w 4 \ --hires_w 4 \ --sizes 256 - ``` -#### Example Command Coronal (Single Resolution) +## Example Command: Coronal (Single Resolution) ``` python3 generate_hdf5.py \ --hdf5_name ../data/training_set_coronal.hdf5 \ @@ -147,10 +149,9 @@ python3 generate_hdf5.py \ --edge_w 4 \ --hires_w 4 \ --sizes 256 - ``` -#### Example Command Sagittal (Multiple Resolutions) +## Example Command: Sagittal (Multiple Resolutions) ``` python3 generate_hdf5.py \ --hdf5_name ../data/training_set_sagittal.hdf5 \ @@ -163,10 +164,9 @@ python3 generate_hdf5.py \ --edge_w 4 \ --hires_w 4 \ --sizes 256 311 320 - ``` -#### Example Command Sagittal using --data_dir instead of --csv_file +## Example Command: Sagittal using --data_dir instead of --csv_file `--data_dir` specifies the path in which the data is located, with `--pattern` we can select subjects from the specified path. By default the pattern is "*" meaning all subjects will be selected. As an example, imagine you have 19 FreeSurfer processed subjects labeled subject1 to subject19 in the ../data directory: @@ -206,8 +206,9 @@ python3 generate_hdf5.py \ --gt_nocc mri/aseg.auto_noCCseg.mgz ``` - + # 3. Training + The *FastSurferCNN* directory contains all the source code and modules needed to run the scripts. A list of python libraries used within the code can be found in __requirements.txt__. The main training script is called __run_model.py__ whose options can be set through a configuration file and command line arguments: * `--cfg`: Path to the configuration file. Default: config/FastSurferVINN.yaml @@ -218,45 +219,45 @@ The `--cfg` file configures the model to be trained. See config/FastSurferVINN.y The configuration options include: -#### Model options -* MODEL_NAME: Name of model [FastSurferCNN, FastSurferVINN]. Default: FastSurferVINN -* NUM_CLASSES: Number of classes to predict including background. Axial and coronal: 79 (default), Sagittal: 51. -* NUM_FILTERS: Filter dimensions for Networks (all layers same). Default: 71 -* NUM_CHANNELS: Number of input channels (slice thickness). Default: 7 -* KERNEL_H: Height of Kernel. Default: 3 -* KERNEL_W: Width of Kernel. Default: 3 -* STRIDE_CONV: Stride during convolution. Default: 1 -* STRIDE_POOL: Stride during pooling. Default: 2 -* POOL: Size of pooling filter. Default: 2 -* BASE_RES: Base resolution of the segmentation model (after interpolation layer). Default: 1 +## Model options +* `MODEL_NAME`: Name of model [FastSurferCNN, FastSurferVINN]. Default: FastSurferVINN +* `NUM_CLASSES`: Number of classes to predict including background. Axial and coronal: 79 (default), Sagittal: 51. +* `NUM_FILTERS`: Filter dimensions for Networks (all layers same). Default: 71 +* `NUM_CHANNELS`: Number of input channels (slice thickness). Default: 7 +* `KERNEL_H`: Height of Kernel. Default: 3 +* `KERNEL_W`: Width of Kernel. Default: 3 +* `STRIDE_CONV`: Stride during convolution. Default: 1 +* `STRIDE_POOL`: Stride during pooling. Default: 2 +* `POOL`: Size of pooling filter. Default: 2 +* `BASE_RES`: Base resolution of the segmentation model (after interpolation layer). Default: 1 -#### Optimizer options +## Optimizer options -* BASE_LR: Base learning rate. Default: 0.01 -* OPTIMIZING_METHOD: Optimization method [sgd, adam, adamW]. Default: adamW -* MOMENTUM: Momentum for optimizer. Default: 0.9 -* NESTEROV: Enables Nesterov for optimizer. Default: True -* LR_SCHEDULER: Learning rate scheduler [step_lr, cosineWarmRestarts, reduceLROnPlateau]. Default: cosineWarmRestarts +* `BASE_LR`: Base learning rate. Default: 0.01 +* `OPTIMIZING_METHOD`: Optimization method [sgd, adam, adamW]. Default: adamW +* `MOMENTUM`: Momentum for optimizer. Default: 0.9 +* `NESTEROV`: Enables Nesterov for optimizer. Default: True +* `LR_SCHEDULER`: Learning rate scheduler [step_lr, cosineWarmRestarts, reduceLROnPlateau]. Default: cosineWarmRestarts -#### Data options +## Data options -* PATH_HDF5_TRAIN: Path to training hdf5-dataset -* PATH_HDF5_VAL: Path to validation hdf5-dataset -* PLANE: Plane to load [axial, coronal, sagittal]. Default: coronal +* `PATH_HDF5_TRAIN`: Path to training hdf5-dataset +* `PATH_HDF5_VAL`: Path to validation hdf5-dataset +* `PLANE`: Plane to load [axial, coronal, sagittal]. Default: coronal -#### Training options +## Training options -* BATCH_SIZE: Input batch size for training. Default: 16 -* NUM_EPOCHS: Number of epochs to train. Default: 30 -* SIZES: Available image sizes for the multi-scale dataloader. Default: [256, 311 and 320] -* AUG: Augmentations. Default: ["Scaling", "Translation"] +* `BATCH_SIZE`: Input batch size for training. Default: 16 +* `NUM_EPOCHS`: Number of epochs to train. Default: 30 +* `SIZES`: Available image sizes for the multi-scale dataloader. Default: [256, 311 and 320] +* `AUG`: Augmentations. Default: ["Scaling", "Translation"] -#### Misc. Options +## Misc. Options -* LOG_DIR: Log directory for run -* NUM_GPUS: Number of GPUs to use. Default: 1 -* RNG_SEED: Select random seed. Default: 1 +* `LOG_DIR`: Log directory for run +* `NUM_GPUS`: Number of GPUs to use. Default: 1 +* `RNG_SEED`: Select random seed. Default: 1 Any option can alternatively be set through the command-line by specifying the option name (as defined in config/defaults.py) followed by a value, such as: `MODEL.NUM_CLASSES 51`. @@ -264,14 +265,14 @@ Any option can alternatively be set through the command-line by specifying the o To train the network on a given hdf5-set, change into the *FastSurferCNN* directory and run `run_model.py` as in the following examples: -### Example Command: Training Default FastSurferVINN +## Example Command: Training Default FastSurferVINN Trains FastSurferVINN on multi-resolution images in the coronal plane: ``` python3 run_model.py \ --cfg ./config/FastSurferVINN.yaml ``` -### Example Command: Training FastSurferVINN (Single Resolution) +## Example Command: Training FastSurferVINN (Single Resolution) Trains FastSurferVINN on single-resolution images in the sagittal plane by overriding the NUM_CLASSES, SIZES, PATH_HDF5_TRAIN, and PATH_HDF5_VAL options: ``` python3 run_model.py \ @@ -282,7 +283,7 @@ DATA.PATH_HDF5_TRAIN ./hdf5_sets/training_sagittal_single_resolution.hdf5 \ DATA.PATH_HDF5_VAL ./hdf5_sets/validation_sagittal_single_resolution.hdf5 \ ``` -### Example Command: Training FastSurferCNN +## Example Command: Training FastSurferCNN Trains FastSurferCNN using a provided configuration file and specifying no augmentations: ``` python3 run_model.py \ diff --git a/FastSurferCNN/__init__.py b/FastSurferCNN/__init__.py index 14f13b55..9ac2d39c 100644 --- a/FastSurferCNN/__init__.py +++ b/FastSurferCNN/__init__.py @@ -21,6 +21,7 @@ "quick_qc", "reduce_to_aseg", "run_prediction", + "run_model", "segstats", "train", ] diff --git a/FastSurferCNN/config/FreeSurferColorLUT.txt b/FastSurferCNN/config/FreeSurferColorLUT.txt index 54336975..39bf4966 100644 --- a/FastSurferCNN/config/FreeSurferColorLUT.txt +++ b/FastSurferCNN/config/FreeSurferColorLUT.txt @@ -29,13 +29,13 @@ 25 Left-Lesion 255 165 0 0 26 Left-Accumbens-area 255 165 0 0 27 Left-Substancia-Nigra 0 255 127 0 -28 Left-VentralDC 165 42 42 0 +28 Left-VentralDC 145 42 42 0 29 Left-undetermined 135 206 235 0 30 Left-vessel 160 32 240 0 31 Left-choroid-plexus 0 200 200 0 32 Left-F3orb 100 50 100 0 -33 Left-lOg 135 50 74 0 -34 Left-aOg 122 135 50 0 +33 Left-aOg 122 135 50 0 +34 Left-WMCrowns 225 225 255 0 35 Left-mOg 51 50 135 0 36 Left-pOg 74 155 60 0 37 Left-Stellate 120 62 43 0 @@ -67,7 +67,7 @@ 63 Right-choroid-plexus 0 200 221 0 64 Right-F3orb 100 50 100 0 65 Right-lOg 135 50 74 0 -66 Right-aOg 122 135 50 0 +66 Right-WMCrowns 215 215 255 0 67 Right-mOg 51 50 135 0 68 Right-pOg 74 155 60 0 69 Right-Stellate 120 62 43 0 @@ -77,6 +77,9 @@ 73 Left-Interior 122 135 50 0 74 Right-Interior 122 135 50 0 # 75/76 removed. duplicates of 4/43 +# new 75/76 added +75 Left-Locus-Coeruleus 91 97 255 0 +76 Right-Locus-Coeruleus 0 7 218 0 77 WM-hypointensities 200 70 255 0 78 Left-WM-hypointensities 255 148 10 0 79 Right-WM-hypointensities 255 148 10 0 @@ -85,8 +88,8 @@ 82 Right-non-WM-hypointensities 164 108 226 0 83 Left-F1 255 218 185 0 84 Right-F1 255 218 185 0 -85 Optic-Chiasm 234 169 30 0 -192 Corpus_Callosum 250 255 50 0 +85 Optic-Chiasm 234 169 30 0 +192 Corpus_Callosum 170 255 255 0 86 Left_future_WMSA 200 120 255 0 87 Right_future_WMSA 200 121 255 0 @@ -191,6 +194,11 @@ 181 Right-Cortical-Dysplasia 73 62 139 0 182 CblumNodulus 10 100 176 0 +# Changed name to "Area" to indicate that this is not +# currently a fully vetted segmentation +183 Left-Vermis-Area 119 100 176 0 +184 Right-Vermis-Area 100 119 176 0 + #192 Corpus_Callosum listed after #85 above 193 Left-hippocampal_fissure 0 196 255 0 194 Left-CADG-head 255 164 164 0 @@ -229,6 +237,7 @@ 227 Polymorphic-Layer 128 255 128 0 228 Intracellular-Space 204 153 204 0 229 molecular_layer_DG 168 0 0 0 +230 Prosubiculum 252 132 8 0 231 HP_body 0 255 0 0 232 HP_head 255 0 0 0 @@ -265,11 +274,25 @@ 259 Eye-Fluid 60 60 60 0 260 BoneOrAir 119 159 176 0 261 PossibleFluid 120 18 134 0 -262 Sinus 196 160 128 0 +262 Sinus 85 85 127 0 263 Left-Eustachian 119 159 176 0 264 Right-Eustachian 119 159 176 0 265 Left-Eyeball 60 60 60 0 266 Right-Eyeball 60 60 60 0 +267 Pons-Belly-Area 206 195 58 0 + +270 ctx-lh-infragranular 205 70 78 0 +271 ctx-lh-layer1 210 80 78 0 +272 ctx-lh-layer2 215 90 78 0 +273 ctx-lh-layer3 220 100 78 0 +274 ctx-lh-layer4 225 110 78 0 +275 ctx-lh-layer5 230 120 78 0 +276 ctx-lh-layer6 235 135 78 0 +277 ctx-lh-supragranular 240 140 78 0 +278 SubiculumU 20 119 165 0 +279 CA1U 255 108 108 0 +280 CA2U 167 130 199 0 +281 CA3U 0 202 0 0 # lymph node and vascular labels 331 Aorta 255 0 0 0 @@ -302,6 +325,15 @@ 358 Pos-Lymph 20 130 180 0 359 Neg-Lymph 20 180 130 0 +370 ctx-rh-infragranular 205 70 130 0 +371 ctx-rh-layer1 210 80 130 0 +372 ctx-rh-layer2 215 90 130 0 +373 ctx-rh-layer3 220 100 130 0 +374 ctx-rh-layer4 225 110 130 0 +375 ctx-rh-layer5 230 120 130 0 +376 ctx-rh-layer6 235 135 130 0 +377 ctx-rh-supragranular 240 140 130 0 + 400 V1 206 62 78 0 401 V2 121 18 134 0 402 BA44 199 58 250 0 @@ -368,6 +400,8 @@ 557 left_subiculum 0 119 86 0 558 left_fornix 20 100 201 0 +559 Subcortical-Gray-Matter 123 187 221 0 + 600 Tumor 254 254 254 0 @@ -470,6 +504,56 @@ 809 R_hypothalamus_tubular_inferior 255 160 200 0 810 R_hypothalamus_tubular_superior 20 180 130 0 +# SAMSEG-CHARM + 400 +901 Air-Internal 0 170 0 0 +902 Artery 204 0 0 255 +906 EyeBalls 230 189 66 255 +907 Other-Tissues 85 85 0 0 +908 Rectus-Muscles 190 50 73 255 +909 Mucosa 255 160 188 255 +911 Skin 85 85 85 0 +912 Charm-Spinal-Cord 0 187 169 255 +914 Vein 0 84 255 255 +915 Bone-Cortical 0 85 0 0 +916 Bone-Cancellous 170 170 0 255 +917 Charm-Background 0 168 255 255 +920 Cortical-CSF 120 133 217 255 +930 Optic-Nerve 0 191 122 255 + +# Below is the color table for olfactory bulb structures generated by FastSurfer OB pipeline. +# See : https://doi.org/10.1016/j.neuroimage.2021.118464 + +951 L-olfactory-bulb 0 0 255 0 +952 R-olfactory-bulb 255 0 0 0 + +# Below is the color table for the hypothalamic subregions generated by FastSurfer HypVINN pipeline +#See : https://doi.org/10.1162/imag_a_00034 + +961 R-N.opticus 70 130 180 0 +962 L-N.opticus 130 180 70 0 +963 R-C.mammilare 205 62 78 0 +964 R-Optic-tract 80 120 134 0 +965 L-Optic-tract 196 58 250 0 +966 L-C.mammilare 0 148 0 0 +967 R-Chiasma-Opticum 220 248 164 0 +968 L-Chiasma-Opticum 230 148 34 0 +969 Ant-Commisure 10 180 225 0 +970 Third-Ventricle 118 0 100 0 +971 R-Fornix 122 200 120 0 +972 L-Fornix 236 13 176 0 +973 Epiphysis 204 182 142 0 +974 Hypophysis 119 159 176 0 +975 Infundibulum 220 216 20 0 +976 Tuberal-Region 120 60 110 0 +977 L-Med-Hypothalamus 165 255 0 0 +978 L-Lat-Hypothalamus 0 255 127 0 +979 L-Ant-Hypothalamus 165 42 42 0 +980 L-Post-Hypothalamus 255 215 0 0 +981 R-Med-Hypothalamus 115 255 0 0 +982 R-Lat-Hypothalamus 60 255 127 0 +983 R-Ant-Hypothalamus 165 142 42 0 +984 R-Post-Hypothalamus 255 170 20 0 + 999 SUSPICIOUS 255 100 100 0 # Below is the color table for the cortical labels of the seg volume @@ -678,6 +762,12 @@ 4206 wm-rh-parietal-lobe 35 195 35 0 4207 wm-rh-insula-lobe 20 220 160 0 +# From Jean A +5024 Low_TDP-43_Pathology_Density 170 238 44 0 +5025 Moderate_TDP-43_Pathology_Density 240 168 105 0 +5026 High_TDP-43_Pathology_Density 182 79 89 0 +5027 Severe_TDP-43_Pathology_Density 156 15 35 0 + # Below is the color table for the cortical labels of the seg volume # created by mri_aparc2aseg (with --a2005s flag) in which the aseg # cortex label is replaced by the labels in the aparc.a2005s. The @@ -1140,6 +1230,20 @@ 6070 Left-SLF2 236 14 230 0 6080 Right-SLF2 237 14 230 0 +6101 Left-Dura-MCA 34 197 246 0 +6102 Right-Dura-MCA 0 250 0 0 +6103 Left-Ento-Dura 160 32 240 0 +6104 Right-Ento-Dura 120 32 240 0 +6111 Left-Transverse-Sinus 229 11 152 0 +6112 Right-Transverse-Sinus 40 251 16 0 +6113 Left-Sigmoid-Sinus 202 3 26 0 +6114 Right-Sigmoid-Sinus 101 76 219 0 +6115 Straight-Sinus 185 210 35 0 +6116 Superior-Sinus-P 26 129 125 0 +6117 Superior-Sinus-D 144 51 173 0 +6118 Superior-Sinus-A 93 151 253 0 + + #No. Label Name: R G B A 7001 Lateral-nucleus 72 132 181 0 @@ -1163,6 +1267,9 @@ 7019 Envelope-Amygdala 141 21 100 0 7020 Extranuclear-Amydala 225 140 141 0 +7030 Left-Amygdala-Cortical-Junction 255 85 255 0 +7031 Right-Amygdala-Cortical-Junction 254 84 254 0 + 7100 Brainstem-inferior-colliculus 42 201 168 0 7101 Brainstem-cochlear-nucleus 168 104 162 0 @@ -1171,14 +1278,16 @@ 7203 PAG 153 153 255 0 7204 VTA 255 0 255 0 -7301 Left-LC 0 0 255 0 +# replaced by 75 +#7301 Left-LC 0 0 255 0 7302 Left-LDTg 255 127 0 0 7303 Left-mRt 255 0 0 0 7304 Left-PBC 255 255 0 0 7305 Left-PnO 0 127 255 0 7306 Left-PTg 127 0 255 0 -7401 Right-LC 0 0 255 0 +# replaced by 76 +#7401 Right-LC 0 0 255 0 7402 Right-LDTg 255 127 0 0 7403 Right-mRt 255 0 0 0 7404 Right-PBC 255 255 0 0 @@ -1233,6 +1342,8 @@ 8130 Left-VM 85 255 0 0 8133 Left-VPL 255 0 255 0 8134 Left-PaV 120 18 134 0 +8135 Left-PuMm 170 255 255 0 +8136 Left-PuMl 140 240 255 0 8203 Right-AV 0 85 0 0 8204 Right-CeM 170 85 0 0 @@ -1261,6 +1372,8 @@ 8230 Right-VM 85 255 0 0 8233 Right-VPL 255 0 255 0 8234 Right-PaV 120 18 134 0 +8235 Right-PuMm 170 255 255 0 +8236 Right-PuMl 140 240 255 0 # # Labels for thalamus parcellation using probabilistic tractography. See: @@ -1287,6 +1400,234 @@ 9505 ctx-rh-prim-sec-somatosensory 225 70 105 0 9506 ctx-rh-occipital 225 70 15 0 +# Below is the the color table for the brainstem segmentation based on the Paxinos Atlas +# It is divided into cranial nerves, tracts, and nuclei which are subdivided by regions, medulla, pons and midbrain + +10000 undefined 136 24 212 0 + +# Cranial nerves +10003 oculomotor_nerve_3n 36 124 212 0 +10004 trochlear_nerve_4n 242 189 15 0 +10005 trigeminal_nerve_5n 130 112 140 0 +10006 abducens_nerve_6n 237 150 9 0 +10007 facial_nerve_7n 45 68 145 0 +10008 vestibulocochlear_nerve_8n 43 114 122 0 +10009 glossopharygeal_nerve_9n 194 60 60 0 +10010 vagus_nerve_10n 121 150 14 0 +10011 accessory_nerve_11n 207 118 167 0 +10012 hypoglossal_nerve_12n 114 98 140 0 + +# Tracts +10021 gracile_tract_gr 113 28 232 0 +10022 cuneate_tract_cu 250 192 75 0 +10023 spinal_trigeminal_tract_sp5 137 53 161 0 +10024 dorsal_spinocerebellar_tract_dsc 250 69 8 0 +10025 ventral_spinocerebellar_tract_vsc 117 43 77 0 +10026 spinothalamic_tract_spth 71 20 134 0 +10027 ventral_corticospinal_tract_vcs 170 37 68 0 +10028 lateral_corticospinal_tract_lcs 99 98 171 0 +10029 tectospinal_tract_ts 53 105 47 0 +10030 medial_lemniscus_ml 209 170 242 0 +10031 lateral_lemniscus_ll 141 87 120 0 +10032 pyramid_py 240 47 44 0 +10033 solitary_tract_sol 185 25 198 0 +10034 amiculum_of_the_inferior_olive_ami 240 235 98 0 +10035 internal_arcuate_fibers_ia 87 215 171 0 +10036 inferior_cerebellar_peduncle_icp 153 163 255 0 +10037 mid_cerebellar_peduncle_mcp 225 237 53 0 +10038 superior_cerebellar_peduncle_scp 65 224 164 0 +10039 central_tegmental_tract_ctg 16 187 185 0 +10040 dorsal_acoustic_stria_das 2 76 56 0 +10041 transverse_fibers_of_the_pons_tfp 40 32 209 0 +10042 longitudinal_fibers_of_the_pons_lfp 59 117 42 0 +10043 genu_of_the_facial_nerve_g7 161 178 209 0 +10044 mammillotegmental_tract_mtg 187 231 162 0 +10045 trigeminothalamic_tract_tth 143 182 241 0 +10046 cerebral_peduncle_cp 13 185 164 0 +10047 corticospinal_tract_csp 205 16 164 0 +10048 corticobulbar_tract_cbu 195 209 118 0 +10049 mesencephalic_trigeminal_tract_me5 220 107 63 0 +10050 medial_longitudinal_fasciculus_mlf 40 51 218 0 + +# Nuclei + +# in the Medulla +10101 Vestibulocochlear_nucleus_8N 62 138 85 0 +10102 Vagus_nerve_nucleus_10N 247 89 10 0 +10103 Dorsal_motor_nucleus_of_vagus_caudal_part_10CA 245 118 54 0 +10104 Dorsal_motor_nucleus_of_vagus_caudointermediate_part_10CaI 247 157 111 0 +10105 Dorsal_motor_nucleus_of_vagus_centrointermediate_part_10CeI 242 194 170 0 +10106 Dorsal_motor_nucleus_of_vagus_dorsointermediate_part_10DI 184 64 9 0 +10107 Dorsal_motor_nucleus_of_vagus_dorsorostral_part_10DR 176 99 60 0 +10108 Dorsal_motor_nucleus_of_vagus_medial_fringe_10F 166 126 106 0 +10109 Dorsal_motor_nucleus_of_vagus_rostrointermediate_part_10RI 115 73 52 0 +10110 Dorsal_motor_nucleus_of_vagus_ventrointermediate_part_10VI 179 93 50 0 +10111 Dorsal_motor_nucleus_of_vagus_ventrorostral_part_10VR 209 106 54 0 +10112 Accessory_nerve_nucleus_11N 54 49 212 0 +10113 Hypoglossal_nucleus_12N 121 240 17 0 +10114 Hypoglossal_nucleus_geniohyoid_part_12GH 150 237 74 0 +10115 Hypoglossal_nucleus_lateral_part_12L 189 242 143 0 +10116 Hypoglossal_nucleus_medial_part_12M 84 166 12 0 +10117 Hypoglossal_nucleus_ventral_part_12V 114 168 67 0 +10118 Hypoglossal_nucleus_ventrolateral_part_12VL 118 140 98 0 +10119 Gracile_nucleus_Gr 238 247 106 0 +10120 Cuneate_nucleus_Cu 10 136 240 0 +10121 Cuneate_nucleus_rotundus_part_CuR 109 177 232 0 +10122 Cuneate_nucleus_triangular_part_CuT 86 121 150 0 +10123 External_cuneate_nucleus_ECu 5 65 115 0 +10124 Lateral_pericuneate_nucleus_LPCu 61 98 219 0 +10125 Medial_pericuneate_nucleus_MPCu 9 61 230 0 +10126 Spinal_trigeminal_nucleus_Sp5 5 250 168 0 +10127 Spinal_trigeminal_nucleus_caudal_part_Sp5C 48 140 109 0 +10128 Spinal_trigeminal_nucleus_caudal_part_lamina_1_Sp5C1 165 207 193 0 +10129 Spinal_trigeminal_nucleus_caudal_part_lamina_2_Sp5C2 7 84 58 0 +10130 Spinal_trigeminal_nucleus_caudal_part_lamina_3_4_Sp5C3/4 58 140 113 0 +10131 Spinal_trigeminal_nucleus_interpolar_part_Sp5I 72 110 97 0 +10132 Spinal_trigeminal_nucleus_oral_part_Sp5O 104 227 186 0 +10133 Ambiguus_nucleus_Amb 117 25 198 0 +10134 Ambiguus_nucleus_compact_part_AmbC 171 99 235 0 +10135 Ambiguus_nucleus_loose_part_AmbL 203 161 240 0 +10136 Ambiguus_nucleus_semicompact_part_AmbSC 99 50 143 0 +10137 Retroambiguus_nucleus_RAmb 200 157 90 0 +10138 Intermediate_reticular_nucleus_IRt 242 169 51 0 +10139 Solitary_nucleus_Sol 120 165 144 0 +10140 Solitary_nucleus_commissural_part_SolC 5 247 134 0 +10141 Solitary_nucleus_dorsal_part_SolD 33 235 156 0 +10142 Solitary_nucleus_dorsolateral_part_SolDL 146 247 200 0 +10143 Solitary_nucleus_gelatinous_part_SolG 213 245 230 0 +10144 Solitary_nucleus_interstitial_part_SolI 130 173 153 0 +10145 Solitary_nucleus_intermediate_part_SolIM 67 125 98 0 +10146 Solitary_nucleus_medial_part_SolM 37 168 107 0 +10147 Solitary_nucleus_paracommissural_part_SolPaC 36 92 66 0 +10148 Solitary_nucleus_ventral_part_SolV 106 143 125 0 +10149 Solitary_nucleus_ventrolateral_SolVL 11 74 45 0 +10150 Inferior_olivary_nucleus_IO 182 10 250 0 +10151 Inferior_olivary_subnucleus_A_of_medial_nucleus_IOA 130 13 168 0 +10152 Inferior_olivary_subnucleus_B_of_medial_nucleus_IOB 203 91 240 0 +10153 Inferior_olivary_beta_subnucleus_IoBe 194 81 232 0 +10154 Inferior_olivary_subnucleus_C_of_medial_nucleus_IOC 124 84 138 0 +10155 Inferior_olivary_dorsal_nucleus_IOD 188 139 204 0 +10156 Inferior_olivary_dorsal_nucleus_caudal_part_IODC 126 53 150 0 +10157 Inferior_olivary_dorsomedial_cell_group_IODM 146 132 150 0 +10158 Inferior_olivary_cap_of_Kooy_of_the_medial_nucleus_IOK 126 32 158 0 +10159 Inferior_olivary_medial_nucleus_IOM 186 22 115 0 +10160 Inferior_olivary_principal_nucleus_IOPr 80 6 105 0 +10161 Inferior_olivary_ventrolateral_protrusion_IOVL 157 91 179 0 +10162 Intercalated_nucleus_In 223 226 196 0 +10163 Raphe_pallidus_nucleus_RPa 178 150 79 0 +10164 Raphe_obscurus_nucleus_ROb 126 233 88 0 +10165 Raphe_magnus_nucleus_RMg 30 178 165 0 +10166 Gigantocellular_reticular_nucleus_Gi 167 186 61 0 +10167 Gigantocellular_reticular_nucleus_ventral_part_GiV 195 230 7 0 +10168 Gigantocellular_reticular_nucleus_alpha_part_GiA 219 232 146 0 +10169 Dorsal_paragigantocellular_nucleus_DPGi 215 73 252 0 +10170 Lateral_paragigantocellular_nucleus_LPGi 47 224 163 0 +10171 Vestibular_nucleus_Ve 36 121 43 0 +10172 Lateral_vestibular_nucleus_LVe 172 34 15 0 +10173 Medial_vestibular_nucleus_MVe 106 80 64 0 +10174 Medial_vestibular_nucleus_magnocellular_part_MVeMC 48 28 158 0 +10175 Medial_vestibular_nucleus_parvocellular_part_MVePC 164 152 237 0 +10176 Paravestibular_nucleus_PaVe 13 76 235 0 +10177 Inferior_vestibular_nucleus_(spinal)_SpVe 249 246 217 0 +10178 Superior_vestibular_nucleus_SuVe 103 152 30 0 +10179 Nucleus_of_Roller_Ro 50 231 44 0 +10180 Arcuate_nucleus_Ar 242 168 29 0 +10181 Anteroventral_cochlear_nucleus_AVC 215 89 28 0 +10182 Dorsal_cochlear_nucleus_DC 232 140 95 0 +10183 Granular_cell_layer_of_the_cochlear_nucleus_GrC 143 77 46 0 +10184 Ventral_cochlear_nuclear_posterior_part_VCP 148 52 6 0 + + +# in the Pons +10301 Motor_trigeminal_nucleus_5N 7 250 209 0 +10302 Motor_trigeminal_nucleus_anterior_digastric_part_5ADi 100 250 225 0 +10303 Motor_trigeminal_nucleus_masseter_part_5Ma 160 250 235 0 +10304 Mylohyoid_subnuleus_of_the_motor_trigeminal_nucleus_5MHy 215 247 242 0 +10305 Motor_trigeminal_nucleus_parvocellar_part_5PC 3 161 134 0 +10306 Motor_trigeminal_nucleus_pterygoid_part_5Pt 75 153 140 0 +10307 Motor_trigeminal_nucleus_temporalis_part_5Te 112 140 136 0 +10308 Abducens_nucleus_6N 185 212 11 0 +10309 Facial_nucleus_7N 250 0 140 0 +10310 Facial_nucleus_dorsal_intermediate_subnucleus_7DI 242 48 155 0 +10311 Facial_nucleus_dorsomedial_subnucleus_7DM 250 120 191 0 +10312 Facial_nucleus_intermediate_part_7I 245 181 216 0 +10313 Facial_nucleus_lateral_subnucleus_7L 250 215 234 0 +10314 Facial_nucleus_ventral_intermediate_subnucleus_7VI 150 3 84 0 +10315 Facial_nucleus_ventrolateral_subnucleus_7VL 150 65 112 0 +10316 Facial_nucleus_ventromedial_subnucleus_7VM 148 105 129 0 +10317 Pontine_nuclei_Pn 105 130 224 0 +10318 Pontine_reticular_nucleus_caudal_part_PRC(PnC) 246 160 93 0 +10319 Pontine_reticular_nucleus_oral_part_PRO(PnO) 133 63 9 0 +10320 Raphe_pontis_nucleus_RPn 70 153 237 0 +10321 Dorsal_raphe_nucleus_DR 41 66 83 0 +10322 Dorsal_raphe_nucleus_caudal_part_DRC 43 108 153 0 +10323 Dorsal_raphe_nucleus_dorsal_part_DRD 23 148 232 0 +10324 Dorsal_raphe_nucleus_interfascicular_part_DRI 108 177 224 0 +10325 Dorsal_raphe_nucleus_lateral_part_DRL 21 98 150 0 +10326 Dorsal_raphe_nucleus_ventral_part_DRV 174 210 235 0 +10327 Median_raphe_nucleus_MnR 248 139 133 0 +10328 Paramedian_raphe_nucleus_PMnR 195 35 162 0 +10329 Posterodorsal_raphe_nucleus_PDR 5 148 245 0 +10330 Mesencephalic_trigeminal_nucleus_Me5 91 187 197 0 +10331 Locus_coeruleus_LC 57 68 183 0 +10332 Subcoeruleus_nucleus_SubC 152 159 235 0 +10333 Subcoeruleus_nucleus_dorsal_part_SubCD 76 81 135 0 +10334 Subcoeruleus_nucleus_ventral_part_SubDV 24 35 150 0 +10335 Laterodorsal_tegmental_nucleus_LDTg 187 118 113 0 +10336 Laterodorsal_tegmental_nucleus_ventral_part_LDTgV 43 151 229 0 +10337 Dorsal_tegmental_nucleus_central_part_DTgC 255 48 167 0 +10338 Dorsal_tegmental_nucleus_pericentral_part_DTgP 173 95 44 0 +10339 Posterodorsal_tegmental_nucleus_PDTg 33 34 248 0 +10340 Reticulotegmental_nucleus_RtTg 29 244 26 0 +10341 Reticulotegmental_nucleus_lateral_part_RtTgL 24 137 128 0 +10342 Central_gray_of_the_rhombencephalon_CGPn 173 59 71 0 +10343 Superior_olivary_nucleus_SO 197 134 141 0 +10344 Medial_superior_olivary_nucleus_MSO 230 48 69 0 +10345 Lateral_superior_olivary_nucleus_LSO 232 167 175 0 +10346 Superior_paraolivary_nucleus_SPO 135 72 79 0 +10347 Medioventral_periolivary_nucleus_MVPO 158 17 34 0 +10348 Lateroventral_periolivary_nucleus_LVPO 250 80 100 0 +10349 Matrix_region_of_the_rhombencephalon_Mx 228 124 216 0 +10350 Lateral_parabrachial_nucleus_LPB 93 236 213 0 +10351 Lateral_parabrachial_nucleus_central_part_LPBC 13 163 138 0 +10352 Lateral_parabrachial_nucleus_dorsal_part_LPBD 161 237 225 0 +10353 Lateral_parabrachial_nucleus_external_part_LPBE 67 125 115 0 +10354 Lateral_parabrachial_nucleus_superior_part_LPBS 12 130 110 0 +10355 Medial_parabrachial_nucleus_MPB 224 128 49 0 +10356 Medial_parabrachial_nucleus_external_part_MPBE 222 182 149 0 + +# in the Midbrain +10401 Oculomotor_nucleus_3N 144 212 36 0 +10402 Trochlear_nucleus_4N 212 83 36 0 +10403 Inferior_colliculus_IC 148 30 154 0 +10404 Nucleus_of_the_brachium_of_the_inferior_colliculus_BIC 235 5 247 0 +10405 Central_nucleus_of_the_inferior_colliculus_CIC 237 131 242 0 +10406 Dorsal_cortex_of_the_inferior_colliculus_DCIC 205 169 207 0 +10407 External_cortex_of_the_inferior_colliculus_ECIC 129 53 133 0 +10408 Superior_colliculus_SC 34 124 43 0 +10409 Deep_gray_layer_of_the_superior_colliculus_DpG 137 171 140 0 +10410 Deep_white_layer_of_the_superior_colliculus_DpWh 218 240 220 0 +10411 Intermediate_gray_layer_of_the_superior_colliculus_InG 7 125 19 0 +10412 Intermediate_white_layer_of_the_superior_colliculus_InWh 170 240 177 0 +10413 Superficial_gray_layer_of_the_superior_colliculus_SuG 44 77 43 0 +10414 Superficial_white_layer_of_the_superior_colliculus_SuWh 51 171 63 0 +10415 Optic_nerve_layer_of_the_superior_colliculus_Op 4 87 13 0 +10416 Periaqueductal_gray_PAG 206 78 109 0 +10417 Dorsolateral_periaqueductal_gray_DLPAG 242 5 64 0 +10418 Dorsomedial_periaqueductal_gray_DMPAG 143 13 45 0 +10419 Lateral_periaqueductal_gray_LPAG 232 151 171 0 +10420 Ventrolateral_periaqueductal_gray_VLPAG 125 57 74 0 +10421 Red_nucleus_RN 250 2 2 0 +10422 Red_nucleus_magnocellular_part_RMC 240 132 132 0 +10423 Red_nucleus_parvocellular_part_RPC 135 39 39 0 +10424 Substantia_nigra_SN 252 220 131 0 +10425 Substantia_nigra_compact_part_SNC 245 181 5 0 +10426 Substantia_nigra_compact_part_dorsal_tier_SNCD 153 113 3 0 +10427 Mesencephalic_reticular_nuclei(formation)_MRN(mRt) 240 100 156 0 +10434 Ventral_tegmental_area_VTA 143 22 141 0 +10435 Ventral_tegmental_area_rostral_part_VTAR 194 110 192 0 +10436 Anterior_pretectal_nucleus_APT 3 10 123 0 + # Below is the color table for the cortical labels of the seg volume # created by mri_aparc2aseg (with --a2009s flag) in which the aseg # cortex label is replaced by the labels in the aparc.a2009s. The @@ -1373,6 +1714,7 @@ 11173 ctx_lh_S_temporal_inf 21 180 180 0 11174 ctx_lh_S_temporal_sup 223 220 60 0 11175 ctx_lh_S_temporal_transverse 221 60 60 0 +11300 ctx_lh_high_myelin 235 62 78 0 12100 ctx_rh_Unknown 0 0 0 0 12101 ctx_rh_G_and_S_frontomargin 23 220 60 0 @@ -1450,6 +1792,7 @@ 12173 ctx_rh_S_temporal_inf 21 180 180 0 12174 ctx_rh_S_temporal_sup 223 220 60 0 12175 ctx_rh_S_temporal_transverse 221 60 60 0 +12300 ctx_rh_high_myelin 235 82 78 0 #No. Label Name: R G B A 13100 wm_lh_Unknown 0 0 0 0 @@ -1606,3 +1949,58 @@ 14174 wm_rh_S_temporal_sup 223 220 60 0 14175 wm_rh_S_temporal_transverse 221 60 60 0 +# Below are labels for the Yeo atlas (both 7 and 17) +# https://surfer.nmr.mgh.harvard.edu/fswiki/CorticalParcellation_Yeo2011 +15000 yeo7_lh_Unknown 0 0 0 0 +15001 yeo7_lh_Net_1 120 18 134 0 +15002 yeo7_lh_Net_2 70 130 180 0 +15003 yeo7_lh_Net_3 0 118 14 0 +15004 yeo7_lh_Net_4 196 58 250 0 +15005 yeo7_lh_Net_5 220 248 164 0 +15006 yeo7_lh_Net_6 230 148 34 0 +15007 yeo7_lh_Net_7 205 62 78 0 +15010 yeo7_rh_Net_Unknown 0 0 0 0 +15011 yeo7_rh_Net_1 120 18 134 0 +15012 yeo7_rh_Net_2 70 130 180 0 +15013 yeo7_rh_Net_3 0 118 14 0 +15014 yeo7_rh_Net_4 196 58 250 0 +15015 yeo7_rh_Net_5 220 248 164 0 +15016 yeo7_rh_Net_6 230 148 34 0 +15017 yeo7_rh_Net_7 205 62 78 0 + +15100 yeo17_lh_Net_Unknown 0 0 0 0 +15101 yeo17_lh_Net_1 120 18 134 0 +15102 yeo17_lh_Net_2 255 0 0 0 +15103 yeo17_lh_Net_3 70 130 180 0 +15104 yeo17_lh_Net_4 42 204 164 0 +15105 yeo17_lh_Net_5 74 155 60 0 +15106 yeo17_lh_Net_6 0 118 14 0 +15107 yeo17_lh_Net_7 196 58 250 0 +15108 yeo17_lh_Net_8 255 152 213 0 +15109 yeo17_lh_Net_9 220 248 164 0 +15110 yeo17_lh_Net_10 122 135 50 0 +15111 yeo17_lh_Net_11 119 140 176 0 +15112 yeo17_lh_Net_12 230 148 34 0 +15113 yeo17_lh_Net_13 135 50 74 0 +15114 yeo17_lh_Net_14 12 48 255 0 +15115 yeo17_lh_Net_15 0 0 130 0 +15116 yeo17_lh_Net_16 255 255 0 0 +15117 yeo17_lh_Net_17 205 62 78 0 +15120 yeo17_rh_Net_Unknown 0 0 0 0 +15121 yeo17_rh_Net_1 120 18 134 0 +15122 yeo17_rh_Net_2 255 0 0 0 +15123 yeo17_rh_Net_3 70 130 180 0 +15124 yeo17_rh_Net_4 42 204 164 0 +15125 yeo17_rh_Net_5 74 155 60 0 +15126 yeo17_rh_Net_6 0 118 14 0 +15127 yeo17_rh_Net_7 196 58 250 0 +15128 yeo17_rh_Net_8 255 152 213 0 +15129 yeo17_rh_Net_9 220 248 164 0 +15130 yeo17_rh_Net_10 122 135 50 0 +15131 yeo17_rh_Net_11 119 140 176 0 +15132 yeo17_rh_Net_12 230 148 34 0 +15133 yeo17_rh_Net_13 135 50 74 0 +15134 yeo17_rh_Net_14 12 48 255 0 +15135 yeo17_rh_Net_15 0 0 130 0 +15136 yeo17_rh_Net_16 255 255 0 0 +15137 yeo17_rh_Net_17 205 62 78 0 diff --git a/FastSurferCNN/config/checkpoint_paths.yaml b/FastSurferCNN/config/checkpoint_paths.yaml new file mode 100644 index 00000000..b003c102 --- /dev/null +++ b/FastSurferCNN/config/checkpoint_paths.yaml @@ -0,0 +1,13 @@ +url: +- "https://zenodo.org/records/10390573/files" +- "https://b2share.fz-juelich.de/api/files/a423a576-220d-47b0-9e0c-b5b32d45fc59" + +checkpoint: + axial: "checkpoints/aparc_vinn_axial_v2.0.0.pkl" + coronal: "checkpoints/aparc_vinn_coronal_v2.0.0.pkl" + sagittal: "checkpoints/aparc_vinn_sagittal_v2.0.0.pkl" + +config: + axial: "FastSurferCNN/config/FastSurferVINN_axial.yaml" + coronal: "FastSurferCNN/config/FastSurferVINN_coronal.yaml" + sagittal: "FastSurferCNN/config/FastSurferVINN_sagittal.yaml" diff --git a/FastSurferCNN/config/global_var.py b/FastSurferCNN/config/global_var.py index eee3c30b..88e5580d 100644 --- a/FastSurferCNN/config/global_var.py +++ b/FastSurferCNN/config/global_var.py @@ -147,14 +147,15 @@ def get_class_names(plane, options): Parameters ---------- - plane : - [MISSING] - options : - [MISSING] + plane : str + Plane of the MRI scan. + options : List[str] + List of classes to include. Returns ------- - [MISSING] + selection : List[str] + List of class names. """ selection = [] diff --git a/FastSurferCNN/data_loader/__init__.py b/FastSurferCNN/data_loader/__init__.py index 5d32d855..9e4a2696 100644 --- a/FastSurferCNN/data_loader/__init__.py +++ b/FastSurferCNN/data_loader/__init__.py @@ -12,4 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["augmentation", "conform", "data_utils", "dataset", "loader"] +__all__ = [ + "augmentation", + "conform", + "data_utils", + "dataset", + "loader", +] \ No newline at end of file diff --git a/FastSurferCNN/data_loader/augmentation.py b/FastSurferCNN/data_loader/augmentation.py index 7499bb3e..4e8a5407 100644 --- a/FastSurferCNN/data_loader/augmentation.py +++ b/FastSurferCNN/data_loader/augmentation.py @@ -25,27 +25,28 @@ # Transformations for evaluation ## class ToTensorTest(object): - """Convert np.ndarrays in sample to Tensors. + """ + Convert np.ndarrays in sample to Tensors. Methods ------- __call__ - converts image + Converts image. """ def __call__(self, img: npt.NDArray) -> np.ndarray: - """Convert the image to float within range [0, 1] and make it torch compatible. + """ + Convert the image to float within range [0, 1] and make it torch compatible. Parameters ---------- img : npt.NDArray - Image to be converted + Image to be converted. Returns ------- img : np.ndarray - Conformed image - + Conformed image. """ img = img.astype(np.float32) @@ -61,37 +62,37 @@ def __call__(self, img: npt.NDArray) -> np.ndarray: class ZeroPad2DTest(object): - """Pad the input with zeros to get output size. + """ + Pad the input with zeros to get output size. Attributes ---------- output_size : Union[Number, Tuple[Number, Number]] - size of the output image either as Number or tuple of two Number + Size of the output image either as Number or tuple of two Number. pos : str - position to put the input + Position to put the input. Methods ------- pad - pad zeroes of image + Pad zeroes of image. call - call _pad() + Call _pad(). """ - def __init__( self, output_size: Union[Number, Tuple[Number, Number]], pos: str = 'top_left' ): - """Construct object. + """ + Construct object. Parameters ---------- output_size : Union[Number, Tuple[Number, Number]] - size of the output image either as Number or tuple of two Number + Size of the output image either as Number or tuple of two Number. pos : Union[Number, Tuple[Number, Number]] - position to put the input. Defaults to 'top_left' - + Position to put the input. Defaults to 'top_left'. """ if isinstance(output_size, Number): output_size = (int(output_size),) * 2 @@ -99,18 +100,18 @@ def __init__( self.pos = pos def _pad(self, image: npt.NDArray) -> np.ndarray: - """Pad with zeros of the input image. + """ + Pad with zeros of the input image. Parameters ---------- image : npt.NDArray - The image to pad + The image to pad. Returns ------- padded_img : np.ndarray - original image with padded zeros - + Original image with padded zeros. """ if len(image.shape) == 2: h, w = image.shape @@ -125,18 +126,18 @@ def _pad(self, image: npt.NDArray) -> np.ndarray: return padded_img def __call__(self, img: npt.NDArray) -> np.ndarray: - """Call the _pad() function. + """ + Call the _pad() function. Parameters ---------- img : npt.NDArray - the image to pad + The image to pad. Returns ------- img : np.ndarray - original image with padded zeros - + Original image with padded zeros. """ img = self._pad(img) @@ -147,28 +148,28 @@ def __call__(self, img: npt.NDArray) -> np.ndarray: # Transformations for training ## class ToTensor(object): - """Convert ndarrays in sample to Tensors. + """ + Convert ndarrays in sample to Tensors. Methods ------- __call__ - Convert image - + Convert image. """ def __call__(self, sample: npt.NDArray) -> Dict[str, Any]: - """Convert the image to float within range [0, 1] and make it torch compatible. + """ + Convert the image to float within range [0, 1] and make it torch compatible. Parameters ---------- sample : npt.NDArray - sample image + Sample image. Returns ------- Dict[str, Any] - Converted image - + Converted image. """ img, label, weight, sf = ( sample["img"], @@ -196,39 +197,38 @@ def __call__(self, sample: npt.NDArray) -> Dict[str, Any]: class ZeroPad2D(object): - """Pad the input with zeros to get output size. + """ + Pad the input with zeros to get output size. Attributes ---------- output_size : Union[Number, Tuple[Number, Number]] - Size of the output image either as Number or tuple of two Number + Size of the output image either as Number or tuple of two Number. pos : str, Optional - Position to put the input + Position to put the input. Methods ------- _pad - Pads zeroes of image + Pads zeroes of image. __call__ - Cals _pad for sample - + Cals _pad for sample. """ - def __init__( self, output_size: Union[Number, Tuple[Number, Number]], pos: Union[None, str] = 'top_left' ): - """Initialize position and output_size (as Tuple[float]). + """ + Initialize position and output_size (as Tuple[float]). Parameters ---------- output_size : Union[Number, Tuple[Number, Number]] Size of the output image either as Number or - tuple of two Number + tuple of two Number. pos : str, Optional - Position to put the input. Default = 'top_left' - + Position to put the input. Default = 'top_left'. """ if isinstance(output_size, Number): output_size = (int(output_size),) * 2 @@ -236,18 +236,18 @@ def __init__( self.pos = pos def _pad(self, image: npt.NDArray) -> np.ndarray: - """Pad the input image with zeros. + """ + Pad the input image with zeros. Parameters ---------- image : npt.NDArray - The image to pad + The image to pad. Returns ------- padded_img : np.ndarray - Original image with padded zeros - + Original image with padded zeros. """ if len(image.shape) == 2: h, w = image.shape @@ -262,18 +262,18 @@ def _pad(self, image: npt.NDArray) -> np.ndarray: return padded_img def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Pad the image, label and weights. + """ + Pad the image, label and weights. Parameters ---------- - sample :Dict[str, Any] - Sample image + sample : Dict[str, Any] + Sample image. Returns ------- Dict[str, Any] - Dictionary including the padded image, label, weight and scale factor - + Dictionary including the padded image, label, weight and scale factor. """ img, label, weight, sf = ( sample["img"], @@ -290,48 +290,48 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: class AddGaussianNoise(object): - """Add gaussian noise to sample. + """ + Add gaussian noise to sample. Attributes ---------- std - Standard deviation + Standard deviation. mean - Gaussian mean + Gaussian mean. Methods ------- __call__ - Adds noise to scale factor + Adds noise to scale factor. """ - def __init__(self, mean: Real = 0, std: Real = 0.1): - """Construct object. + """ + Construct object. Parameters ---------- mean : Real - Standard deviation. Default = 0 + Standard deviation. Default = 0. std : Real - Gaussian mean. Default = 0.1 - + Gaussian mean. Default = 0.1. """ self.std = std self.mean = mean def __call__(self, sample: Dict[str, Real]) -> Dict[str, Real]: - """Add gaussian noise to scalefactor. + """ + Add gaussian noise to scalefactor. Parameters ---------- - sample :Dict[str, Real] - Sample data to add noise + sample : Dict[str, Real] + Sample data to add noise. Returns ------- Dict[str, Real] - Sample with noise - + Sample with noise. """ img, label, weight, sf = ( sample["img"], @@ -345,37 +345,38 @@ def __call__(self, sample: Dict[str, Real]) -> Dict[str, Real]: class AugmentationPadImage(object): - """Pad Image with either zero padding or reflection padding of img, label and weight. + """ + Pad Image with either zero padding or reflection padding of img, label and weight. Attributes ---------- - pad_size_imag - [missing] - pad_size_mask - [missing] + pad_size_image : tuple + The padding size for the image. + pad_size_mask : tuple + The padding size for the mask. + pad_type : str + The type of padding to be applied. Methods ------- __call - add zeroes - + Add zeroes. """ - def __init__( self, pad_size: Tuple[Tuple[int, int], Tuple[int, int]] = ((16, 16), (16, 16)), pad_type: str = "edge" ): - """Construct object. + """ + Construct object. - Attributes + Parameters ---------- - pad_size - [MISSING] - pad_type - [MISSING] - + pad_size : tuple + The padding size. + pad_type : str + The type of padding to be applied. """ assert isinstance(pad_size, (int, tuple)) @@ -391,13 +392,13 @@ def __init__( self.pad_type = pad_type def __call__(self, sample: Dict[str, Number]): - """Pad zeroes of sample image, label and weight. + """ + Pad zeroes of sample image, label and weight. Attributes ---------- sample : Dict[str, Number] - Sample image and data - + Sample image and data. """ img, label, weight, sf = ( sample["img"], @@ -414,7 +415,9 @@ def __call__(self, sample: Dict[str, Number]): class AugmentationRandomCrop(object): - """Randomly Crop Image to given size.""" + """ + Randomly Crop Image to given size. + """ def __init__(self, output_size: Union[int, Tuple], crop_type: str = 'Random'): """Construct object. @@ -422,9 +425,9 @@ def __init__(self, output_size: Union[int, Tuple], crop_type: str = 'Random'): Attributes ---------- output_size - Size of the output image either an integer or a tuple + Size of the output image either an integer or a tuple. crop_type - [MISSING] + The type of crop to be performed. """ assert isinstance(output_size, (int, tuple)) @@ -437,18 +440,18 @@ def __init__(self, output_size: Union[int, Tuple], crop_type: str = 'Random'): self.crop_type = crop_type def __call__(self, sample: Dict[str, Number]) -> Dict[str, Number]: - """Crops the augmentation. + """ + Crops the augmentation. Attributes ---------- sample : Dict[str, Number] - Sample image with data + Sample image with data. Returns ------- Dict[str, Number] - Cropped sample image - + Cropped sample image. """ img, label, weight, sf = ( sample["img"], diff --git a/FastSurferCNN/data_loader/conform.py b/FastSurferCNN/data_loader/conform.py index 828b6400..4e8b3224 100644 --- a/FastSurferCNN/data_loader/conform.py +++ b/FastSurferCNN/data_loader/conform.py @@ -1,4 +1,5 @@ -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# Copyright 2019 +# AI in Medical Imaging, German Center for Neurodegenerative Diseases (DZNE), Bonn # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +15,14 @@ # IMPORTS -import logging -from typing import Optional, Type, Tuple, Union import argparse +from enum import Enum +import logging import sys +from typing import Optional, Type, Tuple, Union, Iterable, cast import numpy as np +import numpy.typing as npt import nibabel as nib from FastSurferCNN.utils.arg_types import ( @@ -30,34 +33,54 @@ ) HELPTEXT = """ -Script to conform an MRI brain image to UCHAR, RAS orientation, and 1mm or minimal isotropic voxels +Script to conform an MRI brain image to UCHAR, RAS orientation, +and 1mm or minimal isotropic voxels + USAGE: conform.py -i -o OR conform.py -i --check_only Dependencies: - Python 3.8 + Python 3.8+ Numpy - http://www.numpy.org + https://www.numpy.org Nibabel to read and write FreeSurfer data - http://nipy.org/nibabel/ + https://nipy.org/nibabel/ Original Author: Martin Reuter Date: Jul-09-2019 """ h_input = "path to input image" h_output = "path to output image" -h_order = "order of interpolation (0=nearest,1=linear(default),2=quadratic,3=cubic)" +h_order = "order of interpolation (0=nearest, 1=linear(default), 2=quadratic, 3=cubic)" + +LIA_AFFINE = np.array([[-1, 0, 0], [0, 0, 1], [0, -1, 0]]) + + +class Criteria(Enum): + FORCE_LIA_STRICT = "lia strict" + FORCE_LIA = "lia" + FORCE_IMG_SIZE = "img size" + FORCE_ISO_VOX = "iso vox" + + +DEFAULT_CRITERIA_DICT = { + "lia": Criteria.FORCE_LIA, + "strict_lia": Criteria.FORCE_LIA_STRICT, + "iso_vox": Criteria.FORCE_ISO_VOX, + "img_size": Criteria.FORCE_IMG_SIZE, +} +DEFAULT_CRITERIA = frozenset(DEFAULT_CRITERIA_DICT.values()) def options_parse(): - """Command line option parser. + """ + Command line option parser. Returns ------- options - object holding options - + Object holding options. """ parser = argparse.ArgumentParser(usage=HELPTEXT) parser.add_argument( @@ -73,25 +96,26 @@ def options_parse(): dest="check_only", default=False, action="store_true", - help="If True, only checks if the input image is conformed, and does not return an output.", + help="If True, only checks if the input image is conformed, and does not " + "return an output.", ) parser.add_argument( "--seg_input", dest="seg_input", default=False, action="store_true", - help="Specifies whether the input is a seg image. If true, the check for conformance " - "disregards the uint8 dtype criteria. Use --dtype any for equivalent results. " - "--seg_input overwrites --dtype arguments.", + help="Specifies whether the input is a seg image. If true, the check for " + "conformance disregards the uint8 dtype criteria. Use --dtype any for " + "equivalent results. --seg_input overwrites --dtype arguments.", ) parser.add_argument( "--vox_size", dest="vox_size", default=1.0, type=__vox_size, - help="Specifies the target voxel size to conform to. Also allows 'min' for conforming to the " - "minimum voxel size, otherwise similar to mri_convert's --conform_size " - "(default: 1, conform to 1mm).", + help="Specifies the target voxel size to conform to. Also allows 'min' for " + "conforming to the minimum voxel size, otherwise similar to mri_convert's " + "--conform_size (default: 1, conform to 1mm).", ) parser.add_argument( "--conform_min", @@ -105,16 +129,42 @@ def options_parse(): advanced.add_argument( "--conform_to_1mm_threshold", type=__conform_to_one_mm, - help="Advanced option to change the threshold beyond which images are conformed to 1" - "(default: infinity, all images are conformed to their minimum voxel size).", + help="Advanced option to change the threshold beyond which images are " + "conformed to 1 (default: infinity, all images are conformed to their " + "minimum voxel size).", ) - parser.add_argument( + advanced.add_argument( "--dtype", dest="dtype", default="uint8", type=__target_dtype, - help="Specifies the target data type of the target image or 'any' (default: 'uint8', " - "as in FreeSurfer)", + help="Specifies the target data type of the target image or 'any' (default: " + "'uint8', as in FreeSurfer)", + ) + advanced.add_argument( + "--no_strict_lia", + dest="force_strict_lia", + action="store_false", + help="Ignore the forced LIA reorientation.", + ) + advanced.add_argument( + "--no_lia", + dest="force_lia", + action="store_false", + help="Ignore the reordering of data into LIA (without interpolation). " + "Superceeds --no_strict_lia", + ) + advanced.add_argument( + "--no_iso_vox", + dest="force_iso_vox", + action="store_false", + help="Ignore the forced isometric voxel size (depends on --conform_min).", + ) + advanced.add_argument( + "--no_img_size", + dest="force_img_size", + action="store_false", + help="Ignore the forced image dimensions (depends on --conform_min).", ) parser.add_argument( "--verbose", @@ -125,48 +175,52 @@ def options_parse(): ) args = parser.parse_args() if args.input is None: - sys.exit("ERROR: Please specify input image") + raise RuntimeError("ERROR: Please specify input image") if not args.check_only and args.output is None: - sys.exit("ERROR: Please specify output image") + raise RuntimeError("ERROR: Please specify output image") if args.check_only and args.output is not None: - sys.exit( + raise RuntimeError( "ERROR: You passed in check_only. Please do not also specify output image" ) if args.seg_input and args.dtype not in ["uint8", "any"]: print("WARNING: --seg_input overwrites the --dtype arguments.") + if not args.force_lia and args.force_strict_lia: + print("INFO: --no_lia includes --no_strict_lia.") + args.force_strict_lia = False return args def map_image( img: nib.analyze.SpatialImage, out_affine: np.ndarray, - out_shape: np.ndarray, + out_shape: tuple[int, ...] | np.ndarray | Iterable[int], ras2ras: Optional[np.ndarray] = None, order: int = 1, dtype: Optional[Type] = None ) -> np.ndarray: - """Map image to new voxel space (RAS orientation). + """ + Map image to new voxel space (RAS orientation). Parameters ---------- img : nib.analyze.SpatialImage - the src 3D image with data and affine set + The src 3D image with data and affine set. out_affine : np.ndarray - trg image affine - out_shape : np.ndarray - the trg shape information - ras2ras : Optional[np.ndarray] - an additional mapping that should be applied (default=id to just reslice) - order : int - order of interpolation (0=nearest,1=linear(default),2=quadratic,3=cubic) - dtype : Optional[Type] - target dtype of the resulting image (relevant for reorientation, default=same as img) + Trg image affine. + out_shape : tuple[int, ...], np.ndarray + The trg shape information. + ras2ras : np.ndarray, optional + An additional mapping that should be applied (default=id to just reslice). + order : int, default=1 + Order of interpolation (0=nearest,1=linear,2=quadratic,3=cubic). + dtype : Type, optional + Target dtype of the resulting image (relevant for reorientation, + default=keep dtype of img). Returns ------- np.ndarray - mapped image data array - + Mapped image data array. """ from scipy.ndimage import affine_transform from numpy.linalg import inv @@ -180,20 +234,40 @@ def map_image( # here we apply the inverse vox2vox (to pull back the src info to the target image) image_data = np.asanyarray(img.dataobj) # convert frames to single image - if len(image_data.shape) > 3: - if any(s != 1 for s in image_data.shape[3:]): + + out_shape = tuple(out_shape) + # if input has frames + if image_data.ndim > 3: + # if the output has no frames + if len(out_shape) == 3: + if any(s != 1 for s in image_data.shape[3:]): + raise ValueError( + f"Multiple input frames {tuple(image_data.shape)} not supported!" + ) + image_data = np.squeeze(image_data, axis=tuple(range(3, image_data.ndim))) + # if the output has the same number of frames as the input + elif image_data.shape[3:] == out_shape[3:]: + # add a frame dimension to vox2vox + _vox2vox = np.eye(5, dtype=vox2vox.dtype) + _vox2vox[:3, :3] = vox2vox[:3, :3] + _vox2vox[3:, 4:] = vox2vox[:3, 3:] + vox2vox = _vox2vox + else: raise ValueError( - f"Multiple input frames {tuple(image_data.shape)} not supported!" - ) - image_data = np.squeeze(image_data, axis=tuple(range(3, len(image_data.shape)))) + f"Input image and requested output shape have different frames:" + f"{image_data.shape} vs. {out_shape}!" + ) if dtype is not None: image_data = image_data.astype(dtype) - new_data = affine_transform( - image_data, inv(vox2vox), output_shape=out_shape, order=order + if not is_resampling_vox2vox(vox2vox): + # this is a shortcut to reordering resampling + order = 0 + + return affine_transform( + image_data, inv(vox2vox), output_shape=out_shape, order=order, ) - return new_data def getscale( @@ -203,94 +277,102 @@ def getscale( f_low: float = 0.0, f_high: float = 0.999 ) -> Tuple[float, float]: - """Get offset and scale of image intensities to robustly rescale to range dst_min..dst_max. + """ + Get offset and scale of image intensities to robustly rescale to dst_min..dst_max. Equivalent to how mri_convert conforms images. Parameters ---------- data : np.ndarray - image data (intensity values) + Image data (intensity values). dst_min : float - future minimal intensity value + Future minimal intensity value. dst_max : float - future maximal intensity value - f_low : float - robust cropping at low end (0.0 no cropping, default) - f_high : float - robust cropping at higher end (0.999 crop one thousandth of high intensity voxels, default) + Future maximal intensity value. + f_low : float, default=0.0 + Robust cropping at low end (0.0=no cropping). + f_high : float, default=0.999 + Robust cropping at higher end (0.999=crop one thousandth of highest intensity). Returns ------- float src_min - (adjusted) offset + (adjusted) offset. float - scale factor - + Scale factor. """ + + if f_low < 0. or f_high > 1. or f_low > f_high: + raise ValueError( + "Invalid values for f_low or f_high, must be within 0 and 1." + ) + # get min and max from source - src_min = np.min(data) - src_max = np.max(data) + data_min = np.min(data) + data_max = np.max(data) - if src_min < 0.0: + if data_min < 0.0: + # logger. warning print("WARNING: Input image has value(s) below 0.0 !") - - print("Input: min: " + format(src_min) + " max: " + format(src_max)) + # logger.info + print(f"Input: min: {data_min} max: {data_max}") if f_low == 0.0 and f_high == 1.0: - return src_min, 1.0 + return data_min, 1.0 # compute non-zeros and total vox num - nz = (np.abs(data) >= 1e-15).sum() - voxnum = data.shape[0] * data.shape[1] * data.shape[2] - - # compute histogram - histosize = 1000 - bin_size = (src_max - src_min) / histosize - hist, bin_edges = np.histogram(data, histosize) - - # compute cumulative sum - cs = np.concatenate(([0], np.cumsum(hist))) - - # get lower limit - nth = int(f_low * voxnum) - idx = np.where(cs < nth) - - if len(idx[0]) > 0: - idx = idx[0][-1] + 1 - + num_nonzero_voxels = (np.abs(data) >= 1e-15).sum() + num_total_voxels = data.shape[0] * data.shape[1] * data.shape[2] + + # compute histogram (number of samples) + bins = 1000 + hist, bin_edges = np.histogram(data, bins=bins, range=(data_min, data_max)) + + # compute cumulative histogram + cum_hist = np.concatenate(([0], np.cumsum(hist))) + + # get lower limit: f_low fraction of total voxels + lower_cutoff = int(f_low * num_total_voxels) + binindex_lt_low_cutoff = np.flatnonzero(cum_hist < lower_cutoff) + + lower_binedge_index = 0 + # if we find any voxels + if len(binindex_lt_low_cutoff) > 0: + lower_binedge_index = binindex_lt_low_cutoff[-1] + 1 + + src_min: float = bin_edges[lower_binedge_index].item() + + # get upper limit (cutoff only based on non-zero voxels, i.e. how many + # non-zero voxels to ignore) + upper_cutoff = num_total_voxels - int((1.0 - f_high) * num_nonzero_voxels) + binindex_ge_up_cutoff = np.flatnonzero(cum_hist >= upper_cutoff) + + if len(binindex_ge_up_cutoff) > 0: + upper_binedge_index = binindex_ge_up_cutoff[0] - 2 + elif np.isclose(cum_hist[-1], 1.0, atol=1e-6) or num_nonzero_voxels < 10: + # if we cannot find a cutoff, check, if we are running into numerical + # issues such that cum_hist does not properly account for the full hist + # index -1 should always yield the last element, which is data_max + upper_binedge_index = -1 else: - idx = 0 - - src_min = idx * bin_size + src_min - - # get upper limit - nth = voxnum - int((1.0 - f_high) * nz) - idx = np.where(cs >= nth) - - if len(idx[0]) > 0: - idx = idx[0][0] - 2 - - else: - print("ERROR: rescale upper bound not found") + # If no upper bound can be found, this is probably a bug somewhere + raise RuntimeError( + f"ERROR: rescale upper bound not found: f_high={f_high}" + ) - src_max = idx * bin_size + src_min + src_max: float = bin_edges[upper_binedge_index].item() # scale if src_min == src_max: + # logger.warning + print("WARNING: Scaling between src_min and src_max. The input image " + "is likely corrupted!") scale = 1.0 - else: scale = (dst_max - dst_min) / (src_max - src_min) - - print( - "rescale: min: " - + format(src_min) - + " max: " - + format(src_max) - + " scale: " - + format(scale) - ) + # logger.info + print(f"rescale: min: {src_min} max: {src_max} scale: {scale}") return src_min, scale @@ -302,26 +384,26 @@ def scalecrop( src_min: float, scale: float ) -> np.ndarray: - """Crop the intensity ranges to specific min and max values. + """ + Crop the intensity ranges to specific min and max values. Parameters ---------- data : np.ndarray - Image data (intensity values) + Image data (intensity values). dst_min : float - future minimal intensity value + Future minimal intensity value. dst_max : float - future maximal intensity value + Future maximal intensity value. src_min : float - minimal value to consider from source (crops below) + Minimal value to consider from source (crops below). scale : float - scale value by which source will be shifted + Scale value by which source will be shifted. Returns ------- np.ndarray - scaled image data - + Scaled image data. """ data_new = dst_min + scale * (data - src_min) @@ -341,26 +423,26 @@ def rescale( f_low: float = 0.0, f_high: float = 0.999 ) -> np.ndarray: - """Rescale image intensity values (0-255). + """ + Rescale image intensity values (0-255). Parameters ---------- data : np.ndarray - image data (intensity values) + Image data (intensity values). dst_min : float - future minimal intensity value + Future minimal intensity value. dst_max : float - future maximal intensity value - f_low : float - robust cropping at low end (0.0 no cropping, default) - f_high : float - robust cropping at higher end (0.999 crop one thousandth of high intensity voxels, default) + Future maximal intensity value. + f_low : float, default=0.0 + Robust cropping at low end (0.0=no cropping). + f_high : float, default=0.999 + Robust cropping at higher end (0.999=crop one thousandth of highest intensity). Returns ------- np.ndarray - scaled image data - + Scaled image data. """ src_min, scale = getscale(data, dst_min, dst_max, f_low, f_high) data_new = scalecrop(data, dst_min, dst_max, src_min, scale) @@ -368,24 +450,24 @@ def rescale( def find_min_size(img: nib.analyze.SpatialImage, max_size: float = 1) -> float: - """Find minimal voxel size <= 1mm. + """ + Find minimal voxel size <= 1mm. Parameters ---------- img : nib.analyze.SpatialImage - loaded source image + Loaded source image. max_size : float - maximal voxel size in mm (default: 1.0) + Maximal voxel size in mm (default: 1.0). Returns ------- float - Rounded minimal voxel size + Rounded minimal voxel size. Notes ----- This function only needs the header (not the data). - """ # find minimal voxel side length sizes = np.array(img.header.get_zooms()[:3]) @@ -399,18 +481,19 @@ def find_img_size_by_fov( vox_size: float, min_dim: int = 256 ) -> int: - """Find the cube dimension (>= 256) to cover the field of view of img. + """ + Find the cube dimension (>= 256) to cover the field of view of img. If vox_size is one, the img_size MUST always be min_dim (the FreeSurfer standard). Parameters ---------- img : nib.analyze.SpatialImage - loaded source image + Loaded source image. vox_size : float - the target voxel size in mm + The target voxel size in mm. min_dim : int - minimal image dimension in voxels (default 256) + Minimal image dimension in voxels (default 256). Returns ------- @@ -420,7 +503,6 @@ def find_img_size_by_fov( Notes ----- This function only needs the header (not the data). - """ if vox_size == 1.0: return min_dim @@ -440,9 +522,10 @@ def conform( order: int = 1, conform_vox_size: VoxSizeOption = 1.0, dtype: Optional[Type] = None, - conform_to_1mm_threshold: Optional[float] = None + conform_to_1mm_threshold: Optional[float] = None, + criteria: set[Criteria] = DEFAULT_CRITERIA, ) -> nib.MGHImage: - """Python version of mri_convert -c. + f"""Python version of mri_convert -c. mri_convert -c by default turns image intensity values into UCHAR, reslices images to standard position, fills up slices to standard @@ -451,105 +534,174 @@ def conform( Parameters ---------- img : nib.analyze.SpatialImage - loaded source image + Loaded source image. order : int - interpolation order (0=nearest,1=linear(default),2=quadratic,3=cubic) + Interpolation order (0=nearest,1=linear(default),2=quadratic,3=cubic). conform_vox_size : VoxSizeOption - conform image the image to voxel size 1. (default), a + Conform image the image to voxel size 1. (default), a specific smaller voxel size (0-1, for high-res), or automatically determine the 'minimum voxel size' from the image (value 'min'). This assumes the smallest of the three voxel sizes. dtype : Optional[Type] - the dtype to enforce in the image (default: UCHAR, as mri_convert -c) + The dtype to enforce in the image (default: UCHAR, as mri_convert -c). conform_to_1mm_threshold : Optional[float] - the threshold above which the image is conformed to 1mm + The threshold above which the image is conformed to 1mm (default: ignore). + criteria : set[Criteria], default={DEFAULT_CRITERIA} + Whether to force the conforming to include a LIA data layout, an image size + requirement and/or a voxel size requirement. Returns ------- nib.MGHImage - conformed image + Conformed image. Notes ----- Unlike mri_convert -c, we first interpolate (float image), and then rescale to uchar. mri_convert is doing it the other way around. However, we compute the scale factor from the input to increase similarity. - """ - from nibabel.freesurfer.mghformat import MGHHeader - conformed_vox_size, conformed_img_size = get_conformed_vox_img_size( - img, conform_vox_size, conform_to_1mm_threshold=conform_to_1mm_threshold + img, conform_vox_size, conform_to_1mm_threshold=conform_to_1mm_threshold, ) + from nibabel.freesurfer.mghformat import MGHHeader - h1 = MGHHeader.from_header( - img.header - ) # may copy some parameters if input was MGH format + # may copy some parameters if input was MGH format + h1 = MGHHeader.from_header(img.header) + mdc_affine = h1["Mdc"] + img_shape = img.header.get_data_shape() + vox_size = img.header.get_zooms() + do_interp = False + affine = img.affine[:3, :3] + if {Criteria.FORCE_LIA, Criteria.FORCE_LIA_STRICT} & criteria != {}: + do_interp = bool(Criteria.FORCE_LIA_STRICT in criteria and is_lia(affine, True)) + re_order_axes = [np.abs(affine[:, j]).argmax() for j in (0, 2, 1)] + else: + re_order_axes = [0, 1, 2] - h1.set_data_shape([conformed_img_size, conformed_img_size, conformed_img_size, 1]) - h1.set_zooms( - [conformed_vox_size, conformed_vox_size, conformed_vox_size] - ) # --> h1['delta'] - h1["Mdc"] = [[-1, 0, 0], [0, 0, -1], [0, 1, 0]] - h1["fov"] = conformed_img_size * conformed_vox_size - h1["Pxyz_c"] = img.affine.dot(np.hstack((np.array(img.shape[:3]) / 2.0, [1])))[:3] + if Criteria.FORCE_IMG_SIZE in criteria: + h1.set_data_shape([conformed_img_size] * 3 + [1]) + else: + h1.set_data_shape([img_shape[i] for i in re_order_axes] + [1]) + if Criteria.FORCE_ISO_VOX in criteria: + h1.set_zooms([conformed_vox_size] * 3) # --> h1['delta'] + do_interp |= not np.allclose(vox_size, conformed_vox_size) + else: + h1.set_zooms([vox_size[i] for i in re_order_axes]) + + if Criteria.FORCE_LIA_STRICT in criteria: + mdc_affine = LIA_AFFINE + elif Criteria.FORCE_LIA in criteria: + mdc_affine = affine[:3, re_order_axes] + if mdc_affine[0, 0] > 0: # make 0,0 negative + mdc_affine[:, 0] = -mdc_affine[:, 0] + if mdc_affine[1, 2] < 0: # make 1,2 positive + mdc_affine[:, 2] = -mdc_affine[:, 2] + if mdc_affine[2, 1] > 0: # make 2,1 negative + mdc_affine[:, 1] = -mdc_affine[:, 1] + else: + mdc_affine = img.affine[:3, :3] + + mdc_affine = mdc_affine / np.linalg.norm(mdc_affine, axis=1) + h1["Mdc"] = np.linalg.inv(mdc_affine) + + print(h1.get_zooms()) + h1["fov"] = max(i * v for i, v in zip(h1.get_data_shape(), h1.get_zooms())) + center = np.asarray(img.shape[:3], dtype=float) / 2.0 + h1["Pxyz_c"] = img.affine.dot(np.hstack((center, [1.0])))[:3] + + # There is a special case here, where an interpolation is triggered, but it is not + # necessary, if the position of the center could "fix this" + # condition: 1. no rotation, no vox-size resampling, + if not is_resampling_vox2vox(np.linalg.inv(h1.get_affine()) @ img.affine): + # 2. img_size changes from odd to even and vice versa + ishape = np.asarray(img.shape)[re_order_axes] + delta_shape = np.subtract(ishape, h1.get_data_shape()[:3]) + # 2. img_size changes from odd to even and vice versa + if not np.allclose(np.remainder(delta_shape, 2), 0): + # invert axis reordering + delta_shape[re_order_axes] = delta_shape + new_center = (center + delta_shape / 2.0, [1.0]) + h1["Pxyz_c"] = img.affine.dot(np.hstack(new_center))[:3] # Here, we are explicitly using MGHHeader.get_affine() to construct the affine as - # MdcD = np.asarray(h1['Mdc']).T * h1['delta'] - # vol_center = MdcD.dot(hdr['dims'][:3]) / 2 - # affine = from_matvec(MdcD, h1['Pxyz_c'] - vol_center) + # MdcD = np.asarray(h1["Mdc"]).T * h1["delta"] + # vol_center = MdcD.dot(hdr["dims"][:3]) / 2 + # affine = from_matvec(MdcD, h1["Pxyz_c"] - vol_center) affine = h1.get_affine() # from_header does not compute Pxyz_c (and probably others) when importing from nii # Pxyz is the center of the image in world coords # target scalar type and dtype - sctype = np.uint8 if dtype is None else np.obj2sctype(dtype, default=np.uint8) + #sctype = np.uint8 if dtype is None else np.obj2sctype(dtype, default=np.uint8) + sctype = np.uint8 if dtype is None else np.dtype(dtype).type target_dtype = np.dtype(sctype) src_min, scale = 0, 1.0 - # get scale for conversion on original input before mapping to be more similar to mri_convert - if ( - img.get_data_dtype() != np.dtype(np.uint8) - or img.get_data_dtype() != target_dtype - ): + # get scale for conversion on original input before mapping to be more similar to + # mri_convert + img_dtype = img.get_data_dtype() + if any(img_dtype != dtyp for dtyp in (np.dtype(np.uint8), target_dtype)): src_min, scale = getscale(np.asanyarray(img.dataobj), 0, 255) - kwargs = {"dtype": "float"} if sctype != np.uint else {} + kwargs = {} + if sctype != np.uint: + kwargs["dtype"] = "float" mapped_data = map_image(img, affine, h1.get_data_shape(), order=order, **kwargs) - if img.get_data_dtype() != np.dtype(np.uint8) or ( - img.get_data_dtype() != target_dtype and scale != 1.0 - ): + if img_dtype != np.dtype(np.uint8) or (img_dtype != target_dtype and scale != 1.0): scaled_data = scalecrop(mapped_data, 0, 255, src_min, scale) # map zero in input to zero in output (usually background) scaled_data[mapped_data == 0] = 0 mapped_data = scaled_data - mapped_data = sctype( - np.clip(np.rint(mapped_data),0,255) if target_dtype == np.dtype(np.uint8) else mapped_data - ) - new_img = nib.MGHImage(mapped_data, affine, h1) + if target_dtype == np.dtype(np.uint8): + mapped_data = np.clip(np.rint(mapped_data), 0, 255) + new_img = nib.MGHImage(sctype(mapped_data), affine, h1) # make sure we store uchar + from nibabel.freesurfer import mghformat try: new_img.set_data_dtype(target_dtype) - except nib.freesurfer.mghformat.MGHError as e: - if "not recognized" in e.args[0]: - codes = set( - k.name - for k in nib.freesurfer.mghformat.data_type_codes.code.keys() - if isinstance(k, np.dtype) - ) - print( - f'The data type "{options.dtype}" is not recognized for MGH images, switching ' - f'to "{new_img.get_data_dtype()}" (supported: {tuple(codes)}).' - ) + except mghformat.MGHError as e: + if "not recognized" not in e.args[0]: + raise + dtype_codes = mghformat.data_type_codes.code.keys() + codes = set(k.name for k in dtype_codes if isinstance(k, np.dtype)) + print( + f"The data type '{options.dtype}' is not recognized for MGH images, " + f"switching to '{new_img.get_data_dtype()}' (supported: {tuple(codes)})." + ) return new_img +def is_resampling_vox2vox( + vox2vox: npt.NDArray[float], + eps: float = 1e-6, +) -> bool: + """ + Check whether the affine is resampling or just reordering. + + Parameters + ---------- + vox2vox : np.ndarray + The affine matrix. + eps : float, default=1e-6 + The epsilon for the affine check. + + Returns + ------- + bool + The result. + """ + _v2v = np.abs(vox2vox[:3, :3]) + # check 1: have exactly 3 times 1/-1 rest 0, check 2: all 1/-1 or 0 + return abs(_v2v.sum() - 3) > eps or np.any(np.maximum(_v2v, abs(_v2v - 1)) > eps) + + def is_conform( img: nib.analyze.SpatialImage, conform_vox_size: VoxSizeOption = 1.0, @@ -557,36 +709,39 @@ def is_conform( check_dtype: bool = True, dtype: Optional[Type] = None, verbose: bool = True, - conform_to_1mm_threshold: Optional[float] = None + conform_to_1mm_threshold: Optional[float] = None, + criteria: set[Criteria] = DEFAULT_CRITERIA, ) -> bool: - """Check if an image is already conformed or not. + f""" + Check if an image is already conformed or not. Dimensions: 256x256x256, Voxel size: 1x1x1, LIA orientation, and data type UCHAR. Parameters ---------- img : nib.analyze.SpatialImage - Loaded source image - conform_vox_size : VoxSizeOption - which voxel size to conform to. Can either be a float between 0.0 and - 1.0 or 'min' check, whether the image is conformed to the minimal voxels size, i.e. - conforming to smaller, but isotropic voxel sizes for high-res (default: 1.0). - eps : float - allowed deviation from zero for LIA orientation check (default: 1e-06). + Loaded source image. + conform_vox_size : VoxSizeOption, default=1.0 + Which voxel size to conform to. Can either be a float between 0.0 and + 1.0 or 'min' check, whether the image is conformed to the minimal voxels size, + i.e. conforming to smaller, but isotropic voxel sizes for high-res. + eps : float, default=1e-06 + Allowed deviation from zero for LIA orientation check. Small inaccuracies can occur through the inversion operation. Already conformed images are thus sometimes not correctly recognized. The epsilon accounts for these small shifts. - check_dtype : bool - specifies whether the UCHAR dtype condition is checked for; - this is not done when the input is a segmentation (default: True). - dtype : Optional[Type] - specifies the intended target dtype (default: uint8 = UCHAR) - verbose : bool - if True, details of which conformance conditions are violated (if any) - are displayed (default: True). - conform_to_1mm_threshold : Optional[float] - the threshold above which the image is conformed to 1mm - (default: ignore). + check_dtype : bool, default=True + Specifies whether the UCHAR dtype condition is checked for; + this is not done when the input is a segmentation. + dtype : Type, optional + Specifies the intended target dtype (default or None: uint8 = UCHAR). + verbose : bool, default=True + If True, details of which conformance conditions are violated (if any) + are displayed. + conform_to_1mm_threshold : float, optional + Above this threshold the image is conformed to 1mm (default or None: ignore). + criteria : set[Criteria], default={DEFAULT_CRITERIA} + An enum/set of criteria to check. Returns ------- @@ -596,7 +751,6 @@ def is_conform( Notes ----- This function only needs the header (not the data). - """ conformed_vox_size, conformed_img_size = get_conformed_vox_img_size( img, conform_vox_size, conform_to_1mm_threshold=conform_to_1mm_threshold @@ -605,38 +759,49 @@ def is_conform( ishape = img.shape # check 3d if len(ishape) > 3 and ishape[3] != 1: - raise ValueError( - f"ERROR: Multiple input frames ({img.shape[3]}) not supported!" - ) + raise ValueError(f"ERROR: Multiple input frames ({ishape[3]}) not supported!") - criteria = {} + checks = { + "Number of Dimensions 3": (len(ishape) == 3, f"image ndim {img.ndim}") + } # check dimensions - criteria["Dimensions {0}x{0}x{0}".format(conformed_img_size)] = all( - s == conformed_img_size for s in ishape[:3] - ) + if Criteria.FORCE_IMG_SIZE in criteria: + img_size_criteria = f"Dimensions {'x'.join([str(conformed_img_size)] * 3)}" + is_correct_img_size = all(s == conformed_img_size for s in ishape[:3]) + checks[img_size_criteria] = is_correct_img_size, f"image dimensions {ishape}" - # check voxel size + # check voxel size, drop voxel sizes of dimension 4 if available izoom = np.array(img.header.get_zooms()) - is_correct_vox_size = np.max(np.abs(izoom - conformed_vox_size) < eps) - criteria["Voxel Size {0}x{0}x{0}".format(conformed_vox_size)] = is_correct_vox_size + is_correct_vox_size = np.max(np.abs(izoom[:3] - conformed_vox_size)) < eps + _vox_sizes = conformed_vox_size if is_correct_vox_size else izoom[:3] + if Criteria.FORCE_ISO_VOX in criteria: + vox_size_criteria = f"Voxel Size {'x'.join([str(conformed_vox_size)] * 3)}" + image_vox_size = "image " + "x".join(map(str, izoom)) + checks[vox_size_criteria] = (is_correct_vox_size, image_vox_size) # check orientation LIA - LIA_affine = np.array([[-1, 0, 0], [0, 0, 1], [0, -1, 0]]) - iaffine = img.affine[0:3, 0:3] - LIA_affine * ( - conformed_vox_size if is_correct_vox_size else izoom - ) - criteria["Orientation LIA"] = np.max(np.abs(iaffine)) <= eps + if {Criteria.FORCE_LIA, Criteria.FORCE_LIA_STRICT} & criteria != {}: + is_strict = Criteria.FORCE_LIA_STRICT in criteria + lia_text = "strict" if is_strict else "lia" + if not (is_correct_lia := is_lia(img.affine, is_strict, eps)): + import re + print_options = np.get_printoptions() + np.set_printoptions(precision=2) + lia_text += ": " + re.sub("\\s+", " ", str(img.affine[:3, :3])) + np.set_printoptions(**print_options) + checks["Orientation LIA"] = (is_correct_lia, lia_text) # check dtype uchar if check_dtype: if dtype is None or (isinstance(dtype, str) and dtype.lower() == "uchar"): dtype = "uint8" else: # assume obj - dtype = np.dtype(np.obj2sctype(dtype)).name - criteria[f"Dtype {dtype}"] = img.get_data_dtype() == dtype + #dtype = np.dtype(np.obj2sctype(dtype)).name + dtype = np.dtype(dtype).type.__name__ + is_correct_dtype = img.get_data_dtype() == dtype + checks[f"Dtype {dtype}"] = (is_correct_dtype, f"dtype {img.get_data_dtype()}") - _is_conform = all(criteria.values()) - # result = (_is_conform, criteria) if return_criteria else _is_conform + _is_conform = all(map(lambda x: x[0], checks.values())) if verbose: if not _is_conform: @@ -646,44 +811,81 @@ def is_conform( "conformed" if conform_vox_size == 1.0 else f"{conform_vox_size}-conformed" ) print(f"A {conform_str} image must satisfy the following criteria:") - for condition, value in criteria.items(): - print(" - {:<30} {}".format(condition + ":", value)) + for condition, (value, message) in checks.items(): + print(f" - {condition:<30}: {value if value else 'BUT ' + message}") return _is_conform +def is_lia( + affine: npt.NDArray[float], + strict: bool = True, + eps: float = 1e-6, +): + """ + Checks whether the affine is LIA-oriented. + + Parameters + ---------- + affine : np.ndarray + The affine to check. + strict : bool, default=True + Whether the orientation should be "exactly" LIA or just similar to LIA (i.e. + it is more LIA than other directions). + eps : float, default=1e-6 + The threshold in strict mode. + + Returns + ------- + bool + Whether the affine is LIA-oriented. + """ + iaffine = affine[:3, :3] + lia_nonzero = LIA_AFFINE != 0 + signs = np.all(np.sign(iaffine[lia_nonzero]) == LIA_AFFINE[lia_nonzero]) + if strict: + directions = np.all(iaffine[np.logical_not(lia_nonzero)] <= eps) + else: + def get_primary_dirs(a): return np.argmax(abs(a), axis=0) + + directions = np.all(get_primary_dirs(iaffine) == get_primary_dirs(LIA_AFFINE)) + is_correct_lia = directions and signs + return is_correct_lia + + def get_conformed_vox_img_size( img: nib.analyze.SpatialImage, conform_vox_size: VoxSizeOption, conform_to_1mm_threshold: Optional[float] = None ) -> Tuple[float, int]: - """Extract the voxel size and the image size. + """ + Extract the voxel size and the image size. This function only needs the header (not the data). Parameters ---------- img : nib.analyze.SpatialImage - Loaded source image - conform_vox_size : VoxSizeOption - [MISSING] - conform_to_1mm_threshold : Optional[float] - [MISSING] + Loaded source image. + conform_vox_size : float, "min" + The voxel size parameter to use: either a voxel size as float, or the string + "min" to automatically find a suitable voxel size (smallest per-dimension voxel + size). + conform_to_1mm_threshold : float, optional + The threshold for which image voxel size should be conformed to 1mm instead of + conformed to the smallest voxel size (default: None, never apply). Returns ------- - [MISSING] - + conformed_vox_size : float + The determined voxel size to conform the image to. + conformed_img_size : int + The size of the image adjusted to the conformed voxel size. """ # this is similar to mri_convert --conform_min - if isinstance(conform_vox_size, str) and conform_vox_size.lower() in [ - "min", - "auto", - ]: + auto_values = ["min", "auto"] + if isinstance(conform_vox_size, str) and conform_vox_size.lower() in auto_values: conformed_vox_size = find_min_size(img) - if ( - conform_to_1mm_threshold is not None - and conformed_vox_size > conform_to_1mm_threshold - ): + if conform_to_1mm_threshold and conformed_vox_size > conform_to_1mm_threshold: conformed_vox_size = 1.0 # this is similar to mri_convert --conform_size elif isinstance(conform_vox_size, float) and 0.0 < conform_vox_size <= 1.0: @@ -698,7 +900,8 @@ def check_affine_in_nifti( img: Union[nib.Nifti1Image, nib.Nifti2Image], logger: Optional[logging.Logger] = None ) -> bool: - """Check the affine in nifti Image. + """ + Check the affine in nifti Image. Sets affine with qform, if it exists and differs from sform. If qform does not exist, voxel sizes between header information and information @@ -708,66 +911,52 @@ def check_affine_in_nifti( Parameters ---------- img : Union[nib.Nifti1Image, nib.Nifti2Image] - loaded nifti-image + Loaded nifti-image. logger : Optional[logging.Logger] Logger object or None (default) to log or print an info message to - stdout (for None) + stdout (for None). Returns ------- - True, if - affine was reset to qform voxel sizes in affine are equivalent to - voxel sizes in header - False, if - voxel sizes in affine and header differ - + bool + False, if voxel sizes in affine and header differ. """ check = True message = "" + header = cast(nib.Nifti1Header | nib.Nifti2Header, img.header) if ( - img.header["qform_code"] != 0 - and np.max(np.abs(img.get_sform() - img.get_qform())) > 0.001 + header["qform_code"] != 0 and + not np.allclose(img.get_sform(), img.get_qform(), atol=0.001) ): message = ( - "#############################################################" - "\nWARNING: qform and sform transform are not identical!\n sform-transform:\n{}\n " - "qform-transform:\n{}\n" - "You might want to check your Nifti-header for inconsistencies!" - "\n!!! Affine from qform transform will now be used !!!\n" - "#############################################################".format( - img.header.get_sform(), img.header.get_qform() - ) + f"#############################################################\n" + f"WARNING: qform and sform transform are not identical!\n" + f" sform-transform:\n{header.get_sform()}\n" + f" qform-transform:\n{header.get_qform()}\n" + f"You might want to check your Nifti-header for inconsistencies!\n" + f"!!! Affine from qform transform will now be used !!!\n" + f"#############################################################" ) - # Set sform with qform affine and update best affine in header + # Set sform with qform affine and update the best affine in header img.set_sform(img.get_qform()) img.update_header() else: - # Check if affine correctly includes voxel information and print Warning/Exit otherwise - vox_size_head = img.header.get_zooms() - aff = img.affine - xsize = np.sqrt( - aff[0][0] * aff[0][0] + aff[1][0] * aff[1][0] + aff[2][0] * aff[2][0] - ) - ysize = np.sqrt( - aff[0][1] * aff[0][1] + aff[1][1] * aff[1][1] + aff[2][1] * aff[2][1] - ) - zsize = np.sqrt( - aff[0][2] * aff[0][2] + aff[1][2] * aff[1][2] + aff[2][2] * aff[2][2] - ) + # Check if affine correctly includes voxel information and print Warning/ + # Exit otherwise + vox_size_header = header.get_zooms() - if ( - (abs(xsize - vox_size_head[0]) > 0.001) - or (abs(ysize - vox_size_head[1]) > 0.001) - or (abs(zsize - vox_size_head[2]) > 0.001) - ): + # voxel size in xyz direction from the affine + vox_size_affine = np.sqrt((img.affine[:3, :3] * img.affine[:3, :3]).sum(0)) + + if not np.allclose(vox_size_affine, vox_size_header, atol=1e-3): message = ( f"#############################################################\n" - f"ERROR: Invalid Nifti-header! Affine matrix is inconsistent with Voxel sizes. " - f"\nVoxel size (from header) vs. Voxel size in affine: " - f"{tuple(vox_size_head[:3])}, {(xsize, ysize, zsize)}\n" - f"Input Affine----------------\n{aff}\n" + f"ERROR: Invalid Nifti-header! Affine matrix is inconsistent with " + f"Voxel sizes. \nVoxel size (from header) vs. Voxel size in affine:\n" + f"{tuple(vox_size_header[:3])}, {tuple(vox_size_affine)}\n" + f"Input Affine----------------\n{img.affine}\n" f"#############################################################" ) check = False @@ -783,19 +972,26 @@ def check_affine_in_nifti( if __name__ == "__main__": # Command Line options are error checking done here - options = options_parse() + try: + options = options_parse() + except RuntimeError as e: + sys.exit(*e.args) print(f"Reading input: {options.input} ...") image = nib.load(options.input) + if not isinstance(image, nib.analyze.SpatialImage): + sys.exit(f"ERROR: Input image is not a spatial image: {type(image).__name__}") if len(image.shape) > 3 and image.shape[3] != 1: sys.exit(f"ERROR: Multiple input frames ({image.shape[3]}) not supported!") - target_dtype = "uint8" if options.seg_input else options.dtype - opt_kwargs = {} - check_dtype = target_dtype != "any" - if check_dtype: - opt_kwargs["dtype"] = target_dtype + _target_dtype = "uint8" if options.seg_input else options.dtype + crit = DEFAULT_CRITERIA_DICT.items() + opt_kwargs = { + "criteria": set(c for n, c in crit if getattr(options, "force_" + n, True)), + } + if check_dtype := _target_dtype != "any": + opt_kwargs["dtype"] = _target_dtype if hasattr(options, "conform_to_1mm_threshold"): opt_kwargs["conform_to_1mm_threshold"] = options.conform_to_1mm_threshold @@ -816,20 +1012,24 @@ def check_affine_in_nifti( print(f"Input {options.input} is already conformed! Exiting.\n") sys.exit(0) else: - # Note: if check_only, a non-conforming image leads to an error code, this result is needed in recon_surf.sh + # Note: if check_only, a non-conforming image leads to an error code, this + # result is needed in recon_surf.sh if options.check_only: print("check_only flag provided. Exiting without conforming input image.\n") sys.exit(1) # If image is nifti image if options.input[-7:] == ".nii.gz" or options.input[-4:] == ".nii": - - if not check_affine_in_nifti(image): + from nibabel import Nifti1Image, Nifti2Image + if not check_affine_in_nifti(cast(Nifti1Image | Nifti2Image, image)): sys.exit("ERROR: inconsistency in nifti-header. Exiting now.\n") try: new_image = conform( - image, order=options.order, conform_vox_size=_vox_size, dtype=options.dtype + image, + order=options.order, + conform_vox_size=_vox_size, + **opt_kwargs, ) except ValueError as e: sys.exit(e.args[0]) diff --git a/FastSurferCNN/data_loader/data_utils.py b/FastSurferCNN/data_loader/data_utils.py index d1098e10..dd11302b 100644 --- a/FastSurferCNN/data_loader/data_utils.py +++ b/FastSurferCNN/data_loader/data_utils.py @@ -14,7 +14,8 @@ # IMPORTS -from typing import Optional, Tuple, Union, Mapping +from pathlib import Path +from typing import Optional, Tuple, Union, Mapping, cast, Iterable import numpy as np from numpy import typing as npt @@ -39,7 +40,7 @@ ## # Global Vars ## -SUPPORTED_OUTPUT_FILE_FORMATS = ["mgz", "nii", "nii.gz"] +SUPPORTED_OUTPUT_FILE_FORMATS = ("mgz", "nii", "nii.gz") LOGGER = logging.getLogger(__name__) ## @@ -47,47 +48,52 @@ ## -# Conform an MRI brain image to UCHAR, RAS orientation, and 1mm or minimal isotropic voxels +# Conform an MRI brain image to UCHAR, RAS orientation, and 1mm or minimal isotropic +# voxels def load_and_conform_image( - img_filename: str, + img_filename: Path | str, interpol: int = 1, logger: logging.Logger = LOGGER, conform_min: bool = False -) -> Tuple[_Header, np.ndarray, np.ndarray]: - """Load MRI image and conform it to UCHAR, RAS orientation and 1mm or minimum isotropic voxels size. +) -> tuple[_Header, np.ndarray, np.ndarray]: + """ + Load MRI image and conform it to UCHAR, RAS orientation and 1mm or minimum isotropic + voxels size. Only, if it does not already have this format. Parameters ---------- - img_filename : str - path and name of volume to read - interpol : int - interpolation order for image conformation (0=nearest,1=linear(default),2=quadratic,3=cubic) - logger : logging.Logger - Logger to write output to (default = STDOUT) - conform_min : bool - conform image to minimal voxel size (for high-res) (Default = False) + img_filename : Path, str + Path and name of volume to read. + interpol : int, default=1 + Interpolation order for image conformation + (0=nearest, 1=linear(default), 2=quadratic, 3=cubic). + logger : logging.Logger, default= + Logger to write output to (default = STDOUT). + conform_min : bool, default=False + Conform image to minimal voxel size (for high-res). Returns ------- nibabel.Header header_info - header information of the conformed image + Header information of the conformed image. numpy.ndarray affine_info - affine information of the conformed image + Affine information of the conformed image. numpy.ndarray orig_data - conformed image data + Conformed image data. Raises ------ RuntimeError - Multiple input frames not supported + Multiple input frames not supported. RuntimeError - Inconsistency in nifti-header - + Inconsistency in nifti-header. """ - orig = nib.load(img_filename) - # is_conform and conform accept numeric values and the string 'min' instead of the bool value + img_file = Path(img_filename) + orig = nib.load(img_file) + # is_conform and conform accept numeric values and the string 'min' instead of the + # bool value _conform_vox_size = "min" if conform_min else 1.0 if not is_conform(orig, conform_vox_size=_conform_vox_size): @@ -101,7 +107,7 @@ def load_and_conform_image( ) # Check affine if image is nifti image - if any(img_filename.endswith(ext) for ext in (".nii.gz", ".nii")): + if img_file.suffix == ".nii" or img_file.suffixes[-2:] == [".nii", ".gz"]: if not check_affine_in_nifti(orig, logger=logger): raise RuntimeError("ERROR: inconsistency in nifti-header. Exiting now.") @@ -117,19 +123,21 @@ def load_and_conform_image( def load_image( - file: str, + file: str | Path, name: str = "image", - **kwargs -) -> Tuple[nib.analyze.SpatialImage, np.ndarray]: - """Load file 'file' with nibabel, including all data. + **kwargs, +) -> tuple[nib.analyze.SpatialImage, np.ndarray]: + """ + Load file 'file' with nibabel, including all data. Parameters ---------- - file : str - path to the file to load. - name : str - name of the file (optional), only effects error messages. (Default value = "image") + file : Path, str + Path to the file to load. + name : str, default="image" + Name of the file (optional), only effects error messages. **kwargs : + Additional keyword arguments. Returns ------- @@ -142,17 +150,16 @@ def load_image( Failed loading the file nibabel releases the GIL, so the following is a parallel example. { - from concurrent.futures import ThreadPoolExecutor - with ThreadPoolExecutor() as pool: - future1 = pool.submit(load_image, filename1) - future2 = pool.submit(load_image, filename2) - image, data = future1.result() - image2, data2 = future2.result() + >>> from concurrent.futures import ThreadPoolExecutor + >>> with ThreadPoolExecutor() as pool: + >>> future1 = pool.submit(load_image, filename1) + >>> future2 = pool.submit(load_image, filename2) + >>> image, data = future1.result() + >>> image2, data2 = future2.result() } - """ try: - img = nib.load(file, **kwargs) + img = cast(nib.analyze.SpatialImage, nib.load(file, **kwargs)) except (IOError, FileNotFoundError) as e: raise IOError( f"Failed loading the {name} '{file}' with error: {e.args[0]}" @@ -162,33 +169,39 @@ def load_image( def load_maybe_conform( - file: str, - alt_file: str, + file: Path | str, + alt_file: Path | str, vox_size: VoxSizeOption = "min" -) -> Tuple[str, nib.analyze.SpatialImage, np.ndarray]: - """Load an image by file, check whether it is conformed to vox_size and conform to vox_size if it is not. +) -> tuple[Path, nib.analyze.SpatialImage, np.ndarray]: + """ + Load an image by file, check whether it is conformed to vox_size and conform to + vox_size if it is not. Parameters ---------- - file : str - path to the file to load. - alt_file : str - alternative file to interpolate from - vox_size : VoxSizeOption - Voxel Size (Default value = "min") + file : Path, str + Path to the file to load. + alt_file : Path, str + Alternative file to interpolate from. + vox_size : VoxSizeOption, default="min" + Voxel Size. Returns ------- - Tuple[str, nib.analyze.SpatialImage, np.ndarray] - [MISSING] - + Path + The path to the file. + nib.analyze.SpatialImage + The file container object including the corrected header. + np.ndarray + The data loaded from the file. """ - from os.path import isfile + file = Path(file) + alt_file = Path(alt_file) _is_conform, img = False, None - if isfile(file): + if file.is_file(): # see if the file is 1mm - img = nib.load(file) + img = cast(nib.analyze.SpatialImage, nib.load(file)) # is_conform only needs the header, not the data _is_conform = is_conform(img, conform_vox_size=vox_size, verbose=False) @@ -201,29 +214,33 @@ def load_maybe_conform( # the image is not conformed to 1mm, do this now. from nibabel.filebasedimages import FileBasedHeader as _Header - fileext = list( - filter(lambda ext: file.endswith("." + ext), SUPPORTED_OUTPUT_FILE_FORMATS) - ) + fileext = [ + ext for ext in SUPPORTED_OUTPUT_FILE_FORMATS + if file.name.endswith("." + ext) + ] if len(fileext) != 1: raise RuntimeError( f"Invalid file extension of conf_name: {file}, must be one of " f"{SUPPORTED_OUTPUT_FILE_FORMATS}." ) - file_no_fileext = file[: -len(fileext[0]) - 1] - vox_suffix = "." + ( - "min" if vox_size == "min" else str(vox_size) + "mm" - ).replace(".", "") + file_no_fileext = str(file)[:-len(fileext[0]) - 1] + if vox_size == "min": + vox_suffix = ".min" + else: + vox_suffix = f".{str(vox_size).replace('.', '')}mm" if not file_no_fileext.endswith(vox_suffix): file_no_fileext += vox_suffix - # if the orig file is neither absolute nor in the subject path, use the conformed file - src_file = alt_file if isfile(alt_file) else file - if not isfile(alt_file): + # if the orig file is neither absolute nor in the subject path, use the + # conformed file + src_file = alt_file if alt_file.is_file() else file + if not alt_file.is_file(): LOGGER.warning( - f"No valid alternative file (e.g. orig, here: {alt_file}) was given to interpolate from, so " - f"we might lose quality due to multiple chained interpolations." + f"No valid alternative file (e.g. orig, here: {alt_file}) was given to " + f"interpolate from, so we might lose quality due to multiple chained " + f"interpolations." ) - dst_file = file_no_fileext + "." + fileext[0] + dst_file = Path(file_no_fileext + "." + fileext[0]) # conform to 1mm header, affine, data = load_and_conform_image( src_file, conform_min=False, logger=logging.getLogger(__name__ + ".conform") @@ -238,52 +255,55 @@ def load_maybe_conform( # Save image routine def save_image( header_info: _Header, - affine_info: npt.NDArray, - img_array: npt.NDArray, - save_as: str, + affine_info: npt.NDArray[float], + img_array: np.ndarray, + save_as: str | Path, dtype: Optional[npt.DTypeLike] = None ) -> None: - """Save an image (nibabel MGHImage), according to the desired output file format. + """ + Save an image (nibabel MGHImage), according to the desired output file format. - Supported formats are defined in supported_output_file_formats. Saves predictions to save_as + Supported formats are defined in supported_output_file_formats. Saves predictions to + save_as. Parameters ---------- header_info : _Header - image header information - affine_info : npt.NDArray - image affine information - img_array : npt.NDArray - an array containing image data - save_as : str - name under which to save prediction; this determines output file format - dtype : Optional[npt.DTypeLike] - image array type; if provided, the image object is explicitly set to match this type - (Default value = None) - + Image header information. + affine_info : npt.NDArray[float] + Image affine information. + img_array : np.ndarray + An array containing image data. + save_as : Path, str + Name under which to save prediction; this determines output file format. + dtype : npt.DTypeLike, optional + Image array type; if provided, the image object is explicitly set to match this + type (Default value = None). """ - assert any( - save_as.endswith(file_ext) for file_ext in SUPPORTED_OUTPUT_FILE_FORMATS + save_as = Path(save_as) + assert ( + save_as.suffix[1:] in SUPPORTED_OUTPUT_FILE_FORMATS or + save_as.suffixes[-2:] == [".nii", ".gz"] ), ( - "Output filename does not contain a supported file format (" - + ", ".join(file_ext for file_ext in SUPPORTED_OUTPUT_FILE_FORMATS) - + ")!" + f"Output filename does not contain a supported file format " + f"{SUPPORTED_OUTPUT_FILE_FORMATS}!" ) mgh_img = None - if save_as.endswith("mgz"): + if save_as.suffix == ".mgz": mgh_img = nib.MGHImage(img_array, affine_info, header_info) - elif any(save_as.endswith(file_ext) for file_ext in ["nii", "nii.gz"]): + elif save_as.suffix == ".nii" or save_as.suffixes[-2:] == [".nii", ".gz"]: mgh_img = nib.nifti1.Nifti1Pair(img_array, affine_info, header_info) if dtype is not None: mgh_img.set_data_dtype(dtype) - if any(save_as.endswith(file_ext) for file_ext in ["mgz", "nii"]): + if save_as.suffix in (".mgz", ".nii"): nib.save(mgh_img, save_as) - elif save_as.endswith("nii.gz"): - # For correct outputs, nii.gz files should be saved using the nifti1 sub-module's save(): - nib.nifti1.save(mgh_img, save_as) + elif save_as.suffixes[-2:] == [".nii", ".gz"]: + # For correct outputs, nii.gz files should be saved using the nifti1 + # sub-module's save(): + nib.nifti1.save(mgh_img, str(save_as)) # Transformation for mapping @@ -291,20 +311,20 @@ def transform_axial( vol: npt.NDArray, coronal2axial: bool = True ) -> np.ndarray: - """Transform volume into Axial axis and back. + """ + Transform volume into Axial axis and back. Parameters ---------- vol : npt.NDArray - image volume to transform + Image volume to transform. coronal2axial : bool - transform from coronal to axial = True (default), + Transform from coronal to axial = True (default). Returns ------- np.ndarray - Transformed image - + Transformed image. """ if coronal2axial: return np.moveaxis(vol, [0, 1, 2], [1, 2, 0]) @@ -316,20 +336,20 @@ def transform_sagittal( vol: npt.NDArray, coronal2sagittal: bool = True ) -> np.ndarray: - """Transform volume into Sagittal axis and back. + """ + Transform volume into Sagittal axis and back. Parameters ---------- vol : npt.NDArray - image volume to transform + Image volume to transform. coronal2sagittal : bool - transform from coronal to sagittal = True (default), + Transform from coronal to sagittal = True (default). Returns ------- np.ndarray: - transformed image - + Transformed image. """ if coronal2sagittal: return np.moveaxis(vol, [0, 1, 2], [2, 1, 0]) @@ -342,23 +362,23 @@ def get_thick_slices( img_data: npt.NDArray, slice_thickness: int = 3 ) -> np.ndarray: - """Extract thick slices from the image. + """ + Extract thick slices from the image. Feed slice_thickness preceding and succeeding slices to network, - label only middle one + label only middle one. Parameters ---------- img_data : npt.NDArray - 3D MRI image read in with nibabel + 3D MRI image read in with nibabel. slice_thickness : int - number of slices to stack on top and below slice of interest (default=3) + Number of slices to stack on top and below slice of interest (default=3). Returns ------- np.ndarray - image data with the thick slices of the n-th axis appended into the n+1-th axis. - + Image data with the thick slices of the n-th axis appended into the n+1-th axis. """ img_data_pad = np.pad( img_data, ((0, 0), (0, 0), (slice_thickness, slice_thickness)), mode="edge" @@ -376,28 +396,28 @@ def filter_blank_slices_thick( weight_vol: npt.NDArray, threshold: int = 50 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Filter blank slices from the volume using the label volume. + """ + Filter blank slices from the volume using the label volume. Parameters ---------- img_vol : npt.NDArray - orig image volume + Orig image volume. label_vol : npt.NDArray - label images (ground truth) + Label images (ground truth). weight_vol : npt.NDArray - weight corresponding to labels + Weight corresponding to labels. threshold : int - threshold for number of pixels needed to keep slice (below = dropped). (Default value = 50) + Threshold for number of pixels needed to keep slice (below = dropped). (Default value = 50). Returns ------- filtered img_vol : np.ndarray - [MISSING] + Image volume with blank slices removed. label_vol : np.ndarray - [MISSING] + Label volume with blank slices removed. weight_vol : np.ndarray - [MISSING] - + Weight volume with blank slices removed. """ # Get indices of all slices with more than threshold labels/pixels select_slices = np.sum(label_vol, axis=(0, 1)) > threshold @@ -421,32 +441,32 @@ def create_weight_mask( cortex_mask: bool = True, gradient: bool = True ) -> np.ndarray: - """Create weighted mask - with median frequency balancing and edge-weighting. + """ + Create weighted mask - with median frequency balancing and edge-weighting. Parameters ---------- mapped_aseg : np.ndarray - segmentation to create weight mask from. + Segmentation to create weight mask from. max_weight : int - maximal weight on median weights (cap at this value). (Default value = 5) + Maximal weight on median weights (cap at this value). (Default value = 5). max_edge_weight : int - maximal weight on gradient weight (cap at this value). (Default value = 5) + Maximal weight on gradient weight (cap at this value). (Default value = 5). max_hires_weight : int - maximal weight on hires weight (cap at this value). (Default value = None) + Maximal weight on hires weight (cap at this value). (Default value = None). ctx_thresh : int - label value of cortex (above = cortical parcels). (Default value = 33) + Label value of cortex (above = cortical parcels). (Default value = 33). mean_filter : bool - flag, set to add mean_filter mask (default = False). + Flag, set to add mean_filter mask (default = False). cortex_mask : bool - flag, set to create cortex weight mask (default=True). + Flag, set to create cortex weight mask (default=True). gradient : bool - flag, set to create gradient mask (default = True). + Flag, set to create gradient mask (default = True). Returns ------- np.ndarray - Weights - + Weights. """ unique, counts = np.unique(mapped_aseg, return_counts=True) @@ -495,22 +515,22 @@ def cortex_border_mask( structure: npt.NDArray, ctx_thresh: int = 33 ) -> np.ndarray: - """Erode the cortex of a given mri image to create the inner gray matter mask (outer most cortex voxels). + """ + Erode the cortex of a given mri image to create the inner gray matter mask (outer most cortex voxels). Parameters ---------- label : npt.NDArray - ground truth labels. + Ground truth labels. structure : npt.NDArray - structuring element to erode with + Structuring element to erode with. ctx_thresh : int - label value of cortex (above = cortical parcels). Defaults to 33 + Label value of cortex (above = cortical parcels). Defaults to 33. Returns ------- np.ndarray - inner grey matter layer - + Inner grey matter layer. """ # create aseg brainmask, erode it and subtract from itself bm = np.clip(label, a_max=1, a_min=0) @@ -529,24 +549,24 @@ def deep_sulci_and_wm_strand_mask( iteration: int = 1, ctx_thresh: int = 33 ) -> np.ndarray: - """Get a binary mask of deep sulci and small white matter strands by using binary closing (erosion and dilation). + """ + Get a binary mask of deep sulci and small white matter strands by using binary closing (erosion and dilation). Parameters ---------- volume : npt.NDArray - loaded image (aseg, label space) + Loaded image (aseg, label space). structure : npt.NDArray - structuring element (e.g. np.ones((3, 3, 3))) + Structuring element (e.g. np.ones((3, 3, 3))). iteration : int - number of times mask should be dilated + eroded. Defaults to 1 + Number of times mask should be dilated + eroded. Defaults to 1. ctx_thresh : int - label value of cortex (above = cortical parcels). Defaults to 33 + Label value of cortex (above = cortical parcels). Defaults to 33. Returns ------- np.ndarray - sulcus + wm mask - + Sulcus + wm mask. """ # Binarize label image (cortex = 1, everything else = 0) empty_im = np.zeros(shape=volume.shape) @@ -564,47 +584,80 @@ def deep_sulci_and_wm_strand_mask( # Label mapping functions (to aparc (eval) and to label (train)) -def read_classes_from_lut(lut_file: str) -> pd.DataFrame: - """Read in FreeSurfer-like LUT table. +def read_classes_from_lut(lut_file: str | Path): + """ + Modify from datautils to allow support for FreeSurfer-distributed ColorLUTs. + + Read in **FreeSurfer-like** LUT table. Parameters ---------- - lut_file : str - path and name of FreeSurfer-style LUT file with classes of interest + lut_file : Path, str + The path and name of FreeSurfer-style LUT file with classes of interest. Example entry: ID LabelName R G B A 0 Unknown 0 0 0 0 1 Left-Cerebral-Exterior 70 130 180 0 + ... Returns ------- - pd.Dataframe - DataFrame with ids present, name of ids, color for plotting - + pandas.DataFrame + DataFrame with ids present, name of ids, color for plotting. """ + if not isinstance(lut_file, Path): + lut_file = Path(lut_file) + if lut_file.suffix == ".tsv": + return pd.read_csv(lut_file, sep="\t") + # Read in file - separator = {"tsv": "\t", "csv": ",", "txt": " "} - return pd.read_csv(lut_file, sep=separator[lut_file[-3:]]) + names = { + "ID": "int", + "LabelName": "str", + "Red": "int", + "Green": "int", + "Blue": "int", + "Alpha": "int", + } + kwargs = {} + if lut_file.suffix == ".csv": + kwargs["sep"] = "," + elif lut_file.suffix == ".txt": + kwargs["sep"] = "\\s+" + else: + raise RuntimeError( + f"Unknown LUT file extension {lut_file}, must be csv, txt or tsv." + ) + return pd.read_csv( + lut_file, + index_col=False, + skip_blank_lines=True, + comment="#", + header=None, + names=list(names.keys()), + dtype=names, + **kwargs, + ) def map_label2aparc_aseg( mapped_aseg: torch.Tensor, labels: Union[torch.Tensor, npt.NDArray] ) -> torch.Tensor: - """Perform look-up table mapping from sequential label space to LUT space. + """ + Perform look-up table mapping from sequential label space to LUT space. Parameters ---------- mapped_aseg : torch.Tensor - label space segmentation (aparc.DKTatlas + aseg) + Label space segmentation (aparc.DKTatlas + aseg). labels : Union[torch.Tensor, npt.NDArray] - list of labels defining LUT space + List of labels defining LUT space. Returns ------- torch.Tensor - labels in LUT space - + Labels in LUT space. """ if isinstance(labels, np.ndarray): labels = torch.from_numpy(labels) @@ -613,24 +666,24 @@ def map_label2aparc_aseg( def clean_cortex_labels(aparc: npt.NDArray) -> np.ndarray: - """Clean up aparc segmentations. + """ + Clean up aparc segmentations. Map undetermined and optic chiasma to BKG Map Hypointensity classes to one Vessel to WM 5th Ventricle to CSF - Remaining cortical labels to BKG + Remaining cortical labels to BKG. Parameters ---------- aparc : npt.NDArray - aparc segmentations + Aparc segmentations. Returns ------- np.ndarray - cleaned aparc - + Cleaned aparc. """ aparc[aparc == 80] = 77 # Hypointensities Class aparc[aparc == 85] = 0 # Optic Chiasma to BKG @@ -650,22 +703,22 @@ def fill_unknown_labels_per_hemi( unknown_label: int, cortex_stop: int ) -> np.ndarray: - """Replace label 1000 (lh unknown) and 2000 (rh unknown) with closest class for each voxel. + """ + Replace label 1000 (lh unknown) and 2000 (rh unknown) with closest class for each voxel. Parameters ---------- gt : npt.NDArray - ground truth segmentation with class unknown + Ground truth segmentation with class unknown. unknown_label : int - class label for unknown (lh: 1000, rh: 2000) + Class label for unknown (lh: 1000, rh: 2000). cortex_stop : int - class label at which cortical labels of this hemi stop (lh: 2000, rh: 3000) + Class label at which cortical labels of this hemi stop (lh: 2000, rh: 3000). Returns ------- np.ndarray - ground truth segmentation with all classes - + Ground truth segmentation with all classes. """ # Define shape of image and dilation element h, w, d = gt.shape @@ -700,18 +753,18 @@ class label at which cortical labels of this hemi stop (lh: 2000, rh: 3000) def fuse_cortex_labels(aparc: npt.NDArray) -> np.ndarray: - """Fuse cortical parcels on left/right hemisphere (reduce aparc classes). + """ + Fuse cortical parcels on left/right hemisphere (reduce aparc classes). Parameters ---------- aparc : npt.NDArray - anatomical segmentation with cortical parcels + Anatomical segmentation with cortical parcels. Returns ------- np.ndarray - anatomical segmentation with reduced number of cortical parcels - + Anatomical segmentation with reduced number of cortical parcels. """ aparc_temp = aparc.copy() @@ -748,18 +801,18 @@ def fuse_cortex_labels(aparc: npt.NDArray) -> np.ndarray: def split_cortex_labels(aparc: npt.NDArray) -> np.ndarray: - """Splot cortex labels to completely de-lateralize structures. + """ + Splot cortex labels to completely de-lateralize structures. Parameters ---------- aparc : npt.NDArray - anatomical segmentation and parcellation from network + Anatomical segmentation and parcellation from network. Returns ------- np.ndarray - re-lateralized aparc - + Re-lateralized aparc. """ # Post processing - Splitting classes # Quick Fix for 2026 vs 1026; 2029 vs. 1029; 2025 vs. 1025 @@ -840,24 +893,24 @@ def unify_lateralized_labels( lut: Union[str, pd.DataFrame], combi: Tuple[str, str] = ("Left-", "Right-") ) -> Mapping: - """Generate lookup dictionary of left-right labels. + """ + Generate lookup dictionary of left-right labels. Parameters ---------- lut : Union[str, pd.DataFrame] - either lut-file string to load or pandas dataframe + Either lut-file string to load or pandas dataframe Example entry: ID LabelName R G B A 0 Unknown 0 0 0 0 - 1 Left-Cerebral-Exterior 70 130 180 0 + 1 Left-Cerebral-Exterior 70 130 180 0. combi : Tuple[str, str] - Prefix or labelnames to combine. Default: Left- and Right- + Prefix or labelnames to combine. Default: Left- and Right-. Returns ------- Mapping - dictionary mapping between left and right hemispheres - + Dictionary mapping between left and right hemispheres. """ if isinstance(lut, str): lut = read_classes_from_lut(lut) @@ -873,7 +926,8 @@ def get_labels_from_lut( lut: Union[str, pd.DataFrame], label_extract: Tuple[str, str] = ("Left-", "ctx-rh") ) -> Tuple[np.ndarray, np.ndarray]: - """Extract labels from the lookup tables. + """ + Extract labels from the lookup tables. Parameters ---------- @@ -883,18 +937,17 @@ def get_labels_from_lut( Example entry: ID LabelName R G B A 0 Unknown 0 0 0 0 - 1 Left-Cerebral-Exterior 70 130 180 0 + 1 Left-Cerebral-Exterior 70 130 180 0. label_extract : Tuple[str, str] - suffix of label names to mask for sagittal labels - Default: "Left-" and "ctx-rh" + Suffix of label names to mask for sagittal labels + Default: "Left-" and "ctx-rh". Returns ------- np.ndarray - full label list, + Full label list. np.ndarray - sagittal label list - + Sagittal label list. """ if isinstance(lut, str): lut = read_classes_from_lut(lut) @@ -910,32 +963,32 @@ def map_aparc_aseg2label( aseg_nocc: Optional[npt.NDArray] = None, processing: str = "aparc" ) -> Tuple[np.ndarray, np.ndarray]: - """Perform look-up table mapping of aparc.DKTatlas+aseg.mgz data to label space. + """ + Perform look-up table mapping of aparc.DKTatlas+aseg.mgz data to label space. Parameters ---------- aseg : npt.NDArray - ground truth aparc+aseg + Ground truth aparc+aseg. labels : npt.NDArray - labels to use (extracted from LUT with get_labels_from_lut) + Labels to use (extracted from LUT with get_labels_from_lut). labels_sag : npt.NDArray - sagittal labels to use (extracted from LUT with - get_labels_from_lut) + Sagittal labels to use (extracted from LUT with + get_labels_from_lut). sagittal_lut_dict : Mapping - left-right label mapping (can be extracted with - unify_lateralized_labels from LUT) + Left-right label mapping (can be extracted with + unify_lateralized_labels from LUT). aseg_nocc : Optional[npt.NDArray] - ground truth aseg without corpus callosum segmentation (Default value = None) + Ground truth aseg without corpus callosum segmentation (Default value = None). processing : str - should be set to "aparc" or "aseg" for additional mappings (hard-coded) (Default value = "aparc") + Should be set to "aparc" or "aseg" for additional mappings (hard-coded) (Default value = "aparc"). Returns ------- np.ndarray - mapped aseg for coronal and axial, + Mapped aseg for coronal and axial. np.ndarray - mapped aseg for sagital - + Mapped aseg for sagital. """ # If corpus callosum is not removed yet, do it now if aseg_nocc is not None: @@ -1001,18 +1054,18 @@ def map_aparc_aseg2label( def sagittal_coronal_remap_lookup(x: int) -> int: - """Convert left labels to corresponding right labels for aseg with dictionary mapping. + """ + Convert left labels to corresponding right labels for aseg with dictionary mapping. Parameters ---------- x : int - label to look up + Label to look up. Returns ------- np.ndarray - mapped label - + Mapped label. """ return { 2: 41, @@ -1037,20 +1090,20 @@ def infer_mapping_from_lut( num_classes_full: int, lut: Union[str, pd.DataFrame] ) -> np.ndarray: - """[MISSING]. + """ + Guess the mapping from a lookup table. Parameters ---------- num_classes_full : int - number of classes + Number of classes. lut : Union[str, pd.DataFrame] - look-up table listing class labels + Look-up table listing class labels. Returns ------- np.ndarray - list of indexes for - + List of indexes for. """ labels, labels_sag = unify_lateralized_labels(lut) idx_list = np.ndarray(shape=(num_classes_full,), dtype=np.int16) @@ -1072,290 +1125,69 @@ def map_prediction_sagittal2full( num_classes: int = 51, lut: Optional[str] = None ) -> np.ndarray: - """Remap the prediction on the sagittal network to full label space used by coronal and axial networks. + """ + Remap the prediction on the sagittal network to full label space used by coronal and axial networks. Create full aparc.DKTatlas+aseg.mgz. Parameters ---------- prediction_sag : npt.NDArray - sagittal prediction (labels) + Sagittal prediction (labels). num_classes : int - number of SAGITTAL classes (96 for full classes, 51 for hemi split, 21 for aseg) (Default value = 51) + Number of SAGITTAL classes (96 for full classes, 51 for hemi split, 21 for aseg) (Default value = 51). lut : Optional[str] - look-up table listing class labels (Default value = None) + Look-up table listing class labels (Default value = None). Returns ------- np.ndarray - Remapped prediction - + Remapped prediction. """ + r = range + _idx = [] if num_classes == 96: - idx_list = np.asarray( - [ - 0, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 1, - 2, - 3, - 14, - 15, - 4, - 16, - 17, - 18, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - 27, - 28, - 29, - 30, - 31, - 32, - 33, - 34, - 35, - 36, - 37, - 38, - 39, - 40, - 41, - 42, - 43, - 44, - 45, - 46, - 47, - 48, - 49, - 50, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - 27, - 28, - 29, - 30, - 31, - 32, - 33, - 34, - 35, - 36, - 37, - 38, - 39, - 40, - 41, - 42, - 43, - 44, - 45, - 46, - 47, - 48, - 49, - 50, - ], - dtype=np.int16, - ) - + _idx = [[0], r(5, 14), r(1, 4), [14, 15, 4], r(16, 19), r(5, 51), r(20, 51)] elif num_classes == 51: - idx_list = np.asarray( - [ - 0, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 1, - 2, - 3, - 14, - 15, - 4, - 16, - 17, - 18, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - 27, - 28, - 29, - 30, - 31, - 32, - 33, - 34, - 35, - 36, - 37, - 38, - 39, - 40, - 41, - 42, - 43, - 44, - 45, - 46, - 47, - 48, - 49, - 50, - 20, - 22, - 27, - 29, - 30, - 31, - 33, - 34, - 38, - 39, - 40, - 41, - 42, - 45, - ], - dtype=np.int16, - ) - + _idx = [[0], r(5, 14), r(1, 4), [14, 15, 4], r(16, 19), r(5, 51)] + _idx.extend([[20, 22, 27], r(29, 32), [33, 34], r(38, 43), [45]]) elif num_classes == 21: - idx_list = np.asarray( - [ - 0, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 1, - 2, - 3, - 15, - 16, - 4, - 17, - 18, - 19, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - ], - dtype=np.int16, - ) - + _idx = [[0], r(5, 15), r(1, 4), [15, 16, 4], r(17, 20), r(5, 21)] + if _idx: + from itertools import chain + idx_list = list(chain(*_idx)) else: assert lut is not None, "lut is not defined!" idx_list = infer_mapping_from_lut(num_classes, lut) - prediction_full = prediction_sag[:, idx_list, :, :] - return prediction_full + return prediction_sag[:, idx_list, :, :] # Clean up and class separation def bbox_3d( img: npt.NDArray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Extract the three-dimensional bounding box coordinates. + """ + Extract the three-dimensional bounding box coordinates. Parameters ---------- img : npt.NDArray - mri image + Mri image. Returns ------- np.ndarray - rmin + Rmin. np.ndarray - rmax + Rmax. np.ndarray - cmin + Cmin. np.ndarray - cmax + Cmax. np.ndarray - zmin + Zmin. np.ndarray - zmax - + Zmax. """ r = np.any(img, axis=(1, 2)) c = np.any(img, axis=(0, 2)) @@ -1369,18 +1201,18 @@ def bbox_3d( def get_largest_cc(segmentation: npt.NDArray) -> np.ndarray: - """Find the largest connected component of segmentation. + """ + Find the largest connected component of segmentation. Parameters ---------- segmentation : npt.NDArray - segmentation + Segmentation. Returns ------- np.ndarray - largest connected component of segmentation (binary mask) - + Largest connected component of segmentation (binary mask). """ labels = label(segmentation, connectivity=3, background=0) diff --git a/FastSurferCNN/data_loader/dataset.py b/FastSurferCNN/data_loader/dataset.py index 1434bc7b..5efe6ee3 100644 --- a/FastSurferCNN/data_loader/dataset.py +++ b/FastSurferCNN/data_loader/dataset.py @@ -31,7 +31,9 @@ # Operator to load imaged for inference class MultiScaleOrigDataThickSlices(Dataset): - """Load MRI-Image and process it to correct format for network inference.""" + """ + Load MRI-Image and process it to correct format for network inference. + """ def __init__( self, @@ -40,19 +42,19 @@ def __init__( cfg: yacs.config.CfgNode, transforms: Optional = None ): - """Construct object. + """ + Construct object. Parameters ---------- orig_data : npt.NDArray - Orignal Data + Orignal Data. orig_zoom : npt.NDArray - Original zoomfactors + Original zoomfactors. cfg : yacs.config.CfgNode - Configuration Node + Configuration Node. transforms : Optional - Transformer for the image. Defaults to None - + Transformer for the image. Defaults to None. """ assert ( orig_data.max() > 0.8 @@ -83,7 +85,8 @@ def __init__( self.transforms = transforms def _get_scale_factor(self) -> npt.NDArray[float]: - """Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. + """ + Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is taken from voxel size in image header. ToDO: This needs to be updated based on the plane we are looking at in case we @@ -92,26 +95,25 @@ def _get_scale_factor(self) -> npt.NDArray[float]: Returns ------- npt.NDArray[float] - scale factor along x and y dimension - + Scale factor along x and y dimension. """ scale = self.base_res / np.asarray(self.zoom) return scale def __getitem__(self, index: int) -> Dict: - """Return a single image and its scale factor. + """ + Return a single image and its scale factor. Parameters ---------- index : int - Index of image to get + Index of image to get. Returns ------- dict - Dictionary of image and scale factor - + Dictionary of image and scale factor. """ img = self.images[index] @@ -122,19 +124,22 @@ def __getitem__(self, index: int) -> Dict: return {"image": img, "scale_factor": scale_factor} def __len__(self) -> int: - """Return length. + """ + Return length. Returns ------- int - count + Count. """ return self.count # Operator to load hdf5-file for training class MultiScaleDataset(Dataset): - """Class for loading aseg file with augmentations (transforms).""" + """ + Class for loading aseg file with augmentations (transforms). + """ def __init__( self, @@ -143,19 +148,19 @@ def __init__( gn_noise: bool = False, transforms: Optional = None ): - """Construct object. + """ + Construct object. Parameters ---------- dataset_path : str - Path to the dataset + Path to the dataset. cfg : yacs.config.CfgNode - Configuration node + Configuration node. gn_noise : bool - Whether to add gaussian noise (Default value = False) + Whether to add gaussian noise (Default value = False). transforms : Optional - Transformer to apply to the image (Default value = None) - + Transformer to apply to the image (Default value = None). """ self.max_size = cfg.DATA.PADDED_SIZE self.base_res = cfg.MODEL.BASE_RES @@ -223,12 +228,13 @@ def __init__( ) def get_subject_names(self): - """Get the subject name. + """ + Get the subject name. Returns ------- list - list of subject names + List of subject names. """ return self.subjects @@ -237,7 +243,8 @@ def _get_scale_factor( img_zoom: torch.Tensor, scale_aug: torch.Tensor ) -> npt.NDArray[float]: - """Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. + """ + Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is taken from voxel size in image header. @@ -247,15 +254,14 @@ def _get_scale_factor( Parameters ---------- img_zoom : torch.Tensor - Image zoom factor + Image zoom factor. scale_aug : torch.Tensor - [MISSING] + Scale augmentation factor. Returns ------- npt.NDArray[float] - scale factor along x and y dimension - + Scale factor along x and y dimension. """ if torch.all(scale_aug > 0): img_zoom *= 1 / scale_aug @@ -274,18 +280,18 @@ def _pad( self, image: npt.NDArray ) -> np.ndarray: - """Pad the image with zeros. + """ + Pad the image with zeros. Parameters ---------- image : npt.NDArray - Image to pad + Image to pad. Returns ------- padded_image - Padded image - + Padded image. """ if len(image.shape) == 2: h, w = image.shape @@ -308,26 +314,26 @@ def unify_imgs( label: npt.NDArray, weight: npt.NDArray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Pad img, label and weight. + """ + Pad img, label and weight. Parameters ---------- img : npt.NDArray - image to unify + Image to unify. label : npt.NDArray - labels of the image + Labels of the image. weight : npt.NDArray - weights of the image + Weights of the image. Returns ------- np.ndarray - img + Img. np.ndarray - label + Label. np.ndarray - weight - + Weight. """ img = self._pad(img) label = self._pad(label) @@ -336,17 +342,18 @@ def unify_imgs( return img, label, weight def __getitem__(self, index): - """[MISSING]. + """ + Retrieve processed data at the specified index. Parameters ---------- - index : - [MISSING] + index : int + Index to retrieve data for. Returns ------- - [MISSING] - + dict + Dictionary containing torch tensors for image, label, weight, and scale factor. """ padded_img, padded_label, padded_weight = self.unify_imgs( self.images[index], self.labels[index], self.weights[index] @@ -395,14 +402,17 @@ def __getitem__(self, index): } def __len__(self): - """Return count.""" + """ + Return count. + """ return self.count # Operator to load hdf5-file for validation class MultiScaleDatasetVal(Dataset): - """Class for loading aseg file with augmentations (transforms).""" - + """ + Class for loading aseg file with augmentations (transforms). + """ def __init__(self, dataset_path, cfg, transforms=None): self.max_size = cfg.DATA.PADDED_SIZE @@ -469,11 +479,14 @@ def __init__(self, dataset_path, cfg, transforms=None): ) def get_subject_names(self): - """Get subject names.""" + """ + Get subject names. + """ return self.subjects def _get_scale_factor(self, img_zoom): - """Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. + """ + Get scaling factor to match original resolution of input image to final resolution of FastSurfer base network. Input resolution is taken from voxel size in image header. @@ -482,20 +495,21 @@ def _get_scale_factor(self, img_zoom): Parameters ---------- - img_zoom : - zooming factor [MISSING] + img_zoom : np.ndarray + Voxel sizes of the image. Returns ------- - np.ndarray : float32 - scale factor along x and y dimension - + np.ndarray : numpy.typing.NDArray[float] + Scale factor along x and y dimension. """ scale = self.base_res / img_zoom return scale def __getitem__(self, index): - """Get item.""" + """ + Get item. + """ img = self.images[index] label = self.labels[index] weight = self.weights[index] @@ -524,5 +538,7 @@ def __getitem__(self, index): } def __len__(self): - """Get count.""" + """ + Get count. + """ return self.count diff --git a/FastSurferCNN/data_loader/loader.py b/FastSurferCNN/data_loader/loader.py index 3a838b82..1ae09d9a 100644 --- a/FastSurferCNN/data_loader/loader.py +++ b/FastSurferCNN/data_loader/loader.py @@ -24,21 +24,20 @@ def get_dataloader(cfg: yacs.config.CfgNode, mode: str): - """Create the dataset and pytorch data loader. + """ + Create the dataset and pytorch data loader. Parameters ---------- cfg : yacs.config.CfgNode - configuration node + Configuration node. mode : str - loading data for train, val and test mode + Loading data for train, val and test mode. Returns ------- torch.utils.data.DataLoader - dataloader with given configs and mode - - + Dataloader with given configs and mode. """ assert mode in ["train", "val"], f"dataloader mode is incorrect {mode}" diff --git a/FastSurferCNN/download_checkpoints.py b/FastSurferCNN/download_checkpoints.py index 5f6ad903..54e5377b 100644 --- a/FastSurferCNN/download_checkpoints.py +++ b/FastSurferCNN/download_checkpoints.py @@ -15,24 +15,46 @@ # limitations under the License. import argparse +from functools import lru_cache +from typing import Optional +from FastSurferCNN.utils import PLANES from FastSurferCNN.utils.checkpoint import ( check_and_download_ckpts, get_checkpoints, - VINN_AXI, - VINN_COR, - VINN_SAG, - URL, + load_checkpoint_config_defaults, + YAML_DEFAULT as VINN_YAML, ) + from CerebNet.utils.checkpoint import ( - CEREBNET_AXI, - CEREBNET_COR, - CEREBNET_SAG, - URL as CEREBNET_URL, + YAML_DEFAULT as CEREBNET_YAML, +) +from HypVINN.utils.checkpoint import ( + YAML_DEFAULT as HYPVINN_YAML, ) -if __name__ == "__main__": +class ConfigCache: + @lru_cache + def vinn_url(self): + return load_checkpoint_config_defaults("url", filename=VINN_YAML) + + @lru_cache + def cerebnet_url(self): + return load_checkpoint_config_defaults("url", filename=CEREBNET_YAML) + + @lru_cache + def hypvinn_url(self): + return load_checkpoint_config_defaults("url", filename=HYPVINN_YAML) + + def all_urls(self): + return self.vinn_url() + self.cerebnet_url() + self.hypvinn_url() + + +defaults = ConfigCache() + + +def make_arguments(): parser = argparse.ArgumentParser( description="Check and Download Network Checkpoints" ) @@ -54,42 +76,91 @@ action="store_true", help="Check and download CerebNet default checkpoints", ) + + parser.add_argument( + "--hypvinn", + default=False, + action="store_true", + help="Check and download HypVinn default checkpoints", + ) + parser.add_argument( "--url", type=str, default=None, - help="Specify you own base URL. This is applied to all models. \n" - "Default for VINN: {} \n" - "Default for CerebNet: {}".format(URL, CEREBNET_URL), + help=f"Specify you own base URL. This is applied to all models. \n" + f"Default for VINN: {defaults.vinn_url()} \n" + f"Default for CerebNet: {defaults.cerebnet_url()} \n" + f"Default for HypVINN: {defaults.hypvinn_url()}", ) parser.add_argument( "files", nargs="*", - help="Checkpoint file paths to download, e.g. checkpoints/aparc_vinn_axial_v2.0.0.pkl ...", + help="Checkpoint file paths to download, e.g. " + "checkpoints/aparc_vinn_axial_v2.0.0.pkl ...", ) - args = parser.parse_args() - - if not args.vinn and not args.files and not args.cerebnet and not args.all: - print( - "Specify either files to download or --vinn, --cerebnet or --all, see help -h." - ) - exit(1) - - # FastSurferVINN checkpoints - if args.vinn or args.all: - get_checkpoints( - VINN_AXI, VINN_COR, VINN_SAG, URL if args.url is None else args.url - ) - - # CerebNet checkpoints - if args.cerebnet or args.all: - get_checkpoints( - CEREBNET_AXI, - CEREBNET_COR, - CEREBNET_SAG, - CEREBNET_URL if args.url is None else args.url, - ) - - # later we can add more defaults here (for other sub-segmentation networks, or old CNN) - for fname in args.files: - check_and_download_ckpts(fname, URL if args.url is None else args.url) + return parser.parse_args() + + +def main( + vinn: bool, + cerebnet: bool, + hypvinn: bool, + all: bool, + files: list[str], + url: Optional[str] = None, +) -> int | str: + if not vinn and not files and not cerebnet and not hypvinn and not all: + return ("Specify either files to download or --vinn, --cerebnet, " + "--hypvinn, or --all, see help -h.") + + try: + # FastSurferVINN checkpoints + if vinn or all: + vinn_config = load_checkpoint_config_defaults( + "checkpoint", + filename=VINN_YAML, + ) + get_checkpoints( + *(vinn_config[plane] for plane in PLANES), + urls=defaults.vinn_url() if url is None else [url] + ) + # CerebNet checkpoints + if cerebnet or all: + cerebnet_config = load_checkpoint_config_defaults( + "checkpoint", + filename=CEREBNET_YAML, + ) + get_checkpoints( + *(cerebnet_config[plane] for plane in PLANES), + urls=defaults.cerebnet_url() if url is None else [url], + ) + # HypVINN checkpoints + if hypvinn or all: + hypvinn_config = load_checkpoint_config_defaults( + "checkpoint", + filename=HYPVINN_YAML, + ) + get_checkpoints( + *(hypvinn_config[plane] for plane in PLANES), + urls=defaults.hypvinn_url() if url is None else [url], + ) + for fname in files: + check_and_download_ckpts( + fname, + urls=defaults.all_urls() if url is None else [url], + ) + except Exception as e: + from traceback import print_exception + print_exception(e) + return e.args[0] + return 0 + + +if __name__ == "__main__": + import sys + from logging import basicConfig, INFO + + basicConfig(stream=sys.stdout, level=INFO) + args = make_arguments() + sys.exit(main(**vars(args))) diff --git a/FastSurferCNN/generate_hdf5.py b/FastSurferCNN/generate_hdf5.py index d14e8c0f..2b13a430 100644 --- a/FastSurferCNN/generate_hdf5.py +++ b/FastSurferCNN/generate_hdf5.py @@ -12,50 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. +import glob + # IMPORTS import time -import glob -from os.path import join, dirname from collections import defaultdict -from typing import Tuple, Dict +from os.path import dirname, join +from pathlib import Path +from typing import Dict, Tuple -import numpy as np -import nibabel as nib import h5py -from numpy import typing as npt, ndarray +import nibabel as nib +import numpy as np +from numpy import ndarray +from numpy import typing as npt from FastSurferCNN.data_loader.data_utils import ( - transform_axial, - transform_sagittal, - map_aparc_aseg2label, create_weight_mask, - get_thick_slices, filter_blank_slices_thick, - read_classes_from_lut, get_labels_from_lut, + get_thick_slices, + map_aparc_aseg2label, + read_classes_from_lut, + transform_axial, + transform_sagittal, unify_lateralized_labels, ) from FastSurferCNN.utils import logging +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT LOGGER = logging.getLogger(__name__) class H5pyDataset: - """Class representing H5py Dataset. + """ + Class representing H5py Dataset. - Methods - ------- - __init__ - Consturctor - _load_volumes - load image and segmentation volume - transform - Transform image along axis - _pad_image - Pad image with zeroes - create_hdf5_dataset - Create a hdf5 file - Attributes ---------- dataset_name : str @@ -99,35 +91,54 @@ class H5pyDataset: Number of subjects processing : str Use aseg, aparc or no specific mapping processing (Default: "aparc") + + Methods + ------- + __init__ + Consturctor + _load_volumes + load image and segmentation volume + transform + Transform image along axis + _pad_image + Pad image with zeroes + create_hdf5_dataset + Create a hdf5 file """ def __init__(self, params: Dict, processing: str = "aparc"): - """Construct H5pyDataset object. + """ + Construct H5pyDataset object. Parameters ---------- params : Dict - dataset_name (str): path and name of hdf5-data_loader - data_path (str): Directory with images to load - thickness (int): Number of pre- and succeeding slices - image_name (str): Default name of original images - gt_name (str): Default name for ground truth segmentations. - gt_nocc (str): Segmentation without corpus callosum (used to mask this segmentation in ground truth). - If the used segmentation was already processed, do not set this argument." - sizes (int): Sizes of images in the dataset. - max_weight (int): Overall max weight for any voxel in weight mask. - edge_weight (int): Weight for edges in weight mask. - hires_weight (int): Weight for hires elements (sulci, WM strands, cortex border) in weight mask. - gradient (bool): Turn on to only use median weight frequency (no gradient) - gm_mask (bool): Turn on to add cortex mask for hires-processing. - lut (str): FreeSurfer-style Color Lookup Table with labels to use in final prediction. + A dictionary containing the following keys: + - dataset_name (str): Path and name of hdf5-data_loader + - data_path (str): Directory with images to load + - thickness (int): Number of pre- and succeeding slices + - image_name (str): Default name of original images + - gt_name (str): Default name for ground truth segmentations. + - gt_nocc (str): Segmentation without corpus callosum (used to mask this segmentation in ground truth). + If the used segmentation was already processed, do not set this argument. + - sizes (int): Sizes of images in the dataset. + - max_weight (int): Overall max weight for any voxel in the weight mask. + - edge_weight (int): Weight for edges in the weight mask. + - hires_weight (int): Weight for hires elements (sulci, WM strands, cortex border) in the weight mask. + - gradient (bool): Turn on to only use median weight frequency (no gradient) + - gm_mask (bool): Turn on to add cortex mask for hires-processing. + - lut (str): FreeSurfer-style Color Lookup Table with labels to use in the final prediction. Has to have columns: ID LabelName R G B A - sag-mask (tuple[str, str, ...]): Suffixes of labels names to mask for final sagittal labels. - combi (str): Suffixes of labels names to combine. - patter (str): Pattern to match files in directory. - processing : str - Use aseg (Default value = "aparc") + - sag_mask (tuple[str, str]): Suffixes of labels names to mask for final sagittal labels. + - combi (str): Suffixes of labels names to combine. + - pattern (str): Pattern to match files in the directory. + processing : str, optional + Use aseg (Default value = "aparc"). + Returns + ------- + None + This is a constructor function, it returns nothing. """ self.dataset_name = params["dataset_name"] self.data_path = params["data_path"] @@ -146,7 +157,7 @@ def __init__(self, params: Dict, processing: str = "aparc"): self.gm_mask = params["gm_mask"] self.lut = read_classes_from_lut(params["lut"]) - self.labels, self.labels_sag = get_labels_from_lut(self.lut, params["sag-mask"]) + self.labels, self.labels_sag = get_labels_from_lut(self.lut, params["sag_mask"]) self.lateralization = unify_lateralized_labels(self.lut, params["combi"]) if params["csv_file"] is not None: @@ -159,28 +170,29 @@ def __init__(self, params: Dict, processing: str = "aparc"): self.data_set_size = len(self.subject_dirs) - def _load_volumes(self, subject_path: str - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Tuple]: - """Load the given image and segmentation and gets the zoom values. + def _load_volumes( + self, subject_path: str + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Tuple]: + """ + Load the given image and segmentation and gets the zoom values. Checks if an aseg-nocc file is set and loads it instead Parameters ---------- subject_path : str - path to subjectfile + Path to subject file. Returns ------- ndarray - original image + Original image. ndarray - segmentation ground truth + Segmentation ground truth. ndarray - segmentation ground truth without corpus callosum + Segmentation ground truth without corpus callosum. tuple - zoom values - + Zoom values. """ # Load the orig and extract voxel spacing information (x, y, and z dim) LOGGER.info( @@ -205,26 +217,27 @@ def _load_volumes(self, subject_path: str return orig, aseg, aseg_nocc, zoom - def transform(self, plane: str, imgs: npt.NDArray, zoom: npt.NDArray - ) -> Tuple[npt.NDArray, npt.NDArray]: - """Transform the image and zoom along the given axis. + def transform( + self, plane: str, imgs: npt.NDArray, zoom: npt.NDArray + ) -> Tuple[npt.NDArray, npt.NDArray]: + """ + Transform the image and zoom along the given axis. Parameters ---------- plane : str - plane (sagittal, axial, ) + Plane (sagittal, axial, ). imgs : npt.NDArray - input image + Input image. zoom : npt.NDArray - zoom factors + Zoom factors. Returns ------- npt.NDArray - transformed image, + Transformed image. npt.NDArray - transformed zoom facors - + Transformed zoom facors. """ for i in range(len(imgs)): if self.plane == "sagittal": @@ -238,20 +251,20 @@ def transform(self, plane: str, imgs: npt.NDArray, zoom: npt.NDArray return imgs, zooms def _pad_image(self, img: npt.NDArray, max_out: int) -> np.ndarray: - """Pad the margins of the input image with zeros. + """ + Pad the margins of the input image with zeros. Parameters ---------- img : npt.NDArray - image array + Image array. max_out : int - size of output image + Size of output image. Returns ------- np.ndarray - 0-padded image to the given size - + 0-padded image to the given size. """ # Get correct size = max along shape h, w, d = img.shape @@ -261,19 +274,18 @@ def _pad_image(self, img: npt.NDArray, max_out: int) -> np.ndarray: return padded_img def create_hdf5_dataset(self, blt: int): - """Create a hdf5 dataset. + """ + Create a hdf5 dataset. Parameters ---------- blt : int - Blank sliec threshold - + Blank slice threshold. """ data_per_size = defaultdict(lambda: defaultdict(list)) start_d = time.time() for idx, current_subject in enumerate(self.subject_dirs): - try: start = time.time() @@ -382,7 +394,7 @@ def create_hdf5_dataset(self, blt: int): ) -if __name__ == "__main__": +def make_parser(): import argparse # Training settings @@ -444,8 +456,8 @@ def create_hdf5_dataset(self, blt: int): ) parser.add_argument( "--lut", - type=str, - default=join(dirname(__file__), "/config/FastSurfer_ColorLUT.tsv"), + type=Path, + default=FASTSURFER_ROOT / "/config/FastSurfer_ColorLUT.tsv", help="FreeSurfer-style Color Lookup Table with labels to use in final prediction. " "Has to have columns: ID LabelName R G B A" "Default: FASTSURFERDIR/FastSurferCNN/config/FastSurfer_ColorLUT.tsv.", @@ -512,9 +524,10 @@ def create_hdf5_dataset(self, blt: int): default=256, help="Sizes of images in the dataset. Default: 256", ) + return parser - args = parser.parse_args() +def main(args): dataset_params = { "dataset_name": args.hdf5_name, "data_path": args.data_dir, @@ -528,9 +541,9 @@ def create_hdf5_dataset(self, blt: int): "max_weight": args.max_w, "edge_weight": args.edge_w, "plane": args.plane, - "lut": args.lut, + "lut": str(args.lut), "combi": args.combi, - "sag-mask": args.sag_mask, + "sag_mask": args.sag_mask, "hires_weight": args.hires_w, "gm_mask": args.gm, "gradient": not args.no_grad, @@ -538,3 +551,9 @@ def create_hdf5_dataset(self, blt: int): dataset_generator = H5pyDataset(params=dataset_params, processing=args.processing) dataset_generator.create_hdf5_dataset(args.blank_slice_thresh) + + +if __name__ == "__main__": + parser = make_parser() + args = parser.parse_args() + main(args) diff --git a/FastSurferCNN/inference.py b/FastSurferCNN/inference.py index 481c0528..f99fea0a 100644 --- a/FastSurferCNN/inference.py +++ b/FastSurferCNN/inference.py @@ -12,46 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + # IMPORTS import time -from typing import Optional, Dict, Tuple, Union -import os +from typing import Dict, Optional, Tuple, Union import numpy as np -from numpy import typing as npt import torch import yacs.config +from numpy import typing as npt from pandas import DataFrame from torch.utils.data import DataLoader from torchvision import transforms -from FastSurferCNN.utils import logging -from FastSurferCNN.models.networks import build_model from FastSurferCNN.data_loader.augmentation import ToTensorTest from FastSurferCNN.data_loader.data_utils import map_prediction_sagittal2full from FastSurferCNN.data_loader.dataset import MultiScaleOrigDataThickSlices - +from FastSurferCNN.models.networks import build_model +from FastSurferCNN.utils import logging logger = logging.getLogger(__name__) class Inference: """Model evaluation class to run inference using FastSurferCNN. - - Methods - ------- - setup_model - Set up the initial model - set_cfg - Set configuration node - to - Moves and/or casts the parameters and buffers. - load_checkpoint - Load the checkpoint - eval - Evaluate predictions - run - Run the loaded model Attributes ---------- @@ -70,9 +55,24 @@ class Inference: model_name : str Name of the model alpha : Dict[str, float] - [MISSING] + Alpha values for different planes. post_prediction_mapping_hook - [MISSING] + Hook for post prediction mapping. + + Methods + ------- + setup_model + Set up the initial model + set_cfg + Set configuration node + to + Moves and/or casts the parameters and buffers. + load_checkpoint + Load the checkpoint + eval + Evaluate predictions + run + Run the loaded model """ permute_order: Dict[str, Tuple[int, int, int, int]] @@ -80,25 +80,25 @@ class Inference: default_device: torch.device def __init__( - self, - cfg: yacs.config.CfgNode, - device: torch.device, - ckpt: str = "", - lut: Union[None, str, np.ndarray, DataFrame] = None + self, + cfg: yacs.config.CfgNode, + device: torch.device, + ckpt: str = "", + lut: Union[None, str, np.ndarray, DataFrame] = None, ): - """Construct Inference object. + """ + Construct Inference object. Parameters ---------- cfg : yacs.config.CfgNode - configuration Node + Configuration Node. device : torch.device - device specification for distributed computation usage. + Device specification for distributed computation usage. ckpt : str - string or os.PathLike object containing the name to the checkpoint file (Default value = "") - lut : Union[None, str, np.ndarray, DataFrame] - [MISSING] (Default value = None) - + String or os.PathLike object containing the name to the checkpoint file (Default value = ""). + lut : str, np.ndarray, DataFrame, optional + Lookup table for mapping (Default value = None). """ # Set random seed from configs. np.random.seed(cfg.RNG_SEED) @@ -138,15 +138,15 @@ def __init__( self.load_checkpoint(ckpt) def setup_model(self, cfg=None, device: torch.device = None): - """Set up the model. + """ + Set up the model. Parameters ---------- cfg : yacs.config.CfgNode - configuration Node (Default value = None) + Configuration Node (Default value = None). device : torch.device - device specification for distributed computation usage. (Default value = None) - + Device specification for distributed computation usage. (Default value = None). """ if cfg is not None: self.cfg = cfg @@ -161,24 +161,24 @@ def setup_model(self, cfg=None, device: torch.device = None): self.device = None def set_cfg(self, cfg: yacs.config.CfgNode): - """[MISSING]. + """ + Set the configuration node. Parameters ---------- cfg : yacs.config.CfgNode - Configuration node - + Configuration node. """ self.cfg = cfg def to(self, device: Optional[torch.device] = None): - """Move and/or cast the parameters and buffers. + """ + Move and/or cast the parameters and buffers. Parameters ---------- device : Optional[torch.device] - the desired device of the parameters and buffers in this module (Default value = None) - + The desired device of the parameters and buffers in this module (Default value = None). """ if self.model_parallel: raise RuntimeError( @@ -189,13 +189,13 @@ def to(self, device: Optional[torch.device] = None): self.model.to(device=_device) def load_checkpoint(self, ckpt: Union[str, os.PathLike]): - """Load the checkpoint and set device and model. + """ + Load the checkpoint and set device and model. Parameters ---------- ckpt : Union[str, os.PathLike] - string or os.PathLike object containing the name to the checkpoint file - + String or os.PathLike object containing the name to the checkpoint file. """ logger.info("Loading checkpoint {}".format(ckpt)) @@ -213,7 +213,9 @@ def load_checkpoint(self, ckpt: Union[str, os.PathLike]): # make sure the model is, where it is supposed to be self.model.to(self.device) - model_state = torch.load(ckpt, map_location=device) + # WARNING: weights_only=False can cause unsafe code execution, but here the + # checkpoint can be considered to be from a safe source + model_state = torch.load(ckpt, map_location=device, weights_only=False) self.model.load_state_dict(model_state["model_state"]) # workaround for mps (move the model back to mps) @@ -223,70 +225,123 @@ def load_checkpoint(self, ckpt: Union[str, os.PathLike]): if self.model_parallel: self.model = torch.nn.DataParallel(self.model) - def get_modelname(self): - """Return the model name.""" + def get_modelname(self) -> str: + """ + Return the model name. + + Returns + ------- + str + The name of the model. + """ return self.model_name - def get_cfg(self): - """Return the configurations.""" + def get_cfg(self) -> yacs.config.CfgNode: + """ + Return the configurations. + + Returns + ------- + yacs.config.CfgNode + Configuration node. + """ return self.cfg - def get_num_classes(self): - """Return the number of classes.""" + def get_num_classes(self) -> int: + """ + Return the number of classes. + + Returns + ------- + int + The number of classes. + """ return self.cfg.MODEL.NUM_CLASSES - def get_plane(self): - """Return the plane.""" + def get_plane(self) -> str: + """ + Return the plane. + + Returns + ------- + str + The plane used in the model. + """ return self.cfg.DATA.PLANE - def get_model_height(self): - """Return the model height.""" + def get_model_height(self) -> int: + """ + Return the model height. + + Returns + ------- + int + The height of the model. + """ return self.cfg.MODEL.HEIGHT - def get_model_width(self): - """Return the model width.""" + def get_model_width(self) -> int: + """ + Return the model width. + + Returns + ------- + int + The width of the model. + """ return self.cfg.MODEL.WIDTH - def get_max_size(self): - """Return the max size.""" + def get_max_size(self) -> int | tuple[int, int]: + """ + Return the max size. + + Returns + ------- + int | tuple[int, int] + The maximum size, either a single value or a tuple (width, height). + """ if self.cfg.MODEL.OUT_TENSOR_WIDTH == self.cfg.MODEL.OUT_TENSOR_HEIGHT: return self.cfg.MODEL.OUT_TENSOR_WIDTH else: return self.cfg.MODEL.OUT_TENSOR_WIDTH, self.cfg.MODEL.OUT_TENSOR_HEIGHT - def get_device(self): - """Return the device.""" + def get_device(self) -> torch.device: + """ + Return the device. + + Returns + ------- + torch.device + The device used for computation. + """ return self.device @torch.no_grad() def eval( - self, - init_pred: torch.Tensor, - val_loader: DataLoader, - *, - out_scale: Optional = None, - out: Optional[torch.Tensor] = None + self, + init_pred: torch.Tensor, + val_loader: DataLoader, + *, + out_scale: Optional = None, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Perform prediction and inplace-aggregate views into pred_prob. Parameters ---------- init_pred : torch.Tensor - initial prediction + Initial prediction. val_loader : DataLoader - value loader - * : - + Validation loader. out_scale : Optional - [MISSING] (Default value = None) + Output scale (Default value = None). out : Optional[torch.Tensor] - previous prediction tensor (Default value = None) + Previous prediction tensor (Default value = None). Returns ------- torch.Tensor - prediction probability tensor - + Prediction probability tensor. """ self.model.eval() # we should check here, whether the DataLoader is a Random or a SequentialSampler, but we cannot easily. @@ -303,8 +358,8 @@ def eval( ii = [slice(None) for _ in range(4)] pred_ii = tuple(slice(i) for i in target_shape[:3]) - from tqdm.contrib.logging import logging_redirect_tqdm from tqdm import tqdm + from tqdm.contrib.logging import logging_redirect_tqdm if out is None: out = init_pred.detach().clone() @@ -313,7 +368,6 @@ def eval( for batch_idx, batch in tqdm( enumerate(val_loader), total=len(val_loader), unit="batch" ): - # move data to the model device images, scale_factors = batch["image"].to(self.device), batch[ "scale_factor" @@ -359,39 +413,39 @@ def eval( @torch.no_grad() def run( - self, - init_pred: torch.Tensor, - img_filename: str, - orig_data: npt.NDArray, - orig_zoom: npt.NDArray, - out: Optional[torch.Tensor] = None, - out_res: Optional[int] = None, - batch_size: int = None + self, + init_pred: torch.Tensor, + img_filename: str, + orig_data: npt.NDArray, + orig_zoom: npt.NDArray, + out: Optional[torch.Tensor] = None, + out_res: Optional[int] = None, + batch_size: int = None, ) -> torch.Tensor: - """Run the loaded model on the data (T1) from orig_data and img_filename (for messages only) with scale factors orig_zoom. + """ + Run the loaded model on the data (T1) from orig_data and img_filename (for messages only) with scale factors orig_zoom. Parameters ---------- init_pred : torch.Tensor - initial prediction + Initial prediction. img_filename : str - original image filename + Original image filename. orig_data : npt.NDArray - original image data + Original image data. orig_zoom : npt.NDArray - original zoom + Original zoom. out : Optional[torch.Tensor] - updated output tensor (Default = None) + Updated output tensor (Default = None). out_res : Optional[int] - output resolution (Default value = None) + Output resolution (Default value = None). batch_size : int - batch size (Default = None) + Batch size (Default = None). Returns ------- toch.Tensor - prediction probability tensor - + Prediction probability tensor. """ # Set up DataLoader test_dataset = MultiScaleOrigDataThickSlices( diff --git a/FastSurferCNN/models/interpolation_layer.py b/FastSurferCNN/models/interpolation_layer.py index f20a4fd1..ccafd4da 100644 --- a/FastSurferCNN/models/interpolation_layer.py +++ b/FastSurferCNN/models/interpolation_layer.py @@ -18,12 +18,11 @@ import numpy as np import torch -from torch import nn, Tensor +from torch import Tensor, nn from torch.nn import functional as _F from FastSurferCNN.utils.logging import getLogger as _getLogger - LOGGER = _getLogger(__name__) T_Scale = _T.TypeVar("T_Scale", _T.List[float], Tensor) @@ -31,46 +30,47 @@ class _ZoomNd(nn.Module): - """Abstract Class to perform a crop and interpolation on a (N+2)-dimensional Tensor respecting batch and channel. + """ + Abstract Class to perform a crop and interpolation on a (N+2)-dimensional Tensor respecting batch and channel. Attributes ---------- _mode - interpolation mode as in `torch.nn.interpolate` (default: 'neareast') + (Protected) Interpolation mode as in `torch.nn.interpolate` (default: 'neareast'). _target_shape - Target tensor size for after this module, + (Protected) Target tensor size for after this module, not including batchsize and channels. _N - Number of dimensions + (Protected) Internal number of dimensions. Methods ------- forward - forward propagation + Forward propagation. _fix_scale_factors - Checking and fixing the conformity of scale_factors + (Protected) Checking and fixing the conformity of scale_factors. _interpolate - abstract method - -calculate_crop_pad - Return start- and end- coordinate + (Protected) Abstract method. + _calculate_crop_pad + (Protected) Return start- and end- coordinate. """ def __init__( - self, - target_shape: _T.Optional[_T.Sequence[int]], - interpolation_mode: str = "nearest" + self, + target_shape: _T.Optional[_T.Sequence[int]], + interpolation_mode: str = "nearest", ): - """Construct Zoom object. + """ + Construct Zoom object. Parameters ---------- target_shape : _T.Optional[_T.Sequence[int]] Target tensor size for after this module, not including batchsize and channels. - interpolation_mode : str - interpolation mode as in `torch.nn.interpolate` - (default: 'neareast') - + interpolation_mode : str, default="nearest" + Interpolation mode as in `torch.nn.interpolate` + (default: 'neareast'). """ super(_ZoomNd, self).__init__() self._mode = interpolation_mode @@ -81,12 +81,16 @@ def __init__( @property def target_shape(self) -> _T.Tuple[int, ...]: - """Return the target shape.""" + """ + Return the target shape. + """ return self._target_shape @target_shape.setter def target_shape(self, target_shape: _T.Optional[_T.Sequence[int]]) -> None: - """Validate and set the target_shape.""" + """ + Validate and set the target_shape. + """ tup_target_shape = ( tuple(target_shape) if isinstance(target_shape, _T.Iterable) else tuple() ) @@ -107,16 +111,14 @@ def target_shape(self, target_shape: _T.Optional[_T.Sequence[int]]) -> None: ) def forward( - self, - input_tensor: Tensor, - scale_factors: T_ScaleAll, - rescale: bool = False + self, input_tensor: Tensor, scale_factors: T_ScaleAll, rescale: bool = False ) -> _T.Tuple[Tensor, _T.List[T_Scale]]: - """Zoom the `input_tensor` with `scale_factors`. + """ + Zoom the `input_tensor` with `scale_factors`. This is not an exact zoom, but rather an "approximate zoom". This is due to the fact that the backbone function only interpolates between integer-sized images and therefore - the target shape must be rounded to the nearest integer + the target shape must be rounded to the nearest integer. Parameters ---------- @@ -130,8 +132,8 @@ def forward( image: The first dimension corresponds to and must be equal to the batch size of the image. The second dimension is optional and may contain different values for the _scale_limits factor per axis. In consequence, this dimension can have 1 or {dim} values. - rescale : bool - (Default value = False) + rescale : bool, default="False" + (Default value = False). Returns ------- @@ -139,10 +141,9 @@ def forward( The zoomed tensor and the zoom factors that were actually used in the calculation for correct rescaling. Notes - ------- + ----- If this Module is used to zoom images of different voxelsizes to the same voxelsize, then `scale_factor` should be equal to `target_voxelsize / source_voxelsize`. - """ if self._N == -1: raise RuntimeError( @@ -186,29 +187,27 @@ def forward( return torch.cat(interp, dim=0), scales_out def _fix_scale_factors( - self, - scale_factors: T_ScaleAll, - batch_size: int + self, scale_factors: T_ScaleAll, batch_size: int ) -> _T.Iterable[_T.Tuple[T_Scale, int]]: - """Check and fix the conformity of scale_factors. + """ + Check and fix the conformity of scale_factors. Parameters ---------- scale_factors : T_ScaleAll - scale factors to fix dimensions + Scale factors to fix dimensions. batch_size : int - number of batches + Number of batches. Yields ------ - _T.Iterable[_T.Tuple[T_Scale, int]] - The next fixed scale factor + tuple[T_Scale, int] + The next fixed scale factor. Raises ------ ValueError - scale_factors is neither a _T.Iterable nor a Number - + Scale_factors is neither a _T.Iterable nor a Number. """ if isinstance(scale_factors, (Tensor, np.ndarray)): batch_size_sf = scale_factors.shape[0] @@ -269,39 +268,41 @@ def _fix_scale_factors( ) def _interpolate(self, *args) -> _T.Tuple[Tensor, T_Scale]: - """Abstract method. + """ + Abstract method. Parameters ---------- args - placeholder - + Placeholder """ raise NotImplementedError def _calculate_crop_pad( - self, - in_shape: _T.Sequence[int], - scale_factor: T_Scale, - dim: int, alignment: str + self, + in_shape: _T.Sequence[int], + scale_factor: T_Scale, + dim: int, + alignment: str, ) -> _T.Tuple[slice, T_Scale, _T.Tuple[int, int], int]: - """Return start- and end- coordinate given sizes, the updated scale factor [MISSING]. + """ + Return start- and end- coordinate given sizes, the updated scale factor. Parameters ---------- in_shape : _T.Sequence[int] - [MISSING] + Input shape. scale_factor : T_Scale - [MISSING] + Scale factor. dim : int - dimension to be cropped + Dimension to be cropped. alignment : str - [MISSING] + Alignment of the cropping. Returns ------- _T.Tuple[slice,T_Scale,_T.Tuple[int,int],int] - slice(start, end), new scale_factor, padding, interp_target_shape + Slice(start, end), new scale_factor, padding, interp_target_shape. """ this_in_shape = in_shape[dim + 2] source_size = self._target_shape[dim] * scale_factor[dim] @@ -366,43 +367,61 @@ def _calculate_crop_pad( class Zoom2d(_ZoomNd): - """Perform a crop and interpolation on a Four-dimensional Tensor respecting batch and channel. - - Attributes - ---------- - _N - Number of dimensions (Here 2) - _crop_position - Position to crop + """ + Perform a crop and interpolation on a Four-dimensional Tensor respecting batch and channel. Methods ------- _interpolate - Crops, interpolates and pads the tensor + (Protected) Crops, interpolates and pads the tensor. """ + _crop_position: str + def __init__( - self, - target_shape: _T.Optional[_T.Sequence[int]], - interpolation_mode: str = "nearest", - crop_position: str = "top_left" + self, + target_shape: _T.Optional[_T.Sequence[int]], + interpolation_mode: str = "nearest", + crop_position: str = "top_left", ): - """Construct Zoom2d object. + """ + Construct Zoom2d object. Parameters ---------- target_shape : _T.Optional[_T.Sequence[int]] Target tensor size for after this module, not including batchsize and channels. - interpolation_mode : str - interpolation mode as in `torch.nn.interpolate` (default: 'nearest') - crop_position : str - crop position to use from 'top_left', 'bottom_left', top_right', 'bottom_right', - 'center' (default: 'top_left') - + interpolation_mode : str, default="nearest" + Interpolation mode as in `torch.nn.interpolate` (default: 'nearest') + crop_position : str, default="top_left" + Crop position to use from 'top_left', 'bottom_left', top_right', 'bottom_right', + 'center' (default: 'top_left'). """ if interpolation_mode not in ["nearest", "bilinear", "bicubic", "area"]: raise ValueError(f"invalid interpolation_mode, got {interpolation_mode}") + self._N = 2 + super(Zoom2d, self).__init__(target_shape, interpolation_mode) + self.crop_position = crop_position + + @property + def crop_position(self) -> str: + """ + Property associated with the position of the image in the data. + """ + return self._crop_position + + @crop_position.setter + def crop_position(self, crop_position: str) -> None: + """ + Set the crop position. + + Parameters + ---------- + crop_position : str + The crop position key from 'top_left', 'bottom_left', top_right', + 'bottom_right', 'center'. + """ if crop_position not in [ "top_left", "bottom_left", @@ -411,33 +430,30 @@ def __init__( "center", ]: raise ValueError(f"invalid crop_position, got {crop_position}") - - self._N = 2 - super(Zoom2d, self).__init__(target_shape, interpolation_mode) - self._crop_position = crop_position - + self._crop_position = crop_position + def _interpolate( - self, - data: Tensor, - scale_factor: _T.Union[Tensor, np.ndarray, _T.Sequence[float]] + self, + data: Tensor, + scale_factor: _T.Union[Tensor, np.ndarray, _T.Sequence[float]], ) -> _T.Tuple[Tensor, T_Scale]: - """Crop, interpolate and pad the tensor according to the scale_factor. + """ + Crop, interpolate and pad the tensor according to the scale_factor. Scale_factor must be 2-length sequence. Parameters ---------- data : Tensor - input, to-be-interpolated tensor + Input, to-be-interpolated tensor. scale_factor : _T.Union[Tensor, np.ndarray, _T.Sequence[float]] - zoom factor - Returns: the interpolated tensor + Zoom factor + Returns: The interpolated tensor. Returns ------- _T.Tuple[Tensor, T_Scale] - The interpolated tensor and its scaling factor - + The interpolated tensor and its scaling factor. """ scale_factor = ( scale_factor.tolist() @@ -490,33 +506,58 @@ def _interpolate( class Zoom3d(_ZoomNd): - """Perform a crop and interpolation on a Five-dimensional Tensor respecting batch and channel.""" + """ + Perform a crop and interpolation on a Five-dimensional Tensor respecting batch and channel. + """ def __init__( - self, - target_shape: _T.Optional[_T.Sequence[int]], - interpolation_mode: str = "nearest", - crop_position: str = "front_top_left" + self, + target_shape: _T.Optional[_T.Sequence[int]], + interpolation_mode: str = "nearest", + crop_position: str = "front_top_left", ): - """Construct Zoom3d object. + """ + Construct Zoom3d object. Parameters ---------- target_shape : _T.Optional[_T.Sequence[int]] Target tensor size for after this module, not including batchsize and channels. - interpolation_mode : str - interpolation mode as in `torch.nn.interpolate` - (default: 'neareast') - crop_position : str - crop position to use from 'front_top_left', 'back_top_left', + interpolation_mode : str, default="nearest" + Interpolation mode as in `torch.nn.interpolate`, + (default: 'neareast'). + crop_position : str, default="front_top_left" + Crop position to use from 'front_top_left', 'back_top_left', 'front_bottom_left', 'back_bottom_left', 'front_top_right', 'back_top_right', - 'front_bottom_right', 'back_bottom_right', 'center' (default: 'front_top_left') - + 'front_bottom_right', 'back_bottom_right', 'center' (default: 'front_top_left'). """ if interpolation_mode not in ["nearest", "trilinear", "area"]: raise ValueError(f"invalid interpolation_mode, got {interpolation_mode}") + self._N = 3 + super(Zoom3d, self).__init__(target_shape, interpolation_mode) + self.crop_position = crop_position + + @property + def crop_position(self) -> str: + """ + Property associated with the position of the image in the data. + """ + return self._crop_position + + @crop_position.setter + def crop_position(self, crop_position: str) -> None: + """ + Set the crop position. + + Parameters + ---------- + crop_position : str + Crop position to use from 'front_top_left', 'back_top_left', + 'front_bottom_left', 'back_bottom_left', 'front_top_right', 'back_top_right', + 'front_bottom_right', 'back_bottom_right', 'center' (default: 'front_top_left'). + """ if crop_position not in [ "front_top_left", "back_top_left", @@ -529,31 +570,27 @@ def __init__( "center", ]: raise ValueError(f"invalid crop_position, got {crop_position}") - - self._N = 3 - super(Zoom3d, self).__init__(target_shape, interpolation_mode) self._crop_position = crop_position - + def _interpolate( - self, - data: Tensor, - scale_factor: _T.Union[Tensor, np.ndarray, _T.Sequence[int]] + self, data: Tensor, scale_factor: _T.Union[Tensor, np.ndarray, _T.Sequence[int]] ): - """Crop, interpolate and pad the tensor according to the scale_factor. + """ + Crop, interpolate and pad the tensor according to the scale_factor. - scale_factor must be 3-length sequence. + Scale_factor must be 3-length sequence. Parameters ---------- data : Tensor - input, to-be-interpolated tensor + Input, to-be-interpolated tensor. scale_factor : _T.Sequence[int] - zoom factor + Zoom factor. Returns ------- _T.Tuple[Tensor, T_Scale] - The interpolated tensor and its scaling factor + The interpolated tensor and its scaling factor. """ scale_factor = ( scale_factor.tolist() diff --git a/FastSurferCNN/models/losses.py b/FastSurferCNN/models/losses.py index db507b3e..7b25820d 100644 --- a/FastSurferCNN/models/losses.py +++ b/FastSurferCNN/models/losses.py @@ -13,50 +13,52 @@ # limitations under the License. + # IMPORTS import torch import yacs.config -from torch import nn, Tensor -from torch.nn.modules.loss import _Loss + +from torch import Tensor, nn from torch.nn import functional as F -from typing import Optional, Union, Tuple +from torch.nn.modules.loss import _Loss from numbers import Real - +from typing import Optional, Tuple, Union class DiceLoss(_Loss): - """Calculate Dice Loss. - + """ + Calculate Dice Loss. + Methods ------- forward - Calulate the DiceLoss + Calulate the DiceLoss. """ def forward( - self, - output: Tensor, - target: Tensor, - weights: Optional[int] = None, - ignore_index: Optional[int] = None - ) -> float: - """Calulate the DiceLoss. + self, + output: Tensor, + target: Tensor, + weights: Optional[int] = None, + ignore_index: Optional[int] = None, + ) -> torch.Tensor: + """ + Calulate the DiceLoss. Parameters ---------- output : Tensor - N x C x H x W Variable + N x C x H x W Variable. target : Tensor - N x C x W LongTensor with starting class at 0 - weights : Optional[int] - C FloatTensor with class wise weights(Default value = None) - ignore_index : Optional[int] - ignore label with index x in the loss calculation (Default value = None) + N x C x W LongTensor with starting class at 0. + weights : int, optional + C FloatTensor with class wise weights(Default value = None). + ignore_index : int, optional + Ignore label with index x in the loss calculation (Default value = None). Returns ------- - float - Calculated Diceloss - + torch.Tensor + Calculated Diceloss. """ eps = 0.001 @@ -92,29 +94,30 @@ def forward( class CrossEntropy2D(nn.Module): - """2D Cross-entropy loss implemented as negative log likelihood. + """ + 2D Cross-entropy loss implemented as negative log likelihood. Attributes ---------- nll_loss - calculated cross-entropy loss + Calculated cross-entropy loss. Methods ------- forward - returns calculated cross entropy + Returns calculated cross entropy. """ - def __init__(self, weight: Optional[Tensor] =None, reduction: str = "none"): - """Construct CrossEntropy2D object. + def __init__(self, weight: Optional[Tensor] = None, reduction: str = "none"): + """ + Construct CrossEntropy2D object. Parameters ---------- - weight : Optional[Tensor] - a manual rescaling weight given to each class. If given, has to be a Tensor of size `C`. Defaults to None + weight : Tensor, optional + A manual rescaling weight given to each class. If given, has to be a Tensor of size `C`. Defaults to None. reduction : str - Specifies the reduction to apply to the output, as in nn.CrossEntropyLoss. Defaults to 'None' - + Specifies the reduction to apply to the output, as in nn.CrossEntropyLoss. Defaults to 'None'. """ super(CrossEntropy2D, self).__init__() self.nll_loss = nn.CrossEntropyLoss(weight=weight, reduction=reduction) @@ -123,35 +126,38 @@ def __init__(self, weight: Optional[Tensor] =None, reduction: str = "none"): ) def forward(self, inputs, targets): - """Feedforward.""" + """ + Feedforward. + """ return self.nll_loss(inputs, targets) class CombinedLoss(nn.Module): - """For CrossEntropy the input has to be a long tensor. + """ + For CrossEntropy the input has to be a long tensor. Attributes ---------- cross_entropy_loss - Results of cross entropy loss + Results of cross entropy loss. dice_loss - Results of dice loss + Results of dice loss. weight_dice - Weight for dice loss + Weight for dice loss. weight_ce - Weight for float + Weight for float. """ def __init__(self, weight_dice: Real = 1, weight_ce: Real = 1): - """Construct CobinedLoss object. + """ + Construct CobinedLoss object. Parameters ---------- weight_dice : Real - Weight for dice loss. Defaults to 1 + Weight for dice loss. Defaults to 1. weight_ce : Real - Weight for cross entropy loss. Defaults to 1 - + Weight for cross entropy loss. Defaults to 1. """ super(CombinedLoss, self).__init__() self.cross_entropy_loss = CrossEntropy2D() @@ -160,31 +166,28 @@ def __init__(self, weight_dice: Real = 1, weight_ce: Real = 1): self.weight_ce = weight_ce def forward( - self, - inputx: Tensor, - target: Tensor, - weight: Tensor + self, inputx: Tensor, target: Tensor, weight: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: - """[MISSING]. + """ + Calculate the total loss, dice loss and cross entropy value for the given input. Parameters ---------- inputx : Tensor - A Tensor of shape N x C x H x W containing the input x values + A Tensor of shape N x C x H x W containing the input x values. target : Tensor - A Tensor of shape N x H x W of integers containing the target + A Tensor of shape N x H x W of integers containing the target. weight : Tensor - A Tensor of shape N x H x W of floats containg the weights + A Tensor of shape N x H x W of floats containg the weights. Returns ------- Tensor - Total loss + Total loss. Tensor - Dice loss + Dice loss. Tensor - Cross entropy value - + Cross entropy value. """ # Typecast to long tensor --> labels are bytes initially (uint8), # index operations require LongTensor in pytorch @@ -206,30 +209,30 @@ def forward( def get_loss_func( - cfg: yacs.config.CfgNode + cfg: yacs.config.CfgNode, ) -> Union[CombinedLoss, CrossEntropy2D, DiceLoss]: - """Give a default object of the loss function. + """ + Give a default object of the loss function. Parameters ---------- cfg : yacs.config.CfgNode - configuration node, containing searched loss function. - The model loss function can either be 'combined', 'ce' or 'dice' + Configuration node, containing searched loss function. + The model loss function can either be 'combined', 'ce' or 'dice'. Returns ------- CombinedLoss - Total loss + Total loss. CrossEntropy2D - Cross entropy value + Cross entropy value. DiceLoss - Dice loss + Dice loss. Raises ------ NotImplementedError - Requested loss function is not implemented - + Requested loss function is not implemented. """ if cfg.MODEL.LOSS_FUNC == "combined": return CombinedLoss() diff --git a/FastSurferCNN/models/networks.py b/FastSurferCNN/models/networks.py index ac33e95d..0c83b79a 100644 --- a/FastSurferCNN/models/networks.py +++ b/FastSurferCNN/models/networks.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. - # IMPORTS -from torch import nn, Tensor +from typing import Optional, TYPE_CHECKING + import numpy as np -from typing import Optional, Union, Dict -import yacs +from torch import Tensor, nn +if TYPE_CHECKING: + import yacs.config -import FastSurferCNN.models.sub_module as sm import FastSurferCNN.models.interpolation_layer as il +import FastSurferCNN.models.sub_module as sm class FastSurferCNNBase(nn.Module): - """Network Definition of Fully Competitive Network network. + """ + Network Definition of Fully Competitive Network network. * Spatial view aggregation (input 7 slices of which only middle one gets segmented) * Same Number of filters per layer (normally 64) @@ -37,11 +39,11 @@ class FastSurferCNNBase(nn.Module): Attributes ---------- encode1, encode2, encode3, encode4 - Competitive Encoder Blocks + Competitive Encoder Blocks. decode1, decode2, decode3, decode4 - Competitive Decoder Blocks + Competitive Decoder Blocks. bottleneck - Bottleneck Block + Bottleneck Block. Methods ------- @@ -49,16 +51,17 @@ class FastSurferCNNBase(nn.Module): Feedforward through graph. """ - def __init__(self, params: Dict, padded_size: int = 256): - """Construct FastSurferCNNBase object. + def __init__(self, params: dict, padded_size: int = 256): + """ + Construct FastSurferCNNBase object. Parameters ---------- params : Dict + Parameters in dictionary format - padded_size : int - size of image when padded (Default value = 256) - + padded_size : int, default = 256 + Size of image when padded (Default value = 256). """ super(FastSurferCNNBase, self).__init__() @@ -89,27 +92,27 @@ def __init__(self, params: Dict, padded_size: int = 256): nn.init.constant_(m.bias, 0) def forward( - self, - x: Tensor, - scale_factor: Optional[Tensor] = None, - scale_factor_out: Optional[Tensor] =None + self, + x: Tensor, + scale_factor: Optional[Tensor] = None, + scale_factor_out: Optional[Tensor] = None, ) -> Tensor: - """Feedforward through graph. + """ + Feedforward through graph. - Parameters [MISSING] + Parameters ---------- x : Tensor - input image [N, C, H, W] - scale_factor : Optional[Tensor] - [N, 1] Defaults to None (Default value = None) - scale_factor_out : Optional[Tensor] - (Default value = None) + Input image [N, C, H, W] representing the input data. + scale_factor : Tensor, optional + [N, 1] Defaults to None. + scale_factor_out : Tensor, optional + Tensor representing the scale factor for the output. Defaults to None. Returns ------- decoder_output1 : Tensor - prediction logits - + Prediction logits. """ encoder_output1, skip_encoder_1, indices_1 = self.encode1.forward(x) encoder_output2, skip_encoder_2, indices_2 = self.encode2.forward( @@ -139,30 +142,30 @@ def forward( class FastSurferCNN(FastSurferCNNBase): - """Main Fastsurfer CNN Network. + """ + Main Fastsurfer CNN Network. Attributes ---------- classifier - Initialized Classification Block + Initialized Classification Block. Methods ------- forward - Feedforward through graph - + Feedforward through graph. """ - def __init__(self, params: Dict, padded_size: int): - """Construct FastSurferCNN object. + def __init__(self, params: dict, padded_size: int): + """ + Construct FastSurferCNN object. Parameters ---------- params : Dict - dictionary of configurations + Dictionary of configurations. padded_size : int - size of image when padded - + Size of image when padded. """ super(FastSurferCNN, self).__init__(params) params["num_channels"] = params["num_filters"] @@ -179,27 +182,27 @@ def __init__(self, params: Dict, padded_size: int): nn.init.constant_(m.bias, 0) def forward( - self, - x: Tensor, - scale_factor: Optional[Tensor] = None, - scale_factor_out: Optional[Tensor] = None + self, + x: Tensor, + scale_factor: Optional[Tensor] = None, + scale_factor_out: Optional[Tensor] = None, ) -> Tensor: - """Feedforward through graph. + """ + Feedforward through graph. Parameters ---------- x : Tensor - input image [N, C, H, W] - scale_factor : Optional[Tensor] - [N, 1] Defaults to None - scale_factor_out : Optional[Tensor] - Defaults to None + Input image [N, C, H, W]. + scale_factor : Tensor, optional + [N, 1] Defaults to None. + scale_factor_out : Tensor, optional + Tensor representing the scale factor for the output. Defaults to None. Returns ------- output : Tensor - Prediction logits - + Prediction logits. """ net_out = super().forward(x, scale_factor) output = self.classifier.forward(net_out) @@ -208,7 +211,8 @@ def forward( class FastSurferVINN(FastSurferCNNBase): - """Network Definition of Fully Competitive Network. + """ + Network Definition of Fully Competitive Network. * Spatial view aggregation (input 7 slices of which only middle one gets segmented) * Same Number of filters per layer (normally 64) @@ -221,26 +225,25 @@ class FastSurferVINN(FastSurferCNNBase): Attributes ---------- height - the height of segmentation model (after interpolation layer) + The height of segmentation model (after interpolation layer). width - the width of segmentation model + The width of segmentation model. out_tensor_shape - Out tensor dimensions for interpolation layer + Out tensor dimensions for interpolation layer. interpolation_mode - Interpolation mode for up/downsampling in flex networks + Interpolation mode for up/downsampling in flex networks. crop_position - Crop positions for up/downsampling in flex networks + Crop positions for up/downsampling in flex networks. inp_block - Initialized input dense block + Initialized input dense block. outp_block - Initialized output dense block + Initialized output dense block. interpol1 - Initialized 2d input interpolation block + Initialized 2d input interpolation block. interpol2 - Initialized 2d output interpolation block + Initialized 2d output interpolation block. classifier - Initialized Classification Block - + Initialized Classification Block. Methods ------- @@ -248,16 +251,16 @@ class FastSurferVINN(FastSurferCNNBase): Feedforward through graph. """ - def __init__(self, params: Dict, padded_size: int = 256): - """Construct FastSurferVINN object. + def __init__(self, params: dict, padded_size: int = 256): + """ + Construct FastSurferVINN object. Parameters ---------- params : Dict - dictionary of configurations - padded_size : int - size of image when padded (Default value = 256) - + Dictionary of configurations. + padded_size : int, default = 256 + Size of image when padded (Default value = 256). """ num_c = params["num_channels"] params["num_channels"] = params["num_filters_interpol"] @@ -326,27 +329,25 @@ def __init__(self, params: Dict, padded_size: int = 256): nn.init.constant_(m.bias, 0) def forward( - self, - x: Tensor, - scale_factor: Tensor, - scale_factor_out: Optional[Tensor] = None + self, x: Tensor, scale_factor: Tensor, scale_factor_out: Optional[Tensor] = None ) -> Tensor: - """Feedforward through graph. + """ + Feedforward through graph. Parameters ---------- x : Tensor - input image [N, C, H, W] + Input image [N, C, H, W]. scale_factor : Tensor - [MISSING] [N, 1] - scale_factor_out : Tensor, Optional - [MISSING] Defaults to None + Tensor of shape [N, 1] representing the scale factor for each image in the + batch. + scale_factor_out : Tensor, optional + Tensor representing the scale factor for the output. Defaults to None. Returns ------- logits : Tensor - prediction logits - + Prediction logits. """ # Input block + Flex to 1 mm skip_encoder_0 = self.inp_block(x) @@ -387,22 +388,22 @@ def forward( } -def build_model(cfg: yacs.config.CfgNode) -> Union[FastSurferCNN, FastSurferVINN]: - """Build requested model. +def build_model(cfg: 'yacs.config.CfgNode') -> FastSurferCNN | FastSurferVINN: + """ + Build requested model. Parameters ---------- cfg : yacs.config.CfgNode - Node of configs to be used + Node of configs to be used. Returns ------- model - Object of the initialized model - + Object of the initialized model. """ assert ( - cfg.MODEL.MODEL_NAME in _MODELS.keys() + cfg.MODEL.MODEL_NAME in _MODELS.keys() ), f"Model {cfg.MODEL.MODEL_NAME} not supported" params = {k.lower(): v for k, v in dict(cfg.MODEL).items()} model = _MODELS[cfg.MODEL.MODEL_NAME](params, padded_size=cfg.DATA.PADDED_SIZE) diff --git a/FastSurferCNN/models/optimizer.py b/FastSurferCNN/models/optimizer.py index d390f1c3..0afe8c17 100644 --- a/FastSurferCNN/models/optimizer.py +++ b/FastSurferCNN/models/optimizer.py @@ -13,32 +13,37 @@ # limitations under the License. # IMPORTS -import torch from typing import Union + +import torch import yacs -from networks import FastSurferCNN, FastSurferVINN +from FastSurferCNN.models.networks import FastSurferCNN, FastSurferVINN -def get_optimizer(model: Union[FastSurferCNN, FastSurferVINN, torch.nn.DataParallel], cfg: yacs.config.CfgNode) -> torch.optim.optimizer.Optimizer: - """Get an instance of requested optimizer. + +def get_optimizer( + model: FastSurferCNN | FastSurferVINN | torch.nn.DataParallel, + cfg: yacs.config.CfgNode, +) -> torch.optim.Optimizer: + """ + Get an instance of requested optimizer. Parameters ---------- - model : Union[FastSurferCNN, FastSurferVINN, torch.nn.DataParallel] - The model for which an optimizer schould be chosen + model : FastSurferCNN, FastSurferVINN, torch.nn.DataParallel + The model for which an optimizer should be chosen. cfg : yacs.config.CfgNode - Configuration Node + Configuration Node. Returns ------- torch.optim.optimizer.Optimizer - SGD, Adam, AdamW or rmsprop optimizer + SGD, Adam, AdamW or rmsprop optimizer. Raises ------ NotImplementedError - Optimizer is not supported - + Optimizer is not supported. """ if cfg.OPTIMIZER.OPTIMIZING_METHOD == "sgd": return torch.optim.SGD( diff --git a/FastSurferCNN/models/sub_module.py b/FastSurferCNN/models/sub_module.py index 2924dc01..67c687d9 100644 --- a/FastSurferCNN/models/sub_module.py +++ b/FastSurferCNN/models/sub_module.py @@ -11,27 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Dict +from typing import Dict, Tuple # IMPORTS import torch -from torch import nn, Tensor +from torch import Tensor, nn # Building Blocks class InputDenseBlock(nn.Module): - """Input Dense Block. + """ + Input Dense Block. Attributes ---------- conv[0-3] - Convolution layers + Convolution layers. bn0 - Batch Normalization + Batch Normalization. gn[1-4] - Batch Normalizations + Batch Normalizations. prelu - Learnable ReLU Parameter + Learnable ReLU Parameter. Methods ------- @@ -40,12 +41,13 @@ class InputDenseBlock(nn.Module): """ def __init__(self, params: Dict): - """Construct InputDenseBlock object. + """ + Construct InputDenseBlock object. Parameters ---------- params : Dict - + Parameters in dictionary format. """ super(InputDenseBlock, self).__init__() # Padding to get output tensor of same dimensions @@ -106,18 +108,18 @@ def __init__(self, params: Dict): self.prelu = nn.PReLU() # Learnable ReLU Parameter def forward(self, x: Tensor) -> Tensor: - """Feedforward through graph. + """ + Feedforward through graph. Parameters ---------- x : Tensor - input image [N, C, H, W] + Input image [N, C, H, W] representing the input data. Returns ------- out : Tensor - [MISSING] - + Output image (processed feature map). """ # Input batch normalization x0_bn = self.bn0(x) @@ -152,7 +154,8 @@ def forward(self, x: Tensor) -> Tensor: class CompetitiveDenseBlock(nn.Module): - """Define a competitive dense block comprising 3 convolutional layers, with BN/ReLU. + """ + Define a competitive dense block comprising 3 convolutional layers, with BN/ReLU. Attributes ---------- @@ -171,19 +174,19 @@ class CompetitiveDenseBlock(nn.Module): Methods ------- forward - Feedforward through graph + Feedforward through graph. """ def __init__(self, params: Dict, outblock: bool = False): - """Construct CompetitiveDenseBlock object. + """ + Construct CompetitiveDenseBlock object. Parameters ---------- params : Dict - dictionary with parameters specifying block architecture + Dictionary with parameters specifying block architecture. outblock : bool - Flag indicating if last block (Default value = False) - + Flag indicating if last block (Default value = False). """ super(CompetitiveDenseBlock, self).__init__() @@ -245,21 +248,21 @@ def __init__(self, params: Dict, outblock: bool = False): self.outblock = outblock def forward(self, x: Tensor) -> Tensor: - """Feedforward through CompetitiveDenseBlock. + """ + Feedforward through CompetitiveDenseBlock. {in (Conv - BN from prev. block) -> PReLU} -> {Conv -> BN -> Maxout -> PReLU} x 2 -> {Conv -> BN} -> out - end with batch-normed output to allow maxout across skip-connections + end with batch-normed output to allow maxout across skip-connections. Parameters ---------- x : Tensor - input tensor (image or feature map) + Input tensor (image or feature map). Returns ------- out - output tensor (processed feature map) - + Output tensor (processed feature map). """ # Activation from pooled input x0 = self.prelu(x) @@ -298,8 +301,9 @@ def forward(self, x: Tensor) -> Tensor: class CompetitiveDenseBlockInput(nn.Module): - """Define a competitive dense block comprising 3 convolutional layers, with BN/ReLU for input. - + """ + Define a competitive dense block comprising 3 convolutional layers, with BN/ReLU for input. + Attributes ---------- params (dict): {'num_channels': 1, @@ -315,17 +319,17 @@ class CompetitiveDenseBlockInput(nn.Module): Methods ------- forward - Feedforward through graph + Feedforward through graph. """ def __init__(self, params: Dict): - """Construct CompetitiveDenseBlockInput object. + """ + Construct CompetitiveDenseBlockInput object. Parameters ---------- params : Dict - dictionary with parameters specifying block architecture - + Dictionary with parameters specifying block architecture. """ super(CompetitiveDenseBlockInput, self).__init__() @@ -381,20 +385,20 @@ def __init__(self, params: Dict): self.prelu = nn.PReLU() # Learnable ReLU Parameter def forward(self, x: Tensor) -> Tensor: - """Feed forward trough CompetitiveDenseBlockInput. + """ + Feed forward trough CompetitiveDenseBlockInput. in -> BN -> {Conv -> BN -> PReLU} -> {Conv -> BN -> Maxout -> PReLU} -> {Conv -> BN} -> out Parameters ---------- x : Tensor - input tensor (image or feature map) + Input tensor (image or feature map). Returns ------- out - output tensor (processed feature map) - + Output tensor (processed feature map). """ # Input batch normalization x0_bn = self.bn0(x) @@ -428,24 +432,25 @@ def forward(self, x: Tensor) -> Tensor: class GaussianNoise(nn.Module): - """Define a Gaussian Noise Block. - + """ + Define a Gaussian Noise Block. + Methods ------- forward - Feedforward through graph + Feedforward through graph. """ def __init__(self, sigma: float = 0.1, device: str = "cuda"): - """Construct GaussianNoise object. + """ + Construct GaussianNoise object. Parameters ---------- - sigma : float - [MISSING] (Default value = 0.1) - device : str - [MISSING] (Default value = "cuda") - + sigma : float, default=0.1 + Standard deviation of the GaussianNoise (Default value = 0.1). + device : str, default="cuda" + Device to run the model on (Default value = "cuda"). """ super().__init__() self.sigma = sigma @@ -453,18 +458,18 @@ def __init__(self, sigma: float = 0.1, device: str = "cuda"): self.register_buffer("noise", torch.tensor(0)) def forward(self, x: Tensor) -> Tensor: - """Feedforward through graph. + """ + Feedforward through graph. Parameters ---------- x : Tensor - Input Tensor + Input Tensor. Returns ------- x : Tensor - output tensor (processed feature map) - + Output tensor (processed feature map). """ if self.training and self.sigma != 0: scale = self.sigma * x.detach() @@ -477,27 +482,28 @@ def forward(self, x: Tensor) -> Tensor: # Encoder/Decoder definitions ## class CompetitiveEncoderBlock(CompetitiveDenseBlock): - """Encoder Block = CompetitiveDenseBlock + Max Pooling. + """ + Encoder Block = CompetitiveDenseBlock + Max Pooling. Attributes ---------- maxpool - Maxpool layer + Maxpool layer. Methods ------- forward - Feed forward trough graph + Feed forward trough graph. """ def __init__(self, params: Dict): - """Construct CompetitiveEncoderBlock object. + """ + Construct CompetitiveEncoderBlock object. Parameters ---------- params : Dict Parameters like number of channels, stride etc. - """ super(CompetitiveEncoderBlock, self).__init__(params) self.maxpool = nn.MaxPool2d( @@ -507,7 +513,8 @@ def __init__(self, params: Dict): ) # For Unpooling later on with the indices def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """Feed forward trough Encoder Block. + """ + Feed forward trough Encoder Block. * CompetitiveDenseBlock * Max Pooling (+ retain indices) @@ -515,17 +522,16 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: Parameters ---------- x : Tensor - feature map from previous block + Feature map from previous block. Returns ------- out_encoder : Tensor - original feature map + Original feature map. out_block : Tensor - maxpooled feature map + Maxpooled feature map. indicies : Tensor - maxpool indices - + Maxpool indices. """ out_block = super(CompetitiveEncoderBlock, self).forward( x @@ -537,16 +543,18 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: class CompetitiveEncoderBlockInput(CompetitiveDenseBlockInput): - """Encoder Block = CompetitiveDenseBlockInput + Max Pooling.""" + """ + Encoder Block = CompetitiveDenseBlockInput + Max Pooling. + """ def __init__(self, params: Dict): - """Construct CompetitiveEncoderBlockInput object. + """ + Construct CompetitiveEncoderBlockInput object. Parameters ---------- params : Dict - parameters like number of channels, stride etc. - + Parameters like number of channels, stride etc. """ super(CompetitiveEncoderBlockInput, self).__init__( params @@ -558,7 +566,8 @@ def __init__(self, params: Dict): ) # For Unpooling later on with the indices def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """Feed forward trough Encoder Block. + """ + Feed forward trough Encoder Block. * CompetitiveDenseBlockInput * Max Pooling (+ retain indices) @@ -566,12 +575,16 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: Parameters ---------- x : Tensor - feature map from previous block + Feature map from previous block. Returns ------- - original feature map, maxpooled feature map, maxpool indices - + Tensor + The original feature map as received by the block. + Tensor + The maxpooled feature map after applying max pooling to the original feature map. + Tensor + The indices of the maxpool operation. """ out_block = super(CompetitiveEncoderBlockInput, self).forward( x @@ -583,19 +596,21 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: class CompetitiveDecoderBlock(CompetitiveDenseBlock): - """Decoder Block = (Unpooling + Skip Connection) --> Dense Block.""" + """ + Decoder Block = (Unpooling + Skip Connection) --> Dense Block. + """ def __init__(self, params: Dict, outblock: bool = False): - """Construct CompetitiveDecoderBlock object. + """ + Construct CompetitiveDecoderBlock object. Parameters ---------- params : Dict - parameters like number of channels, stride etc. + Parameters like number of channels, stride etc. outblock : bool Flag, indicating if last block of network before classifier is created.(Default value = False) - """ super(CompetitiveDecoderBlock, self).__init__(params, outblock=outblock) self.unpool = nn.MaxUnpool2d( @@ -603,7 +618,8 @@ def __init__(self, params: Dict, outblock: bool = False): ) def forward(self, x: Tensor, out_block: Tensor, indices: Tensor) -> Tensor: - """Feed forward trough Decoder block. + """ + Feed forward trough Decoder block. * Unpooling of feature maps from lower block * Maxout combination of unpooled map + skip connection @@ -612,17 +628,16 @@ def forward(self, x: Tensor, out_block: Tensor, indices: Tensor) -> Tensor: Parameters ---------- x : Tensor - input feature map from lower block (gets unpooled and maxed with out_block) + Input feature map from lower block (gets unpooled and maxed with out_block). out_block : Tensor - skip connection feature map from the corresponding Encoder + Skip connection feature map from the corresponding Encoder. indices : Tensor - indices for unpooling from the corresponding Encoder (maxpool op) + Indices for unpooling from the corresponding Encoder (maxpool op). Returns ------- out_block - processed feature maps - + Processed feature maps. """ unpool = self.unpool(x, indices) concat_max = torch.maximum(unpool, out_block) @@ -632,16 +647,17 @@ def forward(self, x: Tensor, out_block: Tensor, indices: Tensor) -> Tensor: class OutputDenseBlock(nn.Module): - """Dense Output Block = (Upinterpolated + Skip Connection) --> Semi Competitive Dense Block. + """ + Dense Output Block = (Upinterpolated + Skip Connection) --> Semi Competitive Dense Block. Attributes ---------- conv0, conv1, conv2, conv3 - Convolution layers + Convolution layers. gn0, gn1, gn2, gn3, gn4 - Normalization layers + Normalization layers. prelu - PReLU activation layer + PReLU activation layer. Methods ------- @@ -650,13 +666,13 @@ class OutputDenseBlock(nn.Module): """ def __init__(self, params: dict): - """Construct OutputDenseBlock object. + """ + Construct OutputDenseBlock object. Parameters ---------- params : dict - parameters like number of channels, stride etc. - + Parameters like number of channels, stride etc. """ super(OutputDenseBlock, self).__init__() @@ -711,7 +727,8 @@ def __init__(self, params: dict): self.prelu = nn.PReLU() # Learnable ReLU Parameter def forward(self, x: Tensor, out_block: Tensor) -> Tensor: - """Feed forward trough Output block. + """ + Feed forward trough Output block. * Maxout combination of unpooled map from previous block + skip connection * Forwarding toward CompetitiveDenseBlock @@ -719,15 +736,14 @@ def forward(self, x: Tensor, out_block: Tensor) -> Tensor: Parameters ---------- x : Tensor - up-interpolated input feature map from lower block (gets maxed with out_block) + Up-interpolated input feature map from lower block (gets maxed with out_block). out_block : Tensor - skip connection feature map from the corresponding Encoder + Skip connection feature map from the corresponding Encoder. Returns ------- out - processed feature maps - + Processed feature maps. """ # Concatenation along channel (different number of channels from decoder and skip connection) concat = torch.cat((x, out_block), dim=1) @@ -764,16 +780,18 @@ def forward(self, x: Tensor, out_block: Tensor) -> Tensor: class ClassifierBlock(nn.Module): - """Classification Block.""" + """ + Classification Block. + """ def __init__(self, params: dict): - """Construct ClassifierBlock object. + """ + Construct ClassifierBlock object. Parameters ---------- params : dict - parameters like number of channels, stride etc. - + Parameters like number of channels, stride etc. """ super(ClassifierBlock, self).__init__() self.conv = nn.Conv2d( @@ -784,18 +802,18 @@ def __init__(self, params: dict): ) # To generate logits def forward(self, x: Tensor) -> Tensor: - """Feed forward trough classifier. + """ + Feed forward trough classifier. Parameters ---------- x : Tensor - Output of last CompetitiveDenseDecoder Block- + Output of last CompetitiveDenseDecoder Block. Returns ------- logits - prediction logits - + Prediction logits. """ logits = self.conv(x) diff --git a/FastSurferCNN/mri_brainvol_stats.py b/FastSurferCNN/mri_brainvol_stats.py new file mode 100644 index 00000000..e90b3081 --- /dev/null +++ b/FastSurferCNN/mri_brainvol_stats.py @@ -0,0 +1,145 @@ +#!/bin/python + +# Copyright 2024 Image Analysis Lab, German Center for Neurodegenerative Diseases +# (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +import argparse +from os import environ as env +from pathlib import Path + +from FastSurferCNN.segstats import HelpFormatter, main, VERSION +from FastSurferCNN.mri_segstats import print_and_exit + +DEFAULT_MEASURES_STRINGS = [ + (False, "BrainSeg"), + (False, "BrainSegNotVent"), + (False, "SupraTentorial"), + (False, "SupraTentorialNotVent"), + (False, "SubCortGray"), + (False, "lhCortex"), + (False, "rhCortex"), + (False, "Cortex"), + (False, "TotalGray"), + (False, "lhCerebralWhiteMatter"), + (False, "rhCerebralWhiteMatter"), + (False, "CerebralWhiteMatter"), + (False, "Mask"), + (False, "SupraTentorialNotVentVox"), + (False, "BrainSegNotVentSurf"), + (False, "VentricleChoroidVol"), +] +DEFAULT_MEASURES = list((False, m) for m in DEFAULT_MEASURES_STRINGS) + +USAGE = "python mri_brainvol_stats.py -s " +HELPTEXT = f""" +Dependencies: + + Python 3.10 + + Numpy + http://www.numpy.org + + Nibabel to read images + http://nipy.org/nibabel/ + + Pandas to read/write stats files etc. + https://pandas.pydata.org/ + +Original Author: David Kügler +Date: Jan-23-2024 + +Revision: {VERSION} +""" +DESCRIPTION = """ +Translates mri_brainvol_stats options for segstats.py. Options not listed here have no +equivalent representation in segstats.py. """ + + +def make_arguments() -> argparse.ArgumentParser: + """Make the argument parser.""" + parser = argparse.ArgumentParser( + usage=USAGE, + epilog=HELPTEXT.replace("\n", "
"), + description=DESCRIPTION, + formatter_class=HelpFormatter, + ) + parser.add_argument( + "--print", + action="append_const", + dest="parse_actions", + default=[], + const=print_and_exit, + help="Print the equivalent native segstats.py options and exit.", + ) + default_sd = Path(env["SUBJECTS_DIR"]) if "SUBJECTS_DIR" in env else None + parser.add_argument( + "--sd", + dest="out_dir", metavar="subjects_dir", type=Path, + default=default_sd, + required=not bool(default_sd), + help="set SUBJECTS_DIR, defaults to environment SUBJECTS_DIR, required to find " + "several files used by measures, e.g. surfaces.") + parser.add_argument( + "-s", + "--subject", + "--sid", + dest="sid", metavar="subject_id", + help="set subject_id, required to find several files used by measures, e.g. " + "surfaces.") + parser.add_argument( + "-o", + "--segstatsfile", + dest="segstatsfile", + default=Path("stats/brainvol.stats"), + help="Where to save the brainvol.stats, if relative path, this will be " + "relative to the subject directory." + ) + fs_home = "FREESURFER_HOME" + default_lut = Path(env[fs_home]) / "ASegStatsLUT.txt" if fs_home in env else None + parser.set_defaults( + segfile=Path("mri/aseg.mgz"), + measures=DEFAULT_MEASURES, + lut=default_lut, + measure_only=True, + ) + advanced = parser.add_argument_group( + "FastSurfer options (no equivalence with FreeSurfer's mri_brainvol_stats)", + ) + advanced.add_argument( + "--no_legacy", + action="store_false", + dest="legacy_freesurfer", + help="use FastSurfer algorithms instead of FastSurfer.", + ) + advanced.add_argument( + "--pvfile", + "-pv", + type=Path, + dest="pvfile", + help="Path to image used to compute the partial volume effects. This file is " + "only used in the FastSurfer algoritms (--no_legacy).", + ) + return parser + + +if __name__ == "__main__": + import sys + + args = make_arguments().parse_args() + parse_actions = getattr(args, "parse_actions", []) + for parse_action in parse_actions: + parse_action(args) + sys.exit(main(args)) diff --git a/FastSurferCNN/mri_segstats.py b/FastSurferCNN/mri_segstats.py new file mode 100644 index 00000000..bf25a71a --- /dev/null +++ b/FastSurferCNN/mri_segstats.py @@ -0,0 +1,545 @@ +#!/bin/python + +# Copyright 2024 Image Analysis Lab, German Center for Neurodegenerative Diseases +# (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +import argparse +from itertools import pairwise, chain +from pathlib import Path +from typing import TypeVar, Sequence, Any, Iterable + +from FastSurferCNN.segstats import ( + main, + HelpFormatter, + add_two_help_messages, + VERSION, + empty, +) + +_T = TypeVar("_T") + + +USAGE = "python mri_segstats.py --seg segvol [optional arguments]" +HELPTEXT = f""" +Dependencies: + + Python 3.10 + + Numpy + http://www.numpy.org + + Nibabel to read images + http://nipy.org/nibabel/ + + Pandas to read/write stats files etc. + https://pandas.pydata.org/ + +Original Author: David Kügler +Date: Jan-04-2024 + +Revision: {VERSION} +""" +DESCRIPTION = """ +Translates mri_segstats options for segstats.py. Options not listed here have no +equivalent representation in segstats.py.
+IMPORTANT NOTES +mri_segstats uses a legacy version for the computation of measures (from FreeSurfer 6). +But mri_segstats.py implements the behavior if first mri_brainvol_stats +and then mri_segstats is run (which uses the stats/brainvol.stats generated by +mri_brainvol_stats). This reflects the output of stats files as created by FreeSurfer's +recon-all. +""" +ETIV_RATIO_KEY = "eTIV-ratios" +ETIV_RATIOS = {"BrainSegVol-to-eTIV": "BrainSeg", "MaskVol-to-eTIV": "Mask"} +ETIV_FROM_TAL = "EstimatedTotalIntraCranialVol" + + +class _ExtendConstAction(argparse.Action): + """Helper class to allow action='extend_const' by action=_ExtendConstAction.""" + def __init__( + self, + option_strings: Sequence[str], + dest: str, + const: _T | None = None, + default: _T | str | None = None, + required: bool = False, + help: str | None = None, + metavar: str | tuple[str, ...] | None = None, + ) -> None: + super(_ExtendConstAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=const, + default=default, + required=required, + help=help, + metavar=metavar, + ) + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: str | Sequence[Any], + option_string: str | None = None, + ) -> None: + """ + Extend attribute `self.dest` of `namespace` with the values in `self.const`. + """ + items = getattr(namespace, self.dest, None) + if items is None: + items = [] + elif type(items) is list: + items = items[:] + else: + import copy + items = copy.copy(items) + items.extend(self.const) + setattr(namespace, self.dest, items) + + +def make_arguments() -> argparse.ArgumentParser: + """Create an argument parser object with all parameters of the script.""" + parser = argparse.ArgumentParser( + usage=USAGE, + epilog=HELPTEXT.replace("\n", "
"), + description=DESCRIPTION, + formatter_class=HelpFormatter, + add_help=False, + ) + + def add_etiv_measures(args: argparse.Namespace) -> None: + measures: list[tuple[bool, str]] = getattr(args, "measures", []) + measure_strings = list(map(lambda x: x[1], measures)) + if all(m in measure_strings for m in (ETIV_RATIO_KEY, ETIV_FROM_TAL)): + + measures = [m for m in measures if m[1] == ETIV_RATIO_KEY] + for k, v in ETIV_RATIOS.items(): + for is_imported, m in measures: + if m == v or m.startswith(v + "("): + measures.append((False, k)) + continue + setattr(args, "measures", measures) + + def _update_what_to_import(args: argparse.Namespace) -> argparse.Namespace: + """ + Update the Namespace object based on the existence of the brainvol.stats file. + """ + cachefile = Path(args.measurefile) + if not cachefile.is_absolute(): + cachefile = args.out_dir / args.sid / cachefile + if not args.explicit_no_cached and cachefile.is_file(): + from FastSurferCNN.utils.brainvolstats import read_measure_file + + measure_data = read_measure_file(cachefile) + + def update_key(measure: tuple[bool, str]) -> tuple[bool, str]: + is_imported, measure_string = measure + measure_key, _, _ = measure_string.partition("(") + return measure_key in measure_data.keys(), measure_string + + # for each measure to be computed, replace it with the value in + # brainvol.stats, if available + args.measures = list(map(update_key, getattr(args, "measures", []))) + return args + + parser.set_defaults( + measurefile="stats/brainvol.stats", + parse_actions=[(1, add_etiv_measures), (10, _update_what_to_import)], + volume_precision=1, + ) + + if "--help" in sys.argv: + from FastSurferCNN.utils.brainvolstats import Manager + manager = Manager([]) + + def help_text(keys: Iterable[str]) -> Iterable[str]: + return (manager[k].help() for k in keys) + else: + help_text = None + + def help_add_measures(message: str, keys: list[str]) -> str: + if help_text: + _keys = (k.split(' ')[0] for k in keys) + keys = [f"{k}: {text}" for k, text in zip(keys, help_text(_keys))] + return "
- ".join([message] + list(keys)) + + add_two_help_messages(parser) + parser.add_argument( + "--print", + action="append_const", + dest="parse_actions", + const=(0, print_and_exit), + help="Print the equivalent native segstats.py options and exit.", + ) + parser.add_argument( + "--version", + action="version", + version=f"%(prog)s {VERSION}", + help="Print the version of the mri_segstats.py script", + ) + parser.add_argument( + "--seg", + type=Path, + metavar="segvol", + dest="segfile", + help="Specify the segmentation file.", + ) + # --annot subject hemi parc + # --surf whitesurfname -- used with annot + # --slabel subject hemi full_path_to_label + # --label-thresh threshold + # --seg-from-input + parser.add_argument( + "--o", + "--sum", + type=Path, + metavar="file", + dest="segstatsfile", + help="Specifiy the output summary statistics file.", + ) + parser.add_argument( + "--pv", + type=Path, + metavar="pvvol", + dest="pvfile", + help="file to compensate for partial volume effects.", + ) + parser.add_argument( + "--i", + "--in", + type=Path, + metavar="invol", + dest="normfile", + help="file to compute intensity values.", + ) + + # --seg-erode Nerodes + # --frame frame + def _percent(__value) -> float: + return float(__value) / 50 + + parser.add_argument( + "--robust", + type=_percent, + metavar="percent", + dest="robust", + help="Compute stats after excluding percent from high and and low values, e.g. " + "with --robust 2, min and max are the 2nd and the 98th percentiles.", + ) + + def _add_invol_op(*flags: str, op: str, metavar: str | None = None) -> None: + if metavar: + def _optype(_a) -> str: + # test the argtype for float as well + return f"{flags[0].lstrip('-')}={float(_a)}" + kwargs = { + "action": "append", + "type": _optype, + "dest": "pvfile_preproc", + "help": f"Apply the {op} with `{metavar}` to `invol` (--in)", + } + else: + kwargs = { + "action": "append_const", + "const": flags[0].lstrip("-"), + "dest": "pvfile_preproc", + "help": f"Apply {op} to `invol` (--in)", + } + parser.add_argument(*flags, **kwargs) + + def _no_import(*args: str) -> list[tuple[bool, str]]: + return list((False, a) for a in args) + + _add_invol_op("--sqr", op="squaring") + _add_invol_op("--sqrt", op="the square root") + _add_invol_op("--mul", op="multiplication", metavar="val") + _add_invol_op("--div", op="division", metavar="val") + # --snr + _add_invol_op("--abs", op="absolute value") + # --accumulate + parser.add_argument( + "--ctab", + type=Path, + metavar="ctabfile", + dest="lut", + help="load the Color Lookup Table.", + ) + import os + env = os.environ + if "FREESURFER_HOME" in env: + default_lut = Path(env["FREESURFER_HOME"]) / "FreeSurferColorLUT.txt" + elif "FASTSURFER_HOME" in env: + default_lut = ( + Path(env["FASTSURFER_HOME"]) / "FastSurferCNN/config/FreeSurferColorLUT.txt" + ) + else: + default_lut = None + parser.add_argument( + "--ctab-default", + metavar="ctabfile", + dest="lut", + const=default_lut, + action="store_const", + help="load default Color Lookup Table (from FREESURFER_HOME or " + "FASTSURFER_HOME).", + ) + # --ctab-gca gcafile + parser.add_argument( + "--id", + type=int, + nargs="+", + metavar="segid", + action="extend", + dest="ids", + default=[], + help="Specify segmentation Exclude segmentation ids from report.", + ) + parser.add_argument( + "--excludeid", + type=int, + nargs="+", + metavar="segid", + dest="excludeid", + help="Exclude segmentation ids from report.", + ) + parser.add_argument( + "--no-cached", + action="store_true", + dest="explicit_no_cached", + help="Do not try to load stats/brainvol.stats.", + ) + parser.add_argument( + "--excl-ctxgmwm", + dest="excludeid", + action=_ExtendConstAction, + const=[2, 3, 41, 42], + help="Exclude cortical gray and white matter regions from volume stats.", + ) + surf_wm = ["rhCerebralWhiteMatter", "lhCerebralWhiteMatter", "CerebralWhiteMatter"] + parser.add_argument( + "--surf-wm-vol", + action=_ExtendConstAction, + dest="measures", + const=_no_import(*surf_wm), + help=help_add_measures( + "Compute cortical white matter based on the surface:", + surf_wm, + ), + ) + surf_ctx = ["rhCortex", "lhCortex", "Cortex"] + parser.add_argument( + "--surf-ctx-vol", + action=_ExtendConstAction, + dest="measures", + const=_no_import(*surf_ctx), + help=help_add_measures( + "compute cortical gray matter based on the surface:", + surf_ctx, + ), + ) + parser.add_argument( + "--no_global_stats", + action="store_const", + dest="measures", + const=[], + help="Resets the computed global stats.", + ) + parser.add_argument( + "--empty", + action="store_true", + dest="empty", + help="Report all segmentation labels in ctab, even if they are not in seg.", + ) + # --ctab-out ctaboutput + # --mask maskvol + # --maskthresh thresh + # --masksign sign + # --maskframe frame + # --maskinvert + # --maskerode nerode + brainseg = ["BrainSeg", "BrainSegNotVent"] + parser.add_argument( + "--brain-vol-from-seg", + action=_ExtendConstAction, + dest="measures", + const=_no_import(*brainseg), + help=help_add_measures("Compute measures BrainSeg measures:", brainseg), + ) + + def _mask(__value): + return False, "Mask(" + str(__value) + ")" + + parser.add_argument( + "--brainmask", + type=_mask, + metavar="brainmask", + action="append", + dest="measures", + help="Report the Volume of the brainmask", + ) + supratent = ["SupraTentorial", "SupraTentorialNotVent"] + parser.add_argument( + "--supratent", + action=_ExtendConstAction, + dest="measures", + const=_no_import(*supratent), + help=help_add_measures("Compute supratentorial measures:", supratent), + ) + parser.add_argument( + "--subcortgray", + action="append_const", + dest="measures", + const=(False, "SubCortGray"), + help=help_add_measures("Compute measure SubCortGray:", ["SubCortGray"]), + ) + parser.add_argument( + "--totalgray", + action="append_const", + dest="measures", + const=(False, "TotalGray"), + help=help_add_measures("Compute measure TotalGray:", ["TotalGray"]), + ) + etiv_measures = [f"{k} (if also --brain-vol-from-seg)" for k in ETIV_RATIOS] + parser.add_argument( + "--etiv", + action=_ExtendConstAction, + dest="measures", + const=_no_import(ETIV_FROM_TAL, ETIV_RATIO_KEY), + help=help_add_measures("Compute eTIV:", [ETIV_FROM_TAL] + etiv_measures), + ) + # --etiv-only + # --old-etiv-only + # --xfm2etiv xfm outfile + surf_holes = ["rhSurfaceHoles", "lhSurfaceHoles", "SurfaceHoles"] + parser.add_argument( + "--euler", + action=_ExtendConstAction, + dest="measures", + const=_no_import(*surf_holes), + help=help_add_measures("Compute surface holes measures:", surf_holes), + ) + # --avgwf textfile + # --sumwf testfile + # --avgwfvol mrivol + # --avgwf-remove-mean + # --sfavg textfile + # --vox C R S + # --replace ID1 ID2 + # --replace-file file + # --gtm-default-seg-merge + # --gtm-default-seg-merge-choroid + # --ga-stats subject statsfile + default_sd = Path(env["SUBJECTS_DIR"]) if "SUBJECTS_DIR" in env else None + parser.add_argument( + "--sd", + dest="out_dir", metavar="subjects_dir", type=Path, + default=default_sd, + help="set SUBJECTS_DIR, defaults to environment SUBJECTS_DIR, required to find " + "several files used by measures, e.g. surfaces.") + parser.add_argument( + "--subject", + dest="sid", metavar="subject_id", + help="set subject_id, required to find several files used by measures, e.g. " + "surfaces.") + parser.add_argument( + "--seed", + nargs=1, metavar="N", help="The seed has no effect") + parser.add_argument( + "--in-intensity-name", + type=str, + dest="norm_name", + default="", + help="name of the intensity image" + ) + parser.add_argument( + "--in-intensity-units", + type=str, + dest="norm_unit", + default="", + help="unit of the intensity image" + ) + parser.add_argument( + "--no_legacy", + action="store_false", + dest="legacy_freesurfer", + help="use fastsurfer algorithms instead of fastsurfer." + ) + return parser + + +def print_and_exit(args: object): + """Print the commandline arguments of the segstats script to stdout and exit.""" + print(" ".join(format_cmdline_args(args))) + import sys + sys.exit(0) + + +def format_cmdline_args(args: object) -> list[str]: + """Format the commandline arguments of the segstats script.""" + arglist = ["python", str(Path(__file__).parent / "segstats.py")] + # this entry has nargs="+" and should therefore be up top + if not empty(__ids := getattr(args, "ids", [])): + arglist.extend(["--id"] + list(map(str, __ids))) + + def _append_storetrue(name: str, flag: str = ""): + if flag == "": + flag = "--" + name + if getattr(args, name, False): + arglist.append(flag) + + def _extend_arg(name: str, flag: str = None): + if flag == "": + flag = "--" + name + if (value := getattr(args, name, None)) is not None: + arglist.extend([flag, str(value)]) + + _append_storetrue("allow_root") + _append_storetrue("legacy_freesurfer") + _extend_arg("segfile") + _extend_arg("normfile") + _extend_arg("pvfile") + _extend_arg("segstatsfile") + _extend_arg("out_dir", "--sd") + _extend_arg("sid") + _extend_arg("threads") + _extend_arg("lut") + _extend_arg("volume_precision") + + measures: list[tuple[bool, str]] = getattr(args, "measures", []) + if not empty(measures): + arglist.append("measures") + _extend_arg("measurefile", "--file") + _flag = {True: "--import", False: "--compute"} + blank_measure = (not measures[0][0], "") + flag_measure_iter = ((_flag[i], m) for i, m in [blank_measure, *measures]) + arglist.extend(chain( + (*((flag,) if flag != last_flag else ()), str(measure)) + for (last_flag, _), (flag, measure) in pairwise(flag_measure_iter) + )) + + return arglist + + +if __name__ == "__main__": + import sys + + args = make_arguments().parse_args() + parse_actions = getattr(args, "parse_actions", []) + for i, parse_action in sorted(parse_actions, key=lambda x: x[0], reverse=True): + parse_action(args) + sys.exit(main(args)) diff --git a/FastSurferCNN/quick_qc.py b/FastSurferCNN/quick_qc.py index 1ea64857..313580ff 100644 --- a/FastSurferCNN/quick_qc.py +++ b/FastSurferCNN/quick_qc.py @@ -16,12 +16,12 @@ # IMPORTS import optparse import sys -import numpy as np -import nibabel as nib +from typing import cast +import nibabel as nib +import numpy as np from skimage.morphology import binary_dilation - HELPTEXT = """ Script to perform quick qualtiy checks for the input segmentation to identify gross errors. @@ -41,13 +41,13 @@ def options_parse(): - """Command line option parser. + """ + Command line option parser. Returns ------- options - object holding options - + Object holding options. """ parser = optparse.OptionParser( version="$Id: quick_qc,v 1.0 2022/09/28 11:34:08 mreuter Exp $", usage=HELPTEXT @@ -69,23 +69,23 @@ def options_parse(): return options -def check_volume(asegdkt_segfile, voxvol, thres=0.70): - """Check if total volume is bigger or smaller than threshold. +def check_volume(asegdkt_segfile:np.ndarray, voxvol: float, thres: float = 0.70): + """ + Check if total volume is bigger or smaller than threshold. Parameters ---------- - asegdkt_segfile : - [MISSING] - voxvol : - [MISSING] - thres : - [MISSING] + asegdkt_segfile : np.ndarray + The segmentation file. + voxvol : float + The volume of a voxel. + thres : float, default=0.7 + The threshold for the total volume (Default value = 0.70). Returns ------- bool - Whether or not total volume is bigger or smaller than threshold - + Whether or not total volume is bigger or smaller than threshold. """ print("Checking total volume ...") mask = asegdkt_segfile > 0 @@ -101,27 +101,34 @@ def check_volume(asegdkt_segfile, voxvol, thres=0.70): def get_region_bg_intersection_mask( seg_array, region_labels=VENT_LABELS, bg_label=BG_LABEL ): - """Return a mask of the intersection between the voxels of a given region and background voxels. - + f""" + Return a mask of the intersection between the voxels of a given region and background voxels. + This is obtained by dilating the region by 1 voxel and computing the intersection with the background mask. - + The region can be defined by passing in the region_labels dict. Parameters ---------- seg_array : numpy.ndarray - Segmentation array - region_labels : Dict - dict whose values correspond to the desired region's labels (Default value = VENT_LABELS) - bg_label : int - (Default value = BG_LABEL) + Segmentation array. + region_labels : dict, default= + Dictionary whose values correspond to the desired region's labels (see Note). + bg_label : int, default={BG_LABEL} + (Default value = {BG_LABEL}). Returns ------- bg_intersect : numpy.ndarray - Region and background intersection mask array - + Region and background intersection mask array. + + Notes + ----- + VENT_LABELS is a dictionary containing labels for four regions related to the ventricles: + "Left-Lateral-Ventricle", "Right-Lateral-Ventricle", "Left-choroid-plexus", + "Right-choroid-plexus" along with their corresponding integer label values + (see also FreeSurferColorLUT.txt). """ region_array = seg_array.copy() conditions = np.all( @@ -145,20 +152,20 @@ def get_region_bg_intersection_mask( def get_ventricle_bg_intersection_volume(seg_array, voxvol): - """Return a volume estimate for the intersection of ventricle voxels with background voxels. + """ + Return a volume estimate for the intersection of ventricle voxels with background voxels. Parameters ---------- seg_array : numpy.ndarray - Segmentation array + Segmentation array. voxvol : float - Voxel volume + Voxel volume. Returns ------- intersection_volume : float - Estimated volume of voxels in ventricle and background intersection - + Estimated volume of voxels in ventricle and background intersection. """ bg_intersect_mask = get_region_bg_intersection_mask(seg_array) intersection_volume = bg_intersect_mask.sum() * voxvol @@ -169,11 +176,11 @@ def get_ventricle_bg_intersection_volume(seg_array, voxvol): if __name__ == "__main__": # Command Line options are error checking done here options = options_parse() - print("Reading in aparc+aseg: {} ...".format(options.asegdkt_segfile)) - inseg = nib.load(options.asegdkt_segfile) + print(f"Reading in aparc+aseg: {options.asegdkt_segfile} ...") + inseg = cast(nib.analyze.SpatialImage, nib.load(options.asegdkt_segfile)) inseg_data = np.asanyarray(inseg.dataobj) inseg_header = inseg.header - inseg_voxvol = np.product(inseg_header.get_zooms()) + inseg_voxvol = np.prod(inseg_header.get_zooms()) # Ventricle-BG intersection volume check: print("Estimating ventricle-background intersection volume...") @@ -186,8 +193,7 @@ def get_ventricle_bg_intersection_volume(seg_array, voxvol): # Total volume check: if not check_volume(inseg_data, inseg_voxvol): print( - "WARNING: Total segmentation volume is very small. Segmentation may be corrupted! Please check." + "WARNING: Total segmentation volume is very small. Segmentation may be " + "corrupted! Please check." ) - sys.exit(0) - else: - sys.exit(0) + sys.exit(0) diff --git a/FastSurferCNN/reduce_to_aseg.py b/FastSurferCNN/reduce_to_aseg.py index 28c5590d..c7c3cc69 100644 --- a/FastSurferCNN/reduce_to_aseg.py +++ b/FastSurferCNN/reduce_to_aseg.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. - # IMPORTS +import copy import optparse import sys -import numpy as np -import nibabel as nib -import copy +import nibabel as nib +import numpy as np import scipy.ndimage -from skimage.measure import label from skimage.filters import gaussian +from skimage.measure import label HELPTEXT = """ Script to reduce aparc+aseg to aseg by mapping cortex labels back to left/right GM. @@ -42,7 +41,7 @@ Dependencies: - Python 3.8 + Python 3.8+ Numpy http://www.numpy.org @@ -65,13 +64,13 @@ def options_parse(): - """Command line option parser. + """ + Command line option parser. Returns ------- options - object holding options - + Object holding options. """ parser = optparse.OptionParser( version="$Id: reduce_to_aseg.py,v 1.0 2018/06/24 11:34:08 mreuter Exp $", @@ -91,18 +90,20 @@ def options_parse(): return options -def reduce_to_aseg(data_inseg): - """[MISSING]. +def reduce_to_aseg(data_inseg: np.ndarray) -> np.ndarray: + """ + Reduce the input segmentation to a simpler segmentation. Parameters ---------- - data_inseg : - [MISSING] + data_inseg : np.ndarray, torch.Tensor + The input segmentation. This should be a 3D array where the value at each position represents the segmentation + label for that position. Returns ------- - [MISSING] - + data_inseg : np.ndarray, torch.Tensor + The reduced segmentation. """ print("Reducing to aseg ...") # replace 2000... with 42 @@ -113,21 +114,22 @@ def reduce_to_aseg(data_inseg): def create_mask(aseg_data, dnum, enum): - """Create dilated mask. + """ + Create dilated mask. Parameters ---------- - aseg_data - [MISSING] - dnum - [MISSING] - enum - [MISSING] + aseg_data : npt.NDArray[int] + The input segmentation data. + dnum : int + The number of iterations for the dilation operation. + enum : int + The number of iterations for the erosion operation. Returns ------- - [MISSING] - + - + Returns aseg_data. """ print("Creating dilated mask ...") @@ -166,20 +168,19 @@ def create_mask(aseg_data, dnum, enum): return aseg_data -def flip_wm_islands(aseg_data): - """[MISSING]. +def flip_wm_islands(aseg_data : np.ndarray) -> np.ndarray: + """ + Flip labels of disconnected white matter islands to the other hemisphere. Parameters ---------- - aseg_data - [MISSING] - + aseg_data : numpy.ndarray + The input segmentation data. Returns ------- - flip_data - [MISSING] - + flip_data : numpy.ndarray + The segmentation data with flipped WM labels. """ # Sometimes WM is far in the other hemisphere, but with a WM label from the other hemi # These are usually islands, not connected to the main hemi WM component diff --git a/FastSurferCNN/run_model.py b/FastSurferCNN/run_model.py index 7142c4e7..4d7dcc89 100644 --- a/FastSurferCNN/run_model.py +++ b/FastSurferCNN/run_model.py @@ -12,27 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -# IMPORTS -from os.path import join -import sys import argparse import json +import sys +# IMPORTS +from os.path import join + +from FastSurferCNN.train import Trainer from FastSurferCNN.utils import misc from FastSurferCNN.utils.load_config import get_config -from FastSurferCNN.train import Trainer +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT -def setup_options(): - """Set up the options parsed from STDIN. - - Parses arguments from the STDIN, including the flags: --cfg, --aug, --opt, opts, +def make_parser() -> argparse.ArgumentParser: + """ + Set up the options parsed from STDIN. + + Parses arguments from the STDIN, including the flags: --cfg, --aug, --opt, opts. Returns ------- - options - object holding options - + argparse.ArgumentParser + The parser object for options. """ parser = argparse.ArgumentParser(description="Segmentation") @@ -40,7 +42,7 @@ def setup_options(): "--cfg", dest="cfg_file", help="Path to the config file", - default="config/FastSurferVINN.yaml", + default=FASTSURFER_ROOT / "FastSurferCNN/config/FastSurferVINN.yaml", type=str, ) parser.add_argument( @@ -53,15 +55,13 @@ def setup_options(): default=None, nargs=argparse.REMAINDER, ) - - if len(sys.argv) == 1: - parser.print_help() - return parser.parse_args() + return parser -def main(): - """[MISSING] First set variables and then runs the trainer model.""" - args = setup_options() +def main(args): + """ + First sets variables and then runs the trainer model. + """ cfg = get_config(args) if args.aug is not None: @@ -89,4 +89,8 @@ def main(): if __name__ == "__main__": - main() + parser = make_parser() + if len(sys.argv) == 1: + parser.print_help() + args = parser.parse_args() + main(args) diff --git a/FastSurferCNN/run_prediction.py b/FastSurferCNN/run_prediction.py index 8adefbf4..e6799ec6 100644 --- a/FastSurferCNN/run_prediction.py +++ b/FastSurferCNN/run_prediction.py @@ -12,126 +12,136 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This is the FastSurfer/run_prediction.py script, the backbone for whole brain +segmentation. + +Usage: + +See Also +-------- +:doc:`/scripts/fastsurfercnn` +`run_prediction.py --help` +""" + + # IMPORTS -import sys -import os -import copy import argparse -from typing import Tuple, Union, Literal, Dict, Any, Optional, Iterator -from concurrent.futures import Executor +import copy +import sys +from concurrent.futures import Executor, ThreadPoolExecutor, Future +from pathlib import Path +from typing import Any, Iterator, Literal, Optional, Sequence +import nibabel as nib import numpy as np import torch -import nibabel as nib import yacs.config +import FastSurferCNN.reduce_to_aseg as rta +from FastSurferCNN.data_loader import conform as conf +from FastSurferCNN.data_loader import data_utils as du from FastSurferCNN.inference import Inference -from FastSurferCNN.utils import logging, parser_defaults -from FastSurferCNN.utils.checkpoint import get_checkpoints, VINN_AXI, VINN_COR, VINN_SAG +from FastSurferCNN.utils import logging, parser_defaults, Plane, PLANES +from FastSurferCNN.utils.arg_types import VoxSizeOption +from FastSurferCNN.utils.checkpoint import ( + get_checkpoints, + load_checkpoint_config_defaults, +) from FastSurferCNN.utils.load_config import load_config from FastSurferCNN.utils.common import ( + SerialExecutor, find_device, assert_no_root, handle_cuda_memory_exception, SubjectList, SubjectDirectory, - NoParallelExecutor, pipeline, ) -from FastSurferCNN.data_loader import data_utils as du, conform as conf +from FastSurferCNN.utils.parser_defaults import SubjectDirectoryConfig from FastSurferCNN.quick_qc import check_volume -import FastSurferCNN.reduce_to_aseg as rta ## # Global Variables ## - +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT LOGGER = logging.getLogger(__name__) +CHECKPOINT_PATHS_FILE = FASTSURFER_ROOT / "FastSurferCNN/config/checkpoint_paths.yaml" ## # Processing ## -def set_up_cfgs(cfg: str, args: argparse.Namespace) -> yacs.config.CfgNode: - """Set up configuration. - - Sets up configurations with given arguments inside the yaml file +def set_up_cfgs( + cfg_file: str | Path, + batch_size: int = 1, +) -> yacs.config.CfgNode: + """ + Set up configuration. + + Sets up configurations with given arguments inside the yaml file. Parameters ---------- - cfg : str - path to yaml file of configurations - args : argparse.Namespace - {out_dir, batch_size} arguments + cfg_file : Path, str + Path to yaml file of configurations. + batch_size : int, default=1 + The batch size to use. Returns ------- yacs.config.CfgNode - Node of configurations - + Node of configurations. """ - cfg = load_config(cfg) - cfg.OUT_LOG_DIR = args.out_dir if args.out_dir is not None else cfg.LOG_DIR + cfg = load_config(str(cfg_file)) cfg.OUT_LOG_NAME = "fastsurfer" - cfg.TEST.BATCH_SIZE = args.batch_size + cfg.TEST.BATCH_SIZE = batch_size cfg.MODEL.OUT_TENSOR_WIDTH = cfg.DATA.PADDED_SIZE cfg.MODEL.OUT_TENSOR_HEIGHT = cfg.DATA.PADDED_SIZE return cfg -def args2cfg(args: argparse.Namespace) -> Tuple[yacs.config.CfgNode, yacs.config.CfgNode, yacs.config.CfgNode, yacs.config.CfgNode]: - """Extract the configuration objects from the arguments. - - Parameters - ---------- - args : argparse.Namespace - arguments - - Returns - ------- - yacs.config.CfgNode - configurations for all planes - +def args2cfg( + cfg_ax: Optional[str] = None, + cfg_cor: Optional[str] = None, + cfg_sag: Optional[str] = None, + batch_size: int = 1, +) -> tuple[ + yacs.config.CfgNode, yacs.config.CfgNode, yacs.config.CfgNode, yacs.config.CfgNode +]: """ - cfg_cor = set_up_cfgs(args.cfg_cor, args) if args.cfg_cor is not None else None - cfg_sag = set_up_cfgs(args.cfg_sag, args) if args.cfg_sag is not None else None - cfg_ax = set_up_cfgs(args.cfg_ax, args) if args.cfg_ax is not None else None - cfg_fin = ( - cfg_cor if cfg_cor is not None else cfg_sag if cfg_sag is not None else cfg_ax - ) - return cfg_fin, cfg_cor, cfg_sag, cfg_ax - - -def removesuffix(string: str, suffix: str) -> str: - """Remove a suffix from a string. - - Similar to string.removesuffix in PY3.9+, + Extract the configuration objects from the arguments. Parameters ---------- - string : str - string to be cut - suffix : str - suffix to be removed + cfg_ax : str, optional + The path to the axial network YAML config file. + cfg_cor : str, optional + The path to the coronal network YAML config file. + cfg_sag : str, optional + The path to the sagittal network YAML config file. + batch_size : int, default=1 + The batch size for the network. Returns ------- - Str - Suffix removed string - + yacs.config.CfgNode + Configurations for all planes. """ - import sys - - if sys.version_info.minor >= 9: - # removesuffix is a Python3.9 feature - return string.removesuffix(suffix) - else: - return ( - string[: -len(suffix)] - if len(suffix) > 0 and string.endswith(suffix) - else string - ) + if cfg_cor is not None: + cfg_cor = set_up_cfgs(cfg_cor, batch_size) + if cfg_sag is not None: + cfg_sag = set_up_cfgs(cfg_sag, batch_size) + if cfg_ax is not None: + cfg_ax = set_up_cfgs(cfg_ax, batch_size) + cfgs = (cfg_cor, cfg_sag, cfg_ax) + # returns the first non-None cfg + try: + cfg_fin = next(filter(None, cfgs)) + except StopIteration: + raise RuntimeError("No valid configuration passed!") + return (cfg_fin,) + cfgs ## @@ -140,171 +150,175 @@ def removesuffix(string: str, suffix: str) -> str: class RunModelOnData: - """Run the model prediction on given data. - + """ + Run the model prediction on given data. + + Attributes + ---------- + vox_size : float, 'min' + current_plane : str + models : Dict[str, Inference] + view_ops : Dict[str, Dict[str, Any]] + conform_to_1mm_threshold : float, optional + threshold until which the image will be conformed to 1mm res + Methods ------- __init__() Construct object. set_and_create_outdir() - sets and creates output directory + Sets and creates output directory. conform_and_save_orig() - saves original image + Saves original image. set_subject() - setter + Setter. get_subject_name() - getter + Getter. set_model() - setter + Setter. run_model() - calculates prediction + Calculates prediction. get_img() - getter + Getter. save_img() - saves image as file + Saves image as file. set_up_model_params() - setter + Setter. get_num_classes() - getter - - Attributes - ---------- - pred_name : str - conf_name : str - orig_name : str - vox_size : Union[float, Literal["min"]] - current_plane : str - models : Dict[str, Inference] - view_ops : Dict[str, Dict[str, Any]] - conform_to_1mm_threshold : Optional[float] - threshold until which the image will be conformed to 1mm res + Getter. """ - pred_name: str - conf_name: str - orig_name: str - vox_size: Union[float, Literal["min"]] - current_plane: str - models: Dict[str, Inference] - view_ops: Dict[str, Dict[str, Any]] + vox_size: float | Literal["min"] + current_plane: Plane + models: dict[Plane, Inference] + view_ops: dict[Plane, dict[str, Any]] conform_to_1mm_threshold: Optional[float] + device: torch.device + viewagg_device: torch.device + _pool: Executor - def __init__(self, args: argparse.Namespace): - """Construct RunModelOnData object. + def __init__( + self, + lut: Path, + ckpt_ax: Optional[Path] = None, + ckpt_sag: Optional[Path] = None, + ckpt_cor: Optional[Path] = None, + cfg_ax: Optional[Path] = None, + cfg_sag: Optional[Path] = None, + cfg_cor: Optional[Path] = None, + device: str = "auto", + viewagg_device: str = "auto", + threads: int = 1, + batch_size: int = 1, + vox_size: VoxSizeOption = "min", + async_io: bool = False, + conform_to_1mm_threshold: float = 0.95, + ): + """ + Construct RunModelOnData object. Parameters ---------- - args : args: argparse.Namespace) [MISSING] - pred_name : str - conf_name (str - orig_name (str - remove_suffix - sf : float - Defaults to 1.0 - out_dir : str - directory of output - viewagg_device : str - device to run viewagg on. Can be auto, cuda or cpu - + viewagg_device : str, default="auto" + Device to run viewagg on. Can be auto, cuda or cpu. """ - self.pred_name = args.pred_name - self.conf_name = args.conf_name - self.orig_name = args.orig_name - self._threads = getattr(args, "threads", 1) + # TODO Fix docstring of RunModelOnData.__init__ + self._threads = threads torch.set_num_threads(self._threads) - self._async_io = getattr(args, "async_io", False) + self._async_io = async_io self.sf = 1.0 - device = find_device(args.device) + self.device = find_device(device) - if device.type == "cpu" and args.viewagg_device == "auto": - self.viewagg_device = device + if self.device.type == "cpu" and viewagg_device in ("auto", "cpu"): + self.viewagg_device = self.device else: # check, if GPU is big enough to run view agg on it # (this currently takes the memory of the passed device) - self.viewagg_device = torch.device( - find_device( - args.viewagg_device, - flag_name="viewagg_device", - min_memory=4 * (2**30), - ) + self.viewagg_device = find_device( + viewagg_device, + flag_name="viewagg_device", + min_memory=4 * (2**30), ) LOGGER.info(f"Running view aggregation on {self.viewagg_device}") try: - self.lut = du.read_classes_from_lut(args.lut) - except FileNotFoundError as e: + self.lut = du.read_classes_from_lut(lut) + except FileNotFoundError: raise ValueError( - f"Could not find the ColorLUT in {args.lut}, please make sure the --lut argument is valid." + f"Could not find the ColorLUT in {lut}, please make sure the " + f"--lut argument is valid." ) self.labels = self.lut["ID"].values self.torch_labels = torch.from_numpy(self.lut["ID"].values) self.names = ["SubjectName", "Average", "Subcortical", "Cortical"] - self.cfg_fin, cfg_cor, cfg_sag, cfg_ax = args2cfg(args) + self.cfg_fin, cfg_cor, cfg_sag, cfg_ax = args2cfg( + cfg_ax, cfg_cor, cfg_sag, batch_size=batch_size, + ) # the order in this dictionary dictates the order in the view aggregation self.view_ops = { - "coronal": {"cfg": cfg_cor, "ckpt": args.ckpt_cor}, - "sagittal": {"cfg": cfg_sag, "ckpt": args.ckpt_sag}, - "axial": {"cfg": cfg_ax, "ckpt": args.ckpt_ax}, + "coronal": {"cfg": cfg_cor, "ckpt": ckpt_cor}, + "sagittal": {"cfg": cfg_sag, "ckpt": ckpt_sag}, + "axial": {"cfg": cfg_ax, "ckpt": ckpt_ax}, } self.num_classes = max( view["cfg"].MODEL.NUM_CLASSES for view in self.view_ops.values() ) self.models = {} for plane, view in self.view_ops.items(): - if view["cfg"] is not None and view["ckpt"] is not None: + if all(view[key] is not None for key in ("cfg", "ckpt")): self.models[plane] = Inference( - view["cfg"], ckpt=view["ckpt"], device=device, lut=self.lut + view["cfg"], ckpt=view["ckpt"], device=self.device, lut=self.lut, ) - vox_size = args.vox_size if vox_size == "min": self.vox_size = "min" elif 0.0 < float(vox_size) <= 1.0: self.vox_size = float(vox_size) else: raise ValueError( - f"Invalid value for vox_size, must be between 0 and 1 or 'min', was {vox_size}." + f"Invalid value for vox_size, must be between 0 and 1 or 'min', was " + f"{vox_size}." ) - self.conform_to_1mm_threshold = args.conform_to_1mm_threshold + self.conform_to_1mm_threshold = conform_to_1mm_threshold @property def pool(self) -> Executor: - """[MISSING].""" + """ + Return, and maybe create the objects executor object (with the number of threads + specified in __init__). + """ if not hasattr(self, "_pool"): if not self._async_io: - self._pool = NoParallelExecutor() + self._pool = SerialExecutor() else: - from concurrent.futures import ThreadPoolExecutor - self._pool = ThreadPoolExecutor(self._threads) return self._pool def __del__(self): - """[MISSING].""" + """Class destructor.""" if hasattr(self, "_pool"): - # only wait on futures, if we specifically ask (see end of the script, so we do not wait if we encounter a - # fail case) + # only wait on futures, if we specifically ask (see end of the script, so we + # do not wait if we encounter a fail case) self._pool.shutdown(True) def conform_and_save_orig( - self, - subject: SubjectDirectory - ) -> Tuple[nib.analyze.SpatialImage, np.ndarray]: - """Conform and saves original image. + self, subject: SubjectDirectory, + ) -> tuple[nib.analyze.SpatialImage, np.ndarray]: + """ + Conform and saves original image. Parameters ---------- subject : SubjectDirectory - subject directory object + Subject directory object. Returns ------- - Tuple[nib.analyze.SpatialImage, np.ndarray] - Conformed image - + tuple[nib.analyze.SpatialImage, np.ndarray] + Conformed image. """ orig, orig_data = du.load_image(subject.orig_name, "orig image") LOGGER.info(f"Successfully loaded image from {subject.orig_name}.") @@ -317,7 +331,7 @@ def conform_and_save_orig( orig, conform_vox_size=self.vox_size, check_dtype=True, - verbose=False, + verbose=True, conform_to_1mm_threshold=self.conform_to_1mm_threshold, ): LOGGER.info("Conforming image") @@ -335,37 +349,42 @@ def conform_and_save_orig( ) else: raise RuntimeError( - "Cannot resolve the name to the conformed image, please specify an absolute path." + "Cannot resolve the name to the conformed image, please specify an " + "absolute path." ) return orig, orig_data - def set_model(self, plane: str): - """[MISSING].""" + def set_model(self, plane: Plane): + """ + Set the current model for the specified plane. + + Parameters + ---------- + plane : Plane + The plane for which to set the current model. + """ self.current_plane = plane def get_prediction( - self, - image_name: str, - orig_data: np.ndarray, - zoom: Union[np.ndarray, Tuple] + self, image_name: str, orig_data: np.ndarray, zoom: np.ndarray | Sequence[int], ) -> np.ndarray: - """Run and get prediction. + """ + Run and get prediction. Parameters ---------- image_name : str - original image filename + Original image filename. orig_data : np.ndarray - original image data - zoom : Union[np.ndarray, Tuple] - original zoom + Original image data. + zoom : np.ndarray, tuple + Original zoom. Returns ------- np.ndarray - predicted classes - + Predicted classes. """ shape = orig_data.shape + (self.get_num_classes(),) kwargs = { @@ -388,37 +407,42 @@ def get_prediction( del pred_prob # map to freesurfer label space pred_classes = du.map_label2aparc_aseg(pred_classes, self.labels) - # return numpy array TODO: split_cortex_labels requires a numpy ndarray input, maybe we can also use Mapper here + # return numpy array + # TODO: split_cortex_labels requires a numpy ndarray input, maybe we can also + # use Mapper here pred_classes = du.split_cortex_labels(pred_classes.cpu().numpy()) return pred_classes def save_img( - self, - save_as: str, - data: Union[np.ndarray, torch.Tensor], - orig: nib.analyze.SpatialImage, - dtype: Optional[type] = None, - ): - """Save image as file. + self, + save_as: str | Path, + data: np.ndarray | torch.Tensor, + orig: nib.analyze.SpatialImage, + dtype: Optional[type] = None, + ) -> None: + """ + Save image as a file. Parameters ---------- - save_as : str - filename to give image - data : Union[np.ndarray, torch.Tensor] - image data + save_as : str, Path + Filename to give the image. + data : np.ndarray, torch.Tensor + Image data. orig : nib.analyze.SpatialImage - original Image - dtype : Optional[type] - (Default value = None) - + Original Image. + dtype : type, optional + Data type to use for saving the image. If None, the original data type is + used (Default value = None). """ + save_as = Path(save_as) # Create output directory if it does not already exist. - if not os.path.exists(os.path.dirname(save_as)): + if not save_as.parent.exists(): LOGGER.info( - f"Output image directory {os.path.basename(save_as)} does not exist. Creating it now..." + f"Output image directory {save_as.parent} does not exist. " + f"Creating it now..." ) - os.makedirs(os.path.dirname(save_as)) + save_as.parent.mkdir(parents=True) np_data = data if isinstance(data, np.ndarray) else data.cpu().numpy() if dtype is not None: @@ -426,35 +450,81 @@ def save_img( _header.set_data_dtype(dtype) else: _header = orig.header - r = du.save_image(_header, orig.affine, np_data, save_as, dtype=dtype) + du.save_image(_header, orig.affine, np_data, save_as, dtype=dtype) LOGGER.info( f"Successfully saved image {'asynchronously ' if self._async_io else ''} as {save_as}." ) - return r def async_save_img( - self, - save_as: str, - data: Union[np.ndarray, torch.Tensor], - orig: nib.analyze.SpatialImage, - dtype: Union[None, type] = None, - ): - """Save the image asynchronously and return a concurrent.futures.Future to track, when this finished.""" + self, + save_as: str | Path, + data: np.ndarray | torch.Tensor, + orig: nib.analyze.SpatialImage, + dtype: type | None = None, + ) -> Future[None]: + """ + Save the image asynchronously and return a concurrent.futures.Future to track, + when this finished. + + Parameters + ---------- + save_as : str, Path + Filename to give the image. + data : Union[np.ndarray, torch.Tensor] + Image data. + orig : nib.analyze.SpatialImage + Original Image. + dtype : type, optional + Data type to use for saving the image. If None, the original data type is + used. + + Returns + ------- + Future[None] + A Future object to synchronize (and catch/handle exceptions in the save_img + method). + """ return self.pool.submit(self.save_img, save_as, data, orig, dtype) - def set_up_model_params(self, plane, cfg, ckpt): - """Set up the model parameters from the configuration and checkpoint.""" + def set_up_model_params( + self, + plane: Plane, + cfg: "yacs.config.CfgNode", + ckpt: "torch.Tensor", + ) -> None: + """ + Set up the model parameters from the configuration and checkpoint. + """ self.view_ops[plane]["cfg"] = cfg self.view_ops[plane]["ckpt"] = ckpt def get_num_classes(self) -> int: - """Return the number of classes.""" + """ + Return the number of classes. + + Returns + ------- + int + The number of classes. + """ return self.num_classes def pipeline_conform_and_save_orig( - self, subjects: SubjectList - ) -> Iterator[Tuple[SubjectDirectory, Tuple[nib.analyze.SpatialImage, np.ndarray]]]: - """[MISSING].""" + self, subjects: SubjectList, + ) -> Iterator[tuple[SubjectDirectory, tuple[nib.analyze.SpatialImage, np.ndarray]]]: + """ + Pipeline for conforming and saving original images asynchronously. + + Parameters + ---------- + subjects : SubjectList + List of subjects to process. + + Yields + ------ + tuple[SubjectDirectory, tuple[nib.analyze.SpatialImage, np.ndarray]] + Subject directory and a tuple with the image and its data. + """ if not self._async_io: # do not pipeline, direct iteration and function call for subject in subjects: @@ -466,7 +536,15 @@ def pipeline_conform_and_save_orig( yield data -if __name__ == "__main__": +def make_parser(): + """ + Create the argparse object. + + Returns + ------- + argparse.ArgumentParser + The parser object. + """ parser = argparse.ArgumentParser(description="Evaluation metrics") # 1. Options for input directories and filenames @@ -489,21 +567,20 @@ def pipeline_conform_and_save_orig( ) # 3. Checkpoint to load + files: dict[Plane, str | Path] = {k: "default" for k in PLANES} parser = parser_defaults.add_plane_flags( parser, "checkpoint", - {"coronal": VINN_COR, "axial": VINN_AXI, "sagittal": VINN_SAG}, + files, + CHECKPOINT_PATHS_FILE ) # 4. CFG-file with default options for network parser = parser_defaults.add_plane_flags( parser, "config", - { - "coronal": "FastSurferCNN/config/FastSurferVINN_coronal.yaml", - "axial": "FastSurferCNN/config/FastSurferVINN_axial.yaml", - "sagittal": "FastSurferCNN/config/FastSurferVINN_sagittal.yaml", - }, + files, + CHECKPOINT_PATHS_FILE ) # 5. technical parameters @@ -520,46 +597,112 @@ def pipeline_conform_and_save_orig( "allow_root", ], ) - - args = parser.parse_args() - + return parser + +def main( + *, + orig_name: Path | str, + out_dir: Path, + pred_name: str, + ckpt_ax: Path, + ckpt_sag: Path, + ckpt_cor: Path, + cfg_ax: Path, + cfg_sag: Path, + cfg_cor: Path, + qc_log: str = "", + log_name: str = "", + allow_root: bool = False, + conf_name: str = "mri/orig.mgz", + in_dir: Optional[Path] = None, + sid: Optional[str] = None, + search_tag: Optional[str] = None, + csv_file: Optional[str | Path] = None, + lut: Optional[Path | str] = None, + remove_suffix: str = "", + brainmask_name: str = "mri/mask.mgz", + aseg_name: str = "mri/aseg.auto_noCC.mgz", + vox_size: VoxSizeOption = "min", + device: str = "auto", + viewagg_device: str = "auto", + batch_size: int = 1, + async_io: bool = True, + threads: int = -1, + conform_to_1mm_threshold: float = 0.95, + **kwargs, +) -> Literal[0] | str: # Warning if run as root user - args.allow_root or assert_no_root() + allow_root or assert_no_root() + + if len(kwargs) > 0: + LOGGER.warning(f"Unknown arguments {list(kwargs.keys())} in {__file__}:main.") qc_file_handle = None - if args.qc_log != "": + if qc_log != "": try: - qc_file_handle = open(args.qc_log, "w") + qc_file_handle = open(qc_log, "w") except NotADirectoryError: LOGGER.warning( "The directory in the provided QC log file path does not exist!" ) LOGGER.warning("The QC log file will not be saved.") - # Set up logging - from FastSurferCNN.utils.logging import setup_logging - - setup_logging(args.log_name) - # Download checkpoints if they do not exist # see utils/checkpoint.py for default paths LOGGER.info("Checking or downloading default checkpoints ...") - get_checkpoints(args.ckpt_ax, args.ckpt_cor, args.ckpt_sag) - - # Set Up Model - eval = RunModelOnData(args) - - args.copy_orig_name = os.path.join("mri", "orig", "001.mgz") - # Get all subjects of interest - subjects = SubjectList(args, segfile="pred_name", copy_orig_name="copy_orig_name") - subjects.make_subjects_dir() + + urls = load_checkpoint_config_defaults("url", filename=CHECKPOINT_PATHS_FILE) + + get_checkpoints(ckpt_ax, ckpt_cor, ckpt_sag, urls=urls) + + config = SubjectDirectoryConfig( + orig_name=orig_name, + pred_name=pred_name, + conf_name=conf_name, + in_dir=in_dir, + csv_file=csv_file, + sid=sid, + search_tag=search_tag, + brainmask_name=brainmask_name, + remove_suffix=remove_suffix, + out_dir=out_dir, + ) + config.copy_orig_name = "mri/orig/001.mgz" + + try: + # Get all subjects of interest + subjects = SubjectList( + config, + segfile="pred_name", + copy_orig_name="copy_orig_name", + ) + subjects.make_subjects_dir() + + # Set Up Model + eval = RunModelOnData( + lut=lut, + ckpt_ax=ckpt_ax, + ckpt_sag=ckpt_sag, + ckpt_cor=ckpt_cor, + cfg_ax=cfg_ax, + cfg_sag=cfg_sag, + cfg_cor=cfg_cor, + device=device, + viewagg_device=viewagg_device, + threads=threads, + batch_size=batch_size, + vox_size=vox_size, + async_io=async_io, + conform_to_1mm_threshold=conform_to_1mm_threshold, + ) + except RuntimeError as e: + return e.args[0] qc_failed_subject_count = 0 iter_subjects = eval.pipeline_conform_and_save_orig(subjects) futures = [] for subject, (orig_img, data_array) in iter_subjects: - # Run model try: # The orig_t1_file is only used to populate verbose messages here @@ -574,26 +717,30 @@ def pipeline_conform_and_save_orig( # Create aseg and brainmask - # There is a funny edge case in legacy FastSurfer 2.0, where the behavior is not well-defined, if orig_name - # is an absolute path, but out_dir is not set. Then, we would create a sub-folder in the folder of orig_name - # using the subject_id (passed by --sid or extracted from the orig_name) and use that as the subject folder. + # There is a funny edge case in legacy FastSurfer 2.0, where the behavior is + # not well-defined, if orig_name is an absolute path, but out_dir is not + # set. Then, we would create a sub-folder in the folder of orig_name using + # the subject_id (passed by --sid or extracted from the orig_name) and use + # that as the subject folder. bm = None - store_brainmask = subject.can_resolve_filename(args.brainmask_name) - store_aseg = subject.can_resolve_filename(args.aseg_name) + store_brainmask = subject.can_resolve_filename(brainmask_name) + store_aseg = subject.can_resolve_filename(aseg_name) if store_brainmask or store_aseg: LOGGER.info("Creating brainmask based on segmentation...") bm = rta.create_mask(copy.deepcopy(pred_data), 5, 4) if store_brainmask: # get mask - mask_name = subject.filename_in_subject_folder(args.brainmask_name) + mask_name = subject.filename_in_subject_folder(brainmask_name) futures.append( eval.async_save_img(mask_name, bm, orig_img, dtype=np.uint8) ) else: LOGGER.info( - "Not saving the brainmask, because we could not figure out where to store it. Please " - "specify a subject id with {sid[flag]}, or an absolute brainmask path with " - "{brainmask_name[flag]}.".format(**subjects.flags) + "Not saving the brainmask, because we could not figure out where " + "to store it. Please specify a subject id with {sid[flag]}, or an " + "absolute brainmask path with {brainmask_name[flag]}.".format( + **subjects.flags, + ) ) if store_aseg: @@ -602,24 +749,27 @@ def pipeline_conform_and_save_orig( aseg = rta.reduce_to_aseg(pred_data) aseg[bm == 0] = 0 aseg = rta.flip_wm_islands(aseg) - aseg_name = subject.filename_in_subject_folder(args.aseg_name) + aseg_name = subject.filename_in_subject_folder(aseg_name) # Change datatype to np.uint8, else mri_cc will fail! futures.append( eval.async_save_img(aseg_name, aseg, orig_img, dtype=np.uint8) ) else: LOGGER.info( - "Not saving the aseg file, because we could not figure out where to store it. Please " - "specify a subject id with {sid[flag]}, or an absolute aseg path with " - "{aseg_name[flag]}.".format(**subjects.flags) + "Not saving the aseg file, because we could not figure out where " + "to store it. Please specify a subject id with {sid[flag]}, or an " + "absolute aseg path with {aseg_name[flag]}.".format( + **subjects.flags, + ) ) # Run QC check LOGGER.info("Running volume-based QC check on segmentation...") - seg_voxvol = np.product(orig_img.header.get_zooms()) + seg_voxvol = np.prod(orig_img.header.get_zooms()) if not check_volume(pred_data, seg_voxvol): LOGGER.warning( - "Total segmentation volume is too small. Segmentation may be corrupted." + "Total segmentation volume is too small. Segmentation may be " + "corrupted." ) if qc_file_handle is not None: qc_file_handle.write(subject.id + "\n") @@ -627,7 +777,7 @@ def pipeline_conform_and_save_orig( qc_failed_subject_count += 1 except RuntimeError as e: if not handle_cuda_memory_exception(e): - raise e + return e.args[0] if qc_file_handle is not None: qc_file_handle.close() @@ -635,13 +785,22 @@ def pipeline_conform_and_save_orig( # Batch case: report ratio of QC warnings if len(subjects) > 1: LOGGER.info( - "Segmentations from {} out of {} processed cases failed the volume-based QC check.".format( - qc_failed_subject_count, len(subjects) - ) + f"Segmentations from {qc_failed_subject_count} out of {len(subjects)} " + f"processed cases failed the volume-based QC check." ) # wait for async processes to finish for f in futures: _ = f.result() + return 0 + + +if __name__ == "__main__": + parser = make_parser() + _args = parser.parse_args() + + # Set up logging + from FastSurferCNN.utils.logging import setup_logging + setup_logging(_args.log_name) - sys.exit(0) + sys.exit(main(**vars(_args))) diff --git a/FastSurferCNN/segstats.py b/FastSurferCNN/segstats.py index d3813914..0be39ef2 100644 --- a/FastSurferCNN/segstats.py +++ b/FastSurferCNN/segstats.py @@ -1,3 +1,5 @@ +#!/bin/python + # Copyright 2022 Image Analysis Lab, German Center for Neurodegenerative Diseases(DZNE), Bonn # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,25 +21,47 @@ from functools import partial, reduce from itertools import product from numbers import Number -from typing import Sequence, Tuple, Union, Optional, Dict, overload, cast, TypeVar, List, Iterable, Callable +from pathlib import Path +from typing import ( + Any, + Callable, + cast, + Iterable, + IO, + Literal, + Optional, + overload, + Sequence, + Sized, + Type, + TypedDict, + TypeVar, + Container, + Iterator, +) +from concurrent.futures import Executor, ThreadPoolExecutor + -import nibabel as nib import numpy as np import pandas as pd from numpy import typing as npt -from FastSurferCNN.utils.threads import get_num_threads +from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as robust_threshold +from FastSurferCNN.utils.arg_types import int_ge_zero as id_type +from FastSurferCNN.utils.arg_types import int_gt_zero as patch_size_type from FastSurferCNN.utils.parser_defaults import add_arguments -from FastSurferCNN.utils.arg_types import (int_gt_zero as patch_size, int_ge_zero as id_type, - float_gt_zero_and_le_one as robust_threshold) - -USAGE = "python seg_stats.py -norm -i -o [optional arguments]" -DESCRIPTION = "Script to calculate partial volumes and other segmentation statistics of a segmentation file." +from FastSurferCNN.utils.threads import get_num_threads -HELPTEXT = """ +# Constants +USAGE = ("python segstats.py (-norm|-pv) -i " + "-o [optional arguments] [{measures,mri_segstats} ...]") +DESCRIPTION = ("Script to calculate partial volumes and other segmentation statistics " + "of a segmentation file.") +VERSION = "1.1" +HELPTEXT = f""" Dependencies: - Python 3.8 + Python 3.10 Numpy http://www.numpy.org @@ -50,346 +74,1141 @@ Original Author: David Kügler Date: Dec-30-2022 -Modified: May-08-2023 +Modified: Dec-07-2023 + +Revision: {VERSION} """ +FILTER_SIZES = (3, 15) +COLUMNS = ["Index", "SegId", "NVoxels", "Volume_mm3", "StructName", "Mean", "StdDev", + "Min", "Max", "Range"] -_NumberType = TypeVar('_NumberType', bound=Number) +# Type definitions +_NumberType = TypeVar("_NumberType", bound=Number) _IntType = TypeVar("_IntType", bound=np.integer) -_DType = TypeVar('_DType', bound=np.dtype) +_DType = TypeVar("_DType", bound=np.dtype) _ArrayType = TypeVar("_ArrayType", bound=np.ndarray) -PVStats = Dict[str, Union[int, float]] -VirtualLabel = Dict[int, Sequence[int]] +SlicingTuple = tuple[slice, ...] +SlicingSequence = Sequence[slice] +VirtualLabel = dict[int, Sequence[int]] +_GlobalStats = tuple[int, int, Optional[_NumberType], Optional[_NumberType], + Optional[float], Optional[float], float, npt.NDArray[bool]] +SubparserCallback = Type[argparse.ArgumentParser.add_subparsers] -FILTER_SIZES = (3, 15) -UNITS = {"Volume_mm3": "mm^3", "normMean": "MR", "normStdDev": "MR", "normMin": "MR", "normMax": "MR", - "normRange": "MR"} -FIELDS = {"Index": "Index", "SegId": "Segmentation Id", "NVoxels": "Number of Voxels", "Volume_mm3": "Volume", - "StructName": "Structure Name", "normMean": "Intensity normMean", "normStdDev": "Intensity normStdDev", - "normMin": "Intensity normMin", "normMax": "Intensity normMax", "normRange": "Intensity normRange"} -FORMATS = {"Index": "d", "SegId": "d", "NVoxels": "d", "Volume_mm3": ".3f", "StructName": "s", "normMean": ".4f", - "normStdDev": ".4f", "normMin": ".4f", "normMax": ".4f", "normRange": ".4f"} +class _RequiredPVStats(TypedDict): + SegId: int + NVoxels: int + Volume_mm3: float + + +class _OptionalPVStats(TypedDict, total=False): + StructName: str + Mean: float + StdDev: float + Min: float + Max: float + Range: float + + +class PVStats(_RequiredPVStats, _OptionalPVStats): + """Dictionary of volume statistics for partial volume evaluation and global stats""" + pass class HelpFormatter(argparse.HelpFormatter): - """Help formatter that forces line breaks in texts where the text is
.""" + """ + Help formatter that forces line breaks in texts where the text is
. + """ def _linebreak_sub(self): + """ + Get the linebreak substitution string. + + Returns + ------- + str + The linebreak substitution string ("
"). + """ return getattr(self, "linebreak_sub", "
") - def _fill_text(self, text, width, indent): - texts = text.split(self._linebreak_sub()) - return "\n".join([super(HelpFormatter, self)._fill_text(tex, width, indent) for tex in texts]) + def _item_symbol(self): + return getattr(self, "item_symbol", "- ") - def _split_lines(self, text: str, width: int): + def _fill_text(self, text: str, width: int, indent: str) -> str: + """ + Fill text with line breaks based on the linebreak substitution string. + + Parameters + ---------- + text : str + The input text. + width : int + The width for filling the text. + indent : int + The indentation level. + + Returns + ------- + str + The formatted text with line breaks. + """ + cond_len, texts = self._itemized_lines(text) + lines = (super(HelpFormatter, self)._fill_text(t[p:], width, indent + " " * p) + for t, (c, p) in zip(texts, cond_len)) + return "\n".join("- " + t[p:] if c else t for t, (c, p) in zip(lines, cond_len)) + + def _itemized_lines(self, text): texts = text.split(self._linebreak_sub()) + item = self._item_symbol() + il = len(item) + cond_len = [(c, il if c else 0) for c in map(lambda t: t[:il] == item, texts)] + texts = [t[p:] for t, (c, p) in zip(texts, cond_len)] + return cond_len, texts + + def _split_lines(self, text: str, width: int) -> list[str]: + """ + Split lines in the text based on the linebreak substitution string. + + Parameters + ---------- + text : str + The input text. + width : int + The width for splitting lines. + + Returns + ------- + list[str] + The list of lines. + """ + def indent_list(items: list[str]) -> list[str]: + return ["- " + items[0]] + [" " + l for l in items[1:]] + + cond_len, texts = self._itemized_lines(text) from itertools import chain - return list(chain.from_iterable(super(HelpFormatter, self)._split_lines(tex, width) for tex in texts)) - - -def make_arguments() -> argparse.ArgumentParser: - """[MISSING].""" - parser = argparse.ArgumentParser(usage=USAGE, epilog=HELPTEXT.replace("\n", "
"), description=DESCRIPTION, - formatter_class=HelpFormatter) - parser.add_argument('-norm', '--normfile', type=str, required=True, dest='normfile', - help="Biasfield-corrected image in the same image space as segmentation (required).") - parser.add_argument('-i', '--segfile', type=str, dest='segfile', required=True, - help="Segmentation file to read and use for evaluation (required).") - parser.add_argument('-o', '--segstatsfile', type=str, required=True, dest='segstatsfile', - help="Path to output segstats file.") - - parser.add_argument('--excludeid', type=id_type, nargs="*", default=[0], - help="List of segmentation ids (integers) to exclude in analysis, " - "e.g. `--excludeid 0 1 10` (default: 0).") - parser.add_argument('--ids', type=id_type, nargs="*", - help="List of exclusive segmentation ids (integers) to use " - "(default: all ids in --lut or all ids in image).") - parser.add_argument('--merged_label', type=id_type, nargs="+", dest='merged_labels', default=[], action='append', - help="Add a 'virtual' label (first value) that is the combination of all following values, " - "e.g. `--merged_label 100 3 4 8` will compute the statistics for label 100 by aggregating " - "labels 3, 4 and 8.") - parser.add_argument('--robust', type=robust_threshold, dest='robust', default=None, - help="Whether to calculate robust segmentation metrics. This parameter " - "expects the fraction of values to keep, e.g. `--robust 0.95` will " - "ignore the 2.5%% smallest and the 2.5%% largest values in the " - "segmentation when calculating the statistics (default: no robust " - "statistics == `--robust 1.0`).") - advanced = parser.add_argument_group(title="Advanced options") - advanced.add_argument('--threads', dest='threads', default=get_num_threads(), type=int, - help=f"Number of threads to use (defaults to number of hardware threads: " - f"{get_num_threads()})") - advanced.add_argument('--patch_size', type=patch_size, dest='patch_size', default=32, - help="Patch size to use in calculating the partial volumes (default: 32).") - advanced.add_argument('--empty', action='store_true', dest='empty', - help="Keep ids for the table that do not exist in the segmentation (default: drop).") - advanced = add_arguments(advanced, ['device', 'lut', 'sid', 'in_dir', 'allow_root']) - advanced.add_argument('--legacy_freesurfer', action='store_true', dest='legacy_freesurfer', - help="Reproduce FreeSurfer mri_segstats numbers (default: off). \n" - "Please note, that exact agreement of numbers cannot be guaranteed, because the " - "condition number of FreeSurfers algorithm (mri_segstats) combined with the fact that " - "mri_segstats uses 'float' to measure the partial volume corrected volume. This yields " - "differences of more than 60mm3 or 0.1%% in large structures. This uniquely impacts " - "highres images with more voxels (on the boundry) and smaller voxel sizes (volume per " - "voxel).") + lines = (super(HelpFormatter, self)._split_lines(tex, width - p) + for tex, (c, p) in zip(texts, cond_len)) + lines = ((indent_list(lst) if c[0] else lst) for lst, c in zip(lines, cond_len)) + return list(chain.from_iterable(lines)) + + +def make_arguments(helpformatter: bool = False) -> argparse.ArgumentParser: + """ + Create an argument parser object with all parameters of the script. + + Returns + ------- + argparse.ArgumentParser + The configured argument parser. + """ + import sys + if helpformatter: + kwargs = { + "epilog": HELPTEXT.replace("\n", "
"), + "formatter_class": HelpFormatter, + } + else: + kwargs = {"epilog": HELPTEXT} + parser = argparse.ArgumentParser( + usage=USAGE, + description=DESCRIPTION, + add_help=False, + **kwargs, + ) + add_two_help_messages(parser) + parser.add_argument( + "--pvfile", + "-pv", + type=Path, + dest="pvfile", + help="Path to image used to compute the partial volume effects (default: the " + "file passed as normfile). This file is required, either directly or " + "indirectly via normfile.", + ) + parser.add_argument( + "-norm", + "--normfile", + type=Path, + dest="normfile", + help="Path to biasfield-corrected image (the same image space as " + "segmentation). This file is used to calculate intensity values. Also, if " + "no pvfile is defined, it is used as pvfile. One of normfile or pvfile is " + "required.", + ) + parser.add_argument( + "-i", + "--segfile", + type=Path, + dest="segfile", + required=True, + help="Segmentation file to read and use for evaluation (required).", + ) + parser.add_argument( + "-o", + "--segstatsfile", + type=Path, + required=True, + dest="segstatsfile", + help="Path to output segstats file.", + ) + + parser.add_argument( + "--excludeid", + type=id_type, + nargs="*", + default=[], + help="List of segmentation ids (integers) to exclude in analysis, " + "e.g. `--excludeid 0 1 10` (default: None).", + ) + parser.add_argument( + "--ids", + type=id_type, + nargs="*", + help="List of exclusive segmentation ids (integers) to use " + "(default: all ids in --lut or all ids in image).", + ) + parser.add_argument( + "--merged_label", + type=id_type, + nargs="+", + dest="merged_labels", + default=[], + action="append", + help="Add a 'virtual' label (first value) that is the combination of all " + "following values, e.g. `--merged_label 100 3 4 8` will compute the " + "statistics for label 100 by aggregating labels 3, 4 and 8.", + ) + parser.add_argument( + "--robust", + type=robust_threshold, + dest="robust", + default=None, + help="Whether to calculate robust segmentation metrics. This parameter " + "expects the fraction of values to keep, e.g. `--robust 0.95` will " + "ignore the 2.5%% smallest and the 2.5%% largest values in the " + "segmentation when calculating the statistics (default: no robust " + "statistics == `--robust 1.0`).", + ) + parser.add_argument( + "--measure_only", + action="store_true", + dest="measure_only", + help="Only calculate the Measures in the header, no PV table." + ) + subparsers = parser.add_subparsers(title="Suboptions", dest="subparser") + add_measure_parser(subparsers.add_parser) + advanced = parser.add_argument_group(title="Advanced options (not shown in -h)") + if "-h" in sys.argv: + return parser + advanced.add_argument( + "--threads", + dest="threads", + default=get_num_threads(), + type=int, + help=f"Number of threads to use (defaults to number of hardware threads: " + f"{get_num_threads()})", + ) + advanced.add_argument( + "--patch_size", + type=patch_size_type, + dest="patch_size", + default=32, + help="Patch size to use in calculating the partial volumes (default: 32).", + ) + advanced.add_argument( + "--empty", + action="store_true", + dest="empty", + help="Keep ids for the table that do not exist in the segmentation " + "(default: drop).", + ) + advanced = add_arguments(advanced, ["device", "sid", "sd", "allow_root"]) + advanced.add_argument( + "--lut", + type=Path, + metavar="lut", + dest="lut", + help="Path and name of LUT to use.", + ) + advanced.add_argument( + "--legacy_freesurfer", + action="store_true", + dest="legacy_freesurfer", + help="Reproduce FreeSurfer mri_segstats numbers (default: off). \n" + "Please note, that exact agreement of numbers cannot be guaranteed, " + "because the condition number of FreeSurfers algorithm (mri_segstats) " + "combined with the fact that mri_segstats uses 'float' to measure the " + "partial volume corrected volume. This yields differences of more than " + "60mm3 or 0.1%% in large structures. This uniquely impacts highres images " + "with more voxels (on the boundary) and smaller voxel sizes (volume per " + "voxel).", + ) # Additional info: - # Changing the data type in mri_segstats to double can reduce this difference to nearly zero. + # Changing the data type in mri_segstats to double can reduce this difference to + # nearly zero. # mri_segstats has two operations affecting a bad condition number: # 1. pv = (val - mean_nbr) / (mean_label - mean_nbr) # 2. volume += vox_vol * pv - # This is further affected by the small vox_vol (volume per voxel) of highres images (0.7iso -> 0.343) - # Their effects stack and can result in differences of more than 60mm3 or 0.1% in a comparison between double and - # single-precision evaluations. - advanced.add_argument('--mixing_coeff', type=str, dest='mix_coeff', default='', - help="Save the mixing coefficients (default: off).") - advanced.add_argument('--alternate_labels', type=str, dest='nbr', default='', - help="Save the alternate labels (default: off).") - advanced.add_argument('--alternate_mixing_coeff', type=str, dest='nbr_mix_coeff', default='', - help="Save the alternate labels' mixing coefficients (default: off).") - advanced.add_argument('--seg_means', type=str, dest='seg_means', default='', - help="Save the segmentation labels' means (default: off).") - advanced.add_argument('--alternate_means', type=str, dest='nbr_means', default='', - help="Save the alternate labels' means (default: off).") - advanced.add_argument('--volume_precision', type=id_type, dest='volume_precision', default=None, - help="Number of digits after dot in summary stats file (default: 3). Note, " - "--legacy_freesurfer sets this to 1.") + # This is further affected by the small vox_vol (volume per voxel) of highres + # images (0.7iso -> 0.343) + # Their effects stack and can result in differences of more than 60mm3 or 0.1% in + # a comparison between double and single-precision evaluations. + advanced.add_argument( + "--mixing_coeff", + type=Path, + dest="mix_coeff", + default="", + help="Save the mixing coefficients (default: off).", + ) + advanced.add_argument( + "--alternate_labels", + type=Path, + dest="nbr", + default="", + help="Save the alternate labels (default: off).", + ) + advanced.add_argument( + "--alternate_mixing_coeff", + type=Path, + dest="nbr_mix_coeff", + default="", + help="Save the alternate labels' mixing coefficients (default: off).", + ) + advanced.add_argument( + "--seg_means", + type=Path, + dest="seg_means", + default="", + help="Save the segmentation labels' means (default: off).", + ) + advanced.add_argument( + "--alternate_means", + type=Path, + dest="nbr_means", + default="", + help="Save the alternate labels' means (default: off).", + ) + advanced.add_argument( + "--volume_precision", + type=id_type, + dest="volume_precision", + default=3, + help="Number of digits after dot in summary stats file (default: 3). Use 1 for " + "maximum FreeSurfer compatibility).", + ) + advanced.add_argument( + "--norm_name", + type=str, + dest="norm_name", + default="norm", + help="Option to change the name of the in volume (default: norm)." + ) + advanced.add_argument( + "--norm_unit", + type=str, + dest="norm_unit", + default="MR", + help="Option to change the unit of the in volume (default: MR)." + ) return parser -def loadfile_full(file: str, name: str) \ - -> Tuple[nib.analyze.SpatialImage, np.ndarray]: - """Load full image and data. +def empty(__arg: Any) -> bool: + """ + Checks if the argument is an empty list (or None). + """ + return __arg is None or (isinstance(__arg, Sized) and len(__arg) == 0) + + +def add_measure_parser(subparser_callback: SubparserCallback) -> None: + """ + Add a parser that supports adding measures to the parameters. + """ + measure_parser = subparser_callback( + "measures", + usage="python segstats.py (...) measures [optional arguments]", + argument_default="measures", + help="Configures options to measures", + description="Options to configure measures", + formatter_class=HelpFormatter, + add_help=False, + ) + add_two_help_messages(measure_parser) + + def __add_computed_measure(x: str) -> tuple[bool, str]: + return False, x + measure_parser.add_argument( + "--compute", + type=__add_computed_measure, + nargs="+", + action="extend", + default=[], + dest="measures", + help="Additional Measures to compute based on imported/computed measures:
" + "Cortex, CerebralWhiteMatter, SubCortGray, TotalGray, " + "BrainSegVol-to-eTIV, MaskVol-to-eTIV, SurfaceHoles, " + "EstimatedTotalIntraCranialVol", + ) + + def __add_imported_measure(x: str) -> tuple[bool, str]: + return True, x + measure_parser.add_argument( + '--import', + type=__add_imported_measure, + nargs="+", + action="extend", + default=[], + dest="measures", + help="Additional Measures to import from the measurefile.
" + "Example measures ('all' to import all measures in the measurefile):
" + "BrainSeg, BrainSegNotVent, SupraTentorial, SupraTentorialNotVent, " + "SubCortGray, lhCortex, rhCortex, Cortex, TotalGray, " + "lhCerebralWhiteMatter, rhCerebralWhiteMatter, CerebralWhiteMatter, Mask, " + "SupraTentorialNotVentVox, BrainSegNotVentSurf, VentricleChoroidVol, " + "BrainSegVol-to-eTIV, MaskVol-to-eTIV, lhSurfaceHoles, rhSurfaceHoles, " + "SurfaceHoles, EstimatedTotalIntraCranialVol
" + "Note, 'all' will always be overwritten by any explicitly mentioned " + "measures.", + ) + measure_parser.add_argument( + "--file", + type=Path, + dest="measurefile", + default="brainvol.stats", + help="Default file to read measures (--import ...) from. If the path is " + "relative, it is interpreted as relative to subjects_dir/subject_id from" + "--sd and --subject_id.", + ) + measure_parser.add_argument( + "--from_seg", + type=Path, + dest="aseg_replace", + default=None, + help="Replace the default segfile to compute measures from by -i/--segfile. " + "This will default to 'mri/aseg.mgz' for --legacy_freesurfer and to the " + "value of -i/--segfile otherwise." + ) + + +def add_two_help_messages(parser: argparse.ArgumentParser) -> None: + """ + Adds separate help flags -h and --help to the parser for simple and detailed help. + Both trigger the help action. Parameters ---------- - file : str - filename - name : - Subject name - + parser : argparse.ArgumentParser + Parser to add the flags to. """ - try: - img = nib.load(file) - except (IOError, FileNotFoundError) as e: - raise IOError(f"Failed loading the {name} '{file}' with error: {e.args[0]}") from e - data = np.asarray(img.dataobj) - return img, data + def this_msg(msg: str, flag: str) -> str: + import sys + return f"{msg} (this message)" if flag in sys.argv else msg + parser.add_argument( + "-h", action="help", + help=this_msg("show a short help message and exit", "-h")) + parser.add_argument( + "--help", action="help", + help=this_msg("show a long, detailed help message and exit", "--help")) + + +def _check_arg_path( + __args: argparse.Namespace, + __attr: str, + subjects_dir: Path | None, + subject_id: str | None, + allow_subject_dir: bool = True, + require_exist: bool = True, +) -> Path: + """ + Check an argument that is supposed to be a Path object and finding the absolute + path, which can be derived from the subject_dir. + Parameters + ---------- + __args : argparse.Namespace + The arguments object. + __attr: str + The name of the attribute in the Namespace object. + allow_subject_dir : bool, optional + Whether relative paths are supposed to be understood with respect to + subjects_dir / subject_id (default: True). + require_exist : bool, optional + Raise a ValueError, if the indicated file does not exist (default: True). -def main(args): - """[MISSING]. + Returns + ------- + Path + The resulting Path object. + + Raises + ------ + ValueError + If attribute does not exist, is not a Path (or convertible to a Path), or if + the file does not exist, but reuire_exist is True. + """ + if (_attr_val := getattr(__args, __attr), None) is None: + raise ValueError(f"No {__attr} passed.") + if isinstance(_attr_val, str): + _attr_val = Path(_attr_val) + elif not isinstance(_attr_val, Path): + raise ValueError(f"{_attr_val} is not a Path object.") + if allow_subject_dir and not _attr_val.is_absolute(): + if isinstance(subjects_dir, Path) and subject_id is not None: + _attr_val = subjects_dir / subject_id / _attr_val + if require_exist and not _attr_val.exists(): + raise ValueError(f"Path {_attr_val} did not exist for {__attr}.") + return _attr_val + + +def _check_arg_defined(attr: str, /, args: argparse.Namespace) -> bool: + """ + Check whether the attribute attr is defined in args. Parameters ---------- - args : - [MISSING] + attr: str + The name of the attribute. + args: argparse.Namespace + The argument container object. Returns ------- - [MISSING] - + bool + Whether the argument is defined (not None, not an empty container/str). """ - import os - import time - start = time.perf_counter_ns() - from FastSurferCNN.utils.common import assert_no_root - getattr(args, "allow_root", False) or assert_no_root() + value = getattr(args, attr, None) + return not (value is None or empty(value)) - if not hasattr(args, 'segfile') or not os.path.exists(args.segfile): - return "No segfile was passed or it does not exist." - if not hasattr(args, 'normfile') or not os.path.exists(args.normfile): - return "No normfile was passed or it does not exist." - if not hasattr(args, 'segstatsfile'): - return "No segstats file was passed" - threads = args.threads - if threads <= 0: - threads = get_num_threads() +def check_shape_affine( + img1: "nib.analyze.SpatialImage", + img2: "nib.analyze.SpatialImage", + name1: str, + name2: str, +) -> None: + """ + Check whether the shape and affine of - from concurrent.futures import ThreadPoolExecutor - with ThreadPoolExecutor(threads) as tpe: - # load these files in different threads to avoid waiting on IO (not parallel due to GIL though) - seg_future = tpe.submit(loadfile_full, args.segfile, 'segfile') - norm_future = tpe.submit(loadfile_full, args.normfile, 'normfile') + Parameters + ---------- + img1 : nibabel.SpatialImage + Image 1. + img2 : nibabel.SpatialImage + Image 2. + name1 : str + Name of image 1. + name2 : str + Name of image 2. + + Raises + ------ + RuntimeError + If shapes or affines are not the same. + """ + if img1.shape != img2.shape or not np.allclose(img1.affine, img2.affine): + raise RuntimeError( + f"The shapes or affines of the {name1} and the {name2} image are not " + f"similar, both must be the same!" + ) + + +def parse_files( + args: argparse.Namespace, + subjects_dir: Path | str | None = None, + subject_id: str | None = None, + require_measurefile: bool = False, + require_pvfile: bool = True, +) -> tuple[Path, Path | None, Path | None, Path, Path | None]: + """ + Parse and read paths of files. - if hasattr(args, 'lut') and args.lut is not None: - try: - lut = read_classes_from_lut(args.lut) - except FileNotFoundError as e: - return f"Could not find the ColorLUT in {args.lut}, please make sure the --lut argument is valid." + Parameters + ---------- + args : argparse.Namespace + Parameters object from make_arguments. + subjects_dir : Path, str, optional + Path to SUBJECTS_DIR, where subject directories are. + subject_id : str, optional + The subject_id string. + require_measurefile : bool, default=False + Require the measurefile to exist. + require_pvfile : bool, default=True + Require a pvfile or normfile to exist. + + Returns + ------- + segfile : Path + Path to the segmentation file, most likely an absolute path. + pvfile : Path, None + Path to the pvfile file, most likely an absolute path. + normfile : Path, None + Path to the norm file, most likely an absolute path, or None if not passed. + segstatsfile : Path + Path to the output segstats file, most likely an absolute path. + measurefile : Path, None + Path to the measure file, most likely an absolute path, not None is not passed. + + Raises + ------ + ValueError + If there is a necessary parameter missing or invalid. + """ + if subjects_dir is not None: + subjects_dir = Path(subjects_dir) + check_arg_path = partial( + _check_arg_path, subjects_dir=subjects_dir, subject_id=subject_id + ) + segfile = check_arg_path(args, "segfile") + not_has_arg = partial(_check_arg_defined, args=args) + if not any(map(not_has_arg, ("normfile", "pvfile"))): + if require_pvfile: + raise ValueError("Either pvfile or normfile are required.") + pvfile = None + normfile = None + elif getattr(args, "normfile", None) is None: + pvfile = check_arg_path(args, "pvfile") + normfile = None + else: + normfile = check_arg_path(args, "normfile") + if getattr(args, "pvfile", None) is None: + pvfile = normfile else: - lut = None - try: - seg, seg_data = seg_future.result() # type: nib.analyze.SpatialImage, Union[np.ndarray, torch.IntTensor] - norm, norm_data = norm_future.result() # type: nib.analyze.SpatialImage, Union[np.ndarray, torch.Tensor] + pvfile = check_arg_path(args, "pvfile") - if seg_data.shape != norm_data.shape or not np.allclose(seg.affine, norm.affine): - return "The shapes or affines of the segmentation and the norm image are not similar, both must be " \ - "the same!" + segstatsfile = check_arg_path(args, "segstatsfile", require_exist=False) + if not segstatsfile.is_absolute(): + raise ValueError("segstatsfile must be an absolute path!") - except IOError as e: - return e.args[0] + if (measurefile := getattr(args, "measurefile", None)) is not None: + measurefile = check_arg_path( + args, + "measurefile", + require_exist=require_measurefile, + ) + + return segfile, pvfile, normfile, segstatsfile, measurefile + + +def infer_labels_excludeid( + args: argparse.Namespace, + lut: "pd.DataFrame", + data: "npt.NDArray[int]", +) -> tuple["npt.NDArray[int]", list[int]]: + """ + Infer the labels and excluded ids from command line arguments, the lookup table, or + the segmentation image. + + Parameters + ---------- + args : argparse.Namespace + The commandline arguments object. + lut : pd.DataFrame + The ColorLUT lookup table object, e.g. FreeSurferColorLUT. + data : npt.NDArray[int] + The segmentation array. + + Returns + ------- + labels : npt.NDArray[int] + The array of all labels to calculate partial volumes for. + exclude_id : list[int] + A list of labels exlicitly excluded from the output table. + """ explicit_ids = False - if hasattr(args, 'ids') and args.ids is not None and len(args.ids) > 0: - labels = np.asarray(args.ids) + if __ids := getattr(args, "ids", None): + labels = np.asarray(__ids) explicit_ids = True elif lut is not None: - labels = lut['ID'] # the column ID contains all ids + labels = lut["ID"] # the column ID contains all ids else: - labels = np.unique(seg_data) + labels = np.unique(data) - if hasattr(args, 'excludeid') and args.excludeid is not None and len(args.excludeid) > 0: - exclude_id = list(args.excludeid) + # filter for excludeid entries + exclude_id = [] + if _excl_id := getattr(args, "excludeid", None): + exclude_id = list(_excl_id) + # check whether if explicit_ids: - excluded_expl_ids = np.asarray(list(filter(lambda x: x in exclude_id, labels))) + _exclude = list(filter(lambda x: x in exclude_id, labels)) + excluded_expl_ids = np.asarray(_exclude) if excluded_expl_ids.size > 0: - return "Some IDs explicitly passed via --ids are also in the list of ids to exclude (--excludeid)" - labels = np.asarray(list(filter(lambda x: x not in exclude_id, labels))) - else: - exclude_id = [] - - kwargs = { - "vox_vol": np.prod(seg.header.get_zooms()).item(), - "robust_percentage": getattr(args, 'robust', None), - "threads": threads, - "legacy_freesurfer": bool(getattr(args, 'legacy_freesurfer', False)), - "patch_size": args.patch_size - } + raise ValueError( + "Some IDs explicitly passed via --ids are also in the list of " + "ids to exclude (--excludeid)." + ) + labels = np.asarray([x for x in labels if x not in exclude_id], dtype=int) + return labels, exclude_id - if getattr(args, 'volume_precision', None) is not None: - FORMATS['Volume_mm3'] = f'.{getattr(args, "volume_precision"):d}f' - elif kwargs["legacy_freesurfer"]: - FORMATS['Volume_mm3'] = f'.1f' - if args.merged_labels is not None and len(args.merged_labels) > 0: - kwargs["merged_labels"] = {lab: vals for lab, *vals in args.merged_labels} +def main(args: argparse.Namespace) -> Literal[0] | str: + """ + Main segstats function, based on mri_segstats. - names = ['nbr', 'nbr_means', 'seg_means', 'mix_coeff', 'nbr_mix_coeff'] - var_names = ['nbr', 'nbrmean', 'segmean', 'pv', 'ipv'] - dtypes = [np.int16] + [np.float32] * 4 - if any(getattr(args, n, '') != '' for n in names): - table, maps = pv_calc(seg_data, norm_data, labels, return_maps=True, **kwargs) + Parameters + ---------- + args : object + Parameter object as defined by `make_arguments().parse_args()`. - for n, v, dtype in zip(names, var_names, dtypes): - file = getattr(args, n, '') - if file == '': - continue + Returns + ------- + Literal[0], str + Either as a successful return code or a string with an error message. + """ + from time import perf_counter_ns + from FastSurferCNN.utils.common import assert_no_root + from FastSurferCNN.utils.brainvolstats import Manager, read_volume_file, ImageTuple + from FastSurferCNN.data_loader.data_utils import read_classes_from_lut + + start = perf_counter_ns() + getattr(args, "allow_root", False) or assert_no_root() + + subjects_dir = getattr(args, "out_dir", None) + if subjects_dir is not None: + subjects_dir = Path(subjects_dir) + subject_id = str(getattr(args, "sid", None)) + legacy_freesurfer = bool(getattr(args, "legacy_freesurfer", False)) + measure_only = bool(getattr(args, "measure_only", False)) + manager_kwargs = {} + + # Check filename parameters segfile, pvfile, normfile, segstatsfile, and measurefile + try: + # individual entries are: (is_this_imported, the_name_and_parameters) + measures: list[tuple[bool, str]] = getattr(args, "measures", []) + any_imported_measure = any(filter(lambda x: x[0], measures)) + segfile, pvfile, normfile, segstatsfile, measurefile = parse_files( + args, + subjects_dir, + subject_id, + require_measurefile=any_imported_measure, + require_pvfile=not legacy_freesurfer, + ) + if legacy_freesurfer and not measure_only and pvfile is None: + return (f"No files are defined via -pv/--pvfile or -norm/--normfile: " + f"This is only supported for header only in legacy mode.") + if measurefile: + manager_kwargs["measurefile"] = measurefile + except ValueError as e: + return e.args[0] + + threads = getattr(args, "threads", 0) + if threads <= 0: + threads = get_num_threads() + + compute_threads = ThreadPoolExecutor(threads) + + # the manager object supports preloading of files (see below) for io parallelization + # and calculates the measure + manager = Manager(measures, segfile=segfile, **manager_kwargs) + read_lut = manager.make_read_hook(read_classes_from_lut) + if lut_file := getattr(args, "lut", None): + read_lut(lut_file, blocking=False) + # load these files in different threads to avoid waiting on IO + # (not parallel due to GIL though) + load_image = manager.make_read_hook(read_volume_file) + preload_image = partial(load_image, blocking=False) + preload_image(segfile) + if normfile is not None: + preload_image(normfile) + needs_pv_calc = manager.needs_pv_calculation() or not measure_only + if needs_pv_calc: + preload_image(pvfile) + + with manager.with_subject(subjects_dir, subject_id): + try: + _seg: ImageTuple = load_image(segfile, blocking=True) + seg, seg_data = _seg + pv_img, pv_data = None, None + norm, norm_data = None, None + + # trigger preprocessing operations on the pvfile like --mul + pv_preproc_future = None + if needs_pv_calc: + _pv: ImageTuple = load_image(pvfile, blocking=True) + pv_img, pv_data = _pv + + if not empty(pvfile_preproc := getattr(args, "pvfile_preproc", None)): + pv_preproc_future = compute_threads.submit( + preproc_image, pvfile_preproc, pv_data, + ) + + check_shape_affine(seg, pv_img, "segmentation", "pv_guide") + if normfile is not None: + _norm: ImageTuple = load_image(normfile, blocking=True) + norm, norm_data = _norm + check_shape_affine(seg, norm, "segmentation", "norm") + + except (IOError, RuntimeError, FileNotFoundError) as e: + return e.args[0] + + lut: Optional[pd.DataFrame] = None + if lut_file: + try: + lut = read_lut(lut_file) + # manager.lut = lut + except FileNotFoundError: + return ( + f"Could not find the ColorLUT in {lut_file}, make sure the --lut " + f"argument is valid." + ) + except Exception as exception: + return exception.args[0] + + if measure_only: + # in this mode, we do not output a data tabel anyways, so no need to compute + # all these PV values. + labels, exclude_id = np.zeros((0,), dtype=int), [] + else: try: - print(f'Saving {n} to {file}') - from FastSurferCNN.data_loader.data_utils import save_image - _header = seg.header.copy() - _header.set_data_dtype(dtype) - save_image(_header, seg.affine, maps[v], file, dtype) - except Exception: - import traceback - traceback.print_exc() + # construct the list of labels to calculate PV for + labels, exclude_id = infer_labels_excludeid(args, lut, seg_data) + except ValueError as e: + return e.args[0] + + if (_merged_labels := getattr(args, "merged_labels", None)) is None: + _merged_labels: Sequence[Sequence[int]] = () + merged_labels, measure_labels = infer_merged_labels( + manager, + labels, + merged_labels=_merged_labels, + merge_labels_start=10000, + ) + vox_vol = np.prod(seg.header.get_zooms()).item() + # more args to pass to pv_calc + kwargs = { + "vox_vol": vox_vol, + "legacy_freesurfer": legacy_freesurfer, + "threads": compute_threads, + "robust_percentage": getattr(args, "robust", None), + "patch_size": getattr(args, "patch_size", 16), + "merged_labels": merged_labels, + } + # more args to pass to write_segstatsfile + write_kwargs = { + "vox_vol": vox_vol, + "legacy_freesurfer": legacy_freesurfer, + "exclude": exclude_id, + "segfile": segfile, + "normfile": normfile, + "lut": lut_file, + "volume_precision": getattr(args, "volume_precision", "1"), + } + # ------ + # finished manager io here + # ------ + manager.compute_non_derived_pv(compute_threads) + + names = ["nbr", "nbr_means", "seg_means", "mix_coeff", "nbr_mix_coeff"] + save_maps_paths = (getattr(args, n, "") for n in names) + save_maps = any(bool(path) and path != Path() for path in save_maps_paths) + save_maps = save_maps and not measure_only + + if needs_pv_calc: + if pv_preproc_future is not None: + # wait for preprocessing options on pvfile + pv_data = pv_preproc_future.result() + out = pv_calc(seg_data, pv_data, norm_data, labels, return_maps=save_maps, **kwargs) + else: + out = None + if measure_only: + # if we are not computing partial volume effects, do not perform pv_calc + try: + if needs_pv_calc: + # make sure required PV measures get computed + dataframe = table_to_dataframe( + out, + bool(getattr(args, "empty", False)), + must_keep_ids=merged_labels.keys(), + ) + manager.update_pv_from_table(dataframe, measure_labels) + + manager.wait_write_brainvolstats(segstatsfile) + except RuntimeError as e: + return e.args[0] + print(f"Brain volume stats written to {segstatsfile}.") + duration = (perf_counter_ns() - start) / 1e9 + print(f"Calculation took {duration:.2f} seconds using up to {threads} threads.") + return 0 + + _io_futures = [] + if save_maps: + table, maps = out + dtypes = [np.int16] + [np.float32] * 4 + for name, dtype in zip(names, dtypes): + if not bool(file := getattr(args, name, "")) or file == Path(): + # skip "fullview"-files that are not defined + continue + print(f"Saving {name} to {file}...") + from FastSurferCNN.data_loader.data_utils import save_image + + _header = seg.header.copy() + _header.set_data_dtype(dtype) + _io_futures.append( + manager.executor.submit( + save_image, + _header, + seg.affine, + maps[name], + file, + dtype, + ), + ) + print("Done.") else: - table: List[PVStats] = pv_calc(seg_data, norm_data, labels, **kwargs) + table: list[PVStats] = out if lut is not None: - for i in range(len(table)): - lut_idx = lut["ID"] == table[i]["SegId"] - if lut_idx.any(): - table[i]["StructName"] = lut[lut_idx]["LabelName"].item() - elif "merged_labels" in kwargs and table[i]["SegId"] in kwargs["merged_labels"].keys(): - # noinspection PyTypeChecker - table[i]["StructName"] = "Merged-Label-" + str(table[i]["SegId"]) - else: - # make the label unknown - table[i]["StructName"] = "Unknown-Label" - lut_idx = {i: lut["ID"] == i for i in exclude_id} - exclude = {i: lut[lut_idx[i]]["LabelName"].item() if lut_idx[i].any() else "" for i in exclude_id} - else: - exclude = {i: "" for i in exclude_id} - dataframe = pd.DataFrame(table, index=np.arange(len(table))) - if not bool(getattr(args, "empty", False)): - dataframe = dataframe[dataframe["NVoxels"] != 0] - dataframe = dataframe.sort_values("SegId") - dataframe.index = np.arange(1, len(dataframe) + 1) - lines = [] - if getattr(args, 'in_dir', None): - lines.append(f'SUBJECTS_DIR {getattr(args, "in_dir")}') - if getattr(args, 'sid', None): - lines.append(f'subjectname {getattr(args, "sid")}') - lines.append("compatibility with freesurfer's mri_segstats: " + - ("legacy" if kwargs["legacy_freesurfer"] else "fixed")) - - write_statsfile(args.segstatsfile, dataframe, - exclude=exclude, vox_vol=kwargs["vox_vol"], segfile=args.segfile, - normfile=args.normfile, lut=getattr(args, "lut", None), extra_header=lines) - print(f"Partial volume stats for {dataframe.shape[0]} labels written to {args.segstatsfile}.") - duration = (time.perf_counter_ns() - start) / 1e9 + update_structnames(table, lut, merged_labels) + + dataframe = table_to_dataframe( + table, + bool(getattr(args, "empty", False)), + must_keep_ids=merged_labels.keys(), + ) + lines = format_parameters(SUBJECT_DIR=subjects_dir, subjectname=subject_id) + + # wait for computation of measures and return an error message if errors occur + errors = list(manager.wait_compute()) + if not empty(errors): + error_messages = ["Some errors occurred during measure computation:"] + error_messages.extend(map(lambda e: f"{type(e).__name__}: {e.args[0]}", errors)) + return "\n - ".join(error_messages) + dataframe = manager.update_pv_from_table(dataframe, measure_labels) + lines.extend(manager.format_measures()) + + write_statsfile( + segstatsfile, + dataframe, + extra_header=lines, + **write_kwargs, + ) + print(f"Partial volume stats for {dataframe.shape[0]} labels written to " + f"{segstatsfile}.") + duration = (perf_counter_ns() - start) / 1e9 print(f"Calculation took {duration:.2f} seconds using up to {threads} threads.") + + for _io_fut in _io_futures: + if (e := _io_fut.exception()) is not None: + logging.getLogger(__name__).exception(e) + return 0 -def write_statsfile(segstatsfile: str, dataframe: pd.DataFrame, vox_vol: float, exclude: Optional[Dict[int, str]] = None, - segfile: str = None, normfile: str = None, lut: str = None, extra_header: Sequence[str] = ()): - """Write a segstatsfile very similar and compatible with mri_segstats output. +def infer_merged_labels( + manager: "Manager", + used_labels: Iterable[int], + merged_labels: Sequence[Sequence[int]] = (), + merge_labels_start: int = 0, +) -> tuple[dict[int, Sequence[int]], dict[int, Sequence[int]]]: + """ + + Parameters + ---------- + manager : Manager + The brainvolstats Manager object to get virtual labels. + used_labels : Iterable[int] + A list of labels at that are already in use. + merged_labels : Sequence[Sequence[int]], default=() + The list of merge labels (first value is SegId, then SegIds it sums across). + merge_labels_start : int, default=0 + Start index to start at for finding multi-class merged label groups. + + Returns + ------- + all_merged_labels : dict[int, Sequence[int]] + The dictionary of all merged labels (via :class:`PVMeasure`s as well as + `merged_labels`). + """ + _merged_labels = {} + if not empty(merged_labels): + _merged_labels = {lab: vals for lab, *vals in merged_labels} + all_labels = list(_merged_labels.keys()) + list(used_labels) + _pv_merged_labels = manager.get_virtual_labels( + i for i in range(merge_labels_start, np.iinfo(int).max) if i not in all_labels + ) + + all_merged_labels = _merged_labels.copy() + all_merged_labels.update(_pv_merged_labels) + return all_merged_labels, _pv_merged_labels + + +def table_to_dataframe( + table: list[PVStats], + report_empty: bool = True, + must_keep_ids: Optional[Container[int]] = None, +) -> pd.DataFrame: + """ + Convert the list of PVStats dictionaries into a dataframe. + + Parameters + ---------- + table : list[PVStats] + List of partial volume stats dictionaries. + report_empty : bool, default=True + Whether empty regions should be part of the dataframe. + must_keep_ids : Container[int], optional + Specifies a list of segids to never remove from the table. + + Returns + ------- + pandas.DataFrame + The DataFrame object of all columns and rows in table. + """ + df = pd.DataFrame(table, index=np.arange(len(table))) + if not report_empty: + df_mask = df["NVoxels"] != 0 + if must_keep_ids and isinstance(must_keep_ids, Container): + df_mask |= df["SegId"].map(lambda x: x in must_keep_ids) + df = df[df_mask] + df = df.sort_values("SegId") + df.index = np.arange(1, len(df) + 1) + return df + + +def update_structnames( + table: list[PVStats], + lut: pd.DataFrame, + merged_labels: Optional[dict[_IntType, Sequence[_IntType]]] = None +) -> None: + """ + Update StructNames from `lut` and `merged_labels` in `table`. Parameters ---------- - segstatsfile : str - path to the output file + table : list[PVStats] + List of partial volume stats dictionaries. + lut : pandas.DataFrame + A pandas DataFrame object containing columns 'ID' and 'LabelName', which serves + as a lookup table for the structure names. + merged_labels : dict[int, Sequence[int]], optional + The dictionary with merged labels. + """ + # table is a list of dicts, so we can add the StructName to the dict + for i in range(len(table)): + lut_idx = lut["ID"] == table[i]["SegId"] + if lut_idx.any(): + # get the label name from the lut, if it is in there + table[i]["StructName"] = lut[lut_idx]["LabelName"].item() + elif merged_labels is not None and table[i]["SegId"] in merged_labels.keys(): + # auto-generate a name for merged labels + table[i]["StructName"] = "Merged-Label-" + str(table[i]["SegId"]) + else: + # make the label unknown + table[i]["StructName"] = "Unknown-Label" + # lut_idx = {i: lut["ID"] == i for i in exclude_id} + # _ids = [(i, lut_idx[i]) for i in exclude_id] + + +def format_parameters(**kwargs) -> list[str]: + """ + Formats each keyword argument passed as a pair of key and value. + + Returns + ------- + list[str] + A list of one string per keyword arg formatted as a string. + """ + return [f"{k} {v}" for k, v in kwargs.items() if v] + + +def write_statsfile( + segstatsfile: Path | str, + dataframe: pd.DataFrame, + vox_vol: float, + exclude: Optional[Sequence[int | str]] = None, + segfile: Optional[Path | str] = None, + normfile: Optional[Path | str] = None, + pvfile: Optional[Path | str] = None, + lut: Optional[Path | str] = None, + report_empty: bool = False, + extra_header: Sequence[str] = (), + norm_name: str = "norm", + norm_unit: str = "MR", + volume_precision: str = "1", + legacy_freesurfer: bool = False, +) -> None: + """ + Write a segstatsfile very similar and compatible with mri_segstats output. + + Parameters + ---------- + segstatsfile : Path, str + Path to the output file. dataframe : pd.DataFrame - data to write into the file + Data to write into the file. vox_vol : float - voxel volume for the header - exclude : Optional[Dict[int, str]] - dictionary of ids and class names that were excluded from the pv analysis (default: None) - segfile : str - path to the segmentation file (default: empty) - normfile : str - path to the bias-field corrected image (default: empty) - lut : str - path to the lookup table to find class names for label ids (default: empty) - extra_header : Sequence[str] - sequence of additional lines to add to the header. The initial # and newline characters will be - added. Should not include newline characters (expect at the end of strings). (default: empty sequence) - + Voxel volume for the header. + exclude : Sequence[Union[int, str]], optional + Sequence of ids and class names that were excluded from the pv analysis + (default: None). + segfile : Path, str, optional + Path to the segmentation file (default: empty). + normfile : Path, str, optional + Path to the bias-field corrected image (default: empty). + pvfile : Path, str, optional + Path to file used to compute the PV effects (default: empty). + lut : Path, str, optional + Path to the lookup table to find class names for label ids (default: empty). + report_empty : bool, default=False + Do not skip non-empty regions in the lut. + extra_header : Sequence[str], default=() + Sequence of additional lines to add to the header. The initial # and newline + characters will be added. Should not include newline characters (expect at the + end of strings). + norm_name : str, default="norm" + Name of the intensity image. + norm_unit : str, default="MR" + Unit of the intensity image. + volume_precision : str, default="1" + Number of digits after the comma for volume. Forced to 1 for legacy_freesurfer. + legacy_freesurfer : bool, default=False + Whether the script ran with the legacy freesurfer option. """ - import sys - import os import datetime - def file_annotation(_fp, name: str, file: Optional[str]) -> None: - if file is not None: - _fp.write(f"# {name} {file}\n") - stat = os.stat(file) - if stat.st_mtime: - mtime = datetime.datetime.fromtimestamp(stat.st_mtime) - _fp.write(f"# {name}Timestamp {mtime:%Y/%m/%d %H:%M:%S}\n") + def _title(file: IO) -> None: + """ + Write the file title to a file. + """ + file.write("# Title Segmentation Statistics\n#\n") - os.makedirs(os.path.dirname(segstatsfile), exist_ok=True) - with open(segstatsfile, "w") as fp: - fp.write("# Title Segmentation Statistics\n#\n" - "# generating_program segstats.py\n" - "# cmdline " + " ".join(sys.argv) + "\n") + def _system_info(file: IO) -> None: + """ + Write the call and system information comments of the header to a file. + """ + import os + import sys + from FastSurferCNN.version import read_and_close_version + file.write( + "# generating_program segstats.py\n" + "# FastSurfer_version " + read_and_close_version() + "\n" + "# cmdline " + " ".join(sys.argv) + "\n" + ) if os.name == 'posix': - fp.write(f"# sysname {os.uname().sysname}\n" - f"# hostname {os.uname().nodename}\n" - f"# machine {os.uname().machine}\n") + file.write( + f"# sysname {os.uname().sysname}\n" + f"# hostname {os.uname().nodename}\n" + f"# machine {os.uname().machine}\n" + ) else: from socket import gethostname - fp.write(f"# platform {sys.platform}\n" - f"# hostname {gethostname()}\n") + file.write( + f"# platform {sys.platform}\n" + f"# hostname {gethostname()}\n" + ) from getpass import getuser + try: - fp.write(f"# user {getuser()}\n") + file.write(f"# user {getuser()}\n") except KeyError: - fp.write(f"# user UNKNOWN\n") - - fp.write(f"# anatomy_type volume\n#\n") + file.write(f"# user UNKNOWN\n") - file_annotation(fp, "SegVolFile", segfile) - file_annotation(fp, "ColorTable", lut) - file_annotation(fp, "PVVolFile", normfile) - if exclude is not None and len(exclude) > 0: - if any(len(e) > 0 for e in exclude.values()): - fp.write(f"# Excluding {', '.join(filter(lambda x: len(x) > 0, exclude.values()))}\n") - fp.write("".join([f"# ExcludeSegId {id}\n" for id in exclude.keys()])) + def _extra_header(file: IO, lines_extra_header: Iterable[str]) -> None: + """ + Write the extra_header (including measures) to a file. + """ warn_msg_sent = False - for i, line in enumerate(extra_header): + for i, line in enumerate(lines_extra_header): if line.endswith("\n"): line = line[:-1] if line.startswith("# "): @@ -399,135 +1218,260 @@ def file_annotation(_fp, name: str, file: Optional[str]) -> None: if "\n" in line: line = line.replace("\n", " ") from warnings import warn - warn_msg_sent or warn(f"extra_header[{i}] includes embedded newline characters. " - "Replacing all newline characters with .") + + warn_msg_sent or warn( + f"extra_header[{i}] includes embedded newline characters. " + "Replacing all newline characters with ." + ) warn_msg_sent = True - fp.write(f"# {line}\n") - fp.write(f"#\n") - if lut is not None: - fp.write("# Only reporting non-empty segmentations\n") - fp.write(f"# VoxelVolume_mm3 {vox_vol}\n") + file.write(f"# {line}\n") + + def _file_annotation(file: IO, name: str, path_to_annotate: Optional[Path]) -> None: + """ + Write the annotation to file/path to a file. + """ + if path_to_annotate is not None: + file.write(f"# {name} {path_to_annotate}\n") + stat = path_to_annotate.stat() + if stat.st_mtime: + mtime = datetime.datetime.fromtimestamp(stat.st_mtime) + file.write(f"# {name}Timestamp {mtime:%Y/%m/%d %H:%M:%S}\n") + + def _extra_parameters( + file: IO, + _voxvol: float, + _exclude: Sequence[int | str], + _report_empty: bool = False, + _lut: Optional[Path] = None, + _leg_freesurfer: bool = False, + ) -> None: + """ + Write the comments of the table header to a file. + """ + if _exclude is not None and len(_exclude) > 0: + exclude_str = list(filter(lambda x: isinstance(x, str), _exclude)) + exclude_int = list(filter(lambda x: isinstance(x, int), _exclude)) + if len(exclude_str) > 0: + excl_names = ', '.join(exclude_str) + file.write(f"# Excluding {excl_names}\n") + if len(exclude_int) > 0: + file.write(f"# ExcludeSegId {' '.join(map(str, exclude_int))}\n") + if _lut is not None and not _report_empty: + file.write("# Only reporting non-empty segmentations\n") + file.write("# compatibility with freesurfer's mri_segstats: " + + ("legacy" if _leg_freesurfer else "fixed") + "\n") + file.write(f"# VoxelVolume_mm3 {_voxvol}\n") + + def _is_norm_column(name: str) -> bool: + """Check whether the column `name` is a norm-column.""" + return name in ("Mean", "StdDev", "Min", "Max", "Range") + + def _column_name(name: str) -> str: + """Convert the column name""" + return norm_name + name if _is_norm_column(name) else name + + def _column_unit(name: str) -> str: + if _is_norm_column(name): + return norm_unit + elif name == "Volume_mm3": + return "mm^3" + elif name == "NVoxels": + return "unitless" + return "NA" + + def _column_description(name: str) -> str: + if _is_norm_column(name): + return f"Intensity {_column_name(name)}" + return { + "Index": "Index", "SegId": "Segmentation Id", "NVoxels": "Number of Voxels", + "Volume_mm3": "Volume", "StructName": "Structure Name" + }.get(name, "Unknown Column") + + def _column_format(name: str) -> str: + if _is_norm_column(name): + return "{: >10.4f}" + elif name == "Volume_mm3": + return f"{{: >10.{volume_precision}f}}" + elif name in ("Index", "SegId"): + return "{: >3d}" + elif name == "NVoxels": + return "{: >9d}" + return " {: <30s}" + + def _table_header(file: IO, _dataframe: pd.DataFrame) -> None: + """Write the comments of the table header to a file.""" + columns = [col for col in COLUMNS if col in _dataframe.columns] + for i, col in enumerate(columns): + file.write(f"# TableCol {i + 1: >2d} ColHeader {_column_name(col)}\n" + f"# TableCol {i + 1: >2d} FieldName {_column_description(col)}\n" + f"# TableCol {i + 1: >2d} Units {_column_unit(col)}\n") + file.write(f"# NRows {len(_dataframe): >2d}\n" + f"# NTableCols {len(columns): >2d}\n") + file.write("# ColHeaders " + " ".join(map(_column_name, columns)) + "\n") + + def _table_body(file: IO, _dataframe: pd.DataFrame) -> None: + """Write the volume stats from _dataframe to a file.""" + columns = [col for col in COLUMNS if col in _dataframe.columns] + fmt = " ".join(_column_format(k) for k in columns) + for index, row in _dataframe.iterrows(): + data = [row[k] for k in columns] + file.write(fmt.format(*data) + "\n") + + if not isinstance(segstatsfile, Path): + segstatsfile = Path(segstatsfile) + if normfile is not None and not isinstance(normfile, Path): + normfile = Path(normfile) + if segfile is not None and not isinstance(segfile, Path): + segfile = Path(segfile) + + if exclude is not None and not isinstance(exclude, Sequence): + raise RuntimeError("exclude must be a sequence of ints or None!") + + segstatsfile.parent.mkdir(exist_ok=True) + with open(segstatsfile, "w") as fp: + _title(fp) + _system_info(fp) + fp.write(f"# anatomy_type volume\n#\n") + _extra_header(fp, extra_header) + + _file_annotation(fp, "SegVolFile", segfile) + # Annot subject hemi annot + # Label subject hemi LabelFile + _file_annotation(fp, "ColorTable", lut) + # ColorTableFromGCA + # GCATimeStamp + # masking applies to PV, not to the Measure Mask + # MaskVolFile MaskThresh MaskSign MaskFrame MaskInvert + _file_annotation(fp, "InVolFile", normfile) + _file_annotation(fp, "PVVolFile", pvfile) + _extra_parameters(fp, vox_vol, exclude, report_empty, lut, legacy_freesurfer) # add the Index column, if it is not in dataframe if "Index" not in dataframe.columns: index_df = pd.DataFrame.from_dict({"Index": dataframe.index}) index_df.index = dataframe.index dataframe = index_df.join(dataframe) + _table_header(fp, dataframe) + _table_body(fp, dataframe) - for i, col in enumerate(dataframe.columns): - for v, name in zip((col, FIELDS.get(col, "Unknown Column"), UNITS.get(col, "NA")), - ("ColHeader", "FieldName", "Units ")): - fp.write(f"# TableCol {i+1: 2d} {name} {v}\n") - fp.write(f"# NRows {len(dataframe)}\n" - f"# NTableCols {len(dataframe.columns)}\n") - fp.write("# ColHeaders " + " ".join(dataframe.columns) + "\n") - max_index = int(np.ceil(np.log10(np.max(dataframe.index)))) - def fmt_field(code: str, data) -> str: - is_s, is_f, is_d = code[-1] == "s", code[-1] == "f", code[-1] == "d" - filler = "<" if is_s else " >" - prec = int(data.dropna().map(len).max() if is_s else np.ceil(np.log10(data.max()))) - if is_f: - prec += int(code[-2]) + 1 - return filler + str(prec) + code - - fmts = ("{:" + fmt_field(FORMATS[k], dataframe[k]) + "}" for k in dataframe.columns) - fmt = " ".join(fmts) + "\n" - for index, row in dataframe.iterrows(): - data = [row[k] for k in dataframe.columns] - fp.write(fmt.format(*data)) - - -# Label mapping functions (to aparc (eval) and to label (train)) -def read_classes_from_lut(lut_file): - """Modify from datautils to allow support for FreeSurfer-distributed ColorLUTs. - - Read in **FreeSurfer-like** LUT table +def preproc_image( + ops: Sequence[str], + data: npt.NDArray[_NumberType] +) -> npt.NDArray[_NumberType]: + """ + Apply preprocessing operations to data. Performs, --mul, --abs, --sqr, --sqrt + operations in that order. Parameters ---------- - lut_file : - path and name of FreeSurfer-style LUT file with classes of interest - Example entry: - ID LabelName R G B A - 0 Unknown 0 0 0 0 - 1 Left-Cerebral-Exterior 70 130 180 0 + ops : Sequence[str] + Sequence of operations to perform from 'mul=', 'div=', 'sqr', + 'abs', and 'sqrt'. + data : np.ndarray + Data to perform operations on. + Returns ------- - DataFrame with ids present, name of ids, color for plotting - + np.ndarray + Data after ops are performed on it. """ - if lut_file.endswith(".tsv"): - return pd.read_csv(lut_file, sep="\t") - - # Read in file - names = { - "ID": "int", - "LabelName": "str", - "Red": "int", - "Green": "int", - "Blue": "int", - "Alpha": "int" - } - return pd.read_csv(lut_file, delim_whitespace=True, index_col=False, skip_blank_lines=True, - comment="#", header=None, names=names.keys(), dtype=names) + mul_ops = np.asarray([o.startswith("mul=") or o.startswith("div=") for o in ops]) + if np.any(mul_ops): + mul_op = ops[mul_ops.nonzero()[0][-1].item()] + factor = float(mul_op[4:]) + data = (np.multiply if mul_op.startswith("mul=") else np.divide)(data, factor) + if "abs" in ops: + data = np.abs(data) + if "sqr" in ops: + data = data * data + if "sqrt" in ops: + data = np.sqrt(data) + return data + + +def seg_borders( + _array: _ArrayType, + label: np.integer | bool, + out: Optional[npt.NDArray[bool]] = None, + cmp_dtype: npt.DTypeLike = "int8", +) -> npt.NDArray[bool]: + """ + Handle to fast 6-connected border computation. + Parameters + ---------- + _array : numpy.ndarray + Image to compute borders from, typically either a label image or a binary mask. + label : int, bool + Which classes to consider for border computation (True/False for binary mask). + out : nt.NDArray[bool], optional + The array for inplace computation. + cmp_dtype : npt.DTypeLike, default=int8 + The data type to use for border laplace computation. -def seg_borders(_array: _ArrayType, label: Union[np.integer, bool], - out: Optional[_ArrayType] = None, cmp_dtype: npt.DTypeLike = "int8") -> _ArrayType: - """Handle to fast 6-connected border computation.""" + Returns + ------- + npt.NDArray[bool] + A binary mask with border voxels as True. + """ # binarize - bin_array = _array if np.issubdtype(_array.dtype, bool) else _array == label + bin_array: npt.NDArray[bool] + bin_array = _array if np.issubdtype(_array.dtype, bool) else np.equal(_array, label) # scipy laplace is about 20% faster than skimage laplace on cpu from scipy.ndimage import laplace - def _laplace(data): - """[MISSING]. - - Parameters - ---------- - data : - [MISSING] - - Returns - ------- - bool - [MISSING] - - """ - return laplace(data.astype(cmp_dtype)) != np.asarray(0., dtype=cmp_dtype) - # laplace - if out is not None: - out[:] = _laplace(bin_array) - return out + if np.issubdtype(cmp_dtype, bool): + laplace_data = laplace(bin_array).astype(bool) + if out is not None: + out[:] = laplace_data + laplace_data = out + return laplace_data else: - return _laplace(bin_array) - + zeros = np.asarray(0., dtype=cmp_dtype) + # laplace + laplace_data = laplace(bin_array.astype(cmp_dtype)) + return np.not_equal(laplace_data, zeros, out=out) + + +def borders( + _array: _ArrayType, + labels: Iterable[np.integer] | bool, + max_label: Optional[np.integer] = None, + six_connected: bool = True, + out: Optional[npt.NDArray[bool]] = None, +) -> npt.NDArray[bool]: + """ + Handle to fast border computation. -def borders(_array: _ArrayType, labels: Union[Iterable[np.integer], bool], max_label: Optional[np.integer] = None, - six_connected: bool = True, out: Optional[_ArrayType] = None) -> _ArrayType: - """Handle to fast border computation. + This is an efficient implementation, for multiple/many classes between which borders + should be computed. Parameters ---------- _array : _ArrayType - [MISSING] - labels : Union[Iterable[np.integer], bool] - [MISSING] - max_label : Optional[np.integer], Optional - [MISSING] - six_connected : bool - [MISSING] - out : Optional[_ArrayType] - [MISSING] + Input labeled image or binary image. + labels : Iterable[int], bool + List of labels for which borders will be computed. + If labels is True, _array is treated as a binary mask. + max_label : int, optional + The maximum label ot consider. If None, the maximum label in the array is used. + six_connected : bool, default=True + If True, 6-connected borders (must share a face) are computed, + otherwise 26-connected borders (must share a vertex) are computed. + out : npt.NDArray[bool], optional + Output array to store the computed borders. Returns ------- - _ArrayType - [MISSING] + npt.NDArray[bool] + Binary mask of border voxels. + Raises + ------ + ValueError + If labels does not fit to _array (binary mask and integer and vice-versa). """ dim = _array.ndim - array_alloc = partial(np.full, dtype=_array.dtype) _shape_plus2 = [s + 2 for s in _array.shape] if labels is True: # already binarized @@ -536,14 +1480,16 @@ def borders(_array: _ArrayType, labels: Union[Iterable[np.integer], bool], max_l cmp = np.logical_xor else: if np.issubdtype(_array, bool): - raise ValueError("If labels is a list/iterable, the array should not be boolean.") + raise ValueError( + "If labels is a list/iterable, the array should not be boolean." + ) def cmp(a, b): return a == b if max_label is None: max_label = _array.max().item() - lookup = array_alloc((max_label + 1,), fill_value=0) + lookup = np.zeros((max_label + 1,), dtype=_array.dtype) # filter labels from labels that are bigger than max_label labels = list(filter(lambda x: x <= max_label, labels)) if 0 not in labels: @@ -551,335 +1497,555 @@ def cmp(a, b): lookup[labels] = np.arange(len(labels), dtype=lookup.dtype) _array = lookup[_array] logical_or = np.logical_or - __array = array_alloc(_shape_plus2, fill_value=0) - __array[(slice(1, -1),) * dim] = _array + # pad array by 1 voxel of zeros all around + padded = np.pad(_array, 1) - mid = (slice(1, -1),) * dim if six_connected: - def ii(axis: int, off: int, is_mid: bool) -> Tuple[slice, ...]: - other_slices = mid[:1] if is_mid else (slice(None),) - return other_slices * axis + (slice(off, -1 if off == 0 else None),) + other_slices * ( - dim - axis - 1) - - nbr_same = [cmp(__array[ii(i, 0, True)], __array[ii(i, 1, True)]) for i in range(dim)] - nbr_same = [logical_or(_ns[ii(i, 0, False)], _ns[ii(i, 1, False)]) for i, _ns in enumerate(nbr_same)] + def indexer(axis: int, is_mid: bool) -> tuple[SlicingTuple, SlicingTuple]: + full_slice = (slice(1, -1),) if is_mid else (slice(None),) + more_axes = dim - axis - 1 + return ((full_slice * axis + (slice(0, -1),) + full_slice * more_axes), + (full_slice * axis + (slice(1, None),) + full_slice * more_axes)) + + # compare the [padded] image/array in all directions, x, y, z... + # ([0], 0, 2, 2, 2, [0]) ==> (False, True, False, False, True) for each dim + # is_mid=True: drops padded values in unaffected axes + indexes = (indexer(i, is_mid=True) for i in range(dim)) + nbr_same = [cmp(padded[i], padded[j]) for i, j in indexes] + # merge neighbors so each border is 2 thick (left and right of change) + # (False, True, False, False, True) ==> + # ((False, True), (True, False), (False, False), (False, True)) for each dim + # is_mid=False: padded values already dropped + indexes = (indexer(i, is_mid=False) for i in range(dim)) + nbr_same = [(nbr_[i], nbr_[j]) for (i, j), nbr_ in zip(indexes, nbr_same)] + from itertools import chain + nbr_same = list(chain.from_iterable(nbr_same)) else: - def ii(off: Iterable[int]) -> Tuple[slice, ...]: - return tuple(slice(o, None if o == 2 else o - 3) for o in off) - - nbr_same = [cmp(__array[mid], __array[ii(i - 1)]) for i in np.ndindex((3,) * dim) if np.all(i != 1)] + # all indexes of the neighbors: ((0, 0, 0), (0, 0, 1) ... (2, 2, 2)) + ndindexes = tuple(np.ndindex((3,) * dim)) + + def nbr_i(__array: _ArrayType, neighbor_index: int) -> _ArrayType: + """Assuming a padded array __array, returns just the neighbor_index-th + neighbors throughout the array.""" + # sample from 1d neighbor index to ndindex + nbr_ndid = ndindexes[neighbor_index] # e.g. (1, 0, 2) + slice_ndindex = tuple(slice(o, None if o == 2 else o - 3) for o in nbr_ndid) + return __array[slice_ndindex] + + # compare the array (center point) with all neighboring voxels + # neighbor samples the neighboring voxel in the padded array + nbr_same = [cmp(_array, nbr_i(padded, i)) for i in range(3**dim) if i != 2**dim] + + # reduce the per-direction/per-neighbor binary arrays into one array return np.logical_or.reduce(nbr_same, out=out) -def unsqueeze(matrix, axis: Union[int, Sequence[int]] = -1): - """Unsqueeze the matrix. - - Allows insertions of axis into the data/tensor, see numpy.expand_dims. This expands the torch.unsqueeze - syntax to allow unsqueezing multiple axis at the same time. - - Parameters - ---------- - matrix : - Matrix to unsqueeze - axis : Union[int, Sequence[int]] - Axis for unsqueezing - +def pad_slicer( + slicer: Sequence[slice], + whalf: int, + img_size: np.ndarray | Sequence[float], +) -> tuple[SlicingTuple, SlicingTuple]: """ - if isinstance(matrix, np.ndarray): - return np.expand_dims(matrix, axis=axis) - - -def grow_patch(patch: Sequence[slice], whalf: int, img_size: Union[np.ndarray, Sequence[float]]) -> Tuple[ - Tuple[slice, ...], Tuple[slice, ...]]: - """Create two slicing tuples for indexing ndarrays/tensors that 'grow' and re-'ungrow' the patch `patch` by `whalf` (also considering the image shape). + Create two slicing tuples for indexing ndarrays/tensors that 'grow' and + re-'ungrow' the patch `patch` by `whalf` (also considering the image shape). Parameters ---------- - patch : Sequence[slice] - [MISSING] + slicer : Sequence[slice] + Input slicing tuple. whalf : int - [MISSING] - img_size : Union[np.ndarray, Sequence[float]] - [MISSING] + How much to pad/grow the slicing tuple all around. + img_size : np.ndarray, Sequence[float] + Shape of the image. Returns ------- - Tuple[Tuple[slice, ...], Tuple[slice, ...]] - [MISSING] + SlicingTuple + Tuple of slice-objects to go from image to padded patch. + SlicingTuple + Tuple of slice-objects to go from padded patch to patch. """ # patch start/stop - _patch = np.asarray([(s.start, s.stop) for s in patch]) + _patch = np.asarray([(s.start, s.stop) for s in slicer]) start, stop = _patch.T # grown patch start/stop _start, _stop = np.maximum(0, start - whalf), np.minimum(stop + whalf, img_size) + def _slice(start_end: npt.NDArray[int]) -> slice: + _start, _end = start_end + return slice(_start.item(), None if _end.item() == 0 else _end.item()) # make grown patch and grown patch to patch - grown_patch = tuple(slice(s.item(), e.item()) for s, e in zip(_start, _stop)) - ungrow_patch = tuple( - slice(s.item(), None if e.item() == 0 else e.item()) for s, e in zip(start - _start, stop - _stop)) - return grown_patch, ungrow_patch + padded_slicer = tuple(slice(s.item(), e.item()) for s, e in zip(_start, _stop)) + unpadded_slicer = tuple(map(_slice, zip(start - _start, stop - _stop))) + return padded_slicer, unpadded_slicer -def uniform_filter(arr: _ArrayType, filter_size: int, fillval: float, - patch: Optional[Tuple[slice, ...]] = None, out: Optional[_ArrayType] = None) -> _ArrayType: - """Apply a uniform filter (with kernel size `filter_size`) to `input`. +def uniform_filter( + data: _ArrayType, + filter_size: int, + fillval: float = 0., + slicer_patch: Optional[SlicingTuple] = None, +) -> _ArrayType: + """ + Apply a uniform filter (with kernel size `filter_size`) to `input`. The uniform filter is normalized (weights add to one). Parameters ---------- - arr : _ArrayType - [MISSING] + data : _ArrayType + Data to perform uniform filter on. filter_size : int - [MISSING] - fillval : float - [MISSING] - patch : Optional[Tuple[slice, ...]] - [MISSING] - out : Optional[_ArrayType] - [MISSING] + Size of the filter. + fillval : float, default=0 + Value to fill around the image. + slicer_patch : SlicingTuple, optional + Sub_region of data to crop to (e.g. to undo the padding (default: full image). Returns ------- _ArrayType - [MISSING] + The filtered data. """ - _patch = (slice(None),) if patch is None else patch - arr = arr.astype(float) + _patch = (slice(None),) if slicer_patch is None else slicer_patch + data = data.astype(float) from scipy.ndimage import uniform_filter def _uniform_filter(_arr, out=None): - return uniform_filter(_arr, size=filter_size, mode='constant', cval=fillval, output=out)[_patch] + uni_filt = uniform_filter( + _arr, + size=filter_size, + mode="constant", + cval=fillval, + output=out, + ) + return uni_filt[_patch] - if out is not None: - _uniform_filter(arr, out) - return out - return _uniform_filter(arr) + return _uniform_filter(data) @overload -def pv_calc(seg: npt.NDArray[_IntType], norm: np.ndarray, labels: Sequence[_IntType], patch_size: int = 32, - vox_vol: float = 1.0, eps: float = 1e-6, robust_percentage: Optional[float] = None, - merged_labels: Optional[VirtualLabel] = None, threads: int = -1, return_maps: False = False, - legacy_freesurfer: bool = False) -> List[PVStats]: - """[MISSING].""" +def pv_calc( + seg: npt.NDArray[_IntType], + pv_guide: np.ndarray, + norm: np.ndarray, + labels: npt.ArrayLike, + patch_size: int = 32, + vox_vol: float = 1.0, + eps: float = 1e-6, + robust_percentage: Optional[float] = None, + merged_labels: Optional[VirtualLabel] = None, + threads: int | Executor = -1, + return_maps: False = False, + legacy_freesurfer: bool = False, +) -> list[PVStats]: ... @overload -def pv_calc(seg: npt.NDArray[_IntType], norm: np.ndarray, labels: Sequence[_IntType], - patch_size: int = 32, vox_vol: float = 1.0, eps: float = 1e-6, robust_percentage: Optional[float] = None, - merged_labels: Optional[VirtualLabel] = None, threads: int = -1, return_maps: True = True, - legacy_freesurfer: bool = False) \ - -> Tuple[List[PVStats], Dict[str, Dict[int, np.ndarray]]]: - """[MISSING].""" +def pv_calc( + seg: npt.NDArray[_IntType], + pv_guide: np.ndarray, + norm: np.ndarray, + labels: npt.ArrayLike, + patch_size: int = 32, + vox_vol: float = 1.0, + eps: float = 1e-6, + robust_percentage: Optional[float] = None, + merged_labels: Optional[VirtualLabel] = None, + threads: int | Executor = -1, + return_maps: True = True, + legacy_freesurfer: bool = False, +) -> tuple[list[PVStats], dict[str, dict[int, np.ndarray]]]: ... -def pv_calc(seg: npt.NDArray[_IntType], norm: np.ndarray, labels: Sequence[_IntType], - patch_size: int = 32, vox_vol: float = 1.0, eps: float = 1e-6, robust_percentage: Optional[float] = None, - merged_labels: Optional[VirtualLabel] = None, threads: int = -1, return_maps: bool = False, - legacy_freesurfer: bool = False) \ - -> Union[List[PVStats], Tuple[List[PVStats], Dict[str, np.ndarray]]]: - """Compute volume effects. +def pv_calc( + seg: npt.NDArray[_IntType], + pv_guide: np.ndarray, + norm: Optional[np.ndarray], + labels: npt.ArrayLike, + patch_size: int = 32, + vox_vol: float = 1.0, + eps: float = 1e-6, + robust_percentage: float | None = None, + merged_labels: VirtualLabel | None = None, + threads: int | Executor = -1, + return_maps: bool = False, + legacy_freesurfer: bool = False, +) -> list[PVStats] | tuple[list[PVStats], dict[str, np.ndarray]]: + """ + Compute volume effects. Parameters ---------- - seg : npt.NDArray[_IntType] - Segmentation array with segmentation labels + seg : np.ndarray + Segmentation array with segmentation labels. + pv_guide : np.ndarray + Image to use to calculate partial volume effects from. norm : np.ndarray - bias - labels : Sequence[_IntType] - Which labels are of interest - patch_size : int - Size of patches (Default value = 32) - vox_vol : float - volume per voxel (Default value = 1.0) - eps : float - threshold for computation of equality (Default value = 1e-6) - robust_percentage : Optional[float] - fraction for robust calculation of statistics (Default value = None) - merged_labels : Optional[VirtualLabel] - defines labels to compute statistics for that are (Default value = None) - threads : int - Number of parallel threads to use in calculation (Default value = -1) - return_maps : bool - returns a dictionary containing the computed maps (Default value = False) - legacy_freesurfer : bool - whether to use a freesurfer legacy compatibility mode to exactly replicate freesurfer (Default value = False) + Intensity image to use to calculate image statistics from. + labels : array_like + Which labels are of interest. + patch_size : int, default=32 + Size of patches. + vox_vol : float, default=1.0 + Volume per voxel. + eps : float, default=1e-6 + Threshold for computation of equality. + robust_percentage : float, optional + Fraction for robust calculation of statistics. + merged_labels : VirtualLabel, optional + Defines labels to compute statistics for that are. + threads : int, concurrent.futures.Executor, default=-1 + Number of parallel threads to use in calculation, alternatively an executor + object. + return_maps : bool, default=False + Returns a dictionary containing the computed maps. + legacy_freesurfer : bool, default=False + Whether to use a freesurfer legacy compatibility mode to exactly replicate + freesurfer. Returns ------- - Union[List[PVStats],Tuple[List[PVStats],Dict[str,np.ndarray]]] - Table (list of dicts) with keys SegId, NVoxels, Volume_mm3, StructName, normMean, normStdDev, - normMin, normMax, and normRange. (Note: StructName is unfilled) - if return_maps: a dictionary with the 5 meta-information pv-maps: - nbr: An image of alternative labels that were considered instead of the voxel's label - nbrmean: The local mean intensity of the label nbr at the specific voxel - segmean: The local mean intensity of the primary label at the specific voxel - pv: The partial volume of the primary label at the location - ipv: The partial volume of the alternative (nbr) label at the location - ipv: The partial volume of the alternative (nbr) label at the location - + pv_stats : list[PVStats] + Table (list of dicts) with keys SegId, NVoxels, Volume_mm3, Mean, StdDev, Min, + Max, and Range. + maps : dict[str, np.ndarray], optional + Only returned, if return_maps is True: + A dictionary with the 5 meta-information pv-maps: + nbr: The alternative labels that were considered instead of the voxel's label. + nbr_means: The local mean intensity of the label nbr at the specific voxel. + seg_means: The local mean intensity of the primary label at the specific voxel. + mixing_coeff: The partial volume of the primary label at the location. + nbr_mixing_coeff: The partial volume of the alternative (nbr) label. """ - if not isinstance(seg, np.ndarray) and np.issubdtype(seg.dtype, np.integer): - raise TypeError("The seg object is not a numpy.ndarray of int type.") - if not isinstance(norm, np.ndarray) and np.issubdtype(seg.dtype, np.numeric): - raise TypeError("The norm object is not a numpy.ndarray of numeric type.") - if not isinstance(labels, Sequence) and all(isinstance(lab, int) for lab in labels): - raise TypeError("The labels list is not a sequence of ints.") + from math import ceil - if seg.shape != norm.shape: - raise RuntimeError(f"The shape of the segmentation and the norm must be identical, but shapes are {seg.shape} " - f"and {norm.shape}!") - - mins, maxes, voxel_counts, __voxel_counts, sums, sums_2, volumes = [{} for _ in range(7)] - loc_border = {} - - if merged_labels is not None: + input_checker = { + "seg": (seg, np.integer), + "pv_guide": (pv_guide, np.number), + "norm": (norm, np.number), + } + for name, (img, _type) in input_checker.items(): + if (img is not None and + not (isinstance(img, np.ndarray) and np.issubdtype(img.dtype, _type))): + raise TypeError(f"The {name} object is not a numpy.ndarray of {_type}.") + _labels = np.asarray(labels) + if not isinstance(labels, Sequence): + labels = _labels.tolist() + if not np.issubdtype(_labels.dtype, np.integer): + raise TypeError("The labels list is not an arraylike of ints.") + + if seg.shape != pv_guide.shape: + raise RuntimeError( + f"The shapes of the segmentation and the pv_guide must be identical, but " + f"shapes are {seg.shape} and {pv_guide.shape}!" + ) + + has_norm = isinstance(norm, np.ndarray) + if has_norm and seg.shape != norm.shape: + raise RuntimeError( + f"The shape of the segmentation and the norm must be identical, but shapes " + f"are {seg.shape} and {norm.shape}!" + ) + + mins, maxes, voxel_counts, robust_voxel_counts = [{} for _ in range(4)] + borders, sums, sums_2, volumes = [{} for _ in range(4)] + + if isinstance(merged_labels, dict) and len(merged_labels) > 0: + _more_labels = list(merged_labels.values()) all_labels = set(labels) - all_labels = all_labels | reduce(lambda i, j: i | j, (set(s) for s in merged_labels.values())) + all_labels |= reduce(set.union, _more_labels[1:], set(_more_labels[0])) else: all_labels = labels # initialize global_crop with the full image - global_crop: Tuple[slice, ...] = tuple(slice(0, _shape) for _shape in seg.shape) + global_crop: SlicingTuple = tuple(slice(0, _shape) for _shape in seg.shape) # ignore all regions of the image that are background only if 0 not in all_labels: # crop global_crop to the data (plus one extra voxel) - any_in_global, global_crop = crop_patch_to_mask(seg != 0, sub_patch=global_crop) + not_background = cast(npt.NDArray[bool], seg != 0) + any_in_global, global_crop = crop_patch_to_mask( + not_background, + sub_patch=global_crop + ) # grow global_crop by one, so all border voxels are included - global_crop = grow_patch(global_crop, 1, seg.shape)[0] + global_crop = pad_slicer(global_crop, 1, seg.shape)[0] if not any_in_global: raise RuntimeError("Segmentation map only consists of background") - global_stats_filled = partial(global_stats, - norm=norm[global_crop], seg=seg[global_crop], - robust_percentage=robust_percentage) - if threads < 0: - threads = get_num_threads() - elif threads == 0: - raise ValueError("Zero is not a valid number of threads.") - map_kwargs = {"chunksize": np.ceil(len(labels) / threads)} + global_stats_filled = partial( + global_stats, + norm=norm[global_crop] if has_norm else None, + seg=seg[global_crop], + robust_percentage=robust_percentage, + ) - from concurrent.futures import ThreadPoolExecutor - with ThreadPoolExecutor(threads) as pool: + if threads == 0: + raise ValueError("Zero is not a valid number of threads.") + elif isinstance(threads, int) and threads > 0: + nthreads = threads + elif isinstance(threads, (Executor, int)): + nthreads: int = get_num_threads() + else: + raise TypeError("threads must be int or concurrent.futures.Executor object.") + executor = ThreadPoolExecutor(nthreads) if isinstance(threads, int) else threads + map_kwargs = {"chunksize": 1 if nthreads < 0 else ceil(len(labels) / nthreads)} - global_stats_future = pool.map(global_stats_filled, all_labels, **map_kwargs) + global_stats_future = executor.map(global_stats_filled, all_labels, **map_kwargs) - if return_maps: - _ndarray_alloc = np.full - full_nbr_label = _ndarray_alloc(seg.shape, fill_value=0, dtype=seg.dtype) - full_nbr_mean = _ndarray_alloc(norm.shape, fill_value=0, dtype=np.dtype(float)) - full_seg_mean = _ndarray_alloc(norm.shape, fill_value=0, dtype=np.dtype(float)) - full_pv = _ndarray_alloc(norm.shape, fill_value=1, dtype=np.dtype(float)) - full_ipv = _ndarray_alloc(norm.shape, fill_value=0, dtype=np.dtype(float)) - else: - full_nbr_label, full_seg_mean, full_nbr_mean, full_pv, full_ipv = [None] * 5 - - for lab, *data in global_stats_future: - if data[0] != 0: - voxel_counts[lab], __voxel_counts[lab] = data[:2] - mins[lab], maxes[lab], sums[lab], sums_2[lab] = data[2:-2] - volumes[lab], loc_border[lab] = data[-2] * vox_vol, data[-1] - - # un_global_crop border here - _border = np.any(list(loc_border.values()), axis=0) - border = np.pad(_border, tuple((slc.start, shp - slc.stop) for slc, shp in zip(global_crop, seg.shape))) - if not np.array_equal(border.shape, seg.shape): - raise RuntimeError("border and seg_array do not have same shape.") - - # iterate through patches of the image - patch_iters = [range(slice_.start, slice_.stop, patch_size) for slice_ in global_crop] # for 3D - - map_kwargs["chunksize"] = int(np.ceil(len(voxel_counts) / get_num_threads() / 4)) # 4 chunks per core - _patches = pool.map(partial(patch_filter, mask=border, global_crop=global_crop, patch_size=patch_size), - product(*patch_iters), **map_kwargs) - patches = (patch for has_pv_vox, patch in _patches if has_pv_vox) - - for vols in pool.map(partial(pv_calc_patch, global_crop=global_crop, loc_border=loc_border, border=border, - seg=seg, norm=norm, full_nbr_label=full_nbr_label, full_seg_mean=full_seg_mean, - full_pv=full_pv, full_ipv=full_ipv, full_nbr_mean=full_nbr_mean, eps=eps, - legacy_freesurfer=legacy_freesurfer), - patches, **map_kwargs): - for lab in volumes.keys(): - volumes[lab] += vols.get(lab, 0.) * vox_vol - - means = {lab: s / __voxel_counts[lab] for lab, s in sums.items() if __voxel_counts.get(lab, 0) > eps} - # *std = sqrt((sum * (*mean) - 2 * (*mean) * sum + sum2) / (nvoxels - 1)); - stds = {lab: np.sqrt((sums_2[lab] - means[lab] * sums[lab]) / (nvox - 1)) for lab, nvox in - __voxel_counts.items() if nvox > 1} - # ColHeaders Index SegId NVoxels Volume_mm3 StructName normMean normStdDev normMin normMax normRange - table = [{"SegId": lab, "NVoxels": voxel_counts.get(lab, 0), "Volume_mm3": volumes.get(lab, 0.), - "StructName": "", "normMean": means.get(lab, 0.), "normStdDev": stds.get(lab, 0.), - "normMin": mins.get(lab, 0.), "normMax": maxes.get(lab, 0.), - "normRange": maxes.get(lab, 0.) - mins.get(lab, 0.)} for lab in labels] + if return_maps: + from concurrent.futures import ProcessPoolExecutor + if isinstance(executor, ProcessPoolExecutor): + raise NotImplementedError( + "The ProcessPoolExecutor is not compatible with return_maps=True!" + ) + full_nbr_label = np.zeros(seg.shape, dtype=seg.dtype) + full_nbr_mean = np.zeros(pv_guide.shape, dtype=float) + full_seg_mean = np.zeros(pv_guide.shape, dtype=float) + full_pv = np.ones(pv_guide.shape, dtype=float) + full_ipv = np.zeros(pv_guide.shape, dtype=float) + else: + full_nbr_label, full_seg_mean, full_nbr_mean, full_pv, full_ipv = [None] * 5 + + for lab, data in global_stats_future: + if data[0] != 0: + voxel_counts[lab], robust_voxel_counts[lab] = data[:2] + mins[lab], maxes[lab], sums[lab], sums_2[lab] = data[2:-2] + volumes[lab], borders[lab] = data[-2] * vox_vol, data[-1] + + # un_global_crop border here + any_border = np.any(list(borders.values()), axis=0) + pad_width = np.asarray( + [(slc.start, shp - slc.stop) for slc, shp in zip(global_crop, seg.shape)], + dtype=int, + ) + any_border = np.pad(any_border, pad_width) + if not np.array_equal(any_border.shape, seg.shape): + raise RuntimeError("border and seg_array do not have same shape.") + + # iterate through patches of the image + patch_iters = [range(slc.start, slc.stop, patch_size) for slc in global_crop] + # 4 chunks per core + num_valid_labels = len(voxel_counts) + map_kwargs["chunksize"] = np.ceil(num_valid_labels / nthreads / 4).item() + patch_filter_func = partial(patch_filter, mask=any_border, + global_crop=global_crop, patch_size=patch_size) + _patches = executor.map(patch_filter_func, product(*patch_iters), **map_kwargs) + patches = (patch for has_pv_vox, patch in _patches if has_pv_vox) + + patchwise_pv_calc_func = partial( + pv_calc_patch, + global_crop=global_crop, + borders=borders, + border=any_border, + seg=seg, + pv_guide=pv_guide, + full_nbr_label=full_nbr_label, + full_seg_mean=full_seg_mean, + full_pv=full_pv, + full_ipv=full_ipv, + full_nbr_mean=full_nbr_mean, + eps=eps, + legacy_freesurfer=legacy_freesurfer, + ) + for vols in executor.map(patchwise_pv_calc_func, patches, **map_kwargs): + for lab in volumes.keys(): + volumes[lab] += vols.get(lab, 0.0) * vox_vol + + # ColHeaders: Index SegId NVoxels Volume_mm3 StructName Mean StdDev Min Max Range + def prep_dict(lab: int): + nvox = voxel_counts.get(lab, 0) + vol = volumes.get(lab, 0.) + return {"SegId": lab, "NVoxels": nvox, "Volume_mm3": vol} + + table = list(map(prep_dict, labels)) + if has_norm: + robust_vc_it = robust_voxel_counts.items() + means = {lab: sums.get(lab, 0.) / cnt for lab, cnt in robust_vc_it if cnt > eps} + + def get_std(lab: _IntType, nvox: int) -> float: + # *std = sqrt((sum * (*mean) - 2 * (*mean) * sum + sum2) / (nvoxels - 1)); + return np.sqrt((sums_2[lab] - means[lab] * sums[lab]) / (nvox - 1)) + + stds = {lab: get_std(lab, nvox) for lab, nvox in robust_vc_it if nvox > eps} + + for lab, this in zip(labels, table): + this.update( + Mean=means.get(lab, 0.0), + StdDev=stds.get(lab, 0.0), + Min=mins.get(lab, 0.0), + Max=maxes.get(lab, 0.0), + Range=maxes.get(lab, 0.0) - mins.get(lab, 0.0), + ) if merged_labels is not None: - def agg(f: Callable[..., np.ndarray], source: Dict[int, _NumberType], merge_labels: Iterable[int]) -> _NumberType: - return f([source.get(l, 0) for l in merge_labels if __voxel_counts.get(l) is not None]).item() - - for lab, merge in merged_labels.items(): - if all(__voxel_counts.get(l) is None for l in merge): - logging.getLogger(__name__).warning(f"None of the labels {merge} for merged label {lab} exist in the " - f"segmentation.") - continue - - nvoxels, _min, _max = agg(np.sum, voxel_counts, merge), agg(np.min, mins, merge), agg(np.max, maxes, merge) - _sums = [(l, sums.get(l, 0)) for l in merge] - _std_tmp = np.sum([s * s / __voxel_counts.get(l, 0) for l, s in _sums if __voxel_counts.get(l, 0) > 0]) - _std = np.sqrt((agg(np.sum, sums_2, merge) - _std_tmp) / (nvoxels - 1)).item() - merge_row = {"SegId": lab, "NVoxels": nvoxels, "Volume_mm3": agg(np.sum, volumes, merge), - "StructName": "", "normMean": agg(np.sum, sums, merge) / nvoxels, "normStdDev": _std, - "normMin": _min, "normMax": _max, "normRange": _max - _min} - table.append(merge_row) + labs_vol_args = (merged_labels, voxel_counts, robust_voxel_counts, volumes) + intensity_args = (mins, maxes, sums, sums_2) if has_norm else () + table.extend(calculate_merged_labels(*labs_vol_args, *intensity_args, eps=eps)) if return_maps: - return table, {"nbr": full_nbr_label, "segmean": full_seg_mean, "nbrmean": full_nbr_mean, "pv": full_pv, - "ipv": full_ipv} + maps = { + "nbr": full_nbr_label, + "seg_means": full_seg_mean, + "nbr_means": full_nbr_mean, + "mixing_coeff": full_pv, + "nbr_mixing_coeff": full_ipv, + } + return table, maps return table -def global_stats(lab: _IntType, norm: npt.NDArray[_NumberType], seg: npt.NDArray[_IntType], - out: Optional[npt.NDArray[bool]] = None, robust_percentage: Optional[float] = None) \ - -> Union[Tuple[_IntType, int], - Tuple[_IntType, int, int, _NumberType, _NumberType, float, float, float, npt.NDArray[bool]]]: - """Compute Label, Number of voxels, 'robust' number of voxels, norm minimum, maximum, sum, sum of squares and 6-connected border of label lab (out references the border). +def calculate_merged_labels( + merged_labels: VirtualLabel, + voxel_counts: dict[_IntType, int], + robust_voxel_counts: dict[_IntType, int], + volumes: dict[_IntType, float], + mins: Optional[dict[_IntType, float]] = None, + maxes: Optional[dict[_IntType, float]] = None, + sums: Optional[dict[_IntType, float]] = None, + sums_of_squares: Optional[dict[_IntType, float]] = None, + eps: float = 1e-6, +) -> Iterator[PVStats]: + """ + Calculate the statistics for meta-labels, i.e. labels based on other labels + (`merge_labels`). Add respective items to `table`. + + Parameters + ---------- + merged_labels : VirtualLabel + A dictionary of key 'merged id' to value list of ids it references. + voxel_counts : dict[int, int] + A dict of voxel counts for labels in the image/referenced in `merged_labels`. + robust_voxel_counts : dict[int, int] + A dict of the robust number of voxels referenced in `merged_labels`. + volumes : dict[int, float] + A dict of the volumes associated with each label. + mins : dict[int, float], optional + A dict of the minimum intensity associated with each label. + maxes : dict[int, float], optional + A dict of the minimum intensity associated with each label. + sums : dict[int, float], optional + A dict of the sums of voxel intensities associated with each label. + sums_of_squares : dict[int, float], optional + A dict of the sums of squares of voxel intensities associated with each label. + eps : float, default=1e-6 + An epsilon value for numeric stability. + + Yields + ------ + PVStats + A dictionary per entry in `merged_labels`. + """ + def num_robust_voxels(lab): + return robust_voxel_counts.get(lab, 0) + + def aggregate(source, merge_labels, f: Callable[..., np.ndarray] = np.sum): + """aggregate labels `merge_labels` from `source` with function `f`""" + _data = [source.get(l, 0) for l in merge_labels if num_robust_voxels(l) > eps] + return f(_data).item() + + def aggregate_std(sums, sums2, merge_labels, nvox): + """aggregate std of labels `merge_labels` from `source`""" + s2 = [(s := sums.get(l, 0)) * s / r for l in group + if (r := num_robust_voxels(l)) > eps] + return np.sqrt((aggregate(sums2, merge_labels) - np.sum(s2)) / nvox).item() + + for lab, group in merged_labels.items(): + stats = {"SegId": lab} + if all(l not in robust_voxel_counts for l in group): + logging.getLogger(__name__).warning( + f"None of the labels {group} for merged label {lab} exist in the " + f"segmentation." + ) + stats.update(NVoxels=0, Volume_mm3=0.0) + for k, v in {"Min": mins, "Max": maxes, "Mean": sums}.items(): + if v is not None: + stats[k] = 0. + if all(v is not None for v in (mins, maxes)): + stats["Range"] = 0. + if all(v is not None for v in (sums, sums_of_squares)): + stats["StdDev"] = 0. + else: + num_voxels = aggregate(voxel_counts, group) + stats.update(NVoxels=num_voxels, Volume_mm3=aggregate(volumes, group)) + if mins is not None: + stats["Min"] = aggregate(mins, group, np.min) + if maxes is not None: + stats["Max"] = aggregate(maxes, group, np.max) + if "Min" in stats: + stats["Range"] = stats["Max"] - stats["Min"] + if sums is not None: + stats["Mean"] = aggregate(sums, group) / num_voxels + if sums_of_squares is not None: + stats["StdDev"] = aggregate_std( + sums, + sums_of_squares, + group, + num_voxels - 1, + ) + yield stats + + +def global_stats( + lab: _IntType, + norm: npt.NDArray[_NumberType] | None, + seg: npt.NDArray[_IntType], + out: Optional[npt.NDArray[bool]] = None, + robust_percentage: Optional[float] = None, +) -> tuple[_IntType, _GlobalStats]: + """ + Compute Label, Number of voxels, 'robust' number of voxels, norm minimum, maximum, + sum, sum of squares and 6-connected border of label lab (out references the border). Parameters ---------- lab : _IntType - [MISSING] - norm : pt.NDArray[_NumberType] - [MISSING] + Label to compute statistics for. + norm : npt.NDArray[_NumberType], optional + The intensity image (default: None, do not compute intensity stats such as + normMin, normMax, etc.). seg : npt.NDArray[_IntType] - [MISSING] - out : npt.NDArray[bool], Optional - [MISSING] - robust_percentage : float, Optional - [MISSING] + The segmentation image. + out : npt.NDArray[bool], optional + Output array to store the computed borders. + robust_percentage : float, optional + A robustness percentile to compute the statistics with (default: None/off = 1). Returns ------- - _IntType and int - [MISSING] - or _IntType, int, int, _NumberType, _NumberType, float, float, float and npt.NDArray[bool] - [MISSING] + label : int + The label the stats belong to (input). + stats : _GlobalStats + A tuple of number_of_voxels, number_of_within_robustness_thresholds, + minimum_intensity, maximum_intensity, sum_of_intensities, + sum_of_intensity_squares, and border with respect to the label. """ - bin_array = cast(npt.NDArray[bool], seg == lab) - data = norm[bin_array].astype(int if np.issubdtype(norm.dtype, np.integer) else float) + def __compute_borders(out: Optional[np.ndarray]) -> np.ndarray: + # compute/update the border + if out is None: + out = seg_borders(label_mask, True, cmp_dtype="int8").astype(bool) + else: + out[:] = seg_borders(label_mask, True, cmp_dtype="int").astype(bool) + return out + + label_mask = cast(npt.NDArray[bool], seg == lab) + if norm is None: + nvoxels = int(label_mask.sum()) + out = __compute_borders(out) + return lab, (nvoxels, nvoxels, None, None, None, None, 0., out) + + data_dtype = int if np.issubdtype(norm.dtype, np.integer) else float + data = norm[label_mask].astype(data_dtype) nvoxels: int = data.shape[0] # if lab is not in the image at all if nvoxels == 0: - return lab, 0 - # compute/update the border - if out is None: - out = seg_borders(bin_array, True, cmp_dtype="int8").astype(bool) - else: - out[:] = seg_borders(bin_array, True, cmp_dtype="int").astype(bool) + return lab, (0, 0, None, None, None, None, 0., out) + out = __compute_borders(out) if robust_percentage is not None: data = np.sort(data) @@ -895,309 +2061,450 @@ def global_stats(lab: _IntType, norm: npt.NDArray[_NumberType], seg: npt.NDArray _sum: float = data.sum().item() sum_2: float = (data * data).sum().item() # this is independent of the robustness criterium - volume: float = np.sum(np.logical_and(bin_array, np.logical_not(out))).astype(float).item() - return lab, nvoxels, __voxel_count, _min, _max, _sum, sum_2, volume, out + _volume_mask = np.logical_and(label_mask, np.logical_not(out)) + volume: float = np.sum(_volume_mask).astype(float).item() + return lab, (nvoxels, __voxel_count, _min, _max, _sum, sum_2, volume, out) -def patch_filter(pos: Tuple[int, int, int], mask: npt.NDArray[bool], - global_crop: Tuple[slice, ...], patch_size: int = 32) \ - -> Tuple[bool, Sequence[slice]]: - """Return, whether there are mask-True voxels in the patch starting at pos with size patch_size and the resulting patch shrunk to mask-True regions. +def patch_filter( + patch_corner: tuple[int, int, int], + mask: npt.NDArray[bool], + global_crop: SlicingTuple, + patch_size: int = 32, +) -> tuple[bool, SlicingSequence]: + """ + Return, whether there are mask-True voxels in the patch starting at pos with size + patch_size and the resulting patch shrunk to mask-True regions. Parameters ---------- - pos : Tuple[int, int, int] - [MISSING] + patch_corner : tuple[int, int, int] + The top left corner of the patch. mask : npt.NDArray[bool] - [MISSING] - global_crop : Tuple[slice, ...] - [MISSING] - patch_size : int - [MISSING]. Defaults to 32 + The mask of interest in the patch. + global_crop : SlicingTuple + A image-wide slicing mask to constrain the 'search space'. + patch_size : int, default=32 + The size of the patch. Returns ------- bool - [MISSING] - Sequence[slice] - [MISSING] - + Whether there is any data in the patch at all. + SlicingSequence + Sequence of slice objects that describe patches with patch_corner and patch_size. """ + + def _slice(patch_start, _patch_size, image_stop): + return slice(patch_start, min(patch_start + _patch_size, image_stop)) + # create slices for current patch context (constrained by the global_crop) - patch = [slice(p, min(p + patch_size, slice_.stop)) for p, slice_ in zip(pos, global_crop)] + patch = [_slice(pc, patch_size, s.stop) for pc, s in zip(patch_corner, global_crop)] # crop patch context to the image content return crop_patch_to_mask(mask, sub_patch=patch) -def crop_patch_to_mask(mask: npt.NDArray[_NumberType], - sub_patch: Optional[Sequence[slice]] = None) \ - -> Tuple[bool, Sequence[slice]]: - """Crop the patch to regions of the mask that are non-zero. +def crop_patch_to_mask( + mask: npt.NDArray[_NumberType], + sub_patch: Optional[SlicingSequence] = None, +) -> tuple[bool, SlicingSequence]: + """ + Crop the patch to regions of the mask that are non-zero. - Assumes mask is always positive. Returns whether there - is any mask>0 in the patch and a patch shrunk to mask>0 regions. The optional subpatch constrains this operation to - the sub-region defined by a sequence of slicing operations. + Assumes mask is always positive. Returns whether there is any mask>0 in the patch + and a slicer/patch shrunk to mask>0 regions. The optional subpatch constrains this + operation to the sub-region defined by a sequence of slicing operations. Parameters ---------- mask : npt.NDArray[_NumberType] - to crop to + Mask to crop to. sub_patch : Optional[Sequence[slice]] - subregion of mask to only consider (default: full mask) + Subregion of mask to only consider (default: full mask). Returns ------- - bool - [MISSING] - Sequence[slice] - [MISSING] - - Note - ---- - This function requires device synchronization. - + not_empty : bool + Whether there is any voxel in the patch at all. + target_slicer : SlicingSequence + Sequence of slice-objects to extract the subregion of mask that is 'True'. """ - _patch = [] - patch = tuple([slice(0, s) for s in mask.shape] if sub_patch is None else sub_patch) - patch_in_patch_coords = tuple([slice(0, slice_.stop - slice_.start) for slice_ in patch]) + _target_slicer = [] + if sub_patch is None: + slicer_context = tuple(slice(0, s) for s in mask.shape) + else: + slicer_context = tuple(sub_patch) + slicer_in_patch_coords = tuple([slice(0, s.stop - s.start) for s in slicer_context]) in_mask = True - _mask = mask[patch].sum(axis=2) - for i, pat in enumerate(patch_in_patch_coords): + _mask = mask[slicer_context].sum(axis=2) + for i, pat in enumerate(slicer_in_patch_coords): p = pat.start if in_mask: if i == 2: - _mask = mask[patch][tuple(_patch)].sum(axis=0) + _mask = mask[slicer_context][tuple(_target_slicer)].sum(axis=0) + slicer_ith_axis = tuple(_target_slicer[1:] if i != 2 else []) # can we shrink the patch context in i-th axis? - pat_has_mask_in_axis = _mask[tuple(_patch[1:] if i != 2 else [])].sum(axis=int(i == 0)) > 0 + pat_has_mask_in_axis = _mask[slicer_ith_axis].sum(axis=int(i == 0)) > 0 # modify both the _patch_size and the coordinate p to shrink the patch - _pat_mask = np.argwhere(pat_has_mask_in_axis) - if _pat_mask.shape[0] == 0: + pat_mask_indices = np.argwhere(pat_has_mask_in_axis) + if pat_mask_indices.shape[0] == 0: + # none in here _patch_size = 0 in_mask = False else: - offset = _pat_mask[0].item() + # some in the mask, find first and distance to last + offset = pat_mask_indices[0].item() p += offset - _patch_size = _pat_mask[-1].item() - offset + 1 + _patch_size = pat_mask_indices[-1].item() - offset + 1 else: _patch_size = 0 - _patch.append(slice(p, p + _patch_size)) - - out_patch = [slice(_p.start + p.start, p.start + _p.stop) for _p, p in zip(_patch, patch)] - return _patch[0].start != _patch[0].stop, out_patch - - -def pv_calc_patch(patch: Tuple[slice, ...], global_crop: Tuple[slice, ...], - loc_border: Dict[_IntType, npt.NDArray[bool]], - seg: npt.NDArray[_IntType], norm: np.ndarray, border: npt.NDArray[bool], - full_pv: Optional[npt.NDArray[float]] = None, full_ipv: Optional[npt.NDArray[float]] = None, - full_nbr_label: Optional[npt.NDArray[_IntType]] = None, - full_seg_mean: Optional[npt.NDArray[float]] = None, - full_nbr_mean: Optional[npt.NDArray[float]] = None, eps: float = 1e-6, - legacy_freesurfer: bool = False) \ - -> Dict[_IntType, float]: - """Calculate PV for patch. + _target_slicer.append(slice(p, p + _patch_size)) + + def _move_slice(the_slice: slice, offset: int) -> slice: + return slice(the_slice.start + offset, the_slice.stop + offset) + + target_slicer = [_move_slice(ts, sc.start) for ts, sc in zip(_target_slicer, + slicer_context)] + return _target_slicer[0].start != _target_slicer[0].stop, target_slicer + + +def pv_calc_patch( + slicer_patch: SlicingTuple, + global_crop: SlicingTuple, + borders: dict[_IntType, npt.NDArray[bool]], + seg: npt.NDArray[_IntType], + pv_guide: npt.NDArray, + border: npt.NDArray[bool], + full_pv: Optional[npt.NDArray[float]] = None, + full_ipv: Optional[npt.NDArray[float]] = None, + full_nbr_label: Optional[npt.NDArray[_IntType]] = None, + full_seg_mean: Optional[npt.NDArray[float]] = None, + full_nbr_mean: Optional[npt.NDArray[float]] = None, + eps: float = 1e-6, + legacy_freesurfer: bool = False, +) -> dict[_IntType, float]: + """ + Calculate PV for patch. - If full* keyword arguments are passed, also fills, per voxel results for the respective - voxels in the patch. + If full* keyword arguments are passed, the function also fills in per voxel results + for the respective voxels in the patch. Parameters ---------- - patch : Tuple[slice, ...] - [MISSING] - global_crop : Tuple[slice, ...] - [MISSING] - loc_border : Dict[_IntType, npt.NDArray[bool]] - [MISSING] - seg : npt.NDArray[_IntType] - [MISSING] - norm : np.ndarray - [MISSING] + slicer_patch : SlicingTuple + Tuple of slice-objects, with indexing origin at the image origin. + global_crop : SlicingTuple + Tuple of slice-objects, a global mask to limit computing to relevant parts of + the image. + borders : dict[int, npt.NDArray[bool]] + Dictionary containing the borders for each label. + seg : numpy.typing.NDArray[int] + The segmentation (full image) defining the labels. + pv_guide : numpy.ndarray + The (full) image with intensities to guide the PV calculation. border : npt.NDArray[bool] - [MISSING] - full_pv : npt.NDArray[float], Optional - [MISSING] - full_ipv : npt.NDArray[float], Optional - [MISSING] - full_nbr_label : npt.NDArray[_IntType], Optional - [MISSING] - full_seg_mean : npt.NDArray[float], Optional - [MISSING] - full_nbr_mean : npt.NDArray[float], Optional - [MISSING] - eps : float - [MISSING]. Defaults to 1e-6 - legacy_freesurfer : bool - [MISSING] + Binary mask, True, where a voxel is considered to be a border voxel. + full_pv : npt.NDArray[float], optional + PV image to fill with values for debugging. + full_ipv : npt.NDArray[float], optional + IPV image to fill with values for debugging. + full_nbr_label : npt.NDArray[_IntType], optional + NBR image to fill with values for debugging. + full_seg_mean : npt.NDArray[float], optional + Mean pv_guide-values for current segmentation label-image to fill with values + for debugging. + full_nbr_mean : npt.NDArray[float], optional + Mean pv_guide-values for nbr label-image to fill with values for debugging. + eps : float, default=1e-6 + Epsilon for considering a voxel being in the neighborhood. + legacy_freesurfer : bool, default=False + Whether to use the legacy freesurfer mri_segstats formula or the corrected + formula. Returns ------- - Dict[_IntType, float] - [MISSING] + dict[int, float] + Dictionary of per-label PV-corrected volume of affected voxels in the patch. """ - log_eps = -int(np.log10(eps)) - patch = tuple(patch) - patch_grow1, ungrow1_patch = grow_patch(patch, (FILTER_SIZES[0]-1)//2, seg.shape) - patch_grow7, ungrow7_patch = grow_patch(patch, (FILTER_SIZES[1]-1)//2, seg.shape) - patch_shrink6 = tuple( - slice(ug7.start - ug1.start, None if ug7.stop == ug1.stop else ug7.stop - ug1.stop) for ug1, ug7 in - zip(ungrow1_patch, ungrow7_patch)) - patch_in_gc = tuple(slice(p.start - gc.start, p.stop - gc.start) for p, gc in zip(patch, global_crop)) + # Variable conventions: + # pat_* : *, but sliced to the patch, i.e. a 3D/4D array + # pat1d_* : like pat_*, but only those voxels, that are part of the border and + # flattened - label_lookup = np.unique(seg[patch_grow1]) + log_eps = -int(np.log10(eps)) + + slicer_patch = tuple(slicer_patch) + slicer_small_patch, slicer_small_to_patch = pad_slicer(slicer_patch, + (FILTER_SIZES[0] - 1) // 2, + seg.shape) + slicer_large_patch, slicer_large_to_patch = pad_slicer(slicer_patch, + (FILTER_SIZES[1] - 1) // 2, + seg.shape) + slicer_large_to_small = tuple( + slice(l2p.start - s2p.start, + None if l2p.stop == s2p.stop else l2p.stop - s2p.stop) + for s2p, l2p in zip(slicer_small_to_patch, slicer_large_to_patch)) + patch_in_gc = tuple( + slice(p.start - gc.start, + p.stop - gc.start) + for p, gc in zip(slicer_patch, global_crop)) + + label_lookup = np.unique(seg[slicer_small_patch]) maxlabels = label_lookup[-1] + 1 if maxlabels > 100_000: raise RuntimeError("Maximum number of labels above 100000!") # create a view for the current patch border - pat_border = border[patch] - pat_is_border, pat_is_nbr, pat_label_counts, pat_label_sums \ - = patch_neighbors(label_lookup, norm, seg, pat_border, loc_border, - patch_grow7, patch_in_gc, patch_shrink6, ungrow1_patch, ungrow7_patch, - ndarray_alloc=np.full, eps=eps, legacy_freesurfer=legacy_freesurfer) + pat_border = border[slicer_patch] + pat_is_border, pat_is_nbr, pat_label_counts, pat_label_sums = patch_neighbors( + label_lookup, + pv_guide, + seg, + pat_border, + borders, + slicer_large_patch, + patch_in_gc, + slicer_large_to_small, + slicer_small_to_patch, + slicer_large_to_patch, + eps=eps, + legacy_freesurfer=legacy_freesurfer, + ) # both counts and sums are "normalized" by the local neighborhood size (15**3) label_lookup_fwd = np.zeros((maxlabels,), dtype="int") label_lookup_fwd[label_lookup] = np.arange(label_lookup.shape[0]) # shrink 3d patch to 1d list of border voxels - pat1d_norm, pat1d_seg = norm[patch][pat_border], seg[patch][pat_border] + pat1d_pv = pv_guide[slicer_patch][pat_border] + pat1d_seg = seg[slicer_patch][pat_border] pat1d_label_counts = pat_label_counts[:, pat_border] - # both sums and counts are normalized by n-hood-size**3, so the output is not anymore - pat1d_label_means = (pat_label_sums[:, pat_border] / np.maximum(pat1d_label_counts, eps * 0.0003)).round(log_eps + 4) # float + pat1d_robust_lblcnt = np.maximum(pat1d_label_counts, eps * 3e-4) + # both sums and counts are normalized by neighborhood-size**3, both are float + pat1d_label_means = pat_label_sums[:, pat_border] / pat1d_robust_lblcnt + pat1d_label_means = pat1d_label_means.round(log_eps + 4) # get the mean label intensity of the "local label" - mean_label = np.take_along_axis(pat1d_label_means, unsqueeze(label_lookup_fwd[pat1d_seg], 0), axis=0)[0] + pat1d_seg_reindexed = np.expand_dims(label_lookup_fwd[pat1d_seg], 0) + _mean_label = np.take_along_axis(pat1d_label_means, pat1d_seg_reindexed, axis=0) + mean_label = _mean_label[0] # get the index of the "alternative label" pat1d_is_this_6border = pat_is_border[:, pat_border] # calculate which classes to consider: - is_valid = np.all( - # 1. considered (mean of) alternative label must be on the other side of norm as the (mean of) the segmentation - # label of the current voxel - [np.logical_xor(pat1d_label_means > unsqueeze(pat1d_norm, 0), unsqueeze(mean_label > pat1d_norm, 0)), - # 2. considered (mean of) alternative label must be different to norm of voxel - pat1d_label_means != unsqueeze(pat1d_norm, 0), - # 3. (mean of) segmentation label must be different to norm of voxel - np.broadcast_to(unsqueeze(np.abs(mean_label - pat1d_norm) > eps, 0), pat1d_label_means.shape), - # 4. label must be a neighbor - pat_is_nbr[:, pat_border], - # 3. label must not be the segmentation - pat1d_seg[np.newaxis] != label_lookup[:, np.newaxis]], axis=0) - - none_valid = ~is_valid.any(axis=0, keepdims=False) - # select the label, that is valid or not valid but also exists and is not the current label - max_counts_index = np.round(pat1d_label_counts * is_valid, log_eps).argmax(axis=0, keepdims=False) - - nbr_label = label_lookup[max_counts_index] # label with max_counts - nbr_label[none_valid] = 0 + pat1d_mean_intensity_higher = pat1d_label_means > np.expand_dims(pat1d_pv, 0) + pat1d_mean_intensity_lower = np.expand_dims(mean_label > pat1d_pv, 0) + pat1d_mean_different = np.expand_dims(np.abs(mean_label - pat1d_pv) > eps, 0) + pat1d_is_valid = np.all( + [ + # 1. considered (mean of) alternative label must be on the other side of pv + # as the (mean of) the segmentation label of the current voxel + np.logical_xor(pat1d_mean_intensity_higher, pat1d_mean_intensity_lower), + # 2. considered (mean of) alternative label must be different to pv of voxel + pat1d_label_means != np.expand_dims(pat1d_pv, 0), + # 3. (mean of) segmentation label must be different to pv of voxel + np.broadcast_to(pat1d_mean_different, pat1d_label_means.shape), + # 4. label must be a neighbor + pat_is_nbr[:, pat_border], + # 3. label must not be the segmentation + pat1d_seg[np.newaxis] != label_lookup[:, np.newaxis], + ], + axis=0, + ) + + pat1d_none_valid = ~pat1d_is_valid.any(axis=0, keepdims=False) + # select the label, that is valid or not valid but also exists and is not the + # current label + pat1d_label_frequency = np.round(pat1d_label_counts * pat1d_is_valid, log_eps) + pat1d_max_frequency_index = pat1d_label_frequency.argmax(axis=0, keepdims=False) + + pat1d_nbr_label = label_lookup[pat1d_max_frequency_index] # label with max_counts + pat1d_nbr_label[pat1d_none_valid] = 0 # get the mean label intensity of the "alternative label" - mean_nbr = np.take_along_axis(pat1d_label_means, unsqueeze(label_lookup_fwd[nbr_label], 0), axis=0)[0] + pat1d_label_lookup_nbr = np.expand_dims(label_lookup_fwd[pat1d_nbr_label], 0) + mean_nbr = np.take_along_axis(pat1d_label_means, pat1d_label_lookup_nbr, axis=0)[0] # interpolate between the "local" and "alternative label" mean_to_mean_nbr = mean_label - mean_nbr delta_gt_eps = np.abs(mean_to_mean_nbr) > eps - pat1d_pv = (pat1d_norm - mean_nbr) / np.where(delta_gt_eps, mean_to_mean_nbr, eps) # make sure no division by zero + # make sure no division by zero + pat1d_pv = (pat1d_pv - mean_nbr) / np.where(delta_gt_eps, mean_to_mean_nbr, eps) - pat1d_pv[~delta_gt_eps] = 1. # set pv fraction to 1 if division by zero - pat1d_pv[none_valid] = 1. # set pv fraction to 1 for voxels that have no 'valid' nbr - pat1d_pv[pat1d_pv > 1.] = 1. - pat1d_pv[pat1d_pv < 0.] = 0. + # set pv fraction to 1 if division by zero + pat1d_pv[~delta_gt_eps] = 1.0 + # set pv fraction to 1 for voxels that have no valid nbr + pat1d_pv[pat1d_none_valid] = 1.0 + pat1d_pv[pat1d_pv > 1.0] = 1.0 + pat1d_pv[pat1d_pv < 0.0] = 0.0 - pat1d_inv_pv = 1. - pat1d_pv + pat1d_inv_pv = 1.0 - pat1d_pv if legacy_freesurfer: - # re-create the "supposed" freesurfer inconsistency that does not count vertex neighbors, if the voxel label - # is not of question - mask_by_6border = np.take_along_axis(pat1d_is_this_6border, unsqueeze(label_lookup_fwd[nbr_label], 0), axis=0)[0] - pat1d_inv_pv = pat1d_inv_pv * mask_by_6border + # re-create the "supposed" freesurfer inconsistency that does not count vertex + # neighbors, if the voxel label is not of question + mask_by_6border = np.take_along_axis( + pat1d_is_this_6border, pat1d_label_lookup_nbr, axis=0 + ) + pat1d_inv_pv = pat1d_inv_pv * mask_by_6border[0] if full_pv is not None: - full_pv[patch][pat_border] = pat1d_pv + full_pv[slicer_patch][pat_border] = pat1d_pv if full_nbr_label is not None: - full_nbr_label[patch][pat_border] = nbr_label + full_nbr_label[slicer_patch][pat_border] = pat1d_nbr_label if full_ipv is not None: - full_ipv[patch][pat_border] = pat1d_inv_pv + full_ipv[slicer_patch][pat_border] = pat1d_inv_pv if full_nbr_mean is not None: - full_nbr_mean[patch][pat_border] = mean_nbr + full_nbr_mean[slicer_patch][pat_border] = mean_nbr if full_seg_mean is not None: - full_seg_mean[patch][pat_border] = mean_label - - return {lab: (pat1d_pv.sum(where=pat1d_seg == lab) + pat1d_inv_pv.sum(where=nbr_label == lab)).item() for lab in - label_lookup} + full_seg_mean[slicer_patch][pat_border] = mean_label + def _vox_calc_pv(lab: _IntType) -> float: + """ + Compute the PV of voxels labels lab and voxels not labeled lab, but chosen as + mixing label. + """ + pv_sum = pat1d_pv.sum(where=pat1d_seg == lab).item() + inv_pv_sum = pat1d_inv_pv.sum(where=pat1d_nbr_label == lab).item() + return pv_sum + inv_pv_sum + + return {lab: _vox_calc_pv(lab) for lab in label_lookup} + + +def patch_neighbors( + labels: Sequence[_IntType], + pv_guide: npt.NDArray, + seg: npt.NDArray[_IntType], + border_patch: npt.NDArray[bool], + borders: dict[_IntType, npt.NDArray[bool]], + slicer_large_patch: SlicingTuple, + slicer_patch: SlicingTuple, + slicer_large_to_small: SlicingTuple, + slicer_small_to_patch: SlicingTuple, + slicer_large_to_patch: SlicingTuple, + eps: float = 1e-6, + legacy_freesurfer: bool = False, +) -> tuple[ + "npt.NDArray[bool]", + "npt.NDArray[bool]", + "npt.NDArray[float]", + "npt.NDArray[float]", +]: + """ + Calculate the neighbor statistics of labels for a specific patch. -def patch_neighbors(labels, norm, seg, pat_border, loc_border, patch_grow7, patch_in_gc, patch_shrink6, - ungrow1_patch, ungrow7_patch, ndarray_alloc, eps, legacy_freesurfer = False): - """Calculate the neighbor statistics of labels, etc.. + The patch is defined by `slicer_large_patch`, `slicer_large_to_small`, + `slicer_small_to_patch`, and `slicer_large_to_patch`. Parameters ---------- - labels : - [MISSING] - norm : - [MISSING] - seg : - [MISSING] - pat_border : - [MISSING] - loc_border : - [MISSING] - patch_grow7 : - [MISSING] - patch_in_gc : - [MISSING] - patch_shrink6 : - [MISSING] - ungrow1_patch : - [MISSING] - ungrow7_patch : - [MISSING] - ndarray_alloc : - [MISSING] - eps : - [MISSING] - legacy_freesurfer : bool - [MISSING]. Defaults to False + labels : Sequence[int] + A sequence of all labels that we want to compute the PV for. + pv_guide : numpy.ndarray + The (full) image with intensities to guide the PV calculation. + seg : numpy.typing.NDArray[int] + The segmentation (full image) defining the labels. + border_patch : npt.NDArray[bool] + Binary mask for the current patch, True, where a voxel is considered to be a + border voxel. + borders : dict[_IntType, npt.NDArray[bool]] + Dictionary containing the borders for each label. + slicer_large_patch : SlicingTuple + Slicing tuple to obtain a patch of shape like the patch but padded to the large + filter size. + slicer_patch : SlicingTuple + Tuple of slice-objects to extract the patch from the full image. + slicer_large_to_small : SlicingTuple + Tuple of slice-objects to extract the small patch (patch plus small filter + window) from the large patch (patch plus large filter window). + slicer_small_to_patch : SlicingTuple + Tuple of slice-objects to extract the patch from the patch padded by the small + filter size. + slicer_large_to_patch : SlicingTuple + Tuple of slice-objects to extract the patch from the patch padded by the large + filter size. + eps : float, default=1e-6 + Epsilon for considering a voxel being in the neighborhood. + legacy_freesurfer : bool, default=False + Whether to use the legacy freesurfer mri_segstats formula or the corrected + formula. Returns ------- - [MISSING] - + pat_is_border : npt.NDArray[bool] + Array indicating whether each label is on the patch border. + pat_is_nbr : npt.NDArray[bool] + Array indicating whether each label is a neighbor in the patch. + pat_label_count : npt.NDArray[float] + Array containing label counts in the patch (divided by the neighborhood size). + pat_label_sums : npt.NDArray[float] + Array containing the sum of normalized values for each label in the patch. """ - loc_shape = (len(labels),) + pat_border.shape + shape_of_patch = (len(labels),) + border_patch.shape - pat_label_counts, pat_label_sums = ndarray_alloc((2,) + loc_shape, fill_value=0., dtype=float) - pat_is_nbr, pat_is_border = ndarray_alloc((2,) + loc_shape, fill_value=False, dtype=bool) + pat_label_counts, pat_label_sums = np.zeros((2,) + shape_of_patch, dtype=float) + pat_is_nbr, pat_is_border = np.zeros((2,) + shape_of_patch, dtype=bool) # all False for i, lab in enumerate(labels): - # in legacy freesurfer mode, we want to fill the binary labels with True if we are looking at the background - fill_binary_label = float(legacy_freesurfer and lab == 0) + # in legacy freesurfer mode, we want to fill the binary labels with True if we + # are looking at the background + fillvalue_binary_label = float(legacy_freesurfer and lab == 0) - pat7_bin_array = cast(npt.NDArray[bool], seg[patch_grow7] == lab) + same_label_large_patch = cast(npt.NDArray[bool], seg[slicer_large_patch] == lab) + same_label_small_patch = same_label_large_patch[slicer_large_to_small] # implicitly also a border detection: is lab a neighbor of the "current voxel" - tmp_nbr_label_counts = uniform_filter(pat7_bin_array[patch_shrink6], FILTER_SIZES[0], fill_binary_label) # as float (*filter_size**3) + # returns 'small patch'-array of float (shape: (patch_size + filter_size)**3) + # for label 'lab' + tmp_nbr_label_counts = uniform_filter( + same_label_small_patch, + FILTER_SIZES[0], + fillvalue_binary_label, + ) if tmp_nbr_label_counts.sum() > eps: # lab is at least once a nbr in the patch (grown by one) - if lab in loc_border: - pat_is_border[i] = loc_border[lab][patch_in_gc] + if lab in borders: + pat_is_border[i] = borders[lab][slicer_patch] else: - pat7_is_border = seg_borders(pat7_bin_array[patch_shrink6], True, cmp_dtype="int8") - pat_is_border[i] = pat7_is_border[ungrow1_patch].astype(bool) - - pat_is_nbr[i] = tmp_nbr_label_counts[ungrow1_patch] > eps - pat_label_counts[i] = uniform_filter(pat7_bin_array, FILTER_SIZES[1], fill_binary_label)[ungrow7_patch] # as float (*filter_size**3) - pat7_filtered_norm = norm[patch_grow7] * pat7_bin_array - pat_label_sums[i] = uniform_filter(pat7_filtered_norm, FILTER_SIZES[1], 0)[ungrow7_patch] + pat7_is_border = seg_borders( + same_label_small_patch, + label=True, + cmp_dtype="int8", + ) + pat_is_border[i] = pat7_is_border[slicer_small_to_patch].astype(bool) + + pat_is_nbr[i] = tmp_nbr_label_counts[slicer_small_to_patch] > eps + # as float (*filter_size**3) + pat_label_counts[i] = uniform_filter( + same_label_large_patch, + FILTER_SIZES[1], + fillvalue_binary_label, + slicer_patch=slicer_large_to_patch + ) + pat_large_filter_pv = pv_guide[slicer_large_patch] * same_label_large_patch + pat_label_sums[i] = uniform_filter( + pat_large_filter_pv, + FILTER_SIZES[1], + fillval=0, slicer_patch=slicer_large_to_patch + ) # else: lab is not present in the patch return pat_is_border, pat_is_nbr, pat_label_counts, pat_label_sums +# timeit cmd arg: +# python -m timeit < None: - """Train the network to the given training data. + """ + Train the network to the given training data. Parameters ---------- train_loader : loader.DataLoader - data loader for the training - optimizer : torch.optim.optimizer.Optimizer - optimizer for the training - scheduler : Union[None, scheduler.StepLR, scheduler.CosineAnnealingWarmRestarts] - lr scheduler for the training + Data loader for the training. + optimizer : torch.optim.Optimizer + Optimizer for the training. + scheduler : None, scheduler.StepLR, scheduler.CosineAnnealingWarmRestarts + LR scheduler for the training. train_meter : Meter - [MISSING] + Meter to keep track of the training stats. epoch : int - [MISSING] - + Current epoch. + """ self.model.train() logger.info("Training started ") @@ -120,7 +123,6 @@ def train( loss_batch = np.zeros(1) for curr_iter, batch in tqdm(enumerate(train_loader), total=len(train_loader)): - images, labels, weights, scale_factors = ( batch["image"].to(self.device), batch["label"].to(self.device), @@ -147,8 +149,8 @@ def train( loss_total.backward() if ( - not self.subepoch - or (curr_iter + 1) % (16 / self.cfg.TRAIN.BATCH_SIZE) == 0 + not self.subepoch + or (curr_iter + 1) % (16 / self.cfg.TRAIN.BATCH_SIZE) == 0 ): optimizer.step() # every second epoch to get batchsize of 16 if using 8 if scheduler is not None: @@ -178,27 +180,24 @@ def train( @torch.no_grad() def eval( - self, - val_loader: loader.DataLoader, - val_meter: Meter, - epoch: int + self, val_loader: loader.DataLoader, val_meter: Meter, epoch: int ) -> np.ndarray: - """Evaluate model and calculates stats. + """ + Evaluate model and calculates stats. Parameters ---------- val_loader : loader.DataLoader - Value loader + Value loader. val_meter : Meter - Meter for the values + Meter for the values. epoch : int - epoch to evaluate + Epoch to evaluate. Returns ------- int, float, ndarray - median miou [value] - + median miou [value]. """ logger.info(f"Evaluating model at epoch {epoch}") self.model.eval() @@ -218,7 +217,6 @@ def eval( val_start = time.time() for curr_iter, batch in tqdm(enumerate(val_loader), total=len(val_loader)): - images, labels, weights, scale_factors = ( batch["image"].to(self.device), batch["label"].to(self.device), @@ -307,10 +305,12 @@ def eval( return np.mean(np.mean(miou)) def run(self): - """Transfer the model to devices, create a tensor board summary writer and then perform the training loop.""" + """ + Transfer the model to devices, create a tensor board summary writer and then perform the training loop. + """ if self.cfg.NUM_GPUS > 1: assert ( - self.cfg.NUM_GPUS <= torch.cuda.device_count() + self.cfg.NUM_GPUS <= torch.cuda.device_count() ), "Cannot use more GPU devices than available" print("Using ", self.cfg.NUM_GPUS, "GPUs!") self.model = torch.nn.DataParallel(self.model) diff --git a/FastSurferCNN/utils/__init__.py b/FastSurferCNN/utils/__init__.py index 6bfc8189..da1c5dcf 100644 --- a/FastSurferCNN/utils/__init__.py +++ b/FastSurferCNN/utils/__init__.py @@ -24,4 +24,17 @@ "misc", "parser_defaults", "threads", + "Plane", + "PlaneAxial", + "PlaneCoronal", + "PlaneSagittal", + "PLANES", ] + +from typing import Literal, get_args + +PlaneAxial = Literal["axial"] +PlaneCoronal = Literal["coronal"] +PlaneSagittal = Literal["sagittal"] +Plane = PlaneAxial | PlaneCoronal | PlaneSagittal +PLANES: tuple[PlaneAxial, PlaneCoronal, PlaneSagittal] = ("axial", "coronal", "sagittal") diff --git a/FastSurferCNN/utils/arg_types.py b/FastSurferCNN/utils/arg_types.py index bb540535..faeb7fa2 100644 --- a/FastSurferCNN/utils/arg_types.py +++ b/FastSurferCNN/utils/arg_types.py @@ -13,7 +13,7 @@ # limitations under the License. import argparse -from typing import Union, Literal, Optional +from typing import Literal, Optional, Union import nibabel as nib import numpy as np @@ -22,22 +22,24 @@ def vox_size(a: str) -> VoxSizeOption: - """Convert the vox_size argument to 'min' or a valid voxel size. + """ + Convert the vox_size argument to 'min' or a valid voxel size. Parameters ---------- a : str - vox size type. Can be auto, bin or a number between 1 an 0 + Vox size type. Can be auto, bin or a number between 1 an 0. Returns ------- - [MISSING] + str or float + If 'auto' or 'min' is provided, it returns a string('auto' or 'min'). + If a valid voxel size (between 0 and 1) is provided, it returns a float. Raises ------ argparse.ArgumentTypeError - An error from creating or using an argument. Additionally, vox_sizes may be 'min'. - + If the arguemnt is not "min", "auto" or convertible to a float between 0 and 1. """ if a.lower() in ["auto", "min"]: return "min" @@ -50,17 +52,24 @@ def vox_size(a: str) -> VoxSizeOption: def float_gt_zero_and_le_one(a: str) -> Optional[float]: - """Check whether a parameters are a float between 0 and one. + """ + Check whether a parameters are a float between 0 and one. Parameters ---------- a : str - String of a number or none, infinity + String of a number or none, infinity. Returns ------- - [MISSING] + float or None + If `a` is a valid float between 0 and 1, return the float value. + If `a` is 'none' or 'infinity', return None. + Raises + ------ + argparse.ArgumentTypeError + If `a` is neither a float between 0 and 1. """ if a is None or a.lower() in ["none", "infinity"]: return None @@ -72,22 +81,28 @@ def float_gt_zero_and_le_one(a: str) -> Optional[float]: def target_dtype(a: str) -> str: - """Check for valid dtypes. + """ + Check for valid dtypes. Parameters ---------- a : str - datatype + Datatype descriptor. Returns ------- - [MISSING] + str + The validated data type. Raises ------ argparse.ArgumentTypeError - Invalid dtype + Invalid dtype. + See Also + -------- + numpy.dtype + For more information on numpy data types and their properties. """ dtypes = nib.freesurfer.mghformat.data_type_codes.value_set("label") dtypes.add("any") @@ -107,23 +122,23 @@ def target_dtype(a: str) -> str: def int_gt_zero(value: Union[str, int]) -> int: - """Convert to positive integers. + """ + Convert to positive integers. Parameters ---------- value : Union[str, int] - integer to convert + Integer to convert. Returns ------- val : int - converted integer + Converted integer. Raises ------ argparse ArgumentTypeError: Invalid value, must not be negative. - """ val = int(value) if val <= 0: @@ -131,24 +146,24 @@ def int_gt_zero(value: Union[str, int]) -> int: return val -def int_ge_zero(value) -> int: - """Convert to integers greater 0. +def int_ge_zero(value: str) -> int: + """ + Convert to integers greater 0. Parameters ---------- - value : - integer to convert + value : str + String to convert to int. Returns ------- val : int - given value if bigger or equal to zero + Given value if bigger or equal to zero. Raises ------ argparse ArgumentTypeError: Invalid value, must be greater than 0. - """ val = int(value) if val < 0: @@ -157,18 +172,18 @@ def int_ge_zero(value) -> int: def unquote_str(value) -> str: - """Unquote a (single quoted) string. + """ + Unquote a (single quoted) string, i.e. remove one level of single-quotes. Parameters ---------- - value : - String to be unquoted + value : str + String to be unquoted. Returns ------- val : str - A string of the value without quoting with ''' - + A string of the value without leading and trailing single-quotes. """ val = str(value) if val.startswith("'") and val.endswith("'"): diff --git a/FastSurferCNN/utils/brainvolstats.py b/FastSurferCNN/utils/brainvolstats.py new file mode 100644 index 00000000..40451bb5 --- /dev/null +++ b/FastSurferCNN/utils/brainvolstats.py @@ -0,0 +1,2442 @@ +import abc +import logging +import re +from concurrent.futures import Executor +from contextlib import contextmanager +from pathlib import Path +from typing import (TYPE_CHECKING, Sequence, cast, Literal, Iterable, Callable, Union, + Optional, overload, TextIO, Protocol, TypeVar, Generic, Type) +from concurrent.futures import Future + +import numpy as np + +if TYPE_CHECKING: + from numpy import typing as npt + import lapy + import nibabel as nib + import pandas as pd + + from CerebNet.datasets.utils import LTADict + +MeasureTuple = tuple[str, str, int | float, str] +ImageTuple = tuple["nib.analyze.SpatialImage", "np.ndarray"] +UnitString = Literal["unitless", "mm^3"] +MeasureString = Union[str, "Measure"] +AnyBufferType = Union[ + dict[str, MeasureTuple], + ImageTuple, + "lapy.TriaMesh", + "npt.NDArray[float]", + "pd.DataFrame", +] +T_BufferType = TypeVar( + "T_BufferType", + bound=Union[ + ImageTuple, + dict[str, MeasureTuple], + "lapy.TriaMesh", + "np.ndarray", + "pd.DataFrame", + ]) +DerivedAggOperation = Literal["sum", "ratio", "by_vox_vol"] +AnyMeasure = Union["AbstractMeasure", str] +PVMode = Literal["vox", "pv"] +ClassesType = Sequence[int] +ClassesOrCondType = ClassesType | Callable[["npt.NDArray[int]"], "npt.NDArray[bool]"] +MaskSign = Literal["abs", "pos", "neg"] +_ToBoolCallback = Callable[["npt.NDArray[int]"], "npt.NDArray[bool]"] + + +class ReadFileHook(Protocol[T_BufferType]): + + @overload + def __call__(self, file: Path, blocking: True = True) -> T_BufferType: ... + + @overload + def __call__(self, file: Path, blocking: False) -> None: ... + + def __call__(self, file: Path, b: bool = True) -> Optional[T_BufferType]: ... + + +class _DefaultFloat(float): + pass + + +def read_measure_file(path: Path) -> dict[str, MeasureTuple]: + """ + Read '# Measure '-entries from stats files. + + Parameters + ---------- + path : Path + The path to the file to read from. + + Returns + ------- + A dictionary of Measure keys to tuple of descriptors like + {'': ('', '', , '')}. + """ + if not path.exists(): + raise IOError(f"Measures could not be imported from {path}, " + f"the file does not exist.") + with open(path, "r") as fp: + lines = list(fp.readlines()) + vox_line = list(filter(lambda l: l.startswith("# VoxelVolume_mm3 "), lines)) + lines = filter(lambda l: l.startswith("# Measure "), lines) + + def to_measure(line: str) -> tuple[str, MeasureTuple]: + data_tup = line.removeprefix("# Measure ").strip() + import re + key, name, desc, sval, unit = re.split("\\s*,\\s*", data_tup) + value = float(sval) if "." in sval else int(sval) + return key, (name, desc, value, unit) + + data = dict(map(to_measure, lines)) + if len(vox_line) > 0: + vox_vol = float(vox_line[-1].split(" ")[2].strip()) + data["vox_vol"] = ("Voxel volume", "The volume of a voxel", vox_vol, "mm^3") + + return data + + +def read_volume_file(path: Path) -> ImageTuple: + """ + Read a volume from disk. + + Parameters + ---------- + path : Path + The path to the file to read from. + + Returns + ------- + A tuple of nibabel image object and the data. + """ + try: + import nibabel as nib + img = cast(nib.analyze.SpatialImage, nib.load(path)) + if not isinstance(img, nib.analyze.SpatialImage): + raise RuntimeError( + f"Loading the file '{path}' for Measure was invalid, no SpatialImage." + ) + except (IOError, FileNotFoundError) as e: + args = e.args[0] + raise IOError(f"Failed loading the file '{path}' with error: {args}") from e + data = np.asarray(img.dataobj) + return img, data + + +def read_mesh_file(path: Path) -> "lapy.TriaMesh": + """ + Read a mesh from disk. + + Parameters + ---------- + path : Path + The path to the file. + + Returns + ------- + lapy.TriaMesh + The mesh object read from the file. + """ + try: + import lapy + mesh = lapy.TriaMesh.read_fssurf(str(path)) + except (IOError, FileNotFoundError) as e: + args = e.args[0] + raise IOError( + f"Failed loading the file '{path}' with error: {args}") from e + return mesh + + +def read_lta_transform_file(path: Path) -> "npt.NDArray[float]": + """ + Read and extract the first lta transform from an LTA file. + + Parameters + ---------- + path : Path + The path of the LTA file. + + Returns + ------- + matrix : npt.NDArray[float] + Matrix of shape (4, 4). + """ + from CerebNet.datasets.utils import read_lta + return read_lta(path)["lta"][0, 0] + + +def read_xfm_transform_file(path: Path) -> "npt.NDArray[float]": + """ + Read XFM talairach transform. + + Parameters + ---------- + path : str | Path + The filename/path of the transform file. + + Returns + ------- + tal + The talairach transform matrix. + + Raises + ------ + ValueError + If the file is of an invalid format. + """ + with open(path) as f: + lines = f.readlines() + + try: + transf_start = [l.lower().startswith("linear_") for l in lines].index(True) + 1 + tal_str = [l.replace(";", " ") for l in lines[transf_start:transf_start + 3]] + tal = np.genfromtxt(tal_str) + tal = np.vstack([tal, [0, 0, 0, 1]]) + + return tal + except Exception as e: + err = ValueError(f"Could not find taiairach transform in {path}.") + raise err from e + + +def read_transform_file(path: Path) -> "npt.NDArray[float]": + """ + Read xfm or lta transform file. + + Parameters + ---------- + path : Path + The path to the file. + + Returns + ------- + tal + The talairach transform matrix. + """ + if path.suffix == ".lta": + return read_lta_transform_file(path) + elif path.suffix == ".xfm": + return read_xfm_transform_file(path) + else: + raise NotImplementedError( + f"The extension {path.suffix} is not '.xfm' or '.lta' and not recognized.") + + +def mask_in_array(arr: "npt.NDArray", items: "npt.ArrayLike") -> "npt.NDArray[bool]": + """ + Efficient function to generate a mask of elements in `arr`, which are also in items. + + Parameters + ---------- + arr : npt.NDArray + An array with data, most likely int. + items : npt.ArrayLike + Which elements of `arr` in arr should yield True. + + Returns + ------- + mask : npt.NDArray[bool] + A binary array, true, where elements in `arr` are in `items`. + + See Also + -------- + mask_not_in_array + """ + _items = np.asarray(items) + if _items.size == 0: + return np.zeros_like(arr, dtype=bool) + elif _items.size == 1: + return np.asarray(arr == _items.flat[0]) + else: + max_index = max(np.max(items), np.max(arr)) + if max_index >= 2 ** 16: + logging.getLogger(__name__).warning( + f"labels in arr are larger than {2 ** 16 - 1}, this is not recommended!" + ) + lookup = np.zeros(max_index + 1, dtype=bool) + lookup[_items] = True + return lookup[arr] + + +def mask_not_in_array( + arr: "npt.NDArray", + items: "npt.ArrayLike", +) -> "npt.NDArray[bool]": + """ + Inverse of mask_in_array. + + Parameters + ---------- + arr : npt.NDArray + An array with data, most likely int. + items : npt.ArrayLike + Which elements of `arr` in arr should yield False. + + Returns + ------- + mask : npt.NDArray[bool] + A binary array, true, where elements in `arr` are not in `items`. + + See Also + -------- + mask_in_array + """ + _items = np.asarray(items) + if _items.size == 0: + return np.ones_like(arr, dtype=bool) + elif _items.size == 1: + return np.asarray(arr != _items.flat[0]) + else: + max_index = max(np.max(items), np.max(arr)) + if max_index >= 2 ** 16: + logging.getLogger(__name__).warning( + f"labels in arr are larger than {2 ** 16 - 1}, this is not recommended!" + ) + lookup = np.ones(max_index + 1, dtype=bool) + lookup[_items] = False + return lookup[arr] + + +class AbstractMeasure(metaclass=abc.ABCMeta): + """ + The base class of all measures, which implements the name, description, and unit + attributes as well as the methods as_tuple(), __call__(), read_subject(), + set_args(), parse_args(), help(), and __str__(). + """ + + __PATTERN = re.compile("^([^\\s=]+)\\s*=\\s*(\\S.*)$") + + def __init__(self, name: str, description: str, unit: str): + self._name: str = name + self._description: str = description + self._unit: str = unit + self._subject_dir: Path | None = None + + def as_tuple(self) -> MeasureTuple: + return self._name, self._description, self(), self.unit + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._description + + @property + def unit(self) -> str: + return self._unit + + @property + def subject_dir(self) -> Path: + return self._subject_dir + + @abc.abstractmethod + def __call__(self) -> int | float: + ... + + def read_subject(self, subject_dir: Path) -> bool: + """ + Perform IO required to compute/fill the Measure. + + Parameters + ---------- + subject_dir : Path + Path to the directory of the subject_dir (often subject_dir/subject_id). + + Returns + ------- + bool + Whether there was an update. + """ + updated = subject_dir != self.subject_dir + if updated: + self._subject_dir = subject_dir + return updated + + @abc.abstractmethod + def _parsable_args(self) -> list[str]: + ... + + def set_args(self, **kwargs: str) -> None: + """ + Set the arguments of the Measure. + + Raises + ------ + ValueError + If there are unrecognized keyword arguments. + """ + if len(kwargs) > 0: + raise ValueError(f"Invalid args {tuple(kwargs.keys())}") + + def parse_args(self, *args: str) -> None: + """ + Parse additional args defining the behavior of the Measure. + + Parameters + ---------- + *args : str + Each args can be a string of '' (arg-style) and '=' + (keyword-arg-style), arg-style cannot follow keyword-arg-style args. + + Raises + ------ + ValueError + If there are more arguments than registered argument names. + RuntimeError + If an arg-style follows a keyword-arg-style argument, or if a keyword value + is redefined, or a keyword is not valid. + """ + + def kwerror(i, args, msg) -> RuntimeError: + return RuntimeError(f"Error parsing arg {i} in {args}: {msg}") + + _pargs = self._parsable_args() + if len(args) > len(_pargs): + raise ValueError( + f"The measure {self.name} can have up to {len(_pargs)} arguments, but " + f"parsing {len(args)}: {args}." + ) + _kwargs = {} + _kwmode = False + for i, (arg, default_key) in enumerate(zip(args, _pargs)): + if (hit := self.__PATTERN.match(arg)) is None: + # non-keyword mode + if _kwmode: + raise kwerror(i, args, f"non-keyword after keyword") + _kwargs[default_key] = arg + else: + # keyword mode + _kwmode = True + k = hit.group(1) + if k in _kwargs: + raise kwerror(i, args, f"keyword '{k}' already assigned") + if k not in _pargs: + raise kwerror(i, args, f"keyword '{k}' not in {_pargs}") + _kwargs[k] = hit.group(2) + self.set_args(**_kwargs) + + def help(self) -> str: + """ + Compiles a help message for the measure describing the measure's settings. + + Returns + ------- + A help string describing the Measure settings. + """ + return f"{self.name}=" + + @abc.abstractmethod + def __str__(self) -> str: + ... + + +class NullMeasure(AbstractMeasure): + """ + A Measure that supports no operations, always returns a value of zero. + """ + + def _parsable_args(self) -> list[str]: + return [] + + def __call__(self) -> int | float: + return 0 if self.unit == "unitless" else 0.0 + + def help(self) -> str: + return super().help() + "NULL" + + def __str__(self) -> str: + return "NullMeasure()" + + +class Measure(AbstractMeasure, Generic[T_BufferType], metaclass=abc.ABCMeta): + """ + Class to buffer computed values, buffers computed values. Implements a value + buffering interface for computed measure values and implement the read_subject + pattern. + """ + + __buffer: float | int | None + __token: str = "" + __PATTERN = re.compile("^([^\\s=]*file)\\s*=\\s*(\\S.*)$") + + def __call__(self) -> int | float: + token = str(self._subject_dir) + if self.__buffer is None or self.__token != token: + self.__token = token + self.__buffer = self._compute() + return self.__buffer + + @abc.abstractmethod + def _compute(self) -> int | float: + ... + + def __init__( + self, + file: Path, + name: str, + description: str, + unit: str, + read_hook: ReadFileHook[T_BufferType], + ): + self._file = file + self._callback = read_hook + self._data: Optional[T_BufferType] = None + self.__buffer = None + super().__init__(name, description, unit) + + def _load_error(self, name: str = "data") -> RuntimeError: + return RuntimeError( + f"The '{name}' is not available for {self.name} ({type(self).__name__}), " + f"maybe the subject has not been loaded or the cache been invalidated." + ) + + def _filename(self) -> Path: + return self._subject_dir / self._file + + def read_subject(self, subject_dir: Path) -> bool: + """ + Perform IO required to compute/fill the Measure. Delegates file reading to + read_hook (set in __init__). + + Parameters + ---------- + subject_dir : Path + Path to the directory of the subject_dir (often subject_dir/subject_id). + + Returns + ------- + bool + Whether there was an update to the data. + """ + if super().read_subject(subject_dir): + try: + self._data = self._callback(self._filename()) + except Exception as e: + e.args = f"{e.args[0]} ... during reading for measure {self}.", + raise e + return True + return False + + def _parsable_args(self) -> list[str]: + return ["file"] + + def set_args(self, file: str | None = None, **kwargs: str) -> None: + if file is not None: + self._file = Path(file) + return super().set_args(**kwargs) + + def __str__(self) -> str: + return f"{type(self).__name__}(file={self._file})" + + +class ImportedMeasure(Measure[dict[str, MeasureTuple]]): + """ + A Measure that implements reading measure values from a statsfile. + """ + + PREFIX = "__IMPORTEDMEASURE-prefix__" + read_file = staticmethod(read_measure_file) + + def __init__( + self, + key: str, + measurefile: Path, + name: str = "N/A", + description: str = "N/A", + unit: UnitString = "unitless", + read_file: Optional[ReadFileHook[dict[str, MeasureTuple]]] = None, + vox_vol: Optional[float] = None, + ): + self._key: str = key + super().__init__( + measurefile, + name, + description, + unit, + self.read_file if read_file is None else read_file, + ) + self._vox_vol: Optional[float] = vox_vol + + def _compute(self) -> int | float: + """ + Will also update the name, description and unit from the strings in the file. + + Returns + ------- + value : int | float + value of the measure (as read from the file) + """ + try: + self._name, self._description, out, self._unit = self._data[self._key] + except KeyError as e: + raise KeyError(f"Could not find {self._key} in {self._file}.") from e + return out + + def _parsable_args(self) -> list[str]: + return ["key", "measurefile"] + + def set_args( + self, + key: str | None = None, + measurefile: str | None = None, + **kwargs: str, + ) -> None: + if measurefile is not None: + kwargs["file"] = measurefile + if key is not None: + self._key = key + return super().set_args(**kwargs) + + def help(self) -> str: + return super().help() + f" imported from {self._file}" + + def __str__(self) -> str: + return f"ImportedMeasure(key={self._key}, measurefile={self._file})" + + def assert_measurefile_absolute(self): + """ + Assert that the Measure can be imported without a subject and subject_dir. + + Raises + ------ + AssertionError + """ + if not self._file.is_absolute() or not self._file.exists(): + raise AssertionError( + f"The ImportedMeasures {self.name} is defined for import, but the " + f"associated measure file {self._file} is not an absolute path or " + f"does not exist and no subjects dir or subject id are defined." + ) + + def get_vox_vol(self) -> float: + """ + Returns the voxel volume. + + Returns + ------- + float + The voxel volume associated with the imported measure. + + Raises + ------ + RuntimeError + If the voxel volume was not defined. + """ + if self._vox_vol is None: + raise RuntimeError(f"The voxel volume of {self} has never been specified.") + return self._vox_vol + + def set_vox_vol(self, value: float): + self._vox_vol = value + + def read_subject(self, subject_dir: Path) -> bool: + if super().read_subject(subject_dir): + vox_vol_tup = self._data.get("vox_vol", None) + if isinstance(vox_vol_tup, tuple) and len(vox_vol_tup) > 2: + self._vox_vol = vox_vol_tup[2] + return True + return False + + +class SurfaceMeasure(Measure["lapy.TriaMesh"], metaclass=abc.ABCMeta): + """ + Class to implement default Surface io. + """ + + read_file = staticmethod(read_mesh_file) + + def __init__( + self, + surface_file: Path, + name: str, + description: str, + unit: UnitString, + read_mesh: Optional[ReadFileHook["lapy.TriaMesh"]] = None, + ): + super().__init__( + surface_file, + name, + description, + unit, + self.read_file if read_mesh is None else read_mesh, + ) + + def __str__(self) -> str: + return f"{type(self).__name__}(surface_file={self._file})" + + def _parsable_args(self) -> list[str]: + return ["surface_file"] + + def set_args(self, surface_file: str | None = None, **kwargs: str) -> None: + if surface_file is not None: + kwargs["file"] = surface_file + return super().set_args(**kwargs) + + +class SurfaceHoles(SurfaceMeasure): + """Class to compute surfaces holes for surfaces.""" + + def _compute(self) -> int: + return int(1 - self._data.euler() / 2) + + def help(self) -> str: + return super().help() + f"surface holes from {self._file}" + + +class SurfaceVolume(SurfaceMeasure): + """Class to compute surface volume for surfaces.""" + + def _compute(self) -> float: + return self._data.volume() + + def help(self) -> str: + return super().help() + f"volume from {self._file}" + + +class PVMeasure(AbstractMeasure): + """Class to compute volume for segmentations (includes PV-correction).""" + + read_file = None + + def __init__( + self, + classes: ClassesType, + name: str, + description: str, + unit: Literal["mm^3"] = "mm^3", + ): + if unit != "mm^3": + raise ValueError("unit must be mm^3 for PVMeasure!") + self._classes = classes + super().__init__(name, description, unit) + self._pv_value = None + + @property + def vox_vol(self) -> float: + return self._vox_vol + + @vox_vol.setter + def vox_vol(self, v: float): + self._vox_vol = v + + def labels(self) -> list[int]: + return list(self._classes) + + def update_data(self, value: "pd.Series"): + self._pv_value = value + + def __call__(self) -> float: + if self._pv_value is None: + raise RuntimeError( + f"The partial volume of {self._name} has not been updated in the " + f"PVMeasure object yet!" + ) + col = "NVoxels" if self.unit == "unitless" else "Volume_mm3" + return self._pv_value[col].item() + + def _parsable_args(self) -> list[str]: + return ["classes"] + + def set_args(self, classes: str | None = None, **kwargs: str) -> None: + if classes is not None: + self._classes = classes + return super().set_args(**kwargs) + + def __str__(self) -> str: + return f"PVMeasure(classes={list(self._classes)})" + + def help(self) -> str: + help_str = f"partial volume of {format_classes(self._classes)} in seg file" + return super().help() + help_str + + +def format_classes(_classes: Iterable[int]) -> str: + """ + Formats an iterable of classes. This compresses consecutive integers into ranges. + >>> format_classes([1, 2, 3, 6]) # '1-3,6' + + Parameters + ---------- + _classes : Iterable[int] + An iterable of integers. + + Returns + ------- + A string of sorted integers and integer ranges, '()' if iterable is empty, or just + the string conversion of _classes, if _classes is not an iterable. + + Notes + ----- + This function will likely be moved to a different file. + """ + # TODO move this function to a more appropriate module + if not isinstance(_classes, Iterable): + return str(_classes) + from itertools import pairwise + sorted_list = list(sorted(_classes)) + if len(sorted_list) == 0: + return "()" + prev = "" + out = str(sorted_list[0]) + + for a, b in pairwise(sorted_list): + if a != b - 1: + out += f"{prev},{b}" + prev = "" + else: + prev = f"-{b}" + return out + prev + + +class VolumeMeasure(Measure[ImageTuple]): + """ + Counts Voxels belonging to a class or condition. + """ + + read_file = staticmethod(read_volume_file) + + def __init__( + self, + segfile: Path, + classes_or_cond: ClassesOrCondType, + name: str, + description: str, + unit: UnitString = "unitless", + read_file: Optional[ReadFileHook[ImageTuple]] = None, + ): + if callable(classes_or_cond): + self._classes: Optional[ClassesType] = None + self._cond: _ToBoolCallback = classes_or_cond + else: + if len(classes_or_cond) == 0: + raise ValueError(f"No operation passed to {type(self).__name__}.") + self._classes = classes_or_cond + from functools import partial + self._cond = partial(mask_in_array, items=self._classes) + if unit not in ["unitless", "mm^3"]: + raise ValueError("unit must be either 'mm^3' or 'unitless' for " + + type(self).__name__) + super().__init__(segfile, name, description, unit, + self.read_file if read_file is None else read_file) + + def get_vox_vol(self) -> float: + return np.prod(self._data[0].header.get_zooms()).item() + + def _compute(self) -> int | float: + if not isinstance(self._data, tuple) or len(self._data) != 2: + raise self._load_error("data") + vox_vol = 1 if self._unit == "unitless" else self.get_vox_vol() + return np.sum(self._cond(self._data[1]), dtype=int).item() * vox_vol + + def _parsable_args(self) -> list[str]: + return ["segfile", "classes"] + + def _set_classes(self, classes: str | None, attr_name: str, cond_name: str) -> None: + """Helper method for set_args.""" + if classes is not None: + from functools import partial + _classes = re.split("\\s+", classes.lstrip("[ ").rstrip("] ")) + items = list(map(int, _classes)) + setattr(self, attr_name, items) + setattr(self, cond_name, partial(mask_in_array, items=items)) + + def set_args( + self, + segfile: str | None = None, + classes: str | None = None, + **kwargs: str, + ) -> None: + if segfile is not None: + kwargs["file"] = segfile + self._set_classes(classes, "_classes", "_cond") + return super().set_args(**kwargs) + + def __str__(self) -> str: + return f"{type(self).__name__}(segfile={self._file}, {self._param_string()})" + + def help(self) -> str: + return f"{self._name}={self._param_help()} in {self._file}" + + def _param_help(self, prefix: str = ""): + """Helper method for format classes and cond.""" + cond = getattr(self, prefix + "_cond") + classes = getattr(self, prefix + "_classes") + return prefix + (f"cond={cond}" if classes is None else format_classes(classes)) + + def _param_string(self, prefix: str = ""): + """Helper method to convert classes and cond to string.""" + cond = getattr(self, prefix + "_cond") + classes = getattr(self, prefix + "_classes") + return prefix + (f"cond={cond}" if classes is None else f"classes={classes}") + + +class MaskMeasure(VolumeMeasure): + + def __init__( + self, + maskfile: Path, + name: str, + description: str, + unit: UnitString = "unitless", + threshold: float = 0.5, + # sign: MaskSign = "abs", frame: int = 0, + # erode: int = 0, invert: bool = False, + read_file: Optional[ReadFileHook[ImageTuple]] = None, + ): + self._threshold: float = threshold + # self._sign: MaskSign = sign + # self._invert: bool = invert + # self._frame: int = frame + # self._erode: int = erode + super().__init__(maskfile, self.mask, name, description, unit, read_file) + + def mask(self, data: "npt.NDArray[int]") -> "npt.NDArray[bool]": + """Generates a mask from data similar to mri_binarize + erosion.""" + # if self._sign == "abs": + # data = np.abs(data) + # elif self._sign == "neg": + # data = -data + out = np.greater(data, self._threshold) + # if self._invert: + # out = np.logical_not(out) + # if self._erode != 0: + # from scipy.ndimage import binary_erosion + # binary_erosion(out, iterations=self._erode, output=out) + return out + + def set_args( + self, + maskfile: Path | None = None, + threshold: float | None = None, + **kwargs: str, + ) -> None: + if threshold is not None: + self._threshold = float(threshold) + if maskfile is not None: + kwargs["file"] = maskfile + return super().set_args(**kwargs) + + def _parsable_args(self) -> list[str]: + return ["maskfile", "threshold"] + + def __str__(self) -> str: + return ( + f"{type(self).__name__}(maskfile={self._file}, threshold={self._threshold})" + ) + + def _param_help(self, prefix: str = ""): + return f"voxel > {self._threshold}" + + +AnyParentsTuple = tuple[float, AnyMeasure] +ParentsTuple = tuple[float, AnyMeasure] + + +class TransformMeasure(Measure, metaclass=abc.ABCMeta): + read_file = staticmethod(read_transform_file) + + def __init__( + self, + lta_file: Path, + name: str, + description: str, + unit: str, + read_lta: Optional[ReadFileHook["npt.NDArray[float]"]] = None, + ): + super().__init__( + lta_file, + name, + description, + unit, + self.read_file if read_lta is None else read_lta, + ) + + def _parsable_args(self) -> list[str]: + return ["lta_file"] + + def set_args(self, lta_file: str | None = None, **kwargs: str) -> None: + if lta_file is not None: + kwargs["file"] = lta_file + return super().set_args(**kwargs) + + def __str__(self) -> str: + return f"{type(self).__name__}(lta_file={self._file})" + + +class ETIVMeasure(TransformMeasure): + """ + Compute the eTIV based on the freesurfer talairach registration and lta. + + Notes + ----- + Reimplemneted from freesurfer/mri_sclimbic_seg + https://github.com/freesurfer/freesurfer/blob/ + 3296e52f8dcffa740df65168722b6586adecf8cc/mri_sclimbic_seg/mri_sclimbic_seg#L627 + """ + + def __init__( + self, + lta_file: Path, + name: str, + description: str, + unit: str, + read_lta: Optional[ReadFileHook["LTADict"]] = None, + etiv_scale_factor: float | None = None, + ): + if etiv_scale_factor is None: + self._etiv_scale_factor = 1948106. # 1948.106 cm^3 * 1e3 mm^3/cm^3 + else: + self._etiv_scale_factor = etiv_scale_factor + super().__init__(lta_file, name, description, unit, read_lta) + + def _parsable_args(self) -> list[str]: + return super()._parsable_args() + ["etiv_scale_factor"] + + def set_args(self, etiv_scale_factor: str | None = None, **kwargs: str) -> None: + if etiv_scale_factor is not None: + self._etiv_scale_factor = float(etiv_scale_factor) + return super().set_args(**kwargs) + + def _compute(self) -> float: + # this scale factor is a fixed number derived by freesurfer + return self._etiv_scale_factor / np.linalg.det(self._data).item() + + def help(self) -> str: + return super().help() + f"eTIV from {self._file}" + + def __str__(self) -> str: + return f"{super().__str__()[:-1]}, etiv_scale_factor={self._etiv_scale_factor})" + + +class DerivedMeasure(AbstractMeasure): + + def __init__( + self, + parents: Iterable[tuple[float, AnyMeasure] | AnyMeasure], + name: str, + description: str, + unit: str = "from parents", + operation: DerivedAggOperation = "sum", + measure_host: Optional[dict[str, AbstractMeasure]] = None, + ): + """ + Create the Measure, which depends on other measures, called parent measures. + + Parameters + ---------- + parents : Iterable[tuple[float, AbstractMeasure] | AbstractMeasure] + Iterable of either the measures (or a tuple of a float and a measure), the + float is the factor by which the value of the respective measure gets + weighted and defaults to 1. + name : str + Name of the Measure. + description : str + Description text of the measure + unit : str, optional + Unit of the measure, typically 'mm^3' or 'unitless', autogenerated from + parents' unit. + operation : "sum", "ratio", "by_vox_vol", optional + How to aggregate multiple `parents`, default = 'sum' + 'ratio' only supports exactly 2 parents. + 'by_vox_vol' only supports exactly one parent. + measure_host : dict[str, AbstractMeasure], optional + A dict-like to provide AbstractMeasure objects for strings. + """ + + def to_tuple( + value: tuple[float, AnyMeasure] | AnyMeasure, + ) -> tuple[float, AnyMeasure]: + if isinstance(value, Sequence) and not isinstance(value, str): + if len(value) != 2: + raise ValueError("A tuple was not length 2.") + factor, measure = value + else: + factor, measure = 1., value + + if not isinstance(measure, (str, AbstractMeasure)): + raise ValueError(f"Expected a str or AbstractMeasure, not " + f"{type(measure).__name__}!") + if not isinstance(factor, float): + factor = float(factor) + return factor, measure + + self._parents: list[AnyParentsTuple] = [to_tuple(p) for p in parents] + if len(self._parents) == 0: + raise ValueError("No parents defined in DerivedMeasure.") + self._measure_host = measure_host + if operation in ("sum", "ratio", "by_vox_vol"): + self._operation: DerivedAggOperation = operation + else: + raise ValueError("operation must be 'sum', 'ratio' or 'by_vox_vol'.") + super().__init__(name, description, unit) + + @property + def unit(self) -> str: + """ + Property to access the unit attribute, also implements auto-generation of unit, + if the stored unit is 'from parents'. + + Returns + ------- + str + A string that identifies the unit of the Measure. + + Raises + ------ + RuntimeError + If unit is 'from parents' and some parent measures are inconsistent with + each other. + """ + if self._unit == "from parents": + units = list(map(lambda x: x.unit, self.parents)) + if self._operation == "sum": + if len(units) == 0: + raise ValueError("DerivedMeasure has no parent measures.") + elif len(units) == 1 or all(units[0] == u for u in units[1:]): + return units[0] + elif self._operation == "ratio": + if len(units) != 2: + raise self.invalid_len_ratio() + elif units[0] == units[1]: + return "unitless" + elif units[1] == "unitless": + return units[0] + elif self._operation == "by_vox_vol": + if len(units) != 1: + raise self.invalid_len_vox_vol() + elif units[0] == "mm^3": + return "unitless" + else: + raise RuntimeError("Invalid value of parent, must be mm^3, but " + f"was {units[0]}.") + raise RuntimeError( + f"unit is set to auto-generate from parents, but the parents' units " + f"are not consistent: {units}!" + ) + else: + return super().unit + + def invalid_len_ratio(self) -> RuntimeError: + return RuntimeError(f"Invalid number of parents ({len(self._parents)}) for " + f"operation 'ratio'.") + + def invalid_len_vox_vol(self) -> RuntimeError: + return RuntimeError(f"Invalid number of parents ({len(self._parents)}) for " + f"operation 'by_vox_vol'.") + + @property + def parents(self) -> Iterable[AbstractMeasure]: + """Iterable of the measures this measure depends on.""" + return (p for _, p in self.parents_items()) + + def parents_items(self) -> Iterable[tuple[float, AbstractMeasure]]: + """Iterable of the measures this measure depends on.""" + return ((f, self._measure_host[p] if isinstance(p, str) else p) + for f, p in self._parents) + + def __read_subject(self, subject_dir: Path) -> bool: + """Default implementation for the read_subject_on_parents function hook.""" + return any(m.read_subject(subject_dir) for m in self.parents) + + @property + def read_subject_on_parents(self) -> Callable[[Path], bool]: + """read_subject_on_parents function hook property""" + if (self._measure_host is not None and + hasattr(self._measure_host, "read_subject_parents")): + from functools import partial + return partial(self._measure_host.read_subject_parents, self.parents) + else: + return self.__read_subject + + def read_subject(self, subject_dir: Path) -> bool: + """ + Perform IO required to compute/fill the Measure. Will trigger the + read_subject_on_parents function hook to populate the values of parent measures. + + Parameters + ---------- + subject_dir : Path + Path to the directory of the subject_dir (often subject_dir/subject_id). + + Returns + ------- + bool + Whether there was an update. + + Notes + ----- + Might trigger a race condition if the function hook `read_subject_on_parents` + depends on this method finishing first, e.g. because of thread availability. + """ + if super().read_subject(subject_dir): + return self.read_subject_on_parents(self._subject_dir) + return False + + def __call__(self) -> int | float: + """ + Compute dependent measures and accumulate them according to the operation. + """ + factor_value = [(s, m()) for s, m in self.parents_items()] + isint = all(isinstance(v, int) for _, v in factor_value) + isint &= all(np.isclose(s, np.round(s)) for s, _ in factor_value) + values = [s * v for s, v in factor_value] + if self._operation == "sum": + # sum should be an int, if all contributors are int + # and all factors are integers (but not necessarily int) + out = np.sum(values) + target_type = int if isint else float + return target_type(out) + elif self._operation == "by_vox_vol": + if len(self._parents) != 1: + raise self.invalid_len_vox_vol() + vox_vol = self.get_vox_vol() + if isinstance(vox_vol, _DefaultFloat): + logging.getLogger(__name__).warning( + f"The vox_vol in {self} was unexpectedly not initialized; using " + f"{vox_vol}!" + ) + # ratio should always be float / could be partial voxels + return float(values[0]) / vox_vol + else: # operation == "ratio" + if len(self._parents) != 2: + raise self.invalid_len_ratio() + # ratio should always be float + return float(values[0]) / float(values[1]) + + def get_vox_vol(self) -> float | None: + """ + Return the voxel volume of the first parent measure. + + Returns + ------- + float, None + voxel volume of the first parent + """ + _types = (VolumeMeasure, DerivedMeasure) + _type = ImportedMeasure + fallback = None + for p in self.parents: + if isinstance(p, _types) and (_vvol := p.get_vox_vol()) is not None: + return _vvol + if isinstance(p, _type) and (_vvol := p.get_vox_vol()) is not None: + if isinstance(_vvol, _DefaultFloat): + fallback = _vvol + else: + return _vvol + return fallback + + def _parsable_args(self) -> list[str]: + return ["parents", "operation"] + + def set_args( + self, + parents: str | None = None, + operation: str | None = None, + **kwargs: str, + ) -> None: + if parents is not None: + pat = re.compile("^(\\d+\\.?\\d*\\s+)?(\\s.*)") + stripped = parents.lstrip("[ ").rstrip("] ") + + def parse(p: str) -> tuple[float, str]: + hit = pat.match(p) + if hit is None: + return 1., p + return 1. if hit.group(1).strip() else float(hit.group(1)), hit.group(2) + + self._parents = list(map(parse, re.split("\\s+", stripped))) + if operation is not None: + from typing import get_args as args + if operation in args(DerivedAggOperation): + self._operation = operation + else: + raise ValueError(f"operation can only be {args(DerivedAggOperation)}") + return super().set_args(**kwargs) + + def __str__(self) -> str: + return f"DerivedMeasure(parents={self._parents}, operation={self._operation})" + + def help(self) -> str: + sign = {True: "+", False: "-"} + + def format_factor(f: float) -> str: + return f"{sign[f >= 0]} " + ((str(abs(f)) + " ") if abs(f) != 1. else '') + + def format_parent(measure: str | AnyMeasure) -> str: + if isinstance(measure, str): + measure = self._measure_host[measure] + return measure if isinstance(measure, str) else measure.help() + + if self._operation == "sum": + par = "".join(f" {format_factor(f)}({format_parent(p)})" + for f, p in self._parents) + return par.lstrip(' +') + elif self._operation == "by_vox_vol": + f, measure = self._parents[0] + return f"{sign[f >= 0]} {format_factor(f)} [{format_parent(measure)}]" + elif self._operation == "ratio": + f = self._parents[0][0] / self._parents[1][0] + return (f" {sign[f >= 0]} {format_factor(f)} (" + + ") / (".join(format_parent(p[1]) for p in self._parents) + ")") + else: + return f"invalid operation {self._operation}" + + +class VoxelClassGenerator(Protocol): + """ + Generator for voxel-based metric Measures. + """ + + def __call__( + self, + classes: Sequence[int], + name: str, + description: str, + unit: str, + ) -> PVMeasure | VolumeMeasure: + ... + + +def format_measure(key: str, data: MeasureTuple) -> str: + value = data[2] if isinstance(data[2], int) else ("%.6f" % data[2]) + return f"# Measure {key}, {data[0]}, {data[1]}, {value}, {data[3]}" + + +class Manager(dict[str, AbstractMeasure]): + _PATTERN_NO_ARGS = re.compile("^\\s*([^(]+?)\\s*$") + _PATTERN_ARGS = re.compile("^\\s*([^(]+)\\(\\s*([^)]*)\\s*\\)\\s*$") + _PATTERN_DELIM = re.compile("\\s*,\\s*") + + _compute_futures: list[Future] + __DEFAULT_MEASURES = ( + "BrainSeg", + "BrainSegNotVent", + "VentricleChoroidVol", + "lhCortex", + "rhCortex", + "Cortex", + "lhCerebralWhiteMatter", + "rhCerebralWhiteMatter", + "CerebralWhiteMatter", + "SubCortGray", + "TotalGray", + "SupraTentorial", + "SupraTentorialNotVent", + "Mask", + "BrainSegVol-to-eTIV", + "MaskVol-to-eTIV", + "lhSurfaceHoles", + "rhSurfaceHoles", + "SurfaceHoles", + "EstimatedTotalIntraCranialVol", + ) + + def __init__( + self, + measures: Sequence[tuple[bool, str]], + measurefile: Optional[Path] = None, + segfile: Optional[Path] = None, + on_missing: Literal["fail", "skip", "fill"] = "fail", + executor: Optional[Executor] = None, + legacy_freesurfer: bool = False, + aseg_replace: Optional[Path] = None, + ): + """ + + Parameters + ---------- + measures : Sequence[tuple[bool, str]] + The measures to be included as whether it is computed and name/measure str. + measurefile : Path, optional + The path to the file to import measures from (other stats file, absolute or + relative to subject_dir). + segfile : Path, optional + The path to the file to use for segmentation (other stats file, absolute or + relative to subject_dir). + on_missing : Literal["fail", "skip", "fill"], optional + behavior to follow if a requested measure does not exist in path. + executor : concurrent.futures.Executor, optional + thread pool to parallelize io + legacy_freesurfer : bool, default=False + FreeSurfer compatibility mode. + """ + from concurrent.futures import ThreadPoolExecutor, Future + from copy import deepcopy + + def _check_measures(x): + return not (isinstance(x, tuple) and len(x) == 2 or + isinstance(x[0], bool) or isinstance(x[1], str)) + super().__init__() + self._default_measures = deepcopy(self.__DEFAULT_MEASURES) + if not isinstance(measures, Sequence) or any(map(_check_measures, measures)): + raise ValueError("measures must be sequences of str.") + if executor is None: + self._executor = ThreadPoolExecutor(8) + elif isinstance(executor, ThreadPoolExecutor): + self._executor = executor + else: + raise TypeError( + "executor must be a futures.concurrent.ThreadPoolExecutor to ensure " + "proper multitask behavior." + ) + self._io_futures: list[Future] = [] + self.__update_context: list[AbstractMeasure] = [] + self._on_missing = on_missing + self._import_all_measures: list[Path] = [] + self._subject_all_imported: list[Path] = [] + self._exported_measures: list[str] = [] + self._cache: dict[Path, Future[AnyBufferType] | AnyBufferType] = {} + # self._lut: Optional[pd.DataFrame] = None + self._fs_compat: bool = legacy_freesurfer + self._seg_from_file = Path("mri/aseg.mgz") + if aseg_replace: + # explicitly defined a file to reduce the aseg for segmentation mask with + logging.getLogger(__name__).info( + f"Replacing segmentation volume to compute volume measures from with " + f"the explicitly defined {aseg_replace}." + ) + self._seg_from_file = Path(aseg_replace) + elif not self._fs_compat and segfile and Path(segfile) != self._seg_from_file: + # not in freesurfer compatibility mode, so implicitly use segfile + logging.getLogger(__name__).info( + f"Replacing segmentation volume to compute volume measures from with " + f"the segmentation file {segfile}." + ) + self._seg_from_file = Path(segfile) + + import_kwargs = {"vox_vol": _DefaultFloat(1.0)} + if any(filter(lambda x: x[0], measures)): + if measurefile is None: + raise ValueError( + "Measures defined to import, but no measurefile specified. " + "A default must always be defined." + ) + import_kwargs["measurefile"] = Path(measurefile) + import_kwargs["read_file"] = self.make_read_hook(read_measure_file) + import_kwargs["read_file"](Path(measurefile), blocking=False) + for is_imported, measure_string in measures: + if is_imported: + self.add_imported_measure(measure_string, **import_kwargs) + else: + self.add_computed_measure(measure_string) + self.instantiate_measures(self.values()) + + @property + def executor(self) -> Executor: + return self._executor + + # @property + # def lut(self) -> Optional["pd.DataFrame"]: + # return self._lut + # + # @lut.setter + # def lut(self, lut: Optional["pd.DataFrame"]): + # self._lut = lut + + def assert_measure_need_subject(self) -> None: + """ + Assert whether the measure expects a definition of the subject_dir. + + Raises + ------ + AssertionError + """ + any_computed = False + for key, measure in self.items(): + if isinstance(measure, DerivedMeasure): + pass + elif isinstance(measure, ImportedMeasure): + measure.assert_measurefile_absolute() + else: + any_computed = True + if any_computed: + raise AssertionError( + "Computed measures are defined, but no subjects dir or subject id." + ) + + def instantiate_measures(self, measures: Iterable[AbstractMeasure]) -> None: + """ + Make sure all measures that dependent on `measures` are instantiated. + """ + for measure in list(measures): + if isinstance(measure, DerivedMeasure): + self.instantiate_measures(measure.parents) + + def add_imported_measure(self, measure_string: str, **kwargs) -> None: + """ + Add an imported measure from the measure_string definition and default + measurefile. + + Parameters + ---------- + measure_string : str + Definition of the measure. + + Other Parameters + ---------------- + measurefile : Path + Path to the default measurefile to import from (ImportedMeasure argument). + read_file : ReadFileHook[dict[str, MeasureTuple]] + Function handle to read and parse the file (argument to ImportedMeasure). + vox_vol: float, optional + The voxel volume to associate the measure with. + + Raises + ------ + RuntimeError + If trying to replace a computed Measure of the same key. + """ + # currently also extracts args, this maybe should be removed for simpler code + key, args = self.extract_key_args(measure_string) + if key == "all": + _mfile = kwargs["measurefile"] if len(args) == 0 else Path(args[0]) + self._import_all_measures.append(_mfile) + elif key not in self.keys() or isinstance(self[key], ImportedMeasure): + # note: name, description and unit are always updated from the input file + self[key] = ImportedMeasure(key, **kwargs) + # parse the arguments (inplace) + self[key].parse_args(*args) + self._exported_measures.append(key) + else: + raise RuntimeError( + "Illegal operation: Trying to replace the computed measure at " + f"{key} ({self[key]}) with an imported measure." + ) + + def add_computed_measure( + self, + measure_string: str, + ) -> None: + """Add a computed measure from the measure_string definition.""" + # currently also extracts args, this maybe should be removed for simpler code + key, args = self.extract_key_args(measure_string) + # also overwrite prior definition + if key in self._exported_measures: + self[key] = self.default(key) + else: + self._exported_measures.append(key) + # load the default config of the measure and copy, overwriting other measures + # with the same key (only keep computed versions or the last) parse the + # arguments (inplace) + self[key].parse_args(*args) + + def __getitem__(self, key: str) -> AbstractMeasure: + """ + Get the value of the key. + + Parameters + ---------- + key : str + A string naming the Measure, may also include extra parameters as format + '()', e.g. 'Mask(maskfile=/path/to/mask.mgz)'. + + Returns + ------- + AbstractMeasure + The measure associated with the '' + """ + if "(" in key: + key, args = key.split("(", 1) + args = list(map(str.strip, args.rstrip(") ").split(","))) + else: + args = [] + try: + out = super().__getitem__(key) + except KeyError: + out = self.default(key) + if out is not None: + self[key] = out + else: + raise + if len(args) > 0: + out.parse_args(*args) + return out + + def start_read_subject(self, subject_dir: Path) -> None: + """ + Start the threads to read the subject in subject_dir, pairs with + `wait_read_subject`. + + Parameters + ---------- + subject_dir : Path + The path to the directory of the subject (with folders 'mri', 'stats', ...). + """ + if len(self._io_futures) != 0: + raise RuntimeError("Did not process/wait on finishing the processing for " + "the previous start_read_subject run. Needs call to " + "`wait_read_subject`.") + self.__update_context = [] + self._subject_all_imported = [] + read_file = self.make_read_hook(read_measure_file) + for file in self._import_all_measures: + path = file if file.is_absolute() else subject_dir / file + read_file(path, blocking=False) + self._subject_all_imported.append(path) + self.read_subject_parents(self.values(), subject_dir, False) + + @contextmanager + def with_subject(self, subjects_dir: Path | None, subject_id: str | None) -> None: + """ + Contextmanager for the `start_read_subject` and the `wait_read_subject` pair. + + If one value is None, it is assumed the subject_dir and subject_id are not + needed, for example because all file names are given by absolute paths. + + Parameters + ---------- + subjects_dir : Path, None + The path to the directory of the subject (with folders 'mri', 'stats', ...). + subject_id : str, None + The subject_id identifying folder of the subjects_dir. + + Raises + ------ + AssertionError + If subjects_dir and or subject_id are needed. + """ + if subjects_dir is None or subject_id is None: + yield self.assert_measure_need_subject() + # no reading the subject required, we have no measures to include + return + else: + # the subject is defined, we read it. + yield self.start_read_subject(subjects_dir / subject_id) + return self.wait_read_subject() + + def wait_read_subject(self) -> None: + """ + Wait for all threads to finish reading the 'current' subject. + + Raises + ------ + Exception + The first exception encountered during the read operation. + """ + for f in self._io_futures: + exception = f.exception() + if exception is not None: + raise exception + self._io_futures.clear() + vox_vol = None + + def check_needs_init(m: AbstractMeasure) -> bool: + return isinstance(m, ImportedMeasure) and isinstance(m.get_vox_vol(), + _DefaultFloat) + + # and an ImportedMeasure is present, but not initialized + for m in filter(check_needs_init, self.values()): + # lazily load a value for vox_vol + if vox_vol is None: + # if the _seg_from_file file is loaded into the cache (should be) + if self._seg_from_file in self._cache: + read_func = self.make_read_hook(read_volume_file) + img, _ = read_func(self._seg_from_file, blocking=True) + vox_vol = np.prod(img.header.get_zooms()) + if vox_vol is not None: + m.set_vox_vol(vox_vol) + + def read_subject_parents( + self, + measures: Iterable[AbstractMeasure], + subject_dir: Path, + blocking: bool = False, + ) -> True: + """ + Multi-threaded iteration through measures and application of read_subject, also + implementation for the read_subject_on_parents function hook. Guaranteed to + return + independent of state and thread availability to avoid a race condition. + + Parameters + ---------- + measures : Iterable[AbstractMeasure] + iterable of Measures to read + subject_dir : Path + Path to the subject directory (often subjects_dir/subject_id). + blocking : bool, optional + whether the execution should be parallel or not (default: False/parallel). + + Returns + ------- + True + """ + + def _read(measure: AbstractMeasure) -> bool: + """Callback so files for measures are loaded in other threads.""" + return measure.read_subject(subject_dir) + + _update_context = set( + filter(lambda m: m not in self.__update_context, measures) + ) + # __update_context is the structure that holds measures that have read_subject + # already called / submitted to the executor + self.__update_context.extend(_update_context) + for x in _update_context: + # DerivedMeasure.read_subject calls Manager.read_subject_parents (this + # method) to read the data from dependent measures (through the callback + # DerivedMeasure.read_subject_on_parents, and DerivedMeasure.measure_host). + if blocking or isinstance(x, DerivedMeasure): + x.read_subject(subject_dir) + else: + # calls read_subject on all measures, redundant io operations are + # handled/skipped through Manager.make_read_hook and the internal + # caching of files within the _cache attribute of Manager. + self._io_futures.append(self._executor.submit(_read, x)) + return True + + def extract_key_args(self, measure: str) -> tuple[str, list[str]]: + """ + Extract the name and options from a string like '()'. + + The '' is optional and is similar to python parameters. It starts + with numbered parameters, followed by key-value pairs. + Examples are: + - 'Mask(mri/aseg.mgz)' + returns: ('BrainSeg', ['mri/aseg.mgz', 'classes=[2, 4]']) + - 'TotalGray(mri/aseg.mgz, classes=[2, 4])' + returns: ('BrainSeg', ['mri/aseg.mgz', 'classes=[2, 4]']) + - 'BrainSeg(segfile=mri/aseg.mgz, classes=[2, 4])' + returns: ('BrainSeg', ['segfile=mri/aseg.mgz', 'classes=[2, 4]']) + + Parameters + ---------- + measure : str + The measure string of the format '' or '()'. + + Returns + ------- + key : str + the name of the measure + args : list[str] + a list of options + + Raises + ------ + ValueError + If the string `measure` does not conform to the format requirements. + """ + hits_no_args = self._PATTERN_NO_ARGS.match(measure) + if hits_no_args is not None: + key = hits_no_args.group(1) + args = [] + elif (hits_args := self._PATTERN_ARGS.match(measure)) is not None: + key = hits_args.group(1) + args = self._PATTERN_DELIM.split(hits_args.group(2)) + else: + extra = "" + if any(q in measure for q in "\"'"): + extra = ", watch out for quotes" + raise ValueError(f"Invalid Format of Measure \"{measure}\"{extra}!") + return key, args + + def make_read_hook( + self, + read_func: Callable[[Path], T_BufferType], + ) -> ReadFileHook[T_BufferType]: + """ + Wraps an io function to buffer results, multi-thread calls, etc. + + Parameters + ---------- + read_func : Callable[[Path], T_BufferType] + Function to read Measure entries/ images/ surfaces from a file. + + Returns + ------- + wrapped_func : ReadFileHook[T_BufferType] + The returned function takes a path and whether to wait for the io to finish. + file : Path + the path to the read from (path can be used for buffering) + blocking : bool, optional + do not return the data, do not wait for the io to finish, just preload + (default: False) + The function returns None or the output of the wrapped function. + """ + + def read_wrapper(file: Path, blocking: bool = True) -> Optional[T_BufferType]: + out = self._cache.get(file, None) + if out is None: + # not already in cache + if blocking: + out = read_func(file) + else: + out = self._executor.submit(read_func, file) + self._cache[file] = out + if not blocking: + return + elif isinstance(out, Future): + self._cache[file] = out = out.result() + return out + + return read_wrapper + + def clear(self): + """ + Clear the file buffers. + """ + self._cache = {} + + def update_measures(self) -> dict[str, float | int]: + """ + Get the values to all measures (including imported via 'all'). + + Returns + ------- + dict[str, Union[float, int]] + A dictionary of '' (the Measure key) and the associated value. + """ + m = {key: v[2] for key, v in self.get_imported_all_measures().items()} + m.update({key: self[key]() for key in self._exported_measures}) + return m + + def print_measures(self, file: Optional[TextIO] = None) -> None: + """ + Print the measures to stdout or file. + + Parameters + ---------- + file: TextIO, optional + The file object to write to. If None, writes to stdout. + """ + kwargs = {} if file is None else {"file": file} + for line in self.format_measures(): + print(line, **kwargs) + + def get_imported_all_measures(self) -> dict[str, MeasureTuple]: + """ + Get the measures imported through the 'all' keyword. + + Returns + ------- + dict[str, MeasureTuple] + A dictionary of Measure keys and tuples of name, description, value, unit. + """ + if len(self._subject_all_imported) == 0: + return {} + measures = {} + read_file = self.make_read_hook(ImportedMeasure.read_file) + for path in self._subject_all_imported: + measures.update(read_file(path)) + return measures + + def format_measures( + self, /, + fmt_func: Callable[[str, MeasureTuple], str] = format_measure, + ) -> Iterable[str]: + """ + Formats all measures as strings and returns them as an iterable of str. + + In the output, measures are ordered in the order they are added to the Manager + object. Finally, the "all"-imported Measures are appended. + + Parameters + ---------- + fmt_func: callable, default=fmt_measure + Function to format the key and a MeasureTuple object into a string. + + Returns + ------- + Iterable[str] + An iterable of the measure strings. + """ + measures = {key: self[key].as_tuple() for key in self._exported_measures} + for k, v in self.get_imported_all_measures().items(): + measures.setdefault(k, v) + + return map(lambda x: fmt_func(*x), measures.items()) + + @property + def default_measures(self) -> Iterable[str]: + """ + Iterable over measures typically included stats files in correct order. + + Returns + ------- + Iterable[str] + An ordered iterable of the default Measure keys. + """ + return self._default_measures + + @default_measures.setter + def default_measures(self, values: Iterable[str]): + """ + Sets the iterable over measure keys in correct order. + + Parameters + ---------- + values : Iterable[str] + An ordered iterable of the default Measure keys. + """ + self._default_measures = values + + @property + def voxel_class(self) -> VoxelClassGenerator: + """ + A callable initializing a Volume-based Measure object with the legacy mode. + + Returns + ------- + type[AbstractMeasure] + A callable to create an object to perform a Volume-based Measure. + """ + from functools import partial + if self._fs_compat: + return partial( + VolumeMeasure, + self._seg_from_file, + read_file=self.make_read_hook(VolumeMeasure.read_file), + ) + else: # FastSurfer compat == None + return partial(PVMeasure) + + def default(self, key: str) -> AbstractMeasure: + """ + Returns the default Measure object for the measure with key `key`. + + Parameters + ---------- + key : str + The key name of the Measure. + + Returns + ------- + AbstractMeasure + The Measure object initialized with default values. + + Supported keys are: + - `lhSurfaceHoles`, `rhSurfaceHoles`, and `SurfaceHoles` + The number of holes in the surfaces. + - `lhPialTotal`, and `rhPialTotal` + The volume enclosed in the pial surfaces. + - `lhWhiteMatterVol`, and `rhWhiteMatterVol` + The Volume of the white matter in the segmentation (incl. lateralized + WM-hypo). + - `lhWhiteMatterTotal`, and `rhWhiteMatterTotal` + The volume enclosed in the white matter surfaces. + - `lhCortex`, `rhCortex`, and `Cortex` + The volume between the pial and the white matter surfaces. + - `CorpusCallosumVol` + The volume of the corpus callosum in the segmentation. + - `lhWM-hypointensities`, and `rhWM-hypointensities` + The volume of unlateralized the white matter hypointensities in the + segmentation, but lateralized by neigboring voxels + (FreeSurfer uses talairach coordinates to re-lateralize). + - `lhCerebralWhiteMatter`, `rhCerebralWhiteMatter`, and `CerebralWhiteMatter` + The volume of the cerebral white matter in the segmentation (including corpus + callosum split evenly into left and right and white matter and WM-hypo). + - `CerebellarGM` + The volume of the cerbellar gray matter in the segmentation. + - `CerebellarWM` + The volume of the cerbellar white matter in the segmentation. + - `SubCortGray` + The volume of the subcortical gray matter in the segmentation. + - `TotalGray` + The total gray matter volume in the segmentation. + - `TFFC` + The volume of the 3rd-5th ventricles and CSF in the segmentation. + - `VentricleChoroidVol` + The volume of the choroid plexus and inferiar and lateral ventricles and CSF. + - `BrainSeg` + The volume of all brains structres in the segmentation. + - `BrainSegNotVent`, and `BrainSegNotVentSurf` + The brain segmentation volume without ventricles. + - `Cerebellum` + The total cerebellar volume. + - `SupraTentorial`, `SupraTentorialNotVent`, and `SupraTentorialNotVentVox` + The supratentorial brain volume/voxel count (without centricles and CSF). + - `Mask` + The volume of the brain mask. + - `EstimatedTotalIntraCranialVol` + The eTIV estimate (via talairach registration). + - `BrainSegVol-to-eTIV`, and `MaskVol-to-eTIV` + The ratios of the brain segmentation volume and the mask volume with respect + to the eTIV estimate. + """ + + hemi = key[:2] + side = "Left" if hemi != "rh" else "Right" + cc_classes = tuple(range(251, 256)) + if key in ("lhSurfaceHoles", "rhSurfaceHoles"): + # FastSurfer and FS7 are same + # l/rSurfaceHoles: (1-lheno/2) -- Euler number of /surf/l/rh.orig.nofix + return SurfaceHoles( + Path(f"surf/{hemi}.orig.nofix"), + f"{hemi}SurfaceHoles", + f"Number of defect holes in {hemi} surfaces prior to fixing", + "unitless", + ) + elif key == "SurfaceHoles": + # sum of holes in left and right surfaces + return DerivedMeasure( + ["rhSurfaceHoles", "lhSurfaceHoles"], + "SurfaceHoles", + "Total number of defect holes in surfaces prior to fixing", + measure_host=self, + ) + elif key in ("lhPialTotal", "rhPialTotal"): + # FastSurfer and FS7 are same + return SurfaceVolume( + Path(f"surf/{hemi}.pial"), + f"{hemi}PialTotalVol", + f"{side} hemisphere total pial volume", + "mm^3", + ) + elif key in ("lhWhiteMatterVol", "rhWhiteMatterVol"): + # This is volume-based in FS7 (ComputeBrainVolumeStats2) + if key[:1] == "l": + classes = (2, 78) + else: # r + classes = (41, 79) + return self.voxel_class( + classes, + f"{hemi}WhiteMatterVol", + f"{side} hemisphere total white matter volume", + "mm^3", + ) + elif key in ("lhWhiteMatterTotal", "rhWhiteMatterTotal"): + return SurfaceVolume( + Path(f"surf/{hemi}.white"), + f"{hemi}WhiteMatterSurfVol", + f"{side} hemisphere total white matter volume", + "mm^3", + ) + elif key in ("lhCortex", "rhCortex"): + # From https://github.com/freesurfer/freesurfer/blob/ + # 3753f8a1af484ac2507809c0edf0bc224bb6ccc1/utils/cma.cpp#L1190C1-L1192C52 + # CtxGM = everything inside pial surface minus everything in white surface. + parents = [f"{hemi}PialTotal", (-1, f"{hemi}WhiteMatterTotal")] + # With version 7, don't need to do a correction because the pial surface is + # pinned to the white surface in the medial wall + return DerivedMeasure( + parents, + f"{hemi}CortexVol", + f"{side} hemisphere cortical gray matter volume", + measure_host=self, + ) + elif key == "Cortex": + # 7 => lhCtxGM + rhCtxGM: sum of left and right cerebral GM + return DerivedMeasure( + ["lhCortex", "rhCortex"], + "CortexVol", + f"Total cortical gray matter volume", + measure_host=self, + ) + elif key == "CorpusCallosumVol": + # FastSurfer and FS7 are same + # CCVol: + # CC_Posterior CC_Mid_Posterior CC_Central CC_Mid_Anterior CC_Anterior + return self.voxel_class( + cc_classes, + "CorpusCallosumVol", + "Volume of the Corpus Callosum", + "mm^3", + ) + elif key in ("lhWM-hypointensities", "rhWM-hypointensities"): + # lateralized counting of class 77 WM hypo intensities + def mask_77_lat(arr): + """ + This function returns a lateralized mask of hypo-WM (class 77). + + This is achieved by looking at surrounding labels and associating them + with left or right (this is not 100% robust when there is no clear + classes with left aseg labels present, but it is cheap to perform. + """ + mask = arr == 77 + left_aseg = (2, 4, 5, 7, 8, 10, 11, 12, 13, 17, 18, 26, 28, 30, 31) + is_left = mask_in_array(arr, left_aseg) + from scipy.ndimage import uniform_filter + is_left = uniform_filter(is_left.astype(np.float32), size=7) > 0.2 + is_side = np.logical_not(is_left) if hemi == "rh" else is_left + return np.logical_and(mask, is_side) + + return VolumeMeasure( + self._seg_from_file, + mask_77_lat, + f"{side}WhiteMatterHypoIntensities", + f"Volume of {side} White matter hypointensities", + "mm^3" + ) + elif key in ("lhCerebralWhiteMatter", "rhCerebralWhiteMatter"): + # SurfaceVolume + # 9/10 => l/rCerebralWM + parents = [ + f"{hemi}WhiteMatterVol", + f"{hemi}WM-hypointensities", + (0.5, "CorpusCallosumVol"), + ] + return DerivedMeasure( + parents, + f"{hemi}CerebralWhiteMatterVol", + f"{side} hemisphere cerebral white matter volume", + measure_host=self, + ) + elif key == "CerebralWhiteMatter": + # 11 => lhCtxWM + rhCtxWM: sum of left and right cerebral WM + return DerivedMeasure( + ["rhCerebralWhiteMatter", "lhCerebralWhiteMatter"], + "CerebralWhiteMatterVol", + "Total cerebral white matter volume", + measure_host=self, + ) + elif key == "CerebellarGM": + # Left-Cerebellum-Cortex Right-Cerebellum-Cortex Cbm_Left_I_IV + # Cbm_Right_I_IV Cbm_Left_V Cbm_Right_V Cbm_Left_VI Cbm_Vermis_VI + # Cbm_Right_VI Cbm_Left_CrusI Cbm_Vermis_CrusI Cbm_Right_CrusI + # Cbm_Left_CrusII Cbm_Vermis_CrusII Cbm_Right_CrusII Cbm_Left_VIIb + # Cbm_Vermis_VIIb Cbm_Right_VIIb Cbm_Left_VIIIa Cbm_Vermis_VIIIa + # Cbm_Right_VIIIa Cbm_Left_VIIIb Cbm_Vermis_VIIIb Cbm_Right_VIIIb + # Cbm_Left_IX Cbm_Vermis_IX Cbm_Right_IX Cbm_Left_X Cbm_Vermis_X Cbm_Right_X + # Cbm_Vermis_VII Cbm_Vermis_VIII Cbm_Vermis + cerebellum_classes = [8, 47] + cerebellum_classes.extend(range(601, 629)) + cerebellum_classes.extend(range(630, 633)) + return self.voxel_class( + cerebellum_classes, + "CerebellarGMVol", + "Cerebellar gray matter volume", + "mm^3", + ) + elif key == "CerebellarWM": + # Left-Cerebellum-White-Matter Right-Cerebellum-White-Matter + cerebellum_classes = [7, 46] + return self.voxel_class( + cerebellum_classes, + "CerebellarWMVol", + "Cerebellar white matter volume", + "mm^3", + ) + elif key == "SubCortGray": + # 4 => SubCortGray + # Left-Thalamus Right-Thalamus Left-Caudate Right-Caudate Left-Putamen + # Right-Putamen Left-Pallidum Right-Pallidum Left-Hippocampus + # Right-Hippocampus Left-Amygdala Right-Amygdala Left-Accumbens-area + # Right-Accumbens-area Left-VentralDC Right-VentralDC Left-Substantia-Nigra + # Right-Substantia-Nigra + subcortgray_classes = [17, 18, 26, 27, 28, 58, 59, 60] + subcortgray_classes.extend(range(10, 14)) + subcortgray_classes.extend(range(49, 55)) + return self.voxel_class( + subcortgray_classes, + "SubCortGrayVol", + "Subcortical gray matter volume", + "mm^3", + ) + elif key == "TotalGray": + # FastSurfer, FS6 and FS7 are same + # 8 => TotalGMVol: sum of SubCortGray., Cortex and Cerebellar GM + return DerivedMeasure( + ["SubCortGray", "Cortex", "CerebellarGM"], + "TotalGrayVol", + "Total gray matter volume", + measure_host=self, + ) + elif key == "TFFC": + # FastSurfer, FS6 and FS7 are same + # TFFC: + # 3rd-Ventricle 4th-Ventricle 5th-Ventricle CSF + tffc_classes = (14, 15, 72, 24) + return self.voxel_class( + tffc_classes, + "Third-Fourth-Fifth-CSF", + "volume of 3rd, 4th, 5th ventricle and CSF", + "mm^3", + ) + elif key == "VentricleChoroidVol": + # FastSurfer, FS6 and FS7 are same, except FS7 adds a KeepCSF flag, which + # excludes CSF (but not by default) + # 15 => VentChorVol: + # Left-Choroid-Plexus Right-Choroid-Plexus Left-Lateral-Ventricle + # Right-Lateral-Ventricle Left-Inf-Lat-Vent Right-Inf-Lat-Vent + ventchor_classes = (4, 5, 31, 43, 44, 63) + return self.voxel_class( + ventchor_classes, + "VentricleChoroidVol", + "Volume of ventricles and choroid plexus", + "mm^3", + ) + elif key in "BrainSeg": + # 0 => BrainSegVol: + # FS7 (does mot use ribbon any more, just ) + # not background, in aseg ctab, not Brain stem, not optic chiasm, + # aseg undefined in aseg ctab and not cortex or WM (L/R Cerebral + # Ctx/WM) + # ComputeBrainStats2 also removes any regions that are not part of the + # AsegStatsLUT.txt + # background, brainstem, optic chiasm: 0, 16, 85 + brain_seg_classes = [2, 3, 4, 5, 7, 8] + brain_seg_classes.extend(range(10, 16)) + brain_seg_classes.extend([17, 18, 24, 26, 28, 30, 31]) + brain_seg_classes.extend(range(41, 55)) + brain_seg_classes.remove(45) + brain_seg_classes.remove(48) + brain_seg_classes.extend([58, 60, 62, 63, 72]) + brain_seg_classes.extend(range(77, 83)) + brain_seg_classes.extend(cc_classes) + if not self._fs_compat: + # also add asegdkt regions 1002-1035, 2002-2035 + brain_seg_classes.extend(range(1002, 1032)) + brain_seg_classes.remove(1004) + brain_seg_classes.extend((1034, 1035)) + brain_seg_classes.extend(range(2002, 2032)) + brain_seg_classes.remove(2004) + brain_seg_classes.extend((2034, 2035)) + return self.voxel_class( + brain_seg_classes, + "BrainSegVol", + "Brain Segmentation Volume", + "mm^3", + ) + elif key in ("BrainSegNotVent", "BrainSegNotVentSurf"): + # FastSurfer, FS6 and FS7 are same + # 1 => BrainSegNotVent: BrainSegVolNotVent (BrainSegVol-VentChorVol-TFFC) + return DerivedMeasure( + ["BrainSeg", (-1, "VentricleChoroidVol"), (-1, "TFFC")], + key.replace("SegNot", "SegVolNot"), + "Brain Segmentation Volume Without Ventricles", + measure_host=self, + ) + elif key == "Cerebellum": + return DerivedMeasure( + ("CerebellarGM", "CerebellarWM"), + "CerebellumVol", + "Cerebellar volume", + measure_host=self, + ) + elif key == "SupraTentorial": + parents = ["BrainSeg", (-1.0, "Cerebellum")] + return DerivedMeasure( + parents, + "SupraTentorialVol", + "Supratentorial volume", + measure_host=self, + ) + elif key == "SupraTentorialNotVent": + # 3 => SupraTentVolNotVent: SupraTentorial w/o Ventricles & Choroid Plexus + parents = ["SupraTentorial", (-1, "VentricleChoroidVol"), (-1, "TFFC")] + return DerivedMeasure( + parents, + "SupraTentorialVolNotVent", + "Supratentorial volume", + measure_host=self, + ) + elif key == "SupraTentorialNotVentVox": + # 3 => SupraTentVolNotVent: SupraTentorial w/o Ventricles & Choroid Plexus + return DerivedMeasure( + ["SupraTentorialNotVent"], + "SupraTentorialVolNotVentVox", + "Supratentorial volume voxel count", + operation="by_vox_vol", + measure_host=self, + ) + elif key == "Mask": + # 12 => MaskVol: Any voxel in mask > 0 + return MaskMeasure( + Path("mri/brainmask.mgz"), + "MaskVol", + "Mask Volume", + "mm^3", + ) + elif key == "EstimatedTotalIntraCranialVol": + # atlas_icv: eTIV from talairach transform determinate + return ETIVMeasure( + Path("mri/transforms/talairach.xfm"), + "eTIV", + "Estimated Total Intracranial Volume", + "mm^3", + ) + elif key == "BrainSegVol-to-eTIV": + # 0/atlas_icv: ratio BrainSegVol to eTIV + return DerivedMeasure( + ["BrainSeg", "EstimatedTotalIntraCranialVol"], + "BrainSegVol-to-eTIV", + "Ratio of BrainSegVol to eTIV", + measure_host=self, + operation="ratio", + ) + elif key == "MaskVol-to-eTIV": + # 12/atlas_icv: ratio Mask to eTIV + return DerivedMeasure( + ["Mask", "EstimatedTotalIntraCranialVol"], + "MaskVol-to-eTIV", + "Ratio of MaskVol to eTIV", + measure_host=self, + operation="ratio", + ) + + def __iter__(self) -> list[AbstractMeasure]: + """ + Iterate through all measures that are exported directly or indirectly. + """ + + out = [self[name] for name in self._exported_measures] + i = 0 + while i < len(out): + this = out[i] + if isinstance(this, DerivedMeasure): + out.extend(filter(lambda x: x not in out, this.parents_items())) + i += 1 + return out + + def compute_non_derived_pv( + self, + compute_threads: Executor | None = None + ) -> "list[Future[int | float]]": + """ + Trigger computation of all non-derived, non-pv measures that are required. + + Parameters + ---------- + compute_threads : concurrent.futures.Executor, optional + An Executor object to perform the computation of measures, if an Executor + object is passed, the computation of measures is submitted to the Executor + object. If not, measures are computed in the main thread. + + Returns + ------- + list[Future[int | float]] + For each non-derived and non-PV measure, a future object that is associated + with the call to the measure. + """ + + def run(f: Callable[[], int | float]) -> Future[int | float]: + out = Future() + out.set_result(f()) + return out + + if isinstance(compute_threads, Executor): + run = compute_threads.submit + + invalid_types = (DerivedMeasure, PVMeasure) + self._compute_futures = [ + run(this) for this in self.values() if not isinstance(this, invalid_types) + ] + return self._compute_futures + + def needs_pv_calculation(self) -> bool: + """ + Returns whether the manager has PV-dependent measures. + + Returns + ------- + bool + Whether the manager has PVMeasure children. + """ + return any(isinstance(this, PVMeasure) for this in self.values()) + + def get_virtual_labels(self, label_pool: Iterable[int]) -> dict[int, list[int]]: + """ + Get the virtual substitute labels that are required. + + Parameters + ---------- + label_pool : Iterable[int] + An iterable over available labels. + + Returns + ------- + dict[int, list[int]] + A dictionary of key-value pairs of new label and a list of labels this + represents. + """ + lbls = (this.labels() for this in self.values() if isinstance(this, PVMeasure)) + no_duplicate_dict = {self.__to_lookup(labs): labs for labs in lbls} + return dict(zip(label_pool, no_duplicate_dict.values())) + + @staticmethod + def __to_lookup(labels: Sequence[int]) -> str: + return str(list(sorted(set(map(int, labels))))) + + def update_pv_from_table( + self, + dataframe: "pd.DataFrame", + merged_labels: dict[int, list[int]], + ) -> "pd.DataFrame": + """ + Update pv measures from dataframe and remove corresponding entries from the + dataframe. + + Parameters + ---------- + dataframe : pd.DataFrame + The dataframe object with the PV values. + merged_labels : dict[int, list[int]] + Mapping from PVMeasure proxy label to list of labels it merges. + + Returns + ------- + pd.DataFrame + A dataframe object, where label 'groups' used for updates and in + `merged_labels` are removed, i.e. those labels added for PVMeasure objects. + + Raises + ------ + RuntimeError + """ + _lookup = {self.__to_lookup(ml): vl for vl, ml in merged_labels.items()} + filtered_df = dataframe + # go through the pv measures and find a measure that has the same list + for this in self.values(): + if isinstance(this, PVMeasure): + virtual_label = _lookup.get(self.__to_lookup(this.labels()), None) + if virtual_label is None: + raise RuntimeError(f"Could not find the virtual label for {this}.") + row = dataframe[dataframe["SegId"] == virtual_label] + if row.shape[0] != 1: + raise RuntimeError( + f"The search results in the dataframe for {this} failed: " + f"shape {row.shape}" + ) + this.update_data(row) + filtered_df = filtered_df[filtered_df["SegId"] != virtual_label] + + return filtered_df + + def wait_compute(self) -> Sequence[BaseException]: + """ + Wait for all pending computation processes and return their errors. + + Also resets the internal compute futures. + + Returns + ------- + Sequence[BaseException] + The errors raised in the computations. + """ + errors = [future.exception() for future in self._compute_futures] + self._compute_futures = [] + return [error for error in errors if error is not None] + + def wait_write_brainvolstats(self, brainvol_statsfile: Path): + """ + Wait for measure computation to finish and write results to brainvol_statsfile. + + Parameters + ---------- + brainvol_statsfile: Path + The file to write the measures to. + + Raises + ------ + RuntimeError + If errors occurred during measure computation. + """ + errors = list(self.wait_compute()) + if len(errors) != 0: + error_messages = ["Some errors occurred during measure computation:"] + error_messages.extend(map(lambda e: str(e.args[0]), errors)) + raise RuntimeError("\n - ".join(error_messages)) + + def fmt_measure(key: str, data: MeasureTuple) -> str: + return f"# Measure {key}, {data[0]}, {data[1]}, {data[2]:.12f}, {data[3]}" + + lines = self.format_measures(fmt_func=fmt_measure) + + with open(brainvol_statsfile, "w") as file: + for line in lines: + print(line, file=file) diff --git a/FastSurferCNN/utils/checkpoint.py b/FastSurferCNN/utils/checkpoint.py index 1832fa2f..6663a9e0 100644 --- a/FastSurferCNN/utils/checkpoint.py +++ b/FastSurferCNN/utils/checkpoint.py @@ -14,41 +14,131 @@ # IMPORTS import os -import glob -from typing import Union, Iterable, Optional, Collection, MutableSequence +from functools import lru_cache +from pathlib import Path +from typing import MutableSequence, Optional, Union, Literal, TypedDict, cast, overload import requests import torch import yacs.config +import yaml -from FastSurferCNN.utils import logging +from FastSurferCNN.utils import logging, Plane +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT Scheduler = "torch.optim.lr_scheduler" LOGGER = logging.getLogger(__name__) # Defaults -URL = "https://b2share.fz-juelich.de/api/files/a423a576-220d-47b0-9e0c-b5b32d45fc59" -FASTSURFER_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) -VINN_AXI = os.path.join(FASTSURFER_ROOT, "checkpoints/aparc_vinn_axial_v2.0.0.pkl") -VINN_COR = os.path.join(FASTSURFER_ROOT, "checkpoints/aparc_vinn_coronal_v2.0.0.pkl") -VINN_SAG = os.path.join(FASTSURFER_ROOT, "checkpoints/aparc_vinn_sagittal_v2.0.0.pkl") +YAML_DEFAULT = FASTSURFER_ROOT / "FastSurferCNN/config/checkpoint_paths.yaml" + + +class CheckpointConfigDict(TypedDict, total=False): + url: list[str] + checkpoint: dict[Plane, Path] + config: dict[Plane, Path] + + +CheckpointConfigFields = Literal["checkpoint", "config", "url"] + + +@lru_cache +def load_checkpoint_config(filename: Path | str = YAML_DEFAULT) -> CheckpointConfigDict: + """ + Load the plane dictionary from the yaml file. + + Parameters + ---------- + filename : Path, str + Path to the yaml file. Either absolute or relative to the FastSurfer root + directory. + + Returns + ------- + CheckpointConfigDict + A dictionary representing the contents of the yaml file. + """ + if not filename.absolute(): + filename = FASTSURFER_ROOT / filename + + with open(filename, "r") as file: + data = yaml.load(file, Loader=yaml.FullLoader) + + required_fields = ("url", "checkpoint") + checks = [k not in data for k in required_fields] + if any(checks): + missing = tuple(k for k, c in zip(required_fields, checks) if c) + message = f"The file {filename} is not valid, missing key(s): {missing}" + raise IOError(message) + if isinstance(data["url"], str): + data["url"] = [data["url"]] + else: + data["url"] = list(data["url"]) + for key in ("config", "checkpoint"): + if key in data: + data[key] = {k: Path(v) for k, v in data[key].items()} + return data + + +@overload +def load_checkpoint_config_defaults( + filetype: Literal["checkpoint", "config"], + filename: str | Path = YAML_DEFAULT, +) -> dict[Plane, Path]: ... + + +@overload +def load_checkpoint_config_defaults( + configtype: Literal["url"], + filename: str | Path = YAML_DEFAULT, +) -> list[str]: ... + + +def load_checkpoint_config_defaults( + configtype: CheckpointConfigFields, + filename: str | Path = YAML_DEFAULT, +) -> dict[Plane, Path] | list[str]: + """ + Get the default value for a specific plane or the url. + + Parameters + ---------- + configtype : "checkpoint", "config", "url" + Type of value. + filename : str, Path + The path to the yaml file. Either absolute or relative to the FastSurfer root + directory. + + Returns + ------- + dict[Plane, Path], list[str] + Default value for the plane. + """ + if not isinstance(filename, Path): + filename = Path(filename) + + configtype = cast(CheckpointConfigFields, configtype.lower()) + if configtype not in ("url", "checkpoint", "config"): + raise ValueError("Type must be 'url', 'checkpoint' or 'config'") + + return load_checkpoint_config(filename)[configtype] def create_checkpoint_dir(expr_dir: Union[os.PathLike], expr_num: int): - """Create the checkpoint dir if not exists. + """ + Create the checkpoint dir if not exists. Parameters ---------- expr_dir : Union[os.PathLike] - directory to create + Directory to create. expr_num : int - number of expr [MISSING] + Experiment number. Returns ------- checkpoint_dir - directory of the checkpoint - + Directory of the checkpoint. """ checkpoint_dir = os.path.join(expr_dir, "checkpoints", str(expr_num)) os.makedirs(checkpoint_dir, exist_ok=True) @@ -56,20 +146,21 @@ def create_checkpoint_dir(expr_dir: Union[os.PathLike], expr_num: int): def get_checkpoint(ckpt_dir: str, epoch: int) -> str: - """Find the standardizes checkpoint name for the checkpoint in the directory ckpt_dir for the given epoch. + """ + Find the standardizes checkpoint name for the checkpoint in the directory + ckpt_dir for the given epoch. Parameters ---------- ckpt_dir : str - Checkpoint directory + Checkpoint directory. epoch : int - Number of the epoch + Number of the epoch. Returns ------- checkpoint_dir - Standardizes checkpoint name - + Standardizes checkpoint name. """ checkpoint_dir = os.path.join( ckpt_dir, "Epoch_{:05d}_training_state.pkl".format(epoch) @@ -78,66 +169,68 @@ def get_checkpoint(ckpt_dir: str, epoch: int) -> str: def get_checkpoint_path( - log_dir: str, resume_experiment: Union[str, int, None] = None -) -> Optional[MutableSequence[str]]: - """Find the paths to checkpoints from the experiment directory. + log_dir: Path | str, resume_experiment: Union[str, int, None] = None +) -> MutableSequence[Path]: + """ + Find the paths to checkpoints from the experiment directory. Parameters ---------- - log_dir : str - experiment directory + log_dir : Path, str + Experiment directory. resume_experiment : Union[str, int, None] - sub-experiment to search in for a model (Default value = None) + Sub-experiment to search in for a model (Default value = None). Returns ------- - prior_model_paths : Optional[MutableSequence[str]] - None, if no models are found, or a list of filenames for checkpoints. - + prior_model_paths : MutableSequence[Path] + A list of filenames for checkpoints. """ if resume_experiment == "Default" or resume_experiment is None: - return None - checkpoint_path = os.path.join(log_dir, "checkpoints", str(resume_experiment)) + return [] + if not isinstance(log_dir, Path): + log_dir = Path(log_dir) + checkpoint_path = log_dir / "checkpoints" / str(resume_experiment) prior_model_paths = sorted( - glob.glob(os.path.join(checkpoint_path, "Epoch_*")), key=os.path.getmtime + checkpoint_path.glob("Epoch_*"), key=lambda p: p.stat().st_mtime ) - if len(prior_model_paths) == 0: - return None - return prior_model_paths + return list(prior_model_paths) def load_from_checkpoint( - checkpoint_path: str, - model: torch.nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[Scheduler] = None, - fine_tune: bool = False, - drop_classifier: bool = False, + checkpoint_path: str | Path, + model: torch.nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[Scheduler] = None, + fine_tune: bool = False, + drop_classifier: bool = False, ): - """Load the model from the given experiment number. + """ + Load the model from the given experiment number. Parameters ---------- - checkpoint_path : str - path to the checkpoint + checkpoint_path : str, Path + Path to the checkpoint. model : torch.nn.Module - Network model + Network model. optimizer : Optional[torch.optim.Optimizer] - Network optimizer (Default value = None) + Network optimizer (Default value = None). scheduler : Optional[Scheduler] - Network scheduler (Default value = None) + Network scheduler (Default value = None). fine_tune : bool - Whether to fine tune or not (Default value = False) + Whether to fine tune or not (Default value = False). drop_classifier : bool - Whether to drop the classifier or not (Default value = False) + Whether to drop the classifier or not (Default value = False). Returns ------- loaded_epoch : int - epoch number - + Epoch number. """ - checkpoint = torch.load(checkpoint_path, map_location="cpu") + # WARNING: weights_only=False can cause unsafe code execution, but here the + # checkpoint can be considered to be from a safe source + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) if drop_classifier: classifier_conv = ["classifier.conv.weight", "classifier.conv.bias"] @@ -159,39 +252,39 @@ def load_from_checkpoint( def save_checkpoint( - checkpoint_dir: str, - epoch: int, - best_metric, - num_gpus: int, - cfg: yacs.config.CfgNode, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: Optional[Scheduler] = None, - best: bool = False, + checkpoint_dir: str | Path, + epoch: int, + best_metric, + num_gpus: int, + cfg: yacs.config.CfgNode, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Optional[Scheduler] = None, + best: bool = False, ) -> None: - """Save the state of training for resume or fine-tune. + """ + Save the state of training for resume or fine-tune. Parameters ---------- - checkpoint_dir : str - path to the checkpoint directory + checkpoint_dir : str, Path + Path to the checkpoint directory. epoch : int - current epoch - best_metric : - best calculated metric + Current epoch. + best_metric : best_metric + Best calculated metric. num_gpus : int - number of used gpus + Number of used gpus. cfg : yacs.config.CfgNode - configuration node + Configuration node. model : torch.nn.Module - used network model + Used network model. optimizer : torch.optim.Optimizer - used network optimizer + Used network optimizer. scheduler : Optional[Scheduler] - used network scheduler. Optional (Default value = None) - best : bool - Whether this was the best checkpoint so far [MISSING] (Default value = False) - + Used network scheduler. Optional (Default value = None). + best : bool, default=False + Whether this was the best checkpoint so far (Default value = False). """ save_name = f"Epoch_{epoch:05d}_training_state.pkl" saving_model = model.module if num_gpus > 1 else model @@ -205,95 +298,116 @@ def save_checkpoint( if scheduler is not None: checkpoint["scheduler_state"] = scheduler.state_dict() + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) - torch.save(checkpoint, checkpoint_dir + "/" + save_name) + torch.save(checkpoint, checkpoint_dir / save_name) if best: - remove_ckpt(checkpoint_dir + "/Best_training_state.pkl") - torch.save(checkpoint, checkpoint_dir + "/Best_training_state.pkl") + remove_ckpt(checkpoint_dir / "Best_training_state.pkl") + torch.save(checkpoint, checkpoint_dir / "Best_training_state.pkl") -def remove_ckpt(ckpt: str): - """Remove the checkpoint. +def remove_ckpt(ckpt: str | Path): + """ + Remove the checkpoint. Parameters ---------- - ckpt : str - Path and filename to the checkpoint - + ckpt : str, Path + Path and filename to the checkpoint. """ try: - os.remove(ckpt) + Path(ckpt).unlink() except FileNotFoundError: pass def download_checkpoint( - download_url: str, checkpoint_name: str, - checkpoint_path: str + checkpoint_path: str | Path, + urls: list[str], ) -> None: - """Download a checkpoint file. + """ + Download a checkpoint file. Raises an HTTPError if the file is not found or the server is not reachable. Parameters ---------- - download_url : str - URL of checkpoint hosting site checkpoint_name : str - name of checkpoint - checkpoint_path : str - path of the file in which the checkpoint will be saved - + Name of checkpoint. + checkpoint_path : Path, str + Path of the file in which the checkpoint will be saved. + urls : list[str] + List of URLs of checkpoint hosting sites. """ - try: - response = requests.get(download_url + "/" + checkpoint_name, verify=True) - # Raise error if file does not exist: - response.raise_for_status() - except requests.exceptions.HTTPError as e: - LOGGER.info("Response code: {}".format(e.response.status_code)) - response = requests.get(download_url + "/" + checkpoint_name, verify=False) - response.raise_for_status() + response = None + for url in urls: + try: + LOGGER.info(f"Downloading checkpoint {checkpoint_name} from {url}") + response = requests.get( + url + "/" + checkpoint_name, + verify=True, + timeout=(5, None), # (connect timeout: 5 sec, read timeout: None) + ) + # Raise error if file does not exist: + response.raise_for_status() + break + + except requests.exceptions.RequestException as e: + LOGGER.warning(f"Server {url} not reachable ({type(e).__name__}): {e}") + if isinstance(e, requests.exceptions.HTTPError): + LOGGER.warning(f"Response code: {e.response.status_code}") + + if response is None: + links = ', '.join(u.removeprefix('https://')[:22] + "..." for u in urls) + raise requests.exceptions.RequestException( + f"Failed downloading the checkpoint {checkpoint_name} from {links}." + ) + else: + response.raise_for_status() # Raise error if no server is reachable with open(checkpoint_path, "wb") as f: f.write(response.content) -def check_and_download_ckpts(checkpoint_path: str, url: str) -> None: - """Check and download a checkpoint file, if it does not exist. +def check_and_download_ckpts(checkpoint_path: Path | str, urls: list[str]) -> None: + """ + Check and download a checkpoint file, if it does not exist. Parameters ---------- - checkpoint_path : str - path of the file in which the checkpoint will be saved - url : str - URL of checkpoint hosting site - + checkpoint_path : Path, str + Path of the file in which the checkpoint will be saved. + urls : list[str] + URLs of checkpoint hosting site. """ + if not isinstance(checkpoint_path, Path): + checkpoint_path = Path(checkpoint_path) # Download checkpoint file from url if it does not exist - if not os.path.exists(checkpoint_path): - ckptdir, ckptname = os.path.split(checkpoint_path) - if not os.path.exists(ckptdir) and ckptdir: - os.makedirs(ckptdir) - download_checkpoint(url, ckptname, checkpoint_path) + if not checkpoint_path.exists(): + # create dir if it does not exist + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + download_checkpoint(checkpoint_path.name, checkpoint_path, urls) -def get_checkpoints(axi: str, cor: str, sag: str, url: str = URL) -> None: - """Check and download checkpoint files if not exist. +def get_checkpoints(*checkpoints: Path | str, urls: list[str]) -> None: + """ + Check and download checkpoint files if not exist. Parameters ---------- - axi : str - Axial path of the file in which the checkpoint will be saved - cor : str - Coronal path of the file in which the checkpoint will be saved - sag : str - Sagittal path of the file in which the checkpoint will be saved - url : str - URL of checkpoint hosting site (Default value = URL) - + *checkpoints : Path, str + Paths of the files in which the checkpoint will be saved. + urls : Path, str + URLs of checkpoint hosting sites. """ - check_and_download_ckpts(axi, url) - check_and_download_ckpts(cor, url) - check_and_download_ckpts(sag, url) + try: + for file in map(Path, checkpoints): + if not file.is_absolute() and file.parts[0] != ".": + file = FASTSURFER_ROOT / file + check_and_download_ckpts(file, urls) + except requests.exceptions.HTTPError: + LOGGER.error(f"Could not find nor download checkpoints from {urls}") + raise diff --git a/FastSurferCNN/utils/common.py b/FastSurferCNN/utils/common.py index 4ab68039..d2577b85 100644 --- a/FastSurferCNN/utils/common.py +++ b/FastSurferCNN/utils/common.py @@ -11,21 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import builtins + # IMPORTS import os +from collections import namedtuple from concurrent.futures import Executor, Future +from dataclasses import dataclass +from pathlib import Path from typing import ( - List, - Union, - TypeVar, + Any, Callable, + Dict, Iterable, - Any, Iterator, + List, Optional, Tuple, - Dict, + TypeVar, ) import torch @@ -34,38 +36,39 @@ __all__ = [ "assert_no_root", - "Executor", "find_device", "handle_cuda_memory_exception", "iterate", - "NoParallelExecutor", + "SerialExecutor", "pipeline", - "removesuffix", "SubjectList", "SubjectDirectory", ] +from FastSurferCNN.utils.parser_defaults import SubjectDirectoryConfig + LOGGER = logging.getLogger(__name__) _T = TypeVar("_T") _Ti = TypeVar("_Ti") def find_device( - device: Union[torch.device, str] = "auto", + device: torch.device | str = "auto", flag_name: str = "device", min_memory: int = 0, ) -> torch.device: - """Create a device object from the device string passed. + """ + Create a device object from the device string passed. Includes detection of devices if device is not defined or "auto". Parameters ---------- - device : Union[torch.device, str] - the device to search for and test following pytorch device naming + device : torch.device, str + The device to search for and test following pytorch device naming conventions, e.g. 'cuda:0', 'cpu', etc. (default: 'auto'). flag_name : str - name of the corresponding flag for error messages (default: 'device') + Name of the corresponding flag for error messages (default: 'device'). min_memory : int The minimum memory in bytes required for cuda-devices to be valid (default: 0, works always). @@ -73,8 +76,7 @@ def find_device( Returns ------- device: torch.device - The torch.device object - + The torch.device object. """ logger = logging.get_logger(__name__ + ".auto_device") # if specific device is requested, check and stop if not available: @@ -103,7 +105,8 @@ def find_device( if total_gpu_memory < min_memory: giga = 1024**3 logger.info( - f"Found {total_gpu_memory/giga:.1f} GB GPU memory, but {min_memory/giga:.f} GB was required." + f"Found {total_gpu_memory/giga:.1f} GB GPU memory, but " + f"{min_memory/giga:.1f} GB was required." ) device = torch.device("cpu") @@ -113,13 +116,13 @@ def find_device( def assert_no_root() -> bool: - """Check whether the user is the root user and raises an error message is so. + """ + Check whether the user is the root user and raises an error message is so. Returns ------- bool - Whether the user is root or not - + Whether the user is root or not. """ if os.name == "posix" and os.getuid() == 0: import sys @@ -127,10 +130,12 @@ def assert_no_root() -> bool: sys.exit( """---------------------------- - ERROR: You are trying to run '{0}' as root. We advice to avoid running - FastSurfer as root, because it will lead to files and folders created as root. - If you are running FastSurfer in a docker container, you can specify the user with - '-u $(id -u):$(id -g)' (see https://docs.docker.com/engine/reference/run/#user). + ERROR + You are trying to run '{0}' as root. We advice to avoid running FastSurfer + as root, because it will lead to files and folders created as root. + If you are running FastSurfer in a docker container, you can specify the + user with '-u $(id -u):$(id -g)' + (see https://docs.docker.com/engine/reference/run/#user). If you want to force running as root, you may pass --allow_root to {0}. """.format( os.path.basename(__main__.__file__) @@ -139,19 +144,19 @@ def assert_no_root() -> bool: return True -def handle_cuda_memory_exception(exception: builtins.BaseException) -> bool: - """Handle CUDA out of memory exception and print a help text. +def handle_cuda_memory_exception(exception: BaseException) -> bool: + """ + Handle CUDA out of memory exception and print a help text. Parameters ---------- exception : builtins.BaseException - Received exception + Received exception. Returns ------- bool - Whether th exception was a RuntimeError caused by Cuda out memory - + Whether the exception was a RuntimeError caused by Cuda out memory. """ if not isinstance(exception, RuntimeError): return False @@ -159,12 +164,15 @@ def handle_cuda_memory_exception(exception: builtins.BaseException) -> bool: if message.startswith("CUDA out of memory. "): LOGGER.critical("ERROR - INSUFFICIENT GPU MEMORY") LOGGER.info( - "The memory requirements exceeds the available GPU memory, try using a smaller batch size " - "(--batch_size ) and/or view aggregation on the cpu (--viewagg_device 'cpu')." - "Note: View Aggregation on the GPU is particularly memory-hungry at approx. 5 GB for standard " - "256x256x256 images." + "The memory requirements exceeds the available GPU memory, try using a " + "smaller batch size (--batch_size ) and/or view aggregation on the " + "cpu (--viewagg_device 'cpu')." + ) + LOGGER.info( + "Note: View Aggregation on the GPU is particularly memory-hungry at " + "approx. 5 GB for standard 256x256x256 images." ) - memory_message = message[message.find("(") + 1 : message.find(")")] + memory_message = message[message.find("(") + 1:message.find(")")] LOGGER.info(f"Using {memory_message}.") return True else: @@ -178,31 +186,29 @@ def pipeline( *, pipeline_size: int = 1, ) -> Iterator[Tuple[_Ti, _T]]: - """Pipeline a function to be executed in the pool. + """ + Pipeline a function to be executed in the pool. Analogous to iterate, but run func in a different thread for the next element while the current element is returned. - Parameters [MISSING] + Parameters ---------- pool : Executor + Thread pool executor for parallel execution. + func : callable + Function to use. + iterable : Iterable + Iterable containing input elements. + pipeline_size : int, default=1 + Size of the processing pipeline. - func : Callable[[_Ti], _T] : - function to use - - iterable : Iterable[_Ti] - - * : - [MISSING] - - pipeline_size : int - size of the pipeline - (Default value = 1) - - Returns - ------- - [MISSING] - + Yields + ------ + element : _Ti + Elements + _T + Results of func corresponding to element: func(element). """ # do pipeline loading the next element from collections import deque @@ -222,64 +228,35 @@ def pipeline( def iterate( - pool: Executor, func: Callable[[_Ti], _T], iterable: Iterable[_Ti] + pool: Executor, func: Callable[[_Ti], _T], iterable: Iterable[_Ti], ) -> Iterator[Tuple[_Ti, _T]]: - """Iterate over iterable, yield pairs of elements and func(element). + """ + Iterate over iterable, yield pairs of elements and func(element). Parameters ---------- pool : Executor - [MISSING] - func : Callable[[_Ti], _T] - function to use - iterable : Iterable[_Ti] - iterable + The Executor object (dummy object to have a common API with pipeline). + func : callable + Function to use. + iterable : Iterable + Iterable to draw objects to process with func from. Yields ------ - element : _Ti - elements + element : _Ti + Elements _T - [MISSING] - + Results of func corresponding to element: func(element). """ for element in iterable: yield element, func(element) -def removesuffix(string: str, suffix: str) -> str: - """Remove a suffix from a string. - - Similar to string.removesuffix in PY3.9+. - - Parameters - ---------- - string : str - string that should be edited - suffix : str - suffix to remove - - Returns - ------- - str - input string with removed suffix - - """ - import sys - - if sys.version_info.minor >= 9: - # removesuffix is a Python3.9 feature - return string.removesuffix(suffix) - else: - return ( - string[: -len(suffix)] - if len(suffix) > 0 and string.endswith(suffix) - else string - ) - - class SubjectDirectory: - """Represent a subject.""" + """ + Represent a subject directory. + """ _orig_name: str _copy_orig_name: str @@ -291,144 +268,152 @@ class SubjectDirectory: _id: str def __init__(self, **kwargs): - """Create a subject, supports generic attributes. + """ + Create a subject, supports generic attributes. Parameters ---------- - **kwargs : - id: the subject id - orig_name: relative or absolute filename of the orig filename - conf_name: relative or absolute filename of the conformed filename - segfile: relative or absolute filename of the segmentation filename - main_segfile: relative or absolute filename of the main segmentation filename - asegdkt_segfile: relative or absolute filename of the aparc+aseg segmentation filename - subject_dir: path to the subjects directory (containing subject folders) - + id : str + The subject id. + orig_name : str + Relative or absolute filename of the orig filename. + conf_name : str + Relative or absolute filename of the conformed filename. + segfile : str + Relative or absolute filename of the segmentation filename. + main_segfile : str + Relative or absolute filename of the main segmentation filename. + asegdkt_segfile : str + Relative or absolute filename of the aparc+aseg segmentation filename. + subject_dir : Path + Path to the subjects directory (containing subject folders). """ for k, v in kwargs.items(): + if k == "subject_dir": + v = Path(v) setattr(self, "_" + k, v) - def filename_in_subject_folder(self, filepath: str) -> str: - """Return the full path to the file. + def filename_in_subject_folder(self, filepath: str | Path) -> Path: + """ + Return the full path to the file. Parameters ---------- - filepath : str - abs path to the file or name of the file + filepath : str, Path + Absolute to the file or name of the file. Returns ------- - str - Path to the file - + Path + Path to the file. """ - return ( - filepath - if os.path.isabs(filepath) - else os.path.join(self.subject_dir, self._id, filepath) - ) + if Path(filepath).is_absolute(): + return Path(filepath) + else: + return self.subject_dir / self._id / filepath - def filename_by_attribute(self, attr_name: str) -> str: - """[MISSING]. + def filename_by_attribute(self, attr_name: str) -> Path: + """ + Retrieve a filename based on the provided attribute name. Parameters ---------- attr_name : str - [MISSING] + The name of the attribute associated with the desired filename. Returns ------- - str - [MISSING] - + Path + The filename corresponding to the provided attribute name. """ return self.filename_in_subject_folder(self.get_attribute(attr_name)) - def fileexists_in_subject_folder(self, filepath: str) -> bool: - """Check if file exists in the subject folder. + def fileexists_in_subject_folder(self, filepath: str | Path) -> bool: + """ + Check if file exists in the subject folder. Parameters ---------- - filepath : str - Path to the file + filepath : Path, str + Path to the file. Returns ------- bool - Whether the file exists or not - + Whether the file exists or not. """ - return os.path.exists(self.filename_in_subject_folder(filepath)) + return self.filename_in_subject_folder(filepath).exists() def fileexists_by_attribute(self, attr_name: str) -> bool: - """[MISSING]. + """ + Check if a file exists based on the provided attribute name. Parameters ---------- attr_name : str - [MISSING] + The name of the attribute associated with the file existence check. Returns ------- bool - Whether the file exists or not - + Whether the file exists or not. """ return self.fileexists_in_subject_folder(self.get_attribute(attr_name)) @property - def subject_dir(self) -> str: - """Gets the subject directory name. + def subject_dir(self) -> Path: + """ + Gets the subject directory name. Returns ------- - str - The set subject directory - + Path + The set subject directory. """ assert hasattr(self, "_subject_dir") or "The folder attribute has not been set!" - return self._subject_dir + return Path(self._subject_dir) @subject_dir.setter - def subject_dir(self, _folder: str): - """Set the subject directory name. + def subject_dir(self, _folder: str | Path): + """ + Set the subject directory name. Parameters ---------- - _folder : str - The subject directory - + _folder : str, Path + The subject directory. """ self._subject_dir = _folder @property def id(self) -> str: - """Get the id. + """ + Get the id. Returns ------- str - The id - + The id. """ assert hasattr(self, "_id") or "The id attribute has not been set!" return self._id @id.setter def id(self, _id: str): - """Set the id. + """ + Set the id. Parameters ---------- _id : str - The id - + The id. """ self._id = _id @property def orig_name(self) -> str: - """Try to return absolute path. + """ + Try to return absolute path. If the native_t1_file is a relative path, it will be interpreted as relative to folder. @@ -436,8 +421,7 @@ def orig_name(self) -> str: Returns ------- str - The orig name - + The orig name. """ assert ( hasattr(self, "_orig_name") or "The orig_name attribute has not been set!" @@ -446,28 +430,28 @@ def orig_name(self) -> str: @orig_name.setter def orig_name(self, _orig_name: str): - """Set the orig name. + """ + Set the orig name. Parameters ---------- _orig_name : str - The orig name - + The orig name. """ self._orig_name = _orig_name @property - def copy_orig_name(self) -> str: - """Try to return absolute path. + def copy_orig_name(self) -> Path: + """ + Try to return absolute path. If the copy_orig_t1_file is a relative path, it will be interpreted as relative to folder. Returns ------- - str - The copy of orig name - + Path + The copy of orig name. """ assert ( hasattr(self, "_copy_orig_name") @@ -477,33 +461,33 @@ def copy_orig_name(self) -> str: @copy_orig_name.setter def copy_orig_name(self, _copy_orig_name: str): - """Set the copy of orig name. + """ + Set the copy of orig name. Parameters ---------- _copy_orig_name : str - [MISSING] + The copy of the orig name. Returns ------- str - original name - + Original name. """ self._copy_orig_name = _copy_orig_name @property - def conf_name(self) -> str: - """Try to return absolute path. + def conf_name(self) -> Path: + """ + Try to return absolute path. If the conformed_t1_file is a relative path, it will be interpreted as relative to folder. Returns ------- - str - [MISSING] - + Path + The path to the conformed image file. """ assert ( hasattr(self, "_conf_name") or "The conf_name attribute has not been set!" @@ -512,59 +496,56 @@ def conf_name(self) -> str: @conf_name.setter def conf_name(self, _conf_name: str): - """[MISSING]. + """ + Set the path to the conformed image. Parameters ---------- _conf_name : str - [MISSING] - - Returns - ------- - str - [MISSING] + Path to the conformed image. """ self._conf_name = _conf_name @property - def segfile(self) -> str: - """Try to return absolute path. + def segfile(self) -> Path: + """ + Try to return absolute path. If the segfile is a relative path, it will be interpreted as relative to folder. Returns ------- - str - Path to the segfile - + Path + Path to the segfile. """ assert hasattr(self, "_segfile") or "The _segfile attribute has not been set!" return self.filename_in_subject_folder(self._segfile) @segfile.setter def segfile(self, _segfile: str): - """Set segfile. + """ + Set segfile. Parameters ---------- _segfile : str - [MISSING] - + Path to the segmentation file. """ self._segfile = _segfile @property - def asegdkt_segfile(self) -> str: - """Try to return absolute path. + def asegdkt_segfile(self) -> Path: + """ + Try to return absolute path. If the asegdkt_segfile is a relative path, it will be interpreted as relative to folder. Returns ------- - str - Path to segmentation file + Path + Path to segmentation file. """ assert ( hasattr(self, "_segfile") @@ -573,27 +554,28 @@ def asegdkt_segfile(self) -> str: return self.filename_in_subject_folder(self._asegdkt_segfile) @asegdkt_segfile.setter - def asegdkt_segfile(self, _asegdkt_segfile: str): - """Set path to segmentation file. + def asegdkt_segfile(self, _asegdkt_segfile: str | Path): + """ + Set path to segmentation file. Parameters ---------- - _asegdkt_segfile : str - Path to segmentation file - + _asegdkt_segfile : Path, str + Path to segmentation file. """ - self._asegdkt_segfile = _asegdkt_segfile + self._asegdkt_segfile = str(_asegdkt_segfile) @property - def main_segfile(self) -> str: - """Try to return absolute path. + def main_segfile(self) -> Path: + """ + Try to return absolute path. If the main_segfile is a relative path, it will be interpreted as relative to folder. Returns ------- - str + Path Path to the main segfile. """ @@ -605,76 +587,82 @@ def main_segfile(self) -> str: @main_segfile.setter def main_segfile(self, _main_segfile: str): - """Set the main segfile. + """ + Set the main segfile. Parameters ---------- _main_segfile : str - Path to the main_segfile - + Path to the main_segfile. """ self._main_segfile = _main_segfile def can_resolve_filename(self, filename: str) -> bool: - """Check whether we can resolve the file name. + """ + Check whether we can resolve the file name. Parameters ---------- filename : str - Name of the filename to check + Name of the filename to check. Returns ------- bool - Whether we can resolve the file name - + Whether we can resolve the file name. """ return os.path.isabs(filename) or self._subject_dir is not None def can_resolve_attribute(self, attr_name: str) -> bool: - """Check whether we can resolve the attribute. + """ + Check whether we can resolve the attribute. Parameters ---------- attr_name : str - Name of the attribute to check + Name of the attribute to check. Returns ------- bool - Whether we can resolve the attribute - + Whether we can resolve the attribute. """ return self.can_resolve_filename(self.get_attribute(attr_name)) def has_attribute(self, attr_name: str) -> bool: - """Check if the attribute is set. + """ + Check if the attribute is set. Parameters ---------- attr_name : str - Name of the attribute to check + Name of the attribute to check. Returns ------- bool - Whether the attribute exists or not - + Whether the attribute exists or not. """ return getattr(self, "_" + attr_name, None) is not None - def get_attribute(self, attr_name: str): - """Give the requested attribute. + def get_attribute(self, attr_name: str) -> str | Path: + """ + Give the requested attribute. Parameters ---------- attr_name : str - Name of the attribute to return + Name of the attribute to return. Returns ------- - Value of the attribute + str, Path + The value of the requested attribute. + Raises + ------ + AttributeError + If the subject has no attribute with the given name. """ if not self.has_attribute(attr_name): raise AttributeError(f"The subject has no attribute named {attr_name}.") @@ -682,61 +670,77 @@ def get_attribute(self, attr_name: str): class SubjectList: - """Represent a list of subjects.""" + """ + Represent a list of subjects. + """ - _subjects: List[str] + _subjects: List[Path] _orig_name_: str _conf_name_: str _segfile_: str - _flags: Dict[str, Dict] + _flags: Dict[str, dict] DEFAULT_FLAGS = {k: v(dict) for k, v in parser_defaults.ALL_FLAGS.items()} - def __init__(self, args, flags: Optional[Dict[str, Dict]] = None, **assign): - """Create an iterate-able list of subjects from the arguments passed. + def __init__( + self, + args: SubjectDirectoryConfig, + flags: Optional[dict[str, dict]] = None, + **assign, + ): + """ + Create an iterate-able list of subjects from the arguments passed. - Parameters - ---------- - args : - The Namespace object (object with attributes to define parameters) with the following 'required' - definitions. - orig_name (str): the path to the input t1 file. - conf_name (str): the path to the conformed t1 file. - segfile (str): the path to the main output file. - in_dir (str) or csv_file (str), if orig_name is not an absolute path. - flags : Optional[Dict[str, Dict]] - dictionary of flags used to generate args (used to populate messages). Default: - `SubjectList.DEFAULT_FLAGS`, which get initialized from `FastSurferCNN.utils.-parser_defaults.ALL_FLAGS` - `SubjectList.DEFAULT_FLAGS`, which get initialized from `FastSurferCNN.utils.-parser_defaults.ALL_FLAGS` There are three modes of operation: - There are three modes of operation: - If args has a non-empty csv_file attribute (cf. {csv_file[flag]} flag), read subjects from a subject list file - The subject listfile is a textfile with one subject per line, where each line can be an absolute or relative - path. If they are relative paths, they are interpreted as relative to args.in_dir, so args.in_dir needs to - be defined. Paths can either point to subject directories (file is path + the t1 image name in + If args has a non-empty csv_file attribute (cf. {csv_file[flag]} flag), + read subjects from a subject list file: The subject listfile is a + textfile with one subject per line, where each line can be an absolute + or relative path. If they are relative paths, they are interpreted as + relative to args.in_dir, so args.in_dir needs to be defined. Paths can + either point to subject directories (file is path + the t1 image name in args.orig_name) or directly to the t1 image. - Else, if args has a non-empty in_dir attribute (c.f. {in_dir[flag]} flag), list the folder contents of in_dir - The search pattern can be specified by the search_tag attribute of args (cf. {tag[flag]} flag), which is + Else, if args has a non-empty in_dir attribute (c.f. {in_dir[flag]} flag), + list the folder contents of in_dir: The search pattern can be specified + by the search_tag attribute of args (cf. {tag[flag]} flag), which is {tag[default]} (all files and folders) by default. - For containing objects that are folders, these folders are interpreted as subject directories and the t1 - image is loaded according to the (necessarily relative) {t1[flag]} (args.orig_name), which defaults to - {t1[default]}. The folder name is treated as the subject id, if no {sid[flag]} is passed (args.sid). - For the containing objects that are files, these files are interpreted are loaded as to-be analyzed data. - Finally, if an absolute path is specified with the orig_name attribute of args (cf. {t1[flag]}), only this - specific file is processed. - If args is passed without a sid attribute (cf. {sid[flag]}), subject ids are extracted from the subject details - (excluding potentially added relative paths). Suffixes can be removed from this by use of the remove_suffix - attribute of args (cf. {remove_suffix[flag]}) including file extensions or subfolders (e.g. `{tag[flag]} - */anat {remove_suffix[flag]} /anat` or `{tag[flag]} *_t1.nii.gz {remove_suffix[flag]} _t1.nii.gz`). + For containing objects that are folders, these folders are interpreted + as subject directories and the t1 image is loaded according to the + (necessarily relative) {t1[flag]} (args.orig_name), which defaults to + {t1[default]}. The folder name is treated as the subject id, if no + {sid[flag]} is passed (args.sid). + For the containing objects that are files, these files are interpreted + are loaded as to-be analyzed data. + Finally, if an absolute path is specified with the orig_name attribute of + args (cf. {t1[flag]}), only this specific file is processed. + + Parameters + ---------- + args : SubjectDirectoryConfig + The namedtuple/Namespace object (object with attributes to define + parameters) with the following 'required' definitions. + - orig_name (str): the path to the input t1 file. + - conf_name (str): the path to the conformed t1 file. + - segfile (str): the path to the main output file. + - in_dir (str) or csv_file (str), if orig_name is not an absolute path. + If args is passed without a sid attribute (cf. {sid[flag]}), subject ids are + extracted from the subject details (excluding potentially added relative + paths). Suffixes can be removed from this by use of the remove_suffix + attribute of args (cf. {remove_suffix[flag]}) including file extensions or + subfolders (e.g. `{tag[flag]} */anat {remove_suffix[flag]} /anat` or + `{tag[flag]} *_t1.nii.gz {remove_suffix[flag]} _t1.nii.gz`). + flags : dict[str, Dict], optional + dictionary of flags used to generate args (used to populate messages). + Default: `SubjectList.DEFAULT_FLAGS`, which get initialized from + `FastSurferCNN.utils.parser_defaults.ALL_FLAGS`. **assign : Raises ------ RuntimeError - For invalid configurations, e.g. no 'in_dir', 'csv_file', or absolute 'orig_name'. + For invalid configurations, e.g. no 'in_dir', 'csv_file', or absolute + 'orig_name'. RuntimeError When using {sid[flag]} with multiple subjects. - """ # populate _flags with DEFAULT_FLAGS self._flags = flags.copy() if flags is not None else {} @@ -749,13 +753,13 @@ def __init__(self, args, flags: Optional[Dict[str, Dict]] = None, **assign): and getattr(args, "csv_file", None) is None and not os.path.isabs(getattr(args, "orig_name", "undefined")) ): - raise RuntimeError( - ( - "One of the following three options has to be passed {in_dir[flag]}, {csv_file[flag]} " - "or {t1[flag]} with an absolute file path. Please specify the data input directory, " - "the subject list file or the full path to input volume" - ).format(**self._flags) + msg = ( + "One of the following three options has to be passed {in_dir[flag]}, " + "{csv_file[flag]} or {t1[flag]} with an absolute file path. Please " + "specify the data input directory, the subject list file or the full " + "path to input volume" ) + raise RuntimeError(msg.format(**self._flags)) assign.setdefault("segfile", "segfile") assign.setdefault("orig_name", "orig_name") assign.setdefault("conf_name", "conf_name") @@ -764,94 +768,99 @@ def __init__(self, args, flags: Optional[Dict[str, Dict]] = None, **assign): for subject_attribute, args_attribute in assign.items(): if not hasattr(args, args_attribute): raise ValueError( - f"You have defined {args_attribute} as a attribute of `args`via keyword argument to " - f"SubjectList.__init__ or {args_attribute} is required, but `args` does not have " - f"{args_attribute} as an attribute." + f"You have defined {args_attribute} as a attribute of `args` via " + f"keyword argument to SubjectList.__init__ or {args_attribute} is " + f"required, but `args` does not have {args_attribute} as an " + f"attribute." ) setattr(self, "_" + subject_attribute + "_", getattr(args, args_attribute)) self._out_segfile = getattr(self, "_segfile_", None) if self._out_segfile is None: raise RuntimeError( - f"The segmentation output file is not set, it should be either 'segfile' (which gets " - f"populated from args.segfile), or a keyword argument to __init__, e.g. " - f"`SubjectList(args, subseg='subseg_param', out_filename='subseg')`." + f"The segmentation output file is not set, it should be either " + f"'segfile' (which gets populated from args.segfile), or a keyword " + f"argument to __init__, e.g. `SubjectList(args, subseg='subseg_param', " + f"out_filename='subseg')`." ) # if out_dir is not set, fall back to in_dir by default self._out_dir = getattr(args, "out_dir", None) or getattr(args, "in_dir", None) if self._out_dir in [None, ""] and not os.path.isabs(self._out_segfile): - raise RuntimeError( - ( - "Please specify, where the segmentation output should be stored by either the " - "{sd[flag]} flag (output subject directory, this can be same as input directory) or an " - "absolute path to the {asegdkt_segfile[flag]} output segmentation volume." - ).format(**self._flags) + msg = ( + "Please specify, where the segmentation output should be stored by " + "either the {sd[flag]} flag (output subject directory, this can be " + "same as input directory) or an absolute path to the " + "{asegdkt_segfile[flag]} output segmentation volume." ) + raise RuntimeError(msg.format(**self._flags)) # 1. are we doing a csv file of subjects if getattr(args, "csv_file") is not None: with open(args.csv_file, "r") as s_dirs: - self._subjects = [line.strip() for line in s_dirs.readlines()] - if any(not os.path.isabs(d) for d in self._subjects): + self._subjects = [Path(line.strip()) for line in s_dirs.readlines()] + if any(not d.is_absolute() for d in self._subjects): msg = f"At least one path in {args.csv_file} was relative, but the " if getattr(args, "in_dir") is None: raise RuntimeError( - msg - + "in_dir was not in args (no {in_dir[flag]} flag).".format( - **self._flags + "{}in_dir was not in args (no {in_dir[flag]} flag).".format( + msg, **self._flags ) ) elif not os.path.isdir(args.in_dir): raise RuntimeError( msg + f"input directory {args.in_dir} does not exist." ) + base = Path(args.in_dir) self._subjects = [ - os.path.join(args.in_dir, d) if os.path.isabs(d) else d + base / d if not d.is_absolute() else d for d in self._subjects ] self._num_subjects = len(self._subjects) LOGGER.info( - f"Analyzing all {self._num_subjects} subjects from csv_file {args.csv_file}." + f"Analyzing all {self._num_subjects} subjects from csv_file " + f"{args.csv_file}." ) # 2. are we doing a single file (absolute path to the file) - elif os.path.isabs(self._orig_name_): + elif (orig_name := Path(self._orig_name_)).is_absolute(): LOGGER.info("Single subject with absolute file path for input.") - if not os.path.isfile(self._orig_name_): + if not orig_name.is_file(): raise RuntimeError( - f"The input file {self._orig_name_} does not exist (is not a file)." + f"The input file {orig_name} does not exist (is not a file)." ) if self._out_dir is None: sid = "" - if os.path.isabs(self._out_segfile): - # try to extract the subject directory from the absolute out filename by, containing folder is 'mri' - # or the subject id - out_dirname = os.path.dirname(self._out_segfile) - parent_dir = os.path.basename(out_dirname) + if (out_segfile := Path(self._out_segfile)).is_absolute(): + # try to extract the subject directory from the absolute out + # filename by, containing folder is 'mri' or the subject id + out_dirname = out_segfile.parent + parent_dir = out_dirname.name + msg = ( + f"No subjects directory specified, but the parent directory " + f"of the output file {out_segfile} is" + ) if parent_dir == "mri": LOGGER.info( - f"No subjects directory specified, but the parent directory of the output file " - f"{self._out_segfile} is 'mri', so we are assuming this is the 'mri' folder in " - f"the subject directory." - ) - self._out_dir, sid = os.path.split(os.path.dirname(out_dirname)) - self._out_segfile = os.path.join( - "mri", os.path.basename(self._out_segfile) + f"{msg} 'mri', so we are assuming this is the 'mri' folder " + f"in the subject directory." ) + self._out_dir = out_dirname.parent.parent + sid = out_dirname.parent.name + self._out_segfile = "mri/" + out_segfile.name elif parent_dir == getattr(args, "sid", ""): LOGGER.info( - f"No subjects directory specified, but the parent directory of the output file " - f"{self._out_segfile} is the subject id, so we are assuming this is the subject " - f"directory." + f"{msg} the subject id, so we are assuming this is the " + f"subject directory." ) - self._out_dir, sid = os.path.split(out_dirname) - self._out_segfile = os.path.basename(self._out_segfile) + self._out_dir = out_dirname.parent + sid = out_dirname.name + self._out_segfile = out_segfile.name def _not_abs(subj_attr): return not os.path.isabs(getattr(self, f"_{subj_attr}_")) - if getattr(args, "sid", "") in [None, ""]: + if getattr(args, "sid", "") in (None, ""): args.sid = sid elif getattr(args, "sid", "") != sid and any( map(_not_abs, self.__attr_assign.keys()) @@ -862,75 +871,79 @@ def _not_abs(subj_attr): if _not_abs(k) ] msg = ( - "Could not extract the subject id from the command line and the output file '{0}', while at " - "the same time, not all output files are absolute. Try passing the subjects directory in " + "Could not extract the subject id from the command line and " + "the output file '{0}', while at the same time, not all output " + "files are absolute. Try passing the subjects directory in " "args (c.f. {sd[flag]}), or absolute paths for {1}.".format( self._segfile_, ", ".join(relative_files), **self._flags ) ) raise RuntimeError(msg) - self._subjects = [self._orig_name_] + self._subjects = [Path(self._orig_name_)] self._num_subjects = 1 LOGGER.info(f"Analyzing single subject {self._orig_name_}") # 3. do we search in a directory elif getattr(args, "search_tag", None) is not None: search_tag = args.search_tag - if not os.path.isabs(search_tag) and getattr(args, "in_dir") is not None: - if not os.path.isdir(args.in_dir): - raise RuntimeError( - f"The input directory {args.in_dir} does not exist." - ) - search_tag = os.path.join(args.in_dir, search_tag) - where = f"in_dir {args.in_dir}" + _in_dir = getattr(args, "in_dir", None) + if not Path(search_tag).is_absolute() and _in_dir: + base = Path(_in_dir) + if not base.is_dir(): + raise RuntimeError(f"The input directory {base} does not exist.") + where = f"in_dir {base}" else: - where = f"the working directory {os.getcwd()}" - from glob import glob + base = Path.cwd() + where = f"the working directory {Path.cwd()}" - self._subjects = glob(search_tag) + self._subjects = list(base.glob(search_tag)) self._num_subjects = len(self._subjects) LOGGER.info( - f"Analyzing all {self._num_subjects} subjects from {where} with search pattern " - f"{search_tag}." + f"Analyzing all {self._num_subjects} subjects from {where} with search " + f"pattern {search_tag}." ) else: - raise RuntimeError( - "Could not identify how to find images to segment. Options are:\n1. Provide a text " - "file with one subject directory or image file per line via args.csv (cf. " - "{csv_file[flag]});\n2. specify an absolute path for relevant files, specifically the " - "t1 file via args.orig_name (cf. {t1[flag]}), but ideally also for expected output " - "files such as the segmentation output file,\n 3. provide a search pattern to search " - "for subject directories or images via args.search_tag (c.f. {tag[flag]}).\n Note also, " - "that the input directory (specified via {in_dir[flag]}) will be used as the base path " - "for relative file paths of input files.".format(**self._flags) + msg = ( + "Could not identify how to find images to segment. Options are:\n1. " + "Provide a text file with one subject directory or image file per line " + "via args.csv (cf. {csv_file[flag]});\n2. specify an absolute path for " + "relevant files, specifically the t1 file via args.orig_name (cf. " + "{t1[flag]}), but ideally also for expected output files such as the " + "segmentation output file,\n 3. provide a search pattern to search " + "for subject directories or images via args.search_tag (c.f. " + "{tag[flag]}).\n Note also, that the input directory (specified via " + "{in_dir[flag]}) will be used as the base path for relative file paths " + "of input files." ) + raise RuntimeError(msg.format(**self._flags)) self._remove_suffix = getattr(args, "remove_suffix", "") if self._num_subjects > 1: if getattr(args, "sid", "") not in ["", None]: - raise RuntimeError( - "The usage of args.sid (cf. {sid[flag]}) with multiple subjects is undefined.".format( - **self._flags - ) + msg = ( + "The usage of args.sid (cf. {sid[flag]}) with multiple subjects is " + "undefined." ) + raise RuntimeError(msg.format(**self._flags)) if self._remove_suffix == "": all_subject_files = self.are_all_subject_files() common_suffix = self.get_common_suffix() + msg = ( + "We detected that the subjects share the common suffix {0} in the " + "subject name. You can remove trailing parts of the filename such " + "as file extensions and/or other characters by passing this suffix " + "in args.remove_suffix (cf. {remove_suffix[flag]} , e.g. " + "{remove_suffix[flag]} '{0}'." + ) if all_subject_files and common_suffix != "": - LOGGER.info( - "We detected that the subjects share the common suffix {0} in the subject name. You " - "can remove trailing parts of the filename such as file extensions and/or other " - "characters by passing this suffix in args.remove_suffix (cf. {remove_suffix[flag]} " - ", e.g. {remove_suffix[flag]} '{0}'.".format( - common_suffix, **self._flags - ) + LOGGER.info(msg.format(common_suffix, **self._flags)) + if os.path.isabs(self._out_segfile): + raise RuntimeError( + f"An absolute path was passed for the output segmentation " + f"{self._out_segfile}, but more than one input image fits the " + f"input definition." ) - if os.path.isabs(self._out_segfile): - raise RuntimeError( - f"An absolute path was passed for the output segmentation {self._out_segfile}, " - f"but more than one input image fits the input definition." - ) self._sid = getattr(args, "sid", "") @@ -938,32 +951,35 @@ def _not_abs(subj_attr): @property def flags(self) -> Dict[str, Dict]: - """Give the flags. + """ + Give the flags. Returns ------- dict[str, dict] - Flags - + Flags. """ return self._flags def __len__(self) -> int: - """Give length of subject list. + """ + Give length of subject list. Returns ------- int - Number of subjects - + Number of subjects. """ return self._num_subjects def make_subjects_dir(self): - """Try to create the subject directory.""" + """ + Try to create the subject directory. + """ if self._out_dir is None: LOGGER.info( - "No Subjects directory found, absolute paths for filenames are required." + "No Subjects directory found, absolute paths for filenames are " + "required." ) return @@ -973,19 +989,22 @@ def make_subjects_dir(self): LOGGER.info("Output directory does not exist. Creating it now...") os.makedirs(self._out_dir) - def __getitem__(self, item: Union[int, str]) -> SubjectDirectory: - """Return a SubjectDirectory object for the i-th subject (if item is an int) or for the subject with name/folder (if item is a str). + def __getitem__(self, item: int | str) -> SubjectDirectory: + """ + Return a SubjectDirectory object for the i-th subject (if item is an int) or for + the subject with name/folder (if item is a str). Parameters ---------- - item : Union[int, str] - [MISSING] + item : int, str + The index or name of the subject. + If integer, it is treated as an index and corresponding subject is returned. + If string, it is treated as the subject. Returns ------- SubjectDirectory - [MISSING] - + A SubjectDirectory object corresponding to the provided index or name. """ if isinstance(item, int): if item < 0 or item >= self._num_subjects: @@ -993,15 +1012,16 @@ def __getitem__(self, item: Union[int, str]) -> SubjectDirectory: f"The index {item} is out of bounds for the subject list." ) - # subject is always an absolute path (or relative to the working directory) ... of the input file + # subject is always an absolute path (or relative to the working directory) + # ... of the input file subject = self._subjects[item] sid = ( - os.path.basename(removesuffix(subject, self._remove_suffix)) + Path(str(subject).removesuffix(self._remove_suffix)).name if self._sid is None else self._sid ) elif isinstance(item, str): - subject = item + subject = Path(item) sid = item else: raise TypeError("Invalid type of the item, must be int or str.") @@ -1015,24 +1035,28 @@ def __getitem__(self, item: Union[int, str]) -> SubjectDirectory: } orig_name = ( subject - if os.path.isfile(subject) - else os.path.join(subject, self._orig_name_) + if subject.is_file() + else subject / self._orig_name_ ) return SubjectDirectory( - subject_dir=self._out_dir, id=sid, orig_name=orig_name, **subject_parameters + subject_dir=self._out_dir, + id=sid, + orig_name=orig_name, + **subject_parameters, ) def get_common_suffix(self) -> str: - """Find common suffix, if all entries in the subject list share a common suffix. + """ + Find common suffix, if all entries in the subject list share a common suffix. Returns ------- str - The suffix the entries share - + The suffix the entries share. """ suffix = self._subjects[0] - for subj in self._subjects[1:]: + for subject_path in self._subjects[1:]: + subj = str(subject_path) if subj.endswith(suffix): continue for i in range(1 - len(suffix), 1): @@ -1044,23 +1068,23 @@ def get_common_suffix(self) -> str: return suffix def are_all_subject_files(self): - """Check if all entries in subjects are actually files. - - This is performed asynchronously internally """ - from asyncio import run, gather - - async def is_file(path): - return os.path.isfile(path) + Check if all entries in subjects are actually files. - async def check_files(files): - return await gather(*[is_file(s) for s in files]) + This is performed asynchronously internally. + """ + from concurrent.futures import ThreadPoolExecutor - return all(run(check_files(self._subjects))) + def is_file(p: Path): + return p.is_file() + with ThreadPoolExecutor(len(self._subjects)) as pool: + return all(pool.map(is_file, self._subjects)) -class NoParallelExecutor(Executor): - """Represent a serial executor.""" +class SerialExecutor(Executor): + """ + Represent a serial executor. + """ def map( self, @@ -1069,44 +1093,45 @@ def map( timeout: Optional[float] = None, chunksize: int = -1, ) -> Iterator[_T]: - """[MISSING]. + """ + The map function. Parameters ---------- fn : Callable[..., _T] - [MISSING] + A callable function to be applied to the items in the iterables. *iterables : Iterable[Any] - [MISSING] + One or more iterable objects. timeout : Optional[float] - [MISSING] (Default value = None) + Maximum number of seconds to wait for a result. Default is None. chunksize : int - [MISSING] (Default value = -1) + The size of the chunks, default value is -1. Returns ------- Iterator[_T] - [MISSING] - + An iterator that yields the results of applying 'fn' to the items of + 'iterables'. """ return map(fn, *iterables) def submit(self, __fn: Callable[..., _T], *args, **kwargs) -> "Future[_T]": - """[MISSING]. + """ + A callable function that returns a Future representing the result. Parameters ---------- __fn : Callable[..., _T] - [MISSING] + A callable function to be executed. *args : - [MISSING] + Potential arguments to be passed to the callable function. **kwargs : - [MISSING] + Keyword arguments to be passed to the callable function. Returns ------- "Future[_T]" - [MISSING] - + A Future object representing the execution result of the callable function. """ f = Future() try: diff --git a/FastSurferCNN/utils/dataclasses.py b/FastSurferCNN/utils/dataclasses.py new file mode 100644 index 00000000..c78abdfc --- /dev/null +++ b/FastSurferCNN/utils/dataclasses.py @@ -0,0 +1,160 @@ +from typing import Mapping, TypeVar, overload, Any, Callable, Optional + +from dataclasses import ( + field as _field, + asdict, + astuple, + dataclass, + fields, + Field, + FrozenInstanceError, + is_dataclass, + InitVar, + make_dataclass, + MISSING, + KW_ONLY, + replace, +) + +__all__ = [ + "field", + "asdict", + "astuple", + "dataclass", + "fields", + "Field", + "FrozenInstanceError", + "get_field", + "is_dataclass", + "InitVar", + "make_dataclass", + "MISSING", + "KW_ONLY", + "replace", +] + +_T = TypeVar("_T") + + +@overload +def field( + *, + default: _T, + help: str = "", + flags: tuple[str] = (), + init: bool = True, + repr: bool = True, + hash: bool | None = None, + compare: bool = True, + metadata: Mapping[Any, Any] | None = None, + kw_only: bool = ..., +) -> _T: ... + + +@overload +def field( + *, + default_factory: Callable[[], _T], + help: str = "", + flags: tuple[str] = (), + init: bool = True, + repr: bool = True, + hash: bool | None = None, + compare: bool = True, + metadata: Mapping[Any, Any] | None = None, + kw_only: bool = ..., +) -> _T: ... + + +@overload +def field( + *, + help: str = "", + flags: tuple[str] = (), + init: bool = True, + repr: bool = True, + hash: bool | None = None, + compare: bool = True, + metadata: Mapping[Any, Any] | None = None, + kw_only: bool = ..., +) -> Any: ... + + +def field( + *, + default: _T = MISSING, + default_factory: Callable[[], _T] = MISSING, + help: str = "", + flags: tuple[str] = (), + init: bool = True, + repr: bool = True, + hash: bool | None = None, + compare: bool = True, + metadata: Mapping[Any, Any] | None = None, + kw_only: bool = False, +) -> _T: + """ + Extends :py:`dataclasses.field` to adds `help` and `flags` to the metadata. + + Parameters + ---------- + help : str, default="" + A help string to be used in argparse description of parameters. + flags : tuple of str, default=() + A list of default flags to add for this attribute. + + Returns + ------- + When used in dataclasses, returns . + + See Also + -------- + :py:func:`dataclasses.field` + """ + if isinstance(metadata, Mapping): + metadata = dict(metadata) + elif metadata is None: + metadata = {} + else: + raise TypeError(f"Invalid type of metadata, must be a Mapping!") + if help: + if not isinstance(help, str): + raise TypeError("help must be a str!") + metadata["help"] = help + if flags: + if not isinstance(flags, tuple): + raise TypeError("flags must be a tuple!") + metadata["flags"] = flags + + kwargs = dict(init=init, repr=repr, hash=hash, compare=compare, kw_only=kw_only) + if default is not MISSING: + kwargs["default"] = default + if default_factory is not MISSING: + kwargs["default_factory"] = default_factory + return _field(**kwargs, metadata=metadata) + + +def get_field(dc, fieldname: str) -> Field | None: + """ + Return a specific Field object associated with a dataclass class or object. + + Parameters + ---------- + dc : dataclass, type[dataclass] + The dataclass containing the field. + fieldname : str + The name of the field. + + Returns + ------- + Field, None + The Field object associated with `fieldname` or None if the field does not exist. + + See Also + -------- + :py:`dataclasses.fields` + """ + for field in fields(dc): + if field.name == fieldname: + return field + return None diff --git a/FastSurferCNN/utils/load_config.py b/FastSurferCNN/utils/load_config.py index b0753796..bee0e987 100644 --- a/FastSurferCNN/utils/load_config.py +++ b/FastSurferCNN/utils/load_config.py @@ -20,18 +20,19 @@ def get_config(args: argparse.Namespace) -> yacs.config.CfgNode: - """Given the arguments, load and initialize the configs. + """ + Given the arguments, load and initialize the configs. Parameters ---------- args : argparse.Namespace - Object holding args + Object holding args. Returns ------- yacs.config.CfgNode - Configuration node - + Configuration node. + """ # Setup cfg. cfg = get_cfg_defaults() @@ -54,18 +55,18 @@ def get_config(args: argparse.Namespace) -> yacs.config.CfgNode: def load_config(cfg_file: str) -> yacs.config.CfgNode: - """Load a yaml config file. + """ + Load a yaml config file. Parameters ---------- cfg_file : str - Configuration filepath + Configuration filepath. Returns ------- yacs.config.CfgNode - configuration node - + Configuration node. """ # setup base cfg = get_cfg_defaults() diff --git a/FastSurferCNN/utils/logging.py b/FastSurferCNN/utils/logging.py index 9ac71c66..3b98752b 100644 --- a/FastSurferCNN/utils/logging.py +++ b/FastSurferCNN/utils/logging.py @@ -14,37 +14,30 @@ # IMPORTS from logging import * -from logging import ( - getLogger as get_logger, - StreamHandler, - FileHandler, - INFO, - DEBUG, - getLogger, - basicConfig, -) -from os import path, makedirs +from logging import DEBUG, INFO, FileHandler, StreamHandler, basicConfig +from logging import getLogger +from logging import getLogger as get_logger +from pathlib import Path as _Path from sys import stdout as _stdout -def setup_logging(log_file_path: str): - """Set up the logging. +def setup_logging(log_file_path: _Path | str): + """ + Set up the logging. Parameters ---------- - log_file_path : str - Path to the logfile - + log_file_path : Path, str + Path to the logfile. """ # Set up logging format. _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" handlers = [StreamHandler(_stdout)] if log_file_path: - log_dir_path = path.dirname(log_file_path) - log_file_name = path.basename(log_file_path) - if not path.exists(log_dir_path): - makedirs(log_dir_path) + if not isinstance(log_file_path, _Path): + log_file_path = _Path(log_file_path) + log_file_path.parent.mkdir(parents=True, exist_ok=True) handlers.append(FileHandler(filename=log_file_path, mode="a")) diff --git a/FastSurferCNN/utils/lr_scheduler.py b/FastSurferCNN/utils/lr_scheduler.py index 01a09c31..490a1080 100644 --- a/FastSurferCNN/utils/lr_scheduler.py +++ b/FastSurferCNN/utils/lr_scheduler.py @@ -11,36 +11,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import torch.optim + # IMPORTS import torch.optim.lr_scheduler as scheduler -from typing import Union - import yacs.config def get_lr_scheduler( - optimzer: torch.optim.Optimizer, - cfg: yacs.config.CfgNode + optimzer: torch.optim.Optimizer, cfg: yacs.config.CfgNode ) -> Union[None, scheduler.StepLR, scheduler.CosineAnnealingWarmRestarts]: - """Give a schedular for left-right scheduling. + """ + Give a schedular for left-right scheduling. Parameters ---------- optimzer : torch.optim.Optimizer - Optimizer for the scheduler + Optimizer for the scheduler. cfg : yacs.config.CfgNode - configuration node + Configuration node. Returns ------- - [MISSING] + Union[None, scheduler.StepLR, scheduler.CosineAnnealingWarmRestarts] + A learning rate scheduler configured according to `cfg`, or None if no scheduling is required. Raises ------ ValueError - lr scheduler is not supported - + lr scheduler is not supported. """ scheduler_type = cfg.OPTIMIZER.LR_SCHEDULER if scheduler_type == "step_lr": diff --git a/FastSurferCNN/utils/mapper.py b/FastSurferCNN/utils/mapper.py index 611776cf..4603b693 100644 --- a/FastSurferCNN/utils/mapper.py +++ b/FastSurferCNN/utils/mapper.py @@ -19,40 +19,39 @@ """ import json -from functools import partialmethod, partial, reduce -from numbers import Number, Integral +import os.path +from functools import partial, partialmethod, reduce +from numbers import Integral, Number from typing import ( + Any, Callable, - TypeVar, + Collection, + Container, + Dict, Generic, + Hashable, + Iterable, + Iterator, + List, + Literal, Mapping, - Tuple, Optional, - Union, - Set, - overload, - cast, - Hashable, - Dict, Sequence, - List, - Iterable, - Any, + Set, TextIO, - Literal, - Container, - Collection, - Iterator, + Tuple, + TypeVar, + Union, + cast, + overload, ) -import os.path - import numpy as np -from numpy import typing as npt +import pandas import torch -from matplotlib.cm import get_cmap +from matplotlib.pyplot import get_cmap from matplotlib.colors import Colormap -import pandas +from numpy import typing as npt from FastSurferCNN.utils import logging @@ -80,18 +79,18 @@ def is_int(a_object) -> bool: - """Check whether the array_or_tensor is an integer. + """ + Check whether the array_or_tensor is an integer. Parameters ---------- - a_object : - The object to check its type + a_object : Any + The object to check its type. Returns ------- bool - Whether the data type of the object is int or not - + Whether the data type of the object is int or not. """ from collections.abc import Collection @@ -108,20 +107,20 @@ def is_int(a_object) -> bool: def to_same_type(data, type_hint: AT) -> AT: - """Convert data to the same type as type_hint. + """ + Convert data to the same type as type_hint. Parameters ---------- - data : - the data to convert + data : Any + The data to convert. type_hint : AT - hint for the data type + Hint for the data type. Returns ------- AT - [MISSING] - + Data converted to the same type as specified by type_hint. """ if torch.is_tensor(type_hint) and not torch.is_tensor(data): return torch.as_tensor(data, dtype=type_hint.dtype, device=type_hint.device) @@ -134,7 +133,9 @@ def to_same_type(data, type_hint: AT) -> AT: class Mapper(Generic[KT, VT]): - """Map from one label space to a generic 'label'-space.""" + """ + Map from one label space to a generic 'label'-space. + """ _map_dict: Dict[KT, npt.NDArray[VT]] _label_shape: Tuple[int, ...] @@ -146,15 +147,15 @@ class Mapper(Generic[KT, VT]): def __init__( self, mappings: Mapping[KT, Union[VT, npt.NDArray[VT]]], name: str = "undefined" ): - """Construct `Mapper` object from a mappings dictionary. + """ + Construct `Mapper` object from a mappings dictionary. Parameters ---------- mappings : Mapping[KT, Union[VT, npt.NDArray[VT]]] - a dictionary of labels from -> to mappings + A dictionary of labels from -> to mappings. name : str - name for messages (default: "undefined"). - + Name for messages (default: "undefined"). """ if len(mappings) == 0: raise RuntimeError("The mappings object is empty.") @@ -182,27 +183,37 @@ def __init__( @property def name(self) -> str: - """Return the name of the mapper.""" + """ + Return the name of the mapper. + """ return self._name @name.setter def name(self, name: str): - """Set the name.""" + """ + Set the name. + """ self._name = name @property def source_space(self) -> Set[KT]: - """Return a set of labels the mapper accepts.""" + """ + Return a set of labels the mapper accepts. + """ return set(self._map_dict.keys()) @property def target_space(self) -> Collection[VT]: - """Return the set of labels the mapper converts to as a set of python-natives (if possible), arrays expanded to tuples.""" + """ + Return the set of labels the mapper converts to as a set of python-natives (if possible), arrays expanded to tuples. + """ return self._map_dict.values() @property def max_label(self) -> int: - """Return the max label.""" + """ + Return the max label. + """ if self._max_label is None: raise RuntimeError("max_label is only valid for integer keys.") return self._max_label @@ -210,20 +221,20 @@ def max_label(self) -> int: def update( self, other: "Mapper[KT, VT]", overwrite: bool = True ) -> "Mapper[KT, VT]": - """Merge another map into this mapper. + """ + Merge another map into this mapper. Parameters ---------- - other : "Mapper[KT, VT]" - [MISSING] - overwrite : bool - [MISSING] (Default value = True) + other : Mapper[KT, VT] + The other Mapper object whose key-value pairs are to be added to this Mapper object. + overwrite : bool, default=True + Flag to overwrite value if key already exists in Mapper object (Default value = True). Returns ------- - "Mapper[KT, VT]" - [MISSING] - + Mapper[KT, VT] + Mapper after merge. """ for key, value in iter(other): if overwrite or key not in self._map_dict: @@ -239,21 +250,21 @@ def update( __iadd__ = partialmethod(update, overwrite=True) def map(self, image: AT, out: Optional[AT] = None) -> AT: - """Forward map the labels from prediction to internal space. + """ + Forward map the labels from prediction to internal space. Parameters ---------- image : AT - data to map to internal space + Data to map to internal space. out : Optional[AT] - output array for performance - Returns an `numpy.ndarray` with mapped values. (Default value = None) + Output array for performance. + Returns an `numpy.ndarray` with mapped values. (Default value = None). Returns ------- AT - [MISSING] - + Data after being mapped to the internal space. """ # torch sparse tensors can't index with images # self._map = _b.torch.sparse_coo_tensor(src_labels, labels, (self._max_label,) + self._label_shape) @@ -327,20 +338,20 @@ def map(self, image: AT, out: Optional[AT] = None) -> AT: return to_same_type(mapped, type_hint=image) def _map_py(self, image: AT, out: Optional[AT] = None) -> AT: - """Map internally by python, for example for strings. + """ + Map internally by python, for example for strings. Parameters ---------- image : AT - image data + Image data. out : Optional[AT] - output data. Optional (Default value = None) + Output data. Optional (Default value = None). Returns ------- AT - [MISSING] - + Image data after being mapped. """ out_type = image if out is None else out if out is None: @@ -369,28 +380,30 @@ def _internal_map(img, o): def __call__( self, image: AT, label_image: Union[npt.NDArray[KT], torch.Tensor] ) -> Tuple[AT, Union[npt.NDArray, torch.Tensor]]: - """Transform a dataset from prediction to internal space for sets of image and segmentation. + """ + Transform a dataset from prediction to internal space for sets of image and segmentation. Parameters ---------- image : AT - image - will stay same + Image - will stay same. label_image : Union[npt.NDArray[KT], torch.Tensor] - data to map to internal space + Data to map to internal space Returns two `numpy.ndarray`s with image and mapped values. Returns ------- image : image - image + Image. Union[npt.NDArray, torch.Tensor] - mapped values - + Mapped values. """ return image, self.map(label_image) def reversed_dict(self) -> Mapping[VT, KT]: - """Map dictionary from the target space to the source space.""" + """ + Map dictionary from the target space to the source space. + """ rev_mappings = {} for src in sorted(self.source_space): a = self._map_dict[src] @@ -402,29 +415,40 @@ def reversed_dict(self) -> Mapping[VT, KT]: return rev_mappings def __reversed__(self) -> "Mapper[VT, KT]": - """Reverse map the original transformation (with non-bijective mappings mapping to the lower key).""" + """ + Reverse map the original transformation (with non-bijective mappings mapping to the lower key). + """ return Mapper(self.reversed_dict(), name="reverse-" + self.name) def is_bijective(self) -> bool: - """Return, whether the Mapper is bijective.""" + """ + Return, whether the Mapper is bijective. + """ return len(self.source_space) == len(self.target_space) def __getitem__(self, item: KT) -> VT: - """Return the value of the item.""" + """ + Return the value of the item. + """ return self._map_dict[item] def __iter__(self) -> Iterator[Tuple[KT, VT]]: - """[MISSING].""" + """ + Create an iterator for the Mapper object. + """ return iter(self._map_dict.items()) def __contains__(self, item: KT) -> bool: - """Check whether the mapping contains the item.""" + """ + Check whether the mapping contains the item. + """ return self._map_dict.__contains__(item) def chain( self, other_mapper: "Mapper[VT, T_OtherValue]" ) -> "Mapper[KT, T_OtherValue]": - """Chain the current mapper with the `other_mapper`. + """ + Chain the current mapper with the `other_mapper`. This effectively is an optimization to first applying this mapper and then applying the `other_mapper`. @@ -437,23 +461,24 @@ def chain( Returns ------- Mapper : "Mapper[KT, T_OtherValue]" - A mapper mapping from the input space of this mapper to the target-space of the `other_mapper`. - + A mapper mapping from the input space of this mapper to the target-space of + the `other_mapper`. """ target_space = list(self.target_space) is_target_set = [not isinstance(t, Hashable) for t in target_space] if any(is_target_set): index = is_target_set.index(True) raise ValueError( - f"The target space must be hashable, but {is_target_set.count(True)} values are not " - f"hashable, for example {index}: {target_space[index]}." + f"The target space must be hashable, but {is_target_set.count(True)} " + f"values are not hashable, for example {index}: {target_space[index]}." ) target_space = set(target_space) if not target_space <= other_mapper.source_space: - # test whether every element in self.target_space is also in other_mapper.source_space + # test whether every element in self.target_space is also in + # other_mapper.source_space raise ValueError( - f"The first set ({self.name}) maps to the following keys, that the second mapper " - f"({other_mapper.name}) does not map from:\n " + f"The first set ({self.name}) maps to the following keys, that the " + f"second mapper ({other_mapper.name}) does not map from:\n " + ", ".join(f"'{v}'" for v in target_space - other_mapper.source_space) ) return Mapper( @@ -472,35 +497,35 @@ def make_classmapper( compress_out_space: bool = False, name: str = "undefined", ) -> "Mapper[int, int]": - """Map from one label space (int) to another (also int) using a mappings function. + """ + Map from one label space (int) to another (also int) using a mappings function. Can also be used as a transform. - + Creates a :class:`Mapper` object from a mappings dictionary and a list of labels to keep. Parameters ---------- mappings : Dict[int, int] - a dictionary of labels from -> to mappings + A dictionary of labels from -> to mappings. keep_labels : Sequence[int] - a list of classes to keep after mapping, where all not included classes are not changed - (default: empty) + A list of classes to keep after mapping, where all not included classes are not changed + (default: empty). compress_out_space : bool - whether to reassign labels to reduce the maximum label (default: False) + Whether to reassign labels to reduce the maximum label (default: False). name : str - name for messages (default: "undefined"). + Mame for messages (default: "undefined"). Returns ------- "Mapper[int, int]" - [MISSING] + A Mapper object that provides a mapping from one label space to another. Raises ------ ValueError If keep_labels contains an entry > 65535. - """ if any(v not in keep_labels for v in mappings.values()): mappings.update(dict((k, k) for k in keep_labels)) @@ -525,7 +550,9 @@ def _map_logits( out: Optional[AT] = None, mode: Literal["logit", "prob"] = "logit", ) -> AT: - """Map logits or probabilities with the Mapper.""" + """ + Map logits or probabilities with the Mapper. + """ if not is_int(self.source_space) or not is_int(self.target_space): raise ValueError("map_logits/map_probs requires a mapping from int to int.") @@ -601,7 +628,9 @@ def _map_logits( class ColorLookupTable(Generic[KT]): - """This class provides utility in creating color palettes from colormaps.""" + """ + This class provides utility in creating color palettes from colormaps. + """ _color_palette: Optional[npt.NDArray[float]] _colormap: Union[str, Colormap, ColormapGenerator] @@ -615,22 +644,22 @@ def __init__( colormap: Union[str, Colormap, ColormapGenerator] = "gist_ncar", name: Optional[str] = None, ): - """Construct a LookupTable object. + """ + Construct a LookupTable object. Parameters ---------- classes : Optional[Iterable[KT]] - Iterable of the classes. (Default value = None) + Iterable of the classes. (Default value = None). color_palette : Union[Dict[KT, npt.ArrayLike], npt.ArrayLike], Optional colors associated with each class, either indexed by a dictionary (class -> Color) or by the - order of classes in classes (default: None). (Default value = None) + order of classes in classes (default: None). (Default value = None). colormap : Union[str, Colormap, ColormapGenerator] Alternative to color_palette, uses a colormap to generate a color_palette automatically. Colormap can be string, matplotlib.Colormap or a function (num_classes -> NDArray of shape (num_classes, 3 or 4)) (default: 'gist_ncar'). name : Optional[str] - name for messages (default: "unnamed lookup table"). - + Name for messages (default: "unnamed lookup table"). """ self._name = "unnamed lookup table" if name is None else name @@ -648,30 +677,36 @@ def __init__( @property def name(self) -> str: - """Return the name of the mapper.""" + """ + Return the name of the mapper. + """ return self._name @name.setter def name(self, name: str): - """Set the name.""" + """ + Set the name. + """ self._name = name @property def classes(self) -> Optional[List[KT]]: - """Return the classes.""" + """ + Return the classes. + """ return self._classes @classes.setter def classes(self, classes: Optional[Iterable[KT]]): - """Set the classes and generates a color palette for the given classes. - + """ + Set the classes and generates a color palette for the given classes. + Will override a manually set color_palette. Parameters ---------- classes : Optional[Iterable[KT]] Iterable of the classes. - """ if classes is None: # resetting the classes @@ -684,14 +719,18 @@ def classes(self, classes: Optional[Iterable[KT]]): @property def color_palette(self) -> Optional[npt.NDArray[float]]: - """Return the color palette if it exists.""" + """ + Return the color palette if it exists. + """ return self._color_palette @color_palette.setter def color_palette( self, color_palette: Union[Dict[KT, npt.ArrayLike], npt.ArrayLike, None] ): - """Set (or reset) the color palette of the LookupTable.""" + """ + Set (or reset) the color palette of the LookupTable. + """ if color_palette is None: self._color_palette = None else: @@ -713,18 +752,18 @@ def color_palette( self._color_palette = color_palette def __getitem__(self, key: KT) -> Tuple[int, KT, Tuple[int, int, int, int], Any]: - """Return index, key, colors and additional values for the key. + """ + Return index, key, colors and additional values for the key. Parameters ---------- key : KT - [MISSING] + The key for which the information is to be retrieved. Raises ------- ValueError - If key is not in _classes - + If key is not in _classes. """ index = self._classes.index(key) return self.getitem_by_index(index) @@ -732,12 +771,16 @@ def __getitem__(self, key: KT) -> Tuple[int, KT, Tuple[int, int, int, int], Any] def getitem_by_index( self, index: int ) -> Tuple[int, KT, Tuple[int, int, int, int], Any]: - """Return index, key, colors and additional values for the key.""" + """ + Return index, key, colors and additional values for the key. + """ color = self.get_color_by_index(index, 255) return index, self._classes[index], color, None def get_color_by_index(self, index: int, base: NT = 1.0) -> Tuple[NT, NT, NT, NT]: - """Return the color (r, g, b, a) tuple associated with the index in the passed base.""" + """ + Return the color (r, g, b, a) tuple associated with the index in the passed base. + """ if self._color_palette is None: raise RuntimeError("No color_palette set") base_type = type(base) @@ -755,7 +798,9 @@ def get_color_by_index(self, index: int, base: NT = 1.0) -> Tuple[NT, NT, NT, NT return color def colormap(self) -> Mapper[KT, ColorTuple]: - """[MISSING].""" + """ + Generate a Mapper object that maps classes to their corresponding colors. + """ if self._color_palette is None: raise RuntimeError("No color_palette set") return Mapper( @@ -763,7 +808,8 @@ def colormap(self) -> Mapper[KT, ColorTuple]: ) def labelname2index(self) -> Mapper[KT, int]: - """Return a mapping between the key and the (consecutive) index it is associated with. + """ + Return a mapping between the key and the (consecutive) index it is associated with. This is the inverse of ColorLookupTable.classes. """ @@ -773,19 +819,24 @@ def labelname2index(self) -> Mapper[KT, int]: ) def labelname2id(self) -> Mapper[KT, Any]: - """Return a mapping between the key and the value it is associated with. + """ + Return a mapping between the key and the value it is associated with. + + Mapper[KT, Any] + Not implemented in the base class. Raises ------ RuntimeError If no value is associated. - """ raise RuntimeError("The base class keeps no ids (only indexes).") class JsonColorLookupTable(ColorLookupTable[KT]): - """[MISSING].""" + """ + This class extends the ColorLookupTable to handle JSON data. + """ _data: Any @@ -796,7 +847,8 @@ def __init__( colormap: Union[str, Colormap, ColormapGenerator] = "gist_ncar", name: Optional[str] = None, ) -> None: - """Construct a JsonLookupTable object from `file_or_buffer` passed. + """ + Construct a JsonLookupTable object from `file_or_buffer` passed. Parameters ---------- @@ -810,8 +862,7 @@ def __init__( can be string, matplotlib.Colormap or a function (num_classes -> NDArray of shape (num_classes, 3 or 4)) (default: 'gist_ncar'). name : Optional[str] - name for messages (default: fallback to file_or_buffer, if possible). - + Name for messages (default: fallback to file_or_buffer, if possible). """ if isinstance(file_or_buffer, str) and file_or_buffer.lstrip().startswith("{"): self._data = json.loads(file_or_buffer) @@ -853,7 +904,9 @@ def __init__( ) def _get_labels(self) -> Union[Dict[KT, Any], Iterable[KT]]: - """Return labels.""" + """ + Return labels. + """ return ( self._data["labels"] if isinstance(self._data, dict) and "labels" in self._data @@ -861,12 +914,16 @@ def _get_labels(self) -> Union[Dict[KT, Any], Iterable[KT]]: ) def dataframe(self) -> pandas.DataFrame: - """[MISSING].""" + """ + Converts the labels from the internal data dictionary to a pandas DataFrame. + """ if isinstance(self._data, dict) and "labels" in self._data: return pandas.DataFrame.from_dict(self._data["labels"]) def __getitem__(self, key: KT) -> Tuple[int, KT, Tuple[int, int, int, int], Any]: - """Index by the index position, unless either key or value are int.""" + """ + Index by the index position, unless either key or value are int. + """ labels = self._get_labels() index, key, color, _other = super(JsonColorLookupTable, self).__getitem__(key) if isinstance(labels, dict): @@ -874,18 +931,18 @@ def __getitem__(self, key: KT) -> Tuple[int, KT, Tuple[int, int, int, int], Any] return index, key, color, _other def labelname2id(self) -> Mapper[KT, Any]: - """Return a mapping between the key and the value it is associated with. + """ + Return a mapping between the key and the value it is associated with. Returns ------- Mapper[KT, Any] - [MISSING] + A Mapper object that provides a mapping between label names (keys) and their corresponding IDs (values). Raises ------ RuntimeError If no value is associated. - """ labels = self._get_labels() if not isinstance(labels, dict): @@ -896,7 +953,9 @@ def labelname2id(self) -> Mapper[KT, Any]: class TSVLookupTable(ColorLookupTable[str]): - """[MISSING].""" + """ + This class extends the ColorLookupTable to handle TSV (Tab Separated Values) data. + """ _data: pandas.DataFrame @@ -907,7 +966,8 @@ def __init__( header: bool = False, add_background: bool = True, ) -> None: - """Create a CSVLookupTable object from `file_or_buffer` passed. + """ + Create a CSVLookupTable object from `file_or_buffer` passed. Parameters ---------- @@ -915,12 +975,11 @@ def __init__( A `pandas`-compatible object to read from. Refer to :func:`pandas.read_csv` for additional documentation. name : str, Optional - name for messages (default: fallback to file_or_buffer, if possible). + Name for messages (default: fallback to file_or_buffer, if possible). header : bool - whether the TSV file has a header line (default: False). + Whether the TSV file has a header line (default: False). add_background : bool - whether to add a label for background (default: True) - + Whether to add a label for background (default: True). """ if name is None: if isinstance(file_or_buffer, str): @@ -942,7 +1001,7 @@ def __init__( self._data = pandas.read_csv( file_or_buffer, - delim_whitespace=True, + sep='\s+', index_col=0, skip_blank_lines=True, comment="#", @@ -967,45 +1026,49 @@ def __init__( def getitem_by_index( self, index: int ) -> Tuple[int, str, Tuple[int, int, int, int], int]: - """Find the Entry associated by a No. + """ + Find the Entry associated by a No. Parameters ---------- index : int - the index + The index Returns a tuple of the index, the label, and a tuple of the RGBA color label. Returns ------- index : int - [MISSING] + The index of the entry. key : str - [MISSING] + The label name associated with the entry. color : Tuple[int, int, int, int] - [MISSING] + The RGBA color label associated with the entry. int - [MISSING] - + The data index associated with the entry. """ index, key, color, _ = super(TSVLookupTable, self).getitem_by_index(index) return index, key, color, self._data.iloc[index].name def dataframe(self) -> pandas.DataFrame: - """Return the raw panda data object.""" + """ + Return the raw panda data object. + """ return self._data def labelname2id(self) -> Mapper[KT, Any]: - """Return a Mapper between the key and the value it is associated with. + """ + Return a Mapper between the key and the value it is associated with. Returns ------- Mapper[KT, Any] + A Mapper object that links keys to their corresponding values based on the + class and data index. Raises ------ RuntimeError If no value is associated. - """ return Mapper( dict(zip(self._classes, self._data.index)), name="value-" + self.name diff --git a/FastSurferCNN/utils/meters.py b/FastSurferCNN/utils/meters.py index b1688df3..e0a7b902 100644 --- a/FastSurferCNN/utils/meters.py +++ b/FastSurferCNN/utils/meters.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any +from typing import Any, Optional + +import matplotlib.pyplot as plt # IMPORTS import numpy as np -import matplotlib.pyplot as plt import torch import yacs.config @@ -23,12 +24,13 @@ from FastSurferCNN.utils.metrics import DiceScore from FastSurferCNN.utils.misc import plot_confusion_matrix - logger = logging.getLogger(__name__) class Meter: - """[MISSING].""" + """ + Meter class to keep track of the losses and scores during training and validation. + """ def __init__( self, @@ -39,28 +41,29 @@ def __init__( total_epoch: Optional[int] = None, class_names: Optional[Any] = None, device: Optional[Any] = None, - writer: Optional[Any] = None, + writer: Optional[Any] = None, ): - """Construct a Meter object. + """ + Construct a Meter object. Parameters ---------- cfg - [MISSING] + Configuration Node. mode - [MISSING] + Meter mode (Train or Val). global_step - [MISSING] + Global step. total_iter - [MISSING] + Total iterations (Default value = None). total_epoch - [MISSING] + Total epochs (Default value = None). class_names - [MISSING] + Class names (Default value = None). device - [MISSING] + Device (Default value = None). writer - [MISSING] + Writer (Default value = None). """ self._cfg = cfg @@ -78,37 +81,45 @@ def __init__( self.total_epochs = total_epoch def reset(self): - """Reset bach losses and dice scores.""" + """ + Reset bach losses and dice scores. + """ self.batch_losses = [] self.dice_score.reset() def enable_confusion_mat(self): - """[MISSING].""" + """ + Enable confusion matrix. + """ self.confusion_mat = True def disable_confusion_mat(self): - """[MISSING].""" + """ + Disable confusion matrix. + """ self.confusion_mat = False def update_stats(self, pred, labels, batch_loss): - """[MISSING].""" + """ + Update the statistics. + """ self.dice_score.update((pred, labels), self.confusion_mat) self.batch_losses.append(batch_loss.item()) def write_summary(self, loss_total, lr=None, loss_ce=None, loss_dice=None): - """Write a summary of the losses and scores. + """ + Write a summary of the losses and scores. Parameters ---------- - loss_total : - [MISSING] - lr : - [MISSING] (Default value = None) - loss_ce : - [MISSING] (Default value = None) - loss_dice : - [MISSING] (Default value = None) - + loss_total : torch.Tensor + Total loss. + lr : default = None + Learning rate (Default value = None). + loss_ce : default = None + Cross entropy loss (Default value = None). + loss_dice : default = None + Dice loss (Default value = None). """ self.writer.add_scalar( f"{self.mode}/total_loss", loss_total.item(), self.global_iter @@ -127,15 +138,15 @@ def write_summary(self, loss_total, lr=None, loss_ce=None, loss_dice=None): self.global_iter += 1 def log_iter(self, cur_iter: int, cur_epoch: int): - """Log the current iteration. + """ + Log the current iteration. Parameters ---------- cur_iter : int - current iteration + Current iteration. cur_epoch : int - current epoch - + Current epoch. """ if (cur_iter + 1) % self._cfg.TRAIN.LOG_INTERVAL == 0: logger.info( @@ -150,13 +161,13 @@ def log_iter(self, cur_iter: int, cur_epoch: int): ) def log_epoch(self, cur_epoch: int): - """Log the current epoch. + """ + Log the current epoch. Parameters ---------- cur_epoch : int - current epoch - + Current epoch. """ dice_score = self.dice_score.compute_dsc() self.writer.add_scalar(f"{self.mode}/mean_dice_score", dice_score, cur_epoch) diff --git a/FastSurferCNN/utils/metrics.py b/FastSurferCNN/utils/metrics.py index db6844c9..0b0cadf2 100644 --- a/FastSurferCNN/utils/metrics.py +++ b/FastSurferCNN/utils/metrics.py @@ -12,37 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Optional, Tuple + +import numpy as np + # IMPORTS import torch -import numpy as np -from typing import Tuple, Optional, Any from FastSurferCNN.utils import logging logger = logging.getLogger(__name__) -def iou_score(pred_cls: torch.Tensor, true_cls: torch.Tensor, nclass: int =79) -> Tuple[np.ndarray, np.ndarray]: - """Compute the intersection-over-union score. +def iou_score( + pred_cls: torch.Tensor, true_cls: torch.Tensor, nclass: int = 79 +) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute the intersection-over-union score. - Both inputs should be categorical (as opposed to one-hot) + Both inputs should be categorical (as opposed to one-hot). Parameters ---------- pred_cls : torch.Tensor - network prediction (categorical) + Network prediction (categorical). true_cls : torch.Tensor - ground truth (categorical) + Ground truth (categorical). nclass : int - number of classes (Default value = 79) + Number of classes (Default value = 79). Returns ------- np.ndarray - [MISSING] + An array containing the intersection for each class. np.ndarray - [MISSING] - + An array containing the union for each class. """ intersect_ = [] union_ = [] @@ -59,28 +63,28 @@ def iou_score(pred_cls: torch.Tensor, true_cls: torch.Tensor, nclass: int =79) - def precision_recall( - pred_cls: torch.Tensor, - true_cls: torch.Tensor, - nclass: int = 79 + pred_cls: torch.Tensor, true_cls: torch.Tensor, nclass: int = 79 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Calculate recall (TP/(TP + FN) and precision (TP/(TP+FP) per class. + """ + Calculate recall (TP/(TP + FN) and precision (TP/(TP+FP) per class. Parameters ---------- pred_cls : torch.Tensor - network prediction (categorical) + Network prediction (categorical). true_cls : torch.Tensor - ground truth (categorical) + Ground truth (categorical). nclass : int - number of classes (Default value = 79) + Number of classes (Default value = 79). Returns ------- np.ndarray - [MISSING] + An array containing the number of true positives for each class. np.ndarray - [MISSING] - + An array containing the sum of true positives and false negatives for each class. + np.ndarray + An array containing the sum of true positives and false positives for each class. """ tpos_fneg = [] tpos_fpos = [] @@ -98,20 +102,21 @@ def precision_recall( class DiceScore: - """Accumulate the component of the dice coefficient i.e. the union and intersection. + """ + Accumulate the component of the dice coefficient i.e. the union and intersection. Attributes ---------- op : callable - a callable to update accumulator. Method's signature is `(accumulator, output)`. + A callable to update accumulator. Method's signature is `(accumulator, output)`. For example, to compute arithmetic mean value, `op = lambda a, x: a + x`. output_transform : callable, optional - a callable that is used to transform the + A callable that is used to transform the :class:`~ignite.engine.Engine`'s `process_function`'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. device : str of torch.device, optional - device specification in case of distributed computation usage. + Device specification in case of distributed computation usage. In most of the cases, it can be defined as "cuda:local_rank" or "cuda" if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is initialized and available, device is set to `cuda`. @@ -120,10 +125,12 @@ class DiceScore: def __init__( self, num_classes: int, - device: Optional[str] =None, + device: Optional[str] = None, output_transform=lambda y_pred, y: (y_pred.data.max(1)[1], y), ): - """Construct DiceScore object.""" + """ + Construct DiceScore object. + """ self._device = device self.out_transform = output_transform self.n_classes = num_classes @@ -131,14 +138,18 @@ def __init__( self.intersection = torch.zeros(self.n_classes, self.n_classes, device=device) def reset(self): - """[MISSING].""" + """ + Reset the union and intersection matrices to zero. + """ self.union = torch.zeros(self.n_classes, self.n_classes, device=self._device) self.intersection = torch.zeros( self.n_classes, self.n_classes, device=self._device ) def _check_output_type(self, output): - """Check the output type.""" + """ + Check the output type. + """ if not (isinstance(output, tuple)): raise TypeError( "Output should a tuple consist of of torch.Tensors, but given {}".format( @@ -146,16 +157,18 @@ def _check_output_type(self, output): ) ) - def _update_union_intersection_matrix(self, batch_output: torch.Tensor, labels_batch: torch.Tensor): - """Update the union intersection matrix. + def _update_union_intersection_matrix( + self, batch_output: torch.Tensor, labels_batch: torch.Tensor + ): + """ + Update the union intersection matrix. Parameters ---------- batch_output : torch.Tensor - output tensor + Output tensor. labels_batch : torch.Tensor - label batch - + Label batch. """ for i in range(self.n_classes): gt = (labels_batch == i).float() @@ -164,15 +177,17 @@ def _update_union_intersection_matrix(self, batch_output: torch.Tensor, labels_b self.intersection[i, j] += torch.sum(torch.mul(gt, pred)) self.union[i, j] += torch.sum(gt) + torch.sum(pred) - def _update_union_intersection(self, batch_output: torch.Tensor, labels_batch: torch.Tensor): - """Update the union intersection. + def _update_union_intersection( + self, batch_output: torch.Tensor, labels_batch: torch.Tensor + ): + """ + Update the union intersection. Parameters ---------- batch_output : torch.Tensor - batch output (prediction, labels) + Batch output (prediction, labels). labels_batch : torch.Tensor - """ for i in range(self.n_classes): gt = (labels_batch == i).float() @@ -181,15 +196,15 @@ def _update_union_intersection(self, batch_output: torch.Tensor, labels_batch: t self.union[i, i] += torch.sum(gt) + torch.sum(pred) def update(self, output: Tuple[Any, Any], cnf_mat: bool): - """Update the intersection. + """ + Update the intersection. Parameters ---------- output : Tuple[Any, Any] - Network output tensor + Network output tensor. cnf_mat : bool - Confusion matrix - + Confusion matrix. """ self._check_output_type(output) @@ -205,32 +220,50 @@ def update(self, output: Tuple[Any, Any], cnf_mat: bool): self._update_union_intersection(y_pred, y) def compute_dsc(self) -> float: - """Compute the dice score. + """ + Compute the dice score. Returns ------- dsc : float - dice score - + Dice score. """ dsc_per_class = self._dice_calculation() dsc = dsc_per_class.mean() return dsc def comput_dice_cnf(self): - """Compute the dice cnf.""" + """ + Compute the dice cnf. + """ dice_cm_mat = self._dice_confusion_matrix() return dice_cm_mat def _dice_calculation(self): - """[MISSING].""" + """ + Calculate the Dice Score. + + The Dice Score is calculated as 2 * (intersection / union). + + Returns + ------- + dsc : torch.Tensor + The Dice Score for each class. + """ intersection = self.intersection.diagonal() union = self.union.diagonal() dsc = 2 * torch.div(intersection, union) return dsc def _dice_confusion_matrix(self): - """[MISSING].""" + """ + Calculate the Dice confusion matrix. + + Returns + ------- + dice_cnf_matrix : numpy.ndarray + The Dice confusion matrix for each class. + """ dice_intersection = self.intersection.cpu().numpy() dice_union = self.union.cpu().numpy() if not (dice_union > 0).all(): @@ -240,5 +273,7 @@ def _dice_confusion_matrix(self): def dice_score(cm): - """[MISSING].""" + """ + [MISSING]. + """ pass diff --git a/FastSurferCNN/utils/misc.py b/FastSurferCNN/utils/misc.py index 7ddb7050..310caeaa 100644 --- a/FastSurferCNN/utils/misc.py +++ b/FastSurferCNN/utils/misc.py @@ -17,40 +17,41 @@ from itertools import product from typing import List -import FastSurferCNN.data_loader.loader import matplotlib.figure -import torch +import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt -import matplotlib.pyplot as plt +import torch import yacs.config from mpl_toolkits.axes_grid1 import make_axes_locatable -from torchvision import utils from skimage import color +from torchvision import utils + +import FastSurferCNN.data_loader.loader def plot_predictions( - images_batch: torch.Tensor, - labels_batch: torch.Tensor, - batch_output: torch.Tensor, - plt_title: str, - file_save_name: str + images_batch: torch.Tensor, + labels_batch: torch.Tensor, + batch_output: torch.Tensor, + plt_title: str, + file_save_name: str, ) -> None: - """Plot predictions from validation set. + """ + Plot predictions from validation set. Parameters ---------- images_batch : torch.Tensor - batch of images + Batch of images. labels_batch : torch.Tensor - batch of labels + Batch of labels. batch_output : torch.Tensor - batch of output + Batch of output. plt_title : str - plot title + Plot title. file_save_name : str - name the plot should be saved tp - + Name the plot should be saved tp. """ f = plt.figure(figsize=(20, 10)) n, c, h, w = images_batch.shape @@ -81,32 +82,32 @@ def plot_predictions( def plot_confusion_matrix( - cm: npt.NDArray, - classes: List[str], - title: str = "Confusion matrix", - cmap: plt.cm.ColormapRegistry = plt.cm.Blues, - file_save_name: str = "temp.pdf" + cm: npt.NDArray, + classes: List[str], + title: str = "Confusion matrix", + cmap: plt.cm.ColormapRegistry = plt.cm.Blues, + file_save_name: str = "temp.pdf", ) -> matplotlib.figure.Figure: - """Plot the confusion matrix. + """ + Plot the confusion matrix. Parameters ---------- cm : npt.NDArray - confusion matrix + Confusion matrix. classes : List[str] - list of class names + List of class names. title : str - (Default value = "Confusion matrix") + (Default value = "Confusion matrix"). cmap : plt.cm.ColormapRegistry - colour map (Default value = plt.cm.Blues) + Colour map (Default value = plt.cm.Blues). file_save_name : str - (Default value = "temp.pdf") + (Default value = "temp.pdf"). Returns ------- fig : matplotlib.figure.Figure - [MISSING] - + Matplotlib Figure object with the confusion matrix plot. """ n_classes = len(classes) @@ -149,18 +150,18 @@ def plot_confusion_matrix( def find_latest_experiment(path: str) -> int: - """Find and load latest experiment. + """ + Find and load latest experiment. Parameters ---------- path : str - path to the latest experiment + Path to the latest experiment. Returns ------- int - latest experiments - + Latest experiments. """ list_of_experiments = os.listdir(path) list_of_int_experiments = [] @@ -178,21 +179,24 @@ def find_latest_experiment(path: str) -> int: def check_path(path: str): - """Create path.""" + """ + Create path. + """ os.makedirs(path, exist_ok=True) return path -def update_num_steps(dataloader: FastSurferCNN.data_loader.loader.DataLoader, - cfg: yacs.config.CfgNode): - """Update the number of steps. +def update_num_steps( + dataloader: FastSurferCNN.data_loader.loader.DataLoader, cfg: yacs.config.CfgNode +): + """ + Update the number of steps. Parameters ---------- dataloader : FastSurferCNN.data_loader.loader.DataLoader - [MISSING] + The dataloader object that contains the training data. cfg : yacs.config.CfgNode - [MISSING] - + The configuration object that contains the training configuration. """ cfg.TRAIN.NUM_STEPS = len(dataloader) diff --git a/FastSurferCNN/utils/parser_defaults.py b/FastSurferCNN/utils/parser_defaults.py index a69fa2f6..9b90a4cc 100644 --- a/FastSurferCNN/utils/parser_defaults.py +++ b/FastSurferCNN/utils/parser_defaults.py @@ -22,22 +22,25 @@ Values can also be extracted by >>> print(ALL_FLAGS["allow_root"](dict, dest="root") ->>> # {'flag': '--allow_root', 'flags': ('--allow_root',), 'action': 'store_true', 'dest': 'root', ->>> # 'help': 'Allow execution as root user.'} +>>> # {'flag': '--allow_root', 'flags': ('--allow_root',), 'action': 'store_true', +>>> # 'dest': 'root', 'help': 'Allow execution as root user.'} """ import argparse -from os import path -from typing import Iterable, Mapping, Union, Literal, Dict, Protocol, TypeVar, Type +import types +from dataclasses import dataclass, Field, fields +from pathlib import Path +from typing import (Dict, Iterable, Literal, Mapping, Protocol, Type, TypeVar, Union, + Optional, get_origin, get_args) +from FastSurferCNN.utils import Plane, PLANES +from FastSurferCNN.utils.arg_types import float_gt_zero_and_le_one as __conform_to_one +from FastSurferCNN.utils.arg_types import unquote_str +from FastSurferCNN.utils.arg_types import vox_size as __vox_size +from FastSurferCNN.utils.dataclasses import field, get_field from FastSurferCNN.utils.threads import get_num_threads -from FastSurferCNN.utils.arg_types import ( - vox_size as __vox_size, - float_gt_zero_and_le_one as __conform_to_one_mm, - unquote_str, -) -FASTSURFER_ROOT = path.dirname(path.dirname(path.dirname(__file__))) +FASTSURFER_ROOT = Path(__file__).parents[2] PLANE_SHORT = {"checkpoint": "ckpt", "config": "cfg"} PLANE_HELP = { "checkpoint": "{} checkpoint to load", @@ -47,211 +50,276 @@ class CanAddArguments(Protocol): - """[MISSING].""" + """ + + """ def add_argument(self, *args, **kwargs): - """[MISSING].""" + """ + Add an argument to the object. + """ ... -def __arg(*default_flags, **default_kwargs): - """Create stub function, which sets default settings for argparse arguments. +def __arg( + *default_flags: str, + dcf: Optional[Field] = None, + dc=None, + fieldname: str = "", + **default_kwargs, +): + """ + Create stub function, which sets default settings for argparse arguments. + + The positional and keyword arguments function as if they were directly passed to + parser.add_arguments(). + + The result will be a stub function, which has as first argument a parser (or other + object with an add_argument method) to which the argument is added. The stub + function also accepts positional and keyword arguments, which overwrite the default + arguments. Additionally, these specific values can be callables, which will be + called upon the default values (to alter the default value). - The positional and keyword arguments function as if they were directly passed to parser.add_arguments(). - - The result will be a stub function, which has as first argument a parser (or other object with an - add_argument method) to which the argument is added. The stub function also accepts positional and - keyword arguments, which overwrite the default arguments. Additionally, these specific values can be callables, - which will be called upon the default values (to alter the default value). - This function is private for this module. """ + # TODO Update the Parameters section of this function. + if dcf is None and dc is not None: + if not bool(fieldname) and default_flags[0].startswith("--"): + fieldname = default_flags[0].removeprefix("--") + if fieldname: + dcf = get_field(dc, fieldname) + + if dcf is not None: + for kw, name in (("dest", "name"), ("default",) * 2): + default_kwargs.setdefault(kw, getattr(dcf, name)) + if "type" not in default_kwargs: + if str(get_origin(dcf.type)) == "typing.Union": + _types = list(t for t in get_args(dcf.type) if t is not types.NoneType) + if len(_types) == 0: + default_kwargs["type"] = None + elif len(_types) == 1: + default_kwargs["type"] = _types[0] + else: + raise TypeError( + "A Union Type cannot be used to generate a argparse command " + "from a dataclasses.Field, must pass a type to __arg!" + ) + else: + default_kwargs["type"] = dcf.type + for kw, default in dcf.metadata.items(): + if kw == "flags": + if isinstance(default, tuple) and len(default_flags) == 0: + default_flags = default + else: + default_kwargs.setdefault(kw, default) + def _stub(parser: Union[CanAddArguments, Type[Dict]], *flags, **kwargs): # prefer the value passed to the "new" call for kw, arg in kwargs.items(): - if callable(arg) and kw in default_kwargs.keys(): + if callable(arg) and kw in default_kwargs: kwargs[kw] = arg(default_kwargs[kw]) - # if no new value is provided to _stub (which is the callable in ALL_FLAGS), use the - # default value (stored in the callable/passed to the default below) + # if no new value is provided to _stub (which is the callable in ALL_FLAGS), use + # the default value (stored in the callable/passed to the default below) for kw, default in default_kwargs.items(): - if kw not in kwargs.keys(): - kwargs[kw] = default + kwargs.setdefault(kw, default) _flags = flags if len(flags) != 0 else default_flags if hasattr(parser, "add_argument"): return parser.add_argument(*_flags, **kwargs) - elif parser == dict: + elif parser == dict or isinstance(parser, dict): return {"flag": _flags[0], "flags": _flags, **kwargs} else: raise ValueError( - f"Unclear parameter, should be dict or argparse.ArgumentParser, not {type(parser).__name__}." + f"Unclear parameter, should be dict or argparse.ArgumentParser, not " + f"{type(parser).__name__}." ) return _stub -ALL_FLAGS = { - "t1": __arg( - "--t1", - type=str, - dest="orig_name", +# TODO add Attributes section to SubjectDirectoryConfig. SubjectDirectoryConfig should +# probably be moved to a different file (as part of the refactoring effort). + + +@dataclass +class SubjectDirectoryConfig: + """ + This class describes the 'minimal' parameters used by SubjectList. + """ + orig_name: str = field( + help="Name of T1 full head MRI. Absolute path if single image else common " + "image name. Default: `mri/orig.mgz`.", default="mri/orig.mgz", - help="Name of T1 full head MRI. Absolute path if single image else " - "common image name. Default: mri/orig.mgz", - ), - "remove_suffix": __arg( - "--remove_suffix", - type=str, - dest="remove_suffix", + flags=("--t1",), + ) + pred_name: str = field( + default="mri/aparc.DKTatlas+aseg.deep.mgz", + help="Name of intermediate DL-based segmentation file (similar to aparc+aseg). " + "When using FastSurfer, this segmentation is already conformed, since " + "inference is always based on a conformed image. Absolute path if single " + "image else common image name. Default: mri/aparc.DKTatlas+aseg.deep.mgz", + ) + conf_name: str = field( + default="mri/orig.mgz", + help="Name under which the conformed input image will be saved, in the same " + "directory as the segmentation (the input image is always conformed " + "first, if it is not already conformed). The original input image is " + "saved in the output directory as $id/mri/orig/001.mgz. Default: " + "mri/orig.mgz.", + flags=("--conformed_name",), + ) + in_dir: Optional[Path] = field( + flags=("--in_dir",), + default=None, + help="Directory in which input volume(s) are located. Optional, if full path " + "is defined for --t1.", + ) + csv_file: Optional[Path] = field( + flags=("--csv_file",), + default=None, + help="Csv-file with subjects to analyze (alternative to --tag)", + ) + sid: Optional[str] = field( + flags=("--sid",), + default=None, + help="Optional: directly set the subject id to use. Can be used for single " + "subject input. For multi-subject processing, use remove suffix if sid is " + "not second to last element of input file passed to --t1", + ) + search_tag: str = field( + flags=("--tag",), + default="*", + help="Search tag to process only certain subjects. If a single image should be " + "analyzed, set the tag with its id. Default: processes all.", + ) + brainmask_name: str = field( + default="mri/mask.mgz", + help="Name under which the brainmask image will be saved, in the same " + "directory as the segmentation. The brainmask is created from the " + "aparc_aseg segmentation (dilate 5, erode 4, largest component). Default: " + "`mri/mask.mgz`.", + flags=("--brainmask_name",), + ) + remove_suffix: str = field( + flags=("--remove_suffix",), default="", - help="Optional: remove suffix from path definition of input file to yield correct subject name " - "(e.g. /ses-x/anat/ for BIDS or /mri/ for FreeSurfer input). Default: do not remove anything.", - ), - "sid": __arg( - "--sid", - type=str, - dest="sid", + help="Optional: remove suffix from path definition of input file to yield " + "correct subject name (e.g. /ses-x/anat/ for BIDS or /mri/ for FreeSurfer " + "input). Default: do not remove anything.", + ) + out_dir: Optional[Path] = field( default=None, - help="Optional: directly set the subject id to use. Can be used for single subject input. For multi-subject " - "processing, use remove suffix if sid is not second to last element of input file passed to --t1", - ), + help="Directory in which evaluation results should be written. Will be created " + "if it does not exist. Optional if full path is defined for --pred_name.", + ) + + +ALL_FLAGS = { + "t1": __arg("--t1", dc=SubjectDirectoryConfig, fieldname="orig_name"), + "remove_suffix": __arg("--remove_suffix", dc=SubjectDirectoryConfig), + "sid": __arg("--sid", dc=SubjectDirectoryConfig), "asegdkt_segfile": __arg( "--asegdkt_segfile", "--aparc_aseg_segfile", - type=str, - dest="pred_name", - default="mri/aparc.DKTatlas+aseg.deep.mgz", - help="Name of intermediate DL-based segmentation file (similar to aparc+aseg). " - "When using FastSurfer, this segmentation is already conformed, since inference " - "is always based on a conformed image. Absolute path if single image else common " - "image name. Default: mri/aparc.DKTatlas+aseg.deep.mgz", - ), - "conformed_name": __arg( - "--conformed_name", - type=str, - dest="conf_name", - default="mri/orig.mgz", - help="Name under which the conformed input image will be saved, in the same directory " - "as the segmentation (the input image is always conformed first, if it is not " - "already conformed). The original input image is saved in the output directory " - "as $id/mri/orig/001.mgz. Default: mri/orig.mgz.", + dc=SubjectDirectoryConfig, + fieldname="pred_name", ), + "conformed_name": __arg("--conformed_name", dc=SubjectDirectoryConfig, fieldname="conf_name"), "norm_name": __arg( "--norm_name", type=str, dest="norm_name", default="mri/norm.mgz", - help="Name under which the bias field corrected image is stored. Default: mri/norm.mgz.", - ), - "brainmask_name": __arg( - "--brainmask_name", - type=str, - dest="brainmask_name", - default="mri/mask.mgz", - help="Name under which the brainmask image will be saved, in the same directory " - "as the segmentation. The brainmask is created from the aparc_aseg segmentation " - "(dilate 5, erode 4, largest component). Default: mri/mask.mgz.", + help="Name under which the bias field corrected image is stored. Default: " + "mri/norm.mgz.", ), + "brainmask_name": __arg("--brainmask_name", dc=SubjectDirectoryConfig), "aseg_name": __arg( "--aseg_name", type=str, dest="aseg_name", default="mri/aseg.auto_noCCseg.mgz", - help="Name under which the reduced aseg segmentation will be saved, in the same directory " - "as the aparc-aseg segmentation (labels of full aparc segmentation are reduced to aseg). " - "Default: mri/aseg.auto_noCCseg.mgz.", + help="Name under which the reduced aseg segmentation will be saved, in the " + "same directory as the aparc-aseg segmentation (labels of full aparc " + "segmentation are reduced to aseg). Default: mri/aseg.auto_noCCseg.mgz.", ), "seg_log": __arg( "--seg_log", type=str, dest="log_name", default="", - help="Absolute path to file in which run logs will be saved. If not set, logs will " - "not be saved.", + help="Absolute path to file in which run logs will be saved. If not set, logs " + "will not be saved.", ), "device": __arg( "--device", default="auto", - help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or specify a certain gpu " - "(e.g. cuda:1), default: auto", + help="Select device to run inference on: cpu, or cuda (= Nvidia gpu) or " + "specify a certain gpu (e.g. cuda:1), default: auto", ), "viewagg_device": __arg( "--viewagg_device", dest="viewagg_device", type=str, default="auto", - help="Define the device, where the view aggregation should be run. By default, the program checks " - "if you have enough memory to run the view aggregation on the gpu (cuda). The total memory is " - "considered for this decision. If this fails, or you actively overwrote the check with setting " - "> --viewagg_device cpu <, view agg is run on the cpu. Equivalently, if you define " - "> --viewagg_device cuda <, view agg will be run on the gpu (no memory check will be done).", - ), - "in_dir": __arg( - "--in_dir", - type=str, - default=None, - help="Directory in which input volume(s) are located. " - "Optional, if full path is defined for --t1.", + help="Define the device, where the view aggregation should be run. By default, " + "the program checks if you have enough memory to run the view aggregation " + "on the gpu (cuda). The total memory is considered for this decision. If " + "this fails, or you actively overwrote the check with setting " + "> --viewagg_device cpu <, view agg is run on the cpu. Equivalently, if " + "you define > --viewagg_device cuda <, view agg will be run on the gpu " + "(no memory check will be done).", ), + "in_dir": __arg("--in_dir", dc=SubjectDirectoryConfig, fieldname="in_dir"), "tag": __arg( "--tag", type=unquote_str, - dest="search_tag", - default="*", - help="Search tag to process only certain subjects. If a single image should be analyzed, " - "set the tag with its id. Default: processes all.", - ), - "csv_file": __arg( - "--csv_file", - type=str, - help="Csv-file with subjects to analyze (alternative to --tag)", - default=None, + dc=SubjectDirectoryConfig, + fieldname="search_tag", ), + "csv_file": __arg("--csv_file", dc=SubjectDirectoryConfig), "batch_size": __arg( - "--batch_size", type=int, default=1, help="Batch size for inference. Default=1" - ), - "sd": __arg( - "--sd", - type=str, - default=None, - dest="out_dir", - help="Directory in which evaluation results should be written. " - "Will be created if it does not exist. Optional if full path is defined for --pred_name.", + "--batch_size", + type=int, + default=1, + help="Batch size for inference. Default=1" ), + "sd": __arg("--sd", dc=SubjectDirectoryConfig, fieldname="out_dir"), "qc_log": __arg( "--qc_log", type=str, dest="qc_log", default="", - help="Absolute path to file in which a list of subjects that failed QC check (when processing multiple " - "subjects) will be saved. If not set, the file will not be saved.", + help="Absolute path to file in which a list of subjects that failed QC check " + "(when processing multiple subjects) will be saved. If not set, the file " + "will not be saved.", ), "vox_size": __arg( "--vox_size", type=__vox_size, default="min", dest="vox_size", - help="Choose the primary voxelsize to process, must be either a number between 0 and 1 (below 0.7 is " - "experimental) or 'min' (default). A number forces processing at that specific voxel size, 'min' " - "determines the voxel size from the image itself (conforming to the minimum voxel size, or 1 if " - "the minimum voxel size is above 0.95mm). ", + help="Choose the primary voxelsize to process, must be either a number between " + "0 and 1 (below 0.7 is experimental) or 'min' (default). A number forces " + "processing at that specific voxel size, 'min' determines the voxel size " + "from the image itself (conforming to the minimum voxel size, or 1 if the " + "minimum voxel size is above 0.95mm). ", ), "conform_to_1mm_threshold": __arg( "--conform_to_1mm_threshold", - type=__conform_to_one_mm, + type=__conform_to_one, default=0.95, dest="conform_to_1mm_threshold", - help="The voxelsize threshold, above which images will be conformed to 1mm isotropic, if the --vox_size " - "argument is also 'min' (the --vox_size default setting). Contrary to conform.py, the default behavior" - "of %(prog)s is to resample all images _above 0.95mm_ to 1mm.", + help="The voxelsize threshold, above which images will be conformed to 1mm " + "isotropic, if the --vox_size argument is also 'min' (the --vox_size " + "default setting). Contrary to conform.py, the default behavior of " + "%(prog)s is to resample all images above 0.95mm to 1mm.", ), "lut": __arg( "--lut", - type=str, + type=Path, help="Path and name of LUT to use.", - default=path.join( - FASTSURFER_ROOT, "FastSurferCNN/config/FastSurfer_ColorLUT.tsv" - ), + default=FASTSURFER_ROOT / "FastSurferCNN/config/FastSurfer_ColorLUT.tsv", ), "allow_root": __arg( "--allow_root", @@ -264,14 +332,16 @@ def _stub(parser: Union[CanAddArguments, Type[Dict]], *flags, **kwargs): dest="threads", default=get_num_threads(), type=int, - help=f"Number of threads to use (defaults to number of hardware threads: {get_num_threads()})", + help=f"Number of threads to use (defaults to number of hardware threads: " + f"{get_num_threads()})", ), "async_io": __arg( "--async_io", dest="async_io", action="store_true", - help="Allow asynchronous file operations (default: off). Note, this may impact the order of " - "messages in the log, but speed up the segmentation specifically for slow file systems.", + help="Allow asynchronous file operations (default: off). Note, this may impact " + "the order of messages in the log, but speed up the segmentation " + "specifically for slow file systems.", ), } @@ -279,21 +349,29 @@ def _stub(parser: Union[CanAddArguments, Type[Dict]], *flags, **kwargs): def add_arguments(parser: T_AddArgs, flags: Iterable[str]) -> T_AddArgs: - """Add default flags to the parser from the flags list in order. + """ + Add default flags to the parser from the flags list in order. Parameters ---------- parser : T_AddArgs The parser to add flags to. flags : Iterable[str] - the flags to add from 'device', 'viewagg_device'. + The flags to add from 'device', 'viewagg_device'. Returns ------- T_AddArgs - The parser object + The parser object. + Raises + ------ + RuntimeError + If parser does not support a call to add_argument. """ + if not hasattr(parser, "add_argument") or not callable(parser.add_argument): + raise RuntimeError("parser does not support add_argument!") + for flag in flags: if flag.startswith("--"): flag = flag[2:] @@ -302,17 +380,19 @@ def add_arguments(parser: T_AddArgs, flags: Iterable[str]) -> T_AddArgs: add_flag(parser) else: raise ValueError( - f"The flag '{flag}' is not defined in FastSurferCNN.utils.parse.add_arguments()." + f"The flag '{flag}' is not defined in {add_arguments.__qualname__}." ) return parser def add_plane_flags( - parser: argparse.ArgumentParser, - type: Literal["checkpoint", "config"], - files: Mapping[str, str], -) -> argparse.ArgumentParser: - """Add plane arguments. + parser: T_AddArgs, + configtype: Literal["checkpoint", "config"], + files: Mapping[Plane, Path | str], + defaults_path: Path | str, +) -> T_AddArgs: + """ + Add plane arguments. Arguments will be added for each entry in files, where the key is the "plane" and the values is the file name (relative for path relative to FASTSURFER_HOME. @@ -321,34 +401,42 @@ def add_plane_flags( ---------- parser : argparse.ArgumentParser The parser to add flags to. - type : Literal["checkpoint", "config"] + configtype : Literal["checkpoint", "config"] The type of files (for help text and prefix from "checkpoint" and "config". - "checkpoint" will lead to flags like "--ckpt_{plane}", "config" to "--cfg_{plane}" - files : Mapping[str, str] - A dictionary of plane to filename. Relative files are assumed to be relative to the FastSurfer root - directory. + "checkpoint" will lead to flags like "--ckpt_{plane}", "config" to + "--cfg_{plane}". + files : Mapping[Plane, Path | str] + A dictionary of plane to filename. Relative files are assumed to be relative to + the FastSurfer root directory. + defaults_path : Path, str + A path to the file to load defaults from. Returns ------- argparse.ArgumentParser The parser object. - """ - if type not in PLANE_SHORT: + if configtype not in PLANE_SHORT: raise ValueError("type must be either config or checkpoint.") - for key, filepath in files.items(): - if not path.isabs(filepath): - filepath = path.join(FASTSURFER_ROOT, filepath) + from FastSurferCNN.utils.checkpoint import load_checkpoint_config_defaults + defaults = load_checkpoint_config_defaults(configtype, defaults_path) + + for plane, filepath in files.items(): + path = defaults[plane] if str(filepath) == "default" else Path(filepath) + if not path.is_absolute(): + path = FASTSURFER_ROOT / path # find the first vowel in the key - flag = key.strip().lower() - index = min(i for i in (flag.find(v) for v in "aeiou") if i >= 0) - flag = flag[: index + 2] + plane = plane.strip().lower() + if plane not in PLANES: + raise ValueError(f"Invalid key in files, no plane: {plane}") + index = min(i for i in (plane.find(v) for v in "aeiou") if i >= 0) + plane_short = plane[: index + 2] parser.add_argument( - f"--{PLANE_SHORT[type]}_{flag}", - type=str, - dest=f"{PLANE_SHORT[type]}_{flag}", - help=PLANE_HELP[type].format(key), - default=filepath, + f"--{PLANE_SHORT[configtype]}_{plane_short}", + type=Path, + dest=f"{PLANE_SHORT[configtype]}_{plane_short}", + help=PLANE_HELP[configtype].format(plane), + default=path, ) return parser diff --git a/FastSurferCNN/utils/run_tools.py b/FastSurferCNN/utils/run_tools.py index 354be024..a206420c 100644 --- a/FastSurferCNN/utils/run_tools.py +++ b/FastSurferCNN/utils/run_tools.py @@ -1,8 +1,26 @@ +#!/bin/python + +# Copyright 2023 Image Analysis Lab, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from concurrent.futures import Executor, Future import subprocess +from concurrent.futures import Executor, Future from dataclasses import dataclass from functools import partialmethod -from typing import Optional, Generator, Sequence +from typing import Generator, Optional, Sequence, Callable, Any, Collection, Iterable +from datetime import datetime # TODO: python3.9+ # from collections.abc import Generator @@ -10,80 +28,119 @@ @dataclass class MessageBuffer: - out: bytes = b'' - err: bytes = b'' + """ + MessageBuffer class. + """ + + out: bytes = b"" + err: bytes = b"" retcode: Optional[int] = None + runtime: float = 0. - def __add__(self, other: 'MessageBuffer') -> 'MessageBuffer': + def __add__(self, other: "MessageBuffer") -> "MessageBuffer": if not isinstance(other, MessageBuffer): raise ValueError("Can only append another MessageBuffer!") - return MessageBuffer(out=self.out + other.out, err=self.err + other.err, retcode=other.retcode) + return MessageBuffer( + out=self.out + other.out, + err=self.err + other.err, + retcode=other.retcode, + runtime=max(self.runtime or 0.0, other.runtime or 0.0), + ) - def __iadd__(self, other: 'MessageBuffer'): - """Append another MessageBuffer's content to this MessageBuffer.""" + def __iadd__(self, other: "MessageBuffer"): + """ + Append another MessageBuffer's content to this MessageBuffer. + """ if not isinstance(other, MessageBuffer): raise ValueError("other must be a MessageBuffer!") self.out += other.out self.err += other.err self.retcode = other.retcode + self.runtime = max(self.runtime or 0.0, other.runtime or 0.0) return self - def out_str(self, encoding=None): + def out_str(self, encoding="utf-8"): return self.out.decode(encoding=encoding) - def err_str(self, encoding=None): + def err_str(self, encoding="utf-8"): return self.err.decode(encoding=encoding) class Popen(subprocess.Popen): - """Extension of subprocess.Popen for convenience.""" + """ + Extension of subprocess.Popen for convenience. + """ + _starttime: Optional[datetime] = None + + def __init__(self, *args, **kwargs): + self._starttime = datetime.now() + super().__init__(*args, **kwargs) def messages(self, timeout: float) -> Generator[MessageBuffer, None, None]: from subprocess import TimeoutExpired + + start = self._starttime or datetime.now() while self.poll() is None: try: - stdout, stderr = self.communicate(timeout=timeout) - yield MessageBuffer( - out=stdout if stdout else b'', - err=stderr if stderr else b'', - retcode=self.returncode) + stdout, stderr = self.communicate(timeout=timeout) + yield MessageBuffer( + out=stdout if stdout else b"", + err=stderr if stderr else b"", + retcode=self.returncode, + runtime=(datetime.now() - start).total_seconds(), + ) except TimeoutExpired: pass - _stdout = b'' if self.stdout is None or self.stdout.closed else self.stdout.read() - _stderr = b'' if self.stderr is None or self.stderr.closed else self.stderr.read() - if _stderr != b'' or _stdout != b'': + _stdout = ( + b"" if self.stdout is None or self.stdout.closed else self.stdout.read() + ) + _stderr = ( + b"" if self.stderr is None or self.stderr.closed else self.stderr.read() + ) + if _stderr != b"" or _stdout != b"": yield MessageBuffer( out=_stdout, err=_stderr, - retcode=self.returncode) + retcode=self.returncode, + runtime=(datetime.now() - start).total_seconds(), + ) def next_message(self, timeout: float) -> MessageBuffer: - + start = self._starttime or datetime.now() if self.poll() is None: stdout, stderr = self.communicate(timeout=timeout) return MessageBuffer( - out=stdout if stdout else b'', - err=stderr if stderr else b'', - retcode=self.returncode) + out=stdout if stdout else b"", + err=stderr if stderr else b"", + retcode=self.returncode, + runtime=(datetime.now() - start).total_seconds(), + ) else: - _stdout = b'' if self.stdout is None or self.stdout.closed else self.stdout.read() - _stderr = b'' if self.stderr is None or self.stderr.closed else self.stderr.read() + _stdout = ( + b"" if self.stdout is None or self.stdout.closed else self.stdout.read() + ) + _stderr = ( + b"" if self.stderr is None or self.stderr.closed else self.stderr.read() + ) if _stderr or _stdout: return MessageBuffer( out=_stdout, err=_stderr, - retcode=self.returncode) + retcode=self.returncode, + runtime=(datetime.now() - start).total_seconds(), + ) else: raise StopIteration() - __next__ = partialmethod(next_message, timeout=0.) - __iter__ = partialmethod(messages, timeout=0.) + __next__ = partialmethod(next_message, timeout=0.0) + __iter__ = partialmethod(messages, timeout=0.0) def finish(self, timeout: float = None) -> MessageBuffer: - """`finish`'s behavior is similar to `subprocess.dry_run`. + """ + `finish`'s behavior is similar to `subprocess.dry_run`. `finish` waits `timeout` seconds, and forces termination after. By default, waits unlimited `timeout=None`. In either case, all messages in stdout and @@ -93,31 +150,35 @@ def finish(self, timeout: float = None) -> MessageBuffer: Parameters ---------- timeout : float, optional - seconds to wait before forcing termination + Seconds to wait before forcing termination. Returns ------- - A MessageBuffer object with the content of the stdout and stderr pipes. + MessageBuffer + A MessageBuffer object with the content of the stdout and stderr pipes. """ try: self.wait(timeout) except subprocess.TimeoutExpired: self.terminate() - msg = MessageBuffer() + msg = MessageBuffer(runtime=0.0) i = 0 for _msg in self.messages(timeout=0.25): msg += _msg if i > 0: self.kill() - raise RuntimeError("The process {} did not stop properly in Popen.finish, " - "abandoning.".format(self)) + raise RuntimeError( + "The process {} did not stop properly in Popen.finish, " + "abandoning.".format(self) + ) i += 1 if i == 0: msg.retcode = self.returncode return msg def as_future(self, pool: Executor, timeout: float = None) -> Future: - """Similar to `finish` in its application, but as non-blocking Future. + """ + Similar to `finish` in its application, but as non-blocking Future. Parameters ---------- @@ -127,11 +188,12 @@ def as_future(self, pool: Executor, timeout: float = None) -> Future: Returns ------- Future[MessageBuffer] - A Future object which will contain the result + A Future object which will contain the result. See Also -------- finish + The `finish` method provides similar functionality. """ return pool.submit(self.finish, timeout=timeout) @@ -140,7 +202,6 @@ async def async_finish(self, timeout: float = None) -> MessageBuffer: class PyPopen(Popen): - def __init__(self, args: Sequence[str], *_args, **kwargs): """ Create a python process with same flags, and additional args. @@ -148,18 +209,32 @@ def __init__(self, args: Sequence[str], *_args, **kwargs): Parameters ---------- args : Sequence[str] - arguments to python process + Arguments to python process. additional arguments as in subprocess.Popen See Also -------- Popen - subprocess.Popen + subprocess.Popen. """ import sys - all_flags = {"d": "debug", "i": "inspect", "I": "isolated", "0": "optimize", "B": "dont_write_bytecode", - "s": "no_user_site", "S": "no_site", "E": "ignore_environment", "v": "verbose", - "b": "bytes_warning", "q": "quiet", "R": "hash_randomization"} - flags = ''.join(k for k, v in all_flags.items() if getattr(sys.flags, v) == 1) + + all_flags = { + "d": "debug", + "i": "inspect", + "I": "isolated", + "0": "optimize", + "B": "dont_write_bytecode", + "s": "no_user_site", + "S": "no_site", + "E": "ignore_environment", + "v": "verbose", + "b": "bytes_warning", + "q": "quiet", + "R": "hash_randomization", + } + flags = "".join(k for k, v in all_flags.items() if getattr(sys.flags, v) == 1) flags = [] if len(flags) == 0 else ["-" + flags] - super(PyPopen, self).__init__([sys.executable] + flags + list(args), *_args, **kwargs) + super(PyPopen, self).__init__( + [sys.executable] + flags + list(args), *_args, **kwargs + ) diff --git a/FastSurferCNN/utils/threads.py b/FastSurferCNN/utils/threads.py index af4117cc..aa2c7591 100644 --- a/FastSurferCNN/utils/threads.py +++ b/FastSurferCNN/utils/threads.py @@ -12,10 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. + def get_num_threads(): + """ + Determine the number of available threads. + + Tries to get the process's CPU affinity for usable thread count; defaults + to total CPU count on failure. + + Returns + ------- + int + Number of threads available to the process or total CPU count. + """ try: from os import sched_getaffinity as __getaffinity + return len(__getaffinity(0)) except ImportError: from os import cpu_count + return cpu_count() diff --git a/FastSurferCNN/version.py b/FastSurferCNN/version.py index 6275b8ef..bcd27d7e 100644 --- a/FastSurferCNN/version.py +++ b/FastSurferCNN/version.py @@ -1,14 +1,12 @@ #!/bin/python -# - import argparse import re import shutil import subprocess -from io import TextIOWrapper from pathlib import Path -from typing import Optional, Union, Dict, Any +from typing import Any, cast, get_args, Literal, Optional, TypedDict, Sequence, TextIO +from concurrent.futures import ThreadPoolExecutor, Future class DEFAULTS: @@ -22,76 +20,165 @@ class DEFAULTS: } +class _RequiredVersionDict(TypedDict): + """ + Dictionary with keys 'version_line', 'version', 'git_hash', 'git_branch'. + """ + git_branch: str + git_hash: str + version: str + version_line: str + + +class _OptionalVersionDict(TypedDict, total=False): + """ + Dictionary with optional keys 'checkpoints', 'git_status', and 'pypackages'. + """ + content: str + checkpoints: str + git_status: str + pypackages: str + version_tag: str + + +class VersionDict(_RequiredVersionDict, _OptionalVersionDict): + """ + Dictionary with keys 'version_line', 'version', 'git_hash', 'git_branch', + 'checkpoints', 'git_status', and 'pypackages'. The last 3 are optional and may + be missing depending on the content of the file. + """ + pass + + +VersionDictKeys = Literal[ + "git_branch", + "git_hash", + "version", + "version_line", + "checkpoints", + "git_status", + "pypackages", +] + + def section(arg: str) -> str: - """Validate the argument is a valid sections string. + """ + Validate the argument is a valid sections string. A valid sections string is either 'all' or a concatenated list of '+checkpoints', - '+git', and '+pip', e.g. '+git+checkpoints'. The order does not matter.""" + '+git', and '+pip', e.g. '+git+checkpoints'. The order does not matter. + + Parameters + ---------- + arg : str + The input string to be validated. + + Returns + ------- + str + The validated sections string. + """ from re import match + if arg == "all": return "+checkpoints+git+pip" elif match("^(\\+branch|\\+checkpoints|\\+git|\\+pip)*$", arg): return arg else: - raise argparse.ArgumentTypeError("The section argument must be 'all', or any combination of " - "'+branch', '+checkpoints', '+git' and '+pip'.") + raise argparse.ArgumentTypeError( + "The section argument must be 'all', or any combination of " + "'+branch', '+checkpoints', '+git' and '+pip'." + ) def make_parser(): - """Generate the argument parser for the version script.""" - parser = argparse.ArgumentParser(description="Helper script to read and write version information") - parser.add_argument("--sections", - default="", - type=section, - help="Sections to include from +checkpoints, +git, +pip. If not passed, will " - "only have the version number.") - parser.add_argument("--build_cache", - type=argparse.FileType("r"), - help=f"File to read version info to read from (default: {DEFAULTS.BUILD_TXT}).") - parser.add_argument("--project_file", - type=argparse.FileType("r"), - help=f"File to project detail / version info from (default: {DEFAULTS.PROJECT_TOML}).") - parser.add_argument("-o", "--file", - default=None, - type=argparse.FileType("w"), - help=f"File to write version info to (default: write to stdout).") - parser.add_argument('--prefer_cache', - action="store_true", - help="Avoid running commands and only read the build file ") + """ + Generate the argument parser for the version script. + + Returns + ------- + argparse.ArgumentParser + The argument parser for the version script. + """ + parser = argparse.ArgumentParser( + description="Helper script to read and write version information" + ) + parser.add_argument( + "--sections", + default="", + type=section, + help="Sections to include from +checkpoints, +git, +pip. If not passed, will " + "only have the version number.", + ) + parser.add_argument( + "--build_cache", + type=argparse.FileType("r"), + help=f"File to read version info to read from (default: {DEFAULTS.BUILD_TXT}).", + ) + parser.add_argument( + "--project_file", + type=argparse.FileType("r"), + help=f"File to project detail / version info from (default: {DEFAULTS.PROJECT_TOML}).", + ) + parser.add_argument( + "-o", + "--file", + default=None, + type=argparse.FileType("w"), + help=f"File to write version info to (default: write to stdout).", + ) + parser.add_argument( + "--prefer_cache", + action="store_true", + help="Avoid running commands and only read the build file.", + ) return parser +def has_git(): + """ + Determine whether FastSurfer is installed as a git directory. + + Returns + ------- + bool + Whether git commands on the FastSurfer dir likely work. + """ + return shutil.which("git") is not None and (DEFAULTS.PROJECT_ROOT / ".git").is_dir() + + def print_build_file( - version: str, - git_hash: str = "", - git_branch: str = "", - git_status: str = "", - checkpoints: str = "", - pypackages: str = "", - file: Optional[TextIOWrapper] = None + version: str, + git_hash: str = "", + git_branch: str = "", + git_status: str = "", + checkpoints: str = "", + pypackages: str = "", + file: Optional[TextIO] = None, ) -> None: - """Format and print the build file. + """ + Format and print the build file. Parameters ---------- version : str - The version number to print + The version number to print. git_hash : str, optional - The git hash associated with the build, may be empty - git_branch: str, optional - The git branch associated with the build, may be empty - checkpoints : str, optional - The md5sums of the checkpoint files, leave empty to skip. + The git hash associated with the build, may be empty. + git_branch : str, optional + The git branch associated with the build, may be empty. git_status : str, optional The md5sums of the checkpoint files, leave empty to skip. + checkpoints : str, optional + The md5sums of the checkpoint files, leave empty to skip. pypackages : str, optional The md5sums of the checkpoint files, leave empty to skip. - file : TextIOWrapper, optional - A file-like object to write the build file to, default: stdout + file : TextIO, optional + A file-like object to write the build file to, default: stdout. - See Also - -------- - main + Notes + ----- + See also main. """ if file is None: @@ -119,12 +206,14 @@ def print_header(section_name: str) -> None: def main( - sections: str = "", - project_file: Optional[TextIOWrapper] = None, - build_cache: Optional[TextIOWrapper] = None, - file: Optional[TextIOWrapper] = None, - prefer_cache: bool = False) -> Union[str, int]: - """Print version info to stdout or file. + sections: str = "", + project_file: Optional[TextIO] = None, + build_cache: Optional[TextIO | bool] = None, + file: Optional[TextIO] = None, + prefer_cache: bool = False, +) -> str | int: + """ + Print version info to stdout or file. Prints/writes version info of FastSurfer in the style: ``` @@ -157,104 +246,133 @@ def main( '+git+checkpoints'. The order does not matter, '+checkpoints', '+git' or '+pip' also implicitly activate '+branch'. - project_file : TextIOWrapper, optional + project_file : TextIO, optional A file-like object to read the projects toml file, with the '[project]' section - with a 'version' attribute. Defaults to $PROJECT_ROOT/pyproject.toml - build_cache : TextIOWrapper, optional + with a 'version' attribute. Defaults to $PROJECT_ROOT/pyproject.toml. + build_cache : False, TextIO, optional A file-like object to read cached version information, the format should be - formatted like the output of `main`. Defaults to $PROJECT_ROOT/BUILD.info - file : TextIOWrapper, optional + formatted like the output of `main`. Defaults to $PROJECT_ROOT/BUILD.info. + If build_cache is False, it is ignored. + file : TextIO, optional A file-like object to write the output to, defaults to stdout if None or not - passed + passed. prefer_cache : bool, default=False Whether to prefer information from the `build_cache` over online generation. Returns ------- int or str - Returns 0, if the function was successful, a error message if a problem occurred. + Returns 0, if the function was successful, or an error message. """ - has_git = shutil.which("git") is not None and (DEFAULTS.PROJECT_ROOT / ".git").is_dir() - has_build_cache = build_cache is not None or DEFAULTS.BUILD_TXT.is_file() + has_build_cache = False + # ignore build_cache, if it is False + if build_cache is not False: + # if build_cache is not False, use the passed value or the default + has_build_cache = build_cache is not None or DEFAULTS.BUILD_TXT.is_file() if prefer_cache and not has_build_cache: - return "Trying to force the use cached version information, but no build information file " \ - f"was passed found at the default location ({DEFAULTS.BUILD_TXT})." + return ( + "Trying to force the use of cached version information (--prefer_cache), " + "but no build information file was passed found at the default location " + f"({DEFAULTS.BUILD_TXT})." + ) if sections == "all": sections = "+checkpoints+git+pip" - from FastSurferCNN.utils.run_tools import Popen, PyPopen + from FastSurferCNN.utils.run_tools import Popen, PyPopen, MessageBuffer + build_cache_required = prefer_cache kw_root = {"cwd": DEFAULTS.PROJECT_ROOT, "stdout": subprocess.PIPE} - futures = {} - - from concurrent.futures import ThreadPoolExecutor + futures: dict[str, Future[str | MessageBuffer | VersionDict]] = {} with ThreadPoolExecutor() as pool: futures["version"] = pool.submit(read_and_close_version, project_file) # if we do not have git, try VERSION file else git sha and branch - if has_git and not prefer_cache: - futures["git_hash"] = Popen(["git", "rev-parse", "--short", "HEAD"], **kw_root).as_future(pool) + if has_git() and not prefer_cache: + futures["git_hash"] = Popen( + ["git", "rev-parse", "--short", "HEAD"], **kw_root + ).as_future(pool) if sections != "": - futures["git_branch"] = Popen(["git", "branch", "--show-current"], **kw_root).as_future(pool) + futures["git_branch"] = Popen( + ["git", "branch", "--show-current"], **kw_root + ).as_future(pool) if "+git" in sections: - - futures["git_status"] = pool.submit(filter_git_status, Popen(["git", "status", "-s", "-b"], **kw_root)) + futures["git_status"] = pool.submit( + filter_git_status, Popen(["git", "status", "-s", "-b"], **kw_root) + ) else: + # we go not have git, try loading the build cache build_cache_required = True - if build_cache_required: + if build_cache_required and build_cache is not False: futures["build_cache"] = pool.submit(parse_build_file, build_cache) if "+checkpoints" in sections and not prefer_cache: - def calculate_md5_for_checkpoints() -> 'MessageBuffer': + + def calculate_md5_for_checkpoints() -> "MessageBuffer": from glob import glob + files = glob(str(DEFAULTS.PROJECT_ROOT / "checkpoints" / "*")) shorten = len(str(DEFAULTS.PROJECT_ROOT)) + 1 files = [f[shorten:] for f in files] return Popen(["md5sum"] + files, **kw_root).finish() + futures["checkpoints"] = pool.submit(calculate_md5_for_checkpoints) if "+pip" in sections and not prefer_cache: - futures["pypackages"] = PyPopen(["-m", "pip", "list", "--verbose"], **kw_root).as_future(pool) + futures["pypackages"] = PyPopen( + ["-m", "pip", "list", "--verbose"], **kw_root + ).as_future(pool) - if build_cache_required: - build_cache = futures.pop("build_cache").result() + if build_cache_required and build_cache is not False: + build_cache: VersionDict = futures.pop("build_cache").result() else: - build_cache = {} + build_cache: VersionDict = get_default_version_info() build_file_kwargs = {} try: version = futures.pop("version").result() except IOError: - version = build_cache["version_no"] + version = build_cache["version"] - def __future_or_cache(key: str, futures: Dict[str, Any], cache: Dict[str, Any]) -> str: - future = futures.get(key, None) - if future: + def __future_or_cache( + key: VersionDictKeys, futures: dict[str, Future[Any]], cache: VersionDict, + ) -> str: + future: None | Future[Any] = futures.get(key, None) + if future is not None: returnmsg = future.result() if isinstance(returnmsg, str): return returnmsg elif returnmsg.retcode != 0: - raise RuntimeError(f"The calculation/determination of {key} has failed.") - return returnmsg.out_str('utf-8').strip() + raise RuntimeError( + f"The calculation/determination of {key} has failed." + ) + return returnmsg.out_str("utf-8").strip() elif key in cache: # fill from cache return cache[key] else: add_msg = "" + if key == "git_status": + add_msg += (" --sections all or --sections with +git require a build " + "cache file or a FastSurfer git directory and git.") if prefer_cache: - add_msg = " The cached build file seems to not contain this info?" + add_msg += " The cached build file seems to not contain this info?" # ERROR raise RuntimeError(f"Could not find a valid value for {key}!" + add_msg) try: - build_file_kwargs["git_hash"] = __future_or_cache("git_hash", futures, build_cache) + build_file_kwargs["git_hash"] = __future_or_cache( + "git_hash", futures, build_cache + ) if sections != "": - build_file_kwargs["git_branch"] = __future_or_cache("git_branch", futures, build_cache) - for key in ("git_status", "checkpoints", "pypackages"): + build_file_kwargs["git_branch"] = __future_or_cache( + "git_branch", futures, build_cache + ) + keys: Sequence[VersionDictKeys] = ("git_status", "checkpoints", "pypackages") + for key in keys: if DEFAULTS.VERSION_SECTIONS[key][0] in sections: # stuff that is needed build_file_kwargs[key] = __future_or_cache(key, futures, build_cache) @@ -266,7 +384,24 @@ def __future_or_cache(key: str, futures: Dict[str, Any], cache: Dict[str, Any]) return 0 -def parse_build_file(build_file: Optional[TextIOWrapper]) -> Dict[str, str]: +def get_default_version_info() -> VersionDict: + """ + Get the blank version information. + + Returns + ------- + VersionDict + A dictionary with blank version information. + """ + return { + "version_line": "N/A", + "version": "N/A", + "git_hash": "0000000", + "git_branch": "release", + } + + +def parse_build_file(build_file: Optional[TextIO]) -> VersionDict: """Read and parse a build file (same as output of `main`). Read and parse a file with version information in the format that is also the @@ -274,37 +409,48 @@ def parse_build_file(build_file: Optional[TextIOWrapper]) -> Dict[str, str]: Parameters ---------- - build_file : TextIOWrapper, optional - file-like object, will be closed. + build_file : TextIO, optional + File-like object, will be closed. Returns ------- - dict + VersionDict Dictionary with keys 'version_line', 'version', 'git_hash', 'git_branch', - 'checkpoints', 'git_status', and 'pip'. The last 3 are optional and may + 'checkpoints', 'git_status', and 'pypackages'. The last 3 are optional and may be missing depending on the content of the file. - See Also - -------- - main + Notes + ----- + See also main. """ - file_cache: Dict[str, str] = {} - try: - if build_file is None: + file_cache: VersionDict = {} + if build_file is None: + try: build_file = open(DEFAULTS.BUILD_TXT, "r") - file_cache["content"] = "".join(build_file.readlines()) - finally: + except FileNotFoundError as e: + return get_default_version_info() + file_cache["content"] = "".join(build_file.readlines()) + if not build_file.closed: build_file.close() section_pattern = re.compile("\n={3,}\n") file_cache["version_line"], *rest = section_pattern.split(file_cache["content"], 1) - version_regex = re.compile("([a-zA-Z.0-9\\-]+)(\\+([0-9A-Fa-f]+))?(\\s+\\(([^)]+)\\))?\\s*") + version_regex = re.compile( + "([a-zA-Z.0-9\\-]+)(\\+([0-9A-Fa-f]+))?(\\s+\\(([^)]+)\\))?\\s*" + ) hits = version_regex.search(file_cache["version_line"]) if hits is None: - raise RuntimeError("The build file has invalid formatting, version tag not " - "recognized!", - f"First line was '{file_cache['version_line']}' and did " - f"not fit the pattern '{version_regex.pattern}.") - file_cache["version"], _, file_cache["git_hash"], _, file_cache["git_branch"] = hits.groups("") + raise RuntimeError( + f"The build file {build_file.name} has invalid formatting, version tag not " + f"recognized! First line was '{file_cache['version_line']}' and did " + f"not fit the pattern '{version_regex.pattern}'.", + ) + ( + file_cache["version"], + _, + file_cache["git_hash"], + _, + file_cache["git_branch"], + ) = hits.groups("") if file_cache["git_hash"]: file_cache["version_tag"] = file_cache["version"] + "+" + file_cache["git_hash"] else: @@ -318,13 +464,14 @@ def get_section_name_by_header(header: str) -> Optional[str]: while len(rest) > 0: section_header, section_content, *rest = section_pattern.split(rest[0], 2) section_name = get_section_name_by_header(section_header) - if section_name: - file_cache[section_name] = section_content + if section_name and section_name in get_args(VersionDictKeys): + file_cache[cast(VersionDictKeys, section_name)] = section_content return file_cache -def read_version_from_project_file(project_file: TextIOWrapper) -> str: - """Read the version entry from the pyproject file. +def read_version_from_project_file(project_file: TextIO) -> str: + """ + Read the version entry from the pyproject file. Searches for the [project] section in project_file, therein the version attribute. Extracts the Value. The file pointer is right after the version attribute at return @@ -332,15 +479,16 @@ def read_version_from_project_file(project_file: TextIOWrapper) -> str: Parameters ---------- - project_file : - file pointer to the project file to read from. + project_file : TextIO + File pointer to the project file to read from. Returns ------- - the version string + str + The version string. """ project_pattern = re.compile(r"\[project]") - version_pattern = re.compile("version\\s*=\\s*(\\\")?([^\\\"]+)\\1") + version_pattern = re.compile('version\\s*=\\s*([\\"\']?)([^\\"]+)\\1') version = "unspecified" seek_to_project = True @@ -355,40 +503,53 @@ def read_version_from_project_file(project_file: TextIOWrapper) -> str: if hits is not None: seek_to_version = False version = hits.group(2) - if version[0] == "\"": - version = version.strip("\"") + if version[0] == '"': + version = version.strip('"') return version -def filter_git_status(git_process: 'FastSurferCNN.utils.run_tools.Popen') -> str: +def filter_git_status(git_process: "FastSurferCNN.utils.run_tools.Popen") -> str: """ - Takes a running git status process and filters the output. + Filter the output of a running git status process. Parameters ---------- git_process : FastSurferCNN.utils.run_tools.Popen - The Popen process object, that will return the git status output + The Popen process object that will return the git status output. Returns ------- - The git status string filtered for __pycache__ + str + The git status string filtered to exclude lines containing "__pycache__". """ - from FastSurferCNN.utils.run_tools import Popen finished_process = git_process.finish() if finished_process.retcode != 0: raise RuntimeError("Failed git status command") - git_status_text = finished_process.out_str('utf-8') - return "\n".join(filter(lambda x: "__pycache__" not in x, git_status_text.split("\n"))) + git_status_text = finished_process.out_str("utf-8") + return "\n".join( + filter(lambda x: "__pycache__" not in x, git_status_text.split("\n")) + ) -def read_and_close_version(project_file: Optional[TextIOWrapper] = None) -> str: - """Read and close the version from the pyproject file. Also fill default. +def read_and_close_version(project_file: Optional[TextIO] = None) -> str: + """ + Read and close the version from the pyproject file. Also fill default. Always closes the file pointer. - See Also - -------- - read_version_from_project_file + Parameters + ---------- + project_file : TextIO, optional + Project file. + + Returns + ------- + str + The version read from the pyproject file. + + Notes + ----- + See also FastSurferCNN.version.read_version_from_project_file """ if project_file is None: project_file = open(DEFAULTS.PROJECT_TOML, "r") @@ -401,5 +562,6 @@ def read_and_close_version(project_file: Optional[TextIOWrapper] = None) -> str: if __name__ == "__main__": import sys + args = make_parser().parse_args() sys.exit(main(**vars(args))) diff --git a/HypVINN/README.md b/HypVINN/README.md new file mode 100644 index 00000000..66ae036d --- /dev/null +++ b/HypVINN/README.md @@ -0,0 +1,125 @@ +# Hypothalamus pipeline + +Hypothalamic subfields segmentation pipeline + +### Input +* T1w image, a T2w image, or both images. Note: Input images to the tool need to be Bias-Field corrected. + +### Requirements +* Same as FastSurfer. +* If the T1w and T2w images are available and not co-registered, FreeSurfer should be sourced to run the registration code, and the mri_coreg and mri_vol2vol binaries should also be available. + +### Model weights +* EUDAT (FZ Jülich) data repository: https://b2share.fz-juelich.de/records/2af6da63d5c1414b832c1f606bbd068a +* Zenodo data repository: https://zenodo.org/records/11184216 + +Note: These weights (version 1.1) are retrained compared to paper ([version 1.0](https://b2share.fz-juelich.de/records/27ab0a28c11741558679c819d608f1e7)) for better rotation generalization, performance is equivalent. + +### Pipeline Steps +1. Registration (optional, only required for multi-modal input) +2. Hypothalamus Segmentation + +### Running the tool +- The HypVINN output can be obtained by running the default `run_fastsurfer.sh` script (for more information see [FastSurfer documentation](../README.md)). +- HypVINN can also be run independently by running `HypVINN/run_prediction.py`, however we recommend running the whole FastSurfer pipeline as it includes all the required pre-processing steps. +- HypVINN has the following arguments: +### Input and output arguments + * `--sid ` : Subject ID, the subject data upon which to operate + * `--sd ` : Directory in which evaluation results should be written. + * `--t1 ` : T1 image path + * `--t2 ` : T2 image path + * `--seg_log` : Path to file in which run logs will be saved. If not set logs will be stored in `/sd/sid/scripts/hypvinn_seg.log` +### Image processing options + * `--reg_mode` : Ignored, if no T2 image is passed. Specifies the registration method used to register T1 and T2 images. Options are 'coreg' (default) for mri_coreg, 'robust' for mri_robust_register, and 'none' to skip registration (this requires T1 and T2 are externally co-registered). + * `--qc_snap`: Activate the creation of QC snapshots of the predicted HypVINN segmentation in `/sd/sid/qc_snapshots`. The created QC snapshots are created to simplify the visual quality control process. +### FastSurfer Technical parameters (see FastSurfer documentation) + * `--device` + * `--viewgg_device` + * `--threads` + * `--batch_size` + * `--async_io` + * `--allow_root` + +### Checkpoint to load + * `--ckpt_cor ` : Coronal checkpoint to load, default = $FASTSURFER_ROOT/checkpoints/HypVINN_axial_v1.1.0.pkl + * `--ckpt_ax ` : Axial checkpoint to load, default = $FASTSURFER_ROOT/checkpoints/HypVINN_coronal_v1.1.0.pkl + * `--ckpt_sag ` : Sagittal checkpoint to load, default = $FASTSURFER_ROOT/checkpoints/HypVINN_sagittal_v1.1.0.pkl + +### CFG-file with default options for network + * `--cfg_cor ` : Coronal config file to load, default = $FASTSURFER_ROOT/HypVINN/config/HypVINN_coronal_v1.1.0.yaml + * `--cfg_ax ` : Axial config file to load, default = $FASTSURFER_ROOT/HypVINN/config/HypVINN_axial_v1.1.0.yaml + * `--cfg_sag ` : Sagittal config file to load, default = $FASTSURFER_ROOT/HypVINN/config/HypVINN_sagittal_v1.1.0.yaml + +### Usage +- The Hypothalamus pipeline can be run by using a T1 a T2 or both images. +- Is recommended that all input images are bias field corrected and when passing both T1 and T2 they need to be co-registered. +- The Hypvinn pipeline can do the registration by itself (step 1). This step can be skipped if images are already registered externally. +- Bias field-corrected images are generated by default by the FastSurfer pipeline; therefore, we recommend running the entire FastSurfer pipeline. If you already have a subject's FastSurfer output without Hypvinn output, check example 5. +1. Run HypVINN pipeline + ``` + python HypVINN/run_prediction.py --sid test_subject --sd /output \ + --t1 /data/test_subject_t1_bias_field_corrected.nii.gz \ + --t2 /data/test_subject_t2_bias_field_corrected.nii.gz \ + --reg_mode coreg \ + --batch_size 6 + ``` +2. Run HypVINN pipeline only using a t1 + ``` + python HypVINN/run_prediction.py --sid test_subject --sd /output \ + --t1 /data/test_subject_t1_bias_field_corrected.nii.gz \ + --batch_size 6 + ``` + +3. Run HypVINN pipeline without the registration step + ``` + python HypVINN/run_prediction.py --sid test_subject --sd /output \ + --t1 /data/test_subject_t1_bias_field_corrected.nii.gz \ + --t2 /data/test_subject_t2_bias_field_corrected_and_coregistered_to_t1.nii.gz \ + --reg_mode none \ + --batch_size 6 + ``` + +4. Run HypVINN pipeline with creation of qc snapshots + ``` + python HypVINN/run_prediction.py --sid test_subject --sd /output \ + --t1 /data/test_subject_t1_bias_field_corrected.nii.gz \ + --t2 /data/test_subject_t2_bias_field_corrected.nii.gz \ + --reg_mode coreg \ + --batch_size 6 --qc_snap + ``` +5. Run HypVINN pipeline from an existing FastSurfer subject output -- recommended when FastSurfer output is there without the Hypothalamus module + ``` + python HypVINN/run_prediction.py --sid test_subject --sd /output \ + --t1 /output/test_subject/mri/orig_nu.mgz \ + --seg_log /output/test_subject/scripts/deep-seg.log \ + --batch_size 6 --qc_snap + ``` + +### Output +``` bash +#Output Scheme +|-- output_dir + |--sid + |-- mri : MRI outputs + |--hypothalamus.HypVINN.nii.gz(Hypothalamus Segmentation) + |-- hypothalamus_mask.HypVINN.nii.gz (Hypothalamus Segmentation Mask) + |-- transforms + |-- t2tot1.lta (FreeSurfer registration file, only available if registration is performed) + |-- qc_snapshots : QC outputs (optional) + |-- hypothalamus.HypVINN_qc_screenshoot.png (Coronal quality control image) + |-- stats : Statistics outputs + |-- hypothalamus.HypVINN.stats (Segmentation stats) + ``` + + +### Developer + +Santiago Estrada : santiago.estrada@dzne.de + +### Citation +If you use the HypVINN module please cite +``` +Santiago Estrada, David Kügler, Emad Bahrami, Peng Xu, Dilshad Mousa, Monique M.B. Breteler, N. Ahmad Aziz, Martin Reuter; +FastSurfer-HypVINN: Automated sub-segmentation of the hypothalamus and adjacent structures on high-resolutional brain MRI. +Imaging Neuroscience 2023; 1 1–32. doi: https://doi.org/10.1162/imag_a_00034 +``` diff --git a/HypVINN/__init__.py b/HypVINN/__init__.py new file mode 100644 index 00000000..870c1df8 --- /dev/null +++ b/HypVINN/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "config", + "data_loader", + "models", + "utils", + "inference", + "run_prediction", +] diff --git a/HypVINN/config/HypVINN_ColorLUT.txt b/HypVINN/config/HypVINN_ColorLUT.txt new file mode 100644 index 00000000..b458e4bb --- /dev/null +++ b/HypVINN/config/HypVINN_ColorLUT.txt @@ -0,0 +1,28 @@ +#No. Label Name: R G B A +0 Background 0 0 0 0 +1 R-N.opticus 70 130 180 0 +2 L-N.opticus 130 180 70 0 +3 R-C.mammilare 205 62 78 0 +4 R-Optic-tract 80 120 134 0 +5 L-Optic-tract 196 58 250 0 +6 L-C.mammilare 0 148 0 0 +7 R-Chiasma-Opticum 220 248 164 0 +8 L-Chiasma-Opticum 230 148 34 0 +9 Ant-Commisure 10 180 225 0 +10 3rd-Ventricle 118 0 100 0 +11 R-Fornix 122 200 120 0 +12 L-Fornix 236 13 176 0 +13 R-Globus-Pallidus 48 255 0 0 +14 Epiphysis 204 182 142 0 +16 Hypophysis 119 159 176 0 +17 Infundibulum 220 216 20 0 +20 L-Globus-pallidus 60 58 210 0 +122 Tuberal-Region 120 60 110 0 +126 L-Med-Hypothalamus 165 255 0 0 +127 L-Lat-Hypothalamus 0 255 127 0 +128 L-Ant-Hypothalamus 165 42 42 0 +129 L-Post-Hypothalamus 255 215 0 0 +226 R-Med-Hypothalamus 115 255 0 0 +227 R-Lat-Hypothalamus 60 255 127 0 +228 R-Ant-Hypothalamus 165 142 42 0 +229 R-Post-Hypothalamus 255 170 20 0 \ No newline at end of file diff --git a/HypVINN/config/HypVINN_axial_v1.0.0.yaml b/HypVINN/config/HypVINN_axial_v1.0.0.yaml new file mode 100644 index 00000000..45f4c948 --- /dev/null +++ b/HypVINN/config/HypVINN_axial_v1.0.0.yaml @@ -0,0 +1,44 @@ +MODEL: + MODEL_NAME : "HypVinn" + NUM_CLASSES : 27 + NUM_CHANNELS : 14 + LOSS_FUNC: "combined" + KERNEL_H: 3 + KERNEL_W: 3 + BASE_RES: 1.0 + NUM_FILTERS_INTERPOL : 64 + NUM_FILTERS : 80 + OUT_TENSOR_WIDTH: 320 + OUT_TENSOR_HEIGHT: 320 + HEIGHT: 256 + WIDTH: 256 + MODE : 't1t2' + MULTI_AUTO_W : True + HETERO_INPUT : True + +DATA: + SIZES : [320] + PATH_HDF5_TRAIN : '/data/train_data_08mm/split5/train_split5_axial.hdf5' + PATH_HDF5_VAL : '' + PLANE : 'axial' + AUG: ['Scaling','Gaussian','Rotation','Translation','BiasField'] + PADDED_SIZE : 320 + REF_FRAME : -1 + VAL_REF_FRAME : 0 + +TRAIN : + BATCH_SIZE : 16 + NUM_EPOCHS : 100 + RUN_VAL : FALSE + +OPTIMIZER: + OPTIMIZING_METHOD : 'adamW' + WEIGHT_DECAY : 1e-4 + BASE_LR : 0.05 + GAMMA: 0.1 + LR_SCHEDULER : 'multiStep' + MILESTONES: [70] + MOMENTUM : 0.95 + +NUM_GPUS : 2 +LOG_DIR : "/src/FSVINN/hetero_models/split5" \ No newline at end of file diff --git a/HypVINN/config/HypVINN_axial_v1.1.0.yaml b/HypVINN/config/HypVINN_axial_v1.1.0.yaml new file mode 100644 index 00000000..8bb6af60 --- /dev/null +++ b/HypVINN/config/HypVINN_axial_v1.1.0.yaml @@ -0,0 +1,44 @@ +MODEL: + MODEL_NAME : "HypVinn" + NUM_CLASSES : 27 + NUM_CHANNELS : 14 + LOSS_FUNC: "combined" + KERNEL_H: 3 + KERNEL_W: 3 + BASE_RES: 1.0 + NUM_FILTERS_INTERPOL : 64 + NUM_FILTERS : 80 + OUT_TENSOR_WIDTH: 320 + OUT_TENSOR_HEIGHT: 320 + HEIGHT: 256 + WIDTH: 256 + MODE : 't1t2' + MULTI_AUTO_W : True + HETERO_INPUT : True + +DATA: + SIZES : [320] + PATH_HDF5_TRAIN : '/data/train_data_08mm/split5/train_split5_conf_axial.hdf5' + PATH_HDF5_VAL : '' + PLANE : 'axial' + AUG: ['Scaling','Gaussian','Rotation','Translation','BiasField'] + PADDED_SIZE : 320 + REF_FRAME : -1 + VAL_REF_FRAME : 0 + +TRAIN : + BATCH_SIZE : 16 + NUM_EPOCHS : 100 + RUN_VAL : FALSE + +OPTIMIZER: + OPTIMIZING_METHOD : 'adamW' + WEIGHT_DECAY : 1e-4 + BASE_LR : 0.05 + GAMMA: 0.1 + LR_SCHEDULER : 'multiStep' + MILESTONES: [70] + MOMENTUM : 0.95 + +NUM_GPUS : 2 +LOG_DIR : "/src/FSVINN/hetero_models/split5" \ No newline at end of file diff --git a/HypVINN/config/HypVINN_coronal_v1.0.0.yaml b/HypVINN/config/HypVINN_coronal_v1.0.0.yaml new file mode 100644 index 00000000..828caaec --- /dev/null +++ b/HypVINN/config/HypVINN_coronal_v1.0.0.yaml @@ -0,0 +1,44 @@ +MODEL: + MODEL_NAME : "HypVinn" + NUM_CLASSES : 27 + NUM_CHANNELS : 14 + LOSS_FUNC: "combined" + KERNEL_H: 3 + KERNEL_W: 3 + BASE_RES: 1.0 + NUM_FILTERS_INTERPOL : 64 + NUM_FILTERS : 80 + OUT_TENSOR_WIDTH: 320 + OUT_TENSOR_HEIGHT: 320 + HEIGHT: 256 + WIDTH: 256 + MODE : 't1t2' + MULTI_AUTO_W : True + HETERO_INPUT : True + +DATA: + SIZES : [320] + PATH_HDF5_TRAIN : '/data/train_data_08mm/split5/train_split5_coronal.hdf5' + PATH_HDF5_VAL : '' + PLANE : 'coronal' + AUG: ['Scaling','Gaussian','Rotation','Translation','BiasField'] + PADDED_SIZE : 320 + REF_FRAME : -1 + VAL_REF_FRAME : 0 + +TRAIN : + BATCH_SIZE : 16 + NUM_EPOCHS : 100 + RUN_VAL : FALSE + +OPTIMIZER: + OPTIMIZING_METHOD : 'adamW' + WEIGHT_DECAY : 1e-4 + BASE_LR : 0.05 + GAMMA: 0.1 + LR_SCHEDULER : 'multiStep' + MILESTONES: [70] + MOMENTUM : 0.95 + +NUM_GPUS : 2 +LOG_DIR : "/src/FSVINN/hetero_models/split5" diff --git a/HypVINN/config/HypVINN_coronal_v1.1.0.yaml b/HypVINN/config/HypVINN_coronal_v1.1.0.yaml new file mode 100644 index 00000000..19d3713c --- /dev/null +++ b/HypVINN/config/HypVINN_coronal_v1.1.0.yaml @@ -0,0 +1,44 @@ +MODEL: + MODEL_NAME : "HypVinn" + NUM_CLASSES : 27 + NUM_CHANNELS : 14 + LOSS_FUNC: "combined" + KERNEL_H: 3 + KERNEL_W: 3 + BASE_RES: 1.0 + NUM_FILTERS_INTERPOL : 64 + NUM_FILTERS : 80 + OUT_TENSOR_WIDTH: 320 + OUT_TENSOR_HEIGHT: 320 + HEIGHT: 256 + WIDTH: 256 + MODE : 't1t2' + MULTI_AUTO_W : True + HETERO_INPUT : True + +DATA: + SIZES : [320] + PATH_HDF5_TRAIN : '/data/train_data_08mm/split5/train_split5_conf_coronal.hdfs' + PATH_HDF5_VAL : '' + PLANE : 'coronal' + AUG: ['Scaling','Gaussian','Rotation','Translation','BiasField'] + PADDED_SIZE : 320 + REF_FRAME : -1 + VAL_REF_FRAME : 0 + +TRAIN : + BATCH_SIZE : 16 + NUM_EPOCHS : 100 + RUN_VAL : FALSE + +OPTIMIZER: + OPTIMIZING_METHOD : 'adamW' + WEIGHT_DECAY : 1e-4 + BASE_LR : 0.05 + GAMMA: 0.1 + LR_SCHEDULER : 'multiStep' + MILESTONES: [70] + MOMENTUM : 0.95 + +NUM_GPUS : 2 +LOG_DIR : "/src/FSVINN/hetero_models/split5" diff --git a/HypVINN/config/HypVINN_sagittal_v1.0.0.yaml b/HypVINN/config/HypVINN_sagittal_v1.0.0.yaml new file mode 100644 index 00000000..ee63ab31 --- /dev/null +++ b/HypVINN/config/HypVINN_sagittal_v1.0.0.yaml @@ -0,0 +1,44 @@ +MODEL: + MODEL_NAME : "HypVinn" + NUM_CLASSES : 17 + NUM_CHANNELS : 14 + LOSS_FUNC: "combined" + KERNEL_H: 3 + KERNEL_W: 3 + BASE_RES: 1.0 + NUM_FILTERS_INTERPOL : 64 + NUM_FILTERS : 80 + OUT_TENSOR_WIDTH: 320 + OUT_TENSOR_HEIGHT: 320 + HEIGHT: 256 + WIDTH: 256 + MODE : 't1t2' + MULTI_AUTO_W : True + HETERO_INPUT : True + +DATA: + SIZES : [320] + PATH_HDF5_TRAIN : '/data/train_data_08mm/split5/train_split5_sagittal.hdf5' + PATH_HDF5_VAL : '' + PLANE : 'sagittal' + AUG: ['Scaling','Gaussian','Rotation','Translation','BiasField'] + PADDED_SIZE : 320 + REF_FRAME : -1 + VAL_REF_FRAME : 0 + +TRAIN : + BATCH_SIZE : 16 + NUM_EPOCHS : 100 + RUN_VAL : FALSE + +OPTIMIZER: + OPTIMIZING_METHOD : 'adamW' + WEIGHT_DECAY : 1e-4 + BASE_LR : 0.05 + GAMMA: 0.1 + LR_SCHEDULER : 'multiStep' + MILESTONES: [70] + MOMENTUM : 0.95 + +NUM_GPUS : 2 +LOG_DIR : "/src/FSVINN/hetero_models/split5" diff --git a/HypVINN/config/HypVINN_sagittal_v1.1.0.yaml b/HypVINN/config/HypVINN_sagittal_v1.1.0.yaml new file mode 100644 index 00000000..ac8baaba --- /dev/null +++ b/HypVINN/config/HypVINN_sagittal_v1.1.0.yaml @@ -0,0 +1,44 @@ +MODEL: + MODEL_NAME : "HypVinn" + NUM_CLASSES : 17 + NUM_CHANNELS : 14 + LOSS_FUNC: "combined" + KERNEL_H: 3 + KERNEL_W: 3 + BASE_RES: 1.0 + NUM_FILTERS_INTERPOL : 64 + NUM_FILTERS : 80 + OUT_TENSOR_WIDTH: 320 + OUT_TENSOR_HEIGHT: 320 + HEIGHT: 256 + WIDTH: 256 + MODE : 't1t2' + MULTI_AUTO_W : True + HETERO_INPUT : True + +DATA: + SIZES : [320] + PATH_HDF5_TRAIN : '/data/train_data_08mm/split5/train_split5_conf_sagittal.hdfs' + PATH_HDF5_VAL : '' + PLANE : 'sagittal' + AUG: ['Scaling','Gaussian','Rotation','Translation','BiasField'] + PADDED_SIZE : 320 + REF_FRAME : -1 + VAL_REF_FRAME : 0 + +TRAIN : + BATCH_SIZE : 16 + NUM_EPOCHS : 100 + RUN_VAL : FALSE + +OPTIMIZER: + OPTIMIZING_METHOD : 'adamW' + WEIGHT_DECAY : 1e-4 + BASE_LR : 0.05 + GAMMA: 0.1 + LR_SCHEDULER : 'multiStep' + MILESTONES: [70] + MOMENTUM : 0.95 + +NUM_GPUS : 2 +LOG_DIR : "/src/FSVINN/hetero_models/split5" diff --git a/HypVINN/config/__init__.py b/HypVINN/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/HypVINN/config/checkpoint_paths.yaml b/HypVINN/config/checkpoint_paths.yaml new file mode 100644 index 00000000..7071534b --- /dev/null +++ b/HypVINN/config/checkpoint_paths.yaml @@ -0,0 +1,13 @@ +url: +- "https://zenodo.org/records/11184216/files" +- "https://b2share.fz-juelich.de/api/files/d9e37247-5455-4c83-853d-21e31fb5bea5" + +checkpoint: + axial: "checkpoints/HypVINN_axial_v1.1.0.pkl" + coronal: "checkpoints/HypVINN_coronal_v1.1.0.pkl" + sagittal: "checkpoints/HypVINN_sagittal_v1.1.0.pkl" + +config: + axial: "HypVINN/config/HypVINN_axial_v1.1.0.yaml" + coronal: "HypVINN/config/HypVINN_coronal_v1.1.0.yaml" + sagittal: "HypVINN/config/HypVINN_sagittal_v1.1.0.yaml" \ No newline at end of file diff --git a/HypVINN/config/hypvinn.py b/HypVINN/config/hypvinn.py new file mode 100644 index 00000000..8df5807e --- /dev/null +++ b/HypVINN/config/hypvinn.py @@ -0,0 +1,271 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from yacs.config import CfgNode as CN + +_C = CN() + +# ---------------------------------------------------------------------------- # +# Model options +# ---------------------------------------------------------------------------- # +_C.MODEL = CN() + +# Name of model +_C.MODEL.MODEL_NAME = "" + +#modalities 't1', 't2' or 't1t2' +_C.MODEL.MODE = "t1" + +# Number of classes to predict, including background +_C.MODEL.NUM_CLASSES = 79 + +# Loss function, combined = dice loss + cross entropy, combined2 = dice loss + boundary loss +_C.MODEL.LOSS_FUNC = "combined" + +# Filter dimensions for DenseNet (all layers same) +_C.MODEL.NUM_FILTERS = 64 + +# Filter dimensions for Input Interpolation block (currently all the same) +_C.MODEL.NUM_FILTERS_INTERPOL = 32 + +# Number of input channels (slice thickness) +_C.MODEL.NUM_CHANNELS = 7 + +# Number of branches for attention mechanism +_C.MODEL.NUM_BRANCHES = 5 + +# Height of convolution kernels +_C.MODEL.KERNEL_H = 5 + +# Width of convolution kernels +_C.MODEL.KERNEL_W = 5 + +# size of Classifier kernel +_C.MODEL.KERNEL_C = 1 + +# Stride during convolution +_C.MODEL.STRIDE_CONV = 1 + +# Stride during pooling +_C.MODEL.STRIDE_POOL = 2 + +# Size of pooling filter +_C.MODEL.POOL = 2 + +# The height of segmentation model (after interpolation layer) +_C.MODEL.HEIGHT = 256 + +# The width of segmentation model +_C.MODEL.WIDTH = 256 + +# The base resolution of the segmentation model (after interpolation layer) +_C.MODEL.BASE_RES = 1.0 + +# Interpolation mode for up/downsampling in Flex networks +_C.MODEL.INTERPOLATION_MODE = "bilinear" + +# Crop positions for up/downsampling in Flex networks +_C.MODEL.CROP_POSITION = "top_left" + +# Out Tensor dimensions for interpolation layer +_C.MODEL.OUT_TENSOR_WIDTH = 320 +_C.MODEL.OUT_TENSOR_HEIGHT = 320 + +# Flag, for smoothing testing (double number of feature maps before/after interpolation block) +_C.MODEL.SMOOTH = False + +# Options for attention +_C.MODEL.ATTENTION_BASE = False +_C.MODEL.ATTENTION_INPUT = False +_C.MODEL.ATTENTION_OUTPUT = False + +# Options for addition instead of Maxout +_C.MODEL.ADDITION = False + +#Options for multi modalitie +_C.MODEL.MULTI_AUTO_W = False # weight per modalitiy +_C.MODEL.MULTI_AUTO_W_CHANNELS = False #weight per channel +# Flag, for smoothing testing (double number of feature maps before the input interpolation block) +_C.MODEL.MULTI_SMOOTH = False +# Brach weights can be aleatory set to zero +_C.MODEL.HETERO_INPUT = False +# Flag for replicating any given modality into the two branches. This branch require that the hetero_input also set to TRUE +_C.MODEL.DUPLICATE_INPUT = False +# ---------------------------------------------------------------------------- # +# Training options +# ---------------------------------------------------------------------------- # +_C.TRAIN = CN() + +# input batch size for training +_C.TRAIN.BATCH_SIZE = 16 + +# how many batches to wait before logging training status +_C.TRAIN.LOG_INTERVAL = 50 + +# Resume training from the latest checkpoint in the output directory. +_C.TRAIN.RESUME = False + +# The experiment number to resume from +_C.TRAIN.RESUME_EXPR_NUM = 1 + +# number of epochs to train +_C.TRAIN.NUM_EPOCHS = 30 + +# number of steps (iteration) which depends on dataset +_C.TRAIN.NUM_STEPS = 10 + +# To fine tune model or not +_C.TRAIN.FINE_TUNE = False + +# checkpoint period +_C.TRAIN.CHECKPOINT_PERIOD = 2 + +# number of worker for dataloader +_C.TRAIN.NUM_WORKERS = 8 + +# run validation +_C.TRAIN.RUN_VAL = True + +# ---------------------------------------------------------------------------- # +# Testing options +# ---------------------------------------------------------------------------- # +_C.TEST = CN() + +# input batch size for testing +_C.TEST.BATCH_SIZE = 16 + +# ---------------------------------------------------------------------------- # +# Data options +# ---------------------------------------------------------------------------- # + +_C.DATA = CN() + +# path to training hdf5-dataset +_C.DATA.PATH_HDF5_TRAIN = "" + +# path to validation hdf5-dataset +_C.DATA.PATH_HDF5_VAL = "" + +# The plane to load ['axial', 'coronal', 'sagittal'] +_C.DATA.PLANE = "coronal" + +# Reference volume frame during training, -1 value randomly select the frame : input data B,H,W,C,FRAME +_C.DATA.REF_FRAME = 0 + +# Reference volume frame during validation : input data B,H,W,C,FRAME +_C.DATA.VAL_REF_FRAME = 0 + +# Available size for dataloader +# This for the multi-scale dataloader +_C.DATA.SIZES = [256, 311, 320] + +# the size that all inputs are padded to +_C.DATA.PADDED_SIZE = 320 + +# classes to consider in the Boundary loss (default: all -> 79) +_C.DATA.BOUNDARY_CLASSES = "None" + +# Augmentations +_C.DATA.AUG = ["Flip", "Elastic", "Scaling", "Rotation", "Translation", "RAnisotropy", "BiasField", "RGamma"] + +#Frequency of the hetero augmentations [both t1 and t2,only t1, only t2 ] +_C.DATA.HETERO_FREQ = [0.5, 0.25, 0.25] + +# ---------------------------------------------------------------------------- # +# DataLoader options (common for test and train) +# ---------------------------------------------------------------------------- # +_C.DATA_LOADER = CN() + +# the split number in cross validation +_C.DATA.SPLIT_NUM = 1 +# Number of data loader workers +_C.DATA_LOADER.NUM_WORKERS = 8 + +# Load data to pinned host memory. +_C.DATA_LOADER.PIN_MEMORY = True + +# ---------------------------------------------------------------------------- # +# Optimizer options +# ---------------------------------------------------------------------------- # +_C.OPTIMIZER = CN() + +# Base learning rate. +_C.OPTIMIZER.BASE_LR = 0.01 + +# Learning rate scheduler, step_lr, cosineWarmRestarts +_C.OPTIMIZER.LR_SCHEDULER = "step_lr" + +# Multiplicative factor of learning rate decay in step_lr +_C.OPTIMIZER.GAMMA = 0.3 + +# Period of learning rate decay in step_lr +_C.OPTIMIZER.STEP_SIZE = 5 + +# minimum learning in cosine lr policy +_C.OPTIMIZER.ETA_MIN = 0.0001 + +# number of iterations for the first restart in cosineWarmRestarts +_C.OPTIMIZER.T_ZERO = 10 + +# A factor increases T_i after a restart in cosineWarmRestarts +_C.OPTIMIZER.T_MULT = 2 + +# MultiStep lr scheduler params ----------------------------- +_C.OPTIMIZER.MILESTONES = [20, 40] + +# Momentum +_C.OPTIMIZER.MOMENTUM = 0.9 + +# Momentum dampening +_C.OPTIMIZER.DAMPENING = 0.0 + +# Nesterov momentum +_C.OPTIMIZER.NESTEROV = True + +# L2 regularization +_C.OPTIMIZER.WEIGHT_DECAY = 1e-4 + +# Optimization method [sgd, adam] +_C.OPTIMIZER.OPTIMIZING_METHOD = "adam" + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # + +# Number of GPUs to use +_C.NUM_GPUS = 1 + +# log directory for run +_C.LOG_DIR = "./experiments" + +# experiment number +_C.EXPR_NUM = "Default" + +# Note that non-determinism may still be present due to non-deterministic +# operator implementations in GPU operator libraries. +_C.RNG_SEED = 1 + + +def get_cfg_hypvinn(): + """ + Get a yacs CfgNode object with default values for HypVINN project. + + Returns + ------- + _C : yacs.config.CfgNode + A clone of the default configuration node for the HypVINN project. + """ + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() \ No newline at end of file diff --git a/HypVINN/config/hypvinn_files.py b/HypVINN/config/hypvinn_files.py new file mode 100644 index 00000000..db444260 --- /dev/null +++ b/HypVINN/config/hypvinn_files.py @@ -0,0 +1,27 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +from FastSurferCNN.utils.checkpoint import FASTSURFER_ROOT + + +HYPVINN_LUT = FASTSURFER_ROOT / "HypVINN/config/HypVINN_ColorLUT.txt" + +HYPVINN_STATS_NAME = "hypothalamus.HypVINN.stats" + +HYPVINN_MASK_NAME = "hypothalamus_mask.HypVINN.nii.gz" + +HYPVINN_SEG_NAME = "hypothalamus.HypVINN.nii.gz" + +HYPVINN_QC_IMAGE_NAME = "hypothalamus.HypVINN_qc_screenshoot.png" diff --git a/HypVINN/config/hypvinn_global_var.py b/HypVINN/config/hypvinn_global_var.py new file mode 100644 index 00000000..04695653 --- /dev/null +++ b/HypVINN/config/hypvinn_global_var.py @@ -0,0 +1,133 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +import numpy as np + +Plane = Literal["axial", "coronal", "sagittal"] + + +HYPVINN_CLASS_NAMES = { + 0: "Background", + + 1: "R-N.opticus", + 2: "L-N.opticus", + + 3: "R-C.mammilare", + 6: "L-C.mammilare", + + 4: "R-Optic-tract", + 5: "L-Optic-tract", + + 7: "R-Chiasma-Opticum", + 8: "L-Chiasma-Opticum", + + 9: "Ant-Commisure", + 10: "Third-Ventricle", + + 11: "R-Fornix", + 12: "L-Fornix", + + 14: "Epiphysis", + 16: "Hypophysis", + 17: "Infundibulum", + + 13: "R-Globus-Pallidus", + 20: "L-Globus-pallidus", + + 122: "Tuberal-Region", + + 126: "L-Med-Hypothalamus", + 226: "R-Med-Hypothalamus", + + 127: "L-Lat-Hypothalamus", + 227: "R-Lat-Hypothalamus", + + 128: "L-Ant-Hypothalamus", + 228: "R-Ant-Hypothalamus", + + 129: "L-Post-Hypothalamus", + 229: "R-Post-Hypothalamus", +} + +FS_CLASS_NAMES = { + "Background": 0, + + "R-N.opticus": 961, + "L-N.opticus": 962, + "R-C.mammilare": 963, + "R-Optic-tract": 964, + "L-Optic-tract": 965, + "L-C.mammilare": 966, + "R-Chiasma-Opticum": 967, + "L-Chiasma-Opticum": 968, + "Ant-Commisure": 969, + "Third-Ventricle": 970, + "R-Fornix": 971, + "L-Fornix": 972, + "Epiphysis": 973, + "Hypophysis": 974, + "Infundibulum": 975, + "Tuberal-Region": 976, + "L-Med-Hypothalamus": 977, + "L-Lat-Hypothalamus": 978, + "L-Ant-Hypothalamus": 979, + "L-Post-Hypothalamus": 980, + "R-Med-Hypothalamus": 981, + "R-Lat-Hypothalamus": 982, + "R-Ant-Hypothalamus": 983, + "R-Post-Hypothalamus": 984, + #excluded ids + "R-Globus-Pallidus": 985, + "L-Globus-pallidus": 986, +} + +planes = ("axial", "coronal", "sagittal") + +hyposubseg_labels = ( + np.array(list(HYPVINN_CLASS_NAMES.keys())), + np.array([0, 1, 3, 4, 7, 9, 10, 11, 14, 16, 17, 13, 122, 226, 227, 228, 229]), +) + +SAG2FULL_MAP = { + # lbl: sag_lbl_index + 0: 0, + 1: 1, + 2: 1, + 3: 2, + 6: 2, + 4: 3, + 5: 3, + 7: 4, + 8: 4, + 9: 5, + 10: 6, + 11: 7, + 12: 7, + 14: 8, + 16: 9, + 17: 10, + 13: 11, + 20: 11, + 122: 12, + 126: 13, + 226: 13, + 127: 14, + 227: 14, + 128: 15, + 228: 15, + 129: 16, + 229: 16 + } diff --git a/HypVINN/data_loader/__init__.py b/HypVINN/data_loader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/HypVINN/data_loader/data_utils.py b/HypVINN/data_loader/data_utils.py new file mode 100644 index 00000000..c7caee71 --- /dev/null +++ b/HypVINN/data_loader/data_utils.py @@ -0,0 +1,271 @@ +# Copyright 2024 +# AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# IMPORTS +import nibabel as nib +import numpy as np +from numpy import typing as npt + +from FastSurferCNN.data_loader.conform import getscale, scalecrop +from HypVINN.config.hypvinn_global_var import ( + hyposubseg_labels, SAG2FULL_MAP, HYPVINN_CLASS_NAMES, FS_CLASS_NAMES, +) + + +## +# Helper Functions +## + + +def calculate_flip_orientation(iornt: np.ndarray, base_ornt: np.ndarray) -> np.ndarray: + """ + Compute the flip orientation transform. + + ornt[N, 1] is flip of axis N, where 1 means no flip and -1 means flip. + + Parameters + ---------- + iornt : np.ndarray + Initial orientation. + base_ornt : np.ndarray + Base orientation. + + Returns + ------- + new_iornt : np.ndarray + New orientation. + """ + new_iornt = iornt.copy() + + # Find the axis to compared and then compared orientation, where 1 means no flip + # and -1 means flip. + for axno, direction in np.asarray(base_ornt): + idx = np.where(iornt[:, 0] == axno) + idirection = iornt[int(idx[0][0]), 1] + if direction == idirection: + new_iornt[int(idx[0][0]), 1] = 1.0 + else: + new_iornt[int(idx[0][0]), 1] = -1.0 + + return new_iornt + + +def reorient_img(img, ref_img): + """ + Reorient a Nibabel image based on the orientation of a reference nibabel image. + + Parameters + ---------- + img : nibabel.Nifti1Image + Nibabel Image to reorient. + ref_img : nibabel.Nifti1Image + Reference orientation nibabel image. + + Returns + ------- + img : nibabel.Nifti1Image + Reoriented image. + """ + ref_ornt = nib.io_orientation(ref_img.affine) + iornt = nib.io_orientation(img.affine) + + if not np.array_equal(iornt, ref_ornt): + # first flip orientation + fornt = calculate_flip_orientation(iornt, ref_ornt) + img = img.as_reoriented(fornt) + # the transpose axis + tornt = np.ones_like(ref_ornt) + tornt[:, 0] = ref_ornt[:, 0] + img = img.as_reoriented(tornt) + + return img + + +def transform_axial2coronal(vol: np.ndarray, axial2coronal: bool = True) -> np.ndarray: + """ + Transforms a volume into the coronal axis and back. + + This function is used to transform a volume into the coronal axis and back. The + transformation is done by moving the axes of the volume. If the `axial2coronal` + parameter is set to True, the function will transform from axial to coronal. If it + is set to False, the function will transform from coronal to axial. + + Parameters + ---------- + vol : np.ndarray + The image volume to transform. + axial2coronal : bool, optional + A flag to determine the direction of the transformation. If True, transform from + axial to coronal. If False, transform from coronal to axial. (Default: True). + + Returns + ------- + np.ndarray + The transformed volume. + """ + # TODO check compatibility with axis transform from CerebNet + if axial2coronal: + return np.moveaxis(vol, [0, 1, 2], [0, 2, 1]) + else: + return np.moveaxis(vol, [0, 1, 2], [0, 2, 1]) + + +def transform_axial2sagittal(vol: np.ndarray, + axial2sagittal: bool = True) -> np.ndarray: + """ + Transforms a volume into the sagittal axis and back. + + This function is used to transform a volume into the sagittal axis and back. The + transformation is done by moving the axes of the volume. If the `axial2sagittal` + parameter is set to True, the function will transform from axial to sagittal. If it + is set to False, the function will transform from sagittal to axial. + + Parameters + ---------- + vol : np.ndarray + The image volume to transform. + axial2sagittal : bool, default=True + A flag to determine the direction of the transformation. If True, transform from + axial to sagittal. If False, transform from sagittal to axial. (Default: True). + + Returns + ------- + np.ndarray + The transformed volume. + """ + # TODO check compatibility with axis transform from CerebNet + if axial2sagittal: + return np.moveaxis(vol, [0, 1, 2], [2, 0, 1]) + else: + return np.moveaxis(vol, [0, 1, 2], [1, 2, 0]) + + +def rescale_image(img_data: np.ndarray) -> np.ndarray: + """ + Rescale the image data to the range [0, 255]. + + This function rescales the input image data to the range [0, 255]. + + Parameters + ---------- + img_data : np.ndarray + The image data to rescale. + + Returns + ------- + np.ndarray + The rescaled image data. + """ + # Conform intensities + # TODO move function to FastSurferCNN similar: CerebNet.datasets.utils.rescale_image + src_min, scale = getscale(img_data, 0, 255) + mapped_data = img_data + + # this used to rescale, if the image was not uint8 and any intensity was > 255 + if not np.allclose([src_min, scale], [0, 1]): + mapped_data = scalecrop(img_data, 0, 255, src_min, scale) + + return np.uint8(np.rint(mapped_data)) + + +def hypo_map_label2subseg(mapped_subseg: npt.NDArray[int]) -> npt.NDArray[int]: + """ + Perform look-up table mapping from label space to subseg space. + + This function is used to perform a look-up table mapping from label space to subseg + space. + + Parameters + ---------- + mapped_subseg : npt.NDArray[int] + The input array in label space to be mapped to subseg space. + + Returns + ------- + npt.NDArray[int] + The mapped array in subseg space. + """ + # TODO can this function be replaced by a Mapper and a mapping file? + labels, _ = hyposubseg_labels + subseg = np.zeros_like(mapped_subseg) + h, w, d = subseg.shape + subseg = labels[mapped_subseg.ravel()] + + return subseg.reshape((h, w, d)) + + +def hypo_map_prediction_sagittal2full( + prediction_sag: npt.NDArray[int], +) -> npt.NDArray[int]: + """ + Remap the prediction on the sagittal network to full label space. + + This function is used to remap the prediction on the sagittal network to the full + label space used by the coronal and axial networks. + + Parameters + ---------- + prediction_sag : npt.NDArray[int] + The sagittal prediction in label space to be remapped to full label space. + + Returns + ------- + npt.NDArray[int] + The remapped prediction in full label space. + """ + # TODO can this function be replaced by a Mapper and a mapping file? + + idx_list = list(SAG2FULL_MAP.values()) + prediction_full = prediction_sag[:, idx_list, :, :] + return prediction_full + + +def hypo_map_subseg_2_fsseg( + subseg: npt.NDArray[int], + reverse: bool = False, +) -> npt.NDArray[int]: + """ + Remap HypVINN internal labels to FastSurfer Labels and vice versa. + + This function is used to remap HypVINN internal labels to FastSurfer Labels and vice + versa. If the `reverse` parameter is set to False, the function will map HypVINN + labels to FastSurfer labels. If it is set to True, the function will map FastSurfer + labels to HypVINN labels. + + Parameters + ---------- + subseg : npt.NDArray[int] + The input array with HypVINN or FastSurfer labels to be remapped. + reverse : bool, default=False + A flag to determine the direction of the remapping. If False, remap HypVINN + labels to FastSurfer labels. If True, remap FastSurfer labels to HypVINN labels. + + Returns + ------- + npt.NDArray[int] + The remapped array with FastSurfer or HypVINN labels. + """ + # TODO can this function be replaced by a Mapper and a mapping file? + + fsseg = np.zeros_like(subseg, dtype=np.int16) + + if not reverse: + for value, name in HYPVINN_CLASS_NAMES.items(): + fsseg[subseg == value] = FS_CLASS_NAMES[name] + else: + reverse_hypvinn = dict(map(reversed, HYPVINN_CLASS_NAMES.items())) + for name, value in FS_CLASS_NAMES.items(): + fsseg[subseg == value] = reverse_hypvinn[name] + return fsseg diff --git a/HypVINN/data_loader/dataset.py b/HypVINN/data_loader/dataset.py new file mode 100644 index 00000000..776ddb56 --- /dev/null +++ b/HypVINN/data_loader/dataset.py @@ -0,0 +1,227 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from numpy import typing as npt +import torch +from torch.utils.data import Dataset + + +from HypVINN.data_loader.data_utils import transform_axial2sagittal,transform_axial2coronal +from FastSurferCNN.data_loader.data_utils import get_thick_slices + +import FastSurferCNN.utils.logging as logging +from HypVINN.utils import ModalityDict, ModalityMode + +logger = logging.get_logger(__name__) + + +# Operator to load imaged for inference +class HypVINNDataset(Dataset): + """ + Class to load MRI-Image and process it to correct format for HypVINN network inference. + + The HypVINN Dataset passed during Inference the input images,the scale factor for the VINN layer and a weight factor + (wT1,wT2). + The Weight factor determines the running mode of the HypVINN model. + if wT1 =1 and wT2 =0. The HypVINN model will only allow the flow of the T1 information (mode = t1). + if wT1 =0 and wT2 =1. The HypVINN model will only allow the flow of the T2 information (mode = t2). + if wT1 !=1 and wT2 !=1. The HypVINN model will automatically weigh the T1 information and the T2 information based + on the learned modality weights (mode = t1t2). + + Methods + ------- + _standarized_img(orig_data: np.ndarray, orig_zoom: npt.NDArray[float], modality: np.ndarray) -> np.ndarray + Standardize the image based on the original data, original zoom, and modality. + _get_scale_factor() -> npt.NDArray[float] + Get the scaling factor to match the original resolution of the input image to the final resolution of the + FastSurfer base network. + __getitem__(index: int) -> dict[str, torch.Tensor | np.ndarray] + Retrieve the image, scale factor, and weight factor for a given index. + __len__() + Return the number of images in the dataset. + """ + def __init__( + self, + subject_name: str, + modalities: ModalityDict, + orig_zoom: npt.NDArray[float], + cfg, + mode: ModalityMode = "t1t2", + transforms=None, + ): + """ + Initialize the HypVINN Dataset. + + Parameters + ---------- + subject_name : str + The name of the subject. + modalities : ModalityDict + The modalities of the subject. + orig_zoom : npt.NDArray[float] + The original zoom of the subject. + cfg : CfgNode + The configuration object. + mode : ModalityMode, default="t1t2" + The running mode of the HypVINN model. (Default: "t1t2"). + transforms : Callable, optional + The transformations to apply to the images. (Default: None). + + """ + self.subject_name = subject_name + self.plane = cfg.DATA.PLANE + #Inference Mode + self.mode = mode + #set thickness base on train paramters + if cfg.MODEL.MODE in ["t1", "t2"]: + self.slice_thickness = cfg.MODEL.NUM_CHANNELS//2 + else: + self.slice_thickness = cfg.MODEL.NUM_CHANNELS//4 + + self.base_res = cfg.MODEL.BASE_RES + + if self.mode == "t1": + orig_thick = self._standarized_img(modalities["t1"], orig_zoom, modality="t1") + orig_thick = np.concatenate((orig_thick, orig_thick), axis=-1) + self.weight_factor = torch.from_numpy(np.asarray([1.0, 0.0])) + + elif self.mode == "t2": + orig_thick = self._standarized_img(modalities["t2"], orig_zoom, modality="t2") + orig_thick = np.concatenate((orig_thick, orig_thick), axis=-1) + self.weight_factor = torch.from_numpy(np.asarray([0.0, 1.0])) + else: + t1_orig_thick = self._standarized_img(modalities["t1"], orig_zoom, modality="t1") + t2_orig_thick = self._standarized_img(modalities["t2"], orig_zoom, modality="t2") + orig_thick = np.concatenate((t1_orig_thick, t2_orig_thick), axis=-1) + self.weight_factor = torch.from_numpy(np.asarray([0.5, 0.5])) + + # Transpose from W,H,N,C to N,W,H,C + orig_thick = np.transpose(orig_thick, (2, 0, 1, 3)) + self.images = orig_thick + self.count = self.images.shape[0] + self.transforms = transforms + + logger.info( + f"Successfully loaded Image from {subject_name} for {self.plane} " + f"model" + ) + + if ((cfg.MODEL.MULTI_AUTO_W or cfg.MODEL.MULTI_AUTO_W_CHANNELS) and + (self.mode == 't1t2' or cfg.MODEL.DUPLICATE_INPUT)) : + logger.info( + f"For inference T1 block weight and the T2 block are set to " + f"the weights learn during training" + ) + else: + logger.info( + f"For inference T1 block weight was set to: " + f"{self.weight_factor.numpy()[0]} and the T2 block was set to: " + f"{self.weight_factor.numpy()[1]}") + + def _standarized_img(self, orig_data: np.ndarray, orig_zoom: npt.NDArray[float], + modality: np.ndarray) -> np.ndarray: + """ + Standardize the image based on the original data, original zoom, and modality. + + Parameters + ---------- + orig_data : np.ndarray + The original data of the image. + orig_zoom : npt.NDArray[float] + The original zoom of the image. + modality : np.ndarray + The modality of the image. + + Returns + ------- + orig_thick : np.ndarray + The standardized image. + """ + if self.plane == "sagittal": + orig_data = transform_axial2sagittal(orig_data) + self.zoom = orig_zoom[::-1][:2] + logger.info( + f"Loading {modality} sagittal with input voxelsize {self.zoom}" + ) + + elif self.plane == "coronal": + orig_data = transform_axial2coronal(orig_data) + self.zoom = orig_zoom[1:] + logger.info( + f"Loading {modality} coronal with input voxelsize {self.zoom}" + ) + + else: + self.zoom = orig_zoom[:2] + logger.info( + f"Loading {modality} axial with input voxelsize {self.zoom}" + ) + + # Create thick slices + orig_thick = get_thick_slices(orig_data, self.slice_thickness) + + return orig_thick + + def _get_scale_factor(self) -> npt.NDArray[float]: + """ + Get the scaling factor to match the original resolution of the input image to + the final resolution of the FastSurfer base network. The input resolution is + taken from the voxel size in the image header. + + Returns + ------- + scale : npt.NDArray[float] + The scaling factor along the x and y dimensions. This is a numpy array of float values. + """ + # TODO: This needs to be updated based on the plane we are looking at in case we + # are dealing with non-isotropic images as inputs. + + scale = self.base_res / np.asarray(self.zoom) + + return scale + + def __getitem__(self, index: int) -> dict[str, torch.Tensor | np.ndarray]: + """ + Retrieve the image, scale factor, and weight factor for a given index. + + This method retrieves the image at the given index from the images attribute, calculates the scale factor, + applies any transformations to the image if they are defined, and returns a dictionary containing the image, + scale factor, and weight factor. + + Parameters + ---------- + index : int + The index of the image to retrieve. + + Returns + ------- + dict[str, torch.Tensor | np.ndarray] + A dictionary containing the image, scale factor, and weight factor. + """ + img = self.images[index] + + scale_factor = self._get_scale_factor() + if self.transforms is not None: + img = self.transforms(img) + + return { + "image": img, + "scale_factor": scale_factor, + "weight_factor": self.weight_factor, + } + + def __len__(self): + return self.count + diff --git a/HypVINN/inference.py b/HypVINN/inference.py new file mode 100644 index 00000000..07953dae --- /dev/null +++ b/HypVINN/inference.py @@ -0,0 +1,415 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from time import time +from typing import Optional + +import torch +import numpy as np +import yacs.config +from tqdm import tqdm +from torch.utils.data import DataLoader +from torchvision import transforms + +import FastSurferCNN.utils.logging as logging +from FastSurferCNN.utils.common import find_device +from FastSurferCNN.data_loader.augmentation import ToTensorTest, ZeroPad2DTest +from HypVINN.models.networks import build_model +from HypVINN.data_loader.data_utils import hypo_map_prediction_sagittal2full +from HypVINN.data_loader.dataset import HypVINNDataset +from HypVINN.utils import ModalityMode + +logger = logging.get_logger(__name__) + + +class Inference: + """ + Class for running inference on a single subject. + + Attributes + ---------- + model : torch.nn.Module + The model to use for inference. + model_name : str + The name of the model. + + Methods + ------- + setup_model(cfg) + Set up the model. + """ + def __init__( + self, + cfg, + threads: int = -1, + async_io: bool = False, + device: str = "auto", + viewagg_device: str = "auto", + ): + """ + Initialize the Inference class. + + This method initializes the Inference class with the provided configuration, number of threads, async IO flag, + device, and view aggregation device. It sets the random seed, switches on denormal flushing, defines the device, + and sets up the initial model. + + Parameters + ---------- + cfg : yacs.config.CfgNode + The configuration node containing the parameters for the model. + threads : int, optional + The number of threads to use. Default is -1, which uses all available threads. + async_io : bool, optional + Whether to use asynchronous IO. Default is False. + device : str, optional + The device to use for computations. Can be 'auto', 'cpu', or 'cuda'. Default is 'auto'. + viewagg_device : str, optional + The device to use for view aggregation. Can be 'auto', 'cpu', or 'cuda'. Default is 'auto'. + """ + self._threads = threads + torch.set_num_threads(self._threads) + self._async_io = async_io + + # Set random seed from configs. + np.random.seed(cfg.RNG_SEED) + torch.manual_seed(cfg.RNG_SEED) + self.cfg = cfg + + # Switch on denormal flushing for faster CPU processing + # seems to have less of an effect on VINN than old CNN + torch.set_flush_denormal(True) + + # Define device and transfer model + self.device = find_device(device) + + if self.device.type == "cpu" and viewagg_device == "auto": + self.viewagg_device = self.device + else: + # check, if GPU is big enough to run view agg on it + # (this currently takes the memory of the passed device) + self.viewagg_device = torch.device( + find_device( + viewagg_device, + flag_name="viewagg_device", + min_memory=4 * (2 ** 30), + ) + ) + + logger.info(f"Running view aggregation on {self.viewagg_device}") + + # Initial model setup + self.model = self.setup_model(cfg) + self.model_name = self.cfg.MODEL.MODEL_NAME + + def setup_model( + self, + cfg: Optional["yacs.config.CfgNode"] = None, + ) -> torch.nn.Module: + """ + Set up the model. + + This method sets up the model for inference. + + Parameters + ---------- + cfg : yacs.config.CfgNode, optional + The configuration node containing the parameters for the model. + + Returns + ------- + model : torch.nn.Module + The model set up for inference. + """ + if cfg is not None: + self.cfg = cfg + + # Set up model + model = build_model(self.cfg) # + model.to(self.device) + + return model + + def set_cfg(self, cfg): + """ + Set the configuration node. + + Parameters + ---------- + cfg : yacs.config.CfgNode + The configuration node containing the parameters for the model. + """ + self.cfg = cfg + + def set_model(self, cfg: yacs.config.CfgNode = None): + """ + Set the model for the Inference instance. + + Parameters + ---------- + cfg : yacs.config.CfgNode, optional + The configuration node containing the parameters for the model. (Default = None). + """ + if cfg is not None: + self.cfg = cfg + + # Set up model + model = build_model(self.cfg) + model.to(self.device) + self.model = model + + def load_checkpoint(self, ckpt: str): + """ + Load a model checkpoint. + + This method loads a model checkpoint from a .pth file containing a state dictionary of a model. + + Parameters + ---------- + ckpt : str + The path to the checkpoint file. The checkpoint file should be a .pth file containing a state dictionary + of a model. + """ + logger.info("Loading checkpoint {}".format(ckpt)) + # WARNING: weights_only=False can cause unsafe code execution, but here the + # checkpoint can be considered to be from a safe source + model_state = torch.load(ckpt, map_location=self.device, weights_only=False) + self.model.load_state_dict(model_state["model_state"]) + + def get_modelname(self): + """ + Get the name of the model. + + This method returns the name of the model used in the Inference instance. + + Returns + ------- + str + The name of the model. + """ + return self.model_name + + def get_cfg(self): + """ + Get the configuration node. + + This method returns the configuration node used in the Inference instance. + + Returns + ------- + yacs.config.CfgNode + The configuration node containing the parameters for the model. + """ + return self.cfg + + def get_num_classes(self): + """ + Get the number of classes. + + This method returns the number of classes defined in the model configuration. + + Returns + ------- + int + The number of classes. + """ + return self.cfg.MODEL.NUM_CLASSES + + def get_plane(self): + """ + Get the plane. + + This method returns the plane defined in the data configuration. + + Returns + ------- + str + The plane. + """ + return self.cfg.DATA.PLANE + + def get_model_height(self): + """ + Get the model height. + + This method returns the height of the model defined in the model configuration. + + Returns + ------- + int + The height of the model. + """ + return self.cfg.MODEL.HEIGHT + + def get_model_width(self): + """ + Get the model width. + + This method returns the width of the model defined in the model configuration. + + Returns + ------- + int + The width of the model. + """ + return self.cfg.MODEL.WIDTH + + def get_max_size(self): + """ + Get the maximum size of the output tensor. + + Returns + ------- + int or tuple + The maximum size. If the width and height of the output tensor are equal, it returns the width. Otherwise, it + returns both the width and height. + """ + if self.cfg.MODEL.OUT_TENSOR_WIDTH == self.cfg.MODEL.OUT_TENSOR_HEIGHT: + return self.cfg.MODEL.OUT_TENSOR_WIDTH + else: + return self.cfg.MODEL.OUT_TENSOR_WIDTH, self.cfg.MODEL.OUT_TENSOR_HEIGHT + + def get_device(self): + """ + Get the device. + + This method returns the device and view aggregation device used in the Inference instance. + + Returns + ------- + tuple + The device and view aggregation device. + """ + return self.device,self.viewagg_device + + #TODO check is possible to modify to CerebNet inference mode from RAS directly to LIA (CerebNet.Inference._predict_single_subject) + @torch.no_grad() + def eval(self, val_loader: DataLoader, pred_prob: torch.Tensor, out_scale: float = None) -> torch.Tensor: + """ + Evaluate the model on a HypVINN dataset. + + This method runs the model in evaluation mode on a HypVINN Dataset. It iterates over the given dataset and + computes the model's predictions. + + Parameters + ---------- + val_loader : DataLoader + The DataLoader for the validation set. + pred_prob : torch.Tensor + The tensor to update with the prediction probabilities. + out_scale : float, optional + The scale factor for the output. Default is None. + + Returns + ------- + pred_prob: torch.Tensor + The updated prediction probabilities. + """ + self.model.eval() + + start_index = 0 + for batch_idx, batch in tqdm(enumerate(val_loader), total=len(val_loader)): + + images = batch["image"].to(self.device) + scale_factors = batch["scale_factor"].to(self.device) + weight_factors = batch["weight_factor"].to(self.device, dtype=torch.float32) + + pred = self.model(images, scale_factors, weight_factors, out_scale) + + if self.cfg.DATA.PLANE == "axial": + pred = pred.permute((2, 3, 0, 1)).to(self.viewagg_device) + pred_prob[:, :, start_index:start_index + pred.shape[2], :] += torch.mul(pred, 0.4) + start_index += pred.shape[2] + + elif self.cfg.DATA.PLANE == "coronal": + pred = pred.permute(2, 0, 3, 1).to(self.viewagg_device) + pred_prob[:, start_index:start_index + pred.shape[1], :, :] += torch.mul(pred, 0.4) + start_index += pred.shape[1] + + else: + pred = hypo_map_prediction_sagittal2full(pred).permute(0, 2, 3, 1).to(self.viewagg_device) + pred_prob[start_index:start_index + pred.shape[0],:, :, :] += torch.mul(pred, 0.2) + start_index += pred.shape[0] + + logger.info("---> {} Model Testing Done.".format(self.cfg.DATA.PLANE)) + + return pred_prob + + def run( + self, + subject_name: str, + modalities, + orig_zoom, + pred_prob, + out_res=None, + mode: ModalityMode = "t1t2", + ): + """ + Run the inference process on a single subject. + + This method sets up the HypVINN DataLoader for the subject, runs the model in evaluation mode on the subject's + data, + and returns the updated prediction probabilities. + + Parameters + ---------- + subject_name : str + The name of the subject. + modalities : ModalityDict + The modalities of the subject. + orig_zoom : npt.NDArray[float] + The original zoom of the subject. + pred_prob : torch.Tensor + The tensor to update with the prediction probabilities. + out_res : float, optional + The resolution of the output. Default is None. + mode : ModalityMode, default="t1t2" + The mode of the modalities. Default is 't1t2'. + + Returns + ------- + pred_prob: torch.Tensor + The updated prediction probabilities. + """ + # Set up DataLoader + test_dataset = HypVINNDataset( + subject_name, + modalities, + orig_zoom, + self.cfg, + mode=mode, + transforms=transforms.Compose( + [ + ZeroPad2DTest( + (self.cfg.DATA.PADDED_SIZE, self.cfg.DATA.PADDED_SIZE), + ), + ToTensorTest(), + ], + ), + ) + + test_data_loader = DataLoader( + dataset=test_dataset, + shuffle=False, + batch_size=self.cfg.TEST.BATCH_SIZE, + ) + + # Run evaluation + start = time() + pred_prob = self.eval(test_data_loader, pred_prob, out_scale=out_res) + logger.info( + f"{self.cfg.DATA.PLANE} Inference on {subject_name} finished in " + f"{time()-start:0.4f} seconds" + ) + + return pred_prob diff --git a/HypVINN/models/__init__.py b/HypVINN/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/HypVINN/models/networks.py b/HypVINN/models/networks.py new file mode 100644 index 00000000..ca2cfbfd --- /dev/null +++ b/HypVINN/models/networks.py @@ -0,0 +1,264 @@ + +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# IMPORTS +from typing import Dict + +import yacs.config +from torch import Tensor, nn +import torch +import FastSurferCNN.models.sub_module as sm +import FastSurferCNN.models.interpolation_layer as il +from FastSurferCNN.models.networks import FastSurferCNNBase +import numpy as np + + +class HypVINN(FastSurferCNNBase): + """ + HypVINN class that extends the FastSurferCNNBase class. + + This class represents a HypVINN model. It includes methods for initializing the model, setting up the layers, + and performing forward propagation. + + Attributes + ---------- + height : int + The height of the output tensor. + width : int + The width of the output tensor. + out_tensor_shape : tuple + The shape of the output tensor. + interpolation_mode : str + The interpolation mode to use when resizing the images. This can be 'nearest', 'bilinear', 'bicubic', or 'area'. + crop_position : str + The position to crop the images from. This can be 'center', 'top_left', 'top_right', 'bottom_left', or 'bottom_right'. + m1_inp_block : InputDenseBlock + The input block for the first modality. + m2_inp_block : InputDenseBlock + The input block for the second modality. + mod_weights : nn.Parameter + The weights for the two modalities. + normalize_weights : nn.Softmax + A softmax function to normalize the modality weights. + outp_block : OutputDenseBlock + The output block of the model. + interpol1 : Zoom2d + The first interpolation layer. + interpol2 : Zoom2d + The second interpolation layer. + classifier : ClassifierBlock + The final classifier block of the model. + + Methods + ------- + forward(x, scale_factor, weight_factor, scale_factor_out=None) + Perform forward propagation through the model. + """ + def __init__(self, params, padded_size=256): + """ + Initialize the HypVINN model. + + This method initializes the HypVINN model by calling the super class constructor and setting up the layers. + + Parameters + ---------- + params : Dict + A dictionary containing the configuration parameters for the model. + padded_size : int, optional + The size of the image when padded. (Default = 256). + + Raises + ------ + ValueError + If the interpolation mode or crop position is invalid. + """ + num_c = params["num_channels"] + + params["num_channels"] = params["num_filters_interpol"] + + super(HypVINN, self).__init__(params) + + # Flex options + self.height = params["height"] + self.width = params["width"] + + self.out_tensor_shape = tuple( + params.get("out_tensor_" + k, padded_size) for k in ["width", "height"] + ) + + self.interpolation_mode = ( + params["interpolation_mode"] + if "interpolation_mode" in params + else "bilinear" + ) + if self.interpolation_mode not in ["nearest", "bilinear", "bicubic", "area"]: + raise ValueError("Invalid interpolation mode") + + self.crop_position = ( + params["crop_position"] if "crop_position" in params else "top_left" + ) + if self.crop_position not in [ + "center", + "top_left", + "top_right", + "bottom_left", + "bottom_right", + ]: + raise ValueError("Invalid crop position") + + # Reset input channels to two modalities head number (overwritten in super call) + params["num_channels"] = num_c // 2 + + self.m1_inp_block = sm.InputDenseBlock(params) + self.m2_inp_block = sm.InputDenseBlock(params) + + # Initialize learneble modality weights + self.mod_weights = nn.Parameter(torch.ones(2) * 0.5) + self.normalize_weights = nn.Softmax(dim=0) + + params["num_channels"] = params["num_filters"] + params["num_filters_interpol"] + + self.outp_block = sm.OutputDenseBlock(params) + + self.interpol1 = il.Zoom2d((self.width, self.height), + interpolation_mode=self.interpolation_mode, + crop_position=self.crop_position) + + self.interpol2 = il.Zoom2d(self.out_tensor_shape, + interpolation_mode=self.interpolation_mode, + crop_position=self.crop_position) + + # Classifier logits options + params['num_channels'] = params['num_filters'] + self.classifier = sm.ClassifierBlock(params) + + # Code for Network Initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, scale_factor: torch.Tensor, weight_factor: torch.Tensor, + scale_factor_out: torch.Tensor = None) -> torch.Tensor: + """ + Forward propagation method for the HypVINN model. + + This method takes an input tensor, a scale factor, a weight factor, and an optional output scale factor. + It performs forward propagation through the model, applying the input blocks, interpolation layers, output + block, and classifier block. It also handles the weighting of the two modalities and the rescaling of the + output. + + Parameters + ---------- + x : torch.Tensor + The input tensor. It should have a shape of (batch_size, num_channels, height, width). + scale_factor : torch.Tensor + The scale factor for the input images. It should have a shape of (batch_size, 2). + weight_factor : torch.Tensor + The weight factor for the two modalities. It should have a shape of (batch_size, 2). + scale_factor_out : torch.Tensor, optional + The scale factor for the output images. If not provided, it defaults to the scale factor of the input images. + + Returns + ------- + logits : torch.Tensor + The output logits from the classifier block. It has a shape of (batch_size, num_classes, height, width). + + Raises + ------ + ValueError + If the interpolation mode or crop position is invalid. + """ + # Weight factor [wT1,wT2] has 3 stages [1,0],[0.5,0.5],[0,1], + # if the weight factor is [0.5,0.5] the automatically weights (s_weights) are passed + # If there is a 1 in the comparison the automatically weights will be replace by the first weight_factors pass + comparison = weight_factor[0] + + x = torch.tensor_split(x, 2, dim=1) + # Input block + Flex to 1 mm + skip_encoder_01 = self.m1_inp_block(x[0]) + skip_encoder_02 = self.m2_inp_block(x[1]) + + s_weights = self.normalize_weights(self.mod_weights) + + # If one weight 1 it means modality is not available + if 1 in comparison: + s_weights = comparison + + mw1 = s_weights[0].float() + mw2 = s_weights[1].float() + + # Shared latent space + skip_encoder_0 = mw1 * skip_encoder_01 + mw2 * skip_encoder_02 + + encoder_output0, rescale_factor = self.interpol1(skip_encoder_0, scale_factor) # instead of maxpool = encoder_output_0 + + # FastSurferCNN Base + decoder_output1 = super().forward(encoder_output0, scale_factor=scale_factor) + + # Flex to original res + if scale_factor_out is None: + scale_factor_out = rescale_factor + else: + scale_factor_out = np.asarray(scale_factor_out) * np.asarray(rescale_factor) / np.asarray(scale_factor) + + prior_target_shape = self.interpol2.target_shape + self.interpol2.target_shape = skip_encoder_0.shape[2:] + try: + decoder_output0, sf = self.interpol2( + decoder_output1, scale_factor_out, rescale=True + ) + finally: + self.interpol2.target_shape = prior_target_shape + + outblock = self.outp_block(decoder_output0, skip_encoder_0) + # Final logits layer + logits = self.classifier.forward(outblock) # 1x1 convolution + + return logits + + +_MODELS = { + "HypVinn": HypVINN, +} + + +def build_model(cfg: yacs.config.CfgNode) -> HypVINN: + """ + Build and return the requested model. + + Parameters + ---------- + cfg : yacs.config.CfgNode + The configuration node containing the parameters for the model. + + Returns + ------- + HypVINN + An instance of the requested model. + + Raises + ------ + AssertionError + If the model specified in the configuration is not supported. + """ + if cfg.MODEL.MODEL_NAME not in _MODELS: + raise AssertionError(f"Model {cfg.MODEL.MODEL_NAME} not supported") + params = {k.lower(): v for k, v in dict(cfg.MODEL).items()} + model_type = _MODELS[cfg.MODEL.MODEL_NAME] + return model_type(params, padded_size=cfg.DATA.PADDED_SIZE) diff --git a/HypVINN/run_prediction.py b/HypVINN/run_prediction.py new file mode 100644 index 00000000..d6945914 --- /dev/null +++ b/HypVINN/run_prediction.py @@ -0,0 +1,678 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# IMPORTS +from typing import TYPE_CHECKING, Optional, cast, Literal +import argparse +from pathlib import Path +from time import time + +import numpy as np +from numpy import typing as npt +import torch + +if TYPE_CHECKING: + import yacs.config + from nibabel.filebasedimages import FileBasedHeader + +from FastSurferCNN.utils import PLANES, Plane, logging, parser_defaults +from FastSurferCNN.utils.checkpoint import ( + get_checkpoints, + load_checkpoint_config_defaults, +) +from FastSurferCNN.utils.common import assert_no_root, SerialExecutor + +from HypVINN.config.hypvinn_files import HYPVINN_SEG_NAME, HYPVINN_MASK_NAME +from HypVINN.data_loader.data_utils import hypo_map_label2subseg, rescale_image +from HypVINN.inference import Inference +from HypVINN.utils import ModalityDict, ModalityMode, ViewOperations +from HypVINN.utils.checkpoint import YAML_DEFAULT as CHECKPOINT_PATHS_FILE +from HypVINN.utils.img_processing_utils import save_segmentation +from HypVINN.utils.load_config import load_config +from HypVINN.utils.misc import create_expand_output_directory +from HypVINN.utils.mode_config import get_hypinn_mode +from HypVINN.utils.preproc import hypvinn_preproc +from HypVINN.utils.stats_utils import compute_stats +from HypVINN.utils.visualization_utils import plot_qc_images + +logger = logging.get_logger(__name__) + +## +# Input array preparation +## + + +def optional_path(a: Path | str) -> Optional[Path]: + """ + Convert a string to a Path object or None. + + Parameters + ---------- + a : Path, str + The input to convert. + + Returns + ------- + Path, None + The converted Path object. + """ + if isinstance(a, Path): + return a + if a.lower() in ("none", ""): + return None + return Path(a) + + +def option_parse() -> argparse.ArgumentParser: + """ + A function to create an ArgumentParser object and parse the command line arguments. + + Returns + ------- + argparse.ArgumentParser + The parser object to parse arguments from the command line. + """ + parser = argparse.ArgumentParser( + description="Script for Hypothalamus Segmentation.", + ) + + # 1. Directory information (where to read from, where to write from and to incl. search-tag) + parser = parser_defaults.add_arguments( + parser, ["sd", "sid"], + ) + + parser = parser_defaults.add_arguments(parser, ["seg_log"]) + + # 2. Options for the MRI volumes + parser = parser_defaults.add_arguments( + parser, ["t1"] + ) + parser.add_argument( + '--t2', + type=optional_path, + default=None, + required=False, + help="Path to the T2 image to process.", + ) + + # 3. Image processing options + parser.add_argument( + "--qc_snap", + action='store_true', + dest="qc_snapshots", + help="Create qc snapshots in //qc_snapshots.", + ) + parser.add_argument( + "--reg_mode", + type=str, + default="coreg", + choices=["none", "coreg", "robust"], + help="Freesurfer Registration type to run. coreg: mri_coreg, " + "robust : mri_robust_register, none: entirely deactivates " + "registration of T2 to T1, if both images are passed, " + "images need to be register properly externally.", + ) + + parser.add_argument( + "--hypo_segfile", + type=str, + default=HYPVINN_SEG_NAME, + dest="hypo_segfile", + help=f"File pattern on where to save the hypothalamus segmentation file " + f"(default: {HYPVINN_SEG_NAME})." + ) + + # 4. Options for advanced, technical parameters + advanced = parser.add_argument_group(title="Advanced options") + parser_defaults.add_arguments( + advanced, + ["device", "viewagg_device", "threads", "batch_size", "async_io", "allow_root"], + ) + + files: dict[Plane, str | Path] = {k: "default" for k in PLANES} + # 5. Checkpoint to load + parser_defaults.add_plane_flags( + advanced, + "checkpoint", + files, + CHECKPOINT_PATHS_FILE, + ) + + parser_defaults.add_plane_flags( + advanced, + "config", + { + "coronal": Path("HypVINN/config/HypVINN_coronal_v1.1.0.yaml"), + "axial": Path("HypVINN/config/HypVINN_axial_v1.1.0.yaml"), + "sagittal": Path("HypVINN/config/HypVINN_sagittal_v1.1.0.yaml"), + }, + CHECKPOINT_PATHS_FILE, + ) + return parser + + +def main( + out_dir: Path, + t2: Optional[Path], + orig_name: Optional[Path], + sid: str, + ckpt_ax: Path, + ckpt_cor: Path, + ckpt_sag: Path, + cfg_ax: Path, + cfg_cor: Path, + cfg_sag: Path, + hypo_segfile: str = HYPVINN_SEG_NAME, + hypo_maskfile: str = HYPVINN_MASK_NAME, + allow_root: bool = False, + qc_snapshots: bool = False, + reg_mode: Literal["coreg", "robust", "none"] = "coreg", + threads: int = -1, + batch_size: int = 1, + async_io: bool = False, + device: str = "auto", + viewagg_device: str = "auto", +) -> int | str: + f""" + Main function of the hypothalamus segmentation module. + + Parameters + ---------- + out_dir : Path + The output directory where the results will be stored. + t2 : Path, optional + The path to the T2 image to process. + orig_name : Path, optional + The path to the T1 image to process or FastSurfer orig image. + sid : str + The subject ID. + ckpt_ax : Path + The path to the axial checkpoint file. + ckpt_cor : Path + The path to the coronal checkpoint file. + ckpt_sag : Path + The path to the sagittal checkpoint file. + cfg_ax : Path + The path to the axial configuration file. + cfg_cor : Path + The path to the coronal configuration file. + cfg_sag : Path + The path to the sagittal configuration file. + hypo_segfile : str, default="{HYPVINN_SEG_NAME}" + The name of the hypothalamus segmentation file. Default is {HYPVINN_SEG_NAME}. + hypo_maskfile : str, default="{HYPVINN_MASK_NAME}" + The name of the hypothalamus mask file. Default is {HYPVINN_MASK_NAME}. + allow_root : bool, default=False + Whether to allow running as root user. Default is False. + qc_snapshots : bool, optional + Whether to create QC snapshots. Default is False. + reg_mode : "coreg", "robust", "none", default="coreg" + The registration mode to use. Default is "coreg". + threads : int, default=-1 + The number of threads to use. Default is -1, which uses all available threads. + batch_size : int, default=1 + The batch size to use. Default is 1. + async_io : bool, default=False + Whether to use asynchronous I/O. Default is False. + device : str, default="auto" + The device to use. Default is "auto", which automatically selects the device. + viewagg_device : str, default="auto" + The view aggregation device to use. Default is "auto", which automatically + selects the device. + + Returns + ------- + int, str + 0, if successful, an error message describing the cause for the + failure otherwise. + """ + from concurrent.futures import ProcessPoolExecutor, Future + if threads != 1: + pool = ProcessPoolExecutor(threads) + else: + pool = SerialExecutor() + prep_tasks: dict[str, Future] = {} + + # mapped freesurfer orig input name to the hypvinn t1 name + t1_path = orig_name + t2_path = t2 + subject_name = sid + subject_dir = out_dir / sid + # Warning if run as root user + allow_root or assert_no_root() + start = time() + try: + # Set up logging + prep_tasks["cp"] = pool.submit(prepare_checkpoints, ckpt_ax, ckpt_cor, ckpt_sag) + + kwargs = {} + if t1_path is not None: + kwargs["t1_path"] = Path(t1_path) + if t2_path: + kwargs["t2_path"] = Path(t2_path) + # Get configuration to run multi-modal or uni-modal + mode = get_hypinn_mode(**kwargs) + + if not mode: + return ( + f"Failed Evaluation on {subject_name} couldn't determine the " + f"processing mode. Please check that T1 or T2 images are " + f"available.\nT1 image path: {t1_path}\nT2 image path " + f"{t2_path}.\nNo T1 or T2 image available." + ) + + # Create output directory if it does not already exist. + create_expand_output_directory(subject_dir, qc_snapshots) + logger.info( + f"Running HypVINN segmentation pipeline on subject {sid}" + ) + logger.info(f"Output will be stored in: {subject_dir}") + logger.info(f"T1 image input {t1_path}") + logger.info(f"T2 image input {t2_path}") + + # Pre-processing -- T1 and T2 registration + if mode == "t1t2": + # Note, that t1_path and t2_path are guaranteed to be not None via + # get_hypvinn_mode, which only returns t1t2, if t1 and t2 exist. + # hypvinn_preproc returns the path to the t2 that is registered to the t1 + prep_tasks["reg"] = pool.submit( + hypvinn_preproc, + mode, + reg_mode, + subject_dir=Path(subject_dir), + threads=threads, + **kwargs, + ) + + # Segmentation pipeline + seg = time() + view_ops: ViewOperations = {a: None for a in PLANES} + logger.info("Setting up HypVINN run") + + cfgs = (cfg_ax, cfg_cor, cfg_sag) + ckpts = (ckpt_ax, ckpt_cor, ckpt_sag) + for plane, _cfg_file, _ckpt_file in zip(PLANES, cfgs, ckpts): + logger.info(f"{plane} model configuration from {_cfg_file}") + view_ops[plane] = { + "cfg": set_up_cfgs(_cfg_file, subject_dir, batch_size), + "ckpt": _ckpt_file, + } + + model = view_ops[plane]["cfg"].MODEL + if mode != model.MODE and "HypVinn" not in model.MODEL_NAME: + raise AssertionError( + f"Modality mode different between input arg: " + f"{mode} and axial train cfg: {model.MODE}" + ) + + cfg_fin, ckpt_fin = view_ops["coronal"].values() + + if "reg" in prep_tasks: + t2_path = prep_tasks["reg"].result() + kwargs["t2_path"] = t2_path + prep_tasks["load"] = pool.submit(load_volumes, mode=mode, **kwargs) + + # Set up model + model = Inference( + cfg=cfg_fin, + async_io=async_io, + threads=threads, + viewagg_device=viewagg_device, + device=device, + ) + + logger.info('----' * 30) + logger.info(f"Evaluating hypothalamus model on {subject_name}") + + # wait for all prep tasks to finish + for ptask in prep_tasks.values(): + if e := ptask.exception(): + raise e + + # Load Images + image_data, affine, header, orig_zoom, orig_size = prep_tasks["load"].result() + logger.info(f"Scale factor: {orig_zoom}") + + pred = time() + pred_classes = get_prediction( + subject_name, + image_data, + orig_zoom, + model, + target_shape=orig_size, + view_opts=view_ops, + out_scale=None, + mode=mode, + ) + logger.info(f"Model prediction finished in {time() - pred:0.4f} seconds") + logger.info(f"Saving results in {subject_dir}") + + if mode == 't1t2' or mode == 't1': + orig_path = t1_path + else: + orig_path = t2_path + + time_needed = save_segmentation( + pred_classes, + orig_path=orig_path, + ras_affine=affine, + ras_header=header, + subject_dir=subject_dir, + seg_file=hypo_segfile, + mask_file=hypo_maskfile, + save_mask=True, + ) + logger.info(f"Prediction successfully saved in {time_needed} seconds.") + if qc_snapshots: + qc_future: Optional[Future] = pool.submit( + plot_qc_images, + subject_qc_dir=subject_dir / "qc_snapshots", + orig_path=orig_path, + prediction_path=Path(subject_dir / "mri" /hypo_segfile), + ) + qc_future.add_done_callback( + lambda x: logger.info(f"QC snapshots saved in {x.result()} seconds."), + ) + else: + qc_future = None + + logger.info("Computing stats") + return_value = compute_stats( + orig_path=orig_path, + prediction_path=Path(subject_dir / "mri" /hypo_segfile), + stats_dir=subject_dir / "stats", + threads=threads, + ) + if return_value != 0: + logger.error(return_value) + + logger.info( + f"Processing segmentation finished in {time() - seg:0.4f} seconds." + ) + except (FileNotFoundError, RuntimeError) as e: + logger.info(f"Failed Evaluation on {subject_name}:") + logger.exception(e) + else: + if qc_future: + # finish qc + qc_future.result() + + logger.info( + f"Processing whole pipeline finished in {time() - start:.4f} seconds." + ) + + +def prepare_checkpoints(ckpt_ax, ckpt_cor, ckpt_sag): + """ + Prepare the checkpoints for the Hypothalamus Segmentation model. + + This function checks if the checkpoint files for the axial, coronal, and sagittal planes exist. + If they do not exist, it downloads them from the default URLs specified in the configuration file. + + Parameters + ---------- + ckpt_ax : str + The path to the axial checkpoint file. + ckpt_cor : str + The path to the coronal checkpoint file. + ckpt_sag : str + The path to the sagittal checkpoint file. + """ + logger.info("Checking or downloading default checkpoints ...") + urls = load_checkpoint_config_defaults( + "url", + filename=CHECKPOINT_PATHS_FILE, + ) + get_checkpoints(ckpt_ax, ckpt_cor, ckpt_sag, urls=urls) + + +def load_volumes( + mode: ModalityMode, + t1_path: Optional[Path] = None, + t2_path: Optional[Path] = None, +) -> tuple[ + ModalityDict, + npt.NDArray[float], + "FileBasedHeader", + tuple[float, float, float], + tuple[int, int, int], +]: + """ + Load the volumes of T1 and T2 images. + + This function loads the T1 and T2 images, checks their compatibility based on the mode, and returns the loaded + volumes along with their affine transformations, headers, zoom levels, and sizes. + + Parameters + ---------- + mode : ModalityMode + The mode of operation. Can be 't1', 't2', or 't1t2'. + t1_path : Path, optional + The path to the T1 image. Default is None. + t2_path : Path, optional + The path to the T2 image. Default is None. + + Returns + ------- + tuple + A tuple containing the following elements: + - modalities: A dictionary with keys 't1' and/or 't2' and values being the corresponding loaded and rescaled images. + - affine: The affine transformation of the loaded image(s). + - header: The header of the loaded image(s). + - zoom: The zoom level of the loaded image(s). + - size: The size of the loaded image(s). + + Raises + ------ + RuntimeError + If the mode is inconsistent with the provided image paths, or if the number of dimensions of the data is invalid. + ValueError + If the mode is invalid, or if a header is missing. + AssertionError + If the mode is 't1t2' but the T1 and T2 images have different resolutions or sizes. + """ + import nibabel as nib + modalities: ModalityDict = {} + + t1_size = () + t2_size = () + t1_zoom = () + t2_zoom = () + affine: npt.NDArray[float] = np.ndarray([0]) + header: Optional["FileBasedHeader"] = None + zoom: tuple[float, float, float] = (0.0, 0.0, 0.0) + size: tuple[int, ...] = (0, 0, 0) + + if t1_path: + logger.info(f'Loading T1 image from : {t1_path}') + t1 = nib.load(t1_path) + t1 = nib.as_closest_canonical(t1) + if mode in ('t1t2', 't1'): + affine = t1.affine + header = t1.header + else: + raise RuntimeError(f"Invalid mode {mode}, or inconsistent with t1_path!") + t1_zoom = t1.header.get_zooms() + zoom = cast(tuple[float, float, float], tuple(np.round(t1_zoom, 3))) + # Conform Intensities + modalities["t1"] = rescale_image(np.asarray(t1.dataobj)) + t1_size: tuple[int, ...] = modalities["t1"].shape + size = t1_size + if t2_path: + logger.info(f"Loading T2 image from {t2_path}") + t2 = nib.load(t2_path) + t2 = nib.as_closest_canonical(t2) + t2_zoom = t2.header.get_zooms() + if mode == "t2": + affine = t2.affine + header = t2.header + zoom = cast(tuple[float, float, float], tuple(np.round(t2_zoom, 3))) + elif mode == "t1t2": + pass + else: + raise RuntimeError(f"Invalid mode {mode}, or inconsistent with t2_path!") + # Conform Intensities + modalities["t2"] = np.asarray(rescale_image(t2.get_fdata()), dtype=np.uint8) + t2_size = modalities["t2"].shape + size = t2_size + + if mode == "t1t2": + if not np.allclose(np.array(t1_zoom), np.array(t2_zoom), rtol=0.05): + raise AssertionError( + f"T1 {t1_zoom} and T2 {t2_zoom} images have different resolutions!" + ) + if not np.allclose(np.array(t1_size), np.array(t2_size), rtol=0.05): + raise AssertionError( + f"T1 {t1_size} and T2 {t2_size} images have different size!" + ) + elif mode not in ("t1", "t2"): + raise ValueError(f"Invalid mode {mode}, vs. 't1', 't2', 't1t2'") + + if header is None: + raise ValueError("Missing a header!") + if len(size) != 3: + raise RuntimeError("Invalid ndims of data!") + _size = cast(tuple[int, int, int], size) + + return modalities, affine, header, zoom, _size + + +def get_prediction( + subject_name: str, + modalities: ModalityDict, + orig_zoom, + model: Inference, + target_shape: tuple[int, int, int], + view_opts: ViewOperations, + out_scale=None, + mode: ModalityMode = "t1t2", +) -> npt.NDArray[int]: + """ + Run the prediction for the Hypothalamus Segmentation model. + + This function sets up the prediction process for the Hypothalamus Segmentation model. It runs the model for each + plane (axial, coronal, sagittal), accumulates the prediction probabilities, and then generates the final prediction. + + Parameters + ---------- + subject_name : str + The name of the subject. + modalities : ModalityDict + A dictionary containing the modalities (T1 and/or T2) and their corresponding images. + orig_zoom : npt.NDArray[float] + The original zoom of the subject. + model : Inference + The Inference object of the model. + target_shape : tuple[int, int, int] + The target shape of the output prediction. + view_opts : ViewOperations + A dictionary containing the configurations for each plane. + out_scale : optional + The output scale. Default is None. + mode : ModalityMode, default="t1t2" + The mode of operation. Can be 't1', 't2', or 't1t2'. Default is 't1t2'. + + Returns + ------- + pred_classes: npt.NDArray[int] + The final prediction of the model. + """ + # TODO There are probably several possibilities to accelerate this script. + # FastSurferVINN takes 7-8s vs. HypVINN 10+s per slicing direction. + # Solution: make this script/function more similar to the optimized FastSurferVINN + device, viewagg_device = model.get_device() + dim = model.get_max_size() + + pred_shape = (dim, dim, dim, model.get_num_classes()) + # Set up tensor to hold probabilities and run inference + pred_prob = torch.zeros(pred_shape, dtype=torch.float, device=viewagg_device) + for plane, opts in view_opts.items(): + logger.info(f"Evaluating {plane} model, cpkt :{opts['ckpt']}") + model.set_model(opts["cfg"]) + model.load_checkpoint(opts["ckpt"]) + pred_prob += model.run(subject_name, modalities, orig_zoom, pred_prob, out_scale, mode=mode) + + # Post processing + h, w, d = target_shape # final prediction shape equivalent to input ground truth shape + + if np.any(target_shape < pred_prob.shape[:3]): + # if orig was padded before running through model (difference in + # aseg_size and pred_shape), select slices of interest only. + # This currently works only for "top_left" padding (see augmentation) + pred_prob = pred_prob[0:h, 0:w, 0:d, :] + + # Get hard predictions and map to freesurfer label space + _, pred_classes = torch.max(pred_prob, 3) + del pred_prob + pred_classes = pred_classes.cpu().numpy() + pred_classes = hypo_map_label2subseg(pred_classes) + + return pred_classes + + +## +# Processing +## +def set_up_cfgs( + cfg: "yacs.config.CfgNode", + out_dir: Path, + batch_size: int = 1, +) -> "yacs.config.CfgNode": + """ + Set up the configuration for the Hypothalamus Segmentation model. + + This function loads the configuration, sets the output directory and batch size, and adjusts the output tensor + dimensions based on the padded size specified in the configuration. + + Parameters + ---------- + cfg : yacs.config.CfgNode + The configuration node to load. + out_dir : Path + The output directory where the results will be stored. + batch_size : int, default=1 + The batch size to use. Default is 1. + + Returns + ------- + yacs.config.CfgNode + The loaded and adjusted configuration node. + + """ + cfg = load_config(cfg) + cfg.OUT_LOG_DIR = str(out_dir or cfg.LOG_DIR) + cfg.TEST.BATCH_SIZE = batch_size + + out_dims = cfg.DATA.PADDED_SIZE + if out_dims > cfg.DATA.PADDED_SIZE: + cfg.MODEL.OUT_TENSOR_WIDTH = out_dims + cfg.MODEL.OUT_TENSOR_HEIGHT = out_dims + else: + cfg.MODEL.OUT_TENSOR_WIDTH = cfg.DATA.PADDED_SIZE + cfg.MODEL.OUT_TENSOR_HEIGHT = cfg.DATA.PADDED_SIZE + return cfg + + +if __name__ == "__main__": + # arguments + parser = option_parse() + args = vars(parser.parse_args()) + log_name = (args["log_name"] or + args["out_dir"] / args["sid"] / "scripts/hypvinn_seg.log") + del args["log_name"] + + from FastSurferCNN.utils.logging import setup_logging + setup_logging(log_name) + + import sys + sys.exit(main(**args)) diff --git a/HypVINN/utils/__init__.py b/HypVINN/utils/__init__.py new file mode 100644 index 00000000..2f2b06fe --- /dev/null +++ b/HypVINN/utils/__init__.py @@ -0,0 +1,10 @@ +from typing import Any, Literal, Optional + +from numpy import ndarray + +from FastSurferCNN.utils import Plane + +ViewOperations = dict[Plane, Optional[dict[Literal["cfg", "ckpt"], Any]]] +ModalityMode = Literal["t1", "t2", "t1t2"] +ModalityDict = dict[Literal["t1", "t2"], ndarray] +RegistrationMode = Literal["robust", "coreg", "none"] diff --git a/HypVINN/utils/checkpoint.py b/HypVINN/utils/checkpoint.py new file mode 100644 index 00000000..f2e122b5 --- /dev/null +++ b/HypVINN/utils/checkpoint.py @@ -0,0 +1,17 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +YAML_DEFAULT = FASTSURFER_ROOT / "HypVINN/config/checkpoint_paths.yaml" diff --git a/HypVINN/utils/img_processing_utils.py b/HypVINN/utils/img_processing_utils.py new file mode 100644 index 00000000..f96adc29 --- /dev/null +++ b/HypVINN/utils/img_processing_utils.py @@ -0,0 +1,293 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import numpy as np +from numpy import typing as npt +import nibabel as nib +from skimage.measure import label +from scipy import ndimage + +import FastSurferCNN.utils.logging as logging +from HypVINN.data_loader.data_utils import hypo_map_subseg_2_fsseg + +LOGGER = logging.get_logger(__name__) + + +def img2axcodes(img: nib.Nifti1Image) -> tuple: + """ + Convert the affine matrix of an image to axis codes. + + This function takes an image as input and returns the axis codes corresponding to the affine matrix of the image. + + Parameters + ---------- + img : nibabel image object + The input image. + + Returns + ------- + tuple + The axis codes corresponding to the affine matrix of the image. + """ + return nib.aff2axcodes(img.affine) + + +def save_segmentation( + prediction: np.ndarray, + orig_path: Path, + ras_affine: npt.NDArray[float], + ras_header: nib.nifti1.Nifti1Header | nib.nifti2.Nifti2Header | nib.freesurfer.mghformat.MGHHeader, + subject_dir: Path, + seg_file: str, + mask_file: str, + save_mask: bool = False, +) -> float: + """ + Save the segmentation results. + + This function takes the prediction results, cleans the labels, maps them to FreeSurfer Hypvinn Labels, and saves + the results. It also reorients the mask and prediction images to match the original image's orientation. + + Parameters + ---------- + prediction : np.ndarray + The prediction results. + orig_path : Path + The path to the original image. + ras_affine : npt.NDArray[float] + The affine transformation of the RAS orientation. + ras_header : nibabel header object + The header of the RAS orientation. + subject_dir : Path + The directory where the subject's data is stored. + seg_file : Path + The file where the segmentation will be saved (relative to subject_dir/mri). + mask_file : str + The file where the mask will be saved (relative to subject_dir/mri). + save_mask : bool, default=False + Whether to save the mask or not. Default is False. + + Returns + ------- + float + The time taken to save the segmentation. + + """ + from time import time + starttime = time() + from HypVINN.data_loader.data_utils import reorient_img + + pred_arr, labels_cc = get_clean_labels(np.array(prediction, dtype=np.uint8)) + # Mapped HypVINN labelst to FreeSurfer Hypvinn Labels + pred_arr = hypo_map_subseg_2_fsseg(pred_arr) + orig_img = nib.load(orig_path) + LOGGER.info(f"Orig data orientation : {img2axcodes(orig_img)}") + + if save_mask: + mask_img = nib.Nifti1Image(labels_cc, affine=ras_affine, header=ras_header) + LOGGER.info(f"HypVINN Mask orientation: {img2axcodes(mask_img)}") + mask_img = reorient_img(mask_img, orig_img) + LOGGER.info( + f"HypVINN Mask after re-orientation: {img2axcodes(mask_img)}" + ) + nib.save(mask_img, subject_dir / "mri" / mask_file) + + pred_img = nib.Nifti1Image(pred_arr, affine=ras_affine, header=ras_header) + LOGGER.info(f"HypVINN Prediction orientation: {img2axcodes(pred_img)}") + pred_img = reorient_img(pred_img, orig_img) + LOGGER.info( + f"HypVINN Prediction after re-orientation: {img2axcodes(pred_img)}" + ) + pred_img.set_data_dtype(np.int16) # Maximum value 984 + nib.save(pred_img, subject_dir / "mri" / seg_file) + return time() - starttime + + +def save_logits( + logits: npt.NDArray[float], + orig_path: Path, + ras_affine: npt.NDArray[float], + ras_header: nib.nifti1.Nifti1Header | nib.nifti2.Nifti2Header | nib.freesurfer.mghformat.MGHHeader, + save_dir: Path, + mode: str, +) -> Path: + """ + Save the logits (raw model outputs) as a NIfTI image. + + This function takes the logits, reorients the image to match the original image's orientation, and saves the + results. + + Parameters + ---------- + logits : npt.NDArray[float] + The raw model outputs. + orig_path : Path + The path to the original image. + ras_affine : npt.NDArray[float] + The affine transformation of the RAS orientation. + ras_header : nib.nifti1.Nifti1Header + The header of the RAS orientation. + save_dir : Path + The directory where the logits will be saved. + mode : str + The mode of operation. + + Returns + ------- + save_as: Path + The path where the logits were saved. + + """ + from HypVINN.data_loader.data_utils import reorient_img + orig_img = nib.load(orig_path) + LOGGER.info(f"Orig data orientation: {img2axcodes(orig_img)}") + nifti_img = nib.Nifti1Image( + logits.astype(np.float32), + affine=ras_affine, + header=ras_header, + ) + LOGGER.info(f"HypVINN logits orientation: {img2axcodes(nifti_img)}") + nifti_img = reorient_img(nifti_img, orig_img) + LOGGER.info( + f"HypVINN logits after re-orientation: {img2axcodes(nifti_img)}" + ) + nifti_img.set_data_dtype(np.float32) + save_as = save_dir / f"HypVINN_logits_{mode}.nii.gz" + nib.save(nifti_img, save_as) + return save_as + + +def get_clean_mask(segmentation: np.ndarray, optic=False) \ + -> tuple[np.ndarray, np.ndarray, bool]: + """ + Get a clean mask by removing non-connected components from a dilated mask. + + This function takes a segmentation mask and an optional boolean flag indicating whether to consider optic labels. + It removes not connected components from the segmentation mask and returns the cleaned segmentation mask, the + labels of the connected components, and a flag indicating whether to save the mask. + + Parameters + ---------- + segmentation : np.ndarray + The input segmentation mask. + optic : bool, default=False + A flag indicating whether to consider optic labels. Default is False. + + Returns + ------- + clean_seg : np.ndarray + The cleaned segmentation mask. + labels_cc : np.ndarray + The labels of the connected components in the segmentation mask. + savemask : bool + A flag indicating whether to save the mask. + + """ + savemask = False + + # Remove not connected components + if optic: + iterations = 7 + # Remove not connected from optics components + copy_segmentation = np.zeros_like(segmentation) + copy_segmentation[segmentation == 1] = 1 + copy_segmentation[segmentation == 2] = 2 + copy_segmentation[segmentation == 4] = 4 + copy_segmentation[segmentation == 5] = 5 + else: + iterations = 5 + copy_segmentation = segmentation.copy() + # remove optic structures + copy_segmentation[segmentation == 1] = 0 + copy_segmentation[segmentation == 2] = 0 + copy_segmentation[segmentation == 4] = 0 + copy_segmentation[segmentation == 5] = 0 + + struct1 = ndimage.generate_binary_structure(3, 3) + mask = ndimage.binary_dilation( + copy_segmentation, + structure=struct1, + iterations=iterations, + ).astype(np.uint8) + labels_cc = label(mask, connectivity=3, background=0) + bincount = np.bincount(labels_cc.flat) + + if len(bincount) > 2: + if optic: + LOGGER.info("Check Optic Labels") + else: + LOGGER.info("Check Hypothalamus Labels") + savemask = True + + background = np.argmax(bincount) + bincount[background] = -1 + largest_cc = labels_cc == np.argmax(bincount) + clean_seg = copy_segmentation * largest_cc + + # remove globus pallidus + clean_seg[clean_seg == 13] = 0 + clean_seg[clean_seg == 20] = 0 + + return clean_seg, labels_cc, savemask + + +def get_clean_labels(segmentation: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Get clean labels by removing non-connected components from a dilated mask and any connected component with size + less than 3. + + Parameters + ---------- + segmentation: np.ndarray + The segmentation mask. + + Returns + ------- + clean_seg: np.ndarray + The cleaned segmentation mask. + labels_cc: np.ndarray + The labels of the connected components in the segmentation mask. + """ + + # Mask largest CC without optic labels + clean_seg, labels_cc, savemask = get_clean_mask(segmentation) + # Mask largest CC from optic labels + optic_clean_seg, optic_labels_cc, optic_savemask = get_clean_mask(segmentation, optic=True) + + # clean segmentation from both largest_cc + clean_seg = clean_seg + optic_clean_seg + + # mask from both largest_cc + optic_mask = optic_labels_cc > 0 + other_mask = labels_cc > 0 + # multiplication times one to change from boolean + non_intersect = (optic_mask * 1 - other_mask * 1) * optic_mask + + optic_labels_cc += np.max(np.unique(labels_cc)) + labels_cc = labels_cc + optic_labels_cc * non_intersect + + # remove small group of voxels less than 3 + small_mask = clean_seg > 0 + labels_small = label(small_mask, connectivity=3, background=0) + bincount_small = np.bincount(labels_small.flat) + idx = np.where(bincount_small <= 3) + if idx[0].any(): + for i in idx[0]: + small_mask[labels_small == i] = False + + clean_seg = clean_seg * small_mask + + return clean_seg, labels_cc diff --git a/HypVINN/utils/load_config.py b/HypVINN/utils/load_config.py new file mode 100644 index 00000000..ad9fd348 --- /dev/null +++ b/HypVINN/utils/load_config.py @@ -0,0 +1,74 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from os.path import join, split, splitext + +from HypVINN.config.hypvinn import get_cfg_hypvinn + + +def get_config(args): + """ + Given the arguments, load and initialize the configs. + + Parameters + ---------- + args : object + The arguments object. + Returns + ------- + cfg : yacs.config.CfgNode + The configuration node. + """ + # Setup cfg. + cfg = get_cfg_hypvinn() + # Load config from cfg. + if args.cfg_file is not None: + cfg.merge_from_file(args.cfg_file) + # Load config from command line, overwrite config from opts. + if args.opts is not None: + cfg.merge_from_list(args.opts) + + if hasattr(args, "rng_seed"): + cfg.RNG_SEED = args.rng_seed + if hasattr(args, "output_dir"): + cfg.LOG_DIR = args.LOG_dir + + cfg_file_name = splitext(split(args.cfg_file)[1])[0] + cfg.LOG_DIR = join(cfg.LOG_DIR, cfg_file_name) + + return cfg + +def load_config(cfg_file): + """ + Load and initialize the configuration from a given file. + + Parameters + ---------- + cfg_file : str + The path to the configuration file. The function will load configurations from this file. + + Returns + ------- + cfg : yacs.config.CfgNode + The configuration node, loaded and initialized with the given file. + """ + # setup base + cfg = get_cfg_hypvinn() + cfg.EXPR_NUM = None + cfg.SUMMARY_PATH = "" + cfg.CONFIG_LOG_PATH = "" + cfg.TRAIN.RESUME_EXPR_NUM = None + # Overwrite with stored arguments + cfg.merge_from_file(cfg_file) + return cfg \ No newline at end of file diff --git a/HypVINN/utils/misc.py b/HypVINN/utils/misc.py new file mode 100644 index 00000000..c018f0bd --- /dev/null +++ b/HypVINN/utils/misc.py @@ -0,0 +1,27 @@ +from pathlib import Path + + +def create_expand_output_directory( + subject_dir: Path, + qc_snapshots: bool = False, +) -> None: + """ + Create the output directories for HypVINN. + + Parameters + ---------- + subject_dir : Path + The path to the subject directory. + qc_snapshots : bool, default=False + Whether the qc_snapshots directory should be created. + """ + paths = [ + subject_dir, + subject_dir / "mri" / "transforms", + subject_dir / "stats", + ] + if qc_snapshots: + paths.append(subject_dir / "qc_snapshots") + + for path in paths: + path.mkdir(parents=True, exist_ok=True) diff --git a/HypVINN/utils/mode_config.py b/HypVINN/utils/mode_config.py new file mode 100644 index 00000000..6f140265 --- /dev/null +++ b/HypVINN/utils/mode_config.py @@ -0,0 +1,87 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +from typing import Optional + +from FastSurferCNN.utils import logging +from HypVINN.utils import ModalityMode + +LOGGER = logging.get_logger(__name__) + + +def get_hypinn_mode( + t1_path: Optional[Path] = None, + t2_path: Optional[Path] = None, +) -> ModalityMode: + """ + Determine the input mode for HypVINN based on the existence of T1 and T2 files. + + This function checks the existence of T1 and T2 files based on the provided paths. + + Parameters + ---------- + t1_path : Path, optional + The path to the T1 file. + t2_path : Path, optional + The path to the T2 file. + + Returns + ------- + ModalityMode + The input mode for HypVINN, which can be "t1t2", "t1", or "t2". + + Raises + ------ + RuntimeError + If neither T1 nor T2 files exist, or if the corresponding flags were passed but the files do not exist. + """ + LOGGER.info("Setting up input mode...") + if t1_path is not None and t2_path is not None: + if t1_path.is_file() and t2_path.is_file(): + return "t1t2" + msg = [] + if not t1_path.is_file(): + msg.append(f"the t1 file does not exist ({t1_path})") + if not t2_path.is_file(): + msg.append(f"the t2 file does not exist ({t2_path})") + raise RuntimeError( + f"ERROR: Both the t1 and the t2 flags were passed, but " + f"{' and '.join(msg)}." + ) + + elif t1_path: + if t1_path.is_file(): + return "t1" + raise RuntimeError( + f"ERROR: The t1 flag was passed, but the t1 file does not exist " + f"({t1_path})." + ) + elif t2_path: + if t2_path.is_file(): + LOGGER.info( + "Warning: T2 mode selected. The quality of segmentations based " + "on only a T2 image is significantly worse than when T1 images " + "are included." + ) + return "t2" + raise RuntimeError( + f"ERROR: The t2 flag was passed, but the t1 file does not exist " + f"({t1_path})." + ) + else: + raise RuntimeError( + "No t1 or t2 flags were passed, invalid configuration." + ) diff --git a/HypVINN/utils/preproc.py b/HypVINN/utils/preproc.py new file mode 100644 index 00000000..c4dbbc19 --- /dev/null +++ b/HypVINN/utils/preproc.py @@ -0,0 +1,242 @@ +# Copyright 2024 +# AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from pathlib import Path +import os +from typing import cast + +import nibabel as nib +import numpy as np + +from FastSurferCNN.utils import logging +from HypVINN.utils import ModalityMode, RegistrationMode + +LOGGER = logging.get_logger(__name__) + + +def t1_to_t2_registration( + t1_path: Path, + t2_path: Path, + output_path: Path, + lta_path: Path, + registration_type: RegistrationMode = "coreg", + threads: int = -1, +) -> Path: + """ + Register T1 to T2 images using either mri_coreg or mri_robust_register. + + Parameters + ---------- + t1_path : Path + The path to the T1 image. + t2_path : Path + The path to the T2 image. + output_path : Path + The path to the output/registered image. + lta_path : Path + The path to the lta transform. + registration_type : RegistrationMode, default="coreg" + The type of registration to be used. It can be either "coreg" or "robust". + threads : int, default=-1 + The number of threads to be used. If it is less than or equal to 0, the number + of threads will be automatically determined. + + Returns + ------- + Path + The path to the registered T2 image. + + Raises + ------ + RuntimeError + If mri_coreg, mri_vol2vol, or mri_robust_register fails to run or if they cannot + be found. + """ + from FastSurferCNN.utils.run_tools import Popen + from FastSurferCNN.utils.threads import get_num_threads + import shutil + + if threads <= 0: + threads = get_num_threads() + + if registration_type == "coreg": + exe = shutil.which("mri_coreg") + if not bool(exe): + if os.environ.get("FREESURFER_HOME", ""): + exe = os.environ["FREESURFER_HOME"] + "/bin/mri_coreg" + else: + raise RuntimeError( + "Could not find mri_coreg, source FreeSurfer or set the " + "FREESURFER_HOME environment variable" + ) + args = [exe, "--mov", t2_path, "--targ", t1_path, "--reg", lta_path] + args = list(map(str, args)) + ["--threads", str(threads)] + LOGGER.info("Running " + " ".join(args)) + retval = Popen(args).finish() + if retval.retcode != 0: + LOGGER.error(f"mri_coreg failed with error code {retval.retcode}. ") + raise RuntimeError("mri_coreg failed registration") + + else: + LOGGER.info(f"{exe} finished in {retval.runtime}!") + exe = shutil.which("mri_vol2vol") + if not bool(exe): + if os.environ.get("FREESURFER_HOME", ""): + exe = os.environ["FREESURFER_HOME"] + "/bin/mri_vol2vol" + else: + raise RuntimeError( + "Could not find mri_vol2vol, source FreeSurfer or set " + "the FREESURFER_HOME environment variable" + ) + args = [ + exe, + "--mov", t2_path, + "--targ", t1_path, + "--reg", lta_path, + "--o", output_path, + "--cubic", + "--keep-precision", + ] + args = list(map(str, args)) + LOGGER.info("Running " + " ".join(args)) + retval = Popen(args).finish() + if retval.retcode != 0: + LOGGER.error( + f"mri_vol2vol failed with error code {retval.retcode}." + ) + raise RuntimeError("mri_vol2vol failed applying registration") + LOGGER.info(f"{exe} finished in {retval.runtime}!") + else: + exe = shutil.which("mri_robust_register") + if not bool(exe): + if os.environ.get("FREESURFER_HOME", ""): + exe = os.environ["FREESURFER_HOME"] + "/bin/mri_robust_register" + else: + raise RuntimeError( + "Could not find mri_robust_register, source FreeSurfer or " + "set the FREESURFER_HOME environment variable" + ) + args = [ + exe, + "--mov", t2_path, + "--dst", t1_path, + "--lta", lta_path, + "--mapmov", output_path, + "--cost NMI", + ] + args = list(map(str, args)) + LOGGER.info("Running " + " ".join(args)) + retval = Popen(args).finish() + if retval.retcode != 0: + LOGGER.error( + f"mri_robust_register failed with error code {retval.retcode}." + ) + raise RuntimeError("mri_robust_register failed registration") + LOGGER.info(f"{exe} finished in {retval.runtime}!") + + return output_path + + +def hypvinn_preproc( + mode: ModalityMode, + reg_mode: RegistrationMode, + t1_path: Path, + t2_path: Path, + subject_dir: Path, + threads: int = -1, +) -> Path: + """ + Preprocess the input images for HypVINN. + + Parameters + ---------- + mode : ModalityMode + The mode for HypVINN. It should be "t1t2". + reg_mode : RegistrationMode + The registration mode. If it is not "none", the function will register T1 to T2 + images. + t1_path : Path + The path to the T1 image. + t2_path : Path + The path to the T2 image. + subject_dir : Path + The directory of the subject. + threads : int, default=-1 + The number of threads to be used. If it is less than or equal to 0, the number + of threads will be automatically determined. + + Returns + ------- + Path + The path to the preprocessed T2 image. + + Raises + ------ + RuntimeError + If the mode is not "t1t2", or if the registration mode is not "none" and the + registration fails. + """ + if mode != "t1t2": + raise RuntimeError( + "hypvinn_preproc should only be called for t1t2 mode." + ) + registered_t2_path = subject_dir / "mri/T2_nu_reg.mgz" + if reg_mode != "none": + from nibabel.analyze import AnalyzeImage + load_res = time.time() + # Print Warning if Resolution from both images is different + t1_zoom = cast(AnalyzeImage, nib.load(t1_path)).header.get_zooms() + t2_zoom = cast(AnalyzeImage, nib.load(t2_path)).header.get_zooms() + + if not np.allclose(np.array(t1_zoom), np.array(t2_zoom), rtol=0.05): + LOGGER.warning( + f"Resolution from T1 {t1_zoom} and T2 {t2_zoom} image are different.\n" + f"T2 image will be interpolated to the resolution of the T1 image." + ) + + LOGGER.info("Registering T1 to T2 ...") + t1_to_t2_registration( + t1_path=t1_path, + t2_path=t2_path, + output_path=registered_t2_path, + lta_path=subject_dir / "mri/transforms/t2tot1.lta", + registration_type=reg_mode, + threads=threads, + ) + LOGGER.info( + f"Registration finish in {time.time() - load_res:0.4f} seconds!" + ) + else: + LOGGER.info( + "No registration step, registering T1w and T2w is required when running " + "the multi-modal input mode.\nUnregistered images can generate wrong " + "predictions. Ignore this message, if input images are already registered." + ) + try: + registered_t2_path.symlink_to(os.path.relpath(t2_path, registered_t2_path)) + except FileNotFoundError as e: + msg = (f"Could not create symlink. " + f"Does the folder {registered_t2_path.parent} exist?") + LOGGER.error(msg) + raise FileNotFoundError(msg) from e + except (RuntimeError, OSError): + LOGGER.info(f"Could not create symlink for {registered_t2_path}, copying.") + from shutil import copy + copy(t2_path, registered_t2_path) + + LOGGER.info("---" * 30) + + return registered_t2_path diff --git a/HypVINN/utils/stats_utils.py b/HypVINN/utils/stats_utils.py new file mode 100644 index 00000000..5fee8ca5 --- /dev/null +++ b/HypVINN/utils/stats_utils.py @@ -0,0 +1,77 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + + +def compute_stats( + orig_path: Path, + prediction_path: Path, + stats_dir: Path, + threads: int, +) -> int | str: + """ + Compute statistics for the segmentation results. + + Parameters + ---------- + orig_path : Path + The path to the original image. + prediction_path : Path + The path to the predicted segmentation. + stats_dir : Path + The directory for storing the statistics. + threads : int + The number of threads to be used. + + Returns + ------- + int, str + The return value of the main function from FastSurferCNN.segstats. + Exit code. Returns 0 upon successful execution. + + Raises + ------ + RuntimeError + If the main function from FastSurferCNN.segstats fails to run. + """ + from collections import namedtuple + + from FastSurferCNN.utils.checkpoint import FASTSURFER_ROOT + from FastSurferCNN.segstats import main + from HypVINN.config.hypvinn_files import HYPVINN_STATS_NAME + from HypVINN.config.hypvinn_global_var import FS_CLASS_NAMES + + args = namedtuple( + "ArgNamespace", + ["normfile", "i", "o", "excludedid", "ids", "merged_labels", + "robust", "threads", "patch_size", "device", "lut", "allow_root"]) + + labels = [v for v in FS_CLASS_NAMES.values() if v != 0] + + args.normfile = orig_path + args.segfile = prediction_path + args.segstatsfile = stats_dir / HYPVINN_STATS_NAME + args.excludeid = [0] + args.ids = labels + args.merged_labels = [] + args.robust = None + args.threads = threads + args.patch_size = 32 + args.device = "auto" + args.lut = FASTSURFER_ROOT / "FastSurferCNN/config/FreeSurferColorLUT.txt" + # We check for this in the parent code + # TODO: it would be better to populate this properly + args.allow_root = True + return main(args) diff --git a/HypVINN/utils/visualization_utils.py b/HypVINN/utils/visualization_utils.py new file mode 100644 index 00000000..a43aaf9b --- /dev/null +++ b/HypVINN/utils/visualization_utils.py @@ -0,0 +1,322 @@ +# Copyright 2024 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path +from pathlib import Path + +import numpy as np +import nibabel as nib +import matplotlib.pyplot as plt + +from HypVINN.config.hypvinn_files import HYPVINN_LUT +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +_doc_HYPVINN_LUT = os.path.relpath(HYPVINN_LUT, FASTSURFER_ROOT) + + +def remove_values_from_list(the_list, val): + """ + Removes values from a list. + + Parameters + ---------- + the_list : list + The original list from which values will be removed. + val : any + The value to be removed from the list. + + Returns + ------- + list + A new list with the specified value removed. + """ + return [value for value in the_list if value != val] + + +def get_lut(lookup_table_path: Path = HYPVINN_LUT): + f""" + Retrieve a color lookup table (LUT) from a file. + + This function reads a file and constructs a lookup table (LUT) from it. + + Parameters + ---------- + lookup_table_path: Path, default="{_doc_HYPVINN_LUT}" + The path to the file from which the LUT will be constructed. + + Returns + ------- + lut: OrderedDict + The constructed LUT as an ordered dictionary. + """ + from collections import OrderedDict + lut = OrderedDict() + with open(lookup_table_path, "r") as f: + for line in f: + if line[0] == "#" or line[0] == "\n": + pass + else: + clean_line = remove_values_from_list(line.split(" "), "") + rgb = [int(clean_line[2]), int(clean_line[3]), int(clean_line[4])] + lut[str(clean_line[0])] = rgb + return lut + + +def map_hyposeg2label(hyposeg: np.ndarray, lut_file: Path = HYPVINN_LUT): + f""" + Map a HypVINN segmentation to a continuous label space using a lookup table. + + Parameters + ---------- + hyposeg : np.ndarray + The original segmentation map. + lut_file : Path, default="{_doc_HYPVINN_LUT}" + The path to the lookup table file. + + Returns + ------- + mapped_hyposeg : ndarray + The mapped segmentation. + cmap : ListedColormap + The colormap for the mapped segmentation. + """ + import matplotlib.colors + + labels = np.unique(hyposeg) + + labels = np.int16(labels) + # retrieve freesurfer color map lookup table + cdict = get_lut(lut_file) + colors = np.zeros((len(labels), 3)) + # colors = list() + mapped_hyposeg = np.zeros_like(hyposeg) + + for idx, value in enumerate(labels): + mapped_hyposeg[hyposeg == value] = idx + r, g, b = cdict[str(value)] + colors[idx] = [r, g, b] + + colors = np.divide(colors, 255) + cmap = matplotlib.colors.ListedColormap(colors) + + return mapped_hyposeg, cmap + + +def plot_coronal_predictions(cmap, images_batch=None, pred_batch=None, img_per_row=8): + """ + Plot the predicted segmentations on a grid layout. + + Parameters + ---------- + cmap : matplotlib.colors.Colormap + The colormap to be used for the predicted segmentations. + images_batch : np.ndarray, optional + The batch of input images. If not provided, the function will not plot anything. + pred_batch : np.ndarray, optional + The batch of predicted segmentations. If not provided, the function will not plot anything. + img_per_row : int, default=8 + The number of images to be plotted per row in the grid layout. + + Returns + ------- + fig: matplotlib.figure.Figure + The figure containing the plotted images and predictions. + + """ + import matplotlib.pyplot as plt + import torch + from torchvision import utils + plt.ioff() + + FIGSIZE = 3 + # FIGDPI = dpi + + ncols = 1 + nrows = 2 + + fig, ax = plt.subplots(nrows, ncols) + + grid_size = (images_batch.shape[0] / img_per_row, img_per_row) + + # adjust layout + fig.set_size_inches([FIGSIZE * ncols * grid_size[1], FIGSIZE * nrows * grid_size[0]]) + # fig.set_dpi(FIGDPI) + fig.set_facecolor("black") + fig.set_tight_layout({"pad": 0}) + fig.subplots_adjust(wspace=0, hspace=0) + + pos = 0 + + images = torch.from_numpy(images_batch.copy()) + images = torch.unsqueeze(images, 1) + grid = utils.make_grid(images.cpu(), nrow=img_per_row, normalize=True) + # ax[pos].imshow(grid.numpy().transpose(1, 2, 0), cmap="gray",origin="lower") + ax[pos].imshow(grid.numpy().transpose(1, 2, 0), cmap="gray", origin="lower") + ax[pos].set_axis_off() + ax[pos].set_aspect("equal") + ax[pos].margins(0, 0) + ax[pos].set_title("T1w input image (1 to N). Coronal orientation from right (R) to left (L).", color="white") + pos += 1 + + pred = torch.from_numpy(pred_batch.copy()) + pred = torch.unsqueeze(pred, 1) + pred_grid = utils.make_grid(pred.cpu(), nrow=img_per_row)[0] # dont take the channels axis from grid + # pred_grid=color.label2rgb(pred_grid.numpy(),grid.numpy().transpose(1, 2, 0),alpha=0.6,bg_label=0,colors=DEFAULT_COLORS) + # pred_grid = color.label2rgb(pred_grid.numpy(), grid.numpy().transpose(1, 2, 0), alpha=0.6, bg_label=0,bg_color=None,colors=DEFAULT_COLORS) + + alphas = np.ones(pred_grid.numpy().shape) * 0.8 + alphas[pred_grid.numpy() == 0] = 0 + + ax[pos].imshow(grid.numpy().transpose(1, 2, 0), cmap="gray", origin="lower") + ax[pos].imshow(pred_grid.numpy(), cmap=cmap, interpolation="none", alpha=alphas, origin="lower") + ax[pos].set_axis_off() + ax[pos].set_aspect("equal") + ax[pos].margins(0, 0) + ax[pos].set_title("Predictions (1 to N). Coronal orientation from right (R) to left (L).", color="white") + ax[pos].margins(0, 0) + + return fig + + +def select_index_to_plot(hyposeg, slice_step=2): + """ + Select indices to plot based on the given segmentation map. + + Parameters + ---------- + hyposeg : np.ndarray + The segmentation map from which indices will be selected. + slice_step : int, default=2 + The step size for selecting indices from the remaining indices after removing certain indices. + + Returns + ------- + list + The selected indices, sorted in ascending order. + """ + # slices with labels + idx = np.where(hyposeg > 0) + idx = np.unique(idx[0]) + # get slices with 3rd ventricle + idx_with_third_ventricle = np.unique(np.where(hyposeg == 10)[0]) + # get slices with only 3rd ventricle + idx_only_third_ventricle = [] + for i in idx_with_third_ventricle: + label = np.unique(hyposeg[i]) + # Background is allways at position 0 + if label[1] == 10: + idx_only_third_ventricle.append(i) + # Remove slices with only third ventricle from the total + idx = list(set(list(idx)) - set(idx_only_third_ventricle)) + # get slices with hyppthalamus variables + idx_hypo = np.where(hyposeg > 100) + idx_hypo = np.unique(idx_hypo[0]) + # remove hypo_varaibles from the list + idx = list(set(list(idx)) - set(idx_hypo)) + # optic nerve index + idx_with_optic_nerve = np.unique(np.where((hyposeg <= 2) & (hyposeg > 0))[0]) + # remove index from list + idx = list(set(list(idx)) - set(idx_with_optic_nerve)) + # take optic nerve every 4 slices + idx_with_optic_nerve = idx_with_optic_nerve[::4] + # from the remaining slices only take increments by slice step default 2 + idx = idx[::slice_step] + # Add slices with hypothalamus variables and optic nerve + idx.extend(idx_hypo) + idx.extend(idx_with_optic_nerve) + + return sorted(idx) + + +def plot_qc_images( + subject_qc_dir: Path, + orig_path: Path, + prediction_path: Path, + padd: int = 45, + lut_file: Path = HYPVINN_LUT, + slice_step: int = 2): + f""" + Plot the quality control images for the subject. + + Parameters + ---------- + subject_qc_dir : Path + The directory for the subject. + orig_path : Path + The path to the original image. + prediction_path : Path + The path to the predicted image. + padd : int, default=45 + The padding value for cropping the images and segmentations. + lut_file : Path, default="{_doc_HYPVINN_LUT}" + The path to the lookup table file. + slice_step : int, default=2 + The step size for selecting indices from the predicted segmentation. + """ + from scipy import ndimage + + from HypVINN.data_loader.data_utils import transform_axial2coronal, hypo_map_subseg_2_fsseg + from HypVINN.config.hypvinn_files import HYPVINN_QC_IMAGE_NAME + + subject_qc_dir.mkdir(exist_ok=True, parents=True) + + image = nib.as_closest_canonical(nib.load(orig_path)) + pred = nib.as_closest_canonical(nib.load(prediction_path)) + pred_arr = hypo_map_subseg_2_fsseg(np.asarray(pred.dataobj, dtype=np.int16), reverse=True) + + mod_image = transform_axial2coronal(image.get_fdata()) + mod_image = np.transpose(mod_image, (2, 0, 1)) + mod_pred = transform_axial2coronal(pred_arr) + mod_pred = np.transpose(mod_pred, (2, 0, 1)) + + idx = select_index_to_plot(hyposeg=mod_pred, slice_step=slice_step) + + hypo_seg, cmap = map_hyposeg2label(hyposeg=mod_pred, lut_file=lut_file) + + if len(idx) > 0: + + crop_image = mod_image[idx, :, :] + + crop_seg = hypo_seg[idx, :, :] + + cm = ndimage.center_of_mass(crop_seg > 0) + + cm = np.asarray(cm).astype(int) + + crop_image = crop_image[:, cm[1] - padd:cm[1] + padd, cm[2] - padd:cm[2] + padd] + crop_seg = crop_seg[:, cm[1] - padd:cm[1] + padd, cm[2] - padd:cm[2] + padd] + + else: + depth = hypo_seg.shape[0] // 2 + crop_image = mod_image[depth - 8:depth + 8, :, :] + crop_seg = hypo_seg[depth - 8:depth + 8, :, :] + + cm = [crop_image.shape[0] // 2, crop_image.shape[1] // 2, crop_image.shape[2] // 2] + cm = np.array(cm).astype(int) + + crop_image = crop_image[:, cm[1] - padd:cm[1] + padd, cm[2] - padd:cm[2] + padd] + crop_seg = crop_seg[:, cm[1] - padd:cm[1] + padd, cm[2] - padd:cm[2] + padd] + + crop_image = np.rot90(np.flip(crop_image, axis=0), k=-1, axes=(1, 2)) + crop_seg = np.rot90(np.flip(crop_seg, axis=0), k=-1, axes=(1, 2)) + + fig = plot_coronal_predictions( + cmap=cmap, + images_batch=crop_image, + pred_batch=crop_seg, + img_per_row=crop_image.shape[0], + ) + + fig.savefig(subject_qc_dir / HYPVINN_QC_IMAGE_NAME, transparent=False) + + plt.close(fig) diff --git a/LICENSE b/LICENSE index 261eeb9e..dd5b3a58 100644 --- a/LICENSE +++ b/LICENSE @@ -172,30 +172,3 @@ defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/README.md b/README.md index 499710cc..8eb33540 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,9 @@ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Deep-MI/FastSurfer/blob/stable/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Deep-MI/FastSurfer/blob/stable/Tutorial/Complete_FastSurfer_Tutorial.ipynb) -# Overview + +# Welcome to FastSurfer! +## Overview This README contains all information needed to run FastSurfer - a fast and accurate deep-learning based neuroimaging pipeline. FastSurfer provides a fully compatible [FreeSurfer](https://freesurfer.net/) alternative for volumetric analysis (within minutes) and surface-based thickness analysis (within only around 1h run time). FastSurfer is transitioning to sub-millimeter resolution support throughout the pipeline. @@ -13,41 +15,52 @@ The FastSurfer pipeline consists of two main parts for segmentation and surface - the segmentation sub-pipeline (`seg`) employs advanced deep learning networks for fast, accurate segmentation and volumetric calculation of the whole brain and selected substructures. - the surface sub-pipeline (`recon-surf`) reconstructs cortical surfaces, maps cortical labels and performs a traditional point-wise and ROI thickness analysis. + ### Segmentation Modules -- approximately 5 minutes (GPU), `--seg_only` only runs this part -- Modules (all by default): - 1. `asegdkt:` FastSurferVINN for whole brain segmentation (deactivate with `--no_asegdkt`) - - the core, outputs anatomical segmentation and cortical parcellation and statistics of 95 classes, mimics FreeSurfer’s DKTatlas. - - requires a T1w image ([notes on input images](#requirements-to-input-images)), supports high-res (up to 0.7mm, experimental beyond that). - - performs bias-field correction and calculates volume statistics corrected for partial volume effects (skipped if `--no_biasfield` is passed). - 2. `cereb:` CerebNet for cerebellum sub-segmentation (deactivate with `--no_cereb`) - - requires `asegdkt_segfile`, outputs cerebellar sub-segmentation with detailed WM/GM delineation. - - requires a T1w image ([notes on input images](#requirements-to-input-images)), which will be resampled to 1mm isotropic images (no native high-res support). - - calculates volume statistics corrected for partial volume effects (skipped if `--no_biasfield` is passed). +- approximately 5 minutes (GPU), `--seg_only` only runs this part. + +Modules (all run by default): +1. `asegdkt:` [FastSurferVINN](FastSurferCNN/README.md) for whole brain segmentation (deactivate with `--no_asegdkt`) + - the core, outputs anatomical segmentation and cortical parcellation and statistics of 95 classes, mimics FreeSurfer’s DKTatlas. + - requires a T1w image ([notes on input images](#requirements-to-input-images)), supports high-res (up to 0.7mm, experimental beyond that). + - performs bias-field correction and calculates volume statistics corrected for partial volume effects (skipped if `--no_biasfield` is passed). +2. `cereb:` [CerebNet](CerebNet/README.md) for cerebellum sub-segmentation (deactivate with `--no_cereb`) + - requires `asegdkt_segfile`, outputs cerebellar sub-segmentation with detailed WM/GM delineation. + - requires a T1w image ([notes on input images](#requirements-to-input-images)), which will be resampled to 1mm isotropic images (no native high-res support). + - calculates volume statistics corrected for partial volume effects (skipped if `--no_biasfield` is passed). +3. `hypothal`: [HypVINN](HypVINN/README.md) for hypothalamus subsegmentation (deactivate with `no_hypothal`) + - outputs a hypothalamic subsegmentation including 3rd ventricle, c. mammilare, fornix and optic tracts. + - a T1w image is highly recommended ([notes on input images](#requirements-to-input-images)), supports high-res (up to 0.7mm, but experimental beyond that). + - allows the additional passing of a T2w image with `--t2 `, which will be registered to the T1w image (see `--reg_mode` option). + - calculates volume statistics corrected for partial volume effects based on the T1w image (skipped if `--no_bias_field` is passed). ### Surface reconstruction -- approximately 60-90 minutes, `--surf_only` runs only the surface part -- supports high-resolution images (up to 0.7mm, experimental beyond that) +- approximately 60-90 minutes, `--surf_only` runs only [the surface part](recon_surf/README.md). +- supports high-resolution images (up to 0.7mm, experimental beyond that). + ### Requirements to input images All pipeline parts and modules require good quality MRI images, preferably from a 3T MR scanner. FastSurfer expects a similar image quality as FreeSurfer, so what works with FreeSurfer should also work with FastSurfer. Notwithstanding module-specific limitations, resolution should be between 1mm and 0.7mm isotropic (slice thickness should not exceed 1.5mm). Preferred sequence is Siemens MPRAGE or multi-echo MPRAGE. GE SPGR should also work. See `--vox_size` flag for high-res behaviour. + + +![](doc/images/teaser.png) -![](/images/teaser.png) + +## Getting started -# Getting started -## Installation +### Installation There are two ways to run FastSurfer (links are to installation instructions): -1. In a container ([Singularity](INSTALL.md#singularity) or [Docker](INSTALL.md#docker)) (OS: [Linux](INSTALL.md#linux), [Windows](INSTALL.md#windows), [MacOS on Intel](INSTALL.md#docker--currently-only-supported-for-intel-cpus-)), -2. As a [native install](INSTALL.md#native--ubuntu-2004-) (all OS for segmentation part). +1. In a container ([Singularity](doc/overview/INSTALL.md#singularity) or [Docker](doc/overview/INSTALL.md#docker)) (OS: [Linux](doc/overview/INSTALL.md#linux), [Windows](doc/overview/INSTALL.md#windows), [MacOS on Intel](doc/overview/INSTALL.md#docker-currently-only-supported-for-intel-cpus)), +2. As a [native install](doc/overview/INSTALL.md#native-ubuntu-2004-or-ubuntu-2204) (all OS for segmentation part). -We recommended you use Singularity or Docker, especially if either is already installed on your system, because the [images we provide](https://hub.docker.com/r/deepmi/fastsurfer) conveniently include everything needed for FastSurfer, expect a [FreeSurfer license file](https://surfer.nmr.mgh.harvard.edu/fswiki/License). We have detailed, per-OS Installation instructions in the [INSTALL.md file](INSTALL.md). +We recommended you use Singularity or Docker on a Linux host system with a GPU. The images we provide on [DockerHub](https://hub.docker.com/r/deepmi/fastsurfer) conveniently include everything needed for FastSurfer. You will also need a [FreeSurfer license file](https://surfer.nmr.mgh.harvard.edu/fswiki/License) for the [Surface pipeline](#surface-reconstruction). We have detailed per-OS Installation instructions in the [INSTALL.md file](doc/overview/INSTALL.md). -## Usage +### Usage -All installation methods use the `run_fastsurfer.sh` call interface (replace `*fastsurfer-flags*` with [FastSurfer flags](#required-arguments)), which is the general starting point for FastSurfer. However, there are different ways to call this script depending on the installation, which we explain here: +All installation methods use the `run_fastsurfer.sh` call interface (replace `*fastsurfer-flags*` with [FastSurfer flags](doc/overview/FLAGS.md#required-arguments)), which is the general starting point for FastSurfer. However, there are different ways to call this script depending on the installation, which we explain here: 1. For container installations, you need to define the hardware and mount the folders with the input (`/data`) and output data (`/output`): (a) For __singularity__, the syntax is @@ -63,11 +76,11 @@ All installation methods use the `run_fastsurfer.sh` call interface (replace `*f ``` The `--nv` flag is needed to allow FastSurfer to run on the GPU (otherwise FastSurfer will run on the CPU). - The `--no-home` flag tells singularity to not mount the home directory (see [Singularity README](Singularity/README.md#mounting-home) for more info). + The `--no-home` flag tells singularity to not mount the home directory (see [Singularity documentation](Singularity/README.md#mounting-home) for more info). The `-B` flag is used to tell singularity, which folders FastSurfer can read and write to. - See also __[Example 2](#example-2--fastSurfer-singularity)__ for a full singularity FastSurfer run command and [the Singularity README](Singularity/README.md#fastsurfer-singularity-image-usage) for details on more singularity flags. + See also __[Example 2](doc/overview/EXAMPLES.md#example-2-fastsurfer-singularity)__ for a full singularity FastSurfer run command and [the Singularity documentation](Singularity/README.md#fastsurfer-singularity-image-usage) for details on more singularity flags. (b) For __docker__, the syntax is ``` @@ -83,358 +96,37 @@ All installation methods use the `run_fastsurfer.sh` call interface (replace `*f The `-v` flag is used to tell docker, which folders FastSurfer can read and write to. - See also __[Example 1](#example-1--fastSurfer-docker)__ for a full FastSurfer run inside a Docker container and [the Docker README](Docker/README.md#docker-flags-) for more details on the docker flags including `--rm` and `--user`. + See also __[Example 1](doc/overview/EXAMPLES.md#example-1-fastsurfer-docker)__ for a full FastSurfer run inside a Docker container and [the Docker documentation](Docker/README.md#docker-flags) for more details on the docker flags including `--rm` and `--user`. 2. For a __native install__, you need to activate your FastSurfer environment (e.g. `conda activate fastsurfer_gpu`) and make sure you have added the FastSurfer path to your `PYTHONPATH` variable, e.g. `export PYTHONPATH=$(pwd)`. You will then be able to run fastsurfer with `./run_fastsurfer.sh *fastsurfer-flags*`. - See also [Example 3](#example-3--native-fastsurfer-on-subjectx--with-parallel-processing-of-hemis-) for an illustration of the commands to run the entire FastSurfer pipeline (FastSurferCNN + recon-surf) natively. - -### FastSurfer Flags -Next, you will need to select the `*fastsurfer-flags*` and replace `*fastsurfer-flags*` with your options. Please see the Examples below for some example flags. - -The `*fastsurfer-flags*` will usually include the subject directory (`--sd`; Note, this will be the mounted path - `/output` - for containers), the subject name/id (`--sid`) and the path to the input image (`--t1`). For example: - -```bash -... --sd /output --sid test_subject --t1 /data/test_subject_t1.nii.gz --3T -``` -Additionally, you can use `--seg_only` or `--surf_only` to only run a part of the pipeline or `--no_biasfield`, `--no_cereb` and `--no_asegdkt` to switch off some segmentation modules (see above). -Here, we have also added the `--3T` flag, which tells fastsurfer to register against the 3T atlas for ICV estimation (eTIV). - -In the following, we give an overview of the most important options, but you can view a full list of options with - -```bash -./run_fastsurfer.sh --help -``` - - -#### Required arguments -* `--sd`: Output directory \$SUBJECTS_DIR (equivalent to FreeSurfer setup --> $SUBJECTS_DIR/sid/mri; $SUBJECTS_DIR/sid/surf ... will be created). -* `--sid`: Subject ID for directory inside \$SUBJECTS_DIR to be created ($SUBJECTS_DIR/sid/...) -* `--t1`: T1 full head input (not bias corrected, global path). The network was trained with conformed images (UCHAR, 256x256x256, 1-0.7 mm voxels and standard slice orientation). These specifications are checked in the run_prediction.py script and the image is automatically conformed if it does not comply. Note, outputs will be in the conformed space (as in FreeSurfer). - -#### Required for docker when running surface module -* `--fs_license`: Path to FreeSurfer license key file (only needed for the surface module). Register (for free) at https://surfer.nmr.mgh.harvard.edu/registration.html to obtain it if you do not have FreeSurfer installed so far. Strictly necessary if you use Docker, optional for local install (your local FreeSurfer license will automatically be used). The license file is usually located in $FREESURFER_HOME/license.txt or $FREESURFER_HOME/.license . - -#### Segmentation pipeline arguments (optional) -* `--seg_only`: only run FastSurferCNN (generate segmentation, do not run the surface pipeline) -* `--seg_log`: Name and location for the log-file for the segmentation (FastSurferCNN). Default: $SUBJECTS_DIR/$sid/scripts/deep-seg.log -* `--viewagg_device`: Define where the view aggregation should be run on. Can be "auto" or a device (see --device). By default, the program checks if you have enough memory to run the view aggregation on the gpu. The total memory is considered for this decision. If this fails, or you actively overwrote the check with setting with "cpu" view agg is run on the cpu. Equivalently, if you pass a different device, view agg will be run on that device (no memory check will be done). -* `--device`: Select device for NN segmentation (_auto_, _cpu_, _cuda_, _cuda:_, _mps_), where cuda means Nvidia GPU, you can select which one e.g. "cuda:1". Default: "auto", check GPU and then CPU. "mps" is for native MAC installs to use the Apple silicon (M-chip) GPU. -* `--asegdkt_segfile`: Name of the segmentation file, which includes the aparc+DKTatlas-aseg segmentations. Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/aparc.DKTatlas+aseg.deep.mgz -* `--no_cereb`: Switch of the cerebellum sub-segmentation -* `--cereb_segfile`: Name of the cerebellum segmentation file. If not provided, this intermediate DL-based segmentation will not be stored, but only the merged segmentation will be stored (see --main_segfile ). Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/cerebellum.CerebNet.nii.gz -* `--no_biasfield`: Deactivate the calculation of partial volume-corrected statistics. - -#### Surface pipeline arguments (optional) -* `--surf_only`: only run the surface pipeline recon_surf. The segmentation created by FastSurferCNN must already exist in this case. -* `--3T`: for Talairach registration, use the 3T atlas instead of the 1.5T atlas (which is used if the flag is not provided). This gives better (more consistent with FreeSurfer) ICV estimates (eTIV) for 3T and better Talairach registration matrices, but has little impact on standard volume or surface stats. -* `--fstess`: Use mri_tesselate instead of marching cube (default) for surface creation -* `--fsqsphere`: Use FreeSurfer default instead of novel spectral spherical projection for qsphere -* `--fsaparc`: Use FS aparc segmentations in addition to DL prediction (slower in this case and usually the mapped ones from the DL prediction are fine) -* `--parallel`: Run both hemispheres in parallel -* `--no_fs_T1`: Do not generate T1.mgz (normalized nu.mgz included in standard FreeSurfer output) and create brainmask.mgz directly from norm.mgz instead. Saves 1:30 min. -* `--no_surfreg`: Skip the surface registration (`sphere.reg`), which is generated automatically by default. To safe time, use this flag to turn this off. - -#### Other -* `--threads`: Target number of threads for all modules (segmentation, surface pipeline), `1` (default) forces FastSurfer to only really use one core. Note, that the default value may change in the future for better performance on multi-core architectures. -* `--vox_size`: Forces processing at a specific voxel size. If a number between 0.7 and 1 is specified (below is experimental) the T1w image is conformed to that isotropic voxel size and processed. - If "min" is specified (default), the voxel size is read from the size of the minimal voxel size (smallest per-direction voxel size) in the T1w image: - If the minimal voxel size is bigger than 0.98mm, the image is conformed to 1mm isometric. - If the minimal voxel size is smaller or equal to 0.98mm, the T1w image will be conformed to isometric voxels of that voxel size. - The voxel size (whether set manually or derived) determines whether the surfaces are processed with highres options (below 1mm) or not. -* `--py`: Command for python, used in both pipelines. Default: python3.8 -* `--conformed_name`: Name of the file in which the conformed input image will be saved. Default location: \$SUBJECTS_DIR/\$sid/mri/orig.mgz -* `--ignore_fs_version`: Switch on to avoid check for FreeSurfer version. Program will terminate if the supported version (see recon-surf.sh) is not sourced. Can be used for testing dev versions. -* `-h`, `--help`: Prints help text - -### Example 1: FastSurfer Docker -After pulling one of our images from Dockerhub, you do not need to have a separate installation of FreeSurfer on your computer (it is already included in the Docker image). However, if you want to run ___more than just the segmentation CNN___, you need to register at the FreeSurfer website (https://surfer.nmr.mgh.harvard.edu/registration.html) to acquire a valid license for free. The directory containing the license needs to be mounted and passed to the script via the `--fs_license` flag. Basically for Docker (as for Singularity below) you are starting a container image (with the run command) and pass several parameters for that, e.g. if GPUs will be used and mounting (linking) the input and output directories to the inside of the container image. In the second half of that call you pass parameters to our run_fastsurfer.sh script that runs inside the container (e.g. where to find the FreeSurfer license file, and the input data and other flags). - -To run FastSurfer on a given subject using the provided GPU-Docker, execute the following command: - -```bash -# 1. get the fastsurfer docker image (if it does not exist yet) -docker pull deepmi/fastsurfer - -# 2. Run command -docker run --gpus all -v /home/user/my_mri_data:/data \ - -v /home/user/my_fastsurfer_analysis:/output \ - -v /home/user/my_fs_license_dir:/fs_license \ - --rm --user $(id -u):$(id -g) deepmi/fastsurfer:latest \ - --fs_license /fs_license/license.txt \ - --t1 /data/subjectX/t1-weighted.nii.gz \ - --sid subjectX --sd /output \ - --parallel --3T -``` - -Docker Flags: -* The `--gpus` flag is used to allow Docker to access GPU resources. With it, you can also specify how many GPUs to use. In the example above, _all_ will use all available GPUS. To use a single one (e.g. GPU 0), set `--gpus device=0`. To use multiple specific ones (e.g. GPU 0, 1 and 3), set `--gpus 'device=0,1,3'`. -* The `-v` commands mount your data, output, and directory with the FreeSurfer license file into the docker container. Inside the container these are visible under the name following the colon (in this case /data, /output, and /fs_license). -* The `--rm` flag takes care of removing the container once the analysis finished. -* The `--user $(id -u):$(id -g)` part automatically runs the container with your group- (id -g) and user-id (id -u). All generated files will then belong to the specified user. Without the flag, the docker container will be run as root which is discouraged. - -FastSurfer Flags: -* The `--fs_license` points to your FreeSurfer license which needs to be available on your computer in the my_fs_license_dir that was mapped above. -* The `--t1` points to the t1-weighted MRI image to analyse (full path, with mounted name inside docker: /home/user/my_mri_data => /data) -* The `--sid` is the subject ID name (output folder name) -* The `--sd` points to the output directory (its mounted name inside docker: /home/user/my_fastsurfer_analysis => /output) -* The `--parallel` activates processing left and right hemisphere in parallel -* The `--3T` changes the atlas for registration to the 3T atlas for better Talairach transforms and ICV estimates (eTIV) - -Note, that the paths following `--fs_license`, `--t1`, and `--sd` are __inside__ the container, not global paths on your system, so they should point to the places where you mapped these paths above with the `-v` arguments (part after colon). - -A directory with the name as specified in `--sid` (here subjectX) will be created in the output directory if it does not exist. So in this example output will be written to /home/user/my_fastsurfer_analysis/subjectX/ . Make sure the output directory is empty, to avoid overwriting existing files. - -If you do not have a GPU, you can also run our CPU-Docker by dropping the `--gpus all` flag and specifying `--device cpu` at the end as a FastSurfer flag. See [Docker/README.md](Docker/README.md) for more details. - - -### Example 2: FastSurfer Singularity -After building the Singularity image (see below or instructions in ./Singularity/README.md), you also need to register at the FreeSurfer website (https://surfer.nmr.mgh.harvard.edu/registration.html) to acquire a valid license (for free) - same as when using Docker. This license needs to be passed to the script via the `--fs_license` flag. This is not necessary if you want to run the segmentation only. - -To run FastSurfer on a given subject using the Singularity image with GPU access, execute the following commands from a directory where you want to store singularity images. This will create a singularity image from our Dockerhub image and execute it: - -```bash -# 1. Build the singularity image (if it does not exist) -singularity build fastsurfer-gpu.sif docker://deepmi/fastsurfer - -# 2. Run command -singularity exec --nv \ - --no-home \ - -B /home/user/my_mri_data:/data \ - -B /home/user/my_fastsurfer_analysis:/output \ - -B /home/user/my_fs_license_dir:/fs_license \ - ./fastsurfer-gpu.sif \ - /fastsurfer/run_fastsurfer.sh \ - --fs_license /fs_license/license.txt \ - --t1 /data/subjectX/t1-weighted.nii.gz \ - --sid subjectX --sd /output \ - --parallel --3T -``` - -#### Singularity Flags -* The `--nv` flag is used to access GPU resources. -* The `--no-home` flag stops mounting your home directory into singularity. -* The `-B` commands mount your data, output, and directory with the FreeSurfer license file into the Singularity container. Inside the container these are visible under the name following the colon (in this case /data, /output, and /fs_license). - -#### FastSurfer Flags -* The `--fs_license` points to your FreeSurfer license which needs to be available on your computer in the my_fs_license_dir that was mapped above. -* The `--t1` points to the t1-weighted MRI image to analyse (full path, with mounted name inside docker: /home/user/my_mri_data => /data) -* The `--sid` is the subject ID name (output folder name) -* The `--sd` points to the output directory (its mounted name inside docker: /home/user/my_fastsurfer_analysis => /output) -* The `--parallel` activates processing left and right hemisphere in parallel -* The `--3T` changes the atlas for registration to the 3T atlas for better Talairach transforms and ICV estimates (eTIV) - -Note, that the paths following `--fs_license`, `--t1`, and `--sd` are __inside__ the container, not global paths on your system, so they should point to the places where you mapped these paths above with the `-v` arguments (part after colon). - -A directory with the name as specified in `--sid` (here subjectX) will be created in the output directory. So in this example output will be written to /home/user/my_fastsurfer_analysis/subjectX/ . Make sure the output directory is empty, to avoid overwriting existing files. - -You can run the Singularity equivalent of CPU-Docker by building a Singularity image from the CPU-Docker image and excluding the `--nv` argument in your Singularity exec command. Also append `--device cpu` as a FastSurfer flag. - - -### Example 3: Native FastSurfer on subjectX (with parallel processing of hemis) - -For a native install you may want to make sure that you are on our stable branch, as the default dev branch is for development and could be broken at any time. For that you can directly clone the stable branch: - -```bash -git clone --branch stable https://github.com/Deep-MI/FastSurfer.git -``` - -More details (e.g. you need all dependencies in the right versions and also FreeSurfer locally) can be found in our [INSTALL.md file](INSTALL.md). -Given you want to analyze data for subject which is stored on your computer under /home/user/my_mri_data/subjectX/t1-weighted.nii.gz, run the following command from the console (do not forget to source FreeSurfer!): - -```bash -# Source FreeSurfer -export FREESURFER_HOME=/path/to/freesurfer -source $FREESURFER_HOME/SetUpFreeSurfer.sh - -# Define data directory -datadir=/home/user/my_mri_data -fastsurferdir=/home/user/my_fastsurfer_analysis - -# Run FastSurfer -./run_fastsurfer.sh --t1 $datadir/subjectX/t1-weighted-nii.gz \ - --sid subjectX --sd $fastsurferdir \ - --parallel --threads 4 --3T -``` - -The output will be stored in the $fastsurferdir (including the aparc.DKTatlas+aseg.deep.mgz segmentation under $fastsurferdir/subjectX/mri (default location)). Processing of the hemispheres will be run in parallel (--parallel flag) to significantly speed-up surface creation. Omit this flag to run the processing sequentially, e.g. if you want to save resources on a compute cluster. - - -### Example 4: FastSurfer on multiple subjects - -In order to run FastSurfer on multiple cases, you may use the helper script `brun_subjects.sh`. This script accepts multiple ways to define the subjects, for example a subjects_list file. -Prepare the subjects_list file as follows: -``` -subject_id1=path_to_t1\n -subject2=path_to_t1\n -subject3=path_to_t1\n -... -subject10=path_to_t1\n -``` -Note, that all paths (`path_to_t1`) are as if you passed them to the `run_fastsurfer.sh` script via `--t1 ` so they may be with respect to the singularity or docker file system. Absolute paths are recommended. - -The `brun_fastsurfer.sh` script can then be invoked in docker, singularity or on the native platform as follows: - -#### Docker -```bash -docker run --gpus all -v /home/user/my_mri_data:/data \ - -v /home/user/my_fastsurfer_analysis:/output \ - -v /home/user/my_fs_license_dir:/fs_license \ - --entrypoint "/fastsurfer/brun_fastsurfer.sh" \ - --rm --user $(id -u):$(id -g) deepmi/fastsurfer:latest \ - --fs_license /fs_license/license.txt \ - --sd /output --subjects_list /data/subjects_list.txt \ - --parallel --3T -``` -#### Singularity -```bash -singularity exec --nv \ - --no-home \ - -B /home/user/my_mri_data:/data \ - -B /home/user/my_fastsurfer_analysis:/output \ - -B /home/user/my_fs_license_dir:/fs_license \ - ./fastsurfer-gpu.sif \ - /fastsurfer/run_fastsurfer.sh \ - --fs_license /fs_license/license.txt \ - --sd /output \ - --subjects_list /data/subjects_list.txt \ - --parallel --3T -``` -#### Native -```bash -export FREESURFER_HOME=/path/to/freesurfer -source $FREESURFER_HOME/SetUpFreeSurfer.sh - -cd /home/user/FastSurfer -datadir=/home/user/my_mri_data -fastsurferdir=/home/user/my_fastsurfer_analysis - -# Run FastSurfer -./brun_fastsurfer.sh --subjects_list $datadir/subjects_list.txt \ - --sd $fastsurferdir \ - --parallel --threads 4 --3T -``` - -#### Flags -The `brun_fastsurfer.sh` script accepts almost all `run_fastsurfer.sh` flags (exceptions are `--t1` and `--sid`). In addition, -* the `--parallel_subjects` runs all subjects in parallel (experimental, parameter may change in future releases). This is particularly useful for surfaces computation `--surf_only`. -* to run segmentation in series, but surfaces in parallel, you may use `--parallel_subjects surf`. -* these options are in contrast (and in addition) to `--parallel`, which just parallelizes the hemispheres of one case. - -### Example 5: Quick Segmentation - -For many applications you won't need the surfaces. You can run only the aparc+DKT segmentation (in 1 minute on a GPU) via - -```bash -./run_fastsurfer.sh --t1 $datadir/subject1/t1-weighted.nii.gz \ - --asegdkt_segfile $outputdir/subject1/aparc.DKTatlas+aseg.deep.mgz \ - --conformed_name $outputdir/subject1/conformed.mgz \ - --threads 4 --seg_only --no_cereb -``` - -This will produce the segmentation in a conformed space (just as FreeSurfer would do). It also writes the conformed image that fits the segmentation. -Conformed means that the image will be isotropic in LIA orientation. -It will furthermore output a brain mask (`mri/mask.mgz`), a simplified segmentation file (`mri/aseg.auto_noCCseg.mgz`), the biasfield corrected image (`mri/orig_nu.mgz`), and the volume statistics (without eTIV) based on the FastSurferVINN segmentation (without the corpus callosum) (`stats/aseg+DKT.stats`). - -If you do not even need the biasfield corrected image and the volume statistics, you may add `--no_biasfield`. These steps especially benefit from larger assigned core counts `--threads 32`. - -The above ```run_fastsurfer.sh``` commands can also be called from the Docker or Singularity images by passing the flags and adjusting input and output directories to the locations inside the containers (where you mapped them via the -v flag in Docker or -B in Singularity). - -```bash -# Docker -docker run --gpus all -v $datadir:/data \ - -v $outputdir:/output \ - --rm --user $(id -u):$(id -g) deepmi/fastsurfer:latest \ - --t1 /data/subject1/t1-weighted.nii.gz \ - --asegdkt_segfile /output/subject1/aparc.DKTatlas+aseg.deep.mgz \ - --conformed_name $outputdir/subject1/conformed.mgz \ - --threads 4 --seg_only --3T -``` - -### Example 6: Running FastSurfer on a SLURM cluster via Singularity - -Starting with version 2.2, FastSurfer comes with a script that helps orchestrate FastSurfer optimally on a SLURM cluster: `srun_fastsurfer.sh`. - -This script distributes GPU-heavy and CPU-heavy workloads to different SLURM partitions and manages intermediate files in a work directory for IO performance. - -```bash -srun_fastsurfer.sh --partition seg=GPU_Partition \ - --partition surf=CPU_Partition \ - --sd $outputdir \ - --data $datadir \ - --singularity_image $HOME/images/fastsurfer-singularity.sif \ - [...] # fastsurfer flags -``` - -This will create three dependent SLURM jobs, one to segment, one for surface reconstruction and one for cleanup (which moves the data from the work directory to the `$outputdir`). -There are many intricacies and options, so it is advised to use `--help`, `--debug` and `--dry` to inspect, what will be scheduled as well as run a test on a small subset. More control over subjects is available with `--subject_list`s. - -The `$outputdir` and the `$datadir` need to be accessible from cluster nodes. Most IO is performed on a work directory (automatically generated from `$HPCWORK` environment variable: `$HPCWORK/fastsurfer-processing/$(date +%Y%m%d-%H%M%S)`). Alternatively, an empty directory can be manually defined via `--work`. On successful cleanup, this directory will be removed. - -## Output files - - -### Segmentation module - -The segmentation module outputs the files shown in the table below. The two primary output files are the `aparc.DKTatlas+aseg.deep.mgz` file, which contains the FastSurfer segmentation of cortical and subcortical structures based on the DKT atlas, and the `aseg+DKT.stats` file, which contains summary statistics for these structures. Note, that the surface model (downstream) corrects these segmentations along the cortex with the created surfaces. So if the surface model is used, it is recommended to use the updated segmentations and stats (see below). - -| directory | filename | module | description | -|:------------|-------------------------------|-----------|-------------| -| mri | aparc.DKTatlas+aseg.deep.mgz | asegdkt | cortical and subcortical segmentation| -| mri | aseg.auto_noCCseg.mgz | asegdkt | simplified subcortical segmentation without corpus callosum labels| -| mri | mask.mgz | asegdkt | brainmask| -| mri | orig.mgz | asegdkt | conformed image| -| mri | orig_nu.mgz | asegdkt | biasfield-corrected image| -| mri/orig | 001.mgz | asegdkt | original image| -| scripts | deep-seg.log | asegdkt | logfile| -| stats | aseg+DKT.stats | asegdkt | table of cortical and subcortical segmentation statistics| - -### Cerebnet module - -The cerebellum module outputs the files in the table shown below. Unless switched off by the `--no_cereb` argument, this module is automatically run whenever the segmentation module is run. It adds two files, an image with the sub-segmentation of the cerebellum and a text file with summary statistics. + See also [Example 3](doc/overview/EXAMPLES.md#example-3-native-fastsurfer-on-subjectx-with-parallel-processing-of-hemis) for an illustration of the commands to run the entire FastSurfer pipeline (FastSurferCNN + recon-surf) natively. + +### FastSurfer_Flags +Please refer to [FASTSURFER_FLAGS](doc/overview/FLAGS.md). -| directory | filename | module | description | -|:------------|-------------------------------|-----------|-------------| -| mri | cerebellum.CerebNet.nii.gz | cerebnet | cerebellum sub-segmentation| -| stats | cerebellum.CerebNet.stats | cerebnet | table of cerebellum segmentation statistics| +## Examples +All the examples can be found here: [FASTSURFER_EXAMPLES](doc/overview/EXAMPLES.md) +- [Example 1: FastSurfer Docker](doc/overview/EXAMPLES.md#example-1-fastsurfer-docker) +- [Example 2: FastSurfer Singularity](doc/overview/EXAMPLES.md#example-2-fastsurfer-singularity) +- [Example 3: Native FastSurfer on subjectX with parallel processing of hemis](doc/overview/EXAMPLES.md#example-3-native-fastsurfer-on-subjectx-with-parallel-processing-of-hemis) +- [Example 4: FastSurfer on multiple subjects](doc/overview/EXAMPLES.md#example-4-fastsurfer-on-multiple-subjects) +- [Example 5: Quick Segmentation](doc/overview/EXAMPLES.md#example-5-quick-segmentation) +- [Example 6: Running FastSurfer on a SLURM cluster via Singularity](doc/overview/EXAMPLES.md#example-6-running-fastsurfer-on-a-slurm-cluster-via-singularity) -### Surface module -The surface module is run unless switched off by the `--seg_only` argument. It outputs a large number of files, which generally correspond to the FreeSurfer nomenclature and definition. A selection of important output files is shown in the table below, for the other files, we refer to the [FreeSurfer documentation](https://surfer.nmr.mgh.harvard.edu/fswiki). In general, the "mri" directory contains images, including segmentations, the "surf" folder contains surface files (geometries and vertex-wise overlay data), the "label" folder contains cortical parcellation labels, and the "stats" folder contains tabular summary statistics. Many files are available for the left ("lh") and right ("rh") hemisphere of the brain. Symbolic links are created to map FastSurfer files to their FreeSurfer equivalents, which may need to be present for further processing (e.g., with FreeSurfer downstream modules). - -After running this module, some of the initial segmentations and corresponding volume estimates are fine-tuned (e.g., surface-based partial volume correction, addition of corpus callosum labels). Specifically, this concerns the `aseg.mgz `, `aparc.DKTatlas+aseg.mapped.mgz`, `aparc.DKTatlas+aseg.deep.withCC.mgz`, which were originally created by the segmentation module or have earlier versions resulting from that module. - -The primary output files are pial, white, and inflated surface files, the thickness overlay files, and the cortical parcellation (annotation) files. The preferred way of assessing this output is the [FreeView](https://surfer.nmr.mgh.harvard.edu/fswiki/FreeviewGuide) software. Summary statistics for volume and thickness estimates per anatomical structure are reported in the stats files, in particular the `aseg.stats`, and the left and right `aparc.DKTatlas.mapped.stats` files. - -| directory | filename | module | description | -|:------------|-------------------------------|-----------|-------------| -| mri | aparc.DKTatlas+aseg.deep.withCC.mgz| surface | cortical and subcortical segmentation incl. corpus callosum after running the surface module| -| mri | aparc.DKTatlas+aseg.mapped.mgz| surface | cortical and subcortical segmentation after running the surface module| -| mri | aparc.DKTatlas+aseg.mgz | surface | symlink to aparc.DKTatlas+aseg.mapped.mgz| -| mri | aparc+aseg.mgz | surface | symlink to aparc.DKTatlas+aseg.mapped.mgz| -| mri | aseg.mgz | surface | subcortical segmentation after running the surface module| -| mri | wmparc.DKTatlas.mapped.mgz | surface | white matter parcellation| -| mri | wmparc.mgz | surface | symlink to wmparc.DKTatlas.mapped.mgz| -| surf | lh.area, rh.area | surface | surface area overlay file| -| surf | lh.curv, rh.curv | surface | curvature overlay file| -| surf | lh.inflated, rh.inflated | surface | inflated cortical surface| -| surf | lh.pial, rh.pial | surface | pial surface| -| surf | lh.thickness, rh.thickness | surface | cortical thickness overlay file| -| surf | lh.volume, rh.volume | surface | gray matter volume overlay file| -| surf | lh.white, rh.white | surface | white matter surface| -| label | lh.aparc.DKTatlas.annot, rh.aparc.DKTatlas.annot| surface | symlink to lh.aparc.DKTatlas.mapped.annot| -| label | lh.aparc.DKTatlas.mapped.annot, rh.aparc.DKTatlas.mapped.annot| surface | annotation file for cortical parcellations, mapped from ASEGDKT segmentation to the surface| -| stats | aseg.stats | surface | table of cortical and subcortical segmentation statistics after running the surface module| -| stats | lh.aparc.DKTatlas.mapped.stats, rh.aparc.DKTatlas.mapped.stats| surface | table of cortical parcellation statistics, mapped from ASEGDKT segmentation to the surface| -| stats | lh.curv.stats, rh.curv.stats | surface | table of curvature statistics| -| stats | wmparc.DKTatlas.mapped.stats | surface | table of white matter segmentation statistics| -| scripts | recon-all.log | surface | logfile| +## Output files +Modules output can be found here: [FastSurfer_Output_Files](doc/overview/OUTPUT_FILES.md) +- [Segmentation module](doc/overview/OUTPUT_FILES.md#segmentation-module) +- [Cerebnet module](doc/overview/OUTPUT_FILES.md#cerebnet-module) +- [Surface module](doc/overview/OUTPUT_FILES.md#surface-module) + ## System Requirements Recommendation: At least 8 GB system memory and 8 GB NVIDIA graphics memory ``--viewagg_device gpu`` @@ -454,20 +146,24 @@ Minimum CPU-only: 8 GB system memory (much slower, not recommended) ``--device c | 0.7mm | gpu | 8 | 6 | | 0.7mm | cpu | 3 | 9 | + ## Expert usage -Individual modules and the surface pipeline can be run independently of the full pipeline script documented in this README. -This is documented in READMEs in subfolders, for example: [whole brain segmentation only with FastSurferVINN](FastSurferCNN/README.md), [cerebellum sub-segmentation (in progress)](CerebNet/README.md) and [surface pipeline only (recon-surf)](recon_surf/README.md). +Individual modules and the surface pipeline can be run independently of the full pipeline script documented in this documentation. +This is documented in READMEs in subfolders, for example: [whole brain segmentation only with FastSurferVINN](FastSurferCNN/README.md), [cerebellum sub-segmentation](CerebNet/README.md), [hypothalamic sub-segmentation](HypVINN/README.md) and [surface pipeline only (recon-surf)](recon_surf/README.md). Specifically, the segmentation modules feature options for optimized parallelization of batch processing. + ## FreeSurfer Downstream Modules -FreeSurfer provides several Add-on modules for downstream processing, such as subfield segmentation ( [hippocampus/amygdala](https://surfer.nmr.mgh.harvard.edu/fswiki/HippocampalSubfieldsAndNucleiOfAmygdala), [brainstrem](https://surfer.nmr.mgh.harvard.edu/fswiki/BrainstemSubstructures), [thalamus](https://freesurfer.net/fswiki/ThalamicNuclei) and [hypothalamus](https://surfer.nmr.mgh.harvard.edu/fswiki/HypothalamicSubunits) ) as well as [TRACULA](https://surfer.nmr.mgh.harvard.edu/fswiki/Tracula). We now provide symlinks to the required files, as FastSurfer creates them with a different name (e.g. using "mapped" or "DKT" to make clear that these file are from our segmentation using the DKT Atlas protocol, and mapped to the surface). Most subfield segmentations require `wmparc.mgz` and work very well with FastSurfer, so feel free to run those pipelines after FastSurfer. TRACULA requires `aparc+aseg.mgz` which we now link, but have not tested if it works, given that [DKT-atlas](https://mindboggle.readthedocs.io/en/latest/labels.html) merged a few labels. You should source FreeSurfer 7.3.2 to run these modules. +FreeSurfer provides several Add-on modules for downstream processing, such as subfield segmentation ( [hippocampus/amygdala](https://surfer.nmr.mgh.harvard.edu/fswiki/HippocampalSubfieldsAndNucleiOfAmygdala), [brainstem](https://surfer.nmr.mgh.harvard.edu/fswiki/BrainstemSubstructures), [thalamus](https://freesurfer.net/fswiki/ThalamicNuclei) and [hypothalamus](https://surfer.nmr.mgh.harvard.edu/fswiki/HypothalamicSubunits) ) as well as [TRACULA](https://surfer.nmr.mgh.harvard.edu/fswiki/Tracula). We now provide symlinks to the required files, as FastSurfer creates them with a different name (e.g. using "mapped" or "DKT" to make clear that these file are from our segmentation using the DKT Atlas protocol, and mapped to the surface). Most subfield segmentations require `wmparc.mgz` and work very well with FastSurfer, so feel free to run those pipelines after FastSurfer. TRACULA requires `aparc+aseg.mgz` which we now link, but have not tested if it works, given that [DKT-atlas](https://mindboggle.readthedocs.io/en/latest/labels.html) merged a few labels. You should source FreeSurfer 7.3.2 to run these modules. + ## Intended Use This software can be used to compute statistics from an MR image for research purposes. Estimates can be used to aggregate population data, compare groups etc. The data should not be used for clinical decision support in individual cases and, therefore, does not benefit the individual patient. Be aware that for a single image, produced results may be unreliable (e.g. due to head motion, imaging artefacts, processing errors etc). We always recommend to perform visual quality checks on your data, as also your MR-sequence may differ from the ones that we tested. No contributor shall be liable to any damages, see also our software [LICENSE](LICENSE). + ## References If you use this for research publications, please cite: @@ -478,8 +174,15 @@ _Henschel L*, Kuegler D*, Reuter M. (*co-first). FastSurferVINN: Building Resolu _Faber J*, Kuegler D*, Bahrami E*, et al. (*co-first). CerebNet: A fast and reliable deep-learning pipeline for detailed cerebellum sub-segmentation. NeuroImage 264 (2022), 119703. https://doi.org/10.1016/j.neuroimage.2022.119703_ -Stay tuned for updates and follow us on Twitter: https://twitter.com/deepmilab +_Estrada S, Kuegler D, Bahrami E, Xu P, Mousa D, Breteler MMB, Aziz NA, Reuter M. FastSurfer-HypVINN: Automated sub-segmentation of the hypothalamus and adjacent structures on high-resolutional brain MRI. Imaging Neuroscience 2023; 1 1–32. https://doi.org/10.1162/imag_a_00034_ + +Stay tuned for updates and follow us on [X/Twitter](https://twitter.com/deepmilab). + ## Acknowledgements -The recon-surf pipeline is largely based on FreeSurfer -https://surfer.nmr.mgh.harvard.edu/fswiki/FreeSurferMethodsCitation + +This project is partially funded by: +- [Chan Zuckerberg Initiative](https://chanzuckerberg.com/eoss/proposals/fastsurfer-ai-based-neuroimage-analysis-package/) +- [German Federal Ministry of Education and Research](https://www.gesundheitsforschung-bmbf.de/de/deepni-innovative-deep-learning-methoden-fur-die-rechnergestutzte-neuro-bildgebung-10897.php) + +The recon-surf pipeline is largely based on [FreeSurfer](https://surfer.nmr.mgh.harvard.edu/fswiki/FreeSurferMethodsCitation). diff --git a/Singularity/README.md b/Singularity/README.md index 51e7267e..4d532d0d 100644 --- a/Singularity/README.md +++ b/Singularity/README.md @@ -1,29 +1,29 @@ -# FastSurfer Singularity Image Creation +# FastSurfer Singularity Support -We host our releases as docker images on [Dockerhub](https://hub.docker.com/r/deepmi/fastsurfer/tags) -For use on HPCs or in other cases where Docker is not preferred you can easily create a Singularity image from the Docker images. - -# FastSurfer Singularity Image Creation -For creating a singularity image from the Dockerhub just run: +For use on HPCs (or in other cases where Docker is not available or preferred) you can easily create a Singularity image from the Docker image. +Singularity uses its own image format, so the Docker images must be converted (we publish our releases as docker images available on [Dockerhub](https://hub.docker.com/r/deepmi/fastsurfer/tags)). +## Singularity with the official FastSurfer Image +To create a Singularity image from the official FastSurfer image hosted on Dockerhub just run: ```bash -cd /home/user/my_singlarity_images -singularity build fastsurfer-latest.sif docker://deepmi/fastsurfer:latest +singularity build /home/user/my_singlarity_images/fastsurfer-latest.sif docker://deepmi/fastsurfer:latest ``` +Singularity images are files - usually with the extension `.sif`. Here, we save the image in `/homer/user/my_singlarity_images`. +If you want to pick a specific FastSurfer version, you can also change the tag (`latest`) in `deepmi/fastsurfer:latest` to any tag. For example to use the cpu image hosted on [Dockerhub](https://hub.docker.com/r/deepmi/fastsurfer/tags) use the tag `cpu-latest`. -Singularity Images are saved as files. Here the _/homer/user/my_singlarity_images_ is the path where you want your file saved. -You can change _deepmi/fastsurfer:latest_ with any tag provided in our [Dockerhub](https://hub.docker.com/r/deepmi/fastsurfer/tags) +## Building your own FastSurfer Singularity Image +To build a custom FastSurfer Singularity image, the `Docker/build.py` script supports a flag for direct conversion. +Simply add `--singularity /home/user/my_singlarity_images/fastsurfer-myimage.sif` to the call, which first builds the image with Docker and then converts the image to Singularity. -If you want to use a locally available image that you created yourself, instead run: +If you want to manually convert the local Docker image `fastsurfer:myimage`, run: ```bash -cd /home/user/my_singlarity_images -singularity build fastsurfer-myimage.sif docker-daemon://fastsurfer:myimage +singularity build /home/user/my_singlarity_images/fastsurfer-myimage.sif docker-daemon://fastsurfer:myimage ``` -For how to create your own Docker images see our [Docker guide](../Docker/README.md) +For more information on how to create your own Docker images, see our [Docker guide](../Docker/README.md). -# FastSurfer Singularity Image Usage +## FastSurfer Singularity Image Usage After building the Singularity image, you need to register at the FreeSurfer website (https://surfer.nmr.mgh.harvard.edu/registration.html) to acquire a valid license (for free) - just as when using Docker. This license needs to be passed to the script via the `--fs_license` flag. This is not necessary if you want to run the segmentation only. @@ -59,6 +59,7 @@ Note, that the paths following `--fs_license`, `--t1`, and `--sd` are __inside__ A directory with the name as specified in `--sid` (here subjectX) will be created in the output directory. So in this example output will be written to /home/user/my_fastsurfer_analysis/subjectX/ . Make sure the output directory is empty, to avoid overwriting existing files. +### Singularity without a GPU You can run the Singularity equivalent of CPU-Docker by building a Singularity image from the CPU-Docker image (replace # with the current version number) and excluding the `--nv` argument in your Singularity exec command as following: ```bash @@ -77,7 +78,7 @@ singularity exec --no-home \ --parallel --3T ``` -# Singularity Best Practice +## Singularity Best Practice ### Mounting Home Do not mount the user home directory into the singularity container as the home directory. diff --git a/Tutorial/Complete_FastSurfer_Tutorial.ipynb b/Tutorial/Complete_FastSurfer_Tutorial.ipynb index ca880abe..2115a5a6 100644 --- a/Tutorial/Complete_FastSurfer_Tutorial.ipynb +++ b/Tutorial/Complete_FastSurfer_Tutorial.ipynb @@ -2354,7 +2354,7 @@ " device (no memory check will be done).\n", " --batch Batch size for inference. Default: 1\n", " --py Command for python, used in both pipelines.\n", - " Default: python3.8\n", + " Default: python3.10\n", "\n", " Dev Flags:\n", " --ignore_fs_version Switch on to avoid check for FreeSurfer version.\n", diff --git a/Tutorial/README.md b/Tutorial/README.md index 0951f35a..39989682 100644 --- a/Tutorial/README.md +++ b/Tutorial/README.md @@ -33,7 +33,7 @@ Docker is an open platform for developing, shipping, and running applications. I If you decide against using docker, you need either python + pip or anaconda (conda) to install FastSurfer. #### 1. Python + pip -Python 3.8 or greater is generally recommended to run FastSurfer. In addition you will need the package manager pip to install the python dependencies used for FastSurfer (see requirements.txt in the main directory for a list). On Linux, pip is not installed by default. You can install it via +Python 3.10 is generally recommended to run FastSurfer. In addition you will need the package manager pip to install the python dependencies used for FastSurfer (see requirements.txt in the main directory for a list). On Linux, pip is not installed by default. You can install it via ```bash sudo apt install python3-pip diff --git a/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb b/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb index 4fb0696e..db463305 100644 --- a/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb +++ b/Tutorial/Tutorial_FastSurferCNN_QuickSeg.ipynb @@ -218,7 +218,7 @@ "#@title Here we first setup the environment by downloading the open source deep-mi/fastsurfer project and the required packages\n", "import os\n", "import sys\n", - "from os.path import exists, join, basename, splitext\n", + "from os.path import exists, basename, splitext\n", "\n", "print(\"Starting setup. This could take a few minutes\")\n", "print(\"----------------------------------------------\")\n", @@ -474,12 +474,12 @@ "source": [ "#@title Click this run button, if you would prefer to download the segmentation in nifti-format\n", "import nibabel as nib\n", + "from google.colab import files\n", "# conversion to nifti\n", "data = nib.load(f'{SETUP_DIR}fastsurfer_seg/Tutorial/mri/aparc.DKTatlas+aseg.deep.mgz')\n", "img_nifti = nib.Nifti1Image(data.get_fdata(), data.affine, header=nib.Nifti1Header())\n", "nib.nifti1.save(img_nifti, f'{SETUP_DIR}fastsurfer_seg/Tutorial/mri/aparc.DKTatlas+aseg.deep.nii.gz')\n", "\n", - "from google.colab import files\n", "files.download(f'{SETUP_DIR}fastsurfer_seg/Tutorial/mri/aparc.DKTatlas+aseg.deep.nii.gz')\n" ] }, @@ -519,12 +519,12 @@ "source": [ "#@title Click this run button, if you would prefer to download the image in nifti-format\n", "import nibabel as nib\n", + "from google.colab import files\n", "# conversion to nifti\n", "data = nib.load(f\"{SETUP_DIR}140_orig.mgz\")\n", "img_nifti = nib.Nifti1Image(data.get_fdata(), data.affine, header=nib.Nifti1Header())\n", "nib.nifti1.save(img_nifti, f\"{SETUP_DIR}140_orig.nii.gz\")\n", "\n", - "from google.colab import files\n", "files.download(f\"{SETUP_DIR}140_orig.nii.gz\")\n" ] }, @@ -612,11 +612,11 @@ "%matplotlib inline\n", "import nibabel as nib\n", "import matplotlib.pyplot as plt\n", - "plt.style.use('seaborn-v0_8-whitegrid')\n", - "from skimage import color\n", "import torch\n", "import numpy as np\n", + "from skimage import color\n", "from torchvision import utils\n", + "plt.style.use('seaborn-v0_8-whitegrid')\n", "\n", "def plot_predictions(image, pred):\n", " \"\"\"\n", @@ -676,7 +676,6 @@ "from ipywidgets import widgets\n", "import matplotlib.pyplot as plt\n", "import nibabel as nib\n", - "import numpy as np\n", "#from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n", "from skimage import measure\n", "\n", @@ -853,7 +852,7 @@ "def plot_3d_plotly_shape(structure, hemisphere, show_mesh=True, crop=True, grid=True):\n", " import plotly.graph_objects as go\n", " label = label_lookups(structure, hemisphere)\n", - " test_cond = np.in1d(pred_data, label).reshape(pred_data.shape)\n", + " test_cond = np.isin(pred_data, label).reshape(pred_data.shape)\n", " roi = np.where(test_cond, 1, 0)\n", " vert_p, faces_p, normals_p, values_p = measure.marching_cubes(roi, 0, spacing=(1, 1, 1))\n", "\n", diff --git a/brun_fastsurfer.sh b/brun_fastsurfer.sh index 2422589e..6b95b876 100755 --- a/brun_fastsurfer.sh +++ b/brun_fastsurfer.sh @@ -16,7 +16,7 @@ # script for batch-running FastSurfer -subjects="" +subjects=() subjects_stdin="true" POSITIONAL_FASTSURFER=() task_count="" @@ -24,7 +24,7 @@ task_id="" surf_only="false" seg_only="false" debug="false" -run_fastsurfer="default" +run_fastsurfer=() parallel_subjects="1" parallel_surf="false" statusfile="" @@ -55,14 +55,14 @@ Documentation of Options: Generally, brun_fastsurfer works similar to run_fastsurfer, but loops over multiple subjects from i. a list passed through stdin of the format (one subject per line) --- -= +=[ [ ...]] ... --- ii. a subject_list file using the same format (use Ctrl-D to end the input), or -iii. a list of subjects directly passed +iii. a list of subjects directly passed (this does not support subject-specific parameters) --batch "/": run the i-th of n batches (starting at 1) of the full list of subjects - (default: 1/1, == run all). + (default: 1/1, == run all). "slurm_task_id" is a valid option for "". Note, brun_fastsurfer.sh will also automatically detect being run in a SLURM JOBARRAY and split according to \$SLURM_ARRAY_TASK_ID and \$SLURM_ARRAY_TASK_COUNT (unless values are specifically assigned with the --batch argument). @@ -90,10 +90,8 @@ This tool requires functions in stools.sh (expected in same folder as this scrip EOF } -if [ -z "${BASH_SOURCE[0]}" ]; then - THIS_SCRIPT="$0" -else - THIS_SCRIPT="${BASH_SOURCE[0]}" +if [ -z "${BASH_SOURCE[0]}" ]; then THIS_SCRIPT="$0" +else THIS_SCRIPT="${BASH_SOURCE[0]}" fi # PRINT USAGE if called without params @@ -104,125 +102,83 @@ then fi # PARSE Command line -newline=" -" inputargs=("$@") POSITIONAL=() +SED_CLEANUP_SUBJECTS='s/\r$//;s/\s*\r\s*/\n/g;s/\s*$//;/^\s*$/d' i=0 while [[ $# -gt 0 ]] do # make key lowercase key=$(echo "$1" | tr '[:upper:]' '[:lower:]') +shift # past argument case $key in - --subject_list) - if [[ ! -f "$2" ]] - then - echo "ERROR: Could not find the subject list $2!" - exit 1 - fi - subjects="$subjects$newline$(cat $2)" - subjects_stdin="false" - shift # past argument + # parse/get the subjects to iterate over + #=================================================== + --subject_list|--subjects_list) + if [[ ! -f "$1" ]] + then + echo "ERROR: Could not find the subject list $1!" + exit 1 + fi + # append the subjects in the listfile (cleanup first) to the subjects array + mapfile -t -O ${#subjects} subjects < <(sed "$SED_CLEANUP_SUBJECTS" "$1") + subjects_stdin="false" shift # past value ;; - --subjects) - subjects_stdin="false" - shift # argument - while [[ "$(expr match \"$1\" '--.')" == 0 ]] - do - if [[ -n "$subjects" ]]; then subjects="$subjects$newline"; fi - subjects="$subjects$1" - shift # next value - done + --subjects) + subjects_stdin="false" + while [[ ! "$1" =~ ^-- ]] ; do subjects+=("$1") ; shift ; done ;; - --batch) - task_count=$(echo "$2" | cut -f2 -d/) - task_id=$(echo "$2" | cut -f1 -d/) - shift - shift - ;; - --parallel_subjects) - shift - if [[ "$(expr match \"$1\" '--.')" == 0 ]] - then - lower_value="$(echo "$1" | tr '[:upper:]' '[:lower:]')" - # has parameter - if [[ "$lower_value" =~ ^surf=?$ ]] - then - parallel_subjects="max" - parallel_surf="true" - elif [[ "$lower_value" =~ ^[0-9]+$ ]] - then - if [[ "$lower_value" -lt 0 ]] - then - parallel_subjects="max" - elif [[ "$lower_value" -lt 2 ]] - then - parallel_subjects="1" - else - parallel_subjects="$lower_value" - fi - elif [[ "$lower_value" =~ ^surf=[0-9]+$ ]] - then - parallel_surf="true" - if [[ "${lower_value:5}" -lt 0 ]] - then - parallel_subjects="max" - elif [[ "${lower_value:5}" -lt 2 ]] - then - parallel_subjects="1" - else - parallel_subjects="${lower_value:5}" - fi - else - echo "Invalid option for --parallel_subjects: $1" - exit 1 - fi - shift - else + # brun_fastsurfer-specific/custom options + #=================================================== + --batch) task_count=$(echo "$1" | cut -f2 -d/) ; task_id=$(echo "$1" | cut -f1 -d/) ; shift ;; + --run_fastsurfer) run_fastsurfer=($1) ; shift ;; + --parallel_subjects) + if [[ "$#" -lt 1 ]] || [[ "$1" =~ ^-- ]] + then # no additional parameter to --parallel_subjects, the next cmd args is unrelated + parallel_subjects="max" + else # has parameter + lower_value="$(echo "$1" | tr '[:upper:]' '[:lower:]')" + if [[ "$lower_value" =~ ^surf(=-[0-9]*|=max)?$ ]] + then # parameter is surf=max or surf= or surf + parallel_surf="true" parallel_subjects="max" + elif [[ "$lower_value" =~ ^surf=[0-9]*$ ]] + then # parameter is surf= + parallel_surf="true" + parallel_subjects="${lower_value:5}" + elif [[ "$lower_value" =~ ^(-[0-9]+|max)$ ]] ; then parallel_subjects="max" + elif [[ "$lower_value" =~ ^[0-9]+$ ]] ; then parallel_subjects="$lower_value" + else + echo "Invalid option for --parallel_subjects: $1" + exit 1 fi - ;; - --statusfile) - statusfile="$statusfile" - shift - shift - ;; - --surf_only) - surf_only="true" shift + fi ;; - --seg_only) - seg_only="true" - shift + --statusfile) statusfile="$1" ; shift ;; + --debug) debug="true" ;; + --help) usage ; exit ;; + # run_fastsurfer.sh options, with extra effect in brun_fastsurfer + #=================================================== + --surf_only) surf_only="true" ;; + --seg_only) seg_only="true" ;; + --sid|--t1) + echo "ERROR: --sid and --t1 are not valid for brun_fastsurfer.sh, these values are populated" + echo "via --subjects or --subject_list." + exit 1 ;; - --sid|--t1) - echo "ERROR: --sid and --t1 are not valid for brun_fastsurfer.sh, these values are populated" - echo "via --subjects or --subject_list." - exit 1 - ;; - --debug) - debug="true" - shift - ;; - --help) - usage - exit - ;; - --run_fastsurfer) - run_fastsurfer="$2" - shift - shift - ;; - *) # unknown option - POSITIONAL_FASTSURFER[$i]="$1" - i=$(($i + 1)) - shift + *) # unknown option/run_fastsurfer.sh option + POSITIONAL_FASTSURFER[$i]="$key" + i=$(($i + 1)) ;; esac done -set -- "${POSITIONAL[@]}" # restore positional parameters +if [[ "${#POSITIONAL[@]}" -gt 0 ]] +then + set -- "${POSITIONAL[@]}" # restore positional parameters +fi echo "$THIS_SCRIPT ${inputargs[*]}" date -R @@ -230,48 +186,50 @@ echo "" set -eo pipefail +source "$(dirname "$THIS_SCRIPT")/stools.sh" + if [[ -n "$SLURM_ARRAY_TASK_ID" ]] then if [[ -z "$task_count" ]] then task_count=$SLURM_ARRAY_TASK_COUNT fi - if [[ -z "$task_id" ]] + if [[ -z "$task_id" ]] || [[ "$task_id" == "slurm_task_id" ]] then task_id=$SLURM_ARRAY_TASK_ID fi echo "SLURM TASK ARRAY detected" fi +if [[ "$subjects_stdin" == "true" ]] +then + echo "Reading subjects from stdin, press Ctrl-D to end input (one subject per line)" + mapfile -t -O ${#subjects[@]} subjects < <(sed "$SED_CLEANUP_SUBJECTS") +fi + if [[ "$debug" == "true" ]] then echo "---START DEBUG---" echo "Debug parameters to script brun_fastsurfer:" echo "" echo "subjects: " - echo $subjects + printf "%s\n" "${subjects[@]}" echo "---" - echo "task_id/task_count: $task_id/$task_count" + echo "task_id/task_count: ${task_id:-not specified}/${task_count:-not specified}" if [[ "$parallel_subjects" != "1" ]] then - printf "--parallel_subjects" - if [[ "$parallel_surf" == "true" ]] - then - printf "surf=%s\n" "$parallel_subjects" + if [[ "$parallel_surf" == "true" ]] ; then echo "--parallel_subjects surf=$parallel_subjects" + else echo "--parallel_subjects $parallel_subjects" fi fi - if [[ "$run_fastsurfer" != "/fastsurfer/run_fastsurfer.sh" ]] - then - echo "running $run_fastsurfer" - fi - if [[ -n "$statusfile" ]] - then - echo "statusfile: $statusfile" + if [[ "${run_fastsurfer[*]}" == "" ]] ; then echo "running default run_fastsurfer" + else echo "running ${run_fastsurfer[*]}" fi + if [[ -n "$statusfile" ]] ; then echo "statusfile: $statusfile" ; fi echo "" - echo "FastSurfer parameters:" - if [[ "$seg_only" == "true" ]]; then echo "--seg_only"; fi - if [[ "$surf_only" == "true" ]]; then echo "--surf_only"; fi + printf "FastSurfer parameters:" + if [[ "$seg_only" == "true" ]]; then printf "\n--seg_only"; fi + if [[ "$surf_only" == "true" ]]; then printf "\n--surf_only"; fi for p in "${POSITIONAL_FASTSURFER[@]}" do if [[ "$p" = --* ]]; then printf "\n%s" "$p"; @@ -279,76 +237,69 @@ then fi done echo "" - echo "Running in $(ls -l /proc/$$/exe | cut -d">" -f2)" + echo "" + shell=$(ls -l "/proc/$$/exe" | cut -d">" -f2) + echo "Running in shell $shell: $($shell --version 2>/dev/null | head -n 1)" echo "" echo "---END DEBUG ---" fi -if [[ "$subjects_stdin" == "true" ]] -then - echo "Reading subjects from stdin, press Ctrl-D to end input (one subject per line)" - subjects="$(cat)" -fi - -if [[ -z "$subjects" ]] +if [[ "${#subjects[@]}" == 0 ]] then echo "ERROR: No subjects specified!" exit 1 fi -i=1 -num_subjects=$(echo "$subjects" | wc -l) - -if [[ "$run_fastsurfer" == "default" ]] +if [[ "${run_fastsurfer[*]}" == "" ]] then if [[ -n "$FASTSURFER_HOME" ]] then - run_fastsurfer=$FASTSURFER_HOME/run_fastsurfer.sh + run_fastsurfer=("$FASTSURFER_HOME/run_fastsurfer.sh") echo "INFO: run_fastsurfer not explicitly specified, using \$FASTSURFER_HOME/run_fastsurfer.sh." - elif [[ -f "$(dirname $THIS_SCRIPT)/run_fastsurfer.sh" ]] + elif [[ -f "$(dirname "$THIS_SCRIPT")/run_fastsurfer.sh" ]] then - run_fastsurfer="$(dirname $THIS_SCRIPT)/run_fastsurfer.sh" - echo "INFO: run_fastsurfer not explicitly specified, using $run_fastsurfer." + run_fastsurfer=("$(dirname "$THIS_SCRIPT")/run_fastsurfer.sh") + echo "INFO: run_fastsurfer not explicitly specified, using ${run_fastsurfer[0]}." elif [[ -f "/fastsurfer/run_fastsurfer.sh" ]] then - run_fastsurfer="/fastsurfer/run_fastsurfer.sh" + run_fastsurfer=("/fastsurfer/run_fastsurfer.sh") echo "INFO: run_fastsurfer not explicitly specified, using /fastsurfer/run_fastsurfer.sh." else echo "ERROR: Could not find FastSurfer, please set the \$FASTSURFER_HOME environment variable." fi fi +num_subjects=${#subjects[@]} if [[ -z "$task_id" ]] && [[ -z "$task_count" ]] then - subject_start=1 + subject_start=0 subject_end=$num_subjects elif [[ -z "$task_id" ]] || [[ -z "$task_count" ]] then echo "Both task_id and task_count have to be defined, invalid --batch argument?" + exit 1 else - subject_start=$(($(($task_id - 1)) * "$num_subjects" / "$task_count" + 1)) - subject_end=$(("$task_id" * "$num_subjects" / "$task_count")) + subject_start=$(((task_id - 1) * num_subjects / task_count)) + subject_end=$((task_id * num_subjects / task_count)) + subject_end=$((subject_end < num_subjects ? subject_end : num_subjects)) echo "Processing subjects $subject_start to $subject_end" fi +subject_len=$((subject_end - subject_start)) -if [[ "$parallel_subjects" != "1" ]] && [[ "$((subject_end - subject_start))" == 0 ]] +if [[ "$parallel_subjects" != "1" ]] && [[ "$subject_len" == 1 ]] then if [[ "$debug" == "true" ]] ; then echo "DEBUG: --parallel_subjects deactivated, since only one subject" ; fi parallel_subjects="1" fi seg_surf_only="" -if [[ "$surf_only" == "true" ]] -then - seg_surf_only=--surf_only -elif [[ "$seg_only" == "true" ]] -then - seg_surf_only=--seg_only +if [[ "$surf_only" == "true" ]] ; then seg_surf_only=--surf_only +elif [[ "$seg_only" == "true" ]] ; then seg_surf_only=--seg_only fi if [[ "$parallel_surf" == "true" ]] then - if [[ -n "$seg_surf_only" ]] + if [[ -n "$seg_surf_only" ]] then echo "ERROR: Cannot combine --parallel_subjects surf= and --seg_only or --surf_only." fi @@ -356,86 +307,127 @@ then fi ### IF THE SCRIPT GETS TERMINATED, ADD A MESSAGE -trap "{ echo \"brun_fastsurfer.sh terminated via signal at \$(date -R)!\" }" SIGINT SIGTERM +trap 'echo "brun_fastsurfer.sh terminated via signal at $(date -R)!"' SIGINT SIGTERM + +if [[ "$parallel_subjects" != "1" ]] +then + echo "Running up to $parallel_subjects in parallel" +fi + +function read_args_from_string () +{ + tab=$(printf '\t') + str_len=${#1} + position=0 + while [[ "$position" -le "$str_len" ]] + do + if [[ -z "${1:$position}" ]]; then position=$((str_len + 1)); continue ; fi + arg=$(expr "${1:$position} " : "\(\(\\\\.\|[^'\"[:space:]\\\\]\+\|'\([^']*\|''\)*'\|\"\([^\"\\\\]\+\|\\\\.\)*\"\)\+\).*") + if [[ -z "$arg" ]] + then + # could not parse + echo "Could not parse the line ${1:$position}, maybe incorrect quoting or escaping?" + exit 1 + else + # arg parsed + if [[ "$position" == "0" ]]; then image_path=$arg + else args=("${args[@]}" "$arg") + fi + position=$((position + ${#arg})) + fi + while [[ "${1:$position:1}" == " " ]] || [[ "${1:$position:1}" == "$tab" ]]; do position=$((position + 1)) ; done + done + export image_path + export args +} pids=() subjectids=() +ROOT_IFS=$IFS IFS=$'\n' -for subject in $subjects + +echo "${subjects[@]}" +# i is a 1-to-n index of the subject +i=$subject_start +for subject in "${subjects[@]:$subject_start:$subject_len}" do + IFS=$ROOT_IFS + i=$((i + 1)) if [[ "$debug" == "true" ]] then echo "DEBUG: subject $i: $subject" fi - # if the subject is in the selected batch - if [[ "$i" -ge "$subject_start" ]] && [[ "$i" -le "$subject_end" ]] - then - subject_id=$(echo "$subject" | cut -d= -f1) + subject_id=$(echo "$subject" | cut -d= -f1) - if [[ -n "$statusfile" ]] && [[ "$surf_only" == "true" ]] - then - status=$(awk -F ": " "/^$subject_id/ { print \$2 }" "$statusfile") - ## if status in statusfile is "Failed", skip this - if [[ "$status" =~ /^Failed.--seg_only/ ]] - then - echo "Skipping $subject_id's surface recon because the segmentation failed." - echo "$subject_id: Skipping surface recon (failed segmentation)" >> "$statusfile" - continue - fi - fi - - image_path=$(echo "$subject" | cut -d= -f2) - args=(--sid "$subject_id") - if [[ "$parallel_surf" == "true" ]] + if [[ -n "$statusfile" ]] && [[ "$surf_only" == "true" ]] + then + status=$(awk -F ": " "/^$subject_id/ { print \$2 }" "$statusfile") + ## if status in statusfile is "Failed", skip this + if [[ "$status" =~ /^Failed.*--seg_only/ ]] then - if [[ "$debug" == "true" ]] - then - echo "DEBUG: $run_fastsurfer --seg_only --t1 "$image_path"" "${args[@]}" "${POSITIONAL_FASTSURFER[@]}" - fi - $run_fastsurfer "--seg_only" --t1 "$image_path" "${args[@]}" "${POSITIONAL_FASTSURFER[@]}" - if [[ -n "$statusfile" ]] - then - print_status "$subject_id" "--seg_only" "$?" | tee -a "$statusfile" - fi + echo "Skipping $subject_id's surface recon because the segmentation failed." + echo "$subject_id: Skipping surface recon (failed segmentation)" >> "$statusfile" + continue fi + fi - if [[ "$surf_only" == "false" ]] && [[ "$parallel_surf" == "false" ]] - then - args=("${args[@]}" --t1 "$image_path") + image_parameters=$(echo "$subject" | cut -d= -f2-1000 --output-delimiter="=") + args=(--sid "$subject_id") + IFS=$' \t' + read_args_from_string "$image_parameters" + IFS=$ROOT_IFS + if [[ "$parallel_surf" == "true" ]] + then + # parallel_surf implies $seg_surf_only == "" (see line 353), i.e. both seg and surf + cmd=("${run_fastsurfer[@]}" "--seg_only" --t1 "$image_path" "${args[@]}" "${POSITIONAL_FASTSURFER[@]}") + if [[ "$debug" == "true" ]] ; then echo "DEBUG:" "${cmd[@]}" ; fi + if [[ "$parallel_subjects" != "1" ]] ; then "${cmd[@]}" | prepend "$subject_id: " + else "${cmd[@]}" fi - if [[ "$debug" == "true" ]] + if [[ -n "$statusfile" ]] then - echo "DEBUG: $run_fastsurfer $seg_surf_only" "${args[@]}" "${POSITIONAL_FASTSURFER[@]}" "[&]" + print_status "$subject_id" "--seg_only" "$?" | tee -a "$statusfile" fi - if [[ "$parallel_subjects" != "1" ]] + fi + + if [[ "$surf_only" == "false" ]] && [[ "$parallel_surf" == "false" ]] + then + args=("${args[@]}" --t1 "$image_path") + fi + if [[ "$debug" == "true" ]] + then + echo "DEBUG: ${run_fastsurfer[*]} $seg_surf_only" "${args[@]}" "${POSITIONAL_FASTSURFER[@]}" "[&]" + fi + if [[ "$parallel_subjects" != "1" ]] + then + "${run_fastsurfer[@]}" $seg_surf_only "${args[@]}" "${POSITIONAL_FASTSURFER[@]}" | prepend "$subject_id: " & + pids=("${pids[@]}" "$!") + subjectids=("${subjectids[@]}" "$subject_id") + else # serial execution + "${run_fastsurfer[@]}" $seg_surf_only "${args[@]}" "${POSITIONAL_FASTSURFER[@]}" + if [[ -n "$statusfile" ]] then - $run_fastsurfer "$seg_surf_only" "${args[@]}" "${POSITIONAL_FASTSURFER[@]}" | prepend "$subject_id: " & - pids=("${pids[@]}" "$!") - subjectids=("${subjectids[@]}" "$subject_id") - else # serial execution - $run_fastsurfer "$seg_surf_only" "${args[@]}" "${POSITIONAL_FASTSURFER[@]}" - if [[ -n "$statusfile" ]] - then - print_status "$subject_id" "$seg_surf_only" "$?" | tee -a "$statusfile" - fi + print_status "$subject_id" "$seg_surf_only" "$?" | tee -a "$statusfile" fi fi - i=$(($i + 1)) + IFS=$'\n' done +IFS=$ROOT_IFS if [[ "$parallel_subjects" != "1" ]] then - i=0 + # indexing in arrays is a 0-base operation, so array[0] is the first element + i=-1 for pid in "${pids[@]}" do + i=$((i + 1)) wait "$pid" if [[ -n "$statusfile" ]] then print_status "${subjectids[$i]}" "$seg_surf_only" "$?" | tee -a "$statusfile" fi - i=$(($i + 1)) done fi # always exit successful -exit 0 \ No newline at end of file +exit 0 diff --git a/doc/Makefile b/doc/Makefile new file mode 100644 index 00000000..d4bb2cbb --- /dev/null +++ b/doc/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/doc/_static/.gitkeep b/doc/_static/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/doc/_templates/autosummary/class.rst b/doc/_templates/autosummary/class.rst new file mode 100644 index 00000000..3322b321 --- /dev/null +++ b/doc/_templates/autosummary/class.rst @@ -0,0 +1,10 @@ +{{ fullname | escape | underline }} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :inherited-members: + +.. minigallery:: {{ fullname }} + :add-heading: diff --git a/doc/_templates/autosummary/function.rst b/doc/_templates/autosummary/function.rst new file mode 100644 index 00000000..cdbecc4f --- /dev/null +++ b/doc/_templates/autosummary/function.rst @@ -0,0 +1,8 @@ +{{ fullname | escape | underline }} + +.. currentmodule:: {{ module }} + +.. autofunction:: {{ objname }} + +.. minigallery:: {{ fullname }} + :add-heading: diff --git a/doc/_templates/autosummary/module.rst b/doc/_templates/autosummary/module.rst new file mode 100644 index 00000000..13a2c278 --- /dev/null +++ b/doc/_templates/autosummary/module.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. automodule:: {{ fullname }} + :members: + diff --git a/doc/api/CerebNet.rst b/doc/api/CerebNet.rst new file mode 100644 index 00000000..46864c31 --- /dev/null +++ b/doc/api/CerebNet.rst @@ -0,0 +1,12 @@ +CerebNet +======== + + +.. currentmodule:: CerebNet + +.. autosummary:: + :toctree: generated/ + + apply_warp + inference + run_prediction diff --git a/doc/api/CerebNet_dataloader.rst b/doc/api/CerebNet_dataloader.rst new file mode 100644 index 00000000..86f95902 --- /dev/null +++ b/doc/api/CerebNet_dataloader.rst @@ -0,0 +1,13 @@ +CerebNet.dataloader +=================== + + +.. currentmodule:: CerebNet.data_loader + +.. autosummary:: + :toctree: generated/ + + augmentation + data_utils + dataset + loader diff --git a/doc/api/CerebNet_datasets.rst b/doc/api/CerebNet_datasets.rst new file mode 100644 index 00000000..ee987c52 --- /dev/null +++ b/doc/api/CerebNet_datasets.rst @@ -0,0 +1,14 @@ +CerebNet.datasets +================= + + +.. currentmodule:: CerebNet.datasets + +.. autosummary:: + :toctree: generated/ + + generate_hdf5 + load_data + utils + wm_merge_clean + \ No newline at end of file diff --git a/doc/api/CerebNet_models.rst b/doc/api/CerebNet_models.rst new file mode 100644 index 00000000..4e206fff --- /dev/null +++ b/doc/api/CerebNet_models.rst @@ -0,0 +1,12 @@ +CerebNet_models +=============== + + +.. currentmodule:: CerebNet.models + +.. autosummary:: + :toctree: generated/ + + networks + sub_module + diff --git a/doc/api/CerebNet_utils.rst b/doc/api/CerebNet_utils.rst new file mode 100644 index 00000000..d9843781 --- /dev/null +++ b/doc/api/CerebNet_utils.rst @@ -0,0 +1,15 @@ +CerebNet.utils +============== + + +.. currentmodule:: CerebNet.utils + +.. autosummary:: + :toctree: generated/ + + checkpoint + load_config + lr_scheduler + meters + metrics + misc diff --git a/doc/api/FastSurferCNN.data_loader.rst b/doc/api/FastSurferCNN.data_loader.rst new file mode 100644 index 00000000..5d505974 --- /dev/null +++ b/doc/api/FastSurferCNN.data_loader.rst @@ -0,0 +1,16 @@ +FastSurferCNN.data_loader +========================= + + +.. currentmodule:: FastSurferCNN.data_loader + +.. autosummary:: + :toctree: generated/ + + augmentation + conform + data_utils + dataset + loader + + diff --git a/doc/api/FastSurferCNN.models.rst b/doc/api/FastSurferCNN.models.rst new file mode 100644 index 00000000..83be62c2 --- /dev/null +++ b/doc/api/FastSurferCNN.models.rst @@ -0,0 +1,14 @@ +FastSurferCNN.models +==================== + + +.. currentmodule:: FastSurferCNN.models + +.. autosummary:: + :toctree: generated/ + + interpolation_layer + losses + networks + sub_module + diff --git a/doc/api/FastSurferCNN.rst b/doc/api/FastSurferCNN.rst new file mode 100644 index 00000000..c301013d --- /dev/null +++ b/doc/api/FastSurferCNN.rst @@ -0,0 +1,20 @@ +FastSurferCNN +============= + + +.. currentmodule:: FastSurferCNN + + +.. autosummary:: + :toctree: generated/ + + + download_checkpoints + generate_hdf5 + inference + quick_qc + reduce_to_aseg + run_prediction + segstats + version + diff --git a/doc/api/FastSurferCNN.utils.rst b/doc/api/FastSurferCNN.utils.rst new file mode 100644 index 00000000..c521c795 --- /dev/null +++ b/doc/api/FastSurferCNN.utils.rst @@ -0,0 +1,24 @@ +FastSurferCNN.utils +=================== + + +.. currentmodule:: FastSurferCNN.utils + +.. autosummary:: + :toctree: generated/ + + arg_types + checkpoint + common + load_config + logging + lr_scheduler + mapper + meters + metrics + misc + parser_defaults + run_tools + threads + + diff --git a/doc/api/HypVINN.rst b/doc/api/HypVINN.rst new file mode 100644 index 00000000..2f560de1 --- /dev/null +++ b/doc/api/HypVINN.rst @@ -0,0 +1,11 @@ +HypVINN +======= + + +.. currentmodule:: HypVINN + +.. autosummary:: + :toctree: generated/ + + inference + run_prediction diff --git a/doc/api/HypVINN_dataloader.rst b/doc/api/HypVINN_dataloader.rst new file mode 100644 index 00000000..4797d420 --- /dev/null +++ b/doc/api/HypVINN_dataloader.rst @@ -0,0 +1,11 @@ +HypVINN.data_loader +=================== + + +.. currentmodule:: HypVINN.data_loader + +.. autosummary:: + :toctree: generated/ + + data_utils + dataset \ No newline at end of file diff --git a/doc/api/HypVINN_models.rst b/doc/api/HypVINN_models.rst new file mode 100644 index 00000000..76bfdbd6 --- /dev/null +++ b/doc/api/HypVINN_models.rst @@ -0,0 +1,10 @@ +HypVINN.models +============== + + +.. currentmodule:: HypVINN.models + +.. autosummary:: + :toctree: generated/ + + networks diff --git a/doc/api/HypVINN_utils.rst b/doc/api/HypVINN_utils.rst new file mode 100644 index 00000000..90fded56 --- /dev/null +++ b/doc/api/HypVINN_utils.rst @@ -0,0 +1,8 @@ +HypVINN.utils +============= + + +.. currentmodule:: HypVINN.utils + +.. autosummary:: + :toctree: generated/ diff --git a/doc/api/index.rst b/doc/api/index.rst new file mode 100644 index 00000000..203f6e64 --- /dev/null +++ b/doc/api/index.rst @@ -0,0 +1,24 @@ +FastSurfer API +============== + +.. note:: Warning + The FastSurfer API is in development and will change without warning. Please consider no internal module, class or function final at this point! + +.. toctree:: + :maxdepth: 2 + + FastSurferCNN.rst + FastSurferCNN.data_loader.rst + FastSurferCNN.models.rst + FastSurferCNN.utils.rst + CerebNet.rst + CerebNet_dataloader.rst + CerebNet_datasets.rst + CerebNet_models.rst + CerebNet_utils.rst + HypVINN.rst + HypVINN_dataloader.rst + HypVINN_models.rst + HypVINN_utils.rst + recon_surf.rst + diff --git a/doc/api/recon_surf.rst b/doc/api/recon_surf.rst new file mode 100644 index 00000000..54403fb9 --- /dev/null +++ b/doc/api/recon_surf.rst @@ -0,0 +1,27 @@ +recon_surf +========== + + +.. currentmodule:: recon_surf + +.. autosummary:: + :toctree: generated/ + + align_points + align_seg + create_annotation + fs_balabels + lta + map_surf_label + N4_bias_correct + paint_cc_into_pred + rewrite_oriented_surface + rewrite_mc_surface + rotate_sphere + sample_parc + smooth_aparc + spherically_project_wrapper + + + + diff --git a/doc/conf.py b/doc/conf.py new file mode 100644 index 00000000..a9c70804 --- /dev/null +++ b/doc/conf.py @@ -0,0 +1,282 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + + +import inspect +from importlib import import_module +import sys +import os +from pathlib import Path + +# here i added the relative path because sphinx was not able +# to locate FastSurferCNN module directly for autosummary +sys.path.append(os.path.dirname(__file__) + "/..") +sys.path.append(os.path.dirname(__file__) + "/../recon_surf") +sys.path.append(os.path.dirname(__file__) + "/sphinx_ext") + +project = "FastSurfer" +author = "FastSurfer Developers" +copyright = f"2020, {author}" +gh_url = "https://github.com/deep-mi/FastSurfer" + + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +# If your documentation needs a minimal Sphinx version, state it here. +needs_sphinx = "5.0" + +# The document name of the “root” document, that is, the document that contains +# the root toctree directive. +root_doc = "index" + + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named "sphinx.ext.*") or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosectionlabel", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.linkcode", + "numpydoc", + "sphinxcontrib.bibtex", + "sphinxcontrib.programoutput", + "sphinx_copybutton", + "sphinx_design", + "sphinx_issues", + # sphinx.ext.autosectionlabel and nbsphinx together with sphinxarg.ext causes a + # duplicate label warning: https://github.com/spatialaudio/nbsphinx/issues/787 + # nbsphinx is currently not 'needed' as we do not include ipynb files. + # "nbsphinx", + "IPython.sphinxext.ipython_console_highlighting", + "myst_parser", + "sphinxarg.ext", + "fix_links", +] + +# Suppress myst.xref_missing warning and i.e A target was +# not found for a cross-reference +# Reference: https://myst-parser.readthedocs.io/en/latest/configuration.html#build-warnings +suppress_warnings = [ + # "myst.xref_missing", + "myst.duplicate_def", + "autosectionlabel", +] + +# create anchors for which headings? +myst_heading_anchors = 7 + +templates_path = ["_templates"] +exclude_patterns = [ + "_build", + "Thumbs.db", + ".DS_Store", + "**.ipynb_checkpoints", +] + + +# Sphinx will warn about all references where the target cannot be found. +nitpicky = False +nitpick_ignore = [] + +# A list of ignored prefixes for module index sorting. +# modindex_common_prefix = [f"{package}."] + +# The name of a reST role (builtin or Sphinx extension) to use as the default +# role, that is, for text marked up `like this`. This can be set to 'py:obj' to +# make `filter` a cross-reference to the Python function “filter”. +default_role = "py:obj" + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output +html_theme = "furo" +html_static_path = ["_static"] +html_title = project +html_show_sphinx = False + +# Documentation to change footer icons: +# https://pradyunsg.me/furo/customisation/footer/#changing-footer-icons +html_theme_options = { + "footer_icons": [ + { + "name": "GitHub", + "url": gh_url, + "html": """ + + + + """, + "class": "", + }, + ], +} + + +# -- autosummary ------------------------------------------------------------- +autosummary_generate = True + +# -- autodoc ----------------------------------------------------------------- +autodoc_typehints = "none" +autodoc_member_order = "groupwise" +autodoc_warningiserror = True +autoclass_content = "class" + + +# -- intersphinx ------------------------------------------------------------- +intersphinx_mapping = { + "matplotlib": ("https://matplotlib.org/stable", None), + "mne": ("https://mne.tools/stable/", None), + "numpy": ("https://numpy.org/doc/stable", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "python": ("https://docs.python.org/3", None), + # "scipy": ("https://docs.scipy.org/doc/scipy", None), + "sklearn": ("https://scikit-learn.org/stable/", None), +} +intersphinx_timeout = 5 + + +# -- sphinx-issues ----------------------------------------------------------- +issues_github_path = gh_url.split("https://github.com/")[-1] + +# -- autosectionlabels ------------------------------------------------------- +autosectionlabel_prefix_document = True + +# -- numpydoc ---------------------------------------------------------------- +numpydoc_class_members_toctree = False +numpydoc_attributes_as_param_list = False +# numpydoc_show_class_members = True + + +# x-ref +numpydoc_xref_param_type = True +numpydoc_xref_aliases = { + # Matplotlib + "Axes": "matplotlib.axes.Axes", + "Figure": "matplotlib.figure.Figure", + # Python + "bool": ":class:`python:bool`", + "Path": "pathlib.Path", + "TextIO": "io.TextIOBase", + # Scipy + "csc_matrix": "scipy.sparse.csc_matrix", +} +# numpydoc_xref_ignore = {} + +# validation +# https://numpydoc.readthedocs.io/en/latest/validation.html#validation-checks +error_ignores = { + "GL01", # docstring should start in the line immediately after the quotes + "EX01", # section 'Examples' not found + "ES01", # no extended summary found + "SA01", # section 'See Also' not found + "RT02", # The first line of the Returns section should contain only the type, unless multiple values are being returned # noqa + "PR01", # Parameters {missing_params} not documented + "GL08", # The object does not have a docstring + "SS05", # Summary must start with infinitive verb, not third person + "RT01", # No Returns section found + "SS06", # Summary should fit in a single line + "GL02", # Closing quotes should be placed in the line after the last text + "GL03", # Double line break found; please use only one blank line to + "SS03", # Summary does not end with a period + "YD01", # No Yields section found + "PR02", # Unknown parameters {unknown_params} + "SS01", # Short summary in a single should be present at the beginning of the docstring. +} +numpydoc_validate = True +numpydoc_validation_checks = {"all"} | set(error_ignores) +numpydoc_validation_exclude = { # regex to ignore during docstring check + r"\.__getitem__", + r"\.__contains__", + r"\.__hash__", + r"\.__mul__", + r"\.__sub__", + r"\.__add__", + r"\.__iter__", + r"\.__div__", + r"\.__neg__", +} + +# -- sphinxcontrib-bibtex ---------------------------------------------------- +bibtex_bibfiles = ["./references.bib"] + +# -- sphinx.ext.linkcode ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/extensions/linkcode.html + +# Alternative method for linking to code by Osama, not sure which one is better +from urllib.parse import quote + +# https://github.com/python-websockets/websockets/blob/e217458ef8b692e45ca6f66c5aeb7fad0aee97ee/docs/conf.py#L102-L134 +def linkcode_resolve(domain, info): + # Check if the domain is Python, if not return None + if domain != "py": + return None + if not info["module"]: + return None + + # Import the module using the module information + mod = import_module(info["module"]) + + # Check if the fullname contains a ".", indicating it's a method or attribute of + # a class + if "." in info["fullname"]: + objname, attrname = info["fullname"].split(".") + # Get the object from the module + obj = getattr(mod, objname) + try: + # Try to get the attribute from the object + obj = getattr(obj, attrname) + except AttributeError: + # If the attribute doesn't exist, return None + return None + else: + # If the fullname doesn't contain a ".", get the object directly from the module + obj = getattr(mod, info["fullname"]) + + try: + # Try to get the source file and line numbers of the object + lines, first_line = inspect.getsourcelines(obj) + except TypeError: + # If the object is not a Python object that has a source file, return None + return None + + # Replace "." with "/" in the module name to construct the file path + filename = quote(info["module"].replace(".", "/")) + # If the filename doesn't start with "tests", add a "/" at the beginning + if not filename.startswith("tests"): + filename = "/" + filename + + # Construct the URL that points to the source code of the object on GitHub + return f"{gh_url}/blob/dev{filename}.py#L{first_line}-L{first_line + len(lines) - 1}" + +# Which domains to search in to create links in markdown texts +# myst_ref_domains = ["myst", "std", "py"] + + +_re_script_dirs = "fastsurfercnn|cerebnet|recon_surf|hypvinn" +_up = "^/\\.\\./" +_end = "(\\.md)?(#.*)?$" + +# re_reference_target=(regex) => used in missing-reference +fix_links_target = { + # all regexpr are ignorecase, individual replacements are applied until no further + # change occurs, but different (different repl str) replacements are not combined + # "^\\/overview\\/intro\\.md#": "/overview/index.rst#", + "^/?(.*)#(.*)ubuntu-(\\d{2})(\\d{2})": ("/\\1#\\2ubuntu-\\3-\\4",), + f"{_up}readme{_end}": ("/index.rst\\1", "/overview/intro.rst\\1"), + "^/overview/intro(#.*)?$": ("/overview/index.rst\\2",), + f"{_up}(singularity|docker)/readme{_end}": ("/overview/\\1.rst\\2",), + f"{_up}({_re_script_dirs})/readme{_end}": ("/scripts/\\1.rst\\2",), + f"{_up}license": ("/overview/license.rst",), +} +fix_links_alternative_targets = { + "/overview/intro": ("/index.rst", "/overview/index.rst"), +} +fix_links_project_root = Path("..") + diff --git a/images/FastSurfer_v5.pdf b/doc/images/FastSurfer_v5.pdf similarity index 100% rename from images/FastSurfer_v5.pdf rename to doc/images/FastSurfer_v5.pdf diff --git a/images/FastSurfer_v5.png b/doc/images/FastSurfer_v5.png similarity index 100% rename from images/FastSurfer_v5.png rename to doc/images/FastSurfer_v5.png diff --git a/images/detailed_network.pdf b/doc/images/detailed_network.pdf similarity index 100% rename from images/detailed_network.pdf rename to doc/images/detailed_network.pdf diff --git a/images/detailed_network.png b/doc/images/detailed_network.png similarity index 100% rename from images/detailed_network.png rename to doc/images/detailed_network.png diff --git a/images/teaser.png b/doc/images/teaser.png similarity index 100% rename from images/teaser.png rename to doc/images/teaser.png diff --git a/images/teaser_white.pdf b/doc/images/teaser_white.pdf similarity index 100% rename from images/teaser_white.pdf rename to doc/images/teaser_white.pdf diff --git a/doc/index.rst b/doc/index.rst new file mode 100644 index 00000000..e266e902 --- /dev/null +++ b/doc/index.rst @@ -0,0 +1,26 @@ +.. FastSurfer documentation master file, created by + sphinx-quickstart on Thu Nov 30 15:48:44 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +.. image:: images/teaser.png + :alt: FastSurfer Teaser Image + :align: center + +.. include:: ../README.md + :parser: fix_links.parser + :relative-docs: . + :start-after: + :end-before: + +.. include:: ../README.md + :parser: fix_links.parser + :relative-docs: . + :start-after: + +.. toctree:: + :hidden: + + overview/index + scripts/index + api/index \ No newline at end of file diff --git a/doc/make.bat b/doc/make.bat new file mode 100644 index 00000000..32bb2452 --- /dev/null +++ b/doc/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/CODE_OF_CONDUCT.md b/doc/overview/CODE_OF_CONDUCT.md similarity index 100% rename from CODE_OF_CONDUCT.md rename to doc/overview/CODE_OF_CONDUCT.md diff --git a/CONTRIBUTING.md b/doc/overview/CONTRIBUTING.md similarity index 92% rename from CONTRIBUTING.md rename to doc/overview/CONTRIBUTING.md index 3e010dc2..aaf0ce49 100644 --- a/CONTRIBUTING.md +++ b/doc/overview/CONTRIBUTING.md @@ -10,7 +10,7 @@ Please complete the following steps in advance to help us fix any potential bug - Make sure that you are using the latest version. - Determine if your bug is really a bug and not an error on your side e.g. using incompatible environment components/versions. -- To see if other users have experienced (and potentially already solved) the same issue you are having, check if there is not already a bug report existing for your bug or error in the [bug tracker](issues?q=label%3Abug). +- To see if other users have experienced (and potentially already solved) the same issue you are having, check if there is not already a bug report existing for your bug or error in the [bug tracker](https://github.com/Deep-MI/FastSurfer/issues?q=label%3Abug). - Collect information about the bug: - Stack trace (Traceback) - OS, Platform and Version (Windows, Linux, macOS, x86, ARM) @@ -22,7 +22,7 @@ Please complete the following steps in advance to help us fix any potential bug We use GitHub issues to track bugs and errors. If you run into an issue with the project: -- Open an [Issue](issues/new). (Since we can't be sure at this point whether it is a bug or not, we ask you not to talk about a bug yet and not to label the issue.) +- Open an [Issue](https://github.com/Deep-MI/FastSurfer/issues/new). (Since we can't be sure at this point whether it is a bug or not, we ask you not to talk about a bug yet and not to label the issue.) - Explain the behavior you would expect and the actual behavior. - Please provide as much context as possible and describe the *reproduction steps* that someone else can follow to recreate the issue on their own. This usually includes your code. For good bug reports you should isolate the problem and create a reduced test case. - Provide the information you collected in the previous section. @@ -42,12 +42,12 @@ Please follow these guidelines to help maintainers and the community to understa - Make sure that you are using the latest version. - Read the documentation carefully and find out if the functionality is already covered, maybe by an individual configuration. -- Perform a [search](issues) to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one. +- Perform a [search](https://github.com/Deep-MI/FastSurfer/issues) to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one. - Find out whether your idea fits with the scope and aims of the project. It's up to you to make a strong case to convince the project's developers of the merits of this feature. Keep in mind that we want features that will be useful to the majority of our users and not just a small subset. If you're just targeting a minority of users, consider writing an add-on/plugin library. ### How Do I Submit a Good Enhancement Suggestion? -Enhancement suggestions are tracked as [GitHub issues](issues). +Enhancement suggestions are tracked as [GitHub issues](https://github.com/Deep-MI/FastSurfer/issues). - Use a **clear and descriptive title** for the issue to identify the suggestion. - Provide a **step-by-step description of the suggested enhancement** in as many details as possible. diff --git a/EDITING.md b/doc/overview/EDITING.md similarity index 100% rename from EDITING.md rename to doc/overview/EDITING.md diff --git a/doc/overview/EXAMPLES.md b/doc/overview/EXAMPLES.md new file mode 100644 index 00000000..afe26a9d --- /dev/null +++ b/doc/overview/EXAMPLES.md @@ -0,0 +1,224 @@ +# Examples + +## Example 1: FastSurfer Docker +After pulling one of our images from Dockerhub, you do not need to have a separate installation of FreeSurfer on your computer (it is already included in the Docker image). However, if you want to run ___more than just the segmentation CNN___, you need to register at the FreeSurfer website (https://surfer.nmr.mgh.harvard.edu/registration.html) to acquire a valid license for free. The directory containing the license needs to be mounted and passed to the script via the `--fs_license` flag. Basically for Docker (as for Singularity below) you are starting a container image (with the run command) and pass several parameters for that, e.g. if GPUs will be used and mounting (linking) the input and output directories to the inside of the container image. In the second half of that call you pass parameters to our run_fastsurfer.sh script that runs inside the container (e.g. where to find the FreeSurfer license file, and the input data and other flags). + +To run FastSurfer on a given subject using the provided GPU-Docker, execute the following command: + +```bash +# 1. get the fastsurfer docker image (if it does not exist yet) +docker pull deepmi/fastsurfer + +# 2. Run command +docker run --gpus all -v /home/user/my_mri_data:/data \ + -v /home/user/my_fastsurfer_analysis:/output \ + -v /home/user/my_fs_license_dir:/fs_license \ + --rm --user $(id -u):$(id -g) deepmi/fastsurfer:latest \ + --fs_license /fs_license/license.txt \ + --t1 /data/subjectX/t1-weighted.nii.gz \ + --sid subjectX --sd /output \ + --parallel --3T +``` + +Docker Flags: +* The `--gpus` flag is used to allow Docker to access GPU resources. With it, you can also specify how many GPUs to use. In the example above, _all_ will use all available GPUS. To use a single one (e.g. GPU 0), set `--gpus device=0`. To use multiple specific ones (e.g. GPU 0, 1 and 3), set `--gpus 'device=0,1,3'`. +* The `-v` commands mount your data, output, and directory with the FreeSurfer license file into the docker container. Inside the container these are visible under the name following the colon (in this case /data, /output, and /fs_license). +* The `--rm` flag takes care of removing the container once the analysis finished. +* The `--user $(id -u):$(id -g)` part automatically runs the container with your group- (id -g) and user-id (id -u). All generated files will then belong to the specified user. Without the flag, the docker container will be run as root which is discouraged. + +FastSurfer Flag: +* The `--fs_license` points to your FreeSurfer license which needs to be available on your computer in the my_fs_license_dir that was mapped above. +* The `--t1` points to the t1-weighted MRI image to analyse (full path, with mounted name inside docker: /home/user/my_mri_data => /data) +* The `--sid` is the subject ID name (output folder name) +* The `--sd` points to the output directory (its mounted name inside docker: /home/user/my_fastsurfer_analysis => /output) +* The `--parallel` activates processing left and right hemisphere in parallel +* The `--3T` changes the atlas for registration to the 3T atlas for better Talairach transforms and ICV estimates (eTIV) + +Note, that the paths following `--fs_license`, `--t1`, and `--sd` are __inside__ the container, not global paths on your system, so they should point to the places where you mapped these paths above with the `-v` arguments (part after colon). + +A directory with the name as specified in `--sid` (here subjectX) will be created in the output directory if it does not exist. So in this example output will be written to /home/user/my_fastsurfer_analysis/subjectX/ . Make sure the output directory is empty, to avoid overwriting existing files. + +If you do not have a GPU, you can also run our CPU-Docker by dropping the `--gpus all` flag and specifying `--device cpu` at the end as a FastSurfer flag, see also [FastSurfer's docker documentation](../../Docker/README.md) for more details. + +## Example 2: FastSurfer Singularity +After building the Singularity image (see below or [these instructions](../../Singularity/README.md)), you also need to register at the FreeSurfer website (https://surfer.nmr.mgh.harvard.edu/registration.html) to acquire a valid license (for free) - same as when using Docker. This license needs to be passed to the script via the `--fs_license` flag. This is not necessary if you want to run the segmentation only. + +To run FastSurfer on a given subject using the Singularity image with GPU access, execute the following commands from a directory where you want to store singularity images. This will create a singularity image from our Dockerhub image and execute it: + +```bash +# 1. Build the singularity image (if it does not exist) +singularity build fastsurfer-gpu.sif docker://deepmi/fastsurfer + +# 2. Run command +singularity exec --nv \ + --no-home \ + -B /home/user/my_mri_data:/data \ + -B /home/user/my_fastsurfer_analysis:/output \ + -B /home/user/my_fs_license_dir:/fs_license \ + ./fastsurfer-gpu.sif \ + /fastsurfer/run_fastsurfer.sh \ + --fs_license /fs_license/license.txt \ + --t1 /data/subjectX/t1-weighted.nii.gz \ + --sid subjectX --sd /output \ + --parallel --3T +``` + +### Singularity Flags +* The `--nv` flag is used to access GPU resources. +* The `--no-home` flag stops mounting your home directory into singularity. +* The `-B` commands mount your data, output, and directory with the FreeSurfer license file into the Singularity container. Inside the container these are visible under the name following the colon (in this case /data, /output, and /fs_license). + +### FastSurfer Flags +* The `--fs_license` points to your FreeSurfer license which needs to be available on your computer in the my_fs_license_dir that was mapped above. +* The `--t1` points to the t1-weighted MRI image to analyse (full path, with mounted name inside docker: /home/user/my_mri_data => /data) +* The `--sid` is the subject ID name (output folder name) +* The `--sd` points to the output directory (its mounted name inside docker: /home/user/my_fastsurfer_analysis => /output) +* The `--parallel` activates processing left and right hemisphere in parallel +* The `--3T` changes the atlas for registration to the 3T atlas for better Talairach transforms and ICV estimates (eTIV) + +Note, that the paths following `--fs_license`, `--t1`, and `--sd` are __inside__ the container, not global paths on your system, so they should point to the places where you mapped these paths above with the `-v` arguments (part after colon). + +A directory with the name as specified in `--sid` (here subjectX) will be created in the output directory. So in this example output will be written to /home/user/my_fastsurfer_analysis/subjectX/ . Make sure the output directory is empty, to avoid overwriting existing files. + +You can run the Singularity equivalent of CPU-Docker by building a Singularity image from the CPU-Docker image and excluding the `--nv` argument in your Singularity exec command. Also append `--device cpu` as a FastSurfer flag. + + +## Example 3: Native FastSurfer on subjectX with parallel processing of hemis + +For a native install you may want to make sure that you are on our stable branch, as the default dev branch is for development and could be broken at any time. For that you can directly clone the stable branch: + +```bash +git clone --branch stable https://github.com/Deep-MI/FastSurfer.git +``` + +More details (e.g. you need all dependencies in the right versions and also FreeSurfer locally) can be found in our [Installation guide](INSTALL.md). +Given you want to analyze data for subject which is stored on your computer under /home/user/my_mri_data/subjectX/t1-weighted.nii.gz, run the following command from the console (do not forget to source FreeSurfer!): + +```bash +# Source FreeSurfer +export FREESURFER_HOME=/path/to/freesurfer +source $FREESURFER_HOME/SetUpFreeSurfer.sh + +# Define data directory +datadir=/home/user/my_mri_data +fastsurferdir=/home/user/my_fastsurfer_analysis + +# Run FastSurfer +./run_fastsurfer.sh --t1 $datadir/subjectX/t1-weighted-nii.gz \ + --sid subjectX --sd $fastsurferdir \ + --parallel --threads 4 --3T +``` + +The output will be stored in the $fastsurferdir (including the aparc.DKTatlas+aseg.deep.mgz segmentation under $fastsurferdir/subjectX/mri (default location)). Processing of the hemispheres will be run in parallel (--parallel flag) to significantly speed-up surface creation. Omit this flag to run the processing sequentially, e.g. if you want to save resources on a compute cluster. + + +## Example 4: FastSurfer on multiple subjects + +In order to run FastSurfer on multiple cases, you may use the helper script `brun_subjects.sh`. This script accepts multiple ways to define the subjects, for example a subjects_list file. +Prepare the subjects_list file as follows (one line subject per line; delimited by `\n`): +``` +subject_id1=path_to_t1 +subject2=path_to_t1 +subject3=path_to_t1 +... +subject10=path_to_t1 +``` +Note, that all paths (`path_to_t1`) are as if you passed them to the `run_fastsurfer.sh` script via `--t1 ` so they may be with respect to the singularity or docker file system. Absolute paths are recommended. + +The `brun_fastsurfer.sh` script can then be invoked in docker, singularity or on the native platform as follows: + +### Docker +```bash +docker run --gpus all -v /home/user/my_mri_data:/data \ + -v /home/user/my_fastsurfer_analysis:/output \ + -v /home/user/my_fs_license_dir:/fs_license \ + --entrypoint "/fastsurfer/brun_fastsurfer.sh" \ + --rm --user $(id -u):$(id -g) deepmi/fastsurfer:latest \ + --fs_license /fs_license/license.txt \ + --sd /output --subject_list /data/subjects_list.txt \ + --parallel --3T +``` +### Singularity +```bash +singularity exec --nv \ + --no-home \ + -B /home/user/my_mri_data:/data \ + -B /home/user/my_fastsurfer_analysis:/output \ + -B /home/user/my_fs_license_dir:/fs_license \ + ./fastsurfer-gpu.sif \ + /fastsurfer/brun_fastsurfer.sh \ + --fs_license /fs_license/license.txt \ + --sd /output \ + --subject_list /data/subjects_list.txt \ + --parallel --3T +``` +### Native +```bash +export FREESURFER_HOME=/path/to/freesurfer +source $FREESURFER_HOME/SetUpFreeSurfer.sh + +cd /home/user/FastSurfer +datadir=/home/user/my_mri_data +fastsurferdir=/home/user/my_fastsurfer_analysis + +# Run FastSurfer +./brun_fastsurfer.sh --subject_list $datadir/subjects_list.txt \ + --sd $fastsurferdir \ + --parallel --threads 4 --3T +``` + +### Flags +The `brun_fastsurfer.sh` script accepts almost all `run_fastsurfer.sh` flags (exceptions are `--t1` and `--sid`). In addition, +* the `--parallel_subjects` runs all subjects in parallel (experimental, parameter may change in future releases). This is particularly useful for surfaces computation `--surf_only`. +* to run segmentation in series, but surfaces in parallel, you may use `--parallel_subjects surf`. +* these options are in contrast (and in addition) to `--parallel`, which just parallelizes the hemispheres of one case. + +## Example 5: Quick Segmentation + +For many applications you won't need the surfaces. You can run only the aparc+DKT segmentation (in 1 minute on a GPU) via + +```bash +./run_fastsurfer.sh --t1 $datadir/subject1/t1-weighted.nii.gz \ + --asegdkt_segfile $outputdir/subject1/aparc.DKTatlas+aseg.deep.mgz \ + --conformed_name $outputdir/subject1/conformed.mgz \ + --threads 4 --seg_only --no_cereb +``` + +This will produce the segmentation in a conformed space (just as FreeSurfer would do). It also writes the conformed image that fits the segmentation. +Conformed means that the image will be isotropic in LIA orientation. +It will furthermore output a brain mask (`mri/mask.mgz`), a simplified segmentation file (`mri/aseg.auto_noCCseg.mgz`), the biasfield corrected image (`mri/orig_nu.mgz`), and the volume statistics (without eTIV) based on the FastSurferVINN segmentation (without the corpus callosum) (`stats/aseg+DKT.stats`). + +If you do not even need the biasfield corrected image and the volume statistics, you may add `--no_biasfield`. These steps especially benefit from larger assigned core counts `--threads 32`. + +The above ```run_fastsurfer.sh``` commands can also be called from the Docker or Singularity images by passing the flags and adjusting input and output directories to the locations inside the containers (where you mapped them via the -v flag in Docker or -B in Singularity). + +```bash +# Docker +docker run --gpus all -v $datadir:/data \ + -v $outputdir:/output \ + --rm --user $(id -u):$(id -g) deepmi/fastsurfer:latest \ + --t1 /data/subject1/t1-weighted.nii.gz \ + --asegdkt_segfile /output/subject1/aparc.DKTatlas+aseg.deep.mgz \ + --conformed_name $outputdir/subject1/conformed.mgz \ + --threads 4 --seg_only --3T +``` + +## Example 6: Running FastSurfer on a SLURM cluster via Singularity + +Starting with version 2.2, FastSurfer comes with a script that helps orchestrate FastSurfer optimally on a SLURM cluster: `srun_fastsurfer.sh`. + +This script distributes GPU-heavy and CPU-heavy workloads to different SLURM partitions and manages intermediate files in a work directory for IO performance. + +```bash +srun_fastsurfer.sh --partition seg=GPU_Partition \ + --partition surf=CPU_Partition \ + --sd $outputdir \ + --data $datadir \ + --singularity_image $HOME/images/fastsurfer-singularity.sif \ + [...] # fastsurfer flags +``` + +This will create three dependent SLURM jobs, one to segment, one for surface reconstruction and one for cleanup (which moves the data from the work directory to the `$outputdir`). +There are many intricacies and options, so it is advised to use `--help`, `--debug` and `--dry` to inspect, what will be scheduled as well as run a test on a small subset. More control over subjects is available with `--subject_list`. + +The `$outputdir` and the `$datadir` need to be accessible from cluster nodes. Most IO is performed on a work directory (automatically generated from `$HPCWORK` environment variable: `$HPCWORK/fastsurfer-processing/$(date +%Y%m%d-%H%M%S)`). Alternatively, an empty directory can be manually defined via `--work`. On successful cleanup, this directory will be removed. diff --git a/doc/overview/FLAGS.md b/doc/overview/FLAGS.md new file mode 100644 index 00000000..5c4c85a7 --- /dev/null +++ b/doc/overview/FLAGS.md @@ -0,0 +1,57 @@ +# Flags +Next, you will need to select the `*fastsurfer-flags*` and replace `*fastsurfer-flags*` with your options. Please see the Examples below for some example flags. + +The `*fastsurfer-flags*` will usually include the subject directory (`--sd`; Note, this will be the mounted path - `/output` - for containers), the subject name/id (`--sid`) and the path to the input image (`--t1`). For example: + +```bash +... --sd /output --sid test_subject --t1 /data/test_subject_t1.nii.gz --3T +``` +Additionally, you can use `--seg_only` or `--surf_only` to only run a part of the pipeline or `--no_biasfield`, `--no_cereb` and `--no_asegdkt` to switch off some segmentation modules (see above). +Here, we have also added the `--3T` flag, which tells fastsurfer to register against the 3T atlas for ICV estimation (eTIV). + +In the following, we give an overview of the most important options, but you can view a full list of options with + +```bash +./run_fastsurfer.sh --help +``` + + +## Required arguments +* `--sd`: Output directory \$SUBJECTS_DIR (equivalent to FreeSurfer setup --> $SUBJECTS_DIR/sid/mri; $SUBJECTS_DIR/sid/surf ... will be created). +* `--sid`: Subject ID for directory inside \$SUBJECTS_DIR to be created ($SUBJECTS_DIR/sid/...) +* `--t1`: T1 full head input (not bias corrected, global path). The network was trained with conformed images (UCHAR, 256x256x256, 1-0.7 mm voxels and standard slice orientation). These specifications are checked in the run_prediction.py script and the image is automatically conformed if it does not comply. Note, outputs will be in the conformed space (as in FreeSurfer). + +## Required for docker when running surface module +* `--fs_license`: Path to FreeSurfer license key file (only needed for the surface module). Register (for free) at https://surfer.nmr.mgh.harvard.edu/registration.html to obtain it if you do not have FreeSurfer installed so far. Strictly necessary if you use Docker, optional for local install (your local FreeSurfer license will automatically be used). The license file is usually located in $FREESURFER_HOME/license.txt or $FREESURFER_HOME/.license . + +## Segmentation pipeline arguments (optional) +* `--seg_only`: only run FastSurferCNN (generate segmentation, do not run the surface pipeline) +* `--seg_log`: Name and location for the log-file for the segmentation (FastSurferCNN). Default: $SUBJECTS_DIR/$sid/scripts/deep-seg.log +* `--viewagg_device`: Define where the view aggregation should be run on. Can be "auto" or a device (see --device). By default, the program checks if you have enough memory to run the view aggregation on the gpu. The total memory is considered for this decision. If this fails, or you actively overwrote the check with setting with "cpu" view agg is run on the cpu. Equivalently, if you pass a different device, view agg will be run on that device (no memory check will be done). +* `--device`: Select device for NN segmentation (_auto_, _cpu_, _cuda_, _cuda:_, _mps_), where cuda means Nvidia GPU, you can select which one e.g. "cuda:1". Default: "auto", check GPU and then CPU. "mps" is for native MAC installs to use the Apple silicon (M-chip) GPU. +* `--asegdkt_segfile`: Name of the segmentation file, which includes the aparc+DKTatlas-aseg segmentations. Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/aparc.DKTatlas+aseg.deep.mgz +* `--no_cereb`: Switch of the cerebellum sub-segmentation +* `--cereb_segfile`: Name of the cerebellum segmentation file. If not provided, this intermediate DL-based segmentation will not be stored, but only the merged segmentation will be stored (see --main_segfile ). Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/cerebellum.CerebNet.nii.gz +* `--no_biasfield`: Deactivate the calculation of partial volume-corrected statistics. + +## Surface pipeline arguments (optional) +* `--surf_only`: only run the surface pipeline recon_surf. The segmentation created by FastSurferCNN must already exist in this case. +* `--3T`: for Talairach registration, use the 3T atlas instead of the 1.5T atlas (which is used if the flag is not provided). This gives better (more consistent with FreeSurfer) ICV estimates (eTIV) for 3T and better Talairach registration matrices, but has little impact on standard volume or surface stats. +* `--fstess`: Use mri_tesselate instead of marching cube (default) for surface creation +* `--fsqsphere`: Use FreeSurfer default instead of novel spectral spherical projection for qsphere +* `--fsaparc`: Use FS aparc segmentations in addition to DL prediction (slower in this case and usually the mapped ones from the DL prediction are fine) +* `--parallel`: Run both hemispheres in parallel +* `--no_fs_T1`: Do not generate T1.mgz (normalized nu.mgz included in standard FreeSurfer output) and create brainmask.mgz directly from norm.mgz instead. Saves 1:30 min. +* `--no_surfreg`: Skip the surface registration (`sphere.reg`), which is generated automatically by default. To safe time, use this flag to turn this off. + +## Other +* `--threads`: Target number of threads for all modules (segmentation, surface pipeline), `1` (default) forces FastSurfer to only really use one core. Note, that the default value may change in the future for better performance on multi-core architectures. +* `--vox_size`: Forces processing at a specific voxel size. If a number between 0.7 and 1 is specified (below is experimental) the T1w image is conformed to that isotropic voxel size and processed. + If "min" is specified (default), the voxel size is read from the size of the minimal voxel size (smallest per-direction voxel size) in the T1w image: + If the minimal voxel size is bigger than 0.98mm, the image is conformed to 1mm isometric. + If the minimal voxel size is smaller or equal to 0.98mm, the T1w image will be conformed to isometric voxels of that voxel size. + The voxel size (whether set manually or derived) determines whether the surfaces are processed with highres options (below 1mm) or not. +* `--py`: Command for python, used in both pipelines. Default: python3.10 +* `--conformed_name`: Name of the file in which the conformed input image will be saved. Default location: \$SUBJECTS_DIR/\$sid/mri/orig.mgz +* `--ignore_fs_version`: Switch on to avoid check for FreeSurfer version. Program will terminate if the supported version (see recon-surf.sh) is not sourced. Can be used for testing dev versions. +* `-h`, `--help`: Prints help text diff --git a/INSTALL.md b/doc/overview/INSTALL.md similarity index 67% rename from INSTALL.md rename to doc/overview/INSTALL.md index 55887ec4..f410c975 100644 --- a/INSTALL.md +++ b/doc/overview/INSTALL.md @@ -1,13 +1,11 @@ -# Introduction - -FastSurfer is a pipeline for the segmentation of human brain MRI data. It consists of two main components: the networks for the fast segmentation of an MRI (FastSurferVINN, CerebNet) and the recon_surf script for the efficient creation of surfaces and most files and statistics that also FreeSurfer provides. +# Installation -The preferred way of installing and running FastSurfer is via Singularity or Docker containers. We provide pre-build images at Dockerhub for various application cases: i) for only the segmentation (both GPU and CPU), ii) for only the CPU-based recon-surf pipeline, and iii) for the full pipeline (GPU or CPU). +FastSurfer is a pipeline for the segmentation of human brain MRI data. It consists of two main components: the networks for the fast segmentation of an MRI (FastSurferVINN, CerebNet, ...) and the recon_surf script for the efficient creation of surfaces and most files and statistics that also FreeSurfer provides. -We also provide information on a native install on some operating systems, but since dependencies may vary, this can produce results different from our testing environment and we may not be able to support you if things don't work. Our testing is performed on Ubuntu 20.04 via our provided Docker images. +The preferred way of installing and running FastSurfer is via Singularity or Docker containers on a Linux host system (with a GPU). We provide pre-build images at Dockerhub for various application cases: i) for only the segmentation (both GPU and CPU), ii) for only the CPU-based recon-surf pipeline, and iii) for the full pipeline (GPU or CPU). +We also provide information on a native install on some operating systems, but since dependencies may vary, this can produce results different from our testing environment and we may not be able to support you if things don't work. Our testing is performed on Ubuntu 22.04 via our provided Docker images. -# Installation ## Linux @@ -15,18 +13,18 @@ Recommended System Spec: 8 GB system memory, NVIDIA GPU with 8 GB graphics memor Minimum System Spec: 8 GB system memory (this requires running FastSurfer on the CPU only, which is much slower) -Non-NVIDIA GPU architectures (Apple M1, AMD) are not officially supported, but experimental. +Non-NVIDIA GPU architectures (AMD) are experimental and not officially supported, but seem to work well also. ### Singularity -Assuming you have singularity installed already (by a system admin), you can build an image easily from our Dockerhub images. Run this command from a directory where you want to store singularity images: +Assuming you have singularity installed already (by a system admin), you can build a Singularity image easily from our Dockerhub images. Run this command from a directory where you want to store singularity images: ```bash singularity build fastsurfer-gpu.sif docker://deepmi/fastsurfer:latest ``` -Additionally, [the Singularity README](Singularity/README.md) contains detailed directions for building your own Singularity images from Docker. +Additionally, [the Singularity README](../../Singularity/README.md) contains detailed directions for building your own Singularity images from Docker. -Our [README](README.md#example-2--fastsurfer-singularity) explains how to run FastSurfer (for the full pipeline you will also need a FreeSurfer .license file!) and you can find details on how to build your own images here: [Docker](docker/README.md) and [Singularity](singularity/README.md). +[Example 2](EXAMPLES.md#example-2-fastsurfer-singularity) explains how to run FastSurfer (for the full pipeline you will also need a FreeSurfer .license file!) and you can find details on how to build your own images here: [Docker](../../Docker/README.md) and [Singularity](../../Singularity/README.md). ### Docker @@ -37,7 +35,7 @@ This is very similar to Singularity. Assuming you have Docker installed (by a sy docker pull deepmi/fastsurfer:latest ``` -Our [README](README.md#example-1--fastsurfer-docker) explains how to run FastSurfer (for the full pipeline you will also need a FreeSurfer .license file!) and you can find details on how to [build your own image](docker/README.md). +[Example 1](EXAMPLES.md#example-1-fastsurfer-docker) explains how to run FastSurfer (for the full pipeline you will also need a FreeSurfer .license file!) and you can find details on how to [build your own image](https://github.com/Deep-MI/FastSurfer/blob/dev/Docker/README.md). ### Native (Ubuntu 20.04 or Ubuntu 22.04) @@ -63,13 +61,13 @@ sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test sudo apt install -y g++-11 ``` -You also need to have bash-4.0 or higher (check with `bash --version`). +You also need to have bash-3.2 or higher (check with `bash --version`). -You also need a [working version of python3](#2-conda--for-python-) (we recommend python 3.8 -- we do not support other versions). These packages should be sufficient to install python dependencies and then run the FastSurfer neural network segmentation. If you want to run the full pipeline, you also need a [working installation of FreeSurfer](https://surfer.nmr.mgh.harvard.edu/fswiki/rel7downloads) (including its dependencies). +You also need a working version of python3.10 (we do not support other versions). These packages should be sufficient to install python dependencies and then run the FastSurfer neural network segmentation. If you want to run the full pipeline, you also need a [working installation of FreeSurfer](https://surfer.nmr.mgh.harvard.edu/fswiki/rel7downloads) (including its dependencies and a license file). If you are using pip, make sure pip is updated as older versions will fail. -#### 2. Conda (for python) +#### 2. Conda for python We recommend to install conda as your python environment. If you don't have conda on your system, an admin needs to install it: @@ -93,12 +91,12 @@ cd FastSurfer Create a new environment and install FastSurfer dependencies: ```bash -conda env create -f ./fastsurfer_env_gpu.yml -conda activate fastsurfer_gpu +conda env create -f ./env/fastsurfer.yml +conda activate fastsurfer ``` -If you do not have an NVIDIA GPU, replace `./fastsurfer_env_gpu.yml` with the cpu-only environment file `./fastsurfer_env_cpu.yml`. -If you only want to run the surface pipeline, replace `./fastsurfer_env_gpu.yml` with the reconsurf-only environment file `./fastsurfer_env_reconsurf.yml`. +If you do not have an NVIDIA GPU, you can create appropriate ymls on the fly with `python ./Docker/install_env.py -m $MODE -i ./env/FastSurfer.yml -o ./fastsurfer_$MODE.yml`. Here `$MODE` can be for example `cpu`, see also `python ./Docker/install_env.py --help` for other options like rocm or cuda versions. Finally, replace `./env/fastsurfer.yml` with your custom environment file `./fastsurfer_$MODE.yml`. +If you only want to run the surface pipeline, use `./env/fastsurfer_reconsurf.yml`. Next, add the fastsurfer directory to the python path (make sure you have changed into it already): ```bash @@ -115,10 +113,10 @@ You can also download all network checkpoint files (this should be done if you a python3 FastSurferCNN/download_checkpoints.py --all ``` -Once all dependencies are installed, you are ready to run the FastSurfer segmentation-only (!!) pipeline by calling ```./run_fastsurfer.sh --seg_only ....``` , see the [README](README.md#example-3--native-fastsurfer-on-subjectx--with-parallel-processing-of-hemis-) for command line flags. +Once all dependencies are installed, you are ready to run the FastSurfer segmentation-only (!!) pipeline by calling ```./run_fastsurfer.sh --seg_only ....``` , see [Example 3](EXAMPLES.md#example-3-native-fastsurfer-on-subjectx-with-parallel-processing-of-hemis) for command line flags. #### 5. FreeSurfer -To run the full pipeline, you will need to install FreeSurfer (we recommend and support version 7.3.2) according to their [Instructions](https://surfer.nmr.mgh.harvard.edu/fswiki/rel7downloads). There is a freesurfer email list, if you run into problems during this step. +To run the full pipeline, you will need to install FreeSurfer (we recommend and support version 7.4.1) according to their [Instructions](https://surfer.nmr.mgh.harvard.edu/fswiki/rel7downloads). There is a freesurfer email list, if you run into problems during this step. Make sure, the `${FREESURFER_HOME}` environment variable is set, so FastSurfer finds the FreeSurfer binaries. @@ -132,7 +130,7 @@ Build the Docker container with ROCm support. python Docker/build.py --device rocm --tag my_fastsurfer:rocm ``` -You will need to add a couple of flags to your docker run command for AMD, see [the Readme](README.md#example-1--fastsurfer-docker) for `**other-docker-flags**` or `**fastsurfer-flags**`: +You will need to add a couple of flags to your docker run command for AMD, see [Example 1](EXAMPLES.md#example-1-fastsurfer-docker) for `**other-docker-flags**` or `**fastsurfer-flags**`: ```bash docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd \ --device=/dev/dri --group-add video --ipc=host --shm-size 8G \ @@ -143,7 +141,7 @@ Note, that this docker image is experimental, uses a different Python version an ## MacOS -Processing on Mac CPUs is possible. On Apple Silicon, you can even use the GPU (experimental) by passing ```--device mps```. +Processing on Mac CPUs is possible. On Apple Silicon, you can even use the GPU by passing ```--device mps```. Recommended System Spec: Mac with Apple Silicon M-Chip and 16 GB system memory. @@ -162,31 +160,29 @@ Second, pull one of our Docker containers. Open a terminal window and run: docker pull deepmi/fastsurfer:latest ``` -Continue with the example in our [README](README.md#example-1--fastsurfer-docker). +Continue with the example in [Example 1](EXAMPLES.md#example-1-fastsurfer-docker). ### Native -On modern Macs with the Apple Silicon M1 or M2 ARM-based chips, we recommend a native installation as it runs much faster than Docker in our tests. The experimental support for the built-in AI Accelerator is also only available on native installations. Native installation also supports older Intel chips. +On modern Macs with the Apple Silicon M1 or M2 ARM-based chips, we recommend a native installation as it runs much faster than Docker in our tests. Access to the built-in AI accelerator (MPS) is also only available on native installations. A native installation also works on older Intel chips. -#### 1. Git and Bash -If you do not have git and a recent bash (version > 4.0 required!) installed, install them via the packet manager, e.g. brew. -This installs brew and then bash: +#### 1. Dependency packages +If you do not have git, python3.10 or bash (at least 3.2) you can install these via the packet manager brew. +This installs brew and then git and python3.10: ```sh /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" -brew install bash +brew install git python@3.10 ``` -Make sure you use this bash and not the older one provided with MacOS! - #### 2. Python -Create a python environment, activate it, and upgrade pip. Here we use pip, but you should also be able to use [conda](#2-conda--for-python-): +Create a python environment, activate it, and upgrade pip: ```sh -python3 -m venv $HOME/python-envs/fastsurfer +python3.10 -m venv $HOME/python-envs/fastsurfer source $HOME/python-envs/fastsurfer/bin/activate -python3 -m pip install --upgrade pip +python3.10 -m pip install --upgrade pip ``` #### 3. FastSurfer and Requirements @@ -197,13 +193,12 @@ cd FastSurfer export PYTHONPATH="${PYTHONPATH}:$PWD" ``` - Install the FastSurfer requirements ```sh -python3 -m pip install -r requirements.mac.txt +python3.10 -m pip install -r requirements.mac.txt ``` -If this step fails, you may need to edit ```requirements.mac.txt``` and adjust version number to what is available. +If this step fails, you may need to edit ```requirements.mac.txt``` and exclude version numbers that produce conflicts or break our code. On newer M1 Macs, we also had issues with the h5py package, which could be solved by using brew for help (not sure this is needed any longer): ```sh @@ -212,26 +207,24 @@ export HDF5_DIR="$(brew --prefix hdf5)" pip3 install --no-binary=h5py h5py ``` -You can also download all network checkpoint files (this should be done if you are installing for multiple users): +You can also download all network checkpoint files at this point already: ```sh -python3 FastSurferCNN/download_checkpoints.py --all +python3.10 FastSurferCNN/download_checkpoints.py --all ``` -Once all dependencies are installed, run the FastSurfer segmentation only (!!) by calling ```bash ./run_fastsurfer.sh --seg_only ....``` with the appropriate command line flags, see the [README](README.md#usage). - -Note: You may always need to prepend the command with `bash` (i.e. `bash run_fastsurfer.sh <...>`) to ensure that bash 4.0 is used instead of the system default. +Once all dependencies are installed, you can run the FastSurfer segmentation only by calling ```./run_fastsurfer.sh --seg_only ....``` with the appropriate command line flags, see the [commandline documentation](../../README.md#usage). -To run the full pipeline, install and source also the supported FreeSurfer version according to their [Instructions](https://surfer.nmr.mgh.harvard.edu/fswiki/rel7downloads). There is a freesurfer email list, if you run into problems during this step. +To run the full pipeline, install and source also the supported FreeSurfer version according to their [Instructions](https://surfer.nmr.mgh.harvard.edu/fswiki/rel7downloads). There is a freesurfer email list, if you run into problems during this step. Note, that currently FreeSurfer for MacOS supports no ARM, but only Intel, so on modern M-chips it will be slow due to the emulation. This is why we recommend using a Linux host system to run FastSurfer on larger datasets. #### 4. Apple AI Accelerator support -You can also try the experimental support for the Apple Silicon AI Accelerator by setting `PYTORCH_ENABLE_MPS_FALLBACK` and passing `--device mps`: +On modern M-Chips you can try the Apple Silicon AI Accelerator by setting `PYTORCH_ENABLE_MPS_FALLBACK` and passing `--device mps` for the segmentation module to make use of the fast GPU: ```sh export PYTORCH_ENABLE_MPS_FALLBACK=1 ./run_fastsurfer.sh --seg_only --device mps .... ``` -This will be at least twice as fast as `--device cpu`. The fallback environment variable is necessary as one function is not yet implemented for the GPU and will fall back to CPU. +This will be at least twice as fast as `--device cpu`. Currently setting the fallback environment variable is necessary as `aten::max_unpool2d` is not yet implemented for MPS and will fall back to CPU. ## Windows @@ -246,15 +239,15 @@ installed and running. After everything is installed, start Windows PowerShell and run the following command to pull the CPU Docker image (check on [dockerhub](https://hub.docker.com/r/deepmi/fastsurfer/tags) what version tag is most recent for cpu): ```bash -docker pull deepmi/fastsurfer:cpu-v2.0.4 +docker pull deepmi/fastsurfer:cpu-latest ``` -Now you can run Fastsurfer the same way as described in our [README](README.md#example-1--fastsurfer-docker) for the CPU build, for example: +Now you can run Fastsurfer the same way as described in [Example 1](EXAMPLES.md#example-1-fastsurfer-docker) for the CPU build, for example: ```bash docker run -v C:/Users/user/my_mri_data:/data \ -v C:/Users/user/my_fastsurfer_analysis:/output \ -v C:/Users/user/my_fs_license_dir:/fs_license \ - --rm --user $(id -u):$(id -g) deepmi/fastsurfer:cpu-v2.0.0 \ + --rm --user $(id -u):$(id -g) deepmi/fastsurfer:cpu-latest \ --fs_license /fs_license/license.txt \ --t1 /data/subjectX/orig.mgz \ --device cpu \ @@ -281,7 +274,7 @@ After everything is installed, start Windows PowerShell and run the following co docker pull deepmi/fastsurfer:latest ``` -Now you can run Fastsurfer the same way as described in our [README](README.md#example-1--fastsurfer-docker), for example: +Now you can run Fastsurfer the same way as described in [Example 1](EXAMPLES.md#example-1-fastsurfer-docker), for example: ```bash docker run --gpus all -v C:/Users/user/my_mri_data:/data \ diff --git a/doc/overview/OUTPUT_FILES.md b/doc/overview/OUTPUT_FILES.md new file mode 100644 index 00000000..87416ed3 --- /dev/null +++ b/doc/overview/OUTPUT_FILES.md @@ -0,0 +1,76 @@ +# Output files + +## Segmentation module + +The segmentation module outputs the files shown in the table below. The two primary output files are the `aparc.DKTatlas+aseg.deep.mgz` file, which contains the FastSurfer segmentation of cortical and subcortical structures based on the DKT atlas, and the `aseg+DKT.stats` file, which contains summary statistics for these structures. Note, that the surface model (downstream) corrects these segmentations along the cortex with the created surfaces. So if the surface model is used, it is recommended to use the updated segmentations and stats (see below). + +| directory | filename | module | description | +|:----------|------------------------------|---------|--------------------------------------------------------------------| +| mri | aparc.DKTatlas+aseg.deep.mgz | asegdkt | cortical and subcortical segmentation | +| mri | aseg.auto_noCCseg.mgz | asegdkt | simplified subcortical segmentation without corpus callosum labels | +| mri | mask.mgz | asegdkt | brainmask | +| mri | orig.mgz | asegdkt | conformed image | +| mri | orig_nu.mgz | asegdkt | biasfield-corrected image | +| mri/orig | 001.mgz | asegdkt | original image | +| scripts | deep-seg.log | asegdkt | logfile | +| stats | aseg+DKT.stats | asegdkt | table of cortical and subcortical segmentation statistics | + +## Cerebnet module + +The cerebellum module outputs the files in the table shown below. Unless switched off by the `--no_cereb` argument, this module is automatically run whenever the segmentation module is run. It adds two files, an image with the sub-segmentation of the cerebellum and a text file with summary statistics. + + +| directory | filename | module | description | +|:----------|----------------------------|----------|---------------------------------------------| +| mri | cerebellum.CerebNet.nii.gz | cerebnet | cerebellum sub-segmentation | +| stats | cerebellum.CerebNet.stats | cerebnet | table of cerebellum segmentation statistics | + +## HypVINN module + +The hypothalamus module outputs the files in the table shown below. Unless switched off by the `--no_hypothal` argument, this module is automatically run whenever the segmentation module is run. It adds three files, an image with the sub-segmentation of the hypothalamus and a text file with summary statistics. + + +| directory | filename | module | description | +|:----------|----------------------------------|---------|-----------------------------------------------| +| mri | hypothalamus.HypVINN.nii.gz | hypvinn | hypothalamus sub-segmentation | +| mri | hypothalamus_mask.HypVINN.nii.gz | hypvinn | hypothalamus sub-segmentation mask | +| stats | hypothalamus.HypVINN.stats | hypvinn | table of hypothalamus segmentation statistics | + +If a T2 image is also passed, the following images are created. + +| directory | filename | module | description | +|:----------|---------------|---------|--------------------------------| +| mri | T2_nu.mgz | hypvinn | biasfield-corrected T2 image | +| mri | T2_nu_reg.mgz | hypvinn | co-registered T2 to orig image | + +## Surface module + +The surface module is run unless switched off by the `--seg_only` argument. It outputs a large number of files, which generally correspond to the FreeSurfer nomenclature and definition. A selection of important output files is shown in the table below, for the other files, we refer to the [FreeSurfer documentation](https://surfer.nmr.mgh.harvard.edu/fswiki). In general, the "mri" directory contains images, including segmentations, the "surf" folder contains surface files (geometries and vertex-wise overlay data), the "label" folder contains cortical parcellation labels, and the "stats" folder contains tabular summary statistics. Many files are available for the left ("lh") and right ("rh") hemisphere of the brain. Symbolic links are created to map FastSurfer files to their FreeSurfer equivalents, which may need to be present for further processing (e.g., with FreeSurfer downstream modules). + +After running this module, some of the initial segmentations and corresponding volume estimates are fine-tuned (e.g., surface-based partial volume correction, addition of corpus callosum labels). Specifically, this concerns the `aseg.mgz `, `aparc.DKTatlas+aseg.mapped.mgz`, `aparc.DKTatlas+aseg.deep.withCC.mgz`, which were originally created by the segmentation module or have earlier versions resulting from that module. + +The primary output files are pial, white, and inflated surface files, the thickness overlay files, and the cortical parcellation (annotation) files. The preferred way of assessing this output is the [FreeView](https://surfer.nmr.mgh.harvard.edu/fswiki/FreeviewGuide) software. Summary statistics for volume and thickness estimates per anatomical structure are reported in the stats files, in particular the `aseg.stats`, and the left and right `aparc.DKTatlas.mapped.stats` files. + +| directory | filename | module | description | +|:----------|----------------------------------------------------------------|---------|----------------------------------------------------------------------------------------------| +| mri | aparc.DKTatlas+aseg.deep.withCC.mgz | surface | cortical and subcortical segmentation incl. corpus callosum after running the surface module | +| mri | aparc.DKTatlas+aseg.mapped.mgz | surface | cortical and subcortical segmentation after running the surface module | +| mri | aparc.DKTatlas+aseg.mgz | surface | symlink to aparc.DKTatlas+aseg.mapped.mgz | +| mri | aparc+aseg.mgz | surface | symlink to aparc.DKTatlas+aseg.mapped.mgz | +| mri | aseg.mgz | surface | subcortical segmentation after running the surface module | +| mri | wmparc.DKTatlas.mapped.mgz | surface | white matter parcellation | +| mri | wmparc.mgz | surface | symlink to wmparc.DKTatlas.mapped.mgz | +| surf | lh.area, rh.area | surface | surface area overlay file | +| surf | lh.curv, rh.curv | surface | curvature overlay file | +| surf | lh.inflated, rh.inflated | surface | inflated cortical surface | +| surf | lh.pial, rh.pial | surface | pial surface | +| surf | lh.thickness, rh.thickness | surface | cortical thickness overlay file | +| surf | lh.volume, rh.volume | surface | gray matter volume overlay file | +| surf | lh.white, rh.white | surface | white matter surface | +| label | lh.aparc.DKTatlas.annot, rh.aparc.DKTatlas.annot | surface | symlink to lh.aparc.DKTatlas.mapped.annot | +| label | lh.aparc.DKTatlas.mapped.annot, rh.aparc.DKTatlas.mapped.annot | surface | annotation file for cortical parcellations, mapped from ASEGDKT segmentation to the surface | +| stats | aseg.stats | surface | table of cortical and subcortical segmentation statistics after running the surface module | +| stats | lh.aparc.DKTatlas.mapped.stats, rh.aparc.DKTatlas.mapped.stats | surface | table of cortical parcellation statistics, mapped from ASEGDKT segmentation to the surface | +| stats | lh.curv.stats, rh.curv.stats | surface | table of curvature statistics | +| stats | wmparc.DKTatlas.mapped.stats | surface | table of white matter segmentation statistics | +| scripts | recon-all.log | surface | logfile | \ No newline at end of file diff --git a/SECURITY.md b/doc/overview/SECURITY.md similarity index 100% rename from SECURITY.md rename to doc/overview/SECURITY.md diff --git a/doc/overview/docker.rst b/doc/overview/docker.rst new file mode 100644 index 00000000..3cc4057e --- /dev/null +++ b/doc/overview/docker.rst @@ -0,0 +1,8 @@ +Docker Support +-------------- + +.. include:: ../../Docker/README.md + :parser: fix_links.parser + :relative-docs: . + :relative-images: + :start-line: 1 diff --git a/doc/overview/index.rst b/doc/overview/index.rst new file mode 100644 index 00000000..cc483845 --- /dev/null +++ b/doc/overview/index.rst @@ -0,0 +1,18 @@ +User Guide +========== + +.. toctree:: + :maxdepth: 2 + + intro + INSTALL.md + EXAMPLES.md + FLAGS.md + OUTPUT_FILES.md + docker + singularity + EDITING.md + SECURITY.md + CODE_OF_CONDUCT.md + CONTRIBUTING.md + license diff --git a/doc/overview/intro.rst b/doc/overview/intro.rst new file mode 100644 index 00000000..aa302e33 --- /dev/null +++ b/doc/overview/intro.rst @@ -0,0 +1,23 @@ +########################## +Introduction to FastSurfer +########################## + +.. include:: ../../README.md + :parser: fix_links.parser + :relative-docs: . + :relative-images: + :start-after: + :end-before: + +.. include:: ../../README.md + :parser: fix_links.parser + :relative-docs: . + :relative-images: + :start-after: + :end-before: + +.. include:: ../../README.md + :parser: fix_links.parser + :relative-docs: . + :relative-images: + :start-after: diff --git a/doc/overview/license.rst b/doc/overview/license.rst new file mode 100644 index 00000000..293ef81a --- /dev/null +++ b/doc/overview/license.rst @@ -0,0 +1,9 @@ +################## +FastSurfer License +################## + + +Apache License +============== + +.. literalinclude:: ../../LICENSE diff --git a/doc/overview/singularity.rst b/doc/overview/singularity.rst new file mode 100644 index 00000000..afb11cb8 --- /dev/null +++ b/doc/overview/singularity.rst @@ -0,0 +1,8 @@ +Singularity Support +------------------- + +.. include:: ../../Singularity/README.md + :parser: fix_links.parser + :relative-docs: . + :relative-images: + :start-line: 1 diff --git a/doc/references.bib b/doc/references.bib new file mode 100644 index 00000000..e9ffb83e --- /dev/null +++ b/doc/references.bib @@ -0,0 +1,34 @@ +@article{conformal_parameterization_2020, + author = {Choi, Gary P. T. and Leung-Liu, Yusan and Gu, Xianfeng and Lui, Lok Ming}, + doi = {10.1137/19M125337X}, + journal = {SIAM Journal on Imaging Sciences}, + number = {3}, + pages = {1049-1083}, + title = {Parallelizable Global Conformal Parameterization of Simply-Connected Surfaces via Partial Welding}, + volume = {13}, + year = {2020} +} + +@article{numpy_2020, + author = {Harris, Charles R. and Millman, K. Jarrod and van der Walt, Stéfan J. and Gommers, Ralf and Virtanen, Pauli and Cournapeau, David and Wieser, Eric and Taylor, Julian and Berg, Sebastian and Smith, Nathaniel J. and Kern, Robert and Picus, Matti and Hoyer, Stephan and van Kerkwijk, Marten H. and Brett, Matthew and Haldane, Allan and del Río, Jaime Fernández and Wiebe, Mark and Peterson, Pearu and Gérard-Marchant, Pierre and Sheppard, Kevin and Reddy, Tyler and Weckesser, Warren and Abbasi, Hameer and Gohlke, Christoph and Oliphant, Travis E.}, + doi = {10.1038/s41586-020-2649-2}, + journal = {Nature}, + month = {September}, + number = {7825}, + pages = {357--362}, + title = {Array programming with {NumPy}}, + volume = {585}, + year = {2020} +} + +@article{scipy_2020, + author = {Virtanen, Pauli and Gommers, Ralf and Oliphant, Travis E. and Haberland, Matt and Reddy, Tyler and Cournapeau, David and Burovski, Evgeni and Peterson, Pearu and Weckesser, Warren and Bright, Jonathan and van der Walt, Stéfan J. and Brett, Matthew and Wilson, Joshua and Millman, K. Jarrod and Mayorov, Nikolay and Nelson, Andrew R. J. and Jones, Eric and Kern, Robert and Larson, Eric and Carey, C J and Polat, İlhan and Feng, Yu and Moore, Eric W. and VanderPlas, Jake and Laxalde, Denis and Perktold, Josef and Cimrman, Robert and Henriksen, Ian and Quintero, E. A. and Harris, Charles R. and Archibald, Anne M. and Ribeiro, Antônio H. and Pedregosa, Fabian and van Mulbregt, Paul and {SciPy 1.0 Contributors} and Vijaykumar, Aditya and Bardelli, Alessandro Pietro and Rothberg, Alex and Hilboll, Andreas and Kloeckner, Andreas and Scopatz, Anthony and Lee, Antony and Rokem, Ariel and Woods, C. Nathan and Fulton, Chad and Masson, Charles and Häggström, Christian and Fitzgerald, Clark and Nicholson, David A. and Hagen, David R. and Pasechnik, Dmitrii V. and Olivetti, Emanuele and Martin, Eric and Wieser, Eric and Silva, Fabrice and Lenders, Felix and Wilhelm, Florian and Young, G. and Price, Gavin A. and Ingold, Gert-Ludwig and Allen, Gregory E. and Lee, Gregory R. and Audren, Hervé and Probst, Irvin and Dietrich, Jörg P. and Silterra, Jacob and Webber, James T and Slavič, Janko and Nothman, Joel and Buchner, Johannes and Kulick, Johannes and Schönberger, Johannes L. and de Miranda Cardoso, José Vinícius and Reimer, Joscha and Harrington, Joseph and Rodríguez, Juan Luis Cano and Nunez-Iglesias, Juan and Kuczynski, Justin and Tritz, Kevin and Thoma, Martin and Newville, Matthew and Kümmerer, Matthias and Bolingbroke, Maximilian and Tartre, Michael and Pak, Mikhail and Smith, Nathaniel J. and Nowaczyk, Nikolai and Shebanov, Nikolay and Pavlyk, Oleksandr and Brodtkorb, Per A. and Lee, Perry and McGibbon, Robert T. and Feldbauer, Roman and Lewis, Sam and Tygier, Sam and Sievert, Scott and Vigna, Sebastiano and Peterson, Stefan and More, Surhud and Pudlik, Tadeusz and Oshima, Takuya and Pingel, Thomas J. and Robitaille, Thomas P. and Spura, Thomas and Jones, Thouis R. and Cera, Tim and Leslie, Tim and Zito, Tiziano and Krauss, Tom and Upadhyay, Utkarsh and Halchenko, Yaroslav O. and Vázquez-Baeza, Yoshiki}, + doi = {10.1038/s41592-019-0686-2}, + journal = {Nature Methods}, + month = {March}, + number = {3}, + pages = {261--272}, + title = {{SciPy} 1.0: fundamental algorithms for scientific computing in {Python}}, + volume = {17}, + year = {2020} +} \ No newline at end of file diff --git a/doc/scripts/BATCH.md b/doc/scripts/BATCH.md new file mode 100644 index 00000000..a0ba46fa --- /dev/null +++ b/doc/scripts/BATCH.md @@ -0,0 +1,18 @@ +BATCH: brun_fastsurfer.sh +========================= + +Usage +----- + +```{command-output} ./brun_fastsurfer.sh --help +:cwd: /../ +``` + +Questions +--------- +Can I disable the progress bars in the output? + +> You can disable the progress bars by setting the TQDM_DISABLE environment variable to 1, if you have tqdm>=4.66. +> +> For docker, this can be done with the flag `-e`, e.g. `docker run -e TQDM_DISABLE=1 ...`, for singularity with the flag `--env`, e.g. `singularity exec --env TQDM_DISABLE=1 ...` and for native installations by prepending, e.g. `TQDM_DISABLE=1 ./run_fastsurfer.sh ...`. + diff --git a/doc/scripts/SLURM.md b/doc/scripts/SLURM.md new file mode 100644 index 00000000..8bc4068e --- /dev/null +++ b/doc/scripts/SLURM.md @@ -0,0 +1,45 @@ +SLURM: srun_fastsurfer.sh +========================= + +Usage +----- + +```{command-output} ./srun_fastsurfer.sh --help +:cwd: /../ +``` + +Debugging SLURM runs +-------------------- + +1. Did the run succeed? + + 1. Check whether all jobs are done (specifically the copy job). + ```bash + $ squeue -u $USER --Format JobArrayID,Name,State,Dependency + 1750814_3 FastSurfer-Seg-kueglRUNNING (null) + 1750815_3 FastSurfer-Surf-kuegPENDING aftercorr:1750814_*( + 1750816 FastSurfer-Cleanup-kPENDING afterany:1750815_*(u + 1750815_1 FastSurfer-Surf-kuegRUNNING (null) + 1750815_2 FastSurfer-Surf-kuegRUNNING (null) + ``` + Here, jobs are not finished yet. The FastSurfer-Cleanup-$USER Job moves data to the subject directory (--sd). + + 2. Check whether there are subject folders and log files in the subject directory, /slurm/logs for the latter. + + 3. Check the subject_success file in /slurm/scripts. It should have a line for each subject for both parts of the FastSurfer pipeline, e.g. `: Finished --seg_only successfully` or `: Finished --surf_only successfully`! If one of these is missing, the job was likely killed by slurm (e.g. because of the time or the memory limit). + + 4. For subjects that were unsuccessful (The subject_success will say so), check `//scripts/deep-seg.log` and `//scripts/recon-surf.log` to see what failed. + Can be found by looking for ": Failed <--seg_only/--surf_only> with exit code " in `/slurm/scripts/subject_success`. + + 5. For subjects that were terminated (missing in subject_success), find which job is associated with subject id `grep "" slurm/logs/surf_*.log`, then look at the end of the job and the job step logs (surf_XXX_YY.log and surf_XXX_YY_ZZ.log). If slurm terminated the job, it will say so there. You can increase the time and memory budget in `srun_fastsurfer.sh` with `--time` and `--mem` flags. + The following bash code snippet can help identify failed runs. + ``` + cd + for sub in * + do + if [[ -z "$(grep "$sub: Finished --surf" slurm/scripts/subject_success)" ]] + then + echo "$sub was terminated externally" + fi + done + ``` diff --git a/doc/scripts/advanced.rst b/doc/scripts/advanced.rst new file mode 100644 index 00000000..514d6000 --- /dev/null +++ b/doc/scripts/advanced.rst @@ -0,0 +1,11 @@ +Advanced scripts +================ + +.. toctree:: + :titlesonly: + + fastsurfercnn + cerebnet + hypvinn + recon_surf + segstats diff --git a/doc/scripts/cerebnet.rst b/doc/scripts/cerebnet.rst new file mode 100644 index 00000000..561d7299 --- /dev/null +++ b/doc/scripts/cerebnet.rst @@ -0,0 +1,19 @@ +CerebNet: run_prediction.py +=========================== + +.. note:: + We recommend to run CerebNet with the standard `run_fastsurfer.sh` interfaces! + +The `CerebNet/run_prediction.py` script enables the inference with CerebNet. In most +situations, it will be called from `run_fastsurfer.sh` a direct call to +`CerebNet/run_prediction.py` is not needed. + +.. argparse:: + :module: CerebNet.run_prediction + :func: setup_options + :prog: CerebNet/run_prediction.py + + +.. include:: ../../CerebNet/README.md + :parser: fix_links.parser + :start-line: 1 diff --git a/doc/scripts/fastsurfercnn.generate_hdf5.rst b/doc/scripts/fastsurfercnn.generate_hdf5.rst new file mode 100644 index 00000000..d565ab54 --- /dev/null +++ b/doc/scripts/fastsurfercnn.generate_hdf5.rst @@ -0,0 +1,16 @@ +FastSurferCNN: generate_hdf5.py +================================ + +.. include:: ../../FastSurferCNN/README.md + :parser: fix_links.parser + :relative-docs: . + :relative-images: + :start-after: + :end-before: + +Full commandline interface of FastSurferCNN/generate_hdf5.py +------------------------------------------------------------ +.. argparse:: + :module: FastSurferCNN.generate_hdf5 + :func: make_parser + :prog: FastSurferCNN/generate_hdf5.py diff --git a/doc/scripts/fastsurfercnn.rst b/doc/scripts/fastsurfercnn.rst new file mode 100644 index 00000000..92a1a677 --- /dev/null +++ b/doc/scripts/fastsurfercnn.rst @@ -0,0 +1,19 @@ +FastSurferCNN: run_prediction.py +================================ + +.. note:: + We recommend to run the surface pipeline with the standard `run_fastsurfer.sh` interfaces! + +.. include:: ../../FastSurferCNN/README.md + :parser: fix_links.parser + :relative-docs: . + :relative-images: + :start-after: + :end-before: + +Full commandline interface of FastSurferCNN/run_prediction.py +------------------------------------------------------------- +.. argparse:: + :module: FastSurferCNN.run_prediction + :func: make_parser + :prog: FastSurferCNN/run_prediction.py diff --git a/doc/scripts/fastsurfercnn.run_model.rst b/doc/scripts/fastsurfercnn.run_model.rst new file mode 100644 index 00000000..16429021 --- /dev/null +++ b/doc/scripts/fastsurfercnn.run_model.rst @@ -0,0 +1,15 @@ +FastSurferCNN: run_model.py +================================ + +.. include:: ../../FastSurferCNN/README.md + :parser: fix_links.parser + :relative-docs: . + :relative-images: + :start-after: + +Full commandline interface of FastSurferCNN/run_model.py +-------------------------------------------------------- +.. argparse:: + :module: FastSurferCNN.run_model + :func: make_parser + :prog: FastSurferCNN/run_model.py diff --git a/doc/scripts/hypvinn.rst b/doc/scripts/hypvinn.rst new file mode 100644 index 00000000..2dad15a0 --- /dev/null +++ b/doc/scripts/hypvinn.rst @@ -0,0 +1,19 @@ +HypVINN: run_prediction.py +========================== + +.. note:: + We recommend to run HypVINN with the standard `run_fastsurfer.sh` interfaces! + +The `HypVINN/run_prediction.py` script enables the inference with HypVINN. In most +situations, it will be called from `run_fastsurfer.sh` a direct call to +`HypVINN/run_prediction.py` is not needed. + +.. argparse:: + :module: HypVINN.run_prediction + :func: option_parse + :prog: HypVINN/run_prediction.py + + +.. include:: ../../HypVINN/README.md + :parser: fix_links.parser + :start-line: 1 diff --git a/doc/scripts/index.rst b/doc/scripts/index.rst new file mode 100644 index 00000000..adfa9549 --- /dev/null +++ b/doc/scripts/index.rst @@ -0,0 +1,10 @@ +Scripts +======= + +.. toctree:: + :maxdepth: 2 + + BATCH.md + SLURM.md + advanced + util diff --git a/doc/scripts/recon_surf.rst b/doc/scripts/recon_surf.rst new file mode 100644 index 00000000..f7b2bb2d --- /dev/null +++ b/doc/scripts/recon_surf.rst @@ -0,0 +1,14 @@ +Surface pipeline: recon-surf.sh +=============================== + +.. include:: ../../recon_surf/README.md + :parser: fix_links.parser + :relative-docs: . + :relative-images: + :heading-offset: 1 + +Usage help text +--------------- + +.. command-output:: ./recon_surf/recon-surf.sh --help + :cwd: /../ diff --git a/doc/scripts/segstats.rst b/doc/scripts/segstats.rst new file mode 100644 index 00000000..e2d8211c --- /dev/null +++ b/doc/scripts/segstats.rst @@ -0,0 +1,12 @@ +FastSurferCNN: segstats.py +========================== + +`segstats.py` is a script that is equivalent to FreeSurfer's `mri_segstats`. However, it is faster and (automatically) scales very well to multi-processing scenarios. + + +Full commandline interface of FastSurferCNN/segstats.py +------------------------------------------------------- +.. argparse:: + :module: FastSurferCNN.segstats + :func: make_arguments + :prog: FastSurferCNN/segstats.py diff --git a/doc/scripts/util.rst b/doc/scripts/util.rst new file mode 100644 index 00000000..9de57d51 --- /dev/null +++ b/doc/scripts/util.rst @@ -0,0 +1,8 @@ +FastSurfer Utilities +==================== + +.. toctree:: + :maxdepth: 2 + + fastsurfercnn.generate_hdf5.rst + fastsurfercnn.run_model.rst diff --git a/doc/sphinx_ext/fix_links/__init__.py b/doc/sphinx_ext/fix_links/__init__.py new file mode 100644 index 00000000..496389d2 --- /dev/null +++ b/doc/sphinx_ext/fix_links/__init__.py @@ -0,0 +1,30 @@ +from pathlib import Path + +from sphinx.application import Sphinx +from sphinx.directives.other import Include +from docutils.parsers.rst import directives + +from fix_links.resolve import MySTReplaceDomain, resolve_xref +from fix_links.parser import Parser, wrap_include_run + + +def setup(app: Sphinx): + + app.add_config_value("fix_links_types", ("ref", "myst", "doc"), "env", list) + app.add_config_value("fix_links_target", {}, "env", dict) + app.add_config_value("fix_links_alternative_targets", {}, "env", dict) + app.add_config_value("fix_links_project_root", Path("."), "env", Path) + app.add_domain(MySTReplaceDomain) + app.connect("missing-reference", resolve_xref) + + # override the myst parser without loading the myst parser for default parsing + # [Sphinx](https://github.com/sphinx-doc/sphinx) extension + app.add_source_parser(Parser, override=True) + + # update the Include directive's run command + Include.run = wrap_include_run(Include.run) + Include.option_spec["relative-images"] = directives.flag + Include.option_spec["relative-docs"] = directives.path + Include.option_spec["heading-offset"] = directives.nonnegative_int + + return {"parallel_read_safe": True, "parallel_write_safe": True, "version": "0.1"} diff --git a/doc/sphinx_ext/fix_links/parser.py b/doc/sphinx_ext/fix_links/parser.py new file mode 100644 index 00000000..9e492a11 --- /dev/null +++ b/doc/sphinx_ext/fix_links/parser.py @@ -0,0 +1,190 @@ +from functools import wraps +from os.path import relpath +from pathlib import Path +from typing import Optional, cast +from itertools import chain + +from docutils import nodes +from markdown_it import MarkdownIt +from markdown_it.tree import SyntaxTreeNode +from myst_parser.mdit_to_docutils.sphinx_ import SphinxRenderer +from myst_parser.sphinx_ import Parser as MySTParser +from sphinx import addnodes +from sphinx.directives.other import Include + + +def wrap_include_run(method): + @wraps(method) + def _wrapper(include_instance: Include): + doc_settings = include_instance.state.document.settings + key = "fix_links_parser_options" + if hasattr(doc_settings, key): + options = getattr(doc_settings, key, {}) + else: + options = {} + setattr(doc_settings, key, options) + + source_dir = Path(include_instance.state.document["source"]).parent + include_path = (source_dir / include_instance.arguments[0]).resolve() + if "relative-images" in include_instance.options: + from os.path import relpath + options["relative-images"] = relpath(include_path.parent, source_dir) + relative_docs = include_instance.options.get("relative-docs", ".") + if relative_docs != "/": + options["relative-docs"] = (relative_docs, source_dir, include_path.parent) + return method(include_instance) + + return _wrapper + + +class Renderer(SphinxRenderer): + """ + Renderer object to automatically fix headings that are not consecutive levels in + (included) Markdown files. Also includes alternative targets into anchors that + are rendered, but do not match a target. + """ + + def __init__(self, parser: MarkdownIt): + self._heading_base: Optional[int] = None + super().__init__(parser) + + def update_section_level_state(self, section: nodes.section, level: int) -> None: + """This method is fixed such that """ + parent_level = max( + section_level + for section_level in self._level_to_section + if level > section_level + ) + if self._heading_base is None: + if (level > parent_level) and (parent_level + 1 != level): + self._heading_base = level - parent_level - 1 + else: + self._heading_base = 0 + + new_level = level - self._heading_base + if new_level < 0: + msg = (f"We fixed the offset to {new_level} based on the first heading, " + f"but following headings have lower numbers") + from myst_parser.warnings_ import MystWarnings + self.create_warning( + msg, + MystWarnings.MD_HEADING_NON_CONSECUTIVE, + line=section.line, + append_to=self.current_node, + ) + self._heading_base = level + new_level = 0 + + super().update_section_level_state(section, new_level) + + def _handle_relative_docs(self, destination: str) -> str: + from os.path import relpath, normpath + if destination.startswith("/"): + return relpath(destination[1:], self.sphinx_env.srcdir) + relative_include = self.md_env.get("relative-docs", None) + if relative_include is not None: + source_dir: Path + source_dir, include_dir = relative_include[1:] + return relpath( + include_dir / relative_include[0] / normpath(destination), + source_dir, + ) + return destination + + def render_link_anchor(self, token: SyntaxTreeNode, target: str) -> None: + + if not target.startswith("#"): + return self.render_link_unknown(token) + + if target[1:] in self.document.nameids: + return super().render_link_anchor(token, target) + + cfg_alt_tgts = self.sphinx_env.config.fix_links_alternative_targets + + include_abspaths = (Path(inc[0]).resolve() for inc in self.document.include_log) + doc_root = self.sphinx_env.srcdir + include_relpaths = (f"/{relpath(path, doc_root)}" for path in include_abspaths) + includes = (".",) + tuple(include_relpaths) + alt_targets = dict.fromkeys(chain(*(cfg_alt_tgts.get(f, ()) for f in includes))) + + # href_before = token.attrGet("href") + token.attrs["href"] = Path(self.current_node.source).name + target + self.render_link_unknown(token) + + ref_node = self.current_node.children[-1] + if isinstance(ref_node, addnodes.pending_xref): + ref_node["alternative_targets"] = tuple(alt_targets) + + def render_link_unknown(self, token: SyntaxTreeNode) -> None: + super().render_link_unknown(token) + ref_node: nodes.Element = cast(nodes.Element, self.current_node.children[-1]) + attr = ref_node.attributes + if (attr.get("refdomain", "") == "doc" and + (target := attr.get("reftarget", "")).startswith("..")): + attr["refdomain"] = None + # project_root: how absolute paths are interpreted w.r.t. the doc root + doc_root = Path(self.sphinx_env.srcdir) + project_root = self.sphinx_env.config.fix_links_project_root + target_path = relpath( + (doc_root / target).resolve(), + (doc_root / project_root).resolve(), + ) + attr["reftarget"] = f"/{target_path}" + + +class Parser(MySTParser): + """ + Parser to use `Renderer`, which automatically fixes non-consecutive headings and + manages alternative targets in the topmatter. + """ + + def parse(self, inputstring: str, document: nodes.document) -> None: + """Parse source text. + + :param inputstring: The source string to parse + :param document: The root docutils node to add AST elements to + + """ + from myst_parser.warnings_ import create_warning + from myst_parser.parsers.mdit import create_md_parser + from myst_parser.config.main import ( + MdParserConfig, TopmatterReadError, merge_file_level, read_topmatter, + ) + + # get the global config + config: MdParserConfig = document.settings.env.myst_config + alt_targets = () + + # update the global config with the file-level config + try: + topmatter = read_topmatter(inputstring) + except TopmatterReadError: + pass # this will be reported during the render + else: + if topmatter: + if "alternative-targets" in topmatter: + alt_targets = tuple(topmatter.pop("alternative-targets").split()) + warning = lambda wtype, msg: create_warning( # noqa: E731 + document, msg, wtype, line=1, append_to=document, + ) + config = merge_file_level(config, topmatter, warning) + + from contextlib import contextmanager + + @contextmanager + def _restore(node, cfg_name: str, values: tuple[str]): + cfg = getattr(node, cfg_name) + before = cfg.get(".", ()) + cfg["."] = before + values + yield + cfg["."] = before + + parser = create_md_parser(config, Renderer) + with _restore( + document.settings.env.config, + "fix_links_alternative_targets", + alt_targets, + ): + parser.options["document"] = document + parser_options = getattr(document.settings, "fix_links_parser_options", {}) + parser.render(inputstring, parser_options) diff --git a/doc/sphinx_ext/fix_links/resolve.py b/doc/sphinx_ext/fix_links/resolve.py new file mode 100644 index 00000000..570e881d --- /dev/null +++ b/doc/sphinx_ext/fix_links/resolve.py @@ -0,0 +1,281 @@ + +import re +from functools import lru_cache, partial +from pathlib import Path +from typing import Generator, Any + +import sphinx.domains +from docutils import nodes +from sphinx.domains import Domain +from sphinx.application import Sphinx +from sphinx.environment import BuildEnvironment +from sphinx.builders import Builder +from sphinx import addnodes +from sphinx.util.logging import getLogger + +logger = getLogger(__name__) + + +@lru_cache +def make_pattern(s: str) -> re.Pattern: + return re.compile(s, re.IGNORECASE) + + +def loc(node) -> str: + return node["refdoc"] if "refdoc" in node.attributes else node.source + + +def resolve_included( + included: dict[str, set[str]], + found_docs: set[str], + uri_path: str, +) -> str: + """ + Iterate through including files resolved via inclusion links. + + Parameters + ---------- + included : dict[str, set[str]] + The dictionary mapping a file to the files it includes. + found_docs : set[str] + A set of doc files that are part of the documentation. + uri_path : str + The path to the included file + + Returns + ------- + str + The resolved path. + """ + def __resolve_all(path, include_tree=()): + for src, inc in included.items(): + if path in inc: + if src in found_docs: + yield src + elif src in include_tree: + logger.warning(f"Recursive inclusion in {path} -> {src}!") + else: + yield from __resolve_all(src, include_tree + (path,)) + + yield from __resolve_all(uri_path) + + +def resolve_xref( + app: Sphinx, + env: BuildEnvironment, + node: addnodes.pending_xref, + contnode: nodes.Element, +) -> nodes.reference | None: + """ + Replace unresolved names by replacing the link with configurable alternatives. + + For an unresolved :py:`sphinx.addnodes.pending_xref` `node` of reftype in the + config variable fix_links_types, this function will look through registered + replacements and replace with alternative labels. The first successful replacement, + that also matches a link in the documentation will be returned. + + This function is compatible with the missing-references sphinx-event. + + Parameters + ---------- + app : sphinx.application.Sphinx + env : sphinx.environment.BuildEnvironment + node : sphinx.addnodes.pending_xref + contnode : docutils.noes.Element + + Returns + ------- + docutils.nodes.reference, None + The first node that successfully links to a valid target. + """ + config = env.config + attr = "reftarget" + if node.attributes.get("reftype", "") in config.fix_links_types: + subs = {k: (make_pattern(k), v) for k, v in config.fix_links_target.items()} + _resolve_xref_with_ = partial(_resolve_xref_with, app, env, node, contnode) + + if attr not in node.attributes: + logger.debug( + f"[fix_links] Skipping replacement of {node.attibutes} (no {attr})", + location=loc(node), + ) + return + logger.debug( + f"[fix_links] Searching for replacement of {node[attr]}:", + location=loc(node), + ) + + from os.path import relpath + # project_root: how absolute paths are interpreted w.r.t. the doc root + doc_root = Path(env.srcdir) + project_root = env.config.fix_links_project_root + uri = node[attr] + if node["reftype"] == "doc": + if uri.startswith("/"): + node['reftarget'] = "index" + _uri_path = uri[1:] + else: + _uri_path = uri + _uri_id = node.attributes.get("refid", None) or "" + _uri_sep = "#" if _uri_id else "" + project_root = "." + reftype = "doc" + else: + _uri_path, _uri_sep, _uri_id = uri.partition("#") + if not _uri_id and getattr(node, "reftargetid", None) is not None: + _uri_sep, _uri_id = "#", node["reftargetid"] + + reftype = "ref" + # resolve the target Path in the link w.r.t. the source it came from + if _uri_path.startswith("/"): + # absolute with respect to documentation root + target_path = (doc_root / project_root / _uri_path[1:]).resolve() + else: + sourcefile_path = doc_root / node.source.split(":")[0] + target_path = (sourcefile_path.parent / _uri_path).resolve() + _uri_path = relpath(target_path, doc_root) + _uri_hash = _uri_sep + _uri_id + + if not _uri_path.startswith("../"): + # maybe this already fixed the path? + ref = _resolve_xref_with_( + f"{_uri_path}{_uri_hash}".lower(), + node.source, + reftype=reftype, + ) + if ref is not None: + return ref + + # trace back the include path and check if this resolves the ref + if env.included: + potential_targets = resolve_included( + env.included, + env.found_docs, + _uri_path, + ) + _reftarget = node["reftarget"] + _reftype = "doc" if _uri_hash == "" else "ref" + for potential_doc in potential_targets: + potential_path = env.doc2path(potential_doc, False) + if potential_path.endswith(".rst"): + potential_path = potential_doc + potential_path = relpath( + doc_root / potential_path, + (doc_root / node["refdoc"]).parent, + ) + node["reftarget"] = potential_path + ref = _resolve_xref_with_( + (potential_path + _uri_hash).lower(), + node.source, + reftype=_reftype, + ) + if ref is not None: + return ref + node["reftarget"] = _reftarget + + source = f"{_uri_path}{_uri_sep}{_uri_id}" + for key, (pat, repls) in subs.items(): + # if this search string does not match, try next + if not pat.match(source): + continue + + tries = [] + # iterate over different replacement options + for repl in repls: + # repeatedly replace until no more changes are occur + replaced = pat.sub(repl, source) + while pat.match(replaced): + _replaced = pat.sub(repl, replaced) + if replaced == _replaced: + logger.warning( + f"[fix_links] Infinite replacement loop with string " + f"'{source}', pattern '{key}' and replacement '{repl}'!", + location=loc(node), + ) + break + replaced = _replaced + # search for a reference associated with the replaced link in std + ref = _resolve_xref_with_( + str(replaced).lower(), + node.source, + reftype=reftype, + ) + + # check and return the reference, if it is valid + if ref is not None: + return ref + + tries.append(str(replaced).lower()) + # if the pattern matched, but none of the replacements lead to a valid + # reference + logger.warning( + f"[fix_links] Target '{source}' matched the pattern '{pat.pattern}', " + f"but could not be resolved by any of the replacements {tuple(tries)} " + f": {node['reftarget']}!", + location=loc(node), + ) + + if env.included and _uri_path.startswith("../"): + logger.warning( + f"[fix_links] Could not find the external target {_uri_path} in " + f"included files, it is likely not included.", + location=loc(node), + ) + + # restore the reftarget attribute + node[attr] = uri + # node["reftype"] = prev_type + + +def _resolve_xref_with( + app: Sphinx, + env: BuildEnvironment, + node: addnodes.pending_xref, + contnode: nodes.Element, + target: str, + source: str, + reftype: str = "ref", +) -> nodes.reference | None: + std_domain = env.domains["std"] + ref: nodes.reference | None = std_domain.resolve_xref( + env, + node["refdoc"], # fromdocname + app.builder, + reftype, + target, + node, + contnode, + ) + + # check and return the reference, if it is valid + if ref is not None: + attrs = ("reftarget", "refuri", "refid") + target = next((a, ref[a]) for a in attrs if a in ref.attributes) + logger.debug( + f"[fix_links] <{node.source}> replacing {source} with {'='.join(target)}", + location=loc(node), + ) + return ref + + +class MySTReplaceDomain(Domain): + """""" + + name: str = "myst_repl" + + def resolve_any_xref( + self, + env: BuildEnvironment, + fromdocname: str, + builder: Builder, + target: str, + node: addnodes.pending_xref, + contnode: nodes.Element, + ) -> list[tuple[str, nodes.Element]]: + try: + ref: nodes.Element | None = resolve_xref(env.app, env, node, contnode) + if ref is not None: + return [("std:ref", ref)] + except StopIteration: + pass + return [] diff --git a/env/export_pip-r.sh b/env/export_pip-r.sh index 49433793..9ec14c61 100644 --- a/env/export_pip-r.sh +++ b/env/export_pip-r.sh @@ -48,21 +48,23 @@ echo "Exporting versions from $2..." echo "#" } > $1 -pip_cmd="python --version && pip list --format=freeze --no-color --all --disable-pip-version-check --no-input" +pip_cmd="python --version && pip list --format=freeze --no-color --disable-pip-version-check --no-input" if [ "${2/#.sif}" != "$2" ] then # singularity - cmd="singularity exec $2 /bin/bash -c '$pip_cmd'" + cmd=("singularity" "exec" "$2" "/bin/bash" -c "$pip_cmd") + clean_cmd="singularity exec $2 /bin/bash -c '$pip_cmd'" else # docker - cmd="docker run --entrypoint /bin/bash $2 -c '$pip_cmd'" + clean_cmd="docker run --rm -u : --entrypoint /bin/bash $2 -c '$pip_cmd'" + cmd=("docker" "run" --rm -u "$(id -u):$(id -g)" --entrypoint /bin/bash "$2" -c "$pip_cmd") fi { echo "# Which ran the following command:" - echo "# $cmd" + echo "# $clean_cmd" echo "#" } >> $1 -out=$($cmd) +out=$("${cmd[@]}") hardware=$(echo "$out" | grep "torch==" | cut -d"+" -f2) pyversion=$(echo "$out" | head -n 1 | cut -d" " -f2) { @@ -73,5 +75,3 @@ pyversion=$(echo "$out" | head -n 1 | cut -d" " -f2) echo "" echo "# $out" } >> $1 - -} \ No newline at end of file diff --git a/env/fastsurfer.yml b/env/fastsurfer.yml index 25a80a91..4e3567f3 100644 --- a/env/fastsurfer.yml +++ b/env/fastsurfer.yml @@ -5,28 +5,28 @@ channels: - defaults dependencies: - - h5py=3.7.0 - - lapy=1.0.1 - - matplotlib=3.7.1 - - nibabel=5.1.0 - - numpy=1.25.0 - - pandas=1.5.3 - - pillow=10.0.1 - - pip=23.1.2 - - python=3.10 - - python-dateutil=2.8.2 - - pyyaml=6.0 - - scikit-image=0.19.3 - - scikit-learn=1.2.2 - - scipy=1.10.1 - - setuptools=67.8.0 - - tensorboard=2.12.1 - - tqdm=4.65.0 - - yacs=0.1.8 - - pip - - pip: - - --extra-index-url https://download.pytorch.org/whl/cu117 - - simpleitk==2.2.1 - - torch==2.0.1 - - torchio==0.18.83 - - torchvision==0.15.2 +- h5py=3.11.0 +- lapy=1.1.0 +- matplotlib=3.9.2 +- nibabel=5.2.1 +- numpy=1.26.4 +- pandas=2.2.2 +- pillow=10.4.0 +- pip=24.2 +- python=3.10 +- python-dateutil=2.9.0 +- pyyaml=6.0.2 +- requests=2.32.3 +- scikit-image=0.24.0 +- scikit-learn=1.5.1 +- scipy=1.14.1 +- setuptools=72.2.0 +- tensorboard=2.17.1 +- tqdm=4.66.5 +- yacs=0.1.8 +- pip: + - --extra-index-url https://download.pytorch.org/whl/cu124 + - simpleitk==2.4.0 + - torch==2.4.0+cu124 + - torchio==0.19.9 + - torchvision==0.19.0+cu124 diff --git a/pyproject.toml b/pyproject.toml index a632c840..c6d16506 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,21 +4,13 @@ build-backend = 'setuptools.build_meta' [project] name = 'fastsurfer' -version = "2.2.0" -description = 'a fast and accurate deep-learning based neuroimaging pipeline' +version = '2.3.0' +description = 'A fast and accurate deep-learning based neuroimaging pipeline' readme = 'README.md' license = {file = 'LICENSE'} -requires-python = '>=3.8' -authors = [ - {name = 'Martin Reuter', email = 'martin.reuter@dzne.de'}, - {name = 'Leonie Henschel', email = 'leonie.henschel@dzne.de'}, - {name = 'David Kügler', email = 'david.kuegler@dzne.de'}, -] -maintainers = [ - {name = 'Martin Reuter', email = 'martin.reuter@dzne.de'}, - {name = 'Leonie Henschel', email = 'leonie.henschel@dzne.de'}, - {name = 'David Kügler', email = 'david.kuegler@dzne.de'}, -] +requires-python = '>=3.10' +authors = [{name = 'Martin Reuter et al.'}] +maintainers = [{name = 'FastSurfer Developers'}] keywords = [ 'python', 'Deep learning', @@ -32,30 +24,71 @@ classifiers = [ 'Operating System :: Unix', 'Operating System :: MacOS', 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Natural Language :: English', - 'License :: OSI Approved :: MIT License', + 'License :: OSI Approved :: Apache Software License', 'Intended Audience :: Science/Research', ] dependencies = [ 'h5py>=3.7', - 'lapy>=0.4.1', - 'matplotlib>=3.5.1', - 'nibabel>=3.2.2', - 'numpy>=1.21', - 'pandas>=1.4.3', - 'pytorch>=1.12.0', + 'lapy>=1.1.0', + 'matplotlib>=3.7.1', + 'nibabel>=5.1.0', + 'numpy>=1.25,<2', + 'pandas>=1.5.3', 'pyyaml>=6.0', - 'scipy>=1.8.0', - 'yacs>=0.1.8', - 'simpleitk>=2.1.1', - 'scipy>=1.8.0', - 'tensorboard>=2.9.1', + 'requests>=2.31.0', + 'scikit-image>=0.19.3', + 'scikit-learn>=1.2.2', + 'scipy>=1.10.1,!=1.13.0', + 'simpleitk>=2.2.1', + 'tensorboard>=2.12.1', + 'torch>=2.0.1', 'torchio>=0.18.83', - 'tqdm>=4.64', + 'torchvision>=0.15.2', + 'tqdm>=4.65', + 'yacs>=0.1.8', +] + +[project.optional-dependencies] +doc = [ + 'furo!=2023.8.17', + 'matplotlib', + 'memory-profiler', + 'myst-parser', + 'numpydoc', + # sphinx 8 handles importing of files differently in some manner and will cause the + # build of the doc to fail. This will need to be addressed before we up to sphinx 8. + 'sphinx>=7.3,<8', + 'sphinxcontrib-bibtex', + 'sphinxcontrib-programoutput', + 'sphinx-argparse', + 'sphinx-copybutton', + 'sphinx-design', + 'sphinx-gallery', + 'sphinx-issues', + 'pypandoc', + 'nbsphinx', + 'IPython', # For syntax highlighting in notebooks + 'ipykernel', + 'scikit-image', + 'torchvision', + 'scikit-learn', +] +style = [ + 'bibclean', + 'codespell', + 'pydocstyle[toml]', + 'ruff', +] +all = [ + 'fastsurfer[doc]', + 'fastsurfer[style]', +] +full = [ + 'fastsurfer[all]', ] [project.urls] @@ -63,3 +96,60 @@ homepage = 'https://fastsurfer.org' documentation = 'https://fastsurfer.org' source = 'https://github.com/Deep-MI/FastSurfer' tracker = 'https://github.com/Deep-MI/FastSurfer/issues' + +[tool.setuptools] +packages = ['FastSurferCNN','CerebNet','recon_surf'] + +[tool.pydocstyle] +convention = 'numpy' +ignore-decorators = '(copy_doc|property|.*setter|.*getter|pyqtSlot|Slot)' +match = '^(?!setup|__init__|test_).*\.py' +match-dir = '^FastSurferCNN.*,^CerebNet.*,^recon-surf.*' +add_ignore = 'D100,D104,D107' + +[tool.ruff] +line-length = 120 +target-version = "py310" +extend-exclude = [ + "build", + "checkpoints", + "doc", + "env", + "images", + "setup.py", +] + +[tool.ruff.lint] +# https://docs.astral.sh/ruff/linter/#rule-selection +select = [ + "E", # pycodestyle + "F", # Pyflakes + "UP", # pyupgrade + "B", # flake8-bugbear + "I", # isort + # "SIM", # flake8-simplify +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] +"Tutorial/*.ipynb" = ["E501"] # exclude "Line too long" + +[tool.pytest.ini_options] +minversion = '6.0' +addopts = '--durations 20 --junit-xml=junit-results.xml --verbose' +filterwarnings = [] + +[tool.coverage.run] +branch = true +cover_pylib = false +omit = [ + '**/__init__.py', +] + +[tool.coverage.report] +exclude_lines = [ + 'pragma: no cover', + 'if __name__ == .__main__.:', +] +precision = 2 + diff --git a/recon_surf/N4_bias_correct.py b/recon_surf/N4_bias_correct.py index 347d1331..ba9cb9ec 100644 --- a/recon_surf/N4_bias_correct.py +++ b/recon_surf/N4_bias_correct.py @@ -15,19 +15,21 @@ # IMPORTS +# Group 1: Python native modules import argparse +import logging import sys from pathlib import Path -from typing import Optional, cast, Tuple -import logging +from typing import Optional, cast, Literal, TypeVar, Callable +# Group 2: External modules import SimpleITK as sitk import numpy as np from numpy import typing as npt +# Group 3: Internal modules import image_io as iio - HELPTEXT = """ Script to call SITK N4 Bias Correction @@ -37,7 +39,7 @@ Dependencies: - Python 3.8 + Python 3.8+ SimpleITK https://simpleitk.org/ (v2.1.1) @@ -48,7 +50,7 @@ --mask. If --mask auto is given, all 0 values will be masked to speed-up correction and avoid influence of flat zero regions on the bias field. If no mask is given, no mask is applied. The mean image intensity of the --out file is adjusted to be equal -to the mean intenstiy of the input. +to the mean intensity of the input. WM Normalization and UCHAR (done if --rescale is passed): @@ -82,58 +84,150 @@ Date: Mar-18-2022 Modified: David Kügler -Date: Oct-25-2023 +Date: Feb-27-2024 """ -h_verbosity = "Logging verbosity: 0 (none), 1 (normal), 2 (debug)" -h_invol = "path to input.nii.gz" -h_outvol = "path to corrected.nii.gz" -h_rescaled = "path to rescaled.nii.gz" -h_mask = "optional: path to mask.nii.gz" -h_aseg = "optional: path to aseg or aseg+dkt image to find the white matter mask" -h_shrink = " shrink factor, default: 4" -h_levels = " number of fitting levels, default: 4" -h_numiter = " max number of iterations per level, default: 50" -h_thres = " convergence threshold, default: 0.0" -h_tal = " file name of talairach.xfm if using this for finding origin" -h_threads = " number of threads, default: 1" +HELP_VERBOSITY = "Logging verbosity: 0 (none), 1 (normal), 2 (debug)" +HELP_INVOL = "path to input.nii.gz" +HELP_OUTVOL = "path to corrected.nii.gz" +HELP_UCHAR = ("sets the output dtype to uchar (only applies to outvol, rescalevol " + "is uchar by default.)") +HELP_RESCALED = "path to rescaled.nii.gz" +HELP_MASK = "optional: path to mask.nii.gz" +HELP_ASEG = "optional: path to aseg or aseg+dkt image to find the white matter mask" +HELP_SHRINK_FACTOR = " shrink factor, default: 4" +HELP_LEVELS = " number of fitting levels, default: 4" +HELP_NUM_ITER = " max number of iterations per level, default: 50" +HELP_THRESHOLD = " convergence threshold, default: 0.0" +HELP_TALAIRACH = " file name of talairach.xfm if using this for finding origin" +HELP_THREADS = " number of threads, default: 1" +LiteralSkipRescaling = Literal["skip rescaling"] +SKIP_RESCALING: LiteralSkipRescaling = "skip rescaling" +LiteralDoNotSave = Literal["do not save"] +DO_NOT_SAVE: LiteralDoNotSave = "do not save" + +logger = logging.getLogger(__name__) + +_T = TypeVar("_T", bound=str) + + +def path_or_(*constants: _T) -> Callable[[str], _T]: + def wrapper(a: str) -> Path | _T: + if a in constants: + return a + return Path(a) + return wrapper def options_parse(): - """Command line option parser. + """ + Command line option parser. Returns ------- options - object holding options - + Namespace object holding options. """ parser = argparse.ArgumentParser( description=HELPTEXT, ) parser.add_argument( "-v", "--verbosity", - dest="verbosity", choices=(0, 1, 2), default=-1, help=h_verbosity, type=int + dest="verbosity", + choices=(0, 1, 2), + default=-1, + help=HELP_VERBOSITY, + type=int, + ) + parser.add_argument( + "--in", + dest="invol", + type=Path, + help=HELP_INVOL, + required=True, + ) + parser.add_argument( + "--out", + dest="outvol", + help=HELP_OUTVOL, + default=DO_NOT_SAVE, + type=path_or_(DO_NOT_SAVE), + ) + parser.add_argument( + "--uchar", + dest="dtype", + action="store_const", + const="uint8", + help=HELP_UCHAR, + default="keep", + ) + parser.add_argument( + "--rescale", + dest="rescalevol", + help=HELP_RESCALED, + default=SKIP_RESCALING, + type=path_or_(SKIP_RESCALING), + ) + parser.add_argument( + "--mask", + dest="mask", + help=HELP_MASK, + default=None, + type=Path, ) - parser.add_argument("--in", dest="invol", help=h_invol, required=True) - parser.add_argument("--out", dest="outvol", help=h_outvol, default="do not save") - parser.add_argument("--rescale", dest="rescalevol", help=h_rescaled, default="skip rescaling") - parser.add_argument("--mask", dest="mask", help=h_mask, default=None) - parser.add_argument("--aseg", dest="aseg", help=h_aseg, default=None) - parser.add_argument("--shrink", dest="shrink", help=h_shrink, default=4, type=int) - parser.add_argument("--levels", dest="levels", help=h_levels, default=4, type=int) parser.add_argument( - "--numiter", dest="numiter", help=h_numiter, default=50, type=int + "--aseg", + dest="aseg", + help=HELP_ASEG, + default=None, + type=Path, ) - parser.add_argument("--thres", dest="thres", help=h_thres, default=0.0, type=float) - parser.add_argument("--tal", dest="tal", help=h_tal, default=None) parser.add_argument( - "--threads", dest="threads", help=h_threads, default=1, type=int + "--shrink", + dest="shrink", + help=HELP_SHRINK_FACTOR, + default=4, + type=int, + ) + parser.add_argument( + "--levels", + dest="levels", + help=HELP_LEVELS, + default=4, + type=int, + ) + parser.add_argument( + "--numiter", + dest="numiter", + help=HELP_NUM_ITER, + default=50, + type=int, + ) + parser.add_argument( + "--thres", + dest="thres", + help=HELP_THRESHOLD, + default=0.0, + type=float, + ) + parser.add_argument( + "--tal", + dest="tal", + help=HELP_TALAIRACH, + default=None, + type=Path, + ) + parser.add_argument( + "--threads", + dest="threads", + help=HELP_THREADS, + default=1, + type=int, ) parser.add_argument( "--version", action="version", - version="$Id: N4_bias_correct.py,v 2.0 2023/10/25 20:02:08 mreuter,dkuegler Exp $" + version="$Id: N4_bias_correct.py,v 2.1 2024/02/27 20:02:08 mreuter,dkuegler Exp $" ) return parser.parse_args() @@ -146,28 +240,28 @@ def itk_n4_bfcorrection( numiter: int = 50, thres: float = 0.0, ) -> sitk.Image: - """Perform the bias field correction. + """ + Perform the bias field correction. Parameters ---------- itk_image : sitk.Image - n-dimensional image + N-dimensional image. itk_mask : Optional[sitk.Image] - Image mask. Defaults to None. Optional + Image mask. Defaults to None. Optional. shrink : int - Shrink factors. Defaults to 4 + Shrink factors. Defaults to 4. levels : int - Number of levels for maximum number of iterations. Defaults to 4 + Number of levels for maximum number of iterations. Defaults to 4. numiter : int - Maximum number if iterations. Defaults 50 + Maximum number if iterations. Defaults 50. thres : float - Convergence threshold. Defaults to 0.0 + Convergence threshold. Defaults to 0.0. Returns ------- itk_bfcorr_image - Bias field corrected image - + Bias field corrected image. """ _logger = logging.getLogger(__name__ + ".itk_n4_bfcorrection") # if no mask is passed, create a simple mask from the image @@ -214,28 +308,28 @@ def normalize_wm_mask_ball( target_wm: float = 110., target_bg: float = 3. ) -> sitk.Image: - """Normalize WM image by Mask and optionally ball around talairach center. + """ + Normalize WM image by Mask and optionally ball around talairach center. Parameters ---------- itk_image : sitk.Image - n-dimensional itk image + N-dimensional itk image. itk_mask : sitk.Image, optional Image mask. - radius : float | int - Defaults to 50 [MISSING] + radius : float, int, default=50 + Radius of ball around centroid. Defaults to 50. centroid : np.ndarray - brain centroid. + Brain centroid. target_wm : float | int Target white matter intensity. Defaults to 110. target_bg : float | int - target background intensity. Defaults to 3 (1% of 255) + Target background intensity. Defaults to 3 (1% of 255). Returns ------- normalized_image : sitk.Image - Normalized WM image - + Normalized WM image. """ _logger = logging.getLogger(__name__ + ".normalize_wm_mask_ball") _logger.info(f"- centroid: {centroid}") @@ -261,14 +355,19 @@ def get_distance(axis): # get ndarray from sitk image image = sitk.GetArrayFromImage(itk_image) # get 90th percentiles of intensities in ball (to identify WM intensity) - source_intensity = np.percentile(image[ball], [1, 90]) + source_intensity_bg, source_intensity_wm = np.percentile(image[ball], [1, 90]) _logger.info( - f"- source background intensity: {source_intensity[0]:.2f}" - f"- source white matter intensity: {source_intensity[1]:.2f}" + f"- source background intensity: {source_intensity_bg:.2f}" + f"- source white matter intensity: {source_intensity_wm:.2f}" ) - return normalize_img(itk_image, itk_mask, tuple(source_intensity.tolist()), (target_bg, target_wm)) + return normalize_img( + itk_image, + itk_mask, + (source_intensity_bg, source_intensity_wm), + (target_bg, target_wm), + ) def normalize_wm_aseg( @@ -278,30 +377,31 @@ def normalize_wm_aseg( target_wm: float = 110., target_bg: float = 3. ) -> sitk.Image: - """Normalize WM image [MISSING]. + """ + Normalize WM image so the white matter has a mean intensity of target_wm and the + background has intensity target_bg. Parameters ---------- itk_image : sitk.Image - n-dimensional itk image + N-dimensional itk image. itk_mask : sitk.Image | None Image mask. itk_aseg : sitk.Image - aseg-like segmentation image to find WM. - radius : float | int - Defaults to 50 [MISSING] - centroid : Optional[np.ndarray] - Image centroid. Defaults to None + Aseg-like segmentation image to find WM. + radius : float, int, default=50 + Radius of ball around centroid. Defaults to 50. + centroid : np.ndarray, optional + Image centroid. Defaults to None. target_wm : float | int - target white matter intensity. Defaults to 110 + Target white matter intensity. Defaults to 110. target_bg : float | int - target background intensity Defaults to 3 (1% of 255) + Target background intensity Defaults to 3 (1% of 255). Returns ------- normed : sitk.Image - Normalized WM image - + Normalized WM image. """ _logger = logging.getLogger(__name__ + ".normalize_wm_aseg") @@ -319,14 +419,19 @@ def normalize_wm_aseg( f"- source white matter intensity: {source_wm_intensity:.2f}" ) - return normalize_img(itk_image, itk_mask, (source_bg, source_wm_intensity), (target_bg, target_wm)) + return normalize_img( + itk_image, + itk_mask, + (source_bg, source_wm_intensity), + (target_bg, target_wm), + ) def normalize_img( itk_image: sitk.Image, itk_mask: Optional[sitk.Image], - source_intensity: Tuple[float, float], - target_intensity: Tuple[float, float] + source_intensity: tuple[float, float], + target_intensity: tuple[float, float] ) -> sitk.Image: """ Normalize image by source and target intensity values. @@ -334,17 +439,26 @@ def normalize_img( Parameters ---------- itk_image : sitk.Image + Input image to be normalized. itk_mask : sitk.Image | None - source_intensity : Tuple[float, float] - target_intensity : Tuple[float, float] + Brain mask, voxels inside the mask are guaranteed to be > 0, + None is optional. + source_intensity : tuple[float, float] + Source intensity range. + target_intensity : tuple[float, float] + Target intensity range. Returns ------- - Rescaled image + sitk.Image + Rescaled image. """ _logger = logging.getLogger(__name__ + ".normalize_wm") # compute intensity transformation - m = (target_intensity[0] - target_intensity[1]) / (source_intensity[0] - source_intensity[1]) + m = ( + (target_intensity[0] - target_intensity[1]) + / (source_intensity[0] - source_intensity[1]) + ) _logger.info(f"- m: {m:.4f}") # itk_image already is Float32 and output should be also Float32, we clamp outside @@ -359,23 +473,23 @@ def normalize_img( def read_talairach_xfm(fname: Path | str) -> np.ndarray: - """Read Talairach transform. + """ + Read Talairach transform. Parameters ---------- fname : str - Filename to Talairach transform + Filename to Talairach transform. Returns ------- tal - Talairach transform matrix + Talairach transform matrix. Raises ------ ValueError if the file is of an invalid format. - """ _logger = logging.getLogger(__name__ + ".read_talairach_xfm") _logger.info(f"reading talairach transform from {fname}") @@ -383,22 +497,30 @@ def read_talairach_xfm(fname: Path | str) -> np.ndarray: lines = f.readlines() try: - transf_start = [l.lower().startswith("linear_") for l in lines].index(True) + 1 - tal_str = [l.replace(";", " ") for l in lines[transf_start:transf_start + 3]] + transform_iter = iter(lines) + # advance transform_iter to linear header + _ = next(ln for ln in transform_iter if ln.lower().startswith("linear_")) + # return the next 3 lines in transform_lines + transform_lines = (ln for ln, _ in zip(transform_iter, range(3))) + tal_str = [ln.replace(";", " ") for ln in transform_lines] tal = np.genfromtxt(tal_str) tal = np.vstack([tal, [0, 0, 0, 1]]) _logger.info(f"- tal: {tal}") return tal - except Exception as e: + except StopIteration: + _logger.error(msg := f"Could not find 'linear_' in {fname}.") + raise ValueError(msg) + except (Exception, StopIteration) as e: err = ValueError(f"Could not find taiairach transform in {fname}.") _logger.exception(err) raise err from e def get_tal_origin_voxel(tal: npt.ArrayLike, image: sitk.Image) -> np.ndarray: - """Get the origin of Talairach space in voxel coordinates. + """ + Get the origin of Talairach space in voxel coordinates. Parameters ---------- @@ -410,8 +532,7 @@ def get_tal_origin_voxel(tal: npt.ArrayLike, image: sitk.Image) -> np.ndarray: Returns ------- vox_origin : np.ndarray - Voxel coordinate of Talairach origin - + Voxel coordinate of Talairach origin. """ tal_inv = np.linalg.inv(tal) tal_origin = np.array(tal_inv[0:3, 3]).ravel() @@ -453,13 +574,14 @@ def get_image_mean(image: sitk.Image, mask: Optional[sitk.Image] = None) -> floa Parameters ---------- image : sitk.Image - image to get mean of + Image to get mean of. mask : sitk.Image, optional - optional mask to apply first + Optional mask to apply first. Returns ------- mean : float + The mean value of the image. """ img = sitk.GetArrayFromImage(image) if mask is not None: @@ -470,16 +592,17 @@ def get_image_mean(image: sitk.Image, mask: Optional[sitk.Image] = None) -> floa def get_brain_centroid(itk_mask: sitk.Image) -> np.ndarray: """ - Get the brain centroid from the itk_mask + Get the brain centroid from a binary image. Parameters ---------- itk_mask : sitk.Image + Binary image to compute the centroid of its labeled region. Returns ------- - brain centroid - + np.ndarray + Brain centroid. """ _logger = logging.getLogger(__name__ + ".get_brain_centroid") _logger.debug("No talairach center passed, estimating center from mask.") @@ -493,38 +616,45 @@ def get_brain_centroid(itk_mask: sitk.Image) -> np.ndarray: return itk_mask.TransformPhysicalPointToIndex(centroid_world) -if __name__ == "__main__": - - # Command Line options are error checking done here - options = options_parse() - LOGLEVEL = (logging.WARNING, logging.INFO, logging.DEBUG) - FORMAT = "" if options.verbosity < 0 else "%(levelname)s (%(module)s:%(lineno)s): " - FORMAT += "%(message)s" - logging.basicConfig(stream=sys.stdout, format=FORMAT) - logging.getLogger().setLevel(LOGLEVEL[abs(options.verbosity)]) - logger = logging.getLogger(__name__) - print_options(vars(options)) - - if options.rescalevol == "skip rescaling" and options.outvol == "do not save": - logger.error("Neither the rescaled nor the unrescaled volume are saved, aborting.") - sys.exit(1) +def main( + invol: Path, + outvol: LiteralDoNotSave | Path = DO_NOT_SAVE, + rescalevol: LiteralSkipRescaling | Path = SKIP_RESCALING, + dtype: str = "keep", + threads: int = 1, + mask: Optional[Path] = None, + aseg: Optional[Path] = None, + shrink: int = 4, + levels: int = 4, + numiter: int = 50, + thres: float = 0.0, + tal: Optional[Path] = None, + verbosity: int = -1, +) -> int | str: + if rescalevol == "skip rescaling" and outvol == DO_NOT_SAVE: + return ( + "Neither the rescaled nor the unrescaled volume are saved, " + "aborting." + ) # set number of threads - sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(options.threads) + sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(threads) # read image (only nii supported) and convert to float32 - logger.debug(f"reading input volume {options.invol}") + logger.debug(f"reading input volume {invol}") # itk_image = sitk.ReadImage(options.invol, sitk.sitkFloat32) itk_image, image_header = iio.readITKimage( - options.invol, sitk.sitkFloat32, with_header=True + str(invol), + sitk.sitkFloat32, + with_header=True, ) # read mask (as uchar) - has_mask = bool(options.mask) + has_mask = bool(mask) if has_mask: - logger.debug(f"reading mask {options.mask}") + logger.debug(f"reading mask {mask}") itk_mask: Optional[sitk.Image] = iio.readITKimage( - options.mask, + str(mask), sitk.sitkUInt8, with_header=False ) @@ -538,50 +668,70 @@ def get_brain_centroid(itk_mask: sitk.Image) -> np.ndarray: itk_bfcorr_image = itk_n4_bfcorrection( itk_image, itk_mask, - options.shrink, - options.levels, - options.numiter, - options.thres, + shrink, + levels, + numiter, + thres, ) - if options.outvol != "do not save": + if outvol != DO_NOT_SAVE: logger.info("Skipping WM normalization, ignoring talairach and aseg inputs") # normalize to average input intensity kw_mask = {"mask": itk_mask} if has_mask else {} - m_bf_img = get_image_mean(itk_bfcorr_image, **kw_mask) - m_image = get_image_mean(itk_image, **kw_mask) - logger.info("- rescale") - logger.info(f" mean input: {m_image:.4f}, mean corrected {m_bf_img:.4f})") - # rescale keeping the zero-point and the mean image intensity - itk_outvol = normalize_img(itk_image, itk_mask, (0., m_bf_img), (0., m_image)) - logger.info("converting outvol to UCHAR") - itk_outvol = sitk.Cast( - sitk.Clamp(itk_outvol, lowerBound=0, upperBound=255), sitk.sitkUInt8 - ) + logger.info("- rescale") + out_dtype = dtype.lower() + if out_dtype == "uint8": + from FastSurferCNN.data_loader.conform import getscale + image = sitk.GetArrayFromImage(itk_bfcorr_image) + l_image, m_image = 0, 255 + l_bf_img, scale = getscale(image, l_image, m_image) + m_bf_img = l_bf_img + (m_image - l_image) / scale + logger.info(f" lower bound corrected: {l_bf_img:.4f}, upper bound corrected {m_bf_img:.4f})") + else: + m_bf_img = get_image_mean(itk_bfcorr_image, **kw_mask) + m_image = get_image_mean(itk_image, **kw_mask) + logger.info(f" mean input: {m_image:.4f}, mean corrected {m_bf_img:.4f})") + l_bf_img, l_image = 0.0, 0.0 + # rescale keeping the zero-point and the mean image intensity + + itk_outvol = normalize_img(itk_image, itk_mask, (l_bf_img, m_bf_img), (l_image, m_image)) + + if out_dtype in ("uint8", "int8", "uint16", "int16"): + dtype_info = np.iinfo(np.dtype(out_dtype)) + itk_outvol = sitk.Clamp(itk_outvol, lowerBound=dtype_info.min, upperBound=dtype_info.max) + + if out_dtype != "keep": + logger.info(f"converting outvol to {dtype.upper()}") + cap_dtype = out_dtype + for prefix in ("i", "ui", "f"): + if cap_dtype.startswith(prefix): + cap_dtype = prefix.upper() + cap_dtype[len(prefix):] + sitk_dtype = getattr(sitk, "sitk" + cap_dtype) + itk_outvol = sitk.Cast(itk_outvol, sitk_dtype) + image_header.set_data_dtype(np.dtype(out_dtype)) # write image - logger.info(f"writing: {options.outvol}") - iio.writeITKimage(itk_outvol, options.outvol, image_header) + logger.info(f"writing {type(outvol).__name__}: {outvol}") + iio.writeITKimage(itk_outvol, str(outvol), image_header) - - if options.rescalevol == "skip rescaling": + if rescalevol == SKIP_RESCALING: logger.info("Skipping WM normalization, ignoring talairach and aseg inputs") else: target_wm = 110. # do some rescaling - if options.aseg: + if aseg: # has aseg # used to be 110, but we found experimentally, that freesurfer wm-normalized # intensity insde the WM mask is closer to 105 (but also quite inconsistent). - # So when we have a WM mask, we need to use 105 and not 110 as for the - # percentile approach above. + # So when we have a WM mask, we need to use 105 and not 110 as for the + # percentile approach above. target_wm = 105. logger.info(f"normalize WM to {target_wm:.1f} (find WM from aseg)") # only grab the white matter - itk_aseg = iio.readITKimage(options.aseg, with_header=False) + itk_aseg = iio.readITKimage(str(aseg), with_header=False) itk_bfcorr_image = normalize_wm_aseg( itk_bfcorr_image, @@ -591,14 +741,14 @@ def get_brain_centroid(itk_mask: sitk.Image) -> np.ndarray: ) else: logger.info(f"normalize WM to {target_wm:.1f} (find WM from mask & talairach)") - if options.tal: - talairach_center = read_talairach_xfm(options.tal) + if tal: + talairach_center = read_talairach_xfm(tal) brain_centroid = get_tal_origin_voxel(talairach_center, itk_image) elif has_mask: brain_centroid = get_brain_centroid(itk_mask) else: - logger.error("Neither --tal, --mask, nor --aseg are passed, but rescaling is requested.") - sys.exit(1) + return ("Neither --tal, --mask, nor --aseg are passed, but " + "rescaling is requested.") itk_bfcorr_image = normalize_wm_mask_ball( itk_bfcorr_image, @@ -613,7 +763,29 @@ def get_brain_centroid(itk_mask: sitk.Image) -> np.ndarray: ) # write image - logger.info(f"writing: {options.rescalevol}") - iio.writeITKimage(itk_bfcorr_image, options.rescalevol, image_header) + logger.info(f"writing {type(rescalevol).__name__}: {rescalevol}") + iio.writeITKimage(itk_bfcorr_image, str(rescalevol), image_header) + + return 0 + + +if __name__ == "__main__": + # Command Line options are error checking done here + options = options_parse() + LOGLEVEL = (logging.WARNING, logging.INFO, logging.DEBUG) + FORMAT = "%(message)s" + if options.verbosity >= 0: + FORMAT = "%(levelname)s (%(module)s:%(lineno)s): " + FORMAT + + logging.basicConfig(stream=sys.stdout, format=FORMAT) + logging.getLogger().setLevel(LOGLEVEL[abs(options.verbosity)]) + + if options.rescalevol == "skip rescaling" and options.outvol == "do not save": + logger.error("Neither the rescaled nor the unrescaled volume are saved, aborting.") + sys.exit(1) + + args = vars(options) + print_options(args) + invol = args.pop("invol") - sys.exit(0) + sys.exit(main(invol, **args)) diff --git a/recon_surf/README.md b/recon_surf/README.md index 8a2daeaf..6e812bd4 100644 --- a/recon_surf/README.md +++ b/recon_surf/README.md @@ -13,7 +13,7 @@ will be run in hires mode. Also note, that if a file exists at `$subjects_dir/$subject_id/mri/orig_nu.mgz`, this file will be used as the bias-field corrected image and the bias-field correction is skipped. # Usage -The *recon_surf* directory contains scripts to run the analysis. In addition, a working installation of __FreeSurfer__ (v7.3.2) is needed for a native install (or use our Docker/Singularity images). +The *recon_surf* directory contains scripts to run the analysis. In addition, a working installation of __FreeSurfer__ (the supported version, usually the most recent) is needed for a native install (or use our Docker/Singularity images). The main script is called __recon-surf.sh__ which accepts certain arguments via the command line. List them by running the following command: @@ -22,11 +22,11 @@ List them by running the following command: ./recon-surf.sh --help ``` -### Required arguments +## Required arguments * `--sd`: Output directory \$SUBJECTS_DIR (equivalent to FreeSurfer setup --> $SUBJECTS_DIR/sid/mri; $SUBJECTS_DIR/sid/surf ... will be created). * `--sid`: Subject ID for directory inside \$SUBJECTS_DIR to be created ($SUBJECTS_DIR/sid/...) -### Optional arguments +## Optional arguments * `--t1`: T1 full head input (not bias corrected). This must be conformed (dimensions: same along each axis, voxel size: isotropic, LIA orientation, and data type UCHAR). Images can be conformed using FastSurferCNN's [conform.py](https://github.com/Deep-MI/FastSurfer/blob/stable/FastSurferCNN/data_loader/conform.py) script (usage example: python3 FastSurferCNN/data_loader/conform.py -i -o ). If not passed we use the orig.mgz in the output subject mri directory if available. * `--asegdkt_segfile`: Global path with filename of segmentation (where and under which name to find it, must already exist). This must be conformed (dimensions: same along each axis, voxel size: isotropic, and LIA orientation). FastSurferCNN's segmentations are conformed by default. Please ensure that segmentations produced otherwise are also conformed and equivalent in dimension and voxel size to the --t1 image. Default location: $SUBJECTS_DIR/$sid/mri/aparc.DKTatlas+aseg.deep.mgz * `--3T`: for Talairach registration, use the 3T atlas instead of the 1.5T atlas (which is used if the flag is not provided). This gives better (more consistent with FreeSurfer) ICV estimates (eTIV) for 3T and better Talairach registration matrices, but has little impact on standard volume or surface stats. @@ -36,94 +36,99 @@ List them by running the following command: * `--parallel`: Run both hemispheres in parallel * `--threads`: Set openMP and ITK threads to -### Other -* `--py`: Command for python, used in both pipelines. Default: python3.8 +## Other +* `--py`: Command for python, used in both pipelines. Default: python3.10 * `--no_surfreg`: Skip surface registration with FreeSurfer (if only stats are needed) * `--fs_license`: Path to FreeSurfer license key file. Register at https://surfer.nmr.mgh.harvard.edu/registration.html for free to obtain it if you do not have FreeSurfer installed already For more details see `--help`. -### Example 1: recon-surf inside Docker +### Example 1: Surface module inside Docker -Docker can be again be used to simplify the installation (no FreeSurfer on system required). +Docker can be used to simplify the installation (no FreeSurfer on system required). Given you already ran the segmentation pipeline, and want to just run the surface pipeline on top of it (i.e. on a different cluster), the following command can be used: ```bash -# 1. Build the singularity image (if it does not exist) -docker pull deepmi/fastsurfer:surfonly-cpu-v2.0.0 +# 1. Pull the docker image (if it does not exist locally) +docker pull deepmi/fastsurfer:cpu-v?.?.? # 2. Run command docker run -v /home/user/my_fastsurfer_analysis:/output \ -v /home/user/my_fs_license_dir:/fs_license \ - --rm --user $(id -u):$(id -g) deepmi/fastsurfer:cpu-v2.2.0 \ + --rm --user $(id -u):$(id -g) deepmi/fastsurfer:cpu-v?.?.? \ --fs_license /fs_license/license.txt \ - --sid subjectX --sd /output --3T + --sid subjectX --sd /output --3T --surf_only ``` +Check [Dockerhub](https://hub.docker.com/r/deepmi/fastsurfer/tags) to find out the latest release version and replace the "?". Docker Flags: * The `-v` commands mount your output, and directory with the FreeSurfer license file into the Docker container. Inside the container these are visible under the name following the colon (in this case /output and /fs_license). -As the --t1 and --asegdkt_segfile flag are not set, a subfolder within the target directory named after the subject (here: subjectX) needs to exist and contain t1-weighted conformed image, -mask and segmentations (as output by our FastSurfer segmentation networks, i.e. under /home/user/my_fastsurfeer_analysis/subjectX/mri/aparc.DKTatlas+aseg.deep.mgz, mask.mgz, and orig.mgz)). The directory will then be populated with the FreeSurfer file structure, including surfaces, statistics -and labels file (equivalent to a FreeSurfer recon-all run). - +This essentially calls the run_fastsurfer.sh script as entry point and starts only the surface module. It assumes that this case `subjectX` exists already and that the output files of the segmentation module are +available in the subjectX/mri directory (e.g. `/home/user/my_fastsurfeer_analysis/subjectX/mri/aparc.DKTatlas+aseg.deep.mgz`, `mask.mgz`, `orig.mgz` etc.). The directory will then be populated with the FreeSurfer file structure, including surfaces, statistics and labels file (equivalent to a FreeSurfer recon-all run). It is possible to modify the entry point during the docker call and directly run recon-surf.sh, as we will demonstrate with the Singularity example next. -### Example 2: recon-surf inside Singularity -Singularity can be used as for the full pipeline. Given you already ran the segmentation pipeline, and want to just run +## Example 2: recon-surf inside Singularity +Singularity can be used instead of Docker to run the full pipeline or individual modules. In this example we change the entrypoint to `recon-surf.sh` instead of the standard +`run_fastsurfer.sh`. Usually it is recomended to just use the default, so this is for expert users who may want to try out specific flags that are not passed to the wrapper. +Given you already ran the segmentation pipeline, and want to just run the surface pipeline on top of it (i.e. on a different cluster), the following command can be used: ```bash # 1. Build the singularity image (if it does not exist) -singularity build fastsurfer-reconsurf.sif docker://deepmi/fastsurfer:surfonly-cpu-v2.0.0 +singularity build fastsurfer-cpu-v?.?.?.sif docker://deepmi/fastsurfer:cpu-v?.?.? # 2. Run command singularity exec --no-home \ -B /home/user/my_fastsurfer_analysis:/output \ -B /home/user/my_fs_license_dir:/fs_license \ - ./fastsurfer.sif \ + ./fastsurfer-cpu-?.?.?.sif \ /fastsurfer/recon_surf/recon-surf.sh \ --fs_license /fs_license/license.txt \ - --sid subjectX --sd /output --3T + --sid subjectX --sd /output --3T \ + --t1 /subjectX/mri/orig.mgz \ + --asegdkt_segfile /subjectX/mri/aparc.DKTatlas+aseg.deep.mgz ``` +Check [Dockerhub](https://hub.docker.com/r/deepmi/fastsurfer/tags) to find out the latest release version and replace the "?". -#### Singularity Flags: +### Singularity Flags: * The `-B` commands mount your output, and directory with the FreeSurfer license file into the Singularity container. Inside the container these are visible under the name following the colon (in this case /data, /output, and /fs_license). * The `--no-home` command disables the automatic mount of the users home directory (see [Best Practice](../Singularity/README.md#mounting-home)) -As the --t1 and --asegdkt_segfile flag are not set, a subfolder within the target directory named after the subject (here: subjectX) needs to exist and contain t1-weighted conformed image, -mask and segmentations (as output by our FastSurfer segmentation networks, i.e. under /home/user/my_fastsurfeer_analysis/subjectX/mri/aparc.DKTatlas+aseg.deep.mgz, mask.mgz, and orig.mgz)). The directory will then be populated with the FreeSurfer file structure, including surfaces, statistics +The `--t1` and `--asegdkt_segfile` flags point to the already existing conformed T1 input and segmentation from the segmentation module. Also other files from that pipeline +will be reused (e.g. the `mask.mgz`, `orig_nu.mgz`). The subject directory will then be populated with the FreeSurfer file structure, including surfaces, statistics and labels file (equivalent to a FreeSurfer recon-all run). -### Example 3: Native installation - recon-surf on a single subject (subjectX) +## Example 3: Native installation - recon-surf on a single subject (subjectX) -Given you want to analyze data for subjectX which is stored on your computer under /home/user/my_mri_data/subjectX/orig.mgz, +Given you want to analyze data for subjectX which is stored on your computer under `/home/user/my_mri_data/subjectX/orig.mgz`, run the following command from the console (do not forget to source FreeSurfer!): ```bash # Source FreeSurfer -export FREESURFER_HOME=/path/to/freesurfer/fs732 +export FREESURFER_HOME=/path/to/freesurfer source $FREESURFER_HOME/SetUpFreeSurfer.sh # Define data directory datadir=/home/user/my_mri_data segdir=/home/user/my_segmentation_data -targetdir=/home/user/my_recon_surf_output # equivalent to FreeSurfer's SUBJECT_DIR +targetdir=/home/user/my_recon_surf_output # equivalent to FreeSurfer's SUBJECTS_DIR # Run recon-surf ./recon-surf.sh --sid subjectX \ --sd $targetdir \ - --py python3.8 \ - --3T - + --py python3.10 \ + --3T \ + --t1 /subjectX/mri/orig.mgz \ + --asegdkt_segfile /subjectX/mri/aparc.DKTatlas+aseg.deep.mgz ``` -As the --t1 and --asegdkt_segfile flag are not set, a subfolder within the target directory named after the subject (here: subjectX) needs to exist and contain t1-weighted conformed image, -mask and segmentations (as output by our FastSurfer segmentation networks, i.e. under `/home/user/my_fastsurfeer_analysis/subjectX/mri/orig.mgz`, `.../aparc.DKTatlas+aseg.deep.mgz`, and `.../mask.mgz`). The directory will then be populated with the FreeSurfer file structure, including surfaces, statistics and labels file (equivalent to a FreeSurfer recon-all run). -The script will also generate a bias-field corrected image at `/home/user/my_fastsurfeer_analysis/subjectX/mri/orig_nu.mgz`, if this did not already exist. +The `--t1` and `--asegdkt_segfile` flags point to the already existing conformed T1 input and segmentation from the segmentation module. Also other files from that pipeline +will be reused (e.g. the `mask.mgz`, `orig_nu.mgz`, i.e. under `/home/user/my_fastsurfeer_analysis/subjectX/mri/mask.mgz`). The `subjectX` directory will then be populated with the FreeSurfer file structure, including surfaces, statistics and labels file (equivalent to a FreeSurfer recon-all run). +The script will generate a bias-field corrected image at `/home/user/my_fastsurfeer_analysis/subjectX/mri/orig_nu.mgz`, if this did not already exist. ### Example 4: recon-surf on multiple subjects -Most of the recon_surf functionality can also be achieved by running `run_fastsurfer.sh` with the `--surf_only` flag. This means we can also use the `brun_fastsurfer.sh` command with `--surf_only` to achieve similar results (see also [Example 4 in the main README](../README.md#example-4-fastsurfer-on-multiple-subjects). +Most of the recon_surf functionality can also be achieved by running `run_fastsurfer.sh` with the `--surf_only` flag. This means we can also use the `brun_fastsurfer.sh` command with `--surf_only` to achieve similar results (see also [Example 4](../doc/overview/EXAMPLES.md#example-4-fastsurfer-on-multiple-subjects). There are however some small differences to be aware of: 1. the path to and the filename of the t1 image in the subject_list file is optional. @@ -147,7 +152,7 @@ singularity exec --no-home \ ``` A dedicated subfolder will be used for each subject within the target directory. -As the --t1 and --asegdkt_segfile flags are not set, a subfolder within the target directory named after each subject (`$subject_id`) needs to exist and contain T1-weighted conformed image, +As the `--t1` and `--asegdkt_segfile` flags are not set, a subfolder within the target directory named after each subject (`$subject_id`) needs to exist and contain T1-weighted conformed image, mask and segmentations (as output by our FastSurfer segmentation networks, i.e. under `$subjects_dir/$subject_id/mri/orig.mgz`, `$subjects_dir/$subject_id/mri/mask.mgz`, and `$subjects_dir/$subject_id/mri/aparc.DKTatlas+aseg.deep.mgz`, respectively). The directory will then be populated with the FreeSurfer file structure, including surfaces, statistics and labels file (equivalent to a FreeSurfer recon-all run). The script will also generate a bias-field corrected image at `$subjects_dir/$subject_id/mri/orig_nu.mgz`, if this did not already exist. @@ -156,46 +161,46 @@ The logs of individual subject's processing can be found in `$subjects_dir/$subj # Manual Edits -### Brainmask Edits +## Brainmask Edits -Currently, FastSurfer has only very limited functionality for manual edits due to missing entrypoints into the recon-surf script. Starting with FastSurfer v2.0.0 one frequently requested edit type (brainmask editing) is now possible, as the initial mask is created in the first segmentation stage. By running segmentation and surface processing in two steps, the mask can be edited in-between. +Currently, FastSurfer has only very limited functionality for manual edits due to missing entrypoints into the recon-surf script. Starting with FastSurfer v2 one frequently requested edit type (brainmask editing) is now possible, as the initial mask is created in the first segmentation stage. By running segmentation and surface processing in two steps, the mask can be edited in-between. For a **Docker setup** one can: 1. Run segmentation only: -```bash -docker run --gpus=all --rm --name $CONTAINER_NAME \ - -v $PATH_TO_IMAGE_DIR:$IMAGE_DIR \ - -v $PATH_TO_OUTPUT_DIR:$OUTPUT_DIR \ - --user $UID:$GID deepmi/fastsurfer:gpu-v2.0.0 \ - --t1 $IMAGE_DIR/input.mgz \ - --sd $OUTPUT_DIR \ - --sid $SUBJECT_ID \ - --seg_only -``` + ```bash + docker run --gpus=all --rm --name $CONTAINER_NAME \ + -v $PATH_TO_IMAGE_DIR:$IMAGE_DIR \ + -v $PATH_TO_OUTPUT_DIR:$OUTPUT_DIR \ + --user $UID:$GID deepmi/fastsurfer:gpu-v?.?.? \ + --t1 $IMAGE_DIR/input.mgz \ + --sd $OUTPUT_DIR \ + --sid $SUBJECT_ID \ + --seg_only + ``` 2. Modify the ```$PATH_TO_OUTPUT_DIR/$SUBJECT_ID/mri/mask.mgz``` file as required. 3. Run the following Docker command to run the surface processing pipeline (remove `--3T` if you are working with 1.5T data): -```bash -docker run --rm --name $CONTAINER_NAME \ - -v $PATH_TO_OUTPUT_DIR:$OUTPUT_DIR \ - -v $PATH_TO_FS_LICENSE_DIR:$FS_LICENSE_DIR \ - --user $UID:$GID deepmi/fastsurfer:gpu-v2.0.0 \ - --sid $SUBJECT_ID \ - --sd $OUTPUT_DIR/$SUBJECT_ID \ - --surf_only --3T \ - --fs_license $FS_LICENSE_DIR/license_file -``` + ```bash + docker run --rm --name $CONTAINER_NAME \ + -v $PATH_TO_OUTPUT_DIR:$OUTPUT_DIR \ + -v $PATH_TO_FS_LICENSE_DIR:$FS_LICENSE_DIR \ + --user $UID:$GID deepmi/fastsurfer:gpu-v?.?.? \ + --sid $SUBJECT_ID \ + --sd $OUTPUT_DIR/$SUBJECT_ID \ + --surf_only --3T \ + --fs_license $FS_LICENSE_DIR/license_file + ``` For a **local install** you can similarly: -1. Go to the FastSurfer directory, source FreeSurfer 7.3.2 and run the segmentation step: -```bash -cd $FASTSURFER_HOME -source $FREESURFER_HOME/SetUpFreeSurfer.sh -./run_fastsurfer.sh --t1 $IMAGE_DIR/input.mgz --sd $OUTPUT_DIR --sid $SUBJECT_ID --seg_only -``` +1. Go to the FastSurfer directory, source FreeSurfer and run the segmentation step: + ```bash + cd $FASTSURFER_HOME + source $FREESURFER_HOME/SetUpFreeSurfer.sh + ./run_fastsurfer.sh --t1 $IMAGE_DIR/input.mgz --sd $OUTPUT_DIR --sid $SUBJECT_ID --seg_only + ``` 2. Modify the ```$OUTPUT_DIR/$SUBJECT_ID/mri/mask.mgz``` file. 3. Run the surface pipeline (remove `--3T` if you are working with 1.5T data): -```bash -./run_fastsurfer.sh --sd $OUTPUT_DIR --sid $SUBJECT_ID --fs_license $FS_LICENSE_DIR/license_file --surf_only --3T -``` + ```bash + ./run_fastsurfer.sh --sd $OUTPUT_DIR --sid $SUBJECT_ID --fs_license $FS_LICENSE_DIR/license_file --surf_only --3T + ``` diff --git a/recon_surf/align_points.py b/recon_surf/align_points.py index 0824dc2c..e94a8f1c 100755 --- a/recon_surf/align_points.py +++ b/recon_surf/align_points.py @@ -28,18 +28,18 @@ def rmat2angles(R: npt.NDArray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Extract rotation angles (alpha,beta,gamma) in FreeSurfer format (mris_register) from a rotation matrix. + """ + Extract rotation angles (alpha,beta,gamma) in FreeSurfer format (mris_register) from a rotation matrix. Parameters ---------- R : npt.NDArray - Rotation matrix + Rotation matrix. Returns ------- alpha, beta, gamma - Rotation degree - + Rotation degree. """ alpha = np.degrees(-np.arctan2(R[1, 0], R[0, 0])) beta = np.degrees(np.arcsin(R[2, 0])) @@ -48,22 +48,22 @@ def rmat2angles(R: npt.NDArray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: def angles2rmat(alpha: float, beta: float, gamma: float) -> np.array: - """Convert FreeSurfer angles (alpha,beta,gamma) in degrees to a rotation matrix. + """ + Convert FreeSurfer angles (alpha,beta,gamma) in degrees to a rotation matrix. Parameters ---------- alpha : float - FreeSurfer angle in degrees + FreeSurfer angle in degrees. beta : float - FreeSurfer angle in degrees + FreeSurfer angle in degrees. gamma : float - FreeSurfer angle in degrees + FreeSurfer angle in degrees. Returns ------- R - rotation angles - + Rotation angles. """ sa = np.sin(np.radians(alpha)) sb = np.sin(np.radians(beta)) @@ -82,25 +82,25 @@ def angles2rmat(alpha: float, beta: float, gamma: float) -> np.array: def find_rotation(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: - """Find the rotation matrix. + """ + Find the rotation matrix. Parameters ---------- p_mov : npt.NDArray - [MISSING] + Source points. p_dst : npt.NDArray - [MISSING] + Destination points. Returns ------- R - Rotation matrix + Rotation matrix. Raises ------ ValueError - Shape of points should be identical - + Shape of points should be identical. """ if p_mov.shape != p_dst.shape: raise ValueError( @@ -131,20 +131,20 @@ def find_rotation(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: - """[MISSING]. + """ + Find rigid transformation matrix between two point sets. Parameters ---------- p_mov : npt.NDArray - [MISSING] + Source points. p_dst : npt.NDArray - [MISSING] + Destination points. Returns ------- T - Homogeneous transformation matrix - + Homogeneous transformation matrix. """ if p_mov.shape != p_dst.shape: raise ValueError( @@ -175,30 +175,28 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: return T def find_affine(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: - """Find affine by least squares solution of overdetermined system. + """ + Find affine by least squares solution of overdetermined system. Assuming we have more than 4 point pairs Parameters ---------- p_mov : npt.NDArray - [MISSING] + The source points. p_dst : npt.NDArray - [MISSING] + The destination points. Returns ------- T - Affine transformation matrix + Affine transformation matrix. Raises ------ ValueError - Shape of points should be identical - + Shape of points should be identical. """ - from scipy.linalg import pinv - if p_mov.shape != p_dst.shape: raise ValueError( "Shape of points should be identical, but mov = {}, dst = {} expecting Nx3".format( diff --git a/recon_surf/align_seg.py b/recon_surf/align_seg.py index 1866d1d9..98b5e5ed 100755 --- a/recon_surf/align_seg.py +++ b/recon_surf/align_seg.py @@ -39,7 +39,7 @@ Dependencies: - Python 3.8 + Python 3.8+ numpy SimpleITK https://simpleitk.org/ (v2.1.1) @@ -65,13 +65,13 @@ def options_parse(): - """Command line option parser. + """ + Command line option parser. Returns ------- options - object holding options - + Object holding options. """ parser = optparse.OptionParser( version="$Id:align_seg.py,v 1.0 2022/08/24 21:22:08 mreuter Exp $", @@ -100,24 +100,24 @@ def options_parse(): def get_seg_centroids(seg_mov: sitk.Image, seg_dst: sitk.Image, label_ids: Optional[npt.NDArray[int]] = []) -> Tuple[npt.NDArray, npt.NDArray]: - """Extract the centroids of the segmentation labels for mov and dst in RAS coords. + """ + Extract the centroids of the segmentation labels for mov and dst in RAS coords. Parameters ---------- seg_mov : sitk.Image - Source segmentation image + Source segmentation image. seg_dst : sitk.Image - Target segmentation image + Target segmentation image. label_ids : Optional[npt.NDArray[int]] - List of label ids to extract (Default value = []) + List of label ids to extract (Default value = []). Returns ------- centroids_mov - List of centroids of source segmentation + List of centroids of source segmentation. centroids_dst - List of centroids of target segmentation - + List of centroids of target segmentation. """ if not label_ids: # use all joint labels except -1 and 0: @@ -159,7 +159,8 @@ def align_seg_centroids( label_ids: Optional[npt.NDArray[int]] = [], affine: bool = False ) -> npt.NDArray: - """Align the segmentations based on label centroids (rigid is default). + """ + Align the segmentations based on label centroids (rigid is default). Parameters ---------- @@ -177,7 +178,6 @@ def align_seg_centroids( ------- T Aligned centroids RAS2RAS transform. - """ # get centroids of each label in image centroids_mov, centroids_dst = get_seg_centroids(seg_mov, seg_dst, label_ids) @@ -191,12 +191,13 @@ def align_seg_centroids( def get_vox2ras(img:sitk.Image) -> npt.NDArray: - """Extract voxel to RAS (affine) from sitk image. + """ + Extract voxel to RAS (affine) from sitk image. Parameters ---------- seg : sitk.Image - sitk Image. + Sitk Image. Returns ------- @@ -219,7 +220,8 @@ def get_vox2ras(img:sitk.Image) -> npt.NDArray: return vox2ras def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDArray: - """Registrate Left - right (make upright). + """ + Registrate Left - right (make upright). Register cortial lables @@ -235,7 +237,6 @@ def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDA ------- Tsqrt RAS2RAS transformation matrix for registration. - """ lhids = np.array( [ @@ -307,7 +308,7 @@ def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDA 2035, ] ) - l = lhids.size + ls = lhids.size label_stats = sitk.LabelShapeStatisticsImageFilter() label_stats.Execute(seg) centroids = np.empty([2 * lhids.size, 3]) @@ -341,7 +342,7 @@ def align_flipped(seg: sitk.Image, mid_slice: Optional[float] = None) -> npt.NDA # now right is left and left is right (re-order) centroids_flipped = np.concatenate( - (centroids_flipped[l::, :], centroids_flipped[0:l, :]) + (centroids_flipped[ls::, :], centroids_flipped[0:ls, :]) ) # register centroids to LR-flipped versions T = align.find_rigid(centroids, centroids_flipped) diff --git a/recon_surf/create_annotation.py b/recon_surf/create_annotation.py index fd2847bf..4e8345ba 100755 --- a/recon_surf/create_annotation.py +++ b/recon_surf/create_annotation.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# This is test line just to check rebase commit # IMPORTS import optparse @@ -42,7 +42,7 @@ Dependencies: - Python 3.8 + Python 3.8+ numpy, nibabel, sklearn @@ -90,13 +90,13 @@ def options_parse(): - """Command line option parser. + """ + Command line option parser. Returns ------- options - object holding options - + Object holding options. """ parser = optparse.OptionParser( version="$Id:create_annotation.py,v 1.0 2022/08/24 21:22:08 mreuter Exp $", @@ -156,45 +156,45 @@ def map_multiple_labels( out_dir: Optional[str] = None, stop_missing: bool = True ) -> Tuple[npt.ArrayLike, npt.ArrayLike]: - """Map a list of labels from one surface (e.g. fsavaerage sphere.reg) to another. + """ + Map a list of labels from one surface (e.g. fsavaerage sphere.reg) to another. Labels are just names without hemisphere or path, - which are passed via hemi, src_dir, out_dir) + which are passed via hemi, src_dir, out_dir). Parameters ---------- hemi : str - "lh" or "rh" for reading labels + "lh" or "rh" for reading labels. src_dir : str - director of the source file + Director of the source file. src_labels : npt.ArrayLike - List of labels + List of labels. src_sphere_name : str - filename of source sphere + Filename of source sphere. trg_sphere_name : str - filename of target sphere + Filename of target sphere. trg_white_name : str - filename of target white + Filename of target white. trg_sid : str - target subject id + Target subject id. out_dir : Optional[str] - directory for output, defaults to None + Directory for output, defaults to None. stop_missing : bool - determines whether to stop on a missing src label file, or continue - with a warning. Defaults to True + Determines whether to stop on a missing src label file, or continue + with a warning. Defaults to True. Returns ------- all_labels - mapped labels + Mapped labels. all_values - values of mapped labels + Values of mapped labels. Raises ------ ValueError - Label file missing - + Label file missing. """ # get reverse mapping (trg->src) for sampling rev_mapping, _, _ = getSurfCorrespondence(trg_sphere_name, src_sphere_name) @@ -214,7 +214,7 @@ def map_multiple_labels( # map label from src to target if os.path.exists(src_label_name): # print("Mapping label {}.{} ...".format(hemi,l_name)) - l, v = mapSurfLabel( + ll, vv = mapSurfLabel( src_label_name, out_label_name, trg_white, trg_sid, rev_mapping ) else: @@ -224,10 +224,10 @@ def map_multiple_labels( ) else: print("\nWARNING: Label file missing {}\n".format(src_label_name)) - l = [] - v = [] - all_labels.append(l) - all_values.append(v) + ll = [] + vv = [] + all_labels.append(ll) + all_values.append(vv) return all_labels, all_values @@ -236,37 +236,37 @@ def read_multiple_labels( input_dir: str, label_names: npt.ArrayLike ) -> Tuple[ List[npt.NDArray], List[npt.NDArray]]: - """Read multiple label files from input_dir. + """ + Read multiple label files from input_dir. Parameters ---------- hemi : str - "lh" or "rh" for reading labels + "lh" or "rh" for reading labels. input_dir : str - director of the source + Director of the source. label_names : npt.ArrayLike - List of labels + List of labels. Returns ------- all_labels - read labels + Read labels. all_values - values of read labels - + Values of read labels. """ all_labels = [] all_values = [] for l_name in label_names: label_file = os.path.join(input_dir, hemi + "." + l_name + ".label") if os.path.exists(label_file): - l, v = fs.read_label(label_file, read_scalars=True) + ll, vv = fs.read_label(label_file, read_scalars=True) else: print("\nWARNING: Label file missing {}\n".format(label_file)) - l = [] - v = [] - all_labels.append(l) - all_values.append(v) + ll = [] + vv = [] + all_labels.append(ll) + all_values.append(vv) return all_labels, all_values @@ -275,7 +275,8 @@ def build_annot(all_labels: npt.ArrayLike, all_values: npt.ArrayLike, col_ids: npt.ArrayLike, trg_white: Union[str, npt.NDArray], cortex_label_name: Optional[str] = None ) -> Tuple[npt.NDArray, npt.NDArray]: - """Create an annotation from multiple labels. + """ + Create an annotation from multiple labels. Here we also consider the label values and overwrite existing labels if values of current are larger (or equal, so the order of the labels matters). @@ -284,23 +285,22 @@ def build_annot(all_labels: npt.ArrayLike, all_values: npt.ArrayLike, Parameters ---------- all_labels : npt.ArrayLike - List of all Labels + List of all Labels. all_values : npt.ArrayLike - List of all values + List of all values. col_ids : npt.ArrayLike - List of col ids + List of col ids. trg_white : Union[str, npt.NDArray] - target file of white + Target file of white. cortex_label_name : Optional[str] - Path to the cortex label file. Defaults to None + Path to the cortex label file. Defaults to None. Returns ------- annot_ids - Ids of build Annotations + Ids of build Annotations. annot_vals - Values of build Annotations - + Values of build Annotations. """ # create annot from a bunch of labels (and values) if isinstance(trg_white, str): @@ -339,22 +339,22 @@ def build_annot(all_labels: npt.ArrayLike, all_values: npt.ArrayLike, def read_colortable(colortab_name: str) -> Tuple[npt.ArrayLike, List[str], npt.ArrayLike]: - """Read the colortable of given name. + """ + Read the colortable of given name. Parameters ---------- colortab_name : str - Path and Name of the colortable file + Path and Name of the colortable file. Returns ------- ids - List of ids + List of ids. names - List of names + List of names. colors - List of colors corresponding to ids and names - + List of colors corresponding to ids and names. """ colortab = np.genfromtxt(colortab_name, dtype="i8", usecols=(0, 2, 3, 4, 5)) ids = colortab[:, 0] @@ -371,25 +371,25 @@ def write_annot( out_annot: str, append: Union[None, str] = "" ) -> None: - """Combine the colortable with the annotations ids to write an annotation file. + """ + Combine the colortable with the annotations ids to write an annotation file. The annotation file contains colortable information Care needs to be taken that the colortable file has the same number - and order of labels as specified in the label_names list + and order of labels as specified in the label_names list. Parameters ---------- annot_ids : npt.ArrayLike - List of annotation ids + List of annotation ids. label_names : npt.ArrayLike - list of label names + List of label names. colortab_name : str - Path and name of colortable file + Path and name of colortable file. out_annot : str - Path and name of output annotation file + Path and name of output annotation file. append : Union[None, str] - String to append to colour name. Defaults to "" - + String to append to colour name. Defaults to "". """ # colortab_name="colortable_BA.txt" col_ids, col_names, col_colors = read_colortable(colortab_name) @@ -413,14 +413,15 @@ def write_annot( def create_annotation(options, verbose: bool = True) -> None: - """Map (if required), build and write annotation. + """ + Map (if required), build and write annotation. (Main function) Parameters ---------- - options : - object holding options + options : Any + Object holding options hemi: "lh" or "rh" for reading labels colortab: colortab with label ids, names and colors labeldir: dir where to find the label files (when reading) @@ -429,10 +430,9 @@ def create_annotation(options, verbose: bool = True) -> None: cortex: optional path to hemi.cortex for optional masking of annotation to only cortex append: optional, e.g. ".thresh" can be appended to label names (I/O) for exvivo FS labels srcsphere: optional, when mapping: path to src sphere.reg - trgsphere: optional, when mapping: path to trg sphere.reg + trgsphere: optional, when mapping: path to trg sphere.reg. verbose : bool - True if options should be printed. Defaults to True - + True if options should be printed. Defaults to True. """ print() print("Map BA Labels Parameters:") diff --git a/recon_surf/fs_balabels.py b/recon_surf/fs_balabels.py index a13e9e44..0fbdbb7b 100755 --- a/recon_surf/fs_balabels.py +++ b/recon_surf/fs_balabels.py @@ -23,7 +23,6 @@ from typing import Tuple, List import numpy as np import sys -import nibabel.freesurfer.io as fs from create_annotation import ( map_multiple_labels, read_colortable, @@ -42,7 +41,7 @@ --fsaverage --hemi Dependencies: - Python 3.8 + Python 3.8+ numpy, nibabel, sklearn Also FreeSurfer v7.3.2 is needed @@ -74,13 +73,13 @@ def options_parse(): - """Command line option parser. + """ + Create a command line interface and return command line options. Returns ------- - options - object holding options - + options : argparse.Namespace + Namespace object holding options. """ parser = optparse.OptionParser( version="$Id:fs_balabels.py,v 1.0 2022/08/24 21:22:08 mreuter Exp $", @@ -110,27 +109,27 @@ def read_colortables( colappend: List[str], drop_unknown: bool = True ) -> Tuple[List, List, List]: - """Read multiple colortables and appends extensions, drops unknown by default. + """ + Read multiple colortables and appends extensions, drops unknown by default. Parameters ---------- colnames : List[str] - List of color-names + List of color-names. colappend : List[str] - List of appends for names + List of appends for names. drop_unknown : bool True if unknown colors should be dropped. - Defaults to True + Defaults to True. Returns ------- all_ids - List of all ids + List of all ids. all_names - List of all names + List of all names. all_cols - List of all colors - + List of all colors. """ pos = 0 all_names = [] diff --git a/recon_surf/fs_time b/recon_surf/fs_time index 8319946f..62b3d712 100755 --- a/recon_surf/fs_time +++ b/recon_surf/fs_time @@ -1,152 +1,34 @@ -#!/bin/tcsh -f +#!/bin/bash # fs_time -set VERSION = '$Id: fs_time,v 1.11 2016/02/16 17:17:20 zkaufman Exp $'; -set outfile = (); -set key = ("@#@FSTIME "); +VERSION='$Id: fs_time,v 1.0 2024/03/08 15:12:00 kueglerd Exp $' +outfile="" +key="@#@FSTIME " +cmd=() +verbose=0 -if($?FSTIME_LOAD == 0) then +if [[ -z "$FSTIME_LOAD" ]] +then # Turn on by default - setenv FSTIME_LOAD 1 -endif - -set inputargs = ($argv); -set PrintHelp = 0; -if($#argv == 0) goto usage_exit; -set n = `echo $argv | grep -e -help | wc -l` -if($n != 0) then - set PrintHelp = 1; - goto usage_exit; -endif -set n = `echo $argv | grep -e -version | wc -l` -if($n != 0) then - echo $VERSION - exit 0; -endif - -source $FREESURFER_HOME/sources.csh - -goto parse_args; -parse_args_return: -goto check_params; -check_params_return: - -@ nargs = $#argv - 1 - -if($FSTIME_LOAD) then - set upt = `uptime | sed 's/,/ /g'`; - @ a = $#upt - 2 - @ b = $#upt - 1 - #echo "@#@FSLOADPRE $dt $argv[1] N $nargs $upt[$a] $upt[$b] $upt[$#upt]" - set upt = "L $upt[$a] $upt[$b] $upt[$#upt]" -else - set upt = "" -endif - -set dt = `date '+%Y:%m:%d:%H:%M:%S'` -set fmt = "$key $dt $argv[1] N $nargs e %e S %S U %U P %P M %M F %F R %R W %W c %c w %w I %I O %O $upt" - -set cmd = /usr/bin/time -if($#outfile) set cmd = ($cmd -o $outfile) -$cmd -f "$fmt" $argv -set st = $status -if($#outfile) cat $outfile - -if($FSTIME_LOAD) then - set dt = `date '+%Y:%m:%d:%H:%M:%S'` - set upt = `uptime | sed 's/,/ /g'`; - @ a = $#upt - 2 - @ b = $#upt - 1 - echo "@#@FSLOADPOST $dt $argv[1] N $nargs $upt[$a] $upt[$b] $upt[$#upt]" -endif - -exit $st - -############################################### - -############--------------################## -parse_args: -set cmdline = ($argv); -while( $#argv != 0 ) - - set flag = $argv[1]; shift; - - switch($flag) - - case "-o": - if($#argv < 1) goto arg1err; - set outfile = $argv[1]; shift; - breaksw - - case "-k": - if($#argv < 1) goto arg1err; - set key = $argv[1]; shift; - breaksw - - case "-l": - case "-load": - setenv FSTIME_LOAD 1 - breaksw - case "-no-load": - setenv FSTIME_LOAD 0 - breaksw - - case "-debug": - set verbose = 1; - set echo = 1; - breaksw - - default: - # must be at the start of the command to run - # put item back into the list - set argv = ($flag $argv) - break; - breaksw - endsw - -end - -goto parse_args_return; -############--------------################## - -############--------------################## -check_params: - -if(! -e /usr/bin/time) then - echo "ERROR: cannot find /usr/bin/time" - exit 1; -endif - -if($#argv == 0) then - goto usage_exit; -endif - -goto check_params_return; -############--------------################## - -############--------------################## -arg1err: - echo "ERROR: flag $flag requires one argument" - exit 1 - -############--------------################## -usage_exit: - echo "" - echo "fs_time [options] command args" - echo " options:" - echo " -o outputfile : save resource info into outputfile" - echo " -k key" - echo " -l : report on load averages as from uptime" + export FSTIME_LOAD=1 +fi - if(! $PrintHelp) exit 1; - echo $VERSION - cat $0 | awk 'BEGIN{prt=0}{if(prt) print $0; if($1 == "BEGINHELP") prt = 1 }' -exit 1; +function usage() +{ + cat << EOF -#---- Everything below here is printed out as part of help -----# -BEGINHELP - -This is a frontend for the unix /usr/bin/time program to keep track of +fs_time [options] command args + options: + -o outputfile : save resource info into outputfile + -k key + -l : report on load averages as from uptime +EOF +} + +function help() +{ + cat << EOF +This is a frontend for the unix /usr/bin/time program to keep track of resources used by a process. The basic usage is like that of time, ie, fs_time [options] command args @@ -169,16 +51,16 @@ Default fs_time Output (see also the manual page for /usr/bin/time): 7. U Total number of CPU-seconds that the process spent in user mode. 8. P Percentage of the CPU that this job got, computed as (U+S)/e. 9. M Maximum resident set size of the process during its lifetime, in Kbytes. -10. F Number of major page faults that occurred while the process was running. +10. F Number of major page faults that occurred while the process was running. These are faults where the page has to be read in from disk. 11. R Number of minor, or recoverable, page faults. These are faults for pages that are not valid but which have not yet been claimed by other virtual pages. Thus the data in the page is still valid but the system tables must be updated. 12. W Number of times the process was swapped out of main memory. -13. c Number of times the process was context-switched involuntarily - (because the time slice expired). -14. w Number of waits: times that the program was context-switched voluntarily, +13. c Number of times the process was context-switched involuntarily + (because the time slice expired). +14. w Number of waits: times that the program was context-switched voluntarily, for instance while waiting for an I/O operation to complete. 15. I Number of file system inputs by the process. 16. O Number of file system outputs by the process. @@ -187,14 +69,14 @@ Default fs_time Output (see also the manual page for /usr/bin/time): Example: fs_time -o resource.dat mri_convert orig.mgz myfile.mgz -mri_convert orig.mgz myfile.mgz +mri_convert orig.mgz myfile.mgz reading from orig.mgz... TR=2730.00, TE=3.44, TI=1000.00, flip angle=7.00 i_ras = (-1, 0, 0) j_ras = (2.38419e-07, 0, -1) k_ras = (-1.93715e-07, 1, 0) writing to myfile.mgz... -@#@FSTIME 2016:01:21:18:27:08 mri_convert N 2 e 2.20 S 0.05 U 1.64 P 77% M 23628 F 0 R 5504 W 0 c 7 w 3 I 0 O 20408 +@#@FSTIME 2016:01:21:18:27:08 mri_convert N 2 e 2.20 S 0.05 U 1.64 P 77% M 23628 F 0 R 5504 W 0 c 7 w 3 I 0 O 20408 The above command runs the mri_convert command with two arguments and produces the information about resources. It also creates a file @@ -224,5 +106,147 @@ If the env variable FSTIME_LOAD is set to 1, the output looks something like The 3 numbers are the system load averages for the past 1, 5, and 15 minutes as given by uptime. +EOF +} + +function arg1err() +{ + # param 1 : flag + echo "ERROR: flag $1 requires one argument" + exit 1 +} + +inputargs=("$@") +any_help=$(echo "$@" | grep -e -help) +if [[ -n "$any_help" ]] +then + usage + help + exit 0 +fi +any_version=$(echo "$@" | grep -e -version) +if [[ -n "$any_version" ]] +then + echo "$VERSION" + exit 0 +fi + +# sourcing FreeSurfer should not be needed at this point, +# source $FREESURFER_HOME/sources.sh + +cmdline=("$@") +while [[ $# != 0 ]] +do + + flag=$1 + shift + + case $flag in + -o) + if [[ "$#" -lt 1 ]] ; then arg1err "$flag"; fi + outfile=$1 + shift + ;; + -k) + if [[ "$#" -lt 1 ]] ; then arg1err "$flag" ; fi + key=$1 + shift + ;; + -l|-load) + export FSTIME_LOAD=1 + ;; + -no-load) + export FSTIME_LOAD=0 + ;; + -debug) + verbose=1 + ;; + *) + # must be at the start of the command to run + # put item back into the list + cmd=("$flag" "$@") + break + ;; + esac + +done + +if [[ "$verbose" == 1 ]] +then + echo "Parameters to fs_time:" + if [[ -n "$outfile" ]] ; then echo "-o $outfile" ; fi + if [[ "$key" != "@#@FSTIME " ]] ; then echo "-k $key" ; fi + echo "FSTIME_LOAD=$FSTIME_LOAD" + echo "command:" + echo "${cmd[*]}" + echo "" +fi + +# CHECK PARAMS +if [[ "${#cmd[@]}" == 0 ]] +then + usage + echo "ERROR: no command passed to execute" + exit 1 +fi + +if [[ ! -e /usr/bin/time ]] +then + echo "ERROR: cannot find /usr/bin/time" + exit 1 +fi + +command="${cmd[0]}" +npyargs=0 +# remove python from command +if [[ "$command" =~ python(3(.[0-9]+)?)?$ ]] +then + npyargs=1 + for c in "${cmd[@]:1}" ; do + if [[ ! "$c" =~ ^- ]] ; then command="$c" ; break ; fi + npyargs=$((npyargs + 1)) + done +fi +# remove $FASTSURFER_HOME path from command +command_short="${command:0:${#FASTSURFER_HOME}}" +if [[ -n "$FASTSURFER_HOME" ]] && [[ "$command_short" == "$FASTSURFER_HOME" ]] +then + command="${command:${#command_short}}" + while [[ "${command:0:1}" == "/" ]] ; do command="${command:1}" ; done +fi + +nargs=$((${#cmd[@]} - 1 - npyargs)) + +function make_string () +{ + # param 1 : key + # param 2 : command + # param 3 : num args + dt=$(date '+%Y:%m:%d:%H:%M:%S') + uptime_data_array=($(uptime | sed 's/,/ /g')) + echo "$1 $dt $2 N $3 ${uptime_data_array[-3]} ${uptime_data_array[-2]} ${uptime_data_array[-1]}" + export upt="L ${uptime_data_array[-3]} ${uptime_data_array[-2]} ${uptime_data_array[-1]}" +} + +if [[ "$FSTIME_LOAD" == 1 ]] +then + make_string "@#@FSLOADPRE" "$command" "$nargs" +else + upt="" +fi + +dt=$(date '+%Y:%m:%d:%H:%M:%S') +fmt="$key $dt $command N $nargs e %e S %S U %U P %P M %M F %F R %R W %W c %c w %w I %I O %O $upt" +timecmd=("/usr/bin/time") +if [[ -n "$outfile" ]] ; then timecmd=("${timecmd[@]}" -o "$outfile"); fi +"${timecmd[@]}" -f "$fmt" "${cmd[@]}" +st=$? +if [[ -n "$outfile" ]] ; then cat $outfile; fi +if [[ "$FSTIME_LOAD" == 1 ]] +then + make_string "@#@FSLOADPOST" "$command" "$nargs" +fi + +exit $st diff --git a/recon_surf/functions.sh b/recon_surf/functions.sh new file mode 100644 index 00000000..f256444e --- /dev/null +++ b/recon_surf/functions.sh @@ -0,0 +1,146 @@ + +# set the binpath variable +if [ -z "$FASTSURFER_HOME" ] +then + binpath="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )/" +else + binpath="$FASTSURFER_HOME/recon_surf/" +fi +export binpath + +# fs_time command from fs60, fs72 fails in parallel mode, use local one +# also check for failure (e.g. on mac it fails) +timecmd="${binpath}fs_time" +$timecmd echo testing &> /dev/null +if [ "${PIPESTATUS[0]}" -ne 0 ] ; then + echo "time command failing, not using time..." + timecmd="" +fi +export timecmd + +function RunIt() +{ + # parameters + # $1 : cmd (command to run) + # $2 : LF (log file) + # $3 : CMDF (command file) optional + # if CMDF is passed, then LF is ignored and cmd is echoed into CMDF and not run + local cmd=$1 + local LF=$2 + if [[ $# -eq 3 ]] + then + local CMDF=$3 + printf -v tmp %q "$cmd" + echo "echo $tmp" | tee -a $CMDF + echo "$timecmd $cmd" | tee -a $CMDF + echo "if [ \${PIPESTATUS[0]} -ne 0 ] ; then exit 1 ; fi" >> $CMDF + else + echo "$cmd" | tee -a "$LF" + $timecmd $cmd 2>&1 | tee -a "$LF" + if [ "${PIPESTATUS[0]}" -ne 0 ] ; then exit 1 ; fi + fi +} + +function RunBatchJobs() +{ +# parameters +# $1 : LF +# $2 ... : CMDFS + local LOG_FILE=$1 + # launch jobs found in command files (shift past first logfile arg). + # job output goes to a logfile named after the command file, which + # later gets appended to LOG_FILE + + echo + echo "RunBatchJobs: Logfile: $LOG_FILE" + + local PIDS=() + local LOGS=() + shift + local JOB + local LOG + for cmdf in "$@"; do + echo "RunBatchJobs: CMDF: $cmdf" + chmod u+x "$cmdf" + JOB="$cmdf" + LOG=$cmdf.log + echo "" >& "$LOG" + echo " $JOB" >> "$LOG" + echo "" >> "$LOG" + exec "$JOB" >> "$LOG" 2>&1 & + PIDS=("${PIDS[@]}" "$!") + LOGS=("${LOGS[@]}" "$LOG") + + done + # wait till all processes have finished + local PIDS_STATUS=() + for pid in "${PIDS[@]}"; do + echo "Waiting for PID $pid of (${PIDS[*]}) to complete..." + wait "$pid" + PIDS_STATUS=("${PIDS_STATUS[@]}" "$?") + done + # now append their logs to the main log file + for log in "${LOGS[@]}" + do + cat "$log" >> "$LOG_FILE" + rm -f "$log" + done + echo "PIDs (${PIDS[*]}) completed and logs appended." + # and check for failures + for pid_status in "${PIDS_STATUS[@]}" + do + if [ "$pid_status" != "0" ] ; then + exit 1 + fi + done +} + +function softlink_or_copy() +{ + # params + # 1: file + # 2: target + # 3: logfile + # 4: cmdf + local LF="$3" + local ln_cmd=(ln -sf "$1" "$2") + local cp_cmd=(cp "$1" "$2") + if [[ $# -eq 4 ]] + then + local CMDF=$4 + { + echo "echo $(echo_quoted "${ln_cmd[@]}")" + echo "$timecmd $(echo_quoted "${ln_cmd[@]}")" + echo "if [ \${PIPESTATUS[0]} -ne 0 ]" + echo "then" + echo " echo $(echo_quoted "${cp_cmd[@]}")" + echo " $timecmd $(echo_quoted "${cp_cmd[@]}")" + echo " if [ \${PIPESTATUS[0]} -ne 0 ] ; then exit 1 ; fi" + echo "fi" + } | tee -a "$CMDF" + else + { + echo_quoted "${ln_cmd[@]}" + $timecmd "${ln_cmd[@]}" 2>&1 + if [ "${PIPESTATUS[0]}" -ne 0 ] + then + echo_quoted "${cp_cmd[@]}" + $timecmd "${cp_cmd[@]}" 2>&1 + if [ "${PIPESTATUS[0]}" -ne 0 ] ; then exit 1 ; fi + fi + } | tee -a "$LF" + fi +} + +function echo_quoted() +{ + # params ... 1-N + sep="" + for i in "$@" + do + if [[ "${i/ /}" != "$i" ]] ; then j="%q" ; else j="%s" ; fi + printf "%s$j" "$sep" "$i" + sep=" " + done + echo "" +} \ No newline at end of file diff --git a/recon_surf/image_io.py b/recon_surf/image_io.py index 3e3306bd..44e53100 100644 --- a/recon_surf/image_io.py +++ b/recon_surf/image_io.py @@ -181,7 +181,8 @@ def writeITKimage( filename: str, header: Optional[nib.freesurfer.mghformat.MGHHeader] = None ) -> None: - """[MISSING]. + """ + Writes the given ITK image to a file. Parameters ---------- diff --git a/recon_surf/lta.py b/recon_surf/lta.py index dd38e1dc..3a6f32f1 100755 --- a/recon_surf/lta.py +++ b/recon_surf/lta.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -from typing import Dict # Copyright 2021 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn # @@ -24,43 +23,43 @@ def writeLTA( filename: str, T: npt.ArrayLike, src_fname: str, - src_header: Dict, + src_header: dict, dst_fname: str, - dst_header: Dict + dst_header: dict ) -> None: - """Write linear transform array info to a .lta file. + """ + Write linear transform array info to an .lta file. Parameters ---------- filename : str - File to write on + File to write on. T : npt.ArrayLike - Linear transform array to be saved + Linear transform array to be saved. src_fname : str - Source filename + Source filename. src_header : Dict - Source header + Source header. dst_fname : str - Destination filename + Destination filename. dst_header : Dict - Destination header + Destination header. Raises ------ ValueError - Header format missing field (Source or Destination) - + Header format missing field (Source or Destination). """ from datetime import datetime import getpass fields = ("dims", "delta", "Mdc", "Pxyz_c") for field in fields: - if not field in src_header: + if field not in src_header: raise ValueError( "writeLTA Error: src_header format missing field: {}".format(field) ) - if not field in dst_header: + if field not in dst_header: raise ValueError( "writeLTA Error: dst_header format missing field: {}".format(field) ) diff --git a/recon_surf/map_surf_label.py b/recon_surf/map_surf_label.py index 4b47c5da..52f0d317 100755 --- a/recon_surf/map_surf_label.py +++ b/recon_surf/map_surf_label.py @@ -36,7 +36,7 @@ Dependencies: - Python 3.8 + Python 3.8+ numpy, nibabel, sklearn @@ -59,13 +59,13 @@ def options_parse(): - """Command line option parser. + """ + Create a command line interface and return command line options. Returns ------- - options - object holding options - + options : argparse.Namespace + Namespace object holding options. """ parser = optparse.OptionParser( version="$Id:map_surf_label.py,v 1.0 2022/08/24 21:22:08 mreuter Exp $", @@ -99,30 +99,30 @@ def writeSurfLabel( values: npt.NDArray, surf: npt.NDArray ) -> None: - """Write a FS surface label file to filename (e.g. lh.labelname.label). + """ + Write a FS surface label file to filename (e.g. lh.labelname.label). Stores sid string in the header, then number of vertices and table of vertex index, RAS wm-surface coords (taken from surf) - and values (which can be zero) + and values (which can be zero). Parameters ---------- filename : str - File there surface label is written + File there surface label is written. sid : str - Subject id + Subject id. label : npt.NDArray[str] - List of label names + List of label names. values : npt.NDArray - List of values + List of values. surf : npt.NDArray - Surface coordinations + Surface coordinations. Raises ------ ValueError - Label and values should have same sizes - + If label and values are not the same size. """ if values is None: values = np.zeros(label.shape) @@ -151,32 +151,33 @@ def getSurfCorrespondence( trg_sphere: Union[str, Tuple, np.ndarray], tree: Optional[KDTree] = None ) -> Tuple[np.ndarray, np.ndarray, KDTree]: - """For each vertex in src_sphere find the closest vertex in trg_sphere. + """ + For each vertex in src_sphere find the closest vertex in trg_sphere. - Spheres are Nx3 arrays of coordinates on the sphere (usually R=100 FS format) - *_sphere can also be a file name of the sphere.reg files, then we load it. - The KDtree can be passed in cases where src moves around and trg stays fixed + src_sphere and trg_sphere are Nx3 arrays of coordinates on the sphere + (usually radius R=100 FS format). They also be a filenames of the corresponding + sphere.reg files to be loaded from disk. The KDtree can optionally be passed in + cases where src moves around and trg stays fixed. Parameters ---------- src_sphere : Union[str, Tuple, np.ndarray] Either filepath (as str) or surface vertices - of source sphere + of source sphere. trg_sphere : Union[str, Tuple, np.ndarray] Either filepath (as str) or surface vertices - of target sphere + of target sphere. tree : Optional[KDTree] - Defaults to None + Defaults to None. Returns ------- mapping : np.ndarray - Surface mapping of the trg surface + Surface mapping of the trg surface. distances : np.ndarray - Surface distance of the trg surface + Surface distance of the trg surface. tree : KDTree - KDTree of the trg surface - + KDTree of the trg surface. """ # We can also work with file names instead of surface vertices if isinstance(src_sphere, str): @@ -204,36 +205,37 @@ def mapSurfLabel( trg_sid: str, rev_mapping: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: - """Map a label from src surface according to the correspondence. + """ + Map a label from src surface according to the correspondence. trg_surf is passed so that the label file will list the - correct coordinates (usually the white surface), can be vetrices or filename + correct coordinates (usually the white surface), can be vetrices or filename. Parameters ---------- src_label_name : str - Path to label file of source + Path to label file of source. out_label_name : str - Path to label file of output + Path to label file of output. trg_surf : Union[str, np.ndarray] - Numpy array of vertex coordinates or filepath to target surface + Numpy array of vertex coordinates or filepath to target surface. trg_sid : str - Subject id of the target subject (as stored in the output label file header) + Subject id of the target subject (as stored in the output label file header). rev_mapping : np.ndarray - a mapping from target to source, listing the corresponding src vertex for each vertex on the trg surface + A mapping from target to source, listing the corresponding src vertex + for each vertex on the trg surface. Returns ------- trg_label : np.ndarray - target labels + Target labels. trg_values : np.ndarray - target values + Target values. Raises ------ ValueError - Label and trg vertices should have same sizes - + If label and trg vertices are not of same sizes. """ print("Mapping label: {} ...".format(src_label_name)) src_label, src_values = fs.read_label(src_label_name, read_scalars=True) @@ -252,11 +254,8 @@ def mapSurfLabel( inside[src_label] = True values = np.zeros(smax) values[src_label] = src_values - inside_trg = inside[rev_mapping] trg_label = np.nonzero(inside[rev_mapping])[0] trg_values = values[rev_mapping[trg_label]] - # print(trg_values) - # print(trg_label.size) if out_label_name is not None: writeSurfLabel(out_label_name, trg_sid, trg_label, trg_values, trg_surf) return trg_label, trg_values diff --git a/recon_surf/paint_cc_into_pred.py b/recon_surf/paint_cc_into_pred.py index 9afcaac6..1d41ae57 100644 --- a/recon_surf/paint_cc_into_pred.py +++ b/recon_surf/paint_cc_into_pred.py @@ -14,11 +14,12 @@ # IMPORTS -import numpy as np -from numpy import typing as npt -import nibabel as nib + import sys import argparse +import numpy as np +import nibabel as nib +from numpy import typing as npt HELPTEXT = """ Script to add corpus callosum segmentation (CC, FreeSurfer IDs 251-255) to @@ -30,7 +31,7 @@ Dependencies: - Python 3.8 + Python 3.8+ Nibabel to read and write FreeSurfer data http://nipy.org/nibabel/ @@ -42,13 +43,13 @@ def argument_parse(): - """Command line option parser. + """ + Create a command line interface and return command line options. Returns ------- - options - object holding options - + options : argparse.Namespace + Namespace object holding options. """ parser = argparse.ArgumentParser(usage=HELPTEXT) parser.add_argument( @@ -79,22 +80,22 @@ def argument_parse(): def paint_in_cc(pred: npt.ArrayLike, aseg_cc: npt.ArrayLike) -> npt.ArrayLike: - """Paint corpus callosum segmentation into prediction. + """ + Paint corpus callosum segmentation into aseg+dkt segmentation map. Note, that this function modifies the original array and does not create a copy. Parameters ---------- - pred : npt.ArrayLike - Deep-learning prediction + asegdkt : npt.ArrayLike + Deep-learning segmentation map. aseg_cc : npt.ArrayLike - Aseg segmentation with CC + Aseg segmentation with CC. Returns ------- - pred - Prediction with added CC - + asegdkt + Segmentation map with added CC. """ cc_mask = (aseg_cc >= 251) & (aseg_cc <= 255) pred[cc_mask] = aseg_cc[cc_mask] @@ -115,3 +116,6 @@ def paint_in_cc(pred: npt.ArrayLike, aseg_cc: npt.ArrayLike) -> npt.ArrayLike: pred_with_cc_fin.to_filename(options.output) sys.exit(0) + + +# TODO: Rename the file (paint_cc_into_asegdkt or similar) and functions. diff --git a/recon_surf/recon-surf.sh b/recon_surf/recon-surf.sh index 36d31cad..19feb9dc 100755 --- a/recon_surf/recon-surf.sh +++ b/recon_surf/recon-surf.sh @@ -30,6 +30,7 @@ DoParallel=0 # if 1, run hemispheres in parallel threads="1" # number of threads to use for running FastSurfer allow_root="" # flag for allowing execution as root user atlas3T="false" # flag to use/do not use the 3t atlas for talairach registration/etiv +segstats_legacy="false" # flag to enable segstats legacy mode # Dev flags default check_version=1 # Check for supported FreeSurfer version (terminate if not detected) @@ -38,25 +39,17 @@ hires_voxsize_threshold=0.999 # Threshold below which the hires options are pas if [ -z "$FASTSURFER_HOME" ] then - binpath="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )/" + binpath="$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )/" + FASTSURFER_HOME="$(cd -- "$(dirname "$binpath")" >/dev/null 2>&1 ; pwd -P )/" else binpath="$FASTSURFER_HOME/recon_surf/" fi -# fs_time command from fs60, fs72 fails in parallel mode, use local one -# also check for failure (e.g. on mac it fails) -timecmd="${binpath}fs_time" -$timecmd echo testing &> /dev/null -if [ ${PIPESTATUS[0]} -ne 0 ] ; then - echo "time command failing, not using time..." - timecmd="" -fi - -# check bash version > 4 +# check bash version > 3.1 (needed for printf %q) function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; } -if [ $(version ${BASH_VERSION}) -lt $(version "4.0.0") ]; then - echo "bash ${BASH_VERSION} is too old. Should be newer than 4.0, please upgrade!" +if [ "$(version "${BASH_VERSION}")" -lt "$(version "3.1.0")" ]; then + echo "bash ${BASH_VERSION} is too old. Should be newer than 3.1, please upgrade!" exit 1 fi @@ -89,22 +82,28 @@ FLAGS: that segmentations produced otherwise are conformed. Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/aparc.DKTatlas+aseg.deep.mgz - --fstess Switch on mri_tesselate for surface creation + --fstess Revert to FreeSurfer mri_tesselate for surface creation (default: mri_mc) - --fsqsphere Use FreeSurfer iterative inflation for qsphere + --fsqsphere Revert to FreeSurfer iterative inflation for qsphere (default: spectral spherical projection) --fsaparc Additionally create FS aparc segmentations and ribbon. Skipped by default (--> DL prediction is used which - is faster, and usually these mapped ones are fine) + is faster, and usually these mapped ones are fine). + Note, if you switch this on it will create all cortical + parcellations with FreeSurfer's spherical atlases and + also map these into the aparc+aseg file instead of + the FastSurfer ones. FastSurfer's cortical DKT atlas + results can still be found in: + .aparc.DKTatlas.mapped.stats --3T Use the 3T atlas for talairach registration (gives better - etiv estimates for 3T MR images, default: 1.5T atlas). + eTIV estimates for 3T MR images, default: 1.5T atlas). --parallel Run both hemispheres in parallel --threads Set openMP and ITK threads to - --py Command for python, default $python + --py Command for python, default ${python} --fs_license Path to FreeSurfer license key file. Register at https://surfer.nmr.mgh.harvard.edu/registration.html for free to obtain it if you do not have FreeSurfer - installed already + installed already. -h --help Print Help Dev Flags: @@ -115,9 +114,9 @@ Dev Flags: standard FreeSurfer output) and create brainmask.mgz directly from norm.mgz instead. Saves 1:30 min. --no_surfreg Do not run Surface registration with FreeSurfer (for - cross-subject correspondence), Not recommended, but - speeds up processing if you e.g. just need the - segmentation stats + cross-subject correspondence). Not recommended, but + speeds up processing if you just need the stats and + don't want to do thickness analysis on the cortex. --allow_root Allow execution as root user REFERENCES: @@ -137,82 +136,8 @@ EOF } - -function RunIt() -{ -# parameters -# $1 : cmd (command to run) -# $2 : LF (log file) -# $3 : CMDF (command file) optional -# if CMDF is passed, then LF is ignored and cmd is echoed into CMDF and not run - cmd=$1 - LF=$2 - if [[ $# -eq 3 ]] - then - CMDF=$3 - echo "echo \"$cmd\" " |& tee -a $CMDF - echo "$timecmd $cmd " |& tee -a $CMDF - echo "if [ \${PIPESTATUS[0]} -ne 0 ] ; then exit 1 ; fi" >> $CMDF - else - echo $cmd |& tee -a $LF - $timecmd $cmd |& tee -a $LF - if [ ${PIPESTATUS[0]} -ne 0 ] ; then exit 1 ; fi - fi -} - -function RunBatchJobs() -{ -# parameters -# $1 : LF -# $2 ... : CMDFS - LOG_FILE=$1 - # launch jobs found in command files (shift past first logfile arg). - # job output goes to a logfile named after the command file, which - # later gets appended to LOG_FILE - - echo - echo "RunBatchJobs: Logfile: $LOG_FILE" - - PIDS=() - LOGS=() - shift - for cmdf in $*; do - echo "RunBatchJobs: CMDF: $cmdf" - chmod u+x $cmdf - JOB="$cmdf" - LOG=$cmdf.log - echo "" >& $LOG - echo " $JOB" >> $LOG - echo "" >> $LOG - exec $JOB >> $LOG 2>&1 & - PIDS=(${PIDS[@]} $!) - LOGS=(${LOGS[@]} $LOG) - - done - # wait till all processes have finished - PIDS_STATUS=() - for pid in "${PIDS[@]}"; do - echo "Waiting for PID $pid of (${PIDS[*]}) to complete..." - wait $pid - PIDS_STATUS=(${PIDS_STATUS[@]} $?) - done - # now append their logs to the main log file - for log in "${LOGS[@]}" - do - cat $log >> $LOG_FILE - rm -f $log - done - echo "PIDs (${PIDS[*]}) completed and logs appended." - # and check for failures - for pid_status in "${PIDS_STATUS[@]}" - do - if [ "$pid_status" != "0" ] ; then - exit 1 - fi - done -} - - +# Load the RunIt and the RunBatchJobs functions, also sets up timecmd +source "$binpath/functions.sh" # PRINT USAGE if called without params if [[ $# -eq 0 ]] @@ -229,128 +154,67 @@ while [[ $# -gt 0 ]] do # make key lowercase key=$(echo "$1" | tr '[:upper:]' '[:lower:]') +shift # past argument case $key in - --sid) - subject="$2" - shift # past argument - shift # past value - ;; - --sd) - export SUBJECTS_DIR="$2" - shift # past argument - shift # past value - ;; - --t1) - t1="$2" - shift # past argument - shift # past value - ;; - --asegdkt_segfile | --aparc_aseg_segfile | --seg) + --sid) subject="$1" ; shift ;; + --sd) export SUBJECTS_DIR="$1" ; shift ;; + --t1) t1="$1" ; shift ;; + --asegdkt_segfile | --aparc_aseg_segfile | --seg) if [ "$key" == "--seg" ] || [ "$key" == "--aparc_aseg_segfile" ]; then - echo "WARNING: $1 is deprecated and will be removed, use --asegdkt_segfile ." + echo "WARNING: $key is deprecated and will be removed, use --asegdkt_segfile ." fi - asegdkt_segfile="$2" - shift # past argument + asegdkt_segfile="$1" shift # past value ;; - --vol_segstats) + --vol_segstats) echo "WARNING: the --vol_segstats flag is obsolete and will be removed, --vol_segstats ignored." - shift # past argument - ;; - --fstess) - fstess=1 - shift # past argument - ;; - --fsqsphere) - fsqsphere=1 - shift # past argument - ;; - --fsaparc) - fsaparc=1 - shift # past argument - ;; - --no_surfreg) - fssurfreg=0 - shift # past argument - ;; - --3t) - atlas3T="true" - shift ;; - --parallel) - DoParallel=1 - shift # past argument - ;; - --threads) - threads="$2" - shift # past argument - shift # past value - ;; - --py) - python="$2" - shift # past argument - shift # past value - ;; - --fs_license) - if [ -f "$2" ]; then - export FS_LICENSE="$2" + --segstats_legacy) segstats_legacy="true" ;; + --fstess) fstess=1 ;; + --fsqsphere) fsqsphere=1 ;; + --fsaparc) fsaparc=1 ;; + --no_surfreg) fssurfreg=0 ;; + --3t) atlas3T="true" ;; + --parallel) DoParallel=1 ;; + --threads) threads="$1" ; shift ;; + --py) python="$1" ; shift ;; + --fs_license) + if [ -f "$1" ]; then + export FS_LICENSE="$1" else - echo "Provided FreeSurfer license file $2 could not be found. Make sure to provide the full path and name. Exiting..." - exit 1; + echo "Provided FreeSurfer license file $1 could not be found. Make sure to provide the full path and name. Exiting..." + exit 1; fi - shift # past argument shift # past value ;; - --ignore_fs_version) - check_version=0 - shift # past argument - ;; - --no_fs_t1 ) - get_t1=0 - shift # past argument - ;; - --allow_root) - allow_root="--allow_root" - shift # past argument - ;; - -h|--help) - usage - exit - ;; - *) # unknown option - echo ERROR: Flag $key unrecognized. - exit 1 - ;; + --ignore_fs_version) check_version=0 ;; + --no_fs_t1 ) get_t1=0 ;; + --allow_root) allow_root="--allow_root" ;; + -h|--help) usage ; exit ;; + # unknown option + *) echo "ERROR: Flag $key unrecognized." ; exit 1 ;; esac done set -- "${POSITIONAL[@]}" # restore positional parameters # CHECKS -echo -echo sid $subject -echo T1 $t1 -echo asegdkt_segfile $asegdkt_segfile -echo +echo "" +echo "sid $subject" +echo "T1 $t1" +echo "asegdkt_segfile $asegdkt_segfile" +echo "" # Warning if run as root user if [ -z "$allow_root" ] && [ "$(id -u)" == "0" ] - then - echo "You are trying to run '$0' as root. We advice to avoid running FastSurfer as root, " - echo "because it will lead to files and folders created as root." - echo "If you are running FastSurfer in a docker container, you can specify the user with " - echo "'-u \$(id -u):\$(id -g)' (see https://docs.docker.com/engine/reference/run/#user)." - echo "If you want to force running as root, you may pass --allow_root to recon-surf.sh." - exit 1; -fi - -if [ "$subject" == "subject" ] then - echo "Subject ID cannot be \"subject\", please choose a different sid" - # Explanation, see https://github.com/Deep-MI/FastSurfer/issues/186 - # this is a bug in FreeSurfer's argparse when calling "mri_brainvol_stats subject" - exit 1 + echo "You are trying to run '$0' as root. We advice to avoid running FastSurfer as root, " + echo "because it will lead to files and folders created as root." + echo "If you are running FastSurfer in a docker container, you can specify the user with " + echo "'-u \$(id -u):\$(id -g)' (see https://docs.docker.com/engine/reference/run/#user)." + echo "If you want to force running as root, you may pass --allow_root to recon-surf.sh." + exit 1; fi if [ -z "$SUBJECTS_DIR" ] @@ -372,9 +236,9 @@ export FREESURFER=$FREESURFER_HOME if [ "$check_version" == "1" ] then - if grep -q -v ${FS_VERSION_SUPPORT} $FREESURFER_HOME/build-stamp.txt + if grep -q -v "${FS_VERSION_SUPPORT}" "$FREESURFER_HOME/build-stamp.txt" then - echo "ERROR: You are trying to run recon-surf with FreeSurfer version $(cat $FREESURFER_HOME/build-stamp.txt)." + echo "ERROR: You are trying to run recon-surf with FreeSurfer version $(cat "$FREESURFER_HOME/build-stamp.txt")." echo "We are currently supporting only FreeSurfer $FS_VERSION_SUPPORT" echo "Therefore, make sure to export and source the correct FreeSurfer version before running recon-surf.sh: " echo "export FREESURFER_HOME=/path/to/your/local/fs$FS_VERSION_SUPPORT" @@ -426,7 +290,7 @@ then fsthreads="-threads $threads -itkthreads $threads" fi -if [ $(echo -n "${SUBJECTS_DIR}/${subject}" | wc -m) -gt 185 ] +if [ "$(echo -n "${SUBJECTS_DIR}/${subject}" | wc -m)" -gt 185 ] then echo "ERROR: subject directory path is very long." echo "This is known to cause errors due to some commands run by freesurfer versions built for Ubuntu." @@ -441,148 +305,161 @@ if [ -f "$SUBJECTS_DIR/$subject/mri/wm.mgz" ] || [ -f "$SUBJECTS_DIR/$subject/mr exit 1 fi -# Check input segmentation quality -echo "Checking Input Segmentation Quality ..." -cmd="$python ${binpath}/../FastSurferCNN/quick_qc.py --asegdkt_segfile $asegdkt_segfile" -echo $cmd -$cmd -if [ ${PIPESTATUS[0]} -ne 0 ] ; then exit 1 ; fi -echo - # collect info -StartTime=`date`; -tSecStart=`date '+%s'`; -year=`date +%Y` -month=`date +%m` -day=`date +%d` -hour=`date +%H` -min=`date +%M` +StartTime=$(date); +tSecStart=$(date '+%s') +year=$(date +%Y) +month=$(date +%m) +day=$(date +%d) +hour=$(date +%H) +min=$(date +%M) # Setup dirs -mkdir -p $SUBJECTS_DIR/$subject/scripts -mkdir -p $SUBJECTS_DIR/$subject/mri/transforms -mkdir -p $SUBJECTS_DIR/$subject/mri/tmp -mkdir -p $SUBJECTS_DIR/$subject/surf -mkdir -p $SUBJECTS_DIR/$subject/label -mkdir -p $SUBJECTS_DIR/$subject/stats +mkdir -p "$SUBJECTS_DIR/$subject/scripts" +mkdir -p "$SUBJECTS_DIR/$subject/mri/transforms" +mkdir -p "$SUBJECTS_DIR/$subject/mri/tmp" +mkdir -p "$SUBJECTS_DIR/$subject/surf" +mkdir -p "$SUBJECTS_DIR/$subject/label" +mkdir -p "$SUBJECTS_DIR/$subject/stats" -mdir=$SUBJECTS_DIR/$subject/mri -sdir=$SUBJECTS_DIR/$subject/surf -ldir=$SUBJECTS_DIR/$subject/label +mdir="$SUBJECTS_DIR/$subject/mri" +sdir="$SUBJECTS_DIR/$subject/surf" +statsdir="$SUBJECTS_DIR/$subject/stats" +ldir="$SUBJECTS_DIR/$subject/label" -mask=$mdir/mask.mgz +mask="$mdir/mask.mgz" # Set up log file -DoneFile=$SUBJECTS_DIR/$subject/scripts/recon-surf.done -if [ $DoneFile != /dev/null ] ; then rm -f $DoneFile ; fi -LF=$SUBJECTS_DIR/$subject/scripts/recon-surf.log -if [ $LF != /dev/null ] ; then rm -f $LF ; fi -echo "Log file for recon-surf.sh" >> $LF -date |& tee -a $LF -echo "" |& tee -a $LF -echo "export SUBJECTS_DIR=$SUBJECTS_DIR" |& tee -a $LF -echo "cd `pwd`" |& tee -a $LF -echo $0 ${inputargs[*]} |& tee -a $LF -echo "" |& tee -a $LF -cat $FREESURFER_HOME/build-stamp.txt |& tee -a $LF -echo $VERSION |& tee -a $LF -uname -a |& tee -a $LF - +DoneFile="$SUBJECTS_DIR/$subject/scripts/recon-surf.done" +if [ "$DoneFile" != /dev/null ] ; then rm -f "$DoneFile" ; fi +LF="$SUBJECTS_DIR/$subject/scripts/recon-surf.log" +if [ "$LF" != /dev/null ] ; then rm -f "$LF" ; fi +echo "Log file for recon-surf.sh" >> "$LF" +{ # all output tee -a "$LF" + date 2>&1 + echo " " + echo "export SUBJECTS_DIR=$SUBJECTS_DIR" + echo "cd $(pwd)" + echo_quoted "$0" "${inputargs[@]}" + echo " " + cat "$FREESURFER_HOME/build-stamp.txt" 2>&1 + echo "$VERSION" + uname -a 2>&1 + echo " " + echo " " + echo "==================== Checking validity of inputs =================================" + echo " " # Print parallelization parameters -echo " " |& tee -a $LF -if [ "$DoParallel" == "1" ] -then - echo " RUNNING both hemis in PARALLEL " |& tee -a $LF -else - echo " RUNNING both hemis SEQUENTIALLY " |& tee -a $LF -fi -echo " RUNNING $OMP_NUM_THREADS number of OMP THREADS " |& tee -a $LF -echo " RUNNING $ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS number of ITK THREADS " |& tee -a $LF -echo " " |& tee -a $LF + # Print parallelization parameters + if [ "$DoParallel" == "1" ] + then + echo " RUNNING both hemis in PARALLEL" + else + echo " RUNNING both hemis SEQUENTIALLY" + fi + echo " RUNNING $OMP_NUM_THREADS number of OMP THREADS" + echo " RUNNING $ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS number of ITK THREADS" + echo " " + + # Check input segmentation quality + echo "Checking Input Segmentation Quality ..." +} | tee -a "$LF" +cmd="$python $FASTSURFER_HOME/FastSurferCNN/quick_qc.py --asegdkt_segfile $asegdkt_segfile" +RunIt "$cmd" "$LF" +echo "" | tee -a "$LF" -#if false; then ########################################## START ######################################################## -echo " " |& tee -a $LF -echo "================== Creating orig and rawavg from input =========================" |& tee -a $LF -echo " " |& tee -a $LF +{ + echo " " + echo "================== Creating orig and rawavg from input =========================" + echo " " +} | tee -a "$LF" CONFORM_LF=$SUBJECTS_DIR/$subject/scripts/conform.log -if [ $CONFORM_LF != /dev/null ] ; then rm -f $CONFORM_LF ; fi -echo "Log file for Conform test" > $CONFORM_LF +if [ "$CONFORM_LF" != /dev/null ] ; then rm -f "$CONFORM_LF" ; fi +echo "Log file for Conform test" > "$CONFORM_LF" # check for input conformance -cmd="$python ${binpath}../FastSurferCNN/data_loader/conform.py -i $t1 --check_only --vox_size min --verbose" -RunIt "$cmd" "$LF -a $CONFORM_LF" +cmd="$python $FASTSURFER_HOME/FastSurferCNN/data_loader/conform.py -i $t1 --check_only --vox_size min --verbose" +RunIt "$cmd" "$LF" 2>&1 | tee -a "$CONFORM_LF" # look into the CONFORM_LF to find the voxel sizes, the second conform.py call will check the legality of vox_size -vox_size=`cat $CONFORM_LF | grep -E " - Voxel Size " | cut -d' ' -f5 | cut -d'x' -f1` +vox_size=$(grep -E " - Voxel Size " "$CONFORM_LF" | cut -d' ' -f5 | cut -d'x' -f1) # remove the temporary conform_log (all info is also in the recon-surf logfile) -if [ -f "$CONFORM_LF" ]; then rm -f $CONFORM_LF ; fi +if [ -f "$CONFORM_LF" ]; then rm -f "$CONFORM_LF" ; fi # here, we check the correct vox_size by passing it to the next conform, so errors in this line might be caused above -cmd="$python ${binpath}../FastSurferCNN/data_loader/conform.py -i $asegdkt_segfile --check_only --vox_size $vox_size --dtype any --verbose" -RunIt "$cmd" $LF +cmd="$python $FASTSURFER_HOME/FastSurferCNN/data_loader/conform.py -i $asegdkt_segfile --check_only --vox_size $vox_size --dtype any --verbose" +RunIt "$cmd" "$LF" if (( $(echo "$vox_size < $hires_voxsize_threshold" | bc -l) )) then - echo "The voxel size $vox_size is less than $hires_voxsize_threshold, so we are proceeding with hires options." |& tee -a $LF + echo "The voxel size $vox_size is less than $hires_voxsize_threshold, so we are proceeding with hires options." | tee -a "$LF" hiresflag="-hires" noconform_if_hires=" -noconform" hires_surface_suffix=".predec" else - echo "The voxel size $vox_size is not less than $hires_voxsize_threshold, so we are proceeding with standard options." |& tee -a $LF + echo "The voxel size $vox_size is not less than $hires_voxsize_threshold, so we are proceeding with standard options." | tee -a "$LF" hiresflag="" noconform_if_hires="" hires_surface_suffix="" fi -# create orig.mgz and aparc.DKTatlas+aseg.orig.mgz (copy of segmentation) +# create orig.mgz and aparc.DKTatlas+aseg.orig.mgz (copy of T1 and segmentation) +# also ensures .mgz format (in case inputs are nifti) cmd="mri_convert $t1 $mdir/orig.mgz" -RunIt "$cmd" $LF +RunIt "$cmd" "$LF" cmd="mri_convert $asegdkt_segfile $mdir/aparc.DKTatlas+aseg.orig.mgz" -RunIt "$cmd" $LF +RunIt "$cmd" "$LF" + +# link original T1 input to rawavg (needed by pctsurfcon) +pushd "$mdir" > /dev/null || ( echo "Could not change to $mdir" ; exit 1 ) + softlink_or_copy "orig.mgz" "rawavg.mgz" "$LF" +popd > /dev/null || ( echo "Could not change to subject_dir" ; exit 1 ) + + -# link to rawavg (needed by pctsurfcon) -pushd $mdir -cmd="ln -sf orig.mgz rawavg.mgz" -RunIt "$cmd" $LF -popd +### The following steps are now usually done outside recon-surf already by the segmentation pipeline. +### However, if these files such as mask, aseg.auto_noCCseg, orig_nu or talairach transforms don't +### exist, we recreate them here, so that this can run on other type of input where only a T1 and +### segmentation is provided. This may need update if it changes in the segmentation pipeline. + + +# ============================= MASK & ASEG_noCC ======================================== -### START SUPERSEDED BY SEGMENTATION PIPELINE, will be removed in the future -### ---------- if [ ! -f "$mask" ] || [ ! -f "$mdir/aseg.auto_noCCseg.mgz" ] ; then - # Mask or aseg.auto_noCCseg not found; create them - echo " " |& tee -a $LF - echo "============= Creating aseg.auto_noCCseg (map aparc labels back) ===============" |& tee -a $LF - echo " " |& tee -a $LF + { + # Mask or aseg.auto_noCCseg not found; create them from aparc.DKTatlas+aseg + echo " " + echo "============= Creating aseg.auto_noCCseg (map aparc labels back) ===============" + echo " " + } | tee -a "$LF" # reduce labels to aseg, then create mask (dilate 5, erode 4, largest component), also mask aseg to remove outliers # output will be uchar (else mri_cc will fail below) - cmd="$python ${binpath}/../FastSurferCNN/reduce_to_aseg.py -i $mdir/aparc.DKTatlas+aseg.orig.mgz -o $mdir/aseg.auto_noCCseg.mgz --outmask $mask --fixwm" - RunIt "$cmd" $LF + cmd="$python $FASTSURFER_HOME/FastSurferCNN/reduce_to_aseg.py -i $mdir/aparc.DKTatlas+aseg.orig.mgz -o $mdir/aseg.auto_noCCseg.mgz --outmask $mask --fixwm" + RunIt "$cmd" "$LF" fi -### END SUPERSEDED BY SEGMENTATION PIPELINE, will be removed in the future -### ---------- -echo " " |& tee -a $LF -echo "============= Computing Talairach Transform and NU (bias corrected) ============" |& tee -a $LF -echo " " |& tee -a $LF -pushd $mdir +# ============================= NU BIAS CORRECTION ======================================= -### START SUPERSEDED BY SEGMENTATION PIPELINE, will be removed in the future -### ---------- -# only run the bias field correction, if the bias field corrected does not exist already if [ ! -f "$mdir/orig_nu.mgz" ]; then + # only run the bias field correction, if the bias field corrected does not exist already + { + echo " " + echo "============= Computing NU (bias corrected) ============" + echo " " + } | tee -a "$LF" # nu processing is changed here compared to recon-all: we use the brainmask from the # segmentation to improve the nu correction (and speedup) # orig_nu N3 in FS6 took 44 sec, FS 7.3.2 uses --ants-n4 (takes 3 min and does not accept @@ -593,122 +470,111 @@ if [ ! -f "$mdir/orig_nu.mgz" ]; then # frontal head), we don't. Also this avoids a second call to nu correct. # talairach.xfm is also not needed here at all, it can be dropped if other places in the # stream can be changed to avoid it. - #cmd="mri_nu_correct.mni --no-rescale --i $mdir/orig.mgz --o $mdir/orig_nu.mgz --n 1 --proto-iters 1000 --distance 50 --mask $mdir/mask.mgz" - cmd="$python ${binpath}/N4_bias_correct.py --in $mdir/orig.mgz --rescale $mdir/orig_nu.mgz --aseg $mdir/aparc.DKTatlas+aseg.orig.mgz --threads $threads" - RunIt "$cmd" $LF + pushd "$mdir" > /dev/null || ( echo "Cannot change to $mdir" ; exit 1 ) + #cmd="mri_nu_correct.mni --no-rescale --i $mdir/orig.mgz --o $mdir/orig_nu.mgz --n 1 --proto-iters 1000 --distance 50 --mask $mdir/mask.mgz" + cmd="$python ${binpath}/N4_bias_correct.py --in $mdir/orig.mgz --rescale $mdir/orig_nu.mgz --aseg $mdir/aparc.DKTatlas+aseg.orig.mgz --threads $threads" + RunIt "$cmd" "$LF" + popd > /dev/null || (echo "Could not popd" ; exit 1) fi -### END SUPERSEDED BY SEGMENTATION PIPELINE, will be removed in the future -### ---------- -# talairach.xfm: compute talairach full head (25sec) -if [[ "$atlas3T" == "true" ]] -then - echo "Using the 3T atlas for talairach registration." - atlas="--atlas 3T18yoSchwartzReactN32_as_orig" -else - echo "Using the default atlas (1.5T) for talairach registration." - atlas="" -fi -cmd="talairach_avi --i $mdir/orig_nu.mgz --xfm $mdir/transforms/talairach.auto.xfm $atlas" -RunIt "$cmd" $LF -# create copy -cmd="cp $mdir/transforms/talairach.auto.xfm $mdir/transforms/talairach.xfm" -RunIt "$cmd" $LF -# talairach.lta: convert to lta -cmd="lta_convert --src $mdir/orig.mgz --trg $FREESURFER_HOME/average/mni305.cor.mgz --inxfm $mdir/transforms/talairach.xfm --outlta $mdir/transforms/talairach.xfm.lta --subject fsaverage --ltavox2vox" -RunIt "$cmd" $LF - -# FS would here create better nu.mgz using talairach transform (finds wm and maps it to approx 110) -#NuIterations="1 --proto-iters 1000 --distance 50" # default 3T -#FS60 cmd="mri_nu_correct.mni --i $mdir/orig.mgz --o $mdir/nu.mgz --uchar $mdir/transforms/talairach.xfm --n $NuIterations --mask $mdir/mask.mgz" -#FS72 cmd="mri_nu_correct.mni --i $mdir/orig.mgz --o $mdir/nu.mgz --uchar $mdir/transforms/talairach.xfm --n $NuIterations --ants-n4" -# all this is basically useless, as we did a good orig_nu already, including WM normalization - -# Since we do not run mri_em_register we sym-link other talairach transform files here -pushd $mdir/transforms -cmd="ln -sf talairach.xfm.lta talairach_with_skull.lta" -RunIt "$cmd" $LF -cmd="ln -sf talairach.xfm.lta talairach.lta" -RunIt "$cmd" $LF -popd - -# Add xfm to nu -# (use orig_nu, if nu.mgz does not exist already); by default, it should exist -if [[ -e "$mdir/nu.mgz" ]]; then src_nu_file="$mdir/nu.mgz" -else src_nu_file="$mdir/orig_nu.mgz" -fi -cmd="mri_add_xform_to_header -c $mdir/transforms/talairach.xfm $src_nu_file $mdir/nu.mgz" -RunIt "$cmd" $LF -popd +# ============================= TALAIRACH ============================================== +if [[ ! -f "$mdir/transforms/talairach.lta" ]] || [[ ! -f "$mdir/transforms/talairach_with_skull.lta" ]]; then + { + echo " " + echo "============= Computing Talairach Transform ============" + echo " " + echo "\"$binpath/talairach-reg.sh\" \"$mdir\" \"$atlas3T\" \"$LF\"" + } | tee -a "$LF" + "$binpath/talairach-reg.sh" "$mdir" "$atlas3T" "$LF" +fi -echo " " |& tee -a $LF -echo "============ Creating brainmask from aseg and norm, and update aseg ============" |& tee -a $LF -echo " " |& tee -a $LF + +# ============================= BRAINMASK ============================================== +{ + echo " " + echo "============ Creating brainmask from aseg and nu or T1 ============" + echo " " +} | tee -a $LF # create norm by masking nu cmd="mri_mask $mdir/nu.mgz $mdir/mask.mgz $mdir/norm.mgz" -RunIt "$cmd" $LF - +RunIt "$cmd" "$LF" if [ "$get_t1" == "1" ] then # create T1.mgz from nu (!! here we could also try passing aseg?) cmd="mri_normalize -g 1 -seed 1234 -mprage $mdir/nu.mgz $mdir/T1.mgz $noconform_if_hires" - RunIt "$cmd" $LF - + RunIt "$cmd" "$LF" # create brainmask by masking T1 cmd="mri_mask $mdir/T1.mgz $mdir/mask.mgz $mdir/brainmask.mgz" - RunIt "$cmd" $LF + RunIt "$cmd" "$LF" else - # Default: create brainmask by linkage to norm.mgz (masked nu.mgz) - pushd $mdir - cmd="ln -sf norm.mgz brainmask.mgz" - RunIt "$cmd" $LF - popd + # create brainmask by linkage to norm.mgz (masked nu.mgz) + pushd "$mdir" > /dev/null || ( echo "Could not cd to $mdir" ; exit 1 ) + softlink_or_copy "norm.mgz" "brainmask.mgz" "$LF" + popd > /dev/null || (echo "Could not popd" ; exit 1 ) fi -# create aseg.auto including cc segmentation and 46 sec, requires norm.mgz + +# ============================= CC SEGMENTATION ============================================ + +{ + echo " " + echo "============ Creating and adding CC Segmentation ============" + echo " " +} | tee -a $LF +# create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz # Note: if original input segmentation already contains CC, this will exit with ERROR # in the future maybe check and skip this step (and next) cmd="mri_cc -aseg aseg.auto_noCCseg.mgz -o aseg.auto.mgz -lta $mdir/transforms/cc_up.lta $subject" -RunIt "$cmd" $LF - -# add cc into aparc.DKTatlas+aseg.deep (not sure if this is really needed) +RunIt "$cmd" "$LF" +# add CC into aparc.DKTatlas+aseg.deep (not sure if this is really needed) cmd="$python ${binpath}paint_cc_into_pred.py -in_cc $mdir/aseg.auto.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" -RunIt "$cmd" $LF +RunIt "$cmd" "$LF" -echo " " |& tee -a $LF -echo "========= Creating filled from brain (brainfinalsurfs, wm.asegedit, wm) =======" |& tee -a $LF -echo " " |& tee -a $LF +# ============================= FILLED ===================================================== + +{ + echo " " + echo "========= Creating filled from brain (brainfinalsurfs, wm.asegedit, wm) =======" + echo " " +} | tee -a $LF + +# filled is needed to generate initial WM surfaces cmd="recon-all -s $subject -asegmerge -normalization2 -maskbfs -segmentation -fill $hiresflag $fsthreads" -RunIt "$cmd" $LF +RunIt "$cmd" "$LF" -# ================================================== SURFACES ========================================================== +# ======= +# ================================================== SURFACES ============================================================== +# ======= -CMDFS="" +CMDFS=() for hemi in lh rh; do -CMDF="$SUBJECTS_DIR/$subject/scripts/$hemi.processing.cmdf" -CMDFS="$CMDFS $CMDF" -rm -rf $CMDF + CMDF="$SUBJECTS_DIR/$subject/scripts/$hemi.processing.cmdf" + CMDFS+=("$CMDF") + rm -rf "$CMDF" + echo "#!/bin/bash" > "$CMDF" -echo "#!/bin/bash" > $CMDF -echo "echo " |& tee -a $CMDF -echo "echo \"================== Creating surfaces $hemi - orig.nofix ==================\"" |& tee -a $CMDF -echo "echo " |& tee -a $CMDF +# ============================= TESSELATE - SMOOTH ===================================================== -if [ "$fstess" == "1" ] -then - cmd="recon-all -subject $subject -hemi $hemi -tessellate -smooth1 -no-isrunning $hiresflag $fsthreads" - RunIt "$cmd" $LF $CMDF -else - # instead of mri_tesselate lego land use marching cube + { + echo "echo \" \"" + echo "echo \"================== Creating surfaces $hemi - orig.nofix ==================\"" + echo "echo \" \"" + } | tee -a "$CMDF" + if [ "$fstess" == "1" ] + then + cmd="recon-all -subject $subject -hemi $hemi -tessellate -smooth1 -no-isrunning $hiresflag $fsthreads" + RunIt "$cmd" "$LF" "$CMDF" + else + # instead of mri_tesselate lego land use marching cube if [ $hemi == "lh" ]; then hemivalue=255; @@ -718,221 +584,246 @@ else # extract initial surface "?h.orig.nofix" cmd="mri_pretess $mdir/filled.mgz $hemivalue $mdir/brain.mgz $mdir/filled-pretess$hemivalue.mgz" - RunIt "$cmd" $LF $CMDF + RunIt "$cmd" "$LF" "$CMDF" # Marching cube does not return filename and wrong volume info! outmesh=$sdir/$hemi.orig.nofix$hires_surface_suffix cmd="mri_mc $mdir/filled-pretess$hemivalue.mgz $hemivalue $outmesh" - RunIt "$cmd" $LF $CMDF + RunIt "$cmd" "$LF" "$CMDF" # Rewrite surface orig.nofix to fix vertex locs bug (scannerRAS instead of surfaceRAS set with mc) - cmd="$python ${binpath}rewrite_mc_surface.py --input $outmesh --output $outmesh --filename_pretess $mdir/filled-pretess$hemivalue.mgz" - RunIt "$cmd" $LF $CMDF + #cmd="$python ${binpath}rewrite_mc_surface.py --input $outmesh --output $outmesh --filename_pretess $mdir/filled-pretess$hemivalue.mgz" + #RunIt "$cmd" "$LF" "$CMDF" # Check if the surfaceRAS was correctly set and exit otherwise (sanity check in case nibabel changes their default header behaviour) - cmd="mris_info $outmesh | tr -s ' ' | grep -q 'vertex locs : surfaceRAS'" - echo "echo \"$cmd\" " |& tee -a $CMDF - echo "$timecmd $cmd " |& tee -a $CMDF - echo "if [ \${PIPESTATUS[1]} -ne 0 ] ; then echo \"Incorrect header information detected in $outmesh: vertex locs is not set to surfaceRAS. Exiting... \"; exit 1 ; fi" >> $CMDF + { + cmd="mris_info $outmesh | tr -s ' ' | grep -q 'vertex locs : surfaceRAS'" + echo "echo \"$cmd\"" + echo "$timecmd $cmd" + } | tee -a "$CMDF" + echo "if [ \${PIPESTATUS[1]} -ne 0 ] ; then echo \"Incorrect header information detected in $outmesh: vertex locs is not set to surfaceRAS. Exiting... \"; exit 1 ; fi" >> "$CMDF" # Reduce to largest component (usually there should only be one) cmd="mris_extract_main_component $outmesh $outmesh" - RunIt "$cmd" $LF $CMDF + RunIt "$cmd" "$LF" "$CMDF" # for hires decimate mesh - if [ ! -z "$hiresflag" ]; then + if [ -n "$hiresflag" ]; then DecimationFaceArea="0.5" # Reduce the number of faces such that the average face area is # DecimationFaceArea. If the average face area is already more # than DecimationFaceArea, then the surface is not changed. # set cmd = (mris_decimate -a $DecimationFaceArea ../surf/$hemi.orig.nofix.predec ../surf/$hemi.orig.nofix) cmd="mris_remesh --desired-face-area $DecimationFaceArea --input $outmesh --output $sdir/$hemi.orig.nofix" - RunIt "$cmd" $LF $CMDF + RunIt "$cmd" "$LF" "$CMDF" fi - # -smooth1 (explicitly state 10 iteration (default) but may change in future) cmd="mris_smooth -n 10 -nw -seed 1234 $sdir/$hemi.orig.nofix $sdir/$hemi.smoothwm.nofix" - RunIt "$cmd" $LF $CMDF - -fi - - - -echo "echo " |& tee -a $CMDF -echo "echo \"=================== Creating surfaces $hemi - qsphere ====================\"" |& tee -a $CMDF -echo "echo " |& tee -a $CMDF + RunIt "$cmd" "$LF" "$CMDF" + fi -#surface inflation (54sec both hemis) (needed for qsphere and for topo-fixer) -cmd="recon-all -subject $subject -hemi $hemi -inflate1 -no-isrunning $hiresflag $fsthreads" -RunIt "$cmd" $LF $CMDF +# ============================= INFLATE1 - QSPHERE ===================================================== + { + echo "echo \"\"" + echo "echo \"=================== Creating surfaces $hemi - qsphere ====================\"" + echo "echo \"\"" + } | tee -a "$CMDF" -if [ "$fsqsphere" == "1" ] -then - # quick spherical mapping (2min48sec) - cmd="recon-all -subject $subject -hemi $hemi -qsphere -no-isrunning $hiresflag $fsthreads" - RunIt "$cmd" $LF $CMDF - -else + #surface inflation (54sec both hemis) (needed for qsphere and for topo-fixer) + cmd="recon-all -subject $subject -hemi $hemi -inflate1 -no-isrunning $hiresflag $fsthreads" + RunIt "$cmd" "$LF" "$CMDF" + if [ "$fsqsphere" == "1" ] + then + # quick spherical mapping (2min48sec) + cmd="recon-all -subject $subject -hemi $hemi -qsphere -no-isrunning $hiresflag $fsthreads" + RunIt "$cmd" "$LF" "$CMDF" + else # instead of mris_sphere, directly project to sphere with spectral approach # equivalent to -qsphere # (23sec) - cmd="$python ${binpath}spherically_project_wrapper.py --hemi $hemi --sdir $sdir --subject $subject" - cmd="$cmd --threads=$threads --py $python --binpath ${binpath}" + cmd="$python ${binpath}spherically_project_wrapper.py --hemi $hemi --sdir $sdir" + printf -v tmp %q "$python" + cmd="$cmd --subject $subject --threads=$threads --py ${tmp} --binpath ${binpath}" + RunIt "$cmd" "$LF" "$CMDF" + fi - RunIt "$cmd" $LF $CMDF +# ============================= FIX - WHITEPREAPARC - CORTEXLABEL ============================================ -fi + { + echo "echo \"\"" + echo "echo \"=================== Creating surfaces $hemi - fix ========================\"" + echo "echo \"\"" + } | tee -a "$CMDF" + cmd="recon-all -subject $subject -hemi $hemi -fix -no-isrunning $hiresflag $fsthreads" + RunIt "$cmd" $LF $CMDF -echo "echo " |& tee -a $CMDF -echo "echo \"=================== Creating surfaces $hemi - fix ========================\"" |& tee -a $CMDF -echo "echo " |& tee -a $CMDF + # fix the surfaces if they are corrupt + cmd="$python ${binpath}rewrite_oriented_surface.py --file $sdir/$hemi.orig.premesh --backup $sdir/$hemi.orig.premesh.noorient" + RunIt "$cmd" $LF $CMDF + cmd="$python ${binpath}rewrite_oriented_surface.py --file $sdir/$hemi.orig --backup $sdir/$hemi.orig.noorient" + RunIt "$cmd" $LF $CMDF -## -fix -cmd="recon-all -subject $subject -hemi $hemi -fix -autodetgwstats -white-preaparc -cortex-label -no-isrunning $hiresflag $fsthreads" -RunIt "$cmd" $LF $CMDF + cmd="recon-all -subject $subject -hemi $hemi -autodetgwstats -white-preaparc -cortex-label -no-isrunning $hiresflag $fsthreads" + RunIt "$cmd" "$LF" "$CMDF" ## copy nofix to orig and inflated for next step # -white (don't know how to call this from recon-all as it needs -whiteonly setting and by default it also creates the pial. # create first WM surface white.preaparc from topo fixed orig surf, also first cortex label (1min), (3min for deep learning surf) +# ============================= INFLATE2 - CURVHK =================================================== -echo "echo \" \"" |& tee -a $CMDF -echo "echo \"================== Creating surfaces $hemi - inflate2 ====================\"" |& tee -a $CMDF -echo "echo \" \"" |& tee -a $CMDF + { + echo "echo \"\"" + echo "echo \"================== Creating surfaces $hemi - inflate2 ====================\"" + echo "echo \"\"" + } | tee -a "$CMDF" + # create nicer inflated surface from topo fixed (not needed, just later for visualization) + cmd="recon-all -subject $subject -hemi $hemi -smooth2 -inflate2 -curvHK -no-isrunning $hiresflag $fsthreads" + RunIt "$cmd" "$LF" "$CMDF" -# create nicer inflated surface from topo fixed (not needed, just later for visualization) -cmd="recon-all -subject $subject -hemi $hemi -smooth2 -inflate2 -curvHK -no-isrunning $hiresflag $fsthreads" -RunIt "$cmd" $LF $CMDF +# ============================= MAP-DKT ========================================================== -echo "echo \" \"" |& tee -a $CMDF -echo "echo \"=========== Creating surfaces $hemi - map input asegdkt_segfile to surf ===============\"" |& tee -a $CMDF -echo "echo \" \"" |& tee -a $CMDF + { + echo "echo \" \"" + echo "echo \"=========== Creating surfaces $hemi - map input asegdkt_segfile to surf ===============\"" + echo "echo \" \"" + } | tee -a "$CMDF" - # sample input segmentation (aparc.DKTatlas+aseg orig) onto wm surface: - # map input aparc to surface (requires thickness (and thus pail) to compute projfrac 0.5), here we do projmm which allows us to compute based only on white - # this is dangerous, as some cortices could be < 0.6 mm, but then there is no volume label probably anyway. - # Also note that currently we cannot mask non-cortex regions here, should be done in mris_anatomical stats later - # the smoothing helps - cmd="mris_sample_parc -ct $FREESURFER_HOME/average/colortable_desikan_killiany.txt -file ${binpath}$hemi.DKTatlaslookup.txt -projmm 0.6 -f 5 -surf white.preaparc $subject $hemi aparc.DKTatlas+aseg.orig.mgz aparc.DKTatlas.mapped.prefix.annot" - RunIt "$cmd" $LF $CMDF + # sample input segmentation (aparc.DKTatlas+aseg orig) onto wm surface: + # map input aparc to surface (requires thickness (and thus pail) to compute projfrac 0.5), here we do projmm which allows us to compute based only on white + # this is dangerous, as some cortices could be < 0.6 mm, but then there is no volume label probably anyway. + # Also note that currently we cannot mask non-cortex regions here, should be done in mris_anatomical stats later + # the smoothing helps + #cmd="mris_sample_parc -ct $FREESURFER_HOME/average/colortable_desikan_killiany.txt -file ${binpath}$hemi.DKTatlaslookup.txt -projmm 0.6 -f 5 -surf white.preaparc $subject $hemi aparc.DKTatlas+aseg.orig.mgz aparc.DKTatlas.mapped.prefix.annot" + #RunIt "$cmd" "$LF" "$CMDF" + #cmd="$python ${binpath}smooth_aparc.py --insurf $sdir/$hemi.white.preaparc --inaparc $ldir/$hemi.aparc.DKTatlas.mapped.prefix.annot --incort $ldir/$hemi.cortex.label --outaparc $ldir/$hemi.aparc.DKTatlas.mapped.annot" + #RunIt "$cmd" "$LF" "$CMDF" + cmd="$python ${binpath}sample_parc.py --inseg $mdir/aparc.DKTatlas+aseg.orig.mgz --insurf $sdir/$hemi.white.preaparc --incort $ldir/$hemi.cortex.label --outaparc $ldir/$hemi.aparc.DKTatlas.mapped.annot --seglut ${binpath}$hemi.DKTatlaslookup.txt --surflut ${binpath}DKTatlaslookup.txt --projmm 0.6 --radius 2" + RunIt "$cmd" "$LF" "$CMDF" - cmd="$python ${binpath}smooth_aparc.py --insurf $sdir/$hemi.white.preaparc --inaparc $ldir/$hemi.aparc.DKTatlas.mapped.prefix.annot --incort $ldir/$hemi.cortex.label --outaparc $ldir/$hemi.aparc.DKTatlas.mapped.annot" - RunIt "$cmd" $LF $CMDF +# ============================= SPHERE - SURFREG (optional) ============================================== -# if we segment with FS or if surface registration is requested do it here: -if [ "$fsaparc" == "1" ] || [ "$fssurfreg" == "1" ] ; then - echo "echo \" \"" |& tee -a $CMDF - echo "echo \"============ Creating surfaces $hemi - FS sphere, surfreg ===============\"" |& tee -a $CMDF - echo "echo \" \"" |& tee -a $CMDF + # if we segment with FS or if surface registration is requested do it here: + if [ "$fsaparc" == "1" ] || [ "$fssurfreg" == "1" ] ; then + { + echo "echo \" \"" + echo "echo \"============ Creating surfaces $hemi - FS sphere, surfreg ===============\"" + echo "echo \" \"" + } | tee -a "$CMDF" - # Surface registration for cross-subject correspondence (registration to fsaverage) - cmd="recon-all -subject $subject -hemi $hemi -sphere $hiresflag -no-isrunning $fsthreads" - RunIt "$cmd" $LF "$CMDF" - - # (mr) FIX: sometimes FreeSurfer Sphere Reg. fails and moves pre and post central - # one gyrus too far posterior, FastSurferCNN's image-based segmentation does not - # seem to do this, so we initialize the spherical registration with the better - # cortical segmentation from FastSurferCNN, this replaces recon-all -surfreg - # 1. get alpha, beta, gamma for global alignment (rotation) based on aseg centers - # (note the former fix, initializing with pre-central label, is not working in FS7.2 - # as they broke the label initialization in mris_register) - cmd="$python ${binpath}/rotate_sphere.py \ - --srcsphere $sdir/${hemi}.sphere \ - --srcaparc $ldir/$hemi.aparc.DKTatlas.mapped.annot \ - --trgsphere $FREESURFER_HOME/subjects/fsaverage/surf/${hemi}.sphere \ - --trgaparc $FREESURFER_HOME/subjects/fsaverage/label/${hemi}.aparc.annot \ - --out $sdir/${hemi}.angles.txt" - RunIt "$cmd" $LF "$CMDF" - # 2. use global rotation as initialization to non-linear registration: - cmd="mris_register -curv -norot -rotate \`cat $sdir/${hemi}.angles.txt\` \ - $sdir/${hemi}.sphere \ - $FREESURFER_HOME/average/${hemi}.folding.atlas.acfb40.noaparc.i12.2016-08-02.tif \ - $sdir/${hemi}.sphere.reg" - RunIt "$cmd" $LF "$CMDF" - # command to generate new aparc to check if registration was OK - # run only for debugging - #cmd="mris_ca_label -l $SUBJECTS_DIR/$subject/label/${hemi}.cortex.label \ - # -aseg $SUBJECTS_DIR/$subject/mri/aseg.presurf.mgz \ - # -seed 1234 $subject $hemi $SUBJECTS_DIR/$subject/surf/${hemi}.sphere.reg \ - # $SUBJECTS_DIR/$subject/label/${hemi}.aparc.DKTatlas-guided.annot" + # Surface registration for cross-subject correspondence (registration to fsaverage) + cmd="recon-all -subject $subject -hemi $hemi -sphere $hiresflag -no-isrunning $fsthreads" + RunIt "$cmd" "$LF" "$CMDF" -fi - - -if [ "$fsaparc" == "1" ] ; then - - echo "echo \" \"" |& tee -a $CMDF - echo "echo \"============ Creating surfaces $hemi - FS asegdkt_segfile..pial ===============\"" |& tee -a $CMDF - echo "echo \" \"" |& tee -a $CMDF - - # 20-25 min for traditional surface segmentation (each hemi) - # this creates aparc and creates pial using aparc, also computes jacobian - cmd="recon-all -subject $subject -hemi $hemi -jacobian_white -avgcurv -cortparc -white -pial -no-isrunning $hiresflag $fsthreads" - RunIt "$cmd" $LF $CMDF - - # Here insert DoT2Pial later! + # (mr) FIX: sometimes FreeSurfer Sphere Reg. fails and moves pre and post central + # one gyrus too far posterior, FastSurferCNN's image-based segmentation does not + # seem to do this, so we initialize the spherical registration with the better + # cortical segmentation from FastSurferCNN, this replaces recon-all -surfreg + # 1. get alpha, beta, gamma for global alignment (rotation) based on aseg centers + # (note the former fix, initializing with pre-central label, is not working in FS7.2 + # as they broke the label initialization in mris_register) + cmd="$python ${binpath}/rotate_sphere.py \ + --srcsphere $sdir/${hemi}.sphere \ + --srcaparc $ldir/$hemi.aparc.DKTatlas.mapped.annot \ + --trgsphere $FREESURFER_HOME/subjects/fsaverage/surf/${hemi}.sphere \ + --trgaparc $FREESURFER_HOME/subjects/fsaverage/label/${hemi}.aparc.annot \ + --out $sdir/${hemi}.angles.txt" + RunIt "$cmd" "$LF" "$CMDF" + # 2. use global rotation as initialization to non-linear registration: + cmd="mris_register -curv -norot -rotate \`cat $sdir/${hemi}.angles.txt\` \ + $sdir/${hemi}.sphere \ + $FREESURFER_HOME/average/${hemi}.folding.atlas.acfb40.noaparc.i12.2016-08-02.tif \ + $sdir/${hemi}.sphere.reg" + RunIt "$cmd" "$LF" "$CMDF" + # command to generate new aparc to check if registration was OK + # run only for debugging + #cmd="mris_ca_label -l $SUBJECTS_DIR/$subject/label/${hemi}.cortex.label \ + # -aseg $SUBJECTS_DIR/$subject/mri/aseg.presurf.mgz \ + # -seed 1234 $subject $hemi $SUBJECTS_DIR/$subject/surf/${hemi}.sphere.reg \ + # $SUBJECTS_DIR/$subject/label/${hemi}.aparc.DKTatlas-guided.annot" + fi -else +# ============================= WHITE & PIAL & (FSSURFSEG optional) =============================================== - echo "echo \" \"" |& tee -a $CMDF - echo "echo \"================ Creating surfaces $hemi - white and pial direct ===================\"" |& tee -a $CMDF - echo "echo \" \"" |& tee -a $CMDF + if [ "$fsaparc" == "1" ] ; then + { + echo "echo \" \"" + echo "echo \"============ Creating surfaces $hemi - FS asegdkt_segfile..pial ===============\"" + echo "echo \" \"" + } | tee -a "$CMDF" + # 20-25 min for traditional surface segmentation (each hemi) + # this creates aparc and creates pial using aparc, also computes jacobian + cmd="recon-all -subject $subject -hemi $hemi -jacobian_white -avgcurv -cortparc -white -pial -no-isrunning $hiresflag $fsthreads" + RunIt "$cmd" "$LF" "$CMDF" + # Here insert DoT2Pial later! + else + { + echo "echo \" \"" + echo "echo \"================ Creating surfaces $hemi - white and pial direct ===================\"" + echo "echo \" \"" + } | tee -a "$CMDF" # 4 min compute white : - echo "pushd $mdir" >> $CMDF + echo "pushd $mdir > /dev/null" >> "$CMDF" cmd="mris_place_surface --adgws-in ../surf/autodet.gw.stats.$hemi.dat --seg aseg.presurf.mgz --wm wm.mgz --invol brain.finalsurfs.mgz --$hemi --i ../surf/$hemi.white.preaparc --o ../surf/$hemi.white --white --nsmooth 0 --rip-label ../label/$hemi.cortex.label --rip-bg --rip-surf ../surf/$hemi.white.preaparc --aparc ../label/$hemi.aparc.DKTatlas.mapped.annot" - RunIt "$cmd" $LF $CMDF + RunIt "$cmd" "$LF" "$CMDF" # 4 min compute pial : cmd="mris_place_surface --adgws-in ../surf/autodet.gw.stats.$hemi.dat --seg aseg.presurf.mgz --wm wm.mgz --invol brain.finalsurfs.mgz --$hemi --i ../surf/$hemi.white --o ../surf/$hemi.pial.T1 --pial --nsmooth 0 --rip-label ../label/$hemi.cortex+hipamyg.label --pin-medial-wall ../label/$hemi.cortex.label --aparc ../label/$hemi.aparc.DKTatlas.mapped.annot --repulse-surf ../surf/$hemi.white --white-surf ../surf/$hemi.white" - RunIt "$cmd" $LF $CMDF - echo "popd" >> $CMDF + RunIt "$cmd" "$LF" "$CMDF" + echo "popd > /dev/null" >> "$CMDF" # Here insert DoT2Pial later --> if T2pial is not run, need to softlink pial.T1 to pial! - echo "pushd $sdir" >> $CMDF - cmd="ln -sf $hemi.pial.T1 $hemi.pial" - RunIt "$cmd" $LF $CMDF - echo "popd" >> $CMDF + echo "pushd $sdir > /dev/null" >> "$CMDF" + softlink_or_copy "$hemi.pial.T1" "$hemi.pial" "$LF" "$CMDF" + echo "popd > /dev/null" >> "$CMDF" - echo "pushd $mdir" >> $CMDF + echo "pushd $mdir > /dev/null" >> "$CMDF" # these are run automatically in fs7* recon-all and cannot be called directly without -pial flag (or other t2 flags) if [ "$fssurfreg" == "1" ] ; then # jacobian needs sphere reg which might be turned off by user (on by default) cmd="mris_jacobian ../surf/$hemi.white ../surf/$hemi.sphere.reg ../surf/$hemi.jacobian_white" - RunIt "$cmd" $LF $CMDF + RunIt "$cmd" "$LF" "$CMDF" fi cmd="mris_place_surface --curv-map ../surf/$hemi.white 2 10 ../surf/$hemi.curv" - RunIt "$cmd" $LF $CMDF + RunIt "$cmd" "$LF" "$CMDF" cmd="mris_place_surface --area-map ../surf/$hemi.white ../surf/$hemi.area" - RunIt "$cmd" $LF $CMDF + RunIt "$cmd" "$LF" "$CMDF" cmd="mris_place_surface --curv-map ../surf/$hemi.pial 2 10 ../surf/$hemi.curv.pial" - RunIt "$cmd" $LF $CMDF + RunIt "$cmd" "$LF" "$CMDF" cmd="mris_place_surface --area-map ../surf/$hemi.pial ../surf/$hemi.area.pial" - RunIt "$cmd" $LF $CMDF + RunIt "$cmd" "$LF" "$CMDF" cmd="mris_place_surface --thickness ../surf/$hemi.white ../surf/$hemi.pial 20 5 ../surf/$hemi.thickness" - RunIt "$cmd" $LF $CMDF - echo "popd" >> $CMDF -fi + RunIt "$cmd" "$LF" "$CMDF" + echo "popd > /dev/null" >> "$CMDF" + fi -# in FS7 curvstats moves here -cmd="recon-all -subject $subject -hemi $hemi -curvstats -no-isrunning $hiresflag $fsthreads" -RunIt "$cmd" $LF "$CMDF" +# ============================= CURVSTATS =============================================== + + # in FS7 curvstats moves here + cmd="recon-all -subject $subject -hemi $hemi -curvstats -no-isrunning $hiresflag $fsthreads" + RunIt "$cmd" "$LF" "$CMDF" -if [ "$DoParallel" == "0" ] ; then - echo " " |& tee -a $LF - echo " RUNNING $hemi sequentially ... " |& tee -a $LF - echo " " |& tee -a $LF - chmod u+x $CMDF - RunIt "$CMDF" $LF -fi + + + + if [ "$DoParallel" == "0" ] ; then + { + echo " " + echo " RUNNING $hemi sequentially ... " + echo " " + } | tee -a "$LF" + chmod u+x $CMDF + RunIt "$CMDF" $LF + fi done # hemi loop ---------------------------------- @@ -940,187 +831,288 @@ done # hemi loop ---------------------------------- if [ "$DoParallel" == 1 ] ; then - echo " " |& tee -a $LF - echo " RUNNING HEMIs in PARALLEL !!! " |& tee -a $LF - echo " " |& tee -a $LF - RunBatchJobs $LF $CMDFS + { + echo "" + echo " RUNNING HEMIs in PARALLEL !!! " + echo "" + } | tee -a "$LF" + RunBatchJobs "$LF" "${CMDFS[@]}" fi +# ============================= RIBBON =============================================== -echo " " |& tee -a $LF -echo "============================ Creating surfaces - ribbon ===========================" |& tee -a $LF -echo " " |& tee -a $LF +{ + echo "" + echo "============================ Creating surfaces - ribbon ===========================" + echo "" +} | tee -a "$LF" # -cortribbon 4 minutes, ribbon is used in mris_anatomical stats to remove voxels from surface based volumes that should not be cortex - # anatomical stats can run without ribon, but will omit some surface based measures then + # anatomical stats can run without ribbon, but will omit some surface based measures then # wmparc needs ribbon, probably other stuff (aparc to aseg etc). - # could be stripped but lets run it to have these measures below + # So lets run it to have these measures below. cmd="recon-all -subject $subject -cortribbon $hiresflag $fsthreads" - RunIt "$cmd" $LF - - + RunIt "$cmd" "$LF" -if [ "$fsaparc" == "1" ] ; then - echo " " |& tee -a $LF - echo "============= Creating surfaces - other FS asegdkt_segfile and stats =======================" |& tee -a $LF - echo " " |& tee -a $LF +# ============================= FSAPARC - parc23 surfcon hypo ... ========================================= - cmd="recon-all -subject $subject -cortparc2 -cortparc3 -pctsurfcon -hyporelabel $hiresflag $fsthreads" - RunIt "$cmd" $LF + if [ "$fsaparc" == "1" ] ; then + { + echo "" + echo "============= Creating surfaces - other FS asegdkt_segfile and stats =======================" + echo "" + } | tee -a "$LF" + cmd="recon-all -subject $subject -cortparc2 -cortparc3 -pctsurfcon -hyporelabel $hiresflag $fsthreads" + RunIt "$cmd" "$LF" - cmd="recon-all -subject $subject -apas2aseg -aparc2aseg -wmparc -parcstats -parcstats2 -parcstats3 -segstats $hiresflag $fsthreads" - RunIt "$cmd" $LF - # removed -balabels here and do that below independent of fsaparc flag + cmd="recon-all -subject $subject -apas2aseg -aparc2aseg -wmparc -parcstats -parcstats2 -parcstats3 $hiresflag $fsthreads" + RunIt "$cmd" "$LF" + # removed -balabels here and do that below independent of fsaparc flag + # removed -segstats here (now part of mri_segstats.py/segstats.py + fi # (FS-APARC) -fi # (FS-APARC) +# ============================= MAPPED SURF-STATS ========================================= -echo " " |& tee -a $LF -echo "===================== Creating surfaces - mapped stats =========================" |& tee -a $LF -echo " " |& tee -a $LF - - - # 2x18sec create stats from mapped aparc -for hemi in lh rh; do - cmd="mris_anatomical_stats -th3 -mgz -cortex $ldir/$hemi.cortex.label -f $sdir/../stats/$hemi.aparc.DKTatlas.mapped.stats -b -a $ldir/$hemi.aparc.DKTatlas.mapped.annot -c $ldir/aparc.annot.mapped.ctab $subject $hemi white" - RunIt "$cmd" $LF -done - - -if [ "$fsaparc" == "0" ] ; then - - echo " " |& tee -a $LF - echo "============= Creating surfaces - pctsurfcon, hypo, segstats ====================" |& tee -a $LF - echo " " |& tee -a $LF +{ + echo "" + echo "===================== Creating surfaces - mapped stats =========================" + echo "" +} | tee -a "$LF" - # pctsurfcon (has no way to specify which annot to use, so we need to link ours as aparc is not available) - pushd $ldir - cmd="ln -sf lh.aparc.DKTatlas.mapped.annot lh.aparc.annot" - RunIt "$cmd" $LF - cmd="ln -sf rh.aparc.DKTatlas.mapped.annot rh.aparc.annot" - RunIt "$cmd" $LF - popd + # 2x18sec create stats from mapped aparc for hemi in lh rh; do - cmd="pctsurfcon --s $subject --$hemi-only" - RunIt "$cmd" $LF + cmd="mris_anatomical_stats -th3 -mgz -cortex $ldir/$hemi.cortex.label -f $statsdir/$hemi.aparc.DKTatlas.mapped.stats -b -a $ldir/$hemi.aparc.DKTatlas.mapped.annot -c $ldir/aparc.annot.mapped.ctab $subject $hemi white" + RunIt "$cmd" "$LF" done - pushd $ldir - cmd="rm *h.aparc.annot" - RunIt "$cmd" $LF - popd - - # 25 sec hyporelabel run whatever else can be done without sphere, cortical ribbon and segmentations - # -hyporelabel creates aseg.presurf.hypos.mgz from aseg.presurf.mgz - # -apas2aseg creates aseg.mgz by editing aseg.presurf.hypos.mgz with surfaces - cmd="recon-all -subject $subject -hyporelabel -apas2aseg $hiresflag $fsthreads" - RunIt "$cmd" $LF - -fi -# creating aparc.DKTatlas+aseg.mapped.mgz by mapping aparc.DKTatlas.mapped from surface to aseg.mgz -# (should be a nicer aparc+aseg compared to orig CNN segmentation, due to surface updates) -cmd="mri_surf2volseg --o $mdir/aparc.DKTatlas+aseg.mapped.mgz --label-cortex --i $mdir/aseg.mgz --threads $threads --lh-annot $ldir/lh.aparc.DKTatlas.mapped.annot 1000 --lh-cortex-mask $ldir/lh.cortex.label --lh-white $sdir/lh.white --lh-pial $sdir/lh.pial --rh-annot $ldir/rh.aparc.DKTatlas.mapped.annot 2000 --rh-cortex-mask $ldir/rh.cortex.label --rh-white $sdir/rh.white --rh-pial $sdir/rh.pial" -RunIt "$cmd" $LF - +# ============================= FASTSURFER - surfcon hypo stats ========================================= + + if [ "$fsaparc" == "0" ] ; then + { + echo "" + echo "============= Creating surfaces - pctsurfcon, hypo, segstats ====================" + echo "" + } | tee -a "$LF" + # pctsurfcon (has no way to specify which annot to use, so we need to link ours as aparc is not available) + pushd "$ldir" > /dev/null || (echo "Could not cd to $ldir" ; exit 1) + softlink_or_copy "lh.aparc.DKTatlas.mapped.annot" "lh.aparc.annot" "$LF" + softlink_or_copy "rh.aparc.DKTatlas.mapped.annot" "rh.aparc.annot" "$LF" + popd > /dev/null || (echo "Could not popd" ; exit 1) + for hemi in lh rh; do + cmd="pctsurfcon --s $subject --$hemi-only" + RunIt "$cmd" "$LF" + done + pushd "$ldir" > /dev/null || (echo "Could not cd to $ldir" ; exit 1) + cmd="rm *h.aparc.annot" + RunIt "$cmd" "$LF" + popd > /dev/null || (echo "Could not popd" ; exit 1) + + # 25 sec hyporelabel run whatever else can be done without sphere, cortical ribbon and segmentations + # -hyporelabel creates aseg.presurf.hypos.mgz from aseg.presurf.mgz + # -apas2aseg creates aseg.mgz by editing aseg.presurf.hypos.mgz with surfaces + cmd="recon-all -subject $subject -hyporelabel -apas2aseg $hiresflag $fsthreads" + RunIt "$cmd" "$LF" + fi -if [ "$fsaparc" == "0" ] ; then - # get stats for the aseg (note these are surface fine tuned, that may be good or bad, below we also do the stats for the input aseg (plus some processing) - cmd="recon-all -subject $subject -segstats $hiresflag $fsthreads" - RunIt "$cmd" $LF +# ============================= MAPPED-TO-VOL ========================================= -fi + # creating aparc.DKTatlas+aseg.mapped.mgz by mapping aparc.DKTatlas.mapped from surface to aseg.mgz + # (should be a nicer aparc+aseg compared to orig CNN segmentation, due to surface updates) + cmd="mri_surf2volseg --o $mdir/aparc.DKTatlas+aseg.mapped.mgz --label-cortex --i $mdir/aseg.mgz --threads $threads --lh-annot $ldir/lh.aparc.DKTatlas.mapped.annot 1000 --lh-cortex-mask $ldir/lh.cortex.label --lh-white $sdir/lh.white --lh-pial $sdir/lh.pial --rh-annot $ldir/rh.aparc.DKTatlas.mapped.annot 2000 --rh-cortex-mask $ldir/rh.cortex.label --rh-white $sdir/rh.white --rh-pial $sdir/rh.pial" + RunIt "$cmd" "$LF" +# ============================= FASTSURFER - STATS ========================================= -echo " " |& tee -a $LF -echo "===================== Creating wmparc from mapped =======================" |& tee -a $LF -echo " " |& tee -a $LF + if [ "$fsaparc" == "0" ] ; then + # get stats for the aseg (note these are surface fine tuned, that may be good or bad, below we also do the stats for the input aseg (plus some processing) + # cmd="recon-all -subject $subject -segstats $hiresflag $fsthreads" + if [[ "$segstats_legacy" == "true" ]] ; then + cmd=($python "$FASTSURFER_HOME/FastSurferCNN/mri_brainvol_stats.py" + --subject "$subject") + RunIt "$(echo_quoted "${cmd[@]}")" "$LF" - # 1m 11sec also create stats for aseg.presurf.hypos (which is basically the aseg derived from the input with CC and hypos) - # difference between this and the surface improved one above are probably tiny, so the surface improvement above can probably be skipped to save time - cmd="mri_segstats --seed 1234 --seg $mdir/aseg.presurf.hypos.mgz --sum $mdir/../stats/aseg.presurf.hypos.stats --pv $mdir/norm.mgz --empty --brainmask $mdir/brainmask.mgz --brain-vol-from-seg --excludeid 0 --excl-ctxgmwm --supratent --subcortgray --in $mdir/norm.mgz --in-intensity-name norm --in-intensity-units MR --etiv --surf-wm-vol --surf-ctx-vol --totalgray --euler --ctab /$FREESURFER_HOME/ASegStatsLUT.txt --subject $subject" - RunIt "$cmd" $LF + cmd=($python "$FASTSURFER_HOME/FastSurferCNN/mri_segstats.py" --seed 1234 + --seg "$mdir/aseg.mgz" --sum "$statsdir/aseg.stats" --pv "$mdir/norm.mgz" + "--in-intensity-name" norm "--in-intensity-units" MR --subject "$subject" + --surf-wm-vol --ctab "$FREESURFER_HOME/ASegStatsLUT.txt" --etiv + --threads "$threads") +# cmd="$python $FASTSURFER_HOME/FastSurferCNN/mri_segstats.py --seed 1234 --seg $mdir/wmparc.mgz --sum $statsdir/wmparc.stats --pv $mdir/norm.mgz --in-intensity-name norm --in-intensity-units MR --subject $subject --surf-wm-vol --ctab $FREESURFER_HOME/WMParcStatsLUT.txt --etiv" + else + # calculate brainvol stats and aseg stats with segstats.py + cmd=($python "$FASTSURFER_HOME/FastSurferCNN/segstats.py" --sid "$subject" + --segfile "$mdir/aseg.mgz" --segstatsfile "$statsdir/aseg.stats" + --pvfile "$mdir/norm.mgz" --normfile "$mdir/norm.mgz" --threads "$threads" + # --excl-ctxgmwm: exclude Left/Right WM / Cortex despite ASegStatsLUT.txt + --excludeid 0 2 3 41 42 + --lut "$FREESURFER_HOME/ASegStatsLUT.txt" --empty + measures --compute "BrainSeg" "BrainSegNotVent" "VentricleChoroidVol" + "lhCortex" "rhCortex" "Cortex" "lhCerebralWhiteMatter" + "rhCerebralWhiteMatter" "CerebralWhiteMatter" + "SubCortGray" "TotalGray" "SupraTentorial" + "SupraTentorialNotVent" "Mask($mdir/mask.mgz)" + "BrainSegVol-to-eTIV" "MaskVol-to-eTIV" "lhSurfaceHoles" + "rhSurfaceHoles" "SurfaceHoles" + "EstimatedTotalIntraCranialVol") + RunIt "$(echo_quoted "${cmd[@]}")" "$LF" + echo "Extract the brainvol stats section from segstats output." | tee -a "$LF" + # ... so stats/brainvol.stats also exists (but it is slightly different +# cmd="recon-all -subject $subject -segstats $hiresflag $fsthreads" +# RunIt "$cmd" "$LF" + + # this call is only "required" to "compute" brainvol.stats, so --normfile/--pvfile + # are not required + cmd=($python "$FASTSURFER_HOME/FastSurferCNN/segstats.py" --sid "$subject" + --segfile "$mdir/aseg.mgz" --pvfile "$mdir/norm.mgz" + --measure_only --threads "$threads" --segstatsfile "$statsdir/brainvol.stats" + measures --file "$statsdir/aseg.stats" + --import "BrainSeg" "BrainSegNotVent" "SupraTentorial" + "SupraTentorialNotVent" "SubCortGray" "lhCortex" "rhCortex" + "Cortex" "TotalGray" "lhCerebralWhiteMatter" + "rhCerebralWhiteMatter" "CerebralWhiteMatter" "Mask" + --compute "SupraTentorialNotVentVox" "BrainSegNotVentSurf" + "VentricleChoroidVol") + fi + RunIt "$(echo_quoted "${cmd[@]}")" "$LF" + fi +# ============================= MAPPED-WMPARC ========================================= +{ + echo "" + echo "===================== Creating wmparc from mapped =======================" + echo "" +} | tee -a "$LF" + + if [[ "$segstats_legacy" == "true" ]] ; then + # 1m 11sec also create stats for aseg.presurf.hypos (which is basically the aseg derived from the input with CC and + # hypos) difference between this and the surface improved one above are probably tiny, so the surface improvement + # above can probably be skipped to save time + cmd=($python "$FASTSURFER_HOME/FastSurferCNN/mri_segstats.py" --seed 1234 + --seg "$mdir/aseg.presurf.hypos.mgz" --sum "$statsdir/aseg.presurf.hypos.stats" + --pv "$mdir/norm.mgz" --empty --brainmask "$mdir/brainmask.mgz" + --brain-vol-from-seg --excludeid 0 --excl-ctxgmwm --supratent --subcortgray + "--in" "$mdir/norm.mgz" "--in-intensity-name" norm "--in-intensity-units" MR + --etiv --surf-wm-vol --surf-ctx-vol --totalgray --euler + --ctab "$FREESURFER_HOME/ASegStatsLUT.txt" --subject "$subject") + else + # segstats.py version of the mri_segstats call + cmd=($python "$FASTSURFER_HOME/FastSurferCNN/segstats.py" --sid "$subject" + --segfile "$mdir/aseg.presurf.hypos.mgz" --normfile "$mdir/norm.mgz" + --pvfile "$mdir/norm.mgz" --segstatsfile "$statsdir/aseg.presurf.hypos.stats" + # --excl-ctxgmwm: exclude Left/Right WM / Cortex despite ASegStatsLUT.txt + --excludeid 0 2 3 41 42 + --lut "$FREESURFER_HOME/ASegStatsLUT.txt" --threads "$threads" --empty + --volume_precision 1 + measures --file "$statsdir/aseg.stats" --import "all") + fi + RunIt "$(echo_quoted "${cmd[@]}")" "$LF" # -wmparc based on mapped aparc labels (from input asegdkt_segfile) (1min40sec) needs ribbon and we need to point it to aparc.mapped: cmd="mri_surf2volseg --o $mdir/wmparc.DKTatlas.mapped.mgz --label-wm --i $mdir/aparc.DKTatlas+aseg.mapped.mgz --threads $threads --lh-annot $ldir/lh.aparc.DKTatlas.mapped.annot 3000 --lh-cortex-mask $ldir/lh.cortex.label --lh-white $sdir/lh.white --lh-pial $sdir/lh.pial --rh-annot $ldir/rh.aparc.DKTatlas.mapped.annot 4000 --rh-cortex-mask $ldir/rh.cortex.label --rh-white $sdir/rh.white --rh-pial $sdir/rh.pial" - RunIt "$cmd" $LF + RunIt "$cmd" "$LF" # takes a few mins - cmd="mri_segstats --seed 1234 --seg $mdir/wmparc.DKTatlas.mapped.mgz --sum $mdir/../stats/wmparc.DKTatlas.mapped.stats --pv $mdir/norm.mgz --excludeid 0 --brainmask $mdir/brainmask.mgz --in $mdir/norm.mgz --in-intensity-name norm --in-intensity-units MR --subject $subject --surf-wm-vol --ctab $FREESURFER_HOME/WMParcStatsLUT.txt" - RunIt "$cmd" $LF - -# Create symlinks for downstream analysis (sub-segmentations, TRACULA, etc.) -if [ "$fsaparc" == "0" ] ; then - # Symlink of aparc.DKTatlas+aseg.mapped.mgz - pushd $mdir - cmd="ln -sf aparc.DKTatlas+aseg.mapped.mgz aparc.DKTatlas+aseg.mgz" - RunIt "$cmd" $LF - cmd="ln -sf aparc.DKTatlas+aseg.mapped.mgz aparc+aseg.mgz" - RunIt "$cmd" $LF - popd - - # Symlink of wmparc.mapped - pushd $mdir - cmd="ln -sf wmparc.DKTatlas.mapped.mgz wmparc.mgz" - RunIt "$cmd" $LF - popd - - # Symbolic link for mapped surface parcellations - pushd $ldir - cmd="ln -sf lh.aparc.DKTatlas.mapped.annot lh.aparc.DKTatlas.annot" - RunIt "$cmd" $LF - cmd="ln -sf rh.aparc.DKTatlas.mapped.annot rh.aparc.DKTatlas.annot" - RunIt "$cmd" $LF -fi + #cmd="mri_segstats --seed 1234 --seg $mdir/wmparc.DKTatlas.mapped.mgz --sum $mdir/../stats/wmparc.DKTatlas.mapped.stats --pv $mdir/norm.mgz --excludeid 0 --brainmask $mdir/brainmask.mgz --in $mdir/norm.mgz --in-intensity-name norm --in-intensity-units MR --subject $subject --surf-wm-vol --ctab $FREESURFER_HOME/WMParcStatsLUT.txt" + if [[ "$segstats_legacy" == "true" ]] ; then + cmd=($python "$FASTSURFER_HOME/FastSurferCNN/mri_segstats.py" + --seed 1234 --seg "$mdir/wmparc.DKTatlas.mapped.mgz" + --sum "$statsdir/wmparc.DKTatlas.mapped.stats" --pv "$mdir/norm.mgz" + --excludeid 0 --brainmask "$mdir/brainmask.mgz" "--in" "$mdir/norm.mgz" + "--in-intensity-name" norm "--in-intensity-units" MR + --subject "$subject" --surf-wm-vol + --ctab "$FREESURFER_HOME/WMParcStatsLUT.txt") + else + # + cmd=($python "$FASTSURFER_HOME/FastSurferCNN/segstats.py" + --sid "$subject" --sd "$SUBJECTS_DIR" --pvfile "$mdir/norm.mgz" + --segfile "$mdir/wmparc.DKTatlas.mapped.mgz" --normfile "$mdir/norm.mgz" + --lut "$FREESURFER_HOME/WMParcStatsLUT.txt" --threads "$threads" + --segstatsfile "$statsdir/wmparc.DKTatlas.mapped.stats" --empty + --volume_precision 1 + measures --file "$statsdir/brainvol.stats" --import "Mask" + "VentricleChoroidVol" "rhCerebralWhiteMatter" "lhCerebralWhiteMatter" + "CerebralWhiteMatter") + fi + RunIt "$(echo_quoted "${cmd[@]}")" "$LF" +# ============================= FASTSURFER - SYMLINKS ========================================= -# balabels need sphere.reg -if [ "$fssurfreg" == "1" ] ; then - # can be produced if surf registration exists - #cmd="recon-all -subject $subject -balabels $hiresflag $fsthreads" - #RunIt "$cmd" $LF - # here we run our version of balabels: mapping and annot creation is very fast - # time is used in mris_anatomical_stats (called 4 times, BA and BA-thresh for each hemi) - cmd="$python ${binpath}/fs_balabels.py --sd $SUBJECTS_DIR --sid $subject" - RunIt "$cmd" $LF -fi + # Create symlinks for downstream analysis (sub-segmentations, TRACULA, etc.) + if [ "$fsaparc" == "0" ] ; then + # Symlink of aparc.DKTatlas+aseg.mapped.mgz + pushd "$mdir" > /dev/null || (echo "Could not cd to $mdir" ; exit 1) + softlink_or_copy "aparc.DKTatlas+aseg.mapped.mgz" "aparc.DKTatlas+aseg.mgz" "$LF" + softlink_or_copy "aparc.DKTatlas+aseg.mapped.mgz" "aparc+aseg.mgz" "$LF" + # Symlink of wmparc.mapped + softlink_or_copy "wmparc.DKTatlas.mapped.mgz" "wmparc.mgz" "$LF" + popd > /dev/null || ( echo "Could not popd" ; exit 1 ) + # Symbolic link for mapped surface parcellations + pushd "$ldir" > /dev/null || (echo "Could not cd to $ldir" ; exit 1) + softlink_or_copy "lh.aparc.DKTatlas.mapped.annot" "lh.aparc.DKTatlas.annot" "$LF" + softlink_or_copy "rh.aparc.DKTatlas.mapped.annot" "rh.aparc.DKTatlas.annot" "$LF" + popd > /dev/null || ( echo "Could not popd" ; exit 1 ) + fi + + +# ============================= BALABELS ========================================= + + # balabels need sphere.reg + if [ "$fssurfreg" == "1" ] ; then + # can be produced if surf registration exists + #cmd="recon-all -subject $subject -balabels $hiresflag $fsthreads" + #RunIt "$cmd" "$LF" + # here we run our version of balabels: mapping and annot creation is very fast + # time is used in mris_anatomical_stats (called 4 times, BA and BA-thresh for each hemi) + cmd="$python ${binpath}/fs_balabels.py --sd $SUBJECTS_DIR --sid $subject" + RunIt "$cmd" "$LF" + fi -echo " " |& tee -a $LF -echo "================= DONE =========================================================" |& tee -a $LF -echo " " |& tee -a $LF # Collect info -EndTime=`date` -tSecEnd=`date '+%s'` -tRunHours=`echo \($tSecEnd - $tSecStart\)/3600|bc -l` -tRunHours=`printf %6.3f $tRunHours` +EndTime=$(date) +tSecEnd=$(date '+%s') +tRunHours=$(printf %6.3f "$(bc <<< "($tSecEnd - $tSecStart) / 3600")") -echo "Started at $StartTime " |& tee -a $LF -echo "Ended at $EndTime" |& tee -a $LF -echo "#@#%# recon-surf-run-time-hours $tRunHours" |& tee -a $LF +{ + echo "" + echo "================= DONE =========================================================" + echo "" + + echo "Started at $StartTime" + echo "Ended at $EndTime" + echo "#@#%# recon-surf-run-time-hours $tRunHours" +} | tee -a "$LF" # Create the Done File -echo "------------------------------" > $DoneFile -echo "SUBJECT $subject" >> $DoneFile -echo "START_TIME $StartTime" >> $DoneFile -echo "END_TIME $EndTime" >> $DoneFile -echo "RUNTIME_HOURS $tRunHours" >> $DoneFile -echo "USER `id -un`" >> $DoneFile -echo "HOST `hostname`" >> $DoneFile -echo "PROCESSOR `uname -m`" >> $DoneFile -echo "OS `uname -s`" >> $DoneFile -echo "UNAME `uname -a`" >> $DoneFile -echo "VERSION $VERSION" >> $DoneFile -echo "CMDPATH $0" >> $DoneFile -echo "CMDARGS ${inputargs[*]}" >> $DoneFile - -echo "recon-surf.sh $subject finished without error at `date`" |& tee -a $LF +{ + echo "------------------------------" + echo "SUBJECT $subject" + echo "START_TIME $StartTime" + echo "END_TIME $EndTime" + echo "RUNTIME_HOURS $tRunHours" + # id -n sends an error message in docker (no user name), fall back to the USER environment variable or + username=$(id -un 2>&1) + if echo "$username" | grep -q "^id: " ; then + if [[ -n "$USER" ]] ; then username="$USER" + else username="$(id -u)" + fi + fi + echo "USER $username" + echo "HOST $(hostname)" + echo "PROCESSOR $(uname -m)" + echo "OS $(uname -s)" + echo "UNAME $(uname -a)" + echo "VERSION $VERSION" + echo "CMDPATH $0" + echo "CMDARGS ${inputargs[*]}" +} > "$DoneFile" +echo "recon-surf.sh $subject finished without error at $(date)" | tee -a "$LF" cmd="$python ${binpath}utils/extract_recon_surf_time_info.py -i $LF -o $SUBJECTS_DIR/$subject/scripts/recon-surf_times.yaml" RunIt "$cmd" "/dev/null" diff --git a/recon_surf/recon-surfreg.sh b/recon_surf/recon-surfreg.sh index 0d9b88d0..87175a2c 100755 --- a/recon_surf/recon-surfreg.sh +++ b/recon_surf/recon-surfreg.sh @@ -19,7 +19,7 @@ FS_VERSION_SUPPORT="7.3.2" # Regular flags default subject=""; # Subject name -python="python3.8" # python version +python="python3.10" # python version DoParallel=0 # if 1, run hemispheres in parallel threads="1" # number of threads to use for running FastSurfer allow_root="" # flag for allowing execution as root user @@ -109,12 +109,12 @@ function RunIt() if [[ $# -eq 3 ]] then CMDF=$3 - echo "echo \"$cmd\" " |& tee -a $CMDF - echo "$timecmd $cmd " |& tee -a $CMDF + echo "echo \"$cmd\" " | tee -a $CMDF + echo "$timecmd $cmd " | tee -a $CMDF echo "if [ \${PIPESTATUS[0]} -ne 0 ] ; then exit 1 ; fi" >> $CMDF else - echo $cmd |& tee -a $LF - $timecmd $cmd |& tee -a $LF + echo $cmd | tee -a $LF + $timecmd $cmd 2>&1 | tee -a $LF #if [ ${PIPESTATUS[0]} -ne 0 ] ; then exit 1 ; fi fi } @@ -356,28 +356,28 @@ if [ $DoneFile != /dev/null ] ; then rm -f $DoneFile ; fi LF=$SUBJECTS_DIR/$subject/scripts/recon-surfreg.log if [ $LF != /dev/null ] ; then rm -f $LF ; fi echo "Log file for recon-surfreg.sh" >> $LF -date |& tee -a $LF -echo "" |& tee -a $LF -echo "export SUBJECTS_DIR=$SUBJECTS_DIR" |& tee -a $LF -echo "cd `pwd`" |& tee -a $LF -echo $0 ${inputargs[*]} |& tee -a $LF -echo "" |& tee -a $LF -cat $FREESURFER_HOME/build-stamp.txt |& tee -a $LF -echo $VERSION |& tee -a $LF -uname -a |& tee -a $LF +date 2>&1 | tee -a $LF +echo "" | tee -a $LF +echo "export SUBJECTS_DIR=$SUBJECTS_DIR" | tee -a $LF +echo "cd `pwd`" | tee -a $LF +echo $0 ${inputargs[*]} | tee -a $LF +echo "" | tee -a $LF +cat $FREESURFER_HOME/build-stamp.txt 2>&1 | tee -a $LF +echo $VERSION | tee -a $LF +uname -a 2>&1 | tee -a $LF # Print parallelization parameters -echo " " |& tee -a $LF +echo " " | tee -a $LF if [ "$DoParallel" == "1" ] then - echo " RUNNING both hemis in PARALLEL " |& tee -a $LF + echo " RUNNING both hemis in PARALLEL " | tee -a $LF else - echo " RUNNING both hemis SEQUENTIALLY " |& tee -a $LF + echo " RUNNING both hemis SEQUENTIALLY " | tee -a $LF fi -echo " RUNNING $OMP_NUM_THREADS number of OMP THREADS " |& tee -a $LF -echo " RUNNING $ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS number of ITK THREADS " |& tee -a $LF -echo " " |& tee -a $LF +echo " RUNNING $OMP_NUM_THREADS number of OMP THREADS " | tee -a $LF +echo " RUNNING $ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS number of ITK THREADS " | tee -a $LF +echo " " | tee -a $LF #if false; then @@ -396,9 +396,9 @@ for hemi in lh rh; do CMDFS="$CMDFS $CMDF" rm -rf $CMDF - echo "echo \" \"" |& tee -a $CMDF - echo "echo \"============ Creating surfaces $hemi - FS sphere, surfreg ===============\"" |& tee -a $CMDF - echo "echo \" \"" |& tee -a $CMDF + echo "echo \" \"" | tee -a $CMDF + echo "echo \"============ Creating surfaces $hemi - FS sphere, surfreg ===============\"" | tee -a $CMDF + echo "echo \" \"" | tee -a $CMDF # Surface registration for cross-subject correspondence (registration to fsaverage) cmd="recon-all -subject $subject -hemi $hemi -sphere -no-isrunning $fsthreads" @@ -432,9 +432,9 @@ for hemi in lh rh; do # $SUBJECTS_DIR/$subject/label/${hemi}.aparc.DKTatlas-guided.annot" if [ "$DoParallel" == "0" ] ; then - echo " " |& tee -a $LF - echo " RUNNING $hemi sequentially ... " |& tee -a $LF - echo " " |& tee -a $LF + echo " " | tee -a $LF + echo " RUNNING $hemi sequentially ... " | tee -a $LF + echo " " | tee -a $LF chmod u+x $CMDF RunIt "$CMDF" $LF fi @@ -444,16 +444,16 @@ done # hemi loop ---------------------------------- if [ "$DoParallel" == 1 ] ; then - echo " " |& tee -a $LF - echo " RUNNING HEMIs in PARALLEL !!! " |& tee -a $LF - echo " " |& tee -a $LF + echo " " | tee -a $LF + echo " RUNNING HEMIs in PARALLEL !!! " | tee -a $LF + echo " " | tee -a $LF RunBatchJobs $LF $CMDFS fi -echo " " |& tee -a $LF -echo "================= DONE =========================================================" |& tee -a $LF -echo " " |& tee -a $LF +echo " " | tee -a $LF +echo "================= DONE =========================================================" | tee -a $LF +echo " " | tee -a $LF # Collect info EndTime=`date` @@ -461,9 +461,9 @@ tSecEnd=`date '+%s'` tRunHours=`echo \($tSecEnd - $tSecStart\)/3600|bc -l` tRunHours=`printf %6.3f $tRunHours` -echo "Started at $StartTime " |& tee -a $LF -echo "Ended at $EndTime" |& tee -a $LF -echo "#@#%# recon-surfreg-run-time-hours $tRunHours" |& tee -a $LF +echo "Started at $StartTime " | tee -a $LF +echo "Ended at $EndTime" | tee -a $LF +echo "#@#%# recon-surfreg-run-time-hours $tRunHours" | tee -a $LF # Create the Done File echo "------------------------------" > $DoneFile @@ -471,7 +471,7 @@ echo "SUBJECT $subject" >> $DoneFile echo "START_TIME $StartTime" >> $DoneFile echo "END_TIME $EndTime" >> $DoneFile echo "RUNTIME_HOURS $tRunHours" >> $DoneFile -echo "USER `id -un`" >> $DoneFile +echo "USER `id -un`" >> $DoneFile 2> /dev/null echo "HOST `hostname`" >> $DoneFile echo "PROCESSOR `uname -m`" >> $DoneFile echo "OS `uname -s`" >> $DoneFile @@ -480,7 +480,7 @@ echo "VERSION $VERSION" >> $DoneFile echo "CMDPATH $0" >> $DoneFile echo "CMDARGS ${inputargs[*]}" >> $DoneFile -echo "recon-surfreg.sh $subject finished without error at `date`" |& tee -a $LF +echo "recon-surfreg.sh $subject finished without error at `date`" | tee -a $LF cmd="$python ${binpath}utils/extract_recon_surf_time_info.py -i $LF -o $SUBJECTS_DIR/$subject/scripts/recon-surfreg_times.yaml" RunIt "$cmd" "/dev/null" diff --git a/recon_surf/rewrite_mc_surface.py b/recon_surf/rewrite_mc_surface.py index 8372f94f..bd13429e 100644 --- a/recon_surf/rewrite_mc_surface.py +++ b/recon_surf/rewrite_mc_surface.py @@ -21,14 +21,13 @@ def options_parse(): - """Command line option parser. + """ + Create a command line interface and return command line options. Returns ------- options - object holding options - - + Namespace object holding options. """ parser = optparse.OptionParser( version="$Id: rewrite_mc_surface,v 1.1 2020/06/23 15:42:08 henschell $", @@ -54,19 +53,21 @@ def options_parse(): def resafe_surface(insurf: str, outsurf: str, pretess: str) -> None: - """Take path to insurf and rewrite it to outsurf thereby fixing vertex locs flag error. + """ + Take path to insurf and rewrite it to outsurf thereby fixing vertex locs flag error. - (scannerRAS instead of surfaceRAS after marching cube) + This function fixes header information not properly saved in marching cubes. + It makes sure the file header correctly references the scannerRAS instead of the surfaceRAS, + i.e. filename and volume is set to the correct data in the header. Parameters ---------- insurf : str - Path and name of input surface + Path and name of input surface. outsurf : str - Path and name of output surface + Path and name of output surface. pretess : str - Path and name of file the input surface was created on (e.g. filled-pretess127.mgz) - + Path and name of file the input surface was created on (e.g. filled-pretess127.mgz). """ surf = fs.read_geometry(insurf, read_metadata=True) diff --git a/recon_surf/rewrite_oriented_surface.py b/recon_surf/rewrite_oriented_surface.py new file mode 100644 index 00000000..fc9c4a37 --- /dev/null +++ b/recon_surf/rewrite_oriented_surface.py @@ -0,0 +1,124 @@ +# Copyright 2024 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import shutil +# IMPORTS +import sys +import argparse +from pathlib import Path + +import lapy + +__version__ = "1.0" + + +def make_parser() -> argparse.ArgumentParser: + """ + Create a command line interface and return command line options. + + Returns + ------- + options + Namespace object holding options. + """ + parser = argparse.ArgumentParser( + description="Script to load and safe surface (that are guaranteed to be " + "correctly oriented) under a given name", + ) + parser.add_argument( + "--file", "-f", + type=Path, + dest="file", + help="path to surface to check and fix", + required=True, + ) + parser.add_argument( + "--backup", + type=Path, + dest="backup", + help="if the surface is corrupted, create a backup of the original surface. " + "Default: no backup.", + default=None, + ) + parser.add_argument( + "--version", + action="version", + version=f"{__version__} 2024/08/08 12:20:10 kueglerd", + ) + return parser + + +def resafe_surface( + surface_file: Path | str, + surface_backup: Path | str | None = None, +) -> bool: + """ + Take path to surface_file and rewrite it to fix improperly oriented triangles. + + If the surface is not oriented and surface_backup is set, rename the old + surface_file to surface_backup. Else just overwrite with the corrected surface. + + Parameters + ---------- + surface_file : Path, str + Path and name of input surface. + surface_backup : Path, str, optional + Path and name of output surface. + + Returns + ------- + bool + Whether the surface was rewritten. + """ + import getpass + try: + getpass.getuser() + except Exception: + # nibabel crashes in write_geometry, if getpass.getuser does not return + # make sure the process has a username + from os import environ + environ.setdefault("USERNAME", "UNKNOWN") + + triamesh = lapy.TriaMesh.read_fssurf(str(surface_file)) + fsinfo = triamesh.fsinfo + + # make sure the triangles are oriented (normals pointing to the same direction + if not triamesh.is_oriented(): + if surface_backup is not None: + print(f"Renaming {surface_file} to {surface_backup}") + shutil.move(surface_file, surface_backup) + + print("Surface was not oriented, flipping corrupted normals.") + triamesh.orient_() + + from packaging.version import Version + if Version(lapy.__version__) <= Version("1.0.1"): + print(f"lapy version {lapy.__version__}<=1.0.1 detected, fixing fsinfo.") + triamesh.fsinfo = fsinfo + + triamesh.write_fssurf(str(surface_file)) + return True + else: + print("Surface was already oriented.") + return False + + +if __name__ == "__main__": + # Command Line options are error checking done here + parser = make_parser() + args = parser.parse_args() + + print(f"Reading in surface: {args.file} ...") + if resafe_surface(args.file, args.backup): + print(f"Outputting surface as: {args.file}") + sys.exit(0) diff --git a/recon_surf/rotate_sphere.py b/recon_surf/rotate_sphere.py index c61bc140..b944b376 100644 --- a/recon_surf/rotate_sphere.py +++ b/recon_surf/rotate_sphere.py @@ -18,10 +18,12 @@ # IMPORTS import optparse +import sys + import numpy as np from numpy import typing as npt -import sys import nibabel.freesurfer.io as fs + import align_points as align @@ -35,7 +37,7 @@ --out Dependencies: - Python 3.8 + Python 3.8+ SimpleITK https://simpleitk.org/ (v2.1.1) @@ -66,13 +68,13 @@ def options_parse(): - """Command line option parser. + """ + Create a command line interface and return command line options. Returns ------- - options - object holding options - + options : argparse.Namespace + Namespace object holding options. """ parser = optparse.OptionParser( version="$Id: rotate_sphere.py,v 1.0 2022/03/18 21:22:08 mreuter Exp $", @@ -106,7 +108,8 @@ def align_aparc_centroids( labels_dst: npt.ArrayLike, label_ids: npt.ArrayLike = [] ) -> np.ndarray: - """Align centroid of aparc parcels on the sphere (Attention mapping back to sphere!). + """ + Align centroid of aparc parcels on the sphere (Attention mapping back to sphere!). Parameters ---------- @@ -119,13 +122,12 @@ def align_aparc_centroids( labels_dst : npt.ArrayLike Labels of aparc parcelation for rotation destination. label_ids : npt.ArrayLike - Ids of the centroid to be aligned. Defaults to [] + Ids of the centroid to be aligned. Defaults to []. Returns ------- - R - Rotation Matrix - + R : npt.NDArray[float] + Rotation Matrix. """ #nferiorparietal, inferiortemporal, lateraloccipital, postcentral, posteriorsingulate # precentral, precuneus, superiorfrontal, supramarginal diff --git a/recon_surf/sample_parc.py b/recon_surf/sample_parc.py new file mode 100644 index 00000000..3d784ad2 --- /dev/null +++ b/recon_surf/sample_parc.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 + + +# Copyright 2024 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# IMPORTS +import optparse +import sys + +import numpy as np +import nibabel.freesurfer.io as fs +import nibabel as nib +from scipy import sparse +from scipy.sparse.csgraph import connected_components +from lapy import TriaMesh + +from smooth_aparc import smooth_aparc + + +HELPTEXT = """ +Script to sample labels from image to surface and clean up. + +USAGE: +sample_parc --inseg --insurf --incort + --seglut --surflut --outaparc + --projmm --radius + + +Dependencies: + Python 3.8 + + Numpy + http://www.numpy.org + + Nibabel to read and write FreeSurfer surface meshes + http://nipy.org/nibabel/ + + +Original Author: Martin Reuter +Date: Dec-18-2023 + +""" + +h_inseg = "path to input segmentation image" +h_incort = "path to input cortex label mask" +h_insurf = "path to input surface" +h_outaparc = "path to output aparc" +h_surflut = "FreeSurfer look-up-table for values on surface" +h_seglut = "Look-up-table for values in segmentation image (rows need to correspond to surflut)" +h_projmm = "Sample along normal at projmm distance (in mm), default 0" +h_radius = "Search around sample location at radius (in mm) for label if 'unknown', default None" + + +def options_parse(): + """ + Create a command line interface and return command line options. + + Returns + ------- + options : argparse.Namespace + Namespace object holding options. + """ + parser = optparse.OptionParser( + version="$Id: smooth_aparc,v 1.0 2018/06/24 11:34:08 mreuter Exp $", + usage=HELPTEXT, + ) + parser.add_option("--inseg", dest="inseg", help=h_inseg) + parser.add_option("--insurf", dest="insurf", help=h_insurf) + parser.add_option("--incort", dest="incort", help=h_incort) + parser.add_option("--surflut", dest="surflut", help=h_surflut) + parser.add_option("--seglut", dest="seglut", help=h_seglut) + parser.add_option("--outaparc", dest="outaparc", help=h_outaparc) + parser.add_option("--projmm", dest="projmm", help=h_projmm, default=0.0, type="float") + parser.add_option("--radius", dest="radius", help=h_radius, default=None, type="float") + (options, args) = parser.parse_args() + + if options.insurf is None or options.inseg is None or options.outaparc is None: + sys.exit("ERROR: Please specify input surface, input image and output aparc!") + + if options.surflut is None or options.seglut is None: + sys.exit("ERROR: Please specify surface and segmentatin image LUT!") + + # maybe later add functionality, to not have a cortex label, e.g. + # like FreeSurfer find largest connected component and fill only + # the other unknown regions + if options.incort is None: + sys.exit("ERROR: Please specify surface cortex label!") + + return options + +def construct_adj_cluster(tria, annot): + """ + Compute adjacency matrix of edges from same annotation label only. + + Operates only on triangles and removes edges that cross annotation + label boundaries. + + Returns + ------- + csc_matrix + The non-directed adjacency matrix + will be symmetric. Each inner edge (i,j) will have + the number of triangles that contain this edge. + Inner edges usually 2, boundary edges 1. Higher + numbers can occur when there are non-manifold triangles. + The sparse matrix can be binarized via: + adj.data = np.ones(adj.data.shape). + """ + t0 = tria[:, 0] + t1 = tria[:, 1] + t2 = tria[:, 2] + i = np.column_stack((t0, t1, t1, t2, t2, t0)).reshape(-1) + j = np.column_stack((t1, t0, t2, t1, t0, t2)).reshape(-1) + ia = annot[i] + ja = annot[j] + keep_edges = (ia == ja) + i = i[keep_edges] + j = j[keep_edges] + dat = np.ones(i.shape) + n = annot.shape[0] + return sparse.csc_matrix((dat, (i, j)), shape=(n, n)) + +def find_all_islands(surf, annot): + """ + Find vertices in disconnected islands for all labels in surface annotation. + + Parameters + ---------- + surf : tuple + Surface as returned by nibabel fs.read_geometry, where: + surf[0] is the np.array of (n, 3) vertex coordinates and + surf[1] is the np.array of (m, 3) triangle indices. + annot : np.ndarray + Annotation as an int array of (n,) with label ids for each vertex. + This is for example the first element of the tupel returned by + nibabel fs.read_annot. + + Returns + ------- + vidx : np.ndarray (i,) + Arrray listing vertex indices of island vertices, empty if no islands + (components disconnetcted from largest label region) are found. + """ + # construct adjaceny matrix without edges across regions: + adjM = construct_adj_cluster(surf[1], annot) + # compute disconnected components + n_comp, labels = connected_components(csgraph=adjM, directed=False, return_labels=True) + # for each label, get islands that are not connected to main component + lids = np.unique(annot) + vidx = np.array([], dtype = np.int32) + for lid in lids: + ll = labels[annot==lid] + lidx = np.arange(labels.size)[annot==lid] + lmax = np.bincount(ll).argmax() + v = lidx[(ll != lmax)] + if v.size > 0: + print("Found disconnected islands ({} vertices total) for label {}!".format(v.size, lid)) + vidx = np.concatenate((vidx,v)) + return vidx + +def sample_nearest_nonzero(img, vox_coords, radius=3.0): + """ + Sample closest non-zero value in a ball of radius around vox_coords. + + Parameters + ---------- + img : nibabel.image + Image to sample. Voxels need to be isotropic. + vox_coords : ndarray float shape(n,3) + Coordinates in voxel space around which to search. + radius : float default 3.0 + Consider all voxels inside this radius to find a non-zero value. + + Returns + ------- + samples : np.ndarray(n,) + Sampled values, returns zero for vertices where values are zero in ball. + """ + # check for isotropic voxels + voxsize = img.header.get_zooms() + print("Check isotropic vox sizes: {}".format(voxsize)) + assert (np.max(np.abs(voxsize - voxsize[0])) < 0.001), 'Voxels not isotropic!' + data = np.asarray(img.dataobj) + + # radius in voxels: + rvox = radius * voxsize[0] + + # sample window around nearest voxel + x_nn = np.rint(vox_coords).astype(int) + # Reason: to always have the same number of voxels that we check + # and to be consistent with FreeSurfer, we center the window at + # the nearest neighbor voxel, instead of at the float vox coordinates + + # create box with 2*rvox+1 side length to fully contain ball + # and get coordiante offsets with zero at center + ri = np.floor(rvox).astype(int) + ll = np.arange(-ri,ri+1) + xv, yv, zv = np.meshgrid(ll, ll, ll) + # modify distances slightly, to avoid randomness when + # sorting with different radius values for voxels that otherwise + # have the same distance to center + xvd = xv+0.001 + yvd = yv+0.002 + zvd = zv+0.003 + ddm = np.sqrt(xvd*xvd + yvd*yvd + zvd*zvd).flatten() + # also compute correct distance for ball mask below + dd = np.sqrt(xv*xv + yv*yv + zv*zv).flatten() + ddball = dd<=rvox + + # flatten and keep only ball with radius + xv = xv.flatten()[ddball] + yv = yv.flatten()[ddball] + zv = zv.flatten()[ddball] + ddm = ddm[ddball] + + # stack to get offset vectors + offsets = np.column_stack((xv, yv, zv)) + + # sort offsets according to distance + # Note: we keep the first zero voxel so we can later + # determine if all voxels are zero with the argmax trick + sortidx = np.argsort(ddm) + offsets = offsets[sortidx,:] + + # reshape and tile to add to list of coords + n = x_nn.shape[0] + toffsets = np.tile(offsets.transpose().reshape(1,3,offsets.shape[0]),(n,1,1)) + s_coords = x_nn[:, :, np.newaxis] + toffsets + + # get image data at the s_coords locations + s_data = data[s_coords[:,0], s_coords[:,1], s_coords[:,2]] + + # get first non-zero if possible + nzidx = (s_data!=0).argmax(axis=1) + # the above return index zero if all elements are zero which is OK for us + # as we can then sample there and get a value of zero + samples = s_data[np.arange(s_data.shape[0]),nzidx] + return samples + + +def sample_img(surf, img, cortex=None, projmm=0.0, radius=None): + """ + Sample volume at a distance from the surface. + + Parameters + ---------- + surf : tuple | str + Surface as returned by nibabel fs.read_geometry, where: + surf[0] is the np.array of (n, 3) vertex coordinates and + surf[1] is the np.array of (m, 3) triangle indices. + If type is str, read surface from file. + img : nibabel.image | str + Image to sample. + If type is str, read image from file. + cortex : np.ndarray | str + Filename of cortex label or np.array with cortex indices. + projmm : float + Sample projmm mm along the surface vertex normals (default=0). + radius : float, optional + If given and if the sample is equal to zero, then consider + all voxels inside this radius to find a non-zero value. + + Returns + ------- + samples : np.ndarray (n,) + Sampled values. + """ + if isinstance(surf, str): + surf = fs.read_geometry(surf, read_metadata=True) + if isinstance(img, str): + img = nib.load(img) + if isinstance(cortex, str): + cortex = fs.read_label(cortex) + nvert = surf[0].shape[0] + # Compute Cortex Mask + if cortex is not None: + mask = np.zeros(nvert, dtype=bool) + mask[cortex] = True + else: + mask = np.ones(nvert, dtype=bool) + + data = np.asarray(img.dataobj) + # Use LaPy TriaMesh for vertex normal computation + T = TriaMesh(surf[0], surf[1]) + + # make sure the triangles are oriented (normals pointing to the same direction + if not T.is_oriented(): + print("WARNING: Surface is not oriented, flipping corrupted normals.") + T.orient_() + + # compute sample coordinates projmm mm along the surface normal + # in surface RAS coordiante system: + x = T.v + projmm * T.vertex_normals() + # mask cortex + xx = x[mask] + + # compute Transformation from surface RAS to voxel space: + Torig = img.header.get_vox2ras_tkr() + Tinv = np.linalg.inv(Torig) + x_vox = np.dot(xx, Tinv[:3, :3].T) + Tinv[:3, 3] + # sample at nearest voxel + x_nn = np.rint(x_vox).astype(int) + samples_nn = data[x_nn[:,0], x_nn[:,1], x_nn[:,2]] + # no search for zeros, done: + if not radius: + samplesfull = np.zeros(nvert, dtype="int") + samplesfull[mask] = samples_nn + return samplesfull + # search for zeros, but no zeros exist, done: + zeros = np.asarray(samples_nn==0).nonzero()[0] + if zeros.size == 0: + samplesfull = np.zeros(nvert, dtype="int") + samplesfull[mask] = samples_nn + return samplesfull + # here we need to do the hard work of searching in a windows + # for non-zero samples + print("sample_img: found {} zero samples, searching radius ...".format(zeros.size)) + z_nn = x_nn[zeros] + z_samples = sample_nearest_nonzero(img, z_nn, radius=radius) + samples_nn[zeros] = z_samples + samplesfull = np.zeros(nvert, dtype="int") + samplesfull[mask] = samples_nn + return samplesfull + + +def replace_labels(img_labels, img_lut, surf_lut): + """ + Replace image labels with corresponding surface labels or unknown. + + Parameters + ---------- + img_labels : np.ndarray(n,) + Array with imgage label ids. + img_lut : str + Filename for image label look up table. + surf_lut : str + Filename for surface label look up table. + + Returns + ------- + surf_labels : np.ndarray (n,) + Array with surface label ids. + surf_ctab : np.ndarray shape(m,4) + Surface color table (RGBA). + surf_names : np.ndarray[str] shape(m,) + Names of label regions. + """ + surflut = np.loadtxt(surf_lut, usecols=(0,2,3,4,5), dtype="int") + surf_ids = surflut[:,0] + surf_ctab = surflut[:,1:5] + surf_names = np.loadtxt(surf_lut, usecols=(1), dtype="str") + imglut = np.loadtxt(img_lut, usecols=(0,2,3,4,5), dtype="int") + img_ids = imglut[:,0] + img_names = np.loadtxt(img_lut, usecols=(1), dtype="str") + assert (np.all(img_names == surf_names)), "Label names in the LUTs do not agree!" + lut = np.zeros((img_labels.max()+1,), dtype = img_labels.dtype) + lut[img_ids] = surf_ids + surf_labels = lut[img_labels] + return surf_labels, surf_ctab, surf_names + + +def sample_parc (surf, seg, imglut, surflut, outaparc, cortex=None, projmm=0.0, radius=None): + """ + Sample cortical GM labels from image to surface and smooth. + + Parameters + ---------- + surf : tuple | str + Surface as returned by nibabel fs.read_geometry, where: + surf[0] is the np.array of (n, 3) vertex coordinates and + surf[1] is the np.array of (m, 3) triangle indices. + If type is str, read surface from file. + seg : nibabel.image | str + Image to sample. + If type is str, read image from file. + imglut : str + Filename for image label look up table. + surflut : str + Filename for surface label look up table. + outaparc : str + Filename for output surface parcellation. + cortex : np.ndarray | str + Filename of cortex label or np.ndarray with cortex indices. + projmm : float + Sample projmm mm along the surface vertex normals (default=0). + radius : float, optional + If given and if the sample is equal to zero, then consider + all voxels inside this radius to find a non-zero value. + """ + if isinstance(cortex, str): + cortex = fs.read_label(cortex) + if isinstance(surf, str): + surf = fs.read_geometry(surf, read_metadata=True) + if isinstance(seg, str): + seg = nib.load(seg) + # get rid of unknown labels first and translate the rest (avoids too much filling + # later as sampling will search around sample point if label is zero) + segdata, surfctab, surfnames = replace_labels(np.asarray(seg.dataobj), imglut, surflut) + # create img with new data (needed by sample img) + seg2 = nib.MGHImage(segdata, seg.affine, seg.header) + # sample from image to surface (and search if zero label) + surfsamples = sample_img(surf, seg2, cortex, projmm, radius) + # find label islands + vidx = find_all_islands(surf, surfsamples) + # set islands to zero (to ensure they get smoothed away later) + surfsamples[vidx] = 0 + # smooth boundaries and remove islands inside cortex region + smooths = smooth_aparc(surf, surfsamples, cortex) + # write annotation + fs.write_annot(outaparc, smooths, ctab=surfctab, names=surfnames) + + +if __name__ == "__main__": + # Command Line options are error checking done here + options = options_parse() + + sample_parc(options.insurf, options.inseg, options.seglut, options.surflut, options.outaparc, options.incort, options.projmm, options.radius) + + sys.exit(0) + diff --git a/recon_surf/smooth_aparc.py b/recon_surf/smooth_aparc.py index fbd50e78..43a063e8 100644 --- a/recon_surf/smooth_aparc.py +++ b/recon_surf/smooth_aparc.py @@ -20,8 +20,8 @@ import optparse import sys import numpy as np -from numpy import typing as npt import nibabel.freesurfer.io as fs +from numpy import typing as npt from scipy import sparse @@ -33,7 +33,7 @@ Dependencies: - Python 3.8 + Python 3.8+ Numpy http://www.numpy.org @@ -54,13 +54,13 @@ def options_parse(): - """Command line option parser. + """ + Create a command line interface and return command line options. Returns ------- options - object holding options - + Namespace object holding options. """ parser = optparse.OptionParser( version="$Id: smooth_aparc,v 1.0 2018/06/24 11:34:08 mreuter Exp $", @@ -78,28 +78,29 @@ def options_parse(): return options -def get_adjM(trias: npt.NDArray, n: int): - """[MISSING]. +def get_adjM(trias: npt.NDArray[int], n: int): + """ + Create symmetric sparse adjacency matrix of triangle mesh. Parameters ---------- - trias : npt.NDArray + trias : npt.NDArray[int](m, 3) + Triangle mesh matrix. n : int - Shape of tje matrix + Shape of output (n,n) adjaceny matrix, where n>=m. Returns ------- - adjM : np.ndarray - Adjoint matrix - + adjM : np.ndarray (bool) shape (n,n) + Symmetric sparse CSR adjacency matrix, true corresponds to an edge. """ - I = trias - J = I[:, [1, 2, 0]] + T = trias + J = T[:, [1, 2, 0]] # flatten - I = I.flatten() + T = T.flatten() J = J.flatten() - adj = sparse.csr_matrix((np.ones(I.shape, dtype=bool), (I, J)), shape=(n, n)) + adj = sparse.csr_matrix((np.ones(T.shape, dtype=bool), (T, J)), shape=(n, n)) # if max adj is > 1 we have non manifold or mesh trias are not oriented # if matrix is not symmetric, we have a boundary # in case we have boundary, make sure this is a symmetric matrix @@ -108,18 +109,18 @@ def get_adjM(trias: npt.NDArray, n: int): def bincount2D_vectorized(a: npt.NDArray) -> np.ndarray: - """Count number of occurrences of each value in array of non-negative ints. + """ + Count number of occurrences of each value in array of non-negative ints. Parameters ---------- - a : npt.NDArray - Array + a : np.ndarray + Input 2D array of non-negative ints. Returns ------- np.ndarray - Array of counted values - + Array of counted values. """ N = a.max() + 1 a_offs = a + np.arange(a.shape[0])[:, None] * N @@ -129,27 +130,30 @@ def bincount2D_vectorized(a: npt.NDArray) -> np.ndarray: def mode_filter( adjM: sparse.csr_matrix, labels: npt.NDArray[str], - fillonlylabel: str = "", + fillonlylabel = None, novote: npt.ArrayLike = [] -) -> npt.NDArray[str]: - """[MISSING]. +) -> npt.NDArray[int]: + """ + Apply mode filter (smoothing) to integer labels on mesh vertices. Parameters ---------- - adjM : sparse.csr_matrix - Adjoint matrix - labels : npt.NDArray[str] - List of labels - fillonlylabel : str - Label to fill exclusively. Defaults to "" + adjM : sparse.csr_matrix[bool] + Symmetric adjacency matrix defining edges between vertices, + this determines what edges can vote so usually one adds the + identity to the adjacency matrix so that each vertex is included + in its own vote. + labels : npt.NDArray[int] + List of integer labels at each vertex of the mesh. + fillonlylabel : int + Label to fill exclusively. Defaults to None to smooth all labels. novote : npt.ArrayLike - Entries that should not vote. Defaults to [] + Label ids that should not vote. Defaults to []. Returns ------- - labels_new - New filtered labels - + labels_new : npt.NDArray[int] + New smoothed labels. """ # make sure labels lengths equals adjM dimension n = labels.shape[0] @@ -191,14 +195,14 @@ def mode_filter( # of all ids to fill, find neighbors nbrs = adjM[ids, :] # get vertex ids (I, J ) of each edge in nbrs - [I, J, V] = sparse.find(nbrs) + [II, JJ, VV] = sparse.find(nbrs) # check if we have neighbors with -1 or 0 # this could produce problems in the loop below, so lets stop for now: - nlabels = labels[J] + nlabels = labels[JJ] if any(nlabels == -1) or any(nlabels == 0): sys.exit("there are -1 or 0 labels in neighbors!") # create sparse matrix with labels at neighbors - nlabels = sparse.csr_matrix((labels[J], (I, J))) + nlabels = sparse.csr_matrix((labels[JJ], (II, JJ))) # print("nlabels: {}".format(nlabels)) from scipy.stats import mode @@ -208,8 +212,8 @@ def mode_filter( # get rid of rows that have uniform vote (or are empty) # for this to work no negative numbers should exist # get row counts, max and sums - rmax = nlabels.max(1).A.squeeze() - sums = nlabels.sum(axis=1).A1 + rmax = nlabels.max(1).toarray().squeeze() + sums = np.asarray(nlabels.sum(axis=1)).ravel() counts = np.diff(nlabels.indptr) # then keep rows where max*counts differs from sums rmax = np.multiply(rmax, counts) @@ -220,7 +224,7 @@ def mode_filter( # since we have only rows that were non-uniform, they should not become empty # rows may become unform: we still need to vote below to update this label if novote: - rr = np.in1d(nlabels.data, novote) + rr = np.isin(nlabels.data, novote) nlabels.data[rr] = 0 nlabels.eliminate_zeros() # run over all rows and compute mode (maybe vectorize later) @@ -231,7 +235,7 @@ def mode_filter( rempty += 1 continue # print(str(rvals)) - mvals = mode(rvals)[0] + mvals = mode(rvals, keepdims=True)[0] # print(str(mvals)) if mvals.size != 0: # print(str(row)+' '+str(ids[row])+' '+str(mvals[0])) @@ -246,35 +250,29 @@ def mode_filter( return labels_new -def smooth_aparc( - insurfname: str, - inaparcname: str, - incortexname: str, - outaparcname: str -) -> None: - """Smoothes aparc. +def smooth_aparc(surf, labels, cortex = None): + """ + Smooth aparc label regions on the surface and fill holes. + + First all labels with 0 and -1 unside cortex are filled via repeated + mode filtering, then all labels are smoothed first with a wider and + then with smaller filters to produce smooth label boundaries. Labels + outside cortex are set to -1 at the end. Parameters ---------- - insurfname : str - Suface filepath and name of source - inaparcname : str - Annotation filepath and name of source - incortexname : str - Label filepath and name of source - outaparcname : str - Suface filepath and name of destination + surf : nibabel surface + Suface filepath and name of source. + labels : np.array[int] + Labels at each vertex (int). + cortex : np.array[int] + Vertex ids inside cortex mask. + Returns + ------- + smoothed_labels : np.array[int] + Smoothed labels. """ - # read input files - print("Reading in surface: {} ...".format(insurfname)) - surf = fs.read_geometry(insurfname, read_metadata=True) - print("Reading in annotation: {} ...".format(inaparcname)) - aparc = fs.read_annot(inaparcname) - print("Reading in cortex label: {} ...".format(incortexname)) - cortex = fs.read_label(incortexname) - # set labels (n) and triangles (n x 3) - labels = aparc[0] faces = surf[1] nvert = labels.size if labels.size != surf[0].shape[0]: @@ -286,8 +284,11 @@ def smooth_aparc( ) # Compute Cortex Mask - mask = np.zeros(labels.shape, dtype=bool) - mask[cortex] = True + if cortex is not None: + mask = np.zeros(labels.shape, dtype=bool) + mask[cortex] = True + else: + mask = np.ones(labels.shape, dtype=bool) # check if we have places where non-cortex has some labels noncortnum = np.where(~mask & (labels != -1)) print( @@ -302,7 +303,7 @@ def smooth_aparc( noncortids = np.where(~mask) # remove triangles where one vertex is non-cortex to avoid these edges to vote on neighbors later - rr = np.in1d(faces, noncortids) + rr = np.isin(faces, noncortids) rr = np.reshape(rr, faces.shape) rr = np.amax(rr, 1) faces = faces[~rr, :] @@ -320,6 +321,7 @@ def smooth_aparc( # print("minlab: "+str(np.min(labels))+" maxlab: "+str(np.max(labels))) # set all labels inside cortex that are -1 or 0 to fill label + labels = labels.copy() fillonlylabel = np.max(labels) + 1 labels[mask & (labels == -1)] = fillonlylabel labels[mask & (labels == 0)] = fillonlylabel @@ -340,7 +342,7 @@ def smooth_aparc( ) fillids = np.where(labels == fillonlylabel)[0] labels[fillids] = 0 - rr = np.in1d(faces, fillids) + rr = np.isin(faces, fillids) rr = np.reshape(rr, faces.shape) rr = np.amax(rr, 1) faces = faces[~rr, :] @@ -352,18 +354,54 @@ def smooth_aparc( idssize = ids.size counter += 1 # SMOOTH other labels (first with wider kernel then again fine-tune): - labels = mode_filter(adjM * adjM, labels) + adjM2 = adjM * adjM + adjM4 = adjM2 * adjM2 + labels = mode_filter(adjM4, labels) + labels = mode_filter(adjM2, labels) labels = mode_filter(adjM, labels) # set labels outside cortex to -1 labels[~mask] = -1 + return labels + + +def main( + insurfname: str, + inaparcname: str, + incortexname: str, + outaparcname: str +) -> None: + """ + Read files, smooth the aparc labels on the surface and save the smoothed labels. + + Parameters + ---------- + insurfname : str + Suface filepath and name of source. + inaparcname : str + Annotation filepath and name of source. + incortexname : str + Label filepath and name of source. + outaparcname : str + Surface filepath and name of destination. + """ + # read input files + print("Reading in surface: {} ...".format(insurfname)) + surf = fs.read_geometry(insurfname, read_metadata=True) + print("Reading in annotation: {} ...".format(inaparcname)) + aparc = fs.read_annot(inaparcname) + print("Reading in cortex label: {} ...".format(incortexname)) + cortex = fs.read_label(incortexname) + # set labels (n) and triangles (n x 3) + labels = aparc[0] + slabels = smooth_aparc(surf, labels, cortex) print("Outputting fixed annot: {}".format(outaparcname)) - fs.write_annot(outaparcname, labels, aparc[1], aparc[2]) + fs.write_annot(outaparcname, slabels, aparc[1], aparc[2]) if __name__ == "__main__": # Command Line options are error checking done here options = options_parse() - smooth_aparc(options.insurf, options.inaparc, options.incort, options.outaparc) + main(options.insurf, options.inaparc, options.incort, options.outaparc) sys.exit(0) diff --git a/recon_surf/spherically_project.py b/recon_surf/spherically_project.py index 42a56e6d..ddb58ce4 100644 --- a/recon_surf/spherically_project.py +++ b/recon_surf/spherically_project.py @@ -19,8 +19,8 @@ import nibabel.freesurfer.io as fs import numpy as np import math -from lapy.diffGeo import tria_mean_curvature_flow -from lapy.triaMesh import TriaMesh +from lapy.diffgeo import tria_mean_curvature_flow +from lapy import TriaMesh from lapy.solver import Solver HELPTEXT = """ @@ -48,7 +48,7 @@ Dependencies: - Python 3.8 + Python 3.8+ Scipy 0.10 or later to solve the generalized eigenvalue problem. http://docs.scipy.org/doc/scipy/reference/tutorial/arpack.html diff --git a/recon_surf/spherically_project_wrapper.py b/recon_surf/spherically_project_wrapper.py index e65904a7..82035bf3 100644 --- a/recon_surf/spherically_project_wrapper.py +++ b/recon_surf/spherically_project_wrapper.py @@ -14,20 +14,20 @@ # IMPORTS -from subprocess import Popen, PIPE import shlex import argparse from typing import Any +from subprocess import Popen, PIPE def setup_options(): - """Command line option parser. + """ + Create a command line interface and return command line options. Returns ------- - options - object holding options - + options : argparse.Namespace + Namespace object holding options. """ # Validation settings parser = argparse.ArgumentParser(description="Wrapper for spherical projection") @@ -46,54 +46,59 @@ def setup_options(): def call(command: str, **kwargs: Any) -> int: - """Run command with arguments. + """ + Run command with arguments. Wait for command to complete. Sends output to logging module. Parameters ---------- command : str - Command to call + Command to call. **kwargs : Any + Keyword arguments. Returns ------- int - Returncode of called command - + Returncode of called command. """ kwargs["stdout"] = PIPE kwargs["stderr"] = PIPE command_split = shlex.split(command) p = Popen(command_split, **kwargs) - stdout = p.communicate()[0] + stdout, stderr = p.communicate() if stdout: for line in stdout.decode("utf-8").split("\n"): print(line) + if stderr: + print("stderr") + for line in stderr.decode("utf-8").split("\n"): + print(line) return p.returncode def spherical_wrapper(command1: str, command2: str, **kwargs: Any) -> int: - """Run the first command. If it fails the fallback command is run as well. + """ + Run the first command. If it fails the fallback command is run instead. Parameters ---------- command1 : str - Command to call + Command to call. command2 : str - Fallback command to call + Fallback command to call. **kwargs : Any - Arguments. The same as for the Popen constructor + Arguments. The same as for the Popen constructor. Returns ------- code_1 - Return code of command1. If command1 failed return code of command2 - + Return code of command1. If command1 failed return code of command2. """ # First try to run standard spherical project print("Running command: {}".format(command1)) @@ -143,4 +148,8 @@ def spherical_wrapper(command1: str, command2: str, **kwargs: Any) -> int: + " -qsphere -no-isrunning " + threading ) - spherical_wrapper(cmd1, cmd2) + # make sure the process has a username, so nibabel does not crash in write_geometry + from os import environ + env = dict(environ) + env.setdefault("USERNAME", "UNKNOWN") + spherical_wrapper(cmd1, cmd2, env=env) diff --git a/recon_surf/talairach-reg.sh b/recon_surf/talairach-reg.sh new file mode 100755 index 00000000..8f199a34 --- /dev/null +++ b/recon_surf/talairach-reg.sh @@ -0,0 +1,106 @@ +#!/bin/bash + +# Copyright 2024 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script only runs the FreeSurfer talairach registration pipeline +# The call signature is: +usage="talairach-reg.sh <3T atlas: true/false> " + +if [[ "$#" != "3" ]] +then + echo "Invalid number of arguments to talairach-reg.sh, must be '$usage'" + exit 1 +fi +if ! [[ -d "$1" ]] +then + echo "First argument must be the mri-directory: $usage" + exit 1 +fi +mdir="$1" +if [[ "$2" != "true" ]] && [[ "$2" != "false" ]] +then + echo "Second argument must be true or false: $usage" + exit 1 +fi +atlas3T="$2" +if ! [[ -f "$3" ]] +then + echo "Third argument must be the logfile (must already exist): $usage" + exit 1 +fi +LF="$3" + +if [ -z "$FASTSURFER_HOME" ] +then + binpath="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )/" +else + binpath="$FASTSURFER_HOME/recon_surf/" +fi + +# Load the RunIt and the RunBatchJobs functions +source "$binpath/functions.sh" + +# needs //mri +# needs //mri/transforms +mkdir -p $mdir/transforms +mkdir -p $mdir/tmp + +pushd "$mdir" > /dev/null || ( echo "Could not change to $mdir!" | tee -a "$LF" && exit 1) + +# talairach.xfm: compute talairach full head (25sec) +if [[ "$atlas3T" == "true" ]] +then + echo "Using the 3T atlas for talairach registration." + atlas="--atlas 3T18yoSchwartzReactN32_as_orig" +else + echo "Using the default atlas (1.5T) for talairach registration." + atlas="" +fi +if [[ ! -f /bin/tcsh ]] ; then + echo "ERROR: The talairach_avi script requires tcsh, but /bin/tcsh does not exist" + exit 1 +fi +cmd="talairach_avi --i $mdir/orig_nu.mgz --xfm $mdir/transforms/talairach.auto.xfm $atlas" +RunIt "$cmd" $LF +# create copy +cmd="cp $mdir/transforms/talairach.auto.xfm $mdir/transforms/talairach.xfm" +RunIt "$cmd" $LF +# talairach.lta: convert to lta +cmd="lta_convert --src $mdir/orig.mgz --trg $FREESURFER_HOME/average/mni305.cor.mgz --inxfm $mdir/transforms/talairach.xfm --outlta $mdir/transforms/talairach.xfm.lta --subject fsaverage --ltavox2vox" +RunIt "$cmd" $LF + +# FS would here create better nu.mgz using talairach transform (finds wm and maps it to approx 110) +#NuIterations="1 --proto-iters 1000 --distance 50" # default 3T +#FS60 cmd="mri_nu_correct.mni --i $mdir/orig.mgz --o $mdir/nu.mgz --uchar $mdir/transforms/talairach.xfm --n $NuIterations --mask $mdir/mask.mgz" +#FS72 cmd="mri_nu_correct.mni --i $mdir/orig.mgz --o $mdir/nu.mgz --uchar $mdir/transforms/talairach.xfm --n $NuIterations --ants-n4" +# all this is basically useless, as we did a good orig_nu already, including WM normalization + +# Since we do not run mri_em_register we sym-link other talairach transform files here +pushd "$mdir/transforms" > /dev/null || ( echo "ERROR: Could not change to the transforms directory $mdir/transforms!" | tee -a "$LF" && exit 1 ) + cmd="ln -sf talairach.xfm.lta talairach_with_skull.lta" + RunIt "$cmd" $LF + cmd="ln -sf talairach.xfm.lta talairach.lta" + RunIt "$cmd" $LF +popd > /dev/null || exit 1 + +# Add xfm to nu +# (use orig_nu, if nu.mgz does not exist already); by default, it should exist +if [[ -e "$mdir/nu.mgz" ]]; then src_nu_file="$mdir/nu.mgz" +else src_nu_file="$mdir/orig_nu.mgz" +fi +cmd="mri_add_xform_to_header -c $mdir/transforms/talairach.xfm $src_nu_file $mdir/nu.mgz" +RunIt "$cmd" $LF + +popd > /dev/null || return \ No newline at end of file diff --git a/recon_surf/utils/extract_recon_surf_time_info.py b/recon_surf/utils/extract_recon_surf_time_info.py index d5238675..dd8897b9 100644 --- a/recon_surf/utils/extract_recon_surf_time_info.py +++ b/recon_surf/utils/extract_recon_surf_time_info.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 -import datetime +from datetime import datetime, timedelta +from pathlib import Path + import dateutil.parser import argparse import yaml @@ -29,10 +31,10 @@ def get_recon_all_stage_duration(line: str, previous_datetime_str: str) -> float try: current_date_time = dateutil.parser.parse(current_datetime_str) previous_date_time = dateutil.parser.parse(previous_datetime_str) - except: # strptime considers the computers time locale settings + except dateutil.parser.ParserError: # strptime considers the computers time locale settings locale.setlocale(locale.LC_TIME,"") - current_date_time = datetime.datetime.strptime(current_datetime_str, "%a %d. %b %H:%M:%S %Z %Y") - previous_date_time = datetime.datetime.strptime(previous_datetime_str, "%a %d. %b %H:%M:%S %Z %Y") + current_date_time = datetime.strptime(current_datetime_str, "%a %d. %b %H:%M:%S %Z %Y") + previous_date_time = datetime.strptime(previous_datetime_str, "%a %d. %b %H:%M:%S %Z %Y") stage_duration = (current_date_time - previous_date_time).total_seconds() return stage_duration @@ -43,26 +45,27 @@ def get_recon_all_stage_duration(line: str, previous_datetime_str: str) -> float parser.add_argument( "-i", "--input_file_path", - type=str, + type=Path, default="scripts/recon-surf.log", help="Path to recon-surf.log file", ) parser.add_argument( "-o", "--output_file_path", - type=str, - default="", + type=Path, + default=None, help="Path to output recon-surf_time.log file", ) parser.add_argument( - "--time_units", type=str, default="m", help="Units of time [s, m]" + "--time_units", + choices=["m", "s"], + default="m", + help="Units of time [s, m]", ) args = parser.parse_args() - lines = [] with open(args.input_file_path) as file: - for line in file: - lines.append(line.rstrip()) + lines = [line.rstrip() for line in file.readlines()] timestamp_feature = "@#@FSTIME" recon_all_stage_feature = "#@# " @@ -94,19 +97,15 @@ def get_recon_all_stage_duration(line: str, previous_datetime_str: str) -> float "#", "This may cause", ] - filtered_cmds = ["ln ", "rm ", "cp "] + filtered_cmds = ["ln", "rm", "cp"] - if args.output_file_path == "": - output_file_path = ( - args.input_file_path.rsplit("/", 1)[0] + "/" + "recon-surf_times.yaml" - ) + if not args.output_file_path: + output_file_path = args.input_file_path.parent / "recon-surf_times.yaml" else: output_file_path = args.output_file_path print( - "[INFO] Parsing file for recon_surf time information: {}\n".format( - args.input_file_path - ) + f"[INFO] Parsing file for recon_surf time information: {args.input_file_path}\n" ) if args.time_units not in ["s", "m"]: print("[WARN] Invalid time_units! Must be in s or m. Defaulting to m...") @@ -114,24 +113,24 @@ def get_recon_all_stage_duration(line: str, previous_datetime_str: str) -> float else: time_units = args.time_units - yaml_dict = {} - yaml_dict["date"] = lines[1] - recon_surf_commands = [] + yaml_dict = {"date": lines[1]} + pre_recon_surf_stage_name = "Starting up / no stage defined yet" + current_recon_surf_stage_name = pre_recon_surf_stage_name + recon_surf_commands = [{current_recon_surf_stage_name: []}] for i, line in enumerate(lines): ## Use recon_surf "stage" names as top level of recon-surf_commands entries: if "======" in line and "teration" not in line: stage_line = line - current_recon_surf_stage_name = stage_line.strip("=")[1:-1].replace( - " ", "-" - ) + current_recon_surf_stage_name = stage_line.strip("= ").replace(" ", "-") if current_recon_surf_stage_name == "DONE": continue recon_surf_commands.append({current_recon_surf_stage_name: []}) + line_parts = line.split() if "recon-surf.sh" in line and "--sid" in line: try: - yaml_dict["subject_id"] = line.split()[line.split().index("--sid") + 1] + yaml_dict["subject_id"] = line_parts[line_parts.index("--sid") + 1] except ValueError: print( "[WARN] Could not extract subject ID from log file! It will not be added to the output." @@ -143,19 +142,19 @@ def get_recon_all_stage_duration(line: str, previous_datetime_str: str) -> float ## Parse out cmd name, start time, and duration: entry_dict = {} - cmd_name = line.split()[2] + " " + cmd_name = line_parts[2] if cmd_name in filtered_cmds: continue - date_time_str = line.split()[1] + date_time_str = line_parts[1] start_time = date_time_str[11:] - start_date_time = datetime.datetime.strptime( + start_date_time = datetime.strptime( date_time_str, "%Y:%m:%d:%H:%M:%S" ) - assert line.split()[5] == "e" - cmd_duration = float(line.split()[6]) + assert line_parts[5] == "e" + cmd_duration = float(line_parts[6]) - end_date_time = start_date_time + datetime.timedelta(0, float(cmd_duration)) + end_date_time = start_date_time + timedelta(0, float(cmd_duration)) end_date_time_str = end_date_time.strftime("%Y:%m:%d:%H:%M:%S") end_time = end_date_time_str[11:] @@ -164,7 +163,7 @@ def get_recon_all_stage_duration(line: str, previous_datetime_str: str) -> float cmd_line = None for previous_line_index in range(i - 1, -1, -1): temp_line = lines[previous_line_index] - if cmd_name in temp_line and all( + if cmd_name + " " in temp_line and all( phrase not in temp_line for phrase in cmd_line_filter_phrases ): cmd_line = temp_line @@ -172,9 +171,8 @@ def get_recon_all_stage_duration(line: str, previous_datetime_str: str) -> float break else: print( - "[WARN] Could not find the line containing the full command for {} in line {}! Skipping...\n".format( - cmd_name[:-1], i - ) + f"[WARN] Could not find the line containing the full command for " + f"{cmd_name} in line {i+1}! Skipping...\n" ) continue @@ -187,9 +185,12 @@ def get_recon_all_stage_duration(line: str, previous_datetime_str: str) -> float entry_dict["duration_m"] = round(cmd_duration / 60.0, 2) ## Parse out the same details for each stage in recon-all - if cmd_name == "recon-all ": + if cmd_name == "recon-all": entry_dict["stages"] = [] first_stage = True + previous_datetime_str = "" + stage_name = "" + previous_stage_start_time = "" for j in range(cmd_line_index, i): if ( @@ -245,6 +246,6 @@ def get_recon_all_stage_duration(line: str, previous_datetime_str: str) -> float yaml_dict["recon-surf_commands"] = recon_surf_commands - print("[INFO] Writing output to file: {}".format(output_file_path)) + print(f"[INFO] Writing output to file: {output_file_path}") with open(output_file_path, "w") as outfile: yaml.dump(yaml_dict, outfile, sort_keys=False) diff --git a/requirements.cpu.txt b/requirements.cpu.txt new file mode 100644 index 00000000..fe88cd9e --- /dev/null +++ b/requirements.cpu.txt @@ -0,0 +1,104 @@ +# +# This file is autogenerated by kueglerd from deepmi/fastsurfer:cpu-v2.3.0 +# by the following command from FastSurfer: +# +# ./requirements.cpu.txt deepmi/fastsurfer:cpu-v2.3.0 +# +# Which ran the following command: +# docker run --rm -u : --entrypoint /bin/bash deepmi/fastsurfer:cpu-v2.3.0 -c 'python --version && pip list --format=freeze --no-color --disable-pip-version-check --no-input' +# +# +# Image was configured for cpu using python version 3.10.14 +# +--extra-index-url https://download.pytorch.org/whl/cpu + +# Python 3.10.14 +absl-py==2.1.0 +Brotli==1.1.0 +cached-property==1.5.2 +certifi==2024.7.4 +cffi==1.17.0 +charset-normalizer==3.3.2 +click==8.1.7 +colorama==0.4.6 +contourpy==1.2.1 +cycler==0.12.1 +Deprecated==1.2.14 +filelock==3.15.4 +fonttools==4.53.1 +fsspec==2024.6.1 +grpcio==1.62.2 +h2==4.1.0 +h5py==3.11.0 +hpack==4.0.0 +humanize==4.10.0 +hyperframe==6.0.1 +idna==3.8 +imagecodecs==2024.6.1 +imageio==2.35.1 +importlib_metadata==8.4.0 +importlib_resources==6.4.4 +Jinja2==3.1.4 +joblib==1.4.2 +kiwisolver==1.4.5 +lapy==1.1.0 +lazy_loader==0.4 +Markdown==3.6 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.9.2 +mdurl==0.1.2 +mpmath==1.3.0 +munkres==1.1.4 +networkx==3.3 +nibabel==5.2.1 +numpy==1.26.4 +packaging==24.1 +pandas==2.2.2 +pillow==10.4.0 +pip==24.2 +plotly==5.23.0 +protobuf==4.25.3 +psutil==6.0.0 +pycparser==2.22 +Pygments==2.18.0 +pyparsing==3.1.4 +PySide6==6.7.2 +PySocks==1.7.1 +python-dateutil==2.9.0 +pytz==2024.1 +PyWavelets==1.7.0 +PyYAML==6.0.2 +requests==2.32.3 +rich==13.8.0 +scikit-image==0.24.0 +scikit-learn==1.5.1 +scikit-sparse==0.4.14 +scipy==1.14.1 +setuptools==72.2.0 +shellingham==1.5.4 +shiboken6==6.7.2 +SimpleITK==2.4.0 +six==1.16.0 +sympy==1.13.2 +tenacity==9.0.0 +tensorboard==2.17.1 +tensorboard-data-server==0.7.0 +threadpoolctl==3.5.0 +tifffile==2024.8.24 +torch==2.4.0+cpu +torchio==0.19.9 +torchvision==0.19.0+cpu +tornado==6.4.1 +tqdm==4.66.5 +typer==0.12.5 +typing_extensions==4.12.2 +tzdata==2024.1 +unicodedata2==15.1.0 +urllib3==2.2.2 +Werkzeug==3.0.4 +wheel==0.44.0 +wrapt==1.16.0 +yacs==0.1.8 +zipp==3.20.0 +zstandard==0.23.0 diff --git a/requirements.mac.txt b/requirements.mac.txt index f1afd67f..95af69a7 100644 --- a/requirements.mac.txt +++ b/requirements.mac.txt @@ -1,185 +1,19 @@ -# -# This file is manually created from the autogenerated -# requirements.txt . It is experimental to support MAC -# (intel, apple silicon and gpus via mps). For this we -# currently need the nightly torch and torchvision. -# ---extra-index-url https://download.pytorch.org/whl/nightly/cpu +h5py>=3.7 +lapy>=1.0.1 +matplotlib>=3.7.1 +nibabel>=5.1.0 +numpy>=1.25,<2 +pandas>=1.5.3 +pyyaml>=6.0 +requests>=2.31.0 +scikit-image>=0.19.3 +scikit-learn>=1.2.2 +scipy>=1.10.1,!=1.13.0 +simpleitk>=2.2.1 +tensorboard>=2.12.1 +torch>=2.0.1 +torchio>=0.18.83 +torchvision>=0.15.2 +tqdm>=4.65 +yacs>=0.1.8 -absl-py==1.2.0 - # via tensorboard -cachetools==5.2.0 - # via google-auth -certifi==2022.6.15 - # via requests -charset-normalizer==2.1.0 - # via requests -click==8.1.3 - # via torchio -cycler==0.11.0 - # via matplotlib -deprecated==1.2.13 - # via torchio -fonttools==4.34.4 - # via matplotlib -google-auth==2.9.1 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==0.4.6 - # via tensorboard -grpcio==1.47.0 - # via tensorboard -h5py==3.7.0 - # via -r requirements.in -humanize==4.2.3 - # via torchio -idna==3.3 - # via requests -imageio==2.19.5 - # via scikit-image -importlib-metadata==4.12.0 - # via markdown -joblib==1.2.0 - # via scikit-learn -kiwisolver==1.4.4 - # via matplotlib -lapy==0.4.1 - # via -r requirements.in -markdown==3.4.1 - # via tensorboard -matplotlib==3.5.1 - # via -r requirements.in -networkx==2.8.5 - # via scikit-image -nibabel==3.2.2 - # via - # -r requirements.in - # torchio -numpy==1.23.5 - # via - # -r requirements.in - # h5py - # imageio - # lapy - # matplotlib - # nibabel - # pandas - # pywavelets - # scikit-image - # scikit-learn - # scipy - # tensorboard - # tifffile - # torchio - # torchvision -oauthlib==3.2.0 - # via requests-oauthlib -packaging==21.3 - # via - # matplotlib - # nibabel - # scikit-image -pandas==1.4.3 - # via -r requirements.in -pillow==9.2.0 - # via - # -r requirements.in - # imageio - # matplotlib - # scikit-image - # torchvision -plotly==5.9.0 - # via lapy -protobuf==3.19.4 - # via tensorboard -pyasn1==0.4.8 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.2.8 - # via google-auth -pyparsing==3.0.9 - # via - # matplotlib - # packaging -python-dateutil==2.8.2 - # via - # -r requirements.in - # matplotlib - # pandas -pytz==2022.1 - # via pandas -pywavelets==1.3.0 - # via scikit-image -pyyaml==6.0 - # via - # -r requirements.in - # yacs -requests==2.28.1 - # via - # requests-oauthlib - # tensorboard - # torchvision -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rsa==4.8 - # via google-auth -scikit-image==0.19.2 - # via -r requirements.in -scikit-learn==1.1.2 - # via -r requirements.in -scipy==1.8.0 - # via - # -r requirements.in - # lapy - # scikit-image - # scikit-learn - # torchio -simpleitk==2.1.1 - # via - # -r requirements.in - # torchio -six==1.16.0 - # via - # google-auth - # grpcio - # python-dateutil -tenacity==8.0.1 - # via plotly -tensorboard==2.9.1 - # via -r requirements.in -tensorboard-data-server==0.6.1 - # via tensorboard -tensorboard-plugin-wit==1.8.1 - # via tensorboard -threadpoolctl==3.1.0 - # via scikit-learn -tifffile==2022.5.4 - # via scikit-image -torch>=1.13.0.dev20220815 - # manually set nighly -torchio==0.18.83 - # via -r requirements.in -torchvision>=0.14.0.dev20220815 - # manually set nighly -tqdm==4.64 - # via - # -r requirements.in - # torchio -typing-extensions==4.3.0 - # via - # torch - # torchvision -urllib3==1.26.10 - # via requests -werkzeug==2.1.2 - # via tensorboard -wheel==0.37.1 - # via tensorboard -wrapt==1.14.1 - # via deprecated -yacs==0.1.8 - # via -r requirements.in -zipp==3.8.1 - # via importlib-metadata diff --git a/requirements.txt b/requirements.txt index 4ee513bf..f168f221 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,210 +1,117 @@ # -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: +# This file is autogenerated by kueglerd from deepmi/fastsurfer:cu124-v2.3.0 +# by the following command from FastSurfer: # -# pip-compile requirements.in +# ./requirements.txt deepmi/fastsurfer:cu124-v2.3.0 # ---extra-index-url https://download.pytorch.org/whl/cu117 +# Which ran the following command: +# docker run --rm -u : --entrypoint /bin/bash deepmi/fastsurfer:cu124-v2.3.0 -c 'python --version && pip list --format=freeze --no-color --disable-pip-version-check --no-input' +# +# +# Image was configured for cu124 using python version 3.10.14 +# +--extra-index-url https://download.pytorch.org/whl/cu124 -absl-py==2.0.0 - # via tensorboard -cachetools==5.3.2 - # via google-auth -certifi==2023.7.22 - # via requests +# Python 3.10.14 +absl-py==2.1.0 +Brotli==1.1.0 +cached-property==1.5.2 +certifi==2024.7.4 +cffi==1.17.0 charset-normalizer==3.3.2 - # via requests click==8.1.7 - # via torchio -cmake==3.27.7 - # via triton -contourpy==1.2.0 - # via matplotlib +colorama==0.4.6 +contourpy==1.2.1 cycler==0.12.1 - # via matplotlib -deprecated==1.2.14 - # via torchio -filelock==3.13.1 - # via - # torch - # triton -fonttools==4.44.3 - # via matplotlib -google-auth==2.23.4 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -grpcio==1.59.2 - # via tensorboard -h5py==3.7.0 - # via -r requirements.in -humanize==4.8.0 - # via torchio -idna==3.4 - # via requests -imageio==2.32.0 - # via scikit-image -jinja2==3.1.2 - # via torch -joblib==1.3.2 - # via scikit-learn +Deprecated==1.2.14 +filelock==3.15.4 +fonttools==4.53.1 +fsspec==2024.6.1 +grpcio==1.62.2 +h2==4.1.0 +h5py==3.11.0 +hpack==4.0.0 +humanize==4.10.0 +hyperframe==6.0.1 +idna==3.8 +imagecodecs==2024.6.1 +imageio==2.35.1 +importlib_metadata==8.4.0 +importlib_resources==6.4.4 +Jinja2==3.1.4 +joblib==1.4.2 kiwisolver==1.4.5 - # via matplotlib -lapy==1.0.1 - # via -r requirements.in -lit==17.0.5 - # via triton -markdown==3.5.1 - # via tensorboard -markupsafe==2.1.3 - # via - # jinja2 - # werkzeug -matplotlib==3.7.1 - # via -r requirements.in +lapy==1.1.0 +lazy_loader==0.4 +Markdown==3.6 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.9.2 +mdurl==0.1.2 mpmath==1.3.0 - # via sympy -networkx==3.2.1 - # via - # scikit-image - # torch -nibabel==5.1.0 - # via - # -r requirements.in - # lapy - # torchio -numpy==1.25.0 - # via - # -r requirements.in - # contourpy - # h5py - # imageio - # lapy - # matplotlib - # nibabel - # pandas - # pywavelets - # scikit-image - # scikit-learn - # scipy - # tensorboard - # tifffile - # torchio - # torchvision -oauthlib==3.2.2 - # via requests-oauthlib -packaging==23.2 - # via - # matplotlib - # nibabel - # plotly - # scikit-image -pandas==1.5.3 - # via -r requirements.in -pillow==10.0.1 - # via - # -r requirements.in - # imageio - # matplotlib - # scikit-image - # torchvision -plotly==5.18.0 - # via lapy -protobuf==4.25.1 - # via tensorboard -psutil==5.9.6 - # via lapy -pyasn1==0.5.0 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pyparsing==3.1.1 - # via matplotlib -python-dateutil==2.8.2 - # via - # -r requirements.in - # matplotlib - # pandas -pytz==2023.3.post1 - # via pandas -pywavelets==1.4.1 - # via scikit-image -pyyaml==6.0 - # via - # -r requirements.in - # yacs -requests==2.31.0 - # via - # requests-oauthlib - # tensorboard - # torchvision -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rsa==4.9 - # via google-auth -scikit-image==0.19.3 - # via -r requirements.in -scikit-learn==1.2.2 - # via -r requirements.in -scipy==1.10.1 - # via - # -r requirements.in - # lapy - # scikit-image - # scikit-learn - # torchio -simpleitk==2.2.1 - # via - # -r requirements.in - # torchio +munkres==1.1.4 +networkx==3.3 +nibabel==5.2.1 +numpy==1.26.4 +nvidia-cublas-cu12==12.4.2.65 +nvidia-cuda-cupti-cu12==12.4.99 +nvidia-cuda-nvrtc-cu12==12.4.99 +nvidia-cuda-runtime-cu12==12.4.99 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.0.44 +nvidia-curand-cu12==10.3.5.119 +nvidia-cusolver-cu12==11.6.0.99 +nvidia-cusparse-cu12==12.3.0.142 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.4.99 +nvidia-nvtx-cu12==12.4.99 +packaging==24.1 +pandas==2.2.2 +pillow==10.4.0 +pip==24.2 +plotly==5.23.0 +protobuf==4.25.3 +psutil==6.0.0 +pycparser==2.22 +Pygments==2.18.0 +pyparsing==3.1.4 +PySide6==6.7.2 +PySocks==1.7.1 +python-dateutil==2.9.0 +pytz==2024.1 +PyWavelets==1.7.0 +PyYAML==6.0.2 +requests==2.32.3 +rich==13.7.1 +scikit-image==0.24.0 +scikit-learn==1.5.1 +scikit-sparse==0.4.14 +scipy==1.14.1 +setuptools==72.2.0 +shellingham==1.5.4 +shiboken6==6.7.2 +SimpleITK==2.4.0 six==1.16.0 - # via python-dateutil -sympy==1.12 - # via torch -tenacity==8.2.3 - # via plotly -tensorboard==2.12.1 - # via -r requirements.in -tensorboard-data-server==0.7.2 - # via tensorboard -tensorboard-plugin-wit==1.8.1 - # via tensorboard -threadpoolctl==3.2.0 - # via scikit-learn -tifffile==2023.9.26 - # via scikit-image -torch==2.0.1+cu117 - # via - # -r requirements.in - # torchio - # torchvision - # triton -torchio==0.18.83 - # via -r requirements.in -torchvision==0.15.2+cu117 - # via -r requirements.in -tqdm==4.65.0 - # via - # -r requirements.in - # torchio -triton==2.0.0 - # via torch -typing-extensions==4.8.0 - # via torch -urllib3==2.1.0 - # via requests -werkzeug==3.0.1 - # via tensorboard -wheel==0.41.3 - # via tensorboard +sympy==1.13.2 +tenacity==9.0.0 +tensorboard==2.17.1 +tensorboard-data-server==0.7.0 +threadpoolctl==3.5.0 +tifffile==2024.8.24 +torch==2.4.0+cu124 +torchio==0.19.9 +torchvision==0.19.0+cu124 +tornado==6.4.1 +tqdm==4.66.5 +triton==3.0.0 +typer==0.12.5 +typing_extensions==4.12.2 +tzdata==2024.1 +unicodedata2==15.1.0 +urllib3==2.2.2 +Werkzeug==3.0.4 +wheel==0.44.0 wrapt==1.16.0 - # via deprecated yacs==0.1.8 - # via -r requirements.in - -# The following packages are considered to be unsafe in a requirements file: -# setuptools +zipp==3.20.0 +zstandard==0.23.0 diff --git a/requirements_cpu.txt b/requirements_cpu.txt deleted file mode 100644 index 2e342b12..00000000 --- a/requirements_cpu.txt +++ /dev/null @@ -1,201 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# pip-compile --output-file=requirements_cpu.txt requirements.in -# ---extra-index-url https://download.pytorch.org/whl/cpu - -absl-py==2.0.0 - # via tensorboard -cachetools==5.3.2 - # via google-auth -certifi==2023.7.22 - # via requests -charset-normalizer==3.3.2 - # via requests -click==8.1.7 - # via torchio -contourpy==1.2.0 - # via matplotlib -cycler==0.12.1 - # via matplotlib -deprecated==1.2.14 - # via torchio -filelock==3.13.1 - # via torch -fonttools==4.44.3 - # via matplotlib -google-auth==2.23.4 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -grpcio==1.59.2 - # via tensorboard -h5py==3.7.0 - # via -r requirements.in -humanize==4.8.0 - # via torchio -idna==3.4 - # via requests -imageio==2.32.0 - # via scikit-image -jinja2==3.1.2 - # via torch -joblib==1.3.2 - # via scikit-learn -kiwisolver==1.4.5 - # via matplotlib -lapy==1.0.1 - # via -r requirements.in -markdown==3.5.1 - # via tensorboard -markupsafe==2.1.3 - # via - # jinja2 - # werkzeug -matplotlib==3.7.1 - # via -r requirements.in -mpmath==1.3.0 - # via sympy -networkx==3.2.1 - # via - # scikit-image - # torch -nibabel==5.1.0 - # via - # -r requirements.in - # lapy - # torchio -numpy==1.25.0 - # via - # -r requirements.in - # contourpy - # h5py - # imageio - # lapy - # matplotlib - # nibabel - # pandas - # pywavelets - # scikit-image - # scikit-learn - # scipy - # tensorboard - # tifffile - # torchio - # torchvision -oauthlib==3.2.2 - # via requests-oauthlib -packaging==23.2 - # via - # matplotlib - # nibabel - # plotly - # scikit-image -pandas==1.5.3 - # via -r requirements.in -pillow==10.0.1 - # via - # -r requirements.in - # imageio - # matplotlib - # scikit-image - # torchvision -plotly==5.18.0 - # via lapy -protobuf==4.25.1 - # via tensorboard -psutil==5.9.6 - # via lapy -pyasn1==0.5.0 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pyparsing==3.1.1 - # via matplotlib -python-dateutil==2.8.2 - # via - # -r requirements.in - # matplotlib - # pandas -pytz==2023.3.post1 - # via pandas -pywavelets==1.4.1 - # via scikit-image -pyyaml==6.0 - # via - # -r requirements.in - # yacs -requests==2.31.0 - # via - # requests-oauthlib - # tensorboard - # torchvision -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rsa==4.9 - # via google-auth -scikit-image==0.19.3 - # via -r requirements.in -scikit-learn==1.2.2 - # via -r requirements.in -scipy==1.10.1 - # via - # -r requirements.in - # lapy - # scikit-image - # scikit-learn - # torchio -simpleitk==2.2.1 - # via - # -r requirements.in - # torchio -six==1.16.0 - # via python-dateutil -sympy==1.12 - # via torch -tenacity==8.2.3 - # via plotly -tensorboard==2.12.1 - # via -r requirements.in -tensorboard-data-server==0.7.2 - # via tensorboard -tensorboard-plugin-wit==1.8.1 - # via tensorboard -threadpoolctl==3.2.0 - # via scikit-learn -tifffile==2023.9.26 - # via scikit-image -torch==2.0.1+cpu - # via - # -r requirements.in - # torchio - # torchvision -torchio==0.18.83 - # via -r requirements.in -torchvision==0.15.2+cpu - # via -r requirements.in -tqdm==4.65.0 - # via - # -r requirements.in - # torchio -typing-extensions==4.8.0 - # via torch -urllib3==2.1.0 - # via requests -werkzeug==3.0.1 - # via tensorboard -wheel==0.41.3 - # via tensorboard -wrapt==1.16.0 - # via deprecated -yacs==0.1.8 - # via -r requirements.in - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index 7b8fc43a..ebed4eb2 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -32,52 +32,47 @@ fi fastsurfercnndir="$FASTSURFER_HOME/FastSurferCNN" cerebnetdir="$FASTSURFER_HOME/CerebNet" +hypvinndir="$FASTSURFER_HOME/HypVINN" reconsurfdir="$FASTSURFER_HOME/recon_surf" # Regular flags defaults subject="" t1="" +t2="" merged_segfile="" cereb_segfile="" asegdkt_segfile="" asegdkt_segfile_default="\$SUBJECTS_DIR/\$SID/mri/aparc.DKTatlas+aseg.deep.mgz" asegdkt_statsfile="" cereb_statsfile="" -cereb_flags="" +cereb_flags=() +hypo_segfile="" +hypo_statsfile="" +hypvinn_flags=() conformed_name="" +conformed_name_t2="" +norm_name="" +norm_name_t2="" seg_log="" +run_talairach_registration="false" +atlas3T="false" viewagg="auto" device="auto" batch_size="1" run_seg_pipeline="1" run_biasfield="1" run_surf_pipeline="1" -flag_3T="" -fstess="" -fsqsphere="" -fsaparc="" -fssurfreg="" +surf_flags=() vox_size="min" -doParallel="" run_asegdkt_module="1" run_cereb_module="1" +run_hypvinn_module="1" threads="1" -# python3.10 -s excludes user-directory package inclusion, but passing "python3.10 -s" is not possible -# python-s is a miniscript to add this flag, but this only works if python-s is defined -if [[ -n "$(which python-s)" ]]; then - python="python-s" -elif [[ -f "/fastsurfer/python-s" ]]; then - python="/fastsurfer/python-s" -else - python="python3.10" -fi -allow_root="" +# python3.10 -s excludes user-directory package inclusion +python="python3.10 -s" +allow_root=() version_and_quit="" -# Dev flags defaults -vcheck="" -vfst1="" - function usage() { # --merged_segfile @@ -147,7 +142,8 @@ FLAGS: SEGMENTATION PIPELINE: --seg_only Run only FastSurferVINN (generate segmentation, do not run surface pipeline) - --seg_log Log-file for the segmentation (FastSurferVINN, CerebNet) + --seg_log Log-file for the segmentation (FastSurferVINN, CerebNet, + HypVINN) Default: \$SUBJECTS_DIR/\$sid/scripts/deep-seg.log --conformed_name Name of the file in which the conformed input @@ -159,6 +155,9 @@ SEGMENTATION PIPELINE: --norm_name Name of the biasfield corrected image Default location: \$SUBJECTS_DIR/\$sid/mri/orig_nu.mgz + --tal_reg Perform the talairach registration for eTIV estimates + in --seg_only stream and stats files (is affected by + the --3T flag, see below). MODULES: By default, all modules are run. @@ -181,7 +180,7 @@ SEGMENTATION PIPELINE: APARC module (see above). Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/aparc.DKTatlas+aseg.deep.mgz - --cereb_segfile + --cereb_segfile Name of DL-based segmentation file of the cerebellum. This segmentation is always at 1mm isotropic resolution, since inference is always based on a @@ -194,6 +193,23 @@ SEGMENTATION PIPELINE: --no_biasfield Deactivate the calculation of partial volume-corrected statistics. + HYPOTHALAMUS MODULE (HypVINN): + --no_hypothal Skip the hypothalamus segmentation. + --no_biasfield This option implies --no_hypothal, as the hypothalamus + sub-segmentation requires biasfield-corrected images. + --t2 *Optional* T2 full head input (does not have to be bias + corrected, a mandatory biasfield correction step is + performed). Requires an ABSOLUTE Path! + --reg_mode + Ignored, if no T2 image is passed. + Specifies the registration method used to register T1 + and T2 images. Options are 'coreg' (default) for + mri_coreg, 'robust' for mri_robust_register, and 'none' + to skip registration (this requires T1 and T2 are + externally co-registered). + --qc_snap Create QC snapshots in \$SUBJECTS_DIR/\$sid/qc_snapshots + to simplify the QC process. + SURFACE PIPELINE: --surf_only Run surface pipeline only. The segmentation input has to exist already in this case. @@ -202,7 +218,7 @@ SURFACE PIPELINE: --parallel Run both hemispheres in parallel --threads Set openMP and ITK threads to - Resource Options: +Resource Options: --device Set device on which inference should be run ("cpu" for CPU, "cuda" for Nvidia GPU, or pass specific device, e.g. cuda:1), default check GPU and then CPU @@ -217,7 +233,8 @@ SURFACE PIPELINE: device (no memory check will be done). --batch Batch size for inference. Default: 1 --py Command for python, used in both pipelines. - Default: python3.8 + Default: "$python" + (-s: do no search for packages in home directory) Dev Flags: --ignore_fs_version Switch on to avoid check for FreeSurfer version. @@ -254,11 +271,18 @@ Henschel L*, Kuegler D*, Reuter M. (*co-first). FastSurferVINN: Building for HighRes Brain MRI. NeuroImage 251 (2022), 118933. http://dx.doi.org/10.1016/j.neuroimage.2022.118933 +For cerebellum sub-segmentation: Faber J*, Kuegler D*, Bahrami E*, et al. (*co-first). CerebNet: A fast and reliable deep-learning pipeline for detailed cerebellum sub-segmentation. NeuroImage 264 (2022), 119703. https://doi.org/10.1016/j.neuroimage.2022.119703 +For hypothalamus sub-segemntation: +Estrada S, Kuegler D, Bahrami E, Xu P, Mousa D, Breteler MMB, Aziz NA, Reuter M. + FastSurfer-HypVINN: Automated sub-segmentation of the hypothalamus and adjacent + structures on high-resolutional brain MRI. Imaging Neuroscience 2023; 1 1–32. + https://doi.org/10.1162/imag_a_00034 + EOF } @@ -277,229 +301,153 @@ do # make key lowercase key=$(echo "$1" | tr '[:upper:]' '[:lower:]') +shift # past argument + case $key in - --fs_license) - if [[ -f "$2" ]]; then - export FS_LICENSE="$2" + ############################################################## + # general options + ############################################################## + --fs_license) + if [[ -f "$1" ]] + then + export FS_LICENSE="$1" else - echo "Provided FreeSurfer license file $2 could not be found. Make sure to provide the full path and name. Exiting..." - exit 1; + echo "ERROR: Provided FreeSurfer license file $1 could not be found. Make sure to provide the full path and name. Exiting..." + exit 1 fi - shift # past argument shift # past value ;; - --sid) - subject="$2" - shift # past argument - shift # past value - ;; - --sd) - sd="$2" - shift # past argument - shift # past value - ;; - --t1) - t1="$2" - shift # past argument - shift # past value - ;; - --merged_segfile) - merged_segfile="$2" - shift # past argument - shift # past value - ;; - --seg | --asegdkt_segfile | --aparc_aseg_segfile) - if [[ "$key" == "--seg" ]]; then - echo "WARNING: --seg is deprecated and will be removed, use --asegdkt_segfile ." - fi - if [[ "$key" == "--aparc_aseg_segfile" ]]; then - echo "WARNING: --aparc_aseg_segfile is deprecated and will be removed, use --asegdkt_segfile " + + # options that *just* set a flag + #============================================================= + --allow_root) allow_root=("--allow_root") ;; + # options that set a variable + --sid) subject="$1" ; shift ;; + --sd) sd="$1" ; shift ;; + --t1) t1="$1" ; shift ;; + --t2) t2="$1" ; shift ;; + --seg_log) seg_log="$1" ; shift ;; + --conformed_name) conformed_name="$1" ; shift ;; + --norm_name) norm_name="$1" ; shift ;; + --norm_name_t2) norm_name_t2="$1" ; shift ;; + --seg|--asegdkt_segfile|--aparc_aseg_segfile) + if [[ "$key" != "--asegdkt_segfile" ]] + then + echo "WARNING: --$key is deprecated and will be removed, use --asegdkt_segfile ." fi - asegdkt_segfile="$2" - shift # past argument - shift # past value - ;; - --asegdkt_statsfile) - asegdkt_statsfile="$2" - shift # past argument - shift # past value - ;; - --cereb_segfile) - cereb_segfile="$2" - shift # past argument - shift # past value - ;; - --cereb_statsfile) - cereb_statsfile="$2" - shift # past argument - shift # past value - ;; - --mask_name) - mask_name="$2" - shift # past argument - shift # past value - ;; - --norm_name) - norm_name="$2" - shift # past argument - shift # past value - ;; - --aseg_segfile) - aseg_segfile="$2" - shift # past argument + asegdkt_segfile="$1" shift # past value ;; - --conformed_name) - conformed_name="$2" - shift # past argument - shift # past value - ;; - --seg_log) - seg_log="$2" - shift # past argument - shift # past value + --vox_size) vox_size="$1" ; shift ;; + # --3t: both for surface pipeline and the --tal_reg flag + --3t) surf_flags=("${surf_flags[@]}" "--3T") ; atlas3T="true" ;; + --threads) threads="$1" ; shift ;; + --py) python="$1" ; shift ;; + -h|--help) usage ; exit ;; + --version) + if [[ "$#" -lt 1 ]] || [[ "$1" =~ ^-- ]]; then + # no more args or next arg starts with -- + version_and_quit="1" + else + case "$(echo "$1" | tr '[:upper:]' '[:lower:]')" in + all) version_and_quit="+checkpoints+git+pip" ;; + +*) version_and_quit="$1" ;; + *) echo "ERROR: Invalid option for --version: '$1', must be 'all' or [+checkpoints][+git][+pip]" + exit 1 + ;; + esac + shift + fi ;; - --viewagg_device | --run_viewagg_on) + + ############################################################## + # seg-pipeline options + ############################################################## + + # common options for seg + #============================================================= + --surf_only) run_seg_pipeline="0" ;; + --no_biasfield) run_biasfield="0" ;; + --tal_reg) run_talairach_registration="true" ;; + --device) device="$1" ; shift ;; + --batch) batch_size="$1" ; shift ;; + --viewagg_device|--run_viewagg_on) if [[ "$key" == "--run_viewagg_on" ]] then echo "WARNING: --run_viewagg_on (cpu|gpu|check) is deprecated and will be removed, use --viewagg_device ." fi - case "$2" in + case "$1" in check) - echo "WARNING: the option \"check\" is deprecated for --viewagg_device , use \"auto\"." - viewagg="auto" - ;; - gpu) - viewagg="cuda" - ;; - *) - viewagg="$2" - ;; + echo "WARNING: the option \"check\" is deprecated for --viewagg_device , use \"auto\"." + viewagg="auto" + ;; + gpu) viewagg="cuda" ;; + *) viewagg="$1" ;; esac - shift # past argument shift # past value ;; - --no_cuda) + --no_cuda) echo "WARNING: --no_cuda is deprecated and will be removed, use --device cpu." device="cpu" - shift # past argument - ;; - --no_biasfield) - run_biasfield="0" - shift # past argument ;; - --no_asegdkt | --no_aparc) + + # asegdkt module options + #============================================================= + --no_asegdkt|--no_aparc) if [[ "$key" == "--no_aparc" ]] then echo "WARNING: --no_aparc is deprecated and will be removed, use --no_asegdkt." fi run_asegdkt_module="0" - shift # past argument - ;; - --no_cereb) - run_cereb_module="0" - shift # past argument - ;; - --device) - device=$2 - shift # past argument - shift # past value - ;; - --batch) - batch_size="$2" - shift # past argument - shift # past value - ;; - --seg_only) - run_surf_pipeline="0" - shift # past argument - ;; - --surf_only) - run_seg_pipeline="0" - shift # past argument - ;; - --fstess) - fstess="--fstess" - shift # past argument - ;; - --fsqsphere) - fsqsphere="--fsqsphere" - shift # past argument - ;; - --fsaparc) - fsaparc="--fsaparc" - shift # past argument - ;; - --no_surfreg) - fssurfreg="--no_surfreg" - shift # past argument - ;; - --vox_size) - vox_size="$2" - shift # past argument - shift # past value - ;; - --3t) - flag_3T="--3T" - shift - ;; - --parallel) - doParallel="--parallel" - shift # past argument ;; - --threads) - threads="$2" - shift # past argument - shift # past value - ;; - --py) - python="$2" - shift # past argument + --asegdkt_statsfile) asegdkt_statsfile="$1" ; shift ;; + --aseg_segfile) aseg_segfile="$1" ; shift ;; + --mask_name) mask_name="$1" ; shift ;; + --merged_segfile) merged_segfile="$1" ; shift ;; + + # cereb module options + #============================================================= + --no_cereb) run_cereb_module="0" ;; + # several options that set a variable + --cereb_segfile) cereb_segfile="$1" ; shift ;; + --cereb_statsfile) cereb_statsfile="$1" ; shift ;; + + # hypothal module options + #============================================================= + --no_hypothal) run_hypvinn_module="0" ;; + # several options that set a variable + --hypo_segfile) hypo_segfile="$1" ; shift ;; + --hypo_statsfile) hypo_statsfile="$1" ; shift ;; + --reg_mode) + mode=$(echo "$1" | tr "[:upper:]" "[:lower:]") + if [[ "$mode" =~ ^(none|coreg|robust)$ ]] ; then + hypvinn_flags+=(--regmode "$mode") + else + echo "Invalid --reg_mode option, must be 'none', 'coreg' or 'robust'." + exit 1 + fi shift # past value ;; - --ignore_fs_version) - vcheck="--ignore_fs_version" - shift # past argument - ;; - --no_fs_t1 ) - vfst1="--no_fs_T1" - shift # past argument - ;; - --allow_root) - allow_root="--allow_root" - shift # past argument - ;; - -h|--help) - usage - exit + # several options that set a variable + --qc_snap) hypvinn_flags+=(--qc_snap) ;; + + ############################################################## + # surf-pipeline options + ############################################################## + --seg_only) run_surf_pipeline="0" ;; + # several flag options that are *just* passed through to recon-surf.sh + --fstess|--fsqsphere|--fsaparc|--no_surfreg|--parallel|--ignore_fs_version) + surf_flags=("${surf_flags[@]}" "$key") ;; - --version) - if [[ "$#" -lt 2 ]]; then - version_and_quit="1" - else - case $2 in - all) - version_and_quit="+checkpoints+git+pip" - shift - ;; - +*) - version_and_quit="$2" - shift - ;; - --*) - version_and_quit="1" - ;; - *) - echo "Invalid option for --version" - exit 1 - ;; - esac - fi - shift + --no_fs_t1) surf_flags=("${surf_flags[@]}" "--no_fs_T1") ;; + + # temporary segstats development flag + --segstats_legacy) + surf_flags=("${surf_flags[@]}" "$key") ;; - *) # unknown option - echo ERROR: Flag $1 unrecognized. - exit 1 + *) # unknown option + # if not empty arguments, error & exit + if [[ "$key" != "" ]] ; then echo "ERROR: Flag '$key' unrecognized." ; exit 1 ; fi ;; esac done @@ -514,103 +462,155 @@ else fi ########################################## VERSION AND QUIT HERE ######################################## -version_args="" +version_args=() if [[ -f "$FASTSURFER_HOME/BUILD.info" ]] - then - version_args="--build_cache $FASTSURFER_HOME/BUILD.info --prefer_cache" +then + version_args=(--build_cache "$FASTSURFER_HOME/BUILD.info" --prefer_cache) fi if [[ -n "$version_and_quit" ]] +then + # if version_and_quit is 1, it should only print the version number+git branch + if [[ "$version_and_quit" != "1" ]] then - # if version_and_quit is 1, it should only print the version number+git branch - if [[ "$version_and_quit" != "1" ]] - then - version_args="$version_args --sections $version_and_quit" - fi - $python $FASTSURFER_HOME/FastSurferCNN/version.py $version_args - exit + version_args=("${version_args[@]}" --sections "$version_and_quit") + fi + $python "$FASTSURFER_HOME/FastSurferCNN/version.py" "${version_args[@]}" + exit +fi + +# make sure the python executable is valid and found +if [[ -z "$(which "${python/ */}")" ]]; then + echo "Cannot find the python interpreter ${python/ */}." + exit 1 fi # Warning if run as root user -if [[ -z "$allow_root" ]] && [[ "$(id -u)" == "0" ]] - then - echo "You are trying to run '$0' as root. We advice to avoid running FastSurfer as root, " - echo "because it will lead to files and folders created as root." - echo "If you are running FastSurfer in a docker container, you can specify the user with " - echo "'-u \$(id -u):\$(id -g)' (see https://docs.docker.com/engine/reference/run/#user)." - echo "If you want to force running as root, you may pass --allow_root to run_fastsurfer.sh." - exit 1; +if [[ "${#allow_root}" == 0 ]] && [[ "$(id -u)" == "0" ]] +then + echo "You are trying to run '$0' as root. We advice to avoid running FastSurfer as root, " + echo "because it will lead to files and folders created as root." + echo "If you are running FastSurfer in a docker container, you can specify the user with " + echo "'-u \$(id -u):\$(id -g)' (see https://docs.docker.com/engine/reference/run/#user)." + echo "If you want to force running as root, you may pass --allow_root to run_fastsurfer.sh." + exit 1; fi # CHECKS if [[ "$run_seg_pipeline" == "1" ]] && { [[ -z "$t1" ]] || [[ ! -f "$t1" ]]; } - then - echo "ERROR: T1 image ($t1) could not be found. Must supply an existing T1 input (full head) via " - echo "--t1 (absolute path and name) for generating the segmentation." - echo "NOTES: If running in a container, make sure symlinks are valid!" - exit 1; +then + echo "ERROR: T1 image ($t1) could not be found. Must supply an existing T1 input (full head) via " + echo "--t1 (absolute path and name) for generating the segmentation." + echo "NOTES: If running in a container, make sure symlinks are valid!" + exit 1; +fi + +if [[ -z "${sd}" ]] +then + echo "ERROR: No subject directory defined via --sd. This is required!" + exit 1; +fi +if [[ ! -d "${sd}" ]] +then + echo "INFO: The subject directory did not exist, creating it now." + if ! mkdir -p "$sd" ; then echo "ERROR: directory creation failed" ; exit 1; fi +fi +if [[ "$(stat -c "%u:%g" "$sd")" == "0:0" ]] && [[ "$(id -u)" != "0" ]] && [[ "$(stat -c "%a" "$sd" | tail -c 2)" -lt 6 ]] +then + echo "ERROR: The subject directory ($sd) is owned by root and is not writable. FastSurfer cannot write results! " + echo "This can happen if the directory is created by docker. Make sure to create the directory before invoking docker!" + exit 1; fi if [[ -z "$subject" ]] - then - echo "ERROR: must supply subject name via --sid" - exit 1; +then + echo "ERROR: must supply subject name via --sid" + exit 1; fi if [[ -z "$merged_segfile" ]] - then - merged_segfile="${sd}/${subject}/mri/fastsurfer.merged.mgz" +then + merged_segfile="${sd}/${subject}/mri/fastsurfer.merged.mgz" fi if [[ -z "$asegdkt_segfile" ]] - then - asegdkt_segfile="${sd}/${subject}/mri/aparc.DKTatlas+aseg.deep.mgz" +then + asegdkt_segfile="${sd}/${subject}/mri/aparc.DKTatlas+aseg.deep.mgz" fi if [[ -z "$aseg_segfile" ]] - then - aseg_segfile="${sd}/${subject}/mri/aseg.auto_noCCseg.mgz" +then + aseg_segfile="${sd}/${subject}/mri/aseg.auto_noCCseg.mgz" fi if [[ -z "$asegdkt_statsfile" ]] - then - asegdkt_statsfile="${sd}/${subject}/stats/aseg+DKT.stats" +then + asegdkt_statsfile="${sd}/${subject}/stats/aseg+DKT.stats" fi - if [[ -z "$cereb_segfile" ]] - then - cereb_segfile="${sd}/${subject}/mri/cerebellum.CerebNet.nii.gz" +then + cereb_segfile="${sd}/${subject}/mri/cerebellum.CerebNet.nii.gz" fi if [[ -z "$cereb_statsfile" ]] - then - cereb_statsfile="${sd}/${subject}/stats/cerebellum.CerebNet.stats" +then + cereb_statsfile="${sd}/${subject}/stats/cerebellum.CerebNet.stats" +fi + +if [[ -z "$hypo_segfile" ]] +then + hypo_segfile="${sd}/${subject}/mri/hypothalamus.HypVINN.nii.gz" +fi + +if [[ -z "$hypo_statsfile" ]] +then + hypo_statsfile="${sd}/${subject}/stats/hypothalamus.HypVINN.stats" fi if [[ -z "$mask_name" ]] - then - mask_name="${sd}/${subject}/mri/mask.mgz" +then + mask_name="${sd}/${subject}/mri/mask.mgz" fi if [[ -z "$conformed_name" ]] +then + conformed_name="${sd}/${subject}/mri/orig.mgz" +fi + +if [[ -z "$conformed_name_t2" ]] then - conformed_name="${sd}/${subject}/mri/orig.mgz" + conformed_name_t2="${sd}/${subject}/mri/T2orig.mgz" fi if [[ -z "$norm_name" ]] - then - norm_name="${sd}/${subject}/mri/orig_nu.mgz" +then + norm_name="${sd}/${subject}/mri/orig_nu.mgz" +fi + +if [[ -z "$norm_name_t2" ]] +then + norm_name_t2="${sd}/${subject}/mri/T2_nu.mgz" fi if [[ -z "$seg_log" ]] - then - seg_log="${sd}/${subject}/scripts/deep-seg.log" +then + seg_log="${sd}/${subject}/scripts/deep-seg.log" fi if [[ -z "$build_log" ]] - then - build_log="${sd}/${subject}/scripts/build.log" +then + build_log="${sd}/${subject}/scripts/build.log" +fi + +if [[ -n "$t2" ]] +then + if [[ ! -f "$t2" ]] + then + echo "ERROR: T2 file $t2 does not exist!" + exit 1; + fi + copy_name_T2="${sd}/${subject}/mri/orig/T2.001.mgz" fi if [[ -z "$PYTHONUNBUFFERED" ]] @@ -647,148 +647,302 @@ fi #fi if [[ "${asegdkt_segfile: -3}" != "${conformed_name: -3}" ]] - then - echo "ERROR: Specified segmentation output and conformed image output do not have same file type." - echo "You passed --asegdkt_segfile ${asegdkt_segfile} and --conformed_name ${conformed_name}." - echo "Make sure these have the same file-format and adjust the names passed to the flags accordingly!" - exit 1; +then + echo "ERROR: Specified segmentation output and conformed image output do not have same file type." + echo "You passed --asegdkt_segfile ${asegdkt_segfile} and --conformed_name ${conformed_name}." + echo "Make sure these have the same file-format and adjust the names passed to the flags accordingly!" + exit 1; fi if [[ "$run_surf_pipeline" == "1" ]] && { [[ "$run_asegdkt_module" == "0" ]] || [[ "$run_seg_pipeline" == "0" ]]; } +then + if [[ ! -f "$asegdkt_segfile" ]] then - if [[ ! -f "$asegdkt_segfile" ]] - then - echo "ERROR: To run the surface pipeline, a whole brain segmentation must already exist." - echo "You passed --surf_only or --no_asegdkt, but the whole-brain segmentation ($asegdkt_segfile) could not be found." - echo "If the segmentation is not saved in the default location ($asegdkt_segfile_default), specify the absolute path and name via --asegdkt_segfile" - exit 1; - fi - if [[ ! -f "$conformed_name" ]] - then - echo "ERROR: To run the surface pipeline only, a conformed T1 image must already exist." - echo "You passed --surf_only but the conformed image ($conformed_name) could not be found." - echo "If the conformed image is not saved in the default location (\$SUBJECTS_DIR/\$SID/mri/orig.mgz)," - echo "specify the absolute path and name via --conformed_name." - exit 1; - fi + echo "ERROR: To run the surface pipeline, a whole brain segmentation must already exist." + echo "You passed --surf_only or --no_asegdkt, but the whole-brain segmentation ($asegdkt_segfile) could not be found." + echo "If the segmentation is not saved in the default location ($asegdkt_segfile_default), specify the absolute path and name via --asegdkt_segfile" + exit 1; + fi + if [[ ! -f "$conformed_name" ]] + then + echo "ERROR: To run the surface pipeline only, a conformed T1 image must already exist." + echo "You passed --surf_only but the conformed image ($conformed_name) could not be found." + echo "If the conformed image is not saved in the default location (\$SUBJECTS_DIR/\$SID/mri/orig.mgz)," + echo "specify the absolute path and name via --conformed_name." + exit 1; + fi fi if [[ "$run_seg_pipeline" == "1" ]] && { [[ "$run_asegdkt_module" == "0" ]] && [[ "$run_cereb_module" == "1" ]]; } +then + if [[ ! -f "$asegdkt_segfile" ]] then - if [[ ! -f "$asegdkt_segfile" ]] - then - echo "ERROR: To run the cerebellum segmentation but no asegdkt, the aseg segmentation must already exist." - echo "You passed --no_asegdkt but the asegdkt segmentation ($asegdkt_segfile) could not be found." - echo "If the segmentation is not saved in the default location ($asegdkt_segfile_default), specify the absolute path and name via --asegdkt_segfile" - exit 1; - fi + echo "ERROR: To run the cerebellum segmentation but no asegdkt, the aseg segmentation must already exist." + echo "You passed --no_asegdkt but the asegdkt segmentation ($asegdkt_segfile) could not be found." + echo "If the segmentation is not saved in the default location ($asegdkt_segfile_default), specify the absolute path and name via --asegdkt_segfile" + exit 1; + fi fi if [[ "$run_surf_pipeline" == "0" ]] && [[ "$run_seg_pipeline" == "0" ]] +then + echo "ERROR: You specified both --surf_only and --seg_only. Therefore neither part of the pipeline will be run." + echo "To run the whole FastSurfer pipeline, omit both flags." + exit 1; +fi + +if [[ "$run_surf_pipeline" == "1" ]] || [[ "$run_talairach_registration" == "true" ]] +then + msg="The surface pipeline and the talairach-registration in the segmentation pipeline require a FreeSurfer License" + if [[ -z "$FS_LICENSE" ]] then - echo "ERROR: You specified both --surf_only and --seg_only. Therefore neither part of the pipeline will be run." - echo "To run the whole FastSurfer pipeline, omit both flags." + msg="$msg, but no license was provided via --fs_license or the FS_LICENSE environment variable." + if [[ "$DO_NOT_SEARCH_FS_LICENSE_IN_FREESURFER_HOME" != "true" ]] && [[ -n "$FREESURFER_HOME" ]] + then + echo "WARNING: $msg Checking common license files in \$FREESURFER_HOME." + for filename in "license.dat" "license.txt" ".license" + do + if [[ -f "$FREESURFER_HOME/$filename" ]] + then + echo "Trying with '$FREESURFER_HOME/$filename', specify a license with --fs_license to overwrite." + export FS_LICENSE="$FREESURFER_HOME/$filename" + break + fi + done + if [[ -z "$FS_LICENSE" ]]; then echo "ERROR: No license found..." ; exit 1 ; fi + else + echo "ERROR: $msg" + exit 1; + fi + elif [[ ! -f "$FS_LICENSE" ]] + then + echo "ERROR: $msg, but the provided path is not a file: $FS_LICENSE." exit 1; + fi fi ########################################## START ######################################################## mkdir -p "$(dirname "$seg_log")" -if [[ -f "$seg_log" ]] && [[ "$run_seg_pipeline" == "1" ]] - then - append_flag=("$seg_log") -else - append_flag=(-a "$seg_log") +source "${reconsurfdir}/functions.sh" + +if [[ -f "$seg_log" ]]; then log_existed="true" +else log_existed="false" fi -VERSION=$($python $FASTSURFER_HOME/FastSurferCNN/version.py $version_args) -echo "Version: $VERSION" |& tee "${append_flag[@]}" +VERSION=$($python "$FASTSURFER_HOME/FastSurferCNN/version.py" "${version_args[@]}") +echo "Version: $VERSION" | tee -a "$seg_log" ### IF THE SCRIPT GETS TERMINATED, ADD A MESSAGE -trap "{ echo \"run_fastsurfer.sh terminated via signal at \$(date -R)!\" >> \"$seg_log\" }" SIGINT SIGTERM +trap "{ echo \"run_fastsurfer.sh terminated via signal at \$(date -R)!\" >> \"$seg_log\" ; }" SIGINT SIGTERM # create the build log, file with all version info in parallel printf "%s %s\n%s\n" "$THIS_SCRIPT" "${inputargs[*]}" "$(date -R)" >> "$build_log" -$python "$FASTSURFER_HOME/FastSurferCNN/version.py" $version_args >> "$build_log" & +$python "$FASTSURFER_HOME/FastSurferCNN/version.py" --sections all -o "$build_log" --prefer_cache & -if [[ ! -f "$seg_log" ]] || [[ "$run_seg_pipeline" != "1" ]] - then - echo "Running run_fastsurfer.sh on a " +if [[ "$run_seg_pipeline" != "1" ]] +then + echo "Running run_fastsurfer.sh without segmentation ; expecting previous --seg_only run in ${sd}/${subject}" | tee -a "$seg_log" fi if [[ "$run_seg_pipeline" == "1" ]] +then + # "============= Running FastSurferCNN (Creating Segmentation aparc.DKTatlas.aseg.mgz) ===============" + # use FastSurferCNN to create cortical parcellation + anatomical segmentation into 95 classes. + echo "Log file for segmentation FastSurferCNN/run_prediction.py" >> "$seg_log" + { date 2>&1 ; echo "" ; } | tee -a "$seg_log" + + if [[ "$run_asegdkt_module" == "1" ]] + then + cmd=($python "$fastsurfercnndir/run_prediction.py" --t1 "$t1" + --asegdkt_segfile "$asegdkt_segfile" --conformed_name "$conformed_name" + --brainmask_name "$mask_name" --aseg_name "$aseg_segfile" --sid "$subject" + --seg_log "$seg_log" --vox_size "$vox_size" --batch_size "$batch_size" + --viewagg_device "$viewagg" --device "$device" "${allow_root[@]}") + # specify the subject dir $sd, if asegdkt_segfile explicitly starts with it + if [[ "$sd" == "${asegdkt_segfile:0:${#sd}}" ]]; then cmd=("${cmd[@]}" --sd "$sd"); fi + echo_quoted "${cmd[@]}" | tee -a "$seg_log" + "${cmd[@]}" + exit_code="${PIPESTATUS[0]}" + if [[ "${exit_code}" == 2 ]] + then + echo "ERROR: FastSurfer asegdkt segmentation failed QC checks." | tee -a "$seg_log" + exit 1 + elif [[ "${exit_code}" -ne 0 ]] + then + echo "ERROR: FastSurfer asegdkt segmentation failed." | tee -a "$seg_log" + exit 1 + fi + fi + if [[ -n "$t2" ]] then - # "============= Running FastSurferCNN (Creating Segmentation aparc.DKTatlas.aseg.mgz) ===============" - # use FastSurferCNN to create cortical parcellation + anatomical segmentation into 95 classes. - echo "Log file for segmentation FastSurferCNN/run_prediction.py" >> "$seg_log" - date |& tee -a "$seg_log" - echo "" |& tee -a "$seg_log" + { + echo "INFO: Copying T2 file to ${copy_name_T2}..." + cmd=("nib-convert" "$t2" "$copy_name_T2") + echo_quoted "${cmd[@]}" + "${cmd[@]}" 2>&1 + + echo "INFO: Robust scaling (partial conforming) of T2 image..." + cmd=($python "${fastsurfercnndir}/data_loader/conform.py" --no_strict_lia + --no_vox_size --no_img_size "$t2" "$conformed_name_t2") + echo_quoted "${cmd[@]}" + "${cmd[@]}" 2>&1 + echo "Done." + } | tee -a "$seg_log" + fi - if [[ "$run_asegdkt_module" == "1" ]] + if [[ "$run_biasfield" == "1" ]] + then + { + # this will always run, since norm_name is set to subject_dir/mri/orig_nu.mgz, if it is not passed/empty + cmd=($python "${reconsurfdir}/N4_bias_correct.py" "--in" "$conformed_name" + --rescale "$norm_name" --aseg "$asegdkt_segfile" --threads "$threads") + echo "INFO: Running N4 bias-field correction" + echo_quoted "${cmd[@]}" + "${cmd[@]}" 2>&1 + } | tee -a "$seg_log" + if [[ "${PIPESTATUS[0]}" -ne 0 ]] + then + echo "ERROR: Biasfield correction failed" | tee -a "$seg_log" + exit 1 + fi + + if [[ "$run_talairach_registration" == "true" ]] + then + cmd=("$reconsurfdir/talairach-reg.sh" "$sd/$subject/mri" "$atlas3T" "$seg_log") + { + echo "INFO: Running talairach registration" + echo_quoted "${cmd[@]}" + } | tee -a "$seg_log" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" -ne 0 ]] then - cmd="$python $fastsurfercnndir/run_prediction.py --t1 $t1 --asegdkt_segfile $asegdkt_segfile --conformed_name $conformed_name --brainmask_name $mask_name --aseg_name $aseg_segfile --sid $subject --seg_log $seg_log --vox_size $vox_size --batch_size $batch_size --viewagg_device $viewagg --device $device $allow_root" - echo "$cmd" |& tee -a "$seg_log" - $cmd - exit_code="${PIPESTATUS[0]}" - if [[ "${exit_code}" == 2 ]] - then - echo "ERROR: FastSurfer asegdkt segmentation failed QC checks." - exit 1 - elif [[ "${exit_code}" -ne 0 ]] - then - echo "ERROR: FastSurfer asegdkt segmentation failed." - exit 1 - fi + echo "ERROR: talairach registration failed" | tee -a "$seg_log" + exit 1 + fi fi - # compute the bias-field corrected image + if [[ "$run_asegdkt_module" ]] + then + cmd=($python "${fastsurfercnndir}/segstats.py" --segfile "$asegdkt_segfile" + --segstatsfile "$asegdkt_statsfile" --normfile "$norm_name" + --threads "$threads" "${allow_root[@]}" --empty --excludeid 0 + --sd "${sd}" --sid "${subject}" + --ids 2 4 5 7 8 10 11 12 13 14 15 16 17 18 24 26 28 31 41 43 44 46 47 + 49 50 51 52 53 54 58 60 63 77 251 252 253 254 255 1002 1003 1005 + 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 + 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 + 1034 1035 2002 2003 2005 2006 2007 2008 2009 2010 2011 2012 2013 + 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 + 2027 2028 2029 2030 2031 2034 2035 + --lut "$fastsurfercnndir/config/FreeSurferColorLUT.txt" + measures --compute "Mask($mask_name)" "BrainSeg" "BrainSegNotVent" + "SupraTentorial" "SupraTentorialNotVent" + "SubCortGray" "rhCerebralWhiteMatter" + "lhCerebralWhiteMatter" "CerebralWhiteMatter" + # make sure to read white matter hypointensities from the + ) + if [[ "$run_talairach_registration" == "true" ]] + then + cmd=("${cmd[@]}" "EstimatedTotalIntraCranialVol" + "BrainSegVol-to-eTIV" "MaskVol-to-eTIV") + fi + { + echo_quoted "${cmd[@]}" + "${cmd[@]}" 2>&1 + } | tee -a "$seg_log" + if [[ "${PIPESTATUS[0]}" -ne 0 ]] + then + echo "ERROR: asegdkt statsfile generation failed" | tee -a "$seg_log" + exit 1 + fi + fi + fi # [[ "$run_biasfield" == "1" ]] + + if [[ -n "$t2" ]] + then if [[ "$run_biasfield" == "1" ]] + then + # ... we have a t2 image, bias field-correct it (save robustly scaled uchar) + cmd=($python "${reconsurfdir}/N4_bias_correct.py" "--in" "$copy_name_T2" + --out "$norm_name_t2" --threads "$threads" --uchar) + { + echo "INFO: Running N4 bias-field correction of the t2" + echo_quoted "${cmd[@]}" + } | tee -a "$seg_log" + "${cmd[@]}" 2>&1 | tee -a "$seg_log" + if [[ "${PIPESTATUS[0]}" -ne 0 ]] then - # this will always run, since norm_name is set to subject_dir/mri/orig_nu.mgz, if it is not passed/empty - echo "INFO: Running N4 bias-field correction" | tee -a "$seg_log" - cmd="$python ${reconsurfdir}/N4_bias_correct.py --in $conformed_name --rescale $norm_name --aseg $asegdkt_segfile --threads $threads" - echo "$cmd" |& tee -a "$seg_log" - $cmd - if [[ "${PIPESTATUS[0]}" -ne 0 ]] - then - echo "ERROR: Biasfield correction failed" | tee -a "$seg_log" - exit 1 - fi + echo "ERROR: T2 Biasfield correction failed" | tee -a "$seg_log" + exit 1 + fi + else + # no biasfield, but a t2 is passed; presumably, this is biasfield corrected + cmd=($python "${fastsurfercnndir}/data_loader/conform.py" --no_strict_lia + --no_iso_vox --no_img_size "$t2" "$norm_name_t2") + { + echo "INFO: Robustly rescaling $t2 to uchar ($norm_name_t2), which is assumed to already be biasfield corrected." + echo "WARNING: --no_biasfield is activated, but FastSurfer does not check, if " + echo " passed T2 image is properly scaled and typed. T2 needs to be uchar and" + echo " robustly scaled (see FastSurferCNN/utils/data_loader/conform.py)!" + } | tee -a "$seg_log" + "${cmd[@]}" 2>&1 | tee -a "$seg_log" + fi + fi - if [[ "$run_asegdkt_module" ]] - then - cmd="$python ${fastsurfercnndir}/segstats.py --segfile $asegdkt_segfile --segstatsfile $asegdkt_statsfile --normfile $norm_name $allow_root --empty --excludeid 0 --ids 2 4 5 7 8 10 11 12 13 14 15 16 17 18 24 26 28 31 41 43 44 46 47 49 50 51 52 53 54 58 60 63 77 251 252 253 254 255 1002 1003 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1034 1035 2002 2003 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2034 2035 --lut $fastsurfercnndir/config/FreeSurferColorLUT.txt --threads $threads " - echo "$cmd" |& tee -a "$seg_log" - $cmd |& tee -a "$seg_log" - if [[ "${PIPESTATUS[0]}" -ne 0 ]] - then - echo "ERROR: asegdkt statsfile generation failed" | tee -a "$seg_log" - exit 1 - fi - fi + if [[ "$run_cereb_module" == "1" ]] + then + if [[ "$run_biasfield" == "1" ]] + then + cereb_flags=("${cereb_flags[@]}" --norm_name "$norm_name" + --cereb_statsfile "$cereb_statsfile") + else + echo "INFO: Running CerebNet without generating a statsfile, since biasfield correction deactivated '--no_biasfield'." | tee -a "$seg_log" fi - if [[ "$run_cereb_module" == "1" ]] - then - if [[ "$run_biasfield" == "1" ]] - then - cereb_flags="$cereb_flags --norm_name $norm_name --cereb_statsfile $cereb_statsfile" - else - echo "INFO: Running CerebNet without generating a statsfile, since biasfield correction deactivated '--no_biasfield'." |& tee -a $seg_log - fi + cmd=($python "$cerebnetdir/run_prediction.py" --t1 "$t1" + --asegdkt_segfile "$asegdkt_segfile" --conformed_name "$conformed_name" + --cereb_segfile "$cereb_segfile" --seg_log "$seg_log" --async_io + --batch_size "$batch_size" --viewagg_device "$viewagg" --device "$device" + --threads "$threads" "${cereb_flags[@]}" "${allow_root[@]}") + # specify the subject dir $sd, if asegdkt_segfile explicitly starts with it + if [[ "$sd" == "${cereb_segfile:0:${#sd}}" ]] ; then cmd=("${cmd[@]}" --sd "$sd"); fi + echo_quoted "${cmd[@]}" | tee -a "$seg_log" + "${cmd[@]}" # no tee, directly logging to $seg_log + if [[ "${PIPESTATUS[0]}" -ne 0 ]] + then + echo "ERROR: Cerebellum Segmentation failed" | tee -a "$seg_log" + exit 1 + fi + fi - cmd="$python $cerebnetdir/run_prediction.py --t1 $t1 --asegdkt_segfile $asegdkt_segfile --conformed_name $conformed_name --cereb_segfile $cereb_segfile --seg_log $seg_log --batch_size $batch_size --viewagg_device $viewagg --device $device --async_io --threads $threads$cereb_flags $allow_root" - echo "$cmd" |& tee -a "$seg_log" - $cmd - if [[ "${PIPESTATUS[0]}" -ne 0 ]] - then - echo "ERROR: Cerebellum Segmentation failed" | tee -a "$seg_log" - exit 1 - fi + if [[ "$run_hypvinn_module" == "1" ]] + then + # currently, the order of the T2 preprocessing only is registration to T1w + cmd=($python "$hypvinndir/run_prediction.py" --sd "${sd}" --sid "${subject}" + "${hypvinn_flags[@]}" "${allow_root[@]}" --threads "$threads" --async_io + --batch_size "$batch_size" --seg_log "$seg_log" --device "$device" + --viewagg_device "$viewagg" --t1) + if [[ "$run_biasfield" == "1" ]] + then + cmd+=("$norm_name") + if [[ -n "$t2" ]] ; then cmd+=(--t2 "$norm_name_t2"); fi + else + echo "WARNING: We strongly recommend to *not* exclude the biasfield (--no_biasfield) with the hypothal module!" + cmd+=("$t1") + if [[ -n "$t2" ]] ; then cmd+=(--t2 "$t2"); fi + fi + echo_quoted "${cmd[@]}" | tee -a "$seg_log" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" -ne 0 ]] + then + echo "ERROR: Hypothalamus Segmentation failed" | tee -a "$seg_log" + exit 1 fi + fi # if [[ ! -f "$merged_segfile" ]] # then @@ -797,17 +951,17 @@ if [[ "$run_seg_pipeline" == "1" ]] fi if [[ "$run_surf_pipeline" == "1" ]] - then - # ============= Running recon-surf (surfaces, thickness etc.) =============== - # use recon-surf to create surface models based on the FastSurferCNN segmentation. - pushd "$reconsurfdir" - cmd="./recon-surf.sh --sid $subject --sd $sd --t1 $conformed_name --asegdkt_segfile $asegdkt_segfile" - cmd="$cmd $fstess $fsqsphere $flag_3T $fsaparc $fssurfreg $doParallel --threads $threads --py $python" - cmd="$cmd $vcheck $vfst1 $allow_root" - echo "$cmd" |& tee -a "$seg_log" - $cmd - if [[ "${PIPESTATUS[0]}" -ne 0 ]] ; then exit 1 ; fi - popd +then + # ============= Running recon-surf (surfaces, thickness etc.) =============== + # use recon-surf to create surface models based on the FastSurferCNN segmentation. + pushd "$reconsurfdir" > /dev/null || exit 1 + cmd=("./recon-surf.sh" --sid "$subject" --sd "$sd" --t1 "$conformed_name" + --asegdkt_segfile "$asegdkt_segfile" --threads "$threads" --py "$python" + "${surf_flags[@]}" "${allow_root[@]}") + echo_quoted "${cmd[@]}" | tee -a "$seg_log" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" -ne 0 ]] ; then exit 1 ; fi + popd > /dev/null || return fi ########################################## End ######################################################## diff --git a/srun_fastsurfer.sh b/srun_fastsurfer.sh index 3339d6a5..4c32a0ec 100755 --- a/srun_fastsurfer.sh +++ b/srun_fastsurfer.sh @@ -39,11 +39,13 @@ extra_singularity_options_seg="" email="" pattern="*.{nii.gz,nii,mgz}" subject_list="" -subject_list_awk_code="\$1:\$2" +subject_list_awk_code_sid="\$1" +subject_list_awk_code_args="\$2" subject_list_delim="=" jobarray="" timelimit_seg=5 -timelimit_surf=180 +# 1mm can take 1h per hemi plus 1h extra on a single core (depending on cpu speed) +timelimit_surf=$((4 * 60)) function usage() { @@ -55,7 +57,8 @@ srun_fastsurfer.sh [--data ] [--sd ] [--work ] (--pattern |--subject_list [--subject_list_delim ] - [--subject_list_awk_code :]) + [--subject_list_awk_code_sid ] + [--subject_list_awk_code_t1 ]) [--singularity_image ] [--extra_singularity_options [(seg|surf)=]] [--num_cases_per_task ] [--num_cpus_per_task ] [--cpu_only] [--time (surf|seg)=] @@ -84,6 +87,10 @@ Data- and subject-related options: a work directory, which can use IO-optimized cluster storage (see --work). --work: directory with fast filesystem on cluster (default: \$HPCWORK/fastsurfer-processing/$(date +%Y%m%d-%H%M%S)) + NOTE: THIS SCRIPT considers this directory to be owned by this script and job! + No modifications should be made to the directory after the job is started until it is + finished (if the job fails, cleanup of this directory may be necessary) and it should be + empty! --data: (root) directory to search in for t1 files (default: current work directory). --pattern: glob string to find image files in 'data directory' (default: *.{nii,nii.gz,mgz}), for example --data /data/ --pattern \*/\*/mri/t1.nii.gz @@ -92,18 +99,28 @@ Data- and subject-related options: subject_id1=/path/to/t1.mgz ... This option invalidates the --pattern option. + May also add additional parameters like: + subject_id1=/path/to/t1.mgz --vox_size 1.0 --subject_list_delim: alternative delimiter in the file (default: "="). For example, if you pass --subject_list_delim "," the subject_list file is parsed as a comma-delimited csv file. ---subject_list_awk_code :: alternative way to construct - subject_id and subject_path from the row in the subject_list (default: '\$1:\$2'), other - examples: '\$1:\$2/\$1/mri/orig.mgz', where the first field is the subject_id and the second - field is the containing folder, e.g. the study. +--subject_list_awk_code_sid : alternative way to construct the subject_id + from the row in the subject_list (default: '\$1'). +--subject_list_awk_code_args : alternative way to construct the image_path and + additional parameters from the row in the subject_list (default: '\$2'), other examples: + '\$2/\$1/mri/orig.mgz', where the first field (of the subject_list file) is the subject_id + and the second field is the containing folder, e.g. the study. + Example for additional parameters: + --subject_list_delim "," --subject_list_awk_code_args '\$2 " --vox_size " \$4' + to implement from the subject_list line + subject-101,raw/T1w-101A.nii.gz,study-1,0.9 + to (additional arguments must be comma-separated) + --sid subject-101 --t1 /raw/T1w-101A.nii.gz --vox_size 0.9 FastSurfer options: --fs_license: path to the freesurfer license (either absolute path or relative to pwd) --seg_only: only run the segmentation pipeline --surf_only: only run the surface pipeline (--sd must contain previous --seg_only processing) -... also standard FastSurfer options can be passed, like --3T, --no_cereb, etc. +--***: also standard FastSurfer options can be passed, like --3T, --no_cereb, etc. Singularity-related options: --singularity_image: Path to the singularity image to use for segmentation and surface @@ -160,11 +177,11 @@ EOF } # the memory required for the surface and the segmentation pipeline depends on the -# voxel size of the image, here we use values proven to work for 0.7mm (and also 0.8 and 1m) +# voxel size of the image, here we use values proven to work for 0.7mm (and also 0.8 and 1mm) mem_seg_cpu=10 # in GB, seg on cpu, actually required: 9G mem_seg_gpu=7 # in GB, seg on gpu, actually required: 6G -mem_surf_parallel=20 # in GB, hemi in parallel -mem_surf_noparallel=18 # in GB, hemi in series +mem_surf_parallel=6 # in GB, hemi in parallel +mem_surf_noparallel=4 # in GB, hemi in series num_cpus_surf=1 # base number of cpus to use for surfaces (doubled if --parallel) do_parallel="false" @@ -190,173 +207,123 @@ source "$(dirname "$THIS_SCRIPT")/stools.sh" inputargs=("$@") while [[ $# -gt 0 ]] do +KEY=$1 # make key lowercase -key=$(echo "$1" | tr '[:upper:]' '[:lower:]') +key=$(echo "$KEY" | tr '[:upper:]' '[:lower:]') +shift case $key in - --fs_license) - fs_license="$2" - shift # past argument - shift # past value - ;; - --data) - in_dir="$2" - shift - shift - ;; - --sd) - out_dir="$2" - shift # past argument - shift # past value - ;; - --pattern) - pattern="$2" - shift # past argument - shift # past value - ;; - --subject_list) - subject_list="$2" - shift - shift - ;; - --subject_list_delim) - subject_list_delim="$2" - shift - shift - ;; + --fs_license) fs_license="$1" ; shift ;; + --data) in_dir="$1" ; shift ;; + --sd) out_dir="$1" ; shift;; + --pattern) pattern="$1"; shift ;; + --subject_list|--subjects_list) subject_list="$1" ; shift ;; + --subject_list_delim|--subjects_list_delim) subject_list_delim="$1" ; shift ;; --subject_list_awk_code) - subject_list_awk_code="$2" - shift - shift - ;; - --num_cases_per_task) - num_cases_per_task="$2" - shift # past argument - shift # past value - ;; - --num_cpus_per_task) - num_cpus_per_task="$2" - shift # past argument - shift # past value - ;; - --cpu_only) - cpu_only="true" - shift # past argument - ;; - --skip_cleanup) - do_cleanup="false" - shift # past argument - ;; - --work) - hpc_work="$2" - shift - shift - ;; - --surf_only) - surf_only="true" - shift - ;; - --seg_only) - seg_only="true" - shift - ;; - --parallel) - do_parallel="true" - shift - ;; - --singularity_image) - singularity_image="$2" - shift - shift + echo "--subject_list_awk_code is outdated, use subject_list_awk_code_sid and subject_list_awk_code_args!" + exit 1 ;; + --subject_list_awk_code_sid|--subjects_list_awk_code_sid) subject_list_awk_code_sid="$1" ; shift ;; + --subject_list_awk_code_args|--subjects_list_awk_code_args) subject_list_awk_code_args="$1" ; shift ;; + --num_cases_per_task) num_cases_per_task="$1" ; shift ;; + --num_cpus_per_task) num_cpus_per_task="$1" ; shift ;; + --cpu_only) cpu_only="true" ;; + --skip_cleanup) do_cleanup="false" ;; + --work) hpc_work="$1" ; shift ;; + --surf_only) surf_only="true" ;; + --seg_only) seg_only="true" ;; + --parallel) do_parallel="true" ;; + --singularity_image) singularity_image="$1" ; shift ;; --partition) - partition_temp="$2" # make key lowercase - lower_value=$(echo "$2" | tr '[:upper:]' '[:lower:]') + lower_value=$(echo "$1" | tr '[:upper:]' '[:lower:]') if [[ "$lower_value" =~ seg=* ]] then - partition_seg=${partition_temp:4} + partition_seg=${1:4} elif [[ "$lower_value" =~ surf=* ]] then - partition_surf=${partition_temp:5} + partition_surf=${1:5} else - partition=$2 + partition=$1 fi shift - shift ;; --extra_singularity_options) - singularity_opts_temp="$2" # make key lowercase - lower_value=$(echo "$2" | tr '[:upper:]' '[:lower:]') + lower_value=$(echo "$1" | tr '[:upper:]' '[:lower:]') if [[ "$lower_value" =~ seg=* ]] then - extra_singularity_options_seg=${singularity_opts_temp:4} + extra_singularity_options_seg=${1:4} elif [[ "$lower_value" =~ surf=* ]] then - extra_singularity_options_surf=${singularity_opts_temp:5} + extra_singularity_options_surf=${1:5} else - extra_singularity_options=$2 + extra_singularity_options=$1 fi shift - shift ;; --time) - time_temp="$2" # make key lowercase - lower_value=$(echo "$2" | tr '[:upper:]' '[:lower:]') + lower_value=$(echo "$1" | tr '[:upper:]' '[:lower:]') if [[ "$lower_value" =~ ^seg=[0-9]+ ]] then - timelimit_seg=${time_temp:4} + timelimit_seg=${1:4} elif [[ "$lower_value" =~ surf=([0-9]+|[0-9]{0,1}(:[0-9]{2}){0,1}) ]] then - timelimit_surf=${time_temp:5} + timelimit_surf=${1:5} else - echo "Invalid parameter to --time: $2, must be seg|surf=" + echo "Invalid parameter to --time: $1, must be seg|surf=" exit 1 fi shift - shift - ;; - --email) - email="$2" - shift - shift - ;; - --dry) - submit_jobs="false" - shift - ;; - --debug) - debug="true" - shift ;; - --slurm_jobarray) - jobarray=$2 - shift + --email) email="$1" ; shift ;; + --dry) submit_jobs="false" ;; + --debug) debug="true" ;; + --slurm_jobarray) jobarray=$1 ; shift ;; + --help) usage ; exit ;; + --mem) + # make key lowercase + lower_value=$(echo "$1" | tr '[:upper:]' '[:lower:]') + if [[ "$lower_value" =~ ^seg=[0-9]+$ ]] + then + mem_seg_cpu=${1:4} + mem_seg_gpu=${1:4} + elif [[ "$lower_value" =~ ^surf=[0-9]+$ ]] + then + mem_surf_parallel=${1:5} + mem_surf_noparallel=${1:5} + else + echo "Invalid parameter to --mem: $1, must be seg|surf=" + exit 1 + fi shift ;; - --help) - usage - exit - ;; *) # unknown option - POSITIONAL_FASTSURFER[$i]=$1 - i=$(($i + 1)) - shift + POSITIONAL_FASTSURFER[i]=$KEY + i=$((i + 1)) ;; esac done -echo "Log of FastSurfer SLURM script" -date -R -echo "$THIS_SCRIPT ${inputargs[*]}" -echo "" +# create a temporary logfile, which we copy over to the final log file location once it +# is available +tmpLF=$(mktemp) +LF=$tmpLF + +function log() { echo "$@" | tee -a "$LF" ; } +function logf() { printf "%s" "$@" | tee -a "$LF" ; } + +log "Log of FastSurfer SLURM script" +log "$(date -R)" +log "$THIS_SCRIPT ${inputargs[*]}" +log "" make_hpc_work="false" if [[ -d "$hpc_work" ]] then + # delete_hpc_work is true, if the hpc_work directory was created by this script. delete_hpc_work="false" else delete_hpc_work="true" @@ -368,6 +335,7 @@ then echo "Neither --work nor \$HPCWORK are defined, make sure to pass --work!" exit 1 else + # check_hpc_work only has log messages, if it also exists check_hpc_work "$HPCWORK/fastsurfer-processing" "false" hpc_work_already_exists="true" # create a new and unused directory @@ -378,56 +346,71 @@ then make_hpc_work="true" fi else + # also checks, if hpc_work is also empty (required) + # check_hpc_work only has log messages, if it also exists check_hpc_work "$hpc_work" "true" fi if [[ "$debug" == "true" ]] then - echo "Debug parameters to script srun_fastsurfer:" - echo "" - echo "SLURM options:" - echo "submit jobs and perform operations: $submit_jobs" - echo "perform the cleanup step: $do_cleanup" - echo "seg/surf running on slurm partition:" \ - "$(first_non_empty_arg "$partition_seg" "$partition")" "/" \ - "$(first_non_empty_arg "$partition_surf" "$partition")" - echo "num_cpus_per_task/max. num_cases_per_task: $num_cpus_per_task/$num_cases_per_task" - echo "segmentation on cpu only: $cpu_only" - echo "Data options:" - echo "source dir: $in_dir" - if [[ -n "$subject_list" ]] + function debug () { log "$@" ; } + function debugf () { logf "$@" ; } +else + # all debug messages go into logfile no matter what, but here, not to the console + function debug () { echo "$@" >> "$LF" ; } + function debugf () { printf "%s" "$@" >> "$LF" ; } + if [[ "$submit_jobs" == "false" ]] then - echo "Reading subjects from subject_list file $subject_list" - echo "subject_list read options: delimiter: '${subject_list_delim}', awk code: '${subject_list_awk_code}'" - else - echo "pattern to search for images: $pattern" + log "dry run, no jobs or operations are performed" + log "" fi - echo "output (subject) dir: $out_dir" - echo "work dir: $hpc_work" - echo "" - echo "FastSurfer parameters:" - echo "singularity image: $singularity_image" - echo "FreeSurfer license: $fs_license" - if [[ "$seg_only" == "true" ]]; then echo "--seg_only"; fi - if [[ "$surf_only" == "true" ]]; then echo "--surf_only"; fi - if [[ "$do_parallel" == "true" ]]; then echo "--parallel"; fi - for p in "${POSITIONAL_FASTSURFER[@]}" - do - if [[ "$p" == --* ]]; then printf "\n%s" "$p"; - else printf " %s" "$p"; - fi - done - echo "" - echo "Running in$(ls -l /proc/$$/exe | cut -d">" -f2)" - echo "" +fi + +debug "Debug parameters to script srun_fastsurfer:" +debug "" +debug "SLURM options:" +debug "submit jobs and perform operations: $submit_jobs" +debug "perform the cleanup step: $do_cleanup" +debug "seg/surf running on slurm partition:" \ + "$(first_non_empty_arg "$partition_seg" "$partition")" "/" \ + "$(first_non_empty_arg "$partition_surf" "$partition")" +debug "num_cpus_per_task/max. num_cases_per_task: $num_cpus_per_task/$num_cases_per_task" +debug "segmentation on cpu only: $cpu_only" +debug "Data options:" +debug "source dir: $in_dir" +if [[ -n "$subject_list" ]] +then + debug "Reading subjects from subject_list file $subject_list" + debug "subject_list read options:" + debug " delimiter: '${subject_list_delim}'" + debug " sid awk code: '${subject_list_awk_code_sid}'" + debug " args awk code: '${subject_list_awk_code_args}'" else - if [[ "$submit_jobs" == "false" ]]; then echo "dry run, no jobs or operations are performed"; echo ""; fi + debug "pattern to search for images: $pattern" fi +debug "output (subject) dir: $out_dir" +debug "work dir: $hpc_work" +debug "" +debug "FastSurfer parameters:" +debug "singularity image: $singularity_image" +debug "FreeSurfer license: $fs_license" +if [[ "$seg_only" == "true" ]]; then debug "--seg_only"; fi +if [[ "$surf_only" == "true" ]]; then debug "--surf_only"; fi +if [[ "$do_parallel" == "true" ]]; then debug "--parallel"; fi +for p in "${POSITIONAL_FASTSURFER[@]}" +do + if [[ "$p" == --* ]]; then debugf "\n%s" "$p"; + else debugf " %s" "$p"; + fi +done +shell=$(stat -c %N "/proc/$$/exe" | cut -d">" -f2 | tail -c +3 | head -c -2) +debug "Running in shell $shell: $($shell --version 2>/dev/null | head -n 1)" +debug "" if [[ "${pattern/#\/}" != "$pattern" ]] - then - echo "ERROR: Absolute paths in --pattern are not allowed, set a base path with --data (this may even be /, i.e. root)." - exit 1 +then + echo "ERROR: Absolute paths in --pattern are not allowed, set a base path with --data (this may even be /, i.e. root)." + exit 1 fi check_singularity_image "$singularity_image" @@ -437,12 +420,12 @@ check_out_dir "$out_dir" if [[ "$cpu_only" == "true" ]] && [[ "$timelimit_seg" -lt 6 ]] then - echo "WARNING!!!" - echo "------------------------------------------------------------------------" - echo "You specified the segmentation shall be performed on the cpu, but the" - echo "time limit per segmentation is less than 6 minutes (default is optimized " - echo "for GPU acceleration @ 5 minutes). This is very likely insufficient!" - echo "------------------------------------------------------------------------" + log "WARNING!!!" + log "------------------------------------------------------------------------" + log "You specified the segmentation shall be performed on the cpu, but the" + log "time limit per segmentation is less than 6 minutes (default is optimized " + log "for GPU acceleration @ 5 minutes). This is very likely insufficient!" + log "------------------------------------------------------------------------" fi # step zero: make directories @@ -450,7 +433,7 @@ if [[ "$submit_jobs" == "true" ]] then if [[ "$make_hpc_work" == "true" ]]; then mkdir "$hpc_work" ; fi make_hpc_work_dirs "$hpc_work" - echo "Setting up the work directory..." + log "Setting up the work directory..." fi wait # for directories to be made @@ -458,18 +441,19 @@ wait # for directories to be made # step one: copy singularity image to hpc all_cases_file="/$hpc_work/scripts/subject_list" -echo "cp \"$singularity_image\" \"$hpc_work/images/fastsurfer.sif\"" -echo "cp \"$(dirname $THIS_SCRIPT)/brun_fastsurfer.sh\" \"$hpc_work/scripts\"" -echo "cp \"$fs_license\" \"$hpc_work/scripts/.fs_license\"" -echo "Create Status/Success file at $hpc_work/scripts/subject_success" +log "cp \"$singularity_image\" \"$hpc_work/images/fastsurfer.sif\"" +script_dir="$(dirname "$THIS_SCRIPT")" +log "cp \"$script_dir/brun_fastsurfer.sh\" \"$script_dir/stools.sh\" \"$hpc_work/scripts\"" +log "cp \"$fs_license\" \"$hpc_work/scripts/.fs_license\"" +log "Create Status/Success file at $hpc_work/scripts/subject_success" tofile="cat" if [[ "$submit_jobs" == "true" ]] then cp "$singularity_image" "$hpc_work/images/fastsurfer.sif" & - cp "$(dirname $THIS_SCRIPT)/brun_fastsurfer.sh" "$hpc_work/scripts" & + cp "$script_dir/brun_fastsurfer.sh" "$script_dir/stools.sh" "$hpc_work/scripts" & cp "$fs_license" "$hpc_work/scripts/.fs_license" & - echo "#Status/Success file of srun_fastsurfer-run $(date)" > "$hpc_work/scripts/subject_success" & + log "#Status/Success file of srun_fastsurfer-run $(date)" > "$hpc_work/scripts/subject_success" & tofile="tee $all_cases_file" fi @@ -478,10 +462,16 @@ fi if [[ -n "$subject_list" ]] then # the test for files (check_subject_images) requires paths to be wrt - cases=$(translate_cases "$in_dir" "$subject_list" "$in_dir" "${subject_list_delim}" "${subject_list_awk_code}") + cases=$(translate_cases "$in_dir" "$subject_list" "$in_dir" "${subject_list_delim}" "${subject_list_awk_code_sid}" "${subject_list_awk_code_args}") check_subject_images "$cases" + if [[ "$debug" == "true" ]] + then + log "Debug output of the parsed subject_list:" + log "$cases" + log "" + fi - cases=$(translate_cases "$in_dir" "$subject_list" "/source" "${subject_list_delim}" "${subject_list_awk_code}" | $tofile) + cases=$(translate_cases "$in_dir" "$subject_list" "/source" "${subject_list_delim}" "${subject_list_awk_code_sid}" "${subject_list_awk_code_args}" | $tofile) else cases=$(read_cases "$in_dir" "$pattern" "/source" | $tofile) fi @@ -490,30 +480,44 @@ num_cases=$(echo "$cases" | wc -l) if [[ "$num_cases" -lt 1 ]] || [[ -z "$cases" ]] then wait - echo "WARNING: No cases found using the parameters provided. Aborting job submission!" + log "WARNING: No cases found using the parameters provided. Aborting job submission!" if [[ "$submit_jobs" == "true" ]] && [[ "$do_cleanup" == "true" ]] then - echo "Cleaning temporary work directory!" + log "Cleaning temporary work directory!" rm -R "$hpc_work/images" rm -R "$hpc_work/scripts" - if [[ "$delete_hpc_work" == "false" ]] + if [[ "$delete_hpc_work" == "true" ]] then + # delete_hpc_work is true, if the hpc_work directory was created by this script. rm -R "$hpc_work" fi fi exit 0 fi +cleanup_mode="mv" +if [[ "$do_cleanup" == "true" ]] +then + if [[ -n "$jobarray" ]]; then jobarray_defined="true" + else jobarray_defined="false" + fi + check_cases_in_out_dir "$out_dir" "$cases" "$jobarray_defined" + if [[ "$cleanup_mode" == "cp" ]] + then + log "Overwriting existing cases in $out_dir by data generated with FastSurfer." + fi +fi if [[ "$submit_jobs" != "true" ]] then - echo "Copying singularity image and scripts..." + log "Copying singularity image and scripts..." fi wait # for copy and other stuff brun_fastsurfer="scripts/brun_fastsurfer.sh" fastsurfer_options=() +log_name="slurm-submit" if [[ "$debug" == "true" ]] then fastsurfer_options=("${fastsurfer_options[@]}" --debug) @@ -527,15 +531,15 @@ then slurm_email=(--mail-user "$email") if [[ "$debug" == "true" ]] then - echo "Sending emails on ALL conditions" + log "Sending emails on ALL conditions" slurm_email=("${slurm_email[@]}" --mail-type "ALL,ARRAY_TASKS") else - echo "Sending emails on END,FAIL conditions" + log "Sending emails on END,FAIL conditions" slurm_email=("${slurm_email[@]}" --mail-type "END,FAIL,ARRAY_TASKS") fi fi -jobarray_size="$(($(($num_cases - 1)) / $num_cases_per_task + 1))" -real_num_cases_per_task="$(($(($num_cases - 1)) / $jobarray_size + 1))" +jobarray_size="$(($((num_cases - 1)) / num_cases_per_task + 1))" +real_num_cases_per_task="$(($((num_cases - 1)) / jobarray_size + 1))" if [[ "$jobarray_size" -gt 1 ]] then if [[ -n "$jobarray" ]] @@ -544,6 +548,7 @@ then else jobarray_option=("--array=1-$jobarray_size") fi + fastsurfer_options=("${fastsurfer_options[@]}" --batch "slurm_task_id/$jobarray_size") jobarray_depend="aftercorr" else jobarray_option=() @@ -567,58 +572,61 @@ then seg_cmd_file=$(mktemp) fi # END OF NEW - slurm_partition=$(first_non_empty_partition "$partition_seg" "$partition") + slurm_part_=$(first_non_empty_partition "$partition_seg" "$partition") + if [[ -z "$slurm_part_" ]] ; then slurm_partition=() ; else slurm_partition=("$slurm_part_") ; fi { echo "#!/bin/bash" echo "module load singularity" - echo "srun --ntasks=1 --nodes=1 --cpus-per-task=$num_cpus_per_task \\" - echo " singularity exec --nv -B \"$hpc_work:/data,$in_dir:/source:ro\" --no-home \\" + echo "singularity exec --nv -B \"$hpc_work:/data,$in_dir:/source:ro\" --no-home --env TQDM_DISABLE=1 \\" if [[ -n "$extra_singularity_options" ]] || [[ -n "$extra_singularity_options_seg" ]]; then echo " $extra_singularity_options $extra_singularity_options_seg\\" fi echo " $hpc_work/images/fastsurfer.sif \\" echo " /data/$brun_fastsurfer ${fastsurfer_options[*]} ${fastsurfer_seg_options[*]}" - } > $seg_cmd_file + echo "# discard the exit code of run_fastsurfer (exit with success), so following" + echo "# jobarray items will be started by slurm under the aftercorr dependency" + echo "# see https://github.com/Deep-MI/FastSurfer/pull/434#issuecomment-1910805112" + echo "exit 0" + } > "$seg_cmd_file" if [[ "$cpu_only" == "true" ]]; then mem_seg="$mem_seg_cpu" else mem_seg="$mem_seg_gpu" fi - # note that there can be a decent startup cost for each run, running multiple cases per task significantly reduces this + # note that there can be a decent startup cost for each run, running multiple cases + # per task significantly reduces this seg_slurm_sched=("--mem=${mem_seg}G" "--cpus-per-task=$num_cpus_per_task" - --time=$(($timelimit_seg * $real_num_cases_per_task + 5)) - $slurm_partition "${slurm_email[@]}" + --time=$((timelimit_seg * real_num_cases_per_task + 5)) + "${slurm_partition[@]}" "${slurm_email[@]}" "${jobarray_option[@]}" -J "FastSurfer-Seg-$USER" -o "$hpc_work/logs/seg_%A_%a.log" "$seg_cmd_filename") if [[ "$cpu_only" == "true" ]] then - if [[ "$debug" == "true" ]] - then - echo "Schedule SLURM job without gpu" - fi + debug "Schedule SLURM job without gpu" else seg_slurm_sched=(--gpus=1 "${seg_slurm_sched[@]}") fi - echo "chmod +x $seg_cmd_filename" - chmod +x $seg_cmd_file - echo "sbatch --parsable ${seg_slurm_sched[*]}" + log "chmod +x $seg_cmd_filename" + chmod +x "$seg_cmd_file" + log "sbatch --parsable ${seg_slurm_sched[*]}" echo "--- sbatch script $seg_cmd_filename ---" - cat $seg_cmd_file + cat "$seg_cmd_file" echo "--- end of script ---" if [[ "$submit_jobs" == "true" ]] then - seg_jobid=$(sbatch --parsable ${seg_slurm_sched[*]}) - echo "Submitted Segmentation Jobs $seg_jobid" + seg_jobid=$(sbatch --parsable "${seg_slurm_sched[@]}") + log "Submitted Segmentation Jobs $seg_jobid" else - echo "Not submitting the Segmentation Jobs to slurm (--dry)." + log "Not submitting the Segmentation Jobs to slurm (--dry)." seg_jobid=SEG_JOB_ID fi + log_name="${log_name}_${seg_jobid}" cleanup_depend="afterany:$seg_jobid" surf_depend="--depend=$jobarray_depend:$seg_jobid" elif [[ "$surf_only" == "true" ]] then # do not run segmentation, but copy over all cases from data to work copy_jobid= - make_copy_job "$hpc_work" "$out_dir" "$hpc_work/scripts/subject_list" "$submit_jobs" + make_copy_job "$hpc_work" "$out_dir" "$hpc_work/scripts/subject_list" "$LF" "$submit_jobs" if [[ -n "$copy_jobid" ]] then surf_depend="--depend=afterok:$copy_jobid" @@ -626,6 +634,7 @@ then echo "ERROR: \$copy_jobid not defined!" exit 1 fi + log_name="${log_name}_${copy_jobid}" fi if [[ "$seg_only" != "true" ]] @@ -660,7 +669,8 @@ then if [[ "$mem_surf" -gt "$((mem_per_core * cores_per_task))" ]]; then mem_per_core=$((mem_per_core+1)) fi - slurm_partition=$(first_non_empty_partition "$partition_surf" "$partition") + slurm_part_=$(first_non_empty_partition "$partition_surf" "$partition") + if [[ -z "$slurm_part_" ]] ; then slurm_partition=() ; else slurm_partition=("$slurm_part_") ; fi { echo "#!/bin/bash" echo "module load singularity" @@ -676,34 +686,45 @@ then echo " /fastsurfer/run_fastsurfer.sh)" echo "$hpc_work/$brun_fastsurfer --run_fastsurfer \"\${run_fastsurfer[*]}\" \\" echo " ${fastsurfer_options[*]} ${fastsurfer_surf_options[*]}" - } > $surf_cmd_file + } > "$surf_cmd_file" surf_slurm_sched=("--mem-per-cpu=${mem_per_core}G" "--cpus-per-task=$cores_per_task" "--ntasks=$real_num_cases_per_task" "--nodes=1-$real_num_cases_per_task" "--hint=nomultithread" "${jobarray_option[@]}" "$surf_depend" -J "FastSurfer-Surf-$USER" -o "$hpc_work/logs/surf_%A_%a.log" - $slurm_partition "${slurm_email[@]}" "$surf_cmd_filename") - chmod +x $surf_cmd_file - echo "sbatch --parsable ${surf_slurm_sched[*]}" + "${slurm_partition[@]}" "${slurm_email[@]}" "$surf_cmd_filename") + chmod +x "$surf_cmd_file" + log "sbatch --parsable ${surf_slurm_sched[*]}" echo "--- sbatch script $surf_cmd_filename ---" - cat $surf_cmd_file + cat "$surf_cmd_file" echo "--- end of script ---" if [[ "$submit_jobs" == "true" ]] then - surf_jobid=$(sbatch --parsable ${surf_slurm_sched[*]}) - echo "Submitted Surface Jobs $surf_jobid" + surf_jobid=$(sbatch --parsable "${surf_slurm_sched[@]}") + log "Submitted Surface Jobs $surf_jobid" else - echo "Not submitting the Surface Jobs to slurm (--dry)." + log "Not submitting the Surface Jobs to slurm (--dry)." surf_jobid=SURF_JOB_ID fi + log_name="${log_name}_${surf_jobid}" cleanup_depend="afterany:$surf_jobid" fi # step four: copy results back and clean the output directory if [[ "$do_cleanup" == "true" ]] then - make_cleanup_job "$hpc_work" "$out_dir" "$cleanup_depend" "$delete_hpc_work" "$submit_jobs" + # delete_hpc_work is true, if the hpc_work directory was created by this script. + make_cleanup_job "$hpc_work" "$out_dir" "$cleanup_depend" "$cleanup_mode" "$LF" "$delete_hpc_work" "$submit_jobs" else - echo "Skipping the cleanup (no cleanup job scheduled, find your results in $hpc_work." -fi \ No newline at end of file + log "Skipping the cleanup (no cleanup job scheduled, find your results in $hpc_work." +fi + + +if [[ "$submit_jobs" == "true" ]] +then + log_dir="$out_dir/slurm/logs" + mkdir -p "$log_dir" + cp "$tmpLF" "$log_dir/$log_name.log" +fi +rm "$tmpLF" diff --git a/stools.sh b/stools.sh index 5c4f6251..59d335ba 100755 --- a/stools.sh +++ b/stools.sh @@ -42,7 +42,8 @@ function translate_cases () #param2 subject_list file #param3 target_dir #param4 optional, delimiter - #param5 optional, awk snippets to modify the subject_id and the subject_path (split by :), default '$1:$2' + #param5 optional, awk snippets to modify the subject_id, default '$1' + #param6 optional, awk snippets to modify the image_path, default '$2' if [[ "$#" -gt 3 ]] then delimiter=$4 @@ -51,16 +52,31 @@ function translate_cases () fi if [[ "$#" -gt 4 ]] then - subid_awk="$(echo "$5" | cut -f1 -d:)" - subpath_awk="$(echo "$5" | cut -f2 -d:)" + subid_awk="$5" else subid_awk='$1' + fi + if [[ "$#" -gt 5 ]] + then + subpath_awk="$6" + else subpath_awk='$2' fi - init='BEGIN { regex="^(" source_dir "|" target_dir ")"; }' - script=$(printf "%s length(\$NF) > 1 { subid=%s; subpath=%s; gsub(regex, \"\", subpath); print subid \"=\" target_dir \"/\" subpath; }" \ - "$init" "$subid_awk" "$subpath_awk") + script=" + BEGIN { + regex=\"^(\" source_dir \"|\" target_dir \")\"; + regex2=\",(\" source_dir \"|\" target_dir \")/*\"; + } + length(\$NF) > 1 { + subid=${subid_awk}; + subpath=${subpath_awk}; + gsub(regex, \"\", subpath); + gsub(regex2, \",\" target_dir \"/\", subpath); + print subid \"=\" target_dir \"/\" subpath; + }" #>&2 echo "awk -F \"$delimiter\" -v target_dir=\"$3\" -v source_dir=\"$1\" \"$script\" \"$2\"" + #>&2 cat "$2" + #>&2 awk -F "$delimiter" -v target_dir="$3" -v source_dir="$1" "$script" "$2" awk -F "$delimiter" -v target_dir="$3" -v source_dir="$1" "$script" "$2" } @@ -84,6 +100,7 @@ function check_out_dir () #param2 true/false, optional check empty, default false if [[ -z "$1" ]]; then echo "The subject directory (output directory) is not defined." + exit 1 elif [[ ! -d "$1" ]]; then echo "The subject directory $1 (output directory) does not exists." read -r -p "Create the directory? [y/N]" -n 1 retval @@ -108,6 +125,39 @@ function check_fs_license () exit 1 fi } +function check_cases_in_out_dir () +{ + #param1 out_dir + #param2 cases + #param3 optional: true/false jobarray defined (default: false) + if [[ "$#" -gt 2 ]] && [[ "$3" == "true" ]] + then + jobarray_defined="true" + else + jobarray_defined="false" + fi + case_already_exists="" + for subject in $2 + do + subject_id=$(echo "$subject" | cut -d= -f1) + if [[ -e "$1/$subject_id" ]] + then + case_already_exists="$case_already_exists, $subject_id" + fi + done + if [[ "$case_already_exists" != "" ]] + then + echo "Some cases already exist in $1 (${case_already_exists:2})" + if [[ "$jobarray_defined" == "true" ]] + then + echo "This list does not filter for the --slurm_jobarray argument!" + fi + read -r -p "Continue AND OVERWRITE those results? [y/N]" -n 1 retval + echo "" + if [[ "$retval" == "y" ]] || [[ "$retval" == "Y" ]] ; then export cleanup_mode="cp"; + else exit 1; fi + fi +} function check_seg_surf_only () { #param1 seg_only @@ -126,7 +176,16 @@ function check_subject_images () for subject in $1 do subject_id=$(echo "$subject" | cut -d= -f1) - image_path=$(echo "$subject" | cut -d= -f2) + image_parameters=$(echo "$subject" | cut -d= -f2) + i=0 + OLD_IFS=$IFS + IFS="," + for arg in $image_parameters + do + if [[ "$i" == 0 ]]; then image_path="$arg"; fi + i=$((i + 1)) + done + IFS=$OLD_IFS #TODO: also check here, if any of the folders up to the mounted dir leading to the file are symlinks #TODO: if so, this will lead to problems if [[ ! -e "$image_path" ]] @@ -164,21 +223,20 @@ function make_cleanup_job () # param1: hpc_work directory # param2: output directory # param3: dependency tag - # param4: optional: true/false (delete hpc_work directory, default=false) - # param5: optional: true/false (submit jobs, default=true) - - # param4: optional: true/false (delete hpc_work, default=false) - # param5: optional: true/false (submit jobs, default=true) + # param4: mode: the mode in which to copy (mv/cp) + # param5: logfile the log file + # param6: optional: true/false (delete hpc_work directory, default=false) + # param7: optional: true/false (submit jobs, default=true) local clean_cmd_file local submit_jobs local delete_hpc_work_dir - if [[ "$#" -gt 3 ]] && [[ "$4" == "true" ]] + if [[ "$#" -gt 5 ]] && [[ "$6" == "true" ]] then delete_hpc_work_dir="true" else delete_hpc_work_dir="false" fi - if [[ "$#" -gt 4 ]] && [[ "$5" == "false" ]] + if [[ "$#" -gt 6 ]] && [[ "$7" == "false" ]] then submit_jobs="false" clean_cmd_file=$(mktemp) @@ -193,6 +251,13 @@ function make_cleanup_job () local clean_slurm_sched=(-d "$3" -J "FastSurfer-Cleanup-$USER" --ntasks=1 --cpus-per-task=4 -o "$out_dir/slurm/logs/cleanup_%A.log" "$clean_cmd_filename") + local mode=$4 + if [[ "$mode" != "mv" ]] && [[ "$mode" != "cp" ]] + then + >&2 echo "invalid mode $mode" + exit 1 + fi + local logfile=$5 mkdir -p "$out_dir/slurm/logs" { @@ -215,7 +280,10 @@ function make_cleanup_job () echo "then" echo " for s in $hpc_work/cases/*" echo " do" - echo " mv -f -t \"$out_dir\" \$s &" + if [[ "$mode" == "mv" ]]; then echo " mv -f -t \"$out_dir\" \$s &" + elif [[ "$mode" == "cp" ]]; then echo " cp -r -t \"$out_dir\" \$s && rm -R \$s &" + else >&2 echo "invalid mode $mode"; exit 1; + fi echo " pids=(\${pids[@]} \$!)" echo " done" echo "fi" @@ -233,7 +301,10 @@ function make_cleanup_job () then echo " rm -R $hpc_work" else - echo " rm -R $hpc_work/*" + echo " rm -R $hpc_work/images" + echo " rm $hpc_work/scripts" + echo " rm $hpc_work/cases" + echo " rm $hpc_work/logs" fi echo "else" echo " echo \"Cleanup finished with errors!\"" @@ -241,18 +312,18 @@ function make_cleanup_job () } > $clean_cmd_file chmod +x $clean_cmd_file - echo "sbatch --parsable ${clean_slurm_sched[*]}" + echo "sbatch --parsable ${clean_slurm_sched[*]}" | tee -a $logfile echo "--- sbatch script $clean_cmd_filename ---" cat $clean_cmd_file echo "--- end of script ---" if [[ "$submit_jobs" == "false" ]] then - echo "Not submitting the Cleanup Jobs to slurm (--dry)." - clean_jobid=CLEAN_JOB_ID + echo "Not submitting the Cleanup Jobs to slurm (--dry)." | tee -a $logfile + export clean_jobid=CLEAN_JOB_ID else - clean_jobid=$(sbatch --parsable ${clean_slurm_sched[*]}) - echo "Submitted Cleanup Jobs $clean_jobid" + export clean_jobid=$(sbatch --parsable ${clean_slurm_sched[*]}) + echo "Submitted Cleanup Jobs $clean_jobid" | tee -a $logfile fi } @@ -262,10 +333,11 @@ function make_copy_job () # param1: hpc_work directory # param2: output directory # param3: subject_list file - # param4: optional: true/false (submit jobs, default=true) + # param4: logfile log file + # param5: optional: true/false (submit jobs, default=true) local copy_cmd_file - if [[ "$#" -gt 3 ]] && [[ "$4" == "false" ]] + if [[ "$#" -gt 4 ]] && [[ "$5" == "false" ]] then copy_cmd_file=$(mktemp) else @@ -276,6 +348,7 @@ function make_copy_job () local hpc_work=$1 local out_dir=$2 local subject_list=$3 + local logfile=$4 local copy_slurm_sched=(-J "FastSurfer-Copyseg-$USER" --ntasks=1 --cpus-per-task=4 -o "$out_dir/slurm/logs/copy_%A.log" "$copy_cmd_filename") @@ -300,18 +373,18 @@ function make_copy_job () } > $copy_cmd_file chmod +x $copy_cmd_file - echo "sbatch --parsable ${copy_slurm_sched[*]}" + echo "sbatch --parsable ${copy_slurm_sched[*]}" | tee -a "$logfile" echo "--- sbatch script $copy_cmd_filename ---" cat $copy_cmd_file echo "--- end of script ---" if [[ "$#" -gt 3 ]] && [[ "$4" == "false" ]] then - echo "Not submitting the Copyseg Jobs to slurm (--dry)." - copy_jobid=COPY_JOB_ID + echo "Not submitting the Copyseg Job to slurm (--dry)." | tee -a "$logfile" + export copy_jobid=COPY_JOB_ID else - copy_jobid=$(sbatch --parsable ${copy_slurm_sched[*]}) - echo "Submitted Copyseg Jobs $copy_jobid" + export copy_jobid=$(sbatch --parsable ${copy_slurm_sched[*]}) + echo "Submitted Copyseg Job $copy_jobid" | tee -a "$logfile" fi } diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..29780dc9 --- /dev/null +++ b/test/__init__.py @@ -0,0 +1,8 @@ + + + +__all__ = [ # This is a list of modules that should be imported when using the import * syntax + 'test_file_existence', + 'test_error_messages', + 'test_errors' + ] diff --git a/test/quick_test/__init__.py b/test/quick_test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/quick_test/data/errors.yaml b/test/quick_test/data/errors.yaml new file mode 100644 index 00000000..a54810ae --- /dev/null +++ b/test/quick_test/data/errors.yaml @@ -0,0 +1,11 @@ +errors: + - "error" + - "error:" + - "exception" + - "traceback" + +whitelist: + - "without error" + - "not included" + - "distance" + - "correcting" diff --git a/test/quick_test/data/files.yaml b/test/quick_test/data/files.yaml new file mode 100644 index 00000000..c06e6d17 --- /dev/null +++ b/test/quick_test/data/files.yaml @@ -0,0 +1,267 @@ +files: + - "scripts/deep-seg.log" + - "scripts/build.log" + - "scripts/recon-surf.log" + - "scripts/build-stamp.txt" + - "scripts/lastcall.build-stamp.txt" + - "scripts/recon-all.cmd" + - "scripts/recon-all.log" + - "scripts/recon-all-status.log" + - "scripts/recon-all.local-copy" + - "scripts/ponscc.cut.log" + - "scripts/lh.processing.cmdf" + - "scripts/rh.processing.cmdf" + - "scripts/recon-all.env.bak" + - "scripts/defect2seg.log" + - "scripts/recon-surf_times.yaml" + - "scripts/pctsurfcon.log.old" + - "scripts/pctsurfcon.log" + - "scripts/patchdir.txt" + - "scripts/recon-config.yaml" + - "scripts/unknown-args.txt" + - "scripts/recon-all.env" + - "scripts/recon-all.done" + - "scripts/recon-surf.done" + - "mri/orig.mgz" + - "mri/aparc.DKTatlas+aseg.deep.mgz" + - "mri/mask.mgz" + - "mri/aseg.auto_noCCseg.mgz" + - "mri/orig_nu.mgz" + - "mri/cerebellum.CerebNet.nii.gz" + - "mri/aparc.DKTatlas+aseg.orig.mgz" + - "mri/rawavg.mgz" + - "mri/nu.mgz" + - "mri/norm.mgz" + - "mri/T1.mgz" + - "mri/brainmask.mgz" + - "mri/aseg.auto.mgz" + - "mri/aparc.DKTatlas+aseg.deep.withCC.mgz" + - "mri/aseg.presurf.mgz" + - "mri/brain.mgz" + - "mri/brain.finalsurfs.mgz" + - "mri/brain.finalsurfs.manedit.mgz" + - "mri/antsdn.brain.mgz" + - "mri/segment.dat" + - "mri/wm.seg.mgz" + - "mri/wm.asegedit.mgz" + - "mri/wm.mgz" + - "mri/filled.mgz" + - "mri/filled.auto.mgz" + - "mri/filled-pretess127.mgz" + - "mri/filled-pretess255.mgz" + - "mri/rh.surface.defects.mgz" + - "mri/lh.surface.defects.mgz" + - "mri/ribbon.mgz" + - "mri/lh.ribbon.mgz" + - "mri/rh.ribbon.mgz" + - "mri/aseg.presurf.hypos.mgz" + - "mri/aseg.mgz" + - "mri/aparc.DKTatlas+aseg.mapped.mgz" + - "mri/wmparc.DKTatlas.mapped.mgz" + - "mri/aparc.DKTatlas+aseg.mgz" + - "mri/aparc+aseg.mgz" + - "mri/wmparc.mgz" + - "mri/orig/001.mgz" + - "mri/transforms/talairach_avi.log" + - "mri/transforms/talairach.auto.xfm" + - "mri/transforms/talairach.auto.xfm.lta" + - "mri/transforms/talairach.xfm" + - "mri/transforms/talairach.xfm.lta" + - "mri/transforms/talairach_with_skull.lta" + - "mri/transforms/talairach.lta" + - "mri/transforms/cc_up.lta" + - "stats/aseg+DKT.stats" + - "stats/cerebellum.CerebNet.stats" + - "stats/lh.curv.stats" + - "stats/rh.curv.stats" + - "stats/lh.aparc.DKTatlas.mapped.stats" + - "stats/brainvol.stats" + - "stats/rh.aparc.DKTatlas.mapped.stats" + - "stats/lh.w-g.pct.stats" + - "stats/rh.w-g.pct.stats" + - "stats/aseg.stats" + - "stats/aseg.presurf.hypos.stats" + - "stats/wmparc.DKTatlas.mapped.stats" + - "stats/lh.BA_exvivo.stats" + - "stats/lh.BA_exvivo.thresh.stats" + - "stats/rh.BA_exvivo.stats" + - "stats/rh.BA_exvivo.thresh.stats" + - "surf/rh.orig.nofix" + - "surf/lh.orig.nofix" + - "surf/rh.smoothwm.nofix" + - "surf/lh.smoothwm.nofix" + - "surf/rh.inflated.nofix" + - "surf/lh.inflated.nofix" + - "surf/lh.qsphere.nofix" + - "surf/rh.qsphere.nofix" + - "surf/lh.defect_labels" + - "surf/lh.defect_borders" + - "surf/lh.defect_chull" + - "surf/rh.defect_labels" + - "surf/rh.defect_borders" + - "surf/rh.defect_chull" + - "surf/rh.orig.premesh" + - "surf/rh.defects.pointset" + - "surf/lh.orig.premesh" + - "surf/lh.defects.pointset" + - "surf/rh.orig" + - "surf/autodet.gw.stats.rh.dat" + - "surf/lh.orig" + - "surf/autodet.gw.stats.lh.dat" + - "surf/rh.white.preaparc" + - "surf/lh.white.preaparc" + - "surf/rh.smoothwm" + - "surf/rh.inflated" + - "surf/rh.sulc" + - "surf/rh.white.preaparc.K" + - "surf/rh.white.preaparc.H" + - "surf/rh.white.H" + - "surf/rh.white.K" + - "surf/lh.smoothwm" + - "surf/lh.inflated" + - "surf/lh.sulc" + - "surf/lh.white.preaparc.K" + - "surf/lh.white.preaparc.H" + - "surf/lh.white.H" + - "surf/lh.white.K" + - "surf/rh.inflated.K" + - "surf/rh.inflated.H" + - "surf/lh.inflated.K" + - "surf/lh.inflated.H" + - "surf/lh.sphere" + - "surf/lh.angles.txt" + - "surf/rh.sphere" + - "surf/rh.angles.txt" + - "surf/lh.sphere.reg" + - "surf/rh.sphere.reg" + - "surf/lh.white" + - "surf/rh.white" + - "surf/lh.pial.T1" + - "surf/lh.pial" + - "surf/lh.jacobian_white" + - "surf/lh.curv" + - "surf/lh.area" + - "surf/lh.curv.pial" + - "surf/lh.area.pial" + - "surf/rh.pial.T1" + - "surf/rh.pial" + - "surf/rh.jacobian_white" + - "surf/rh.curv" + - "surf/rh.area" + - "surf/rh.curv.pial" + - "surf/rh.area.pial" + - "surf/lh.thickness" + - "surf/lh.area.mid" + - "surf/lh.volume" + - "surf/lh.smoothwm.K.crv" + - "surf/lh.smoothwm.H.crv" + - "surf/lh.smoothwm.K1.crv" + - "surf/lh.smoothwm.K2.crv" + - "surf/lh.smoothwm.S.crv" + - "surf/lh.smoothwm.C.crv" + - "surf/lh.smoothwm.BE.crv" + - "surf/lh.smoothwm.FI.crv" + - "surf/rh.thickness" + - "surf/rh.area.mid" + - "surf/rh.volume" + - "surf/rh.smoothwm.K.crv" + - "surf/rh.smoothwm.H.crv" + - "surf/rh.smoothwm.K1.crv" + - "surf/rh.smoothwm.K2.crv" + - "surf/rh.smoothwm.S.crv" + - "surf/rh.smoothwm.C.crv" + - "surf/rh.smoothwm.BE.crv" + - "surf/rh.smoothwm.FI.crv" + - "surf/lh.w-g.pct.mgh" + - "surf/rh.w-g.pct.mgh" + - "label/rh.nofix.cortex.label" + - "label/lh.nofix.cortex.label" + - "label/rh.cortex.label" + - "label/rh.cortex+hipamyg.label" + - "label/lh.cortex.label" + - "label/lh.cortex+hipamyg.label" + - "label/rh.aparc.DKTatlas.mapped.annot" + - "label/lh.aparc.DKTatlas.mapped.annot" + - "label/aparc.annot.mapped.ctab" + - "label/lh.aparc.DKTatlas.annot" + - "label/rh.aparc.DKTatlas.annot" + - "label/lh.BA1_exvivo.label" + - "label/lh.BA2_exvivo.label" + - "label/lh.BA3a_exvivo.label" + - "label/lh.BA3b_exvivo.label" + - "label/lh.BA4a_exvivo.label" + - "label/lh.BA4p_exvivo.label" + - "label/lh.BA6_exvivo.label" + - "label/lh.BA44_exvivo.label" + - "label/lh.BA45_exvivo.label" + - "label/lh.V1_exvivo.label" + - "label/lh.V2_exvivo.label" + - "label/lh.MT_exvivo.label" + - "label/lh.perirhinal_exvivo.label" + - "label/lh.entorhinal_exvivo.label" + - "label/lh.BA1_exvivo.thresh.label" + - "label/lh.BA2_exvivo.thresh.label" + - "label/lh.BA3a_exvivo.thresh.label" + - "label/lh.BA3b_exvivo.thresh.label" + - "label/lh.BA4a_exvivo.thresh.label" + - "label/lh.BA4p_exvivo.thresh.label" + - "label/lh.BA6_exvivo.thresh.label" + - "label/lh.BA44_exvivo.thresh.label" + - "label/lh.BA45_exvivo.thresh.label" + - "label/lh.V1_exvivo.thresh.label" + - "label/lh.V2_exvivo.thresh.label" + - "label/lh.MT_exvivo.thresh.label" + - "label/lh.perirhinal_exvivo.thresh.label" + - "label/lh.entorhinal_exvivo.thresh.label" + - "label/lh.FG1.mpm.vpnl.label" + - "label/lh.FG2.mpm.vpnl.label" + - "label/lh.FG3.mpm.vpnl.label" + - "label/lh.FG4.mpm.vpnl.label" + - "label/lh.hOc1.mpm.vpnl.label" + - "label/lh.hOc2.mpm.vpnl.label" + - "label/lh.hOc3v.mpm.vpnl.label" + - "label/lh.hOc4v.mpm.vpnl.label" + - "label/lh.BA_exvivo.annot" + - "label/BA_exvivo.ctab" + - "label/lh.BA_exvivo.thresh.annot" + - "label/BA_exvivo.thresh.ctab" + - "label/lh.mpm.vpnl.annot" + - "label/rh.BA1_exvivo.label" + - "label/rh.BA2_exvivo.label" + - "label/rh.BA3a_exvivo.label" + - "label/rh.BA3b_exvivo.label" + - "label/rh.BA4a_exvivo.label" + - "label/rh.BA4p_exvivo.label" + - "label/rh.BA6_exvivo.label" + - "label/rh.BA44_exvivo.label" + - "label/rh.BA45_exvivo.label" + - "label/rh.V1_exvivo.label" + - "label/rh.V2_exvivo.label" + - "label/rh.MT_exvivo.label" + - "label/rh.perirhinal_exvivo.label" + - "label/rh.entorhinal_exvivo.label" + - "label/rh.BA1_exvivo.thresh.label" + - "label/rh.BA2_exvivo.thresh.label" + - "label/rh.BA3a_exvivo.thresh.label" + - "label/rh.BA3b_exvivo.thresh.label" + - "label/rh.BA4a_exvivo.thresh.label" + - "label/rh.BA4p_exvivo.thresh.label" + - "label/rh.BA6_exvivo.thresh.label" + - "label/rh.BA44_exvivo.thresh.label" + - "label/rh.BA45_exvivo.thresh.label" + - "label/rh.V1_exvivo.thresh.label" + - "label/rh.V2_exvivo.thresh.label" + - "label/rh.MT_exvivo.thresh.label" + - "label/rh.perirhinal_exvivo.thresh.label" + - "label/rh.entorhinal_exvivo.thresh.label" + - "label/rh.FG1.mpm.vpnl.label" + - "label/rh.FG2.mpm.vpnl.label" + - "label/rh.FG3.mpm.vpnl.label" + - "label/rh.FG4.mpm.vpnl.label" + - "label/rh.hOc1.mpm.vpnl.label" + - "label/rh.hOc2.mpm.vpnl.label" + - "label/rh.hOc3v.mpm.vpnl.label" + - "label/rh.hOc4v.mpm.vpnl.label" + - "label/rh.BA_exvivo.annot" + - "label/rh.BA_exvivo.thresh.annot" + - "label/rh.mpm.vpnl.annot" diff --git a/test/quick_test/test_errors.py b/test/quick_test/test_errors.py new file mode 100644 index 00000000..01611b5e --- /dev/null +++ b/test/quick_test/test_errors.py @@ -0,0 +1,96 @@ +import sys +import yaml +import unittest +import argparse +from pathlib import Path + + +class TestErrors(unittest.TestCase): + """ + A test case class to check for the word "error" in the given log files. + """ + + error_file_path: Path = Path("./test/quick_test/data/errors.yaml") + + error_flag = False + + @classmethod + def setUpClass(cls): + """ + Set up the test class. + This method retrieves the log directory from the command line argument, + and assigns it to a class variable. + """ + + # Open the error_file_path and read the errors and whitelist into arrays + with open(cls.error_file_path, 'r') as file: + data = yaml.safe_load(file) + cls.errors = data.get('errors', []) + cls.whitelist = data.get('whitelist', []) + + # Retrieve the log files in given log directory + try: + # cls.log_directory = Path(cls.log_directory) + print(cls.log_directory) + cls.log_files = [file for file in cls.log_directory.iterdir() if file.suffix == '.log'] + except FileNotFoundError: + raise FileNotFoundError(f"Log directory not found at path: {cls.log_directory}") + + def test_find_errors_in_logs(self): + """ + Test that the words "error", "exception", and "traceback" are not in the log files. + + This method retrieves the log files in the log directory, reads each log file line by line, + and checks that none of the keywords are in any line. + """ + + files_with_errors = {} + + # Check if any of the keywords are in the log files + for log_file in self.log_files: + rel_path = log_file.relative_to(self.log_directory) + print(f"Checking file: {rel_path}") + try: + with log_file.open('r') as file: + lines = file.readlines() + lines_with_errors = [] + for line_number, line in enumerate(lines, start=1): + if any(error in line.lower() for error in self.errors): + if not any(white in line.lower() for white in self.whitelist): + # Get two lines before and after the current line + context = lines[max(0, line_number-2):min(len(lines), line_number+3)] + lines_with_errors.append((line_number, context)) + print(lines_with_errors) + files_with_errors[rel_path] = lines_with_errors + self.error_flag = True + except FileNotFoundError: + raise FileNotFoundError(f"Log file not found at path: {log_file}") + continue + + # Print the lines and context with errors for each file + for file, lines in files_with_errors.items(): + print(f"\nFile {file}, in line {files_with_errors[file][0][0]}:") + for line_number, line in lines: + print(*line, sep = "") + + # Assert that there are no lines with any of the keywords + self.assertEqual(self.error_flag, False, f"Found errors in the following files: {files_with_errors}") + print("No errors found in any log files.") + + +if __name__ == '__main__': + """ + Main entry point of the script. + + This block checks if there are any command line arguments, + assigns the first argument to the log_directory class variable + """ + + parser = argparse.ArgumentParser(description="Test for errors in log files.") + parser.add_argument('log_directory', type=Path, help="The directory containing the log files.") + + args = parser.parse_args() + + TestErrors.log_directory = args.log_directory + + unittest.main(argv=[sys.argv[0]]) diff --git a/test/quick_test/test_file_existence.py b/test/quick_test/test_file_existence.py new file mode 100644 index 00000000..77abcb09 --- /dev/null +++ b/test/quick_test/test_file_existence.py @@ -0,0 +1,71 @@ +import sys +import yaml +import unittest +import argparse +from pathlib import Path + + +class TestFileExistence(unittest.TestCase): + """ + A test case class to check the existence of files in a folder based on a YAML file. + + This class defines test methods to verify if each file specified in the YAML file exists in the given folder. + """ + + file_path: Path = Path("./test/quick_test/data/files.yaml") + + @classmethod + def setUpClass(cls): + """ + Set up the test case by loading the YAML file and extracting the folder path. + + This method is executed once before any test methods in the class. + """ + + # Open the file_path and read the files into an array + with cls.file_path.open('r') as file: + data = yaml.safe_load(file) + cls.files = data.get('files', []) + + # Get a list of all files in the folder recursively + cls.filenames = [] + for file in cls.folder_path.glob('**/*'): + if file.is_file(): + # Get the relative path from the current directory to the file + rel_path = file.relative_to(cls.folder_path) + cls.filenames.append(str(rel_path)) + + def test_file_existence(self): + """ + Test method to check the existence of files in the folder. + + This method gets a list of all files in the folder recursively and checks if each file specified in the YAML file exists in the folder. + """ + + # Check if each file in the YAML file exists in the folder + if not self.files: + self.fail("The 'files' key was not found in the YAML file") + + for file in self.files: + print(f"Checking for file: {file}") + self.assertIn(file, self.filenames, f"File '{file}' does not exist in the folder.") + + print("All files present") + + +if __name__ == '__main__': + """ + Main entry point of the script. + + This block checks if there are any command line arguments, assigns the first argument to the error_file_path class variable, + and runs the unittest main function. + """ + + parser = argparse.ArgumentParser(description="Test for file existence based on a YAML file.") + parser.add_argument('folder_path', type=Path, help="The path to the folder to check.") + + args = parser.parse_args() + + TestFileExistence.folder_path = args.folder_path + + unittest.main(argv=[sys.argv[0]])