Skip to content

Commit

Permalink
Merge pull request #893 from AFM-SPM/SylviaWhittle/multi-class-unet
Browse files Browse the repository at this point in the history
Add multi-class deep learning segmentation support
  • Loading branch information
llwiggins authored Sep 5, 2024
2 parents 2567fb3 + 3b3cdd2 commit 70d02bd
Show file tree
Hide file tree
Showing 10 changed files with 964 additions and 212 deletions.
28 changes: 27 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def skeleton_linear3() -> dict:

# U-Net fixtures
@pytest.fixture()
def mock_model_5_by_5() -> MagicMock:
def mock_model_5_by_5_single_class() -> MagicMock:
"""Create a mock model."""
model_mocker = MagicMock()

Expand Down Expand Up @@ -1056,6 +1056,31 @@ def side_effect_predict(input_array: npt.NDArray[np.float32]) -> npt.NDArray[np.
.reshape((1, 5, 5, 1))
.astype(np.float32)
)
if np.array_equal(
input_array_without_batch_and_channel,
np.array(
[
[0.1, 0.2, 0.1, 0.2, 0.1],
[0.2, 0.1, 1.0, 0.1, 0.2],
[0.1, 1.0, 1.0, 1.0, 0.1],
[0.2, 0.1, 1.0, 0.1, 0.2],
[0.1, 0.2, 0.1, 0.2, 0.1],
]
).astype(np.float32),
):
return (
np.array(
[
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
]
)
.reshape((1, 5, 5, 1))
.astype(np.float32)
)
raise ValueError(
"Input is not as expected. Check the image crop sent to the model and check the"
"mocked unet predict function has a case for that exact input."
Expand All @@ -1065,5 +1090,6 @@ def side_effect_predict(input_array: npt.NDArray[np.float32]) -> npt.NDArray[np.
model_mocker.predict.side_effect = side_effect_predict
# Override the output of the input_shape property
model_mocker.input_shape = (1, 5, 5, 1)
model_mocker.output_shape = (1, 5, 5, 1)

return model_mocker
Binary file not shown.
Binary file not shown.
Binary file modified tests/resources/process_scan_topostats_file_regtest.topostats
Binary file not shown.
Loading

0 comments on commit 70d02bd

Please sign in to comment.