Skip to content

Commit

Permalink
✨Fully support decathlon datalist (#63)
Browse files Browse the repository at this point in the history
* refactor dataset, add decathlon support
* fix wrong import of lightning, test_kfold_crossval still fails
* fixes needed for metatensor changes in monai
* remove python38 support
* increase version number

---------

Co-authored-by: Bryn Lloyd <lloyd@itis.swiss>
  • Loading branch information
dyollb and dyollb authored Mar 14, 2024
1 parent 9e176d7 commit 288663f
Show file tree
Hide file tree
Showing 32 changed files with 785 additions and 361 deletions.
5 changes: 0 additions & 5 deletions .github/workflows/deploy_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ jobs:
os: windows-latest,
python-version: "3.9"
}
- {
name: "ubuntu-latest - Python 3.8",
os: ubuntu-latest,
python-version: "3.8"
}
- {
name: "ubuntu-latest - Python 3.9",
os: ubuntu-latest,
Expand Down
25 changes: 15 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand All @@ -26,34 +26,34 @@ repos:
- id: requirements-txt-fixer

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.277
rev: v0.3.2
hooks:
- id: ruff
name: ruff
args: [ --fix, --exit-non-zero-on-fix ]

- repo: https://github.com/psf/black
rev: 23.3.0
rev: 24.2.0
hooks:
- id: black
name: format code

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
name: format imports
args: ["--profile", "black"]

- repo: https://github.com/asottile/pyupgrade
rev: v3.8.0
rev: v3.15.1
hooks:
- id: pyupgrade
name: upgrade code
args: ["--py38-plus"]
args: ["--py39-plus"]

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
rev: 0.7.17
hooks:
- id: mdformat
name: format markdown
Expand All @@ -63,25 +63,30 @@ repos:
exclude: CHANGELOG.md

- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8
name: check PEP8
args: ["--ignore=E501,W503,E203"]

- repo: https://github.com/hadialqattan/pycln
rev: v2.1.5
rev: v2.4.0
hooks:
- id: pycln
name: prune imports
args: [--expand-stars]

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.0
rev: 1.8.4
hooks:
- id: nbqa-black
additional_dependencies: [black]
name: format notebooks
- id: nbqa-mypy
additional_dependencies: [mypy]
name: static analysis for notebooks

- repo: https://github.com/crate-ci/typos
rev: v1.19.0
hooks:
- id: typos
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ segmantic-unet train-config -c config.yml

## What is this tisse_list?

The example above included a tissue_list option. This is a path to a text file specifying the labels contained in a segmented image. By convention the 'label=0' is the background and is ommited from the the format. A segmentation with three tissues 'Bone'=1, 'Fat'=2, and 'Skin'=3 would be specified as follows:
The example above included a tissue_list option. This is a path to a text file specifying the labels contained in a segmented image. By convention the 'label=0' is the background and is omitted from the the format. A segmentation with three tissues 'Bone'=1, 'Fat'=2, and 'Skin'=3 would be specified as follows:

```
V7
Expand All @@ -96,7 +96,7 @@ Instead of providing the 'image_dir'/'labels_dir' pair, the training data can al

```json
{
"dataset": ["/dataA/dataset.json", "/dataB/dataset.json"],
"datalist": ["/dataA/dataset.json", "/dataB/dataset.json"],
"output_dir": "<path where trained model and logs are saved>",
"Etc": "etc"
}
Expand Down
17 changes: 11 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ build-backend = "setuptools.build_meta"
name = "segmantic"
authors = [{name = "Bryn Lloyd", email = "lloyd@itis.swiss"}]
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
license = {file = "LICENSE"}
dynamic = ["version"]
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down Expand Up @@ -72,7 +71,7 @@ profile = "black"
disallow_untyped_defs = false
warn_unused_configs = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_unused_ignores = false
warn_return_any = true
strict_equality = true
no_implicit_optional = false
Expand All @@ -90,9 +89,15 @@ module = [
disallow_untyped_defs = true

[[tool.mypy.overrides]]
module = "itk.*,PIL,matplotlib.*,torch,torchvision.*,numba,setuptools,pytest,typer.*,click,colorama,nibabel,sklearn.*,yaml,scipy.*,adabelief_pytorch,h5py,SimpleITK,vtk.*,sitk_cli"
module = "itk.*,PIL,matplotlib.*,torchvision.*,numba,setuptools,pytest,typer.*,click,colorama,nibabel,sklearn.*,yaml,scipy.*,adabelief_pytorch,h5py,SimpleITK,vtk.*,sitk_cli"
ignore_missing_imports = true

[tool.ruff]
select = ["E", "F"]
ignore = ["E501"]
lint.select = ["E", "F"]
lint.ignore = ["E501"]

[tool.typos]

[tool.typos.default.extend-identifiers]
# *sigh* monai adds 'd' to dictionary transforms
SpatialPadd = "SpatialPadd"
15 changes: 14 additions & 1 deletion scripts/check_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,18 @@ def fix_binary_masks(directory: Path, file_glob: str = "*.nii.gz"):
)


def round_half_up(input_dir: Path, output_dir: Path = None):
import SimpleITK as sitk

for f in input_dir.glob("*.nii.gz"):
img = sitk.ReadImage(f)
img_np = sitk.GetArrayViewFromImage(img)
imin, imax = np.min(img_np), np.max(img_np)
if imin < 0 or imax > 3:
print(f"{f.name}: [{imin}, {imax}]")
if img.GetPixelID() in (sitk.sitkFloat32, sitk.sitkFloat64):
print(f"{f.name}: {img.GetPixelIDTypeAsString()}")


if __name__ == "__main__":
typer.run(fix_binary_masks)
typer.run(round_half_up)
31 changes: 31 additions & 0 deletions scripts/check_training_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path

import numpy as np
import SimpleITK as sitk
import typer

from segmantic.utils.file_iterators import find_matching_files


def check_training_data(
image_dir: Path, labels_dir: Path, copy_image_information: bool = False
):
matches = find_matching_files([image_dir / "*.nii.gz", labels_dir / "*.nii.gz"])
for p in matches:
img = sitk.ReadImage(p[0])
lbl = sitk.ReadImage(p[1])
if img.GetSize() != lbl.GetSize():
print(f"Size mismatch {p[0].name}: {img.GetSize()} != {lbl.GetSize()}")
continue
if copy_image_information:
lbl.CopyInformation(img)
sitk.WriteImage(sitk.Cast(lbl, sitk.sitkUInt8), p[1])
elif img.GetSpacing() != lbl.GetSpacing() or img.GetOrigin() != lbl.GetOrigin():
np.testing.assert_almost_equal(
img.GetSpacing(), lbl.GetSpacing(), decimal=2
)
np.testing.assert_almost_equal(img.GetOrigin(), lbl.GetOrigin(), decimal=2)


if __name__ == "__main__":
typer.run(check_training_data)
2 changes: 0 additions & 2 deletions scripts/evaluate_segmentations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

from pathlib import Path

import SimpleITK as sitk
Expand Down
22 changes: 22 additions & 0 deletions scripts/extract_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pathlib import Path
from typing import Optional

import typer


def extract_unet(input_file: Path, output_file: Optional[Path] = None):
"""Load segmantic unet lightning module and export inner monai UNet"""
import torch

from segmantic.seg.monai_unet import Net

if output_file is None:
output_file = input_file.with_suffix(".pth")
if output_file.exists() and output_file.samefile(input_file):
raise RuntimeError("Input and output file are identical")
net = Net.load_from_checkpoint(input_file)
torch.save(net._model.state_dict(), output_file)


if __name__ == "__main__":
typer.run(extract_unet)
36 changes: 0 additions & 36 deletions scripts/generate_dataset.py

This file was deleted.

81 changes: 81 additions & 0 deletions scripts/make_datalist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
import random
from pathlib import Path

import typer

from segmantic.image.labels import load_tissue_list
from segmantic.utils.file_iterators import find_matching_files


def make_datalist(
data_dir: Path = typer.Option(
...,
help="root data directory. Paths in datalist will be relative to this directory",
),
image_dir: Path = typer.Option(..., help="Directory containing images"),
labels_dir: Path = typer.Option(None, help="Directory containing labels"),
datalist_path: Path = typer.Option(..., help="Filename of output datalist"),
num_channels: int = 1,
num_classes: int = -1,
tissuelist_path: Path = None,
percent: float = 1.0,
description: str = "",
image_glob: str = "*.nii.gz",
labels_glob: str = "*.nii.gz",
test_only: bool = False,
seed: int = 104,
) -> int:
# add labels
if tissuelist_path is not None:
tissuelist = load_tissue_list(tissuelist_path)
labels = {str(id): n for n, id in tissuelist.items() if id != 0}
elif num_classes > 0:
labels = {str(id): f"tissue{id:02d}" for id in range(1, num_classes + 1)}
else:
raise ValueError("Either specify 'tissuelist_path' or 'num_classes'")

data_config = {
"description": description,
"num_channels": num_channels,
"labels": labels,
}

# add all files as test files
if test_only:
test_files = (data_dir / image_dir).glob(image_glob)
data_config["training"] = []
data_config["validation"] = []
data_config["test"] = [str(f.relative_to(data_dir)) for f in test_files]

# build proper datalist with training/validation/test split
else:
matches = find_matching_files(
[data_dir / image_dir / image_glob, data_dir / labels_dir / labels_glob]
)
pairs = [
(p[0].relative_to(data_dir), p[1].relative_to(data_dir)) for p in matches
]

random.Random(seed).shuffle(pairs)
test, pairs = pairs[:10], pairs[10:]
num_valid = int(percent * 0.2 * len(pairs))
num_training = len(pairs) - num_valid if percent >= 1.0 else 4 * num_valid

data_config["training"] = [
{"image": str(im), "label": str(lbl)} for im, lbl in pairs[:num_training]
]
data_config["validation"] = [
{"image": str(im), "label": str(lbl)} for im, lbl in pairs[-num_valid:]
]
data_config["test"] = ([str(im) for im, _ in test],)

return datalist_path.write_text(json.dumps(data_config, indent=2))


def main():
typer.run(make_datalist)


if __name__ == "__main__":
main()
5 changes: 2 additions & 3 deletions scripts/visualize_label_surfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pathlib import Path
from typing import Dict, List

import itk
import numpy as np
Expand All @@ -15,11 +14,11 @@ def extract_surfaces(
file_path: Path,
output_dir: Path,
tissuelist_path: Path,
selected_tissues: List[int] = [],
selected_tissues: list[int] = [],
):
image = itk.imread(f"{file_path}", pixel_type=itk.US)

tissues: Dict[int, str] = {}
tissues: dict[int, str] = {}
if tissuelist_path.exists():
name_id_map = load_tissue_list(tissuelist_path)
tissues = {id: name for name, id in name_id_map.items()}
Expand Down
2 changes: 1 addition & 1 deletion src/segmantic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" ML-based segmentation for medical images
"""

__version__ = "0.3.0"
__version__ = "0.4.0"
Loading

0 comments on commit 288663f

Please sign in to comment.