diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3c45b11f..90a92f78 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,7 @@ jobs: strategy: matrix: platform: [ubuntu-latest, windows-latest, macos-latest] - python-version: [3.9, "3.10"] + python-version: [3.9, "3.10", "3.11"] env: DISPLAY: ':99.0' steps: @@ -44,17 +44,15 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools - pip install pytest - pip install pytest-qt pip install pytest-xvfb pip install coverage - pip install -e ".[testing]" + pip install -e ".[dev]" pip install matplotlib working-directory: src/client - name: Install server dependencies (for communication tests) run: | - pip install -e ".[testing]" + pip install -e ".[dev]" working-directory: src/server - name: Test with pytest @@ -94,10 +92,9 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade setuptools pip install numpy - pip install pytest pip install wheel pip install coverage - pip install -e ".[testing]" + pip install -e ".[dev]" working-directory: src/server - name: Test with pytest @@ -114,4 +111,43 @@ jobs: files: src/server/coverage.xml env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + + deploy: + # this will run when you have tagged a commit, starting with "v*" + # and requires that you have put your twine API key in your + # github secrets (see readme for details) + needs: [test_client, test_server] + runs-on: ubuntu-latest + if: contains(github.ref, 'tags') + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.x" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install twine + pip install build + + - name: Build and publish dcp_client + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} + run: | + git tag + python -m build . + twine upload dist/* + working-directory: src/client + + - name: Build and publish dcp_server + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} + run: | + git tag + python -m build . + twine upload dist/* + working-directory: src/server diff --git a/.gitignore b/.gitignore index 8d5b07d1..0f64af41 100644 --- a/.gitignore +++ b/.gitignore @@ -2,12 +2,14 @@ data/ in_progress/ curated/ +uncurated/ # model dir *mytrainedmodel/ #configs src/client/dcp_client/config.cfg +src/server/dcp_server/config.cfg # Byte-compiled / optimized / DLL files __pycache__/ @@ -80,7 +82,7 @@ instance/ .scrapy # Sphinx documentation -docs/_build/ +# docs/build/ # PyBuilder target/ @@ -151,8 +153,6 @@ dmypy.json .idea/ .DS_Store -docs/ -test-napari.pub data/ BentoML/ -models/ + diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..a41b4211 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,30 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.10" + # You can also specify other tool versions: + # nodejs: "19" + # rust: "1.64" + # golang: "1.19" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/source/conf.py + +# Optionally declare the Python requirements required to build your docs +python: + # Install both Python packages before building the docs + install: + - method: pip + path: src/client + - method: pip + path: src/server + - requirements: docs/requirements.txt diff --git a/README.md b/README.md index e4c42374..bf0e3499 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ![stability-wip](https://img.shields.io/badge/stability-work_in_progress-lightgrey.svg) ![tests](https://github.com/HelmholtzAI-Consultants-Munich/data-centric-platform/actions/workflows/test.yml/badge.svg?event=push) [![codecov](https://codecov.io/gh/HelmholtzAI-Consultants-Munich/data-centric-platform/branch/main/graph/badge.svg)](https://codecov.io/gh/HelmholtzAI-Consultants-Munich/data-centric-platform) - +[![Documentation Status](https://readthedocs.org/projects/data-centric-platform/badge/?version=latest)](https://data-centric-platform.readthedocs.io/en/latest/?badge=latest) ## How to use this? @@ -16,7 +16,7 @@ To run the client GUI follow the instructions described in [DCP Client Installat DCP handles all kinds of **segmentation tasks**! Try it out if you need to do: * **Instance** segmentation * **Semantic** segmentation -* **Panoptic** segmentation +* **Multi-class instance** segmentation ### Toy data This repo includes the ```data/``` directory with some toy data which you can use as the *Uncurated dataset* folder. You can create (empty) folders for the other two directories required in the welcome window and start playing around. @@ -29,3 +29,4 @@ Our platform encourages the use of data centric practices. With the user friendl - Focus on data curation: no interaction with model parameters during training and inference #### *Get more with less!* + diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..d0c3cbf1 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000..dc1312ab --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..cbf1e365 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,2 @@ +sphinx +sphinx-rtd-theme diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 00000000..c0df37ab --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,48 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'DCP' +copyright = '2024, Christina Bukas, Mariia Koren, Helena Pelin' +author = 'Christina Bukas, Mariia Koren, Helena Pelin' +release = '0.1' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [] + +templates_path = ['_templates'] +exclude_patterns = [] + +language = 'English' + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'alabaster' +#html_static_path = ['_static'] + +import os +import sys +from pathlib import Path +import sphinx_rtd_theme + +# Add parent dir to known paths +p = Path(__file__).parents[2] +sys.path.insert(0, os.path.abspath(p)) +sys.path.insert(0, os.path.join(p, 'src/server/dcp_server')) +# Add the following extensions +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.napoleon', + 'sphinx_rtd_theme' +] + +# Use RTD theme +html_theme = "sphinx_rtd_theme" diff --git a/docs/source/dcp_client.gui.rst b/docs/source/dcp_client.gui.rst new file mode 100644 index 00000000..f0e027ad --- /dev/null +++ b/docs/source/dcp_client.gui.rst @@ -0,0 +1,34 @@ +dcp\_client.gui package +======================= + +.. automodule:: dcp_client.gui + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +dcp\_client.gui.main\_window module +----------------------------------- + +.. automodule:: dcp_client.gui.main_window + :members: + :undoc-members: + :show-inheritance: + +dcp\_client.gui.napari\_window module +------------------------------------- + +.. automodule:: dcp_client.gui.napari_window + :members: + :undoc-members: + :show-inheritance: + +dcp\_client.gui.welcome\_window module +-------------------------------------- + +.. automodule:: dcp_client.gui.welcome_window + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/dcp_client.rst b/docs/source/dcp_client.rst new file mode 100644 index 00000000..f921a211 --- /dev/null +++ b/docs/source/dcp_client.rst @@ -0,0 +1,52 @@ +dcp\_client package +=================== + +The dcp_client package contains modules and subpackages for interacting with a server for model inference and training. It provides functionalities for managing GUI windows, handling image storage, and connecting to the server for model operations. + +dcp_client.app + Defines the core application class and related functionalities. + - ``dcp_client.app.Application``: Represents the main application and provides methods for image management, model interaction, and server connectivity. + - ``dcp_client.app.DataSync``: Abstract base class for data synchronization operations. + - ``dcp_client.app.ImageStorage``: Abstract base class for image storage operations. + - ``dcp_client.app.Model``: Abstract base class for model operations. + +dcp_client.gui + Contains modules for GUI components. + - ``dcp_client.gui.main_window``: Defines the main application window and associated event functions. + - ``dcp_client.gui.napari_window``: Manages the Napari window and its functionalities. + - ``dcp_client.gui.welcome_window``: Implements the welcome window and its interactions. + +dcp_client.utils + Contains utility modules for various tasks. + - ``dcp_client.utils.bentoml_model``: Handles interactions with BentoML for model inference and training. + - ``dcp_client.utils.fsimagestorage``: Provides functions for managing images stored in the filesystem. + - ``dcp_client.utils.settings``: Defines initialization functions and settings. + - ``dcp_client.utils.sync_src_dst``: Implements data synchronization between source and destination. + - ``dcp_client.utils.utils``: Offers various utility functions for common tasks. + + +Submodules +---------- + +dcp\_client.app module +---------------------- + +.. automodule:: dcp_client.app + :members: + :undoc-members: + :show-inheritance: + +dcp\_client.gui module +---------------------- +.. toctree:: + :maxdepth: 4 + + dcp_client.gui + +dcp\_client.utils module +------------------------ +.. toctree:: + :maxdepth: 4 + + dcp_client.utils + \ No newline at end of file diff --git a/docs/source/dcp_client.utils.rst b/docs/source/dcp_client.utils.rst new file mode 100644 index 00000000..61f563f7 --- /dev/null +++ b/docs/source/dcp_client.utils.rst @@ -0,0 +1,51 @@ +dcp\_client.utils package +========================= + +.. automodule:: dcp_client.utils + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +dcp\_client.utils.bentoml\_model module +--------------------------------------- + +.. automodule:: dcp_client.utils.bentoml_model + :members: + :undoc-members: + :show-inheritance: + +dcp\_client.utils.fsimagestorage module +--------------------------------------- + +.. automodule:: dcp_client.utils.fsimagestorage + :members: + :undoc-members: + :show-inheritance: + +dcp\_client.utils.settings module +--------------------------------- + +.. automodule:: dcp_client.utils.settings + :members: + :undoc-members: + :show-inheritance: + +dcp\_client.utils.sync\_src\_dst module +--------------------------------------- + +.. automodule:: dcp_client.utils.sync_src_dst + :members: + :undoc-members: + :show-inheritance: + +dcp\_client.utils.utils module +------------------------------ + +.. automodule:: dcp_client.utils.utils + :members: + :undoc-members: + :show-inheritance: + diff --git a/docs/source/dcp_client_installation.rst b/docs/source/dcp_client_installation.rst new file mode 100644 index 00000000..b4a883b2 --- /dev/null +++ b/docs/source/dcp_client_installation.rst @@ -0,0 +1,125 @@ +.. _DCP Client: + +DCP Client +=========== + +The client of our data centric platform for microscopy imaging. + +.. image:: https://img.shields.io/badge/stability-work_in_progress-lightgrey.svg + :alt: stability-wip + +Installation +------------- + +Before starting make sure you have navigated to ``data-centric-platform/src/client``. All future steps expect you are in the client directory. This installation has been tested using a conda environment with python version 3.9 on a mac local machine. In your dedicated environment run: + +.. code-block:: bash + + pip install -e . + +Running the client: A step-by-step guide! +------------------------------------------ + +1. **Launching the client** +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +DCP includes a client and server side for using our data centric platform. The client and server communicate via the `bentoml `_ library. + There are currently two options available: running the server locally, or connecting to the running instance on the FZJ jusuf-cloud. + Before continuing, you need to make sure that DCP server is running, either locally or on the cloud. See :doc: `dcp_server_installation` for instructions on how to launch the server. **Note:** In order for this connection to succeed, you will need to have contacted the team developing DCP, so they can add your IP to the list of accepted requests. + +After you are certain the server is running, simply run: + + .. code-block:: bash + + dcp-client --mode local/remote + +Set the ``--mode`` argument to ``local`` or ``remote`` depending on which setup you have chosen for the server. + +2. **Welcome window** +~~~~~~~~~~~~~~~~~~~~~~ + + The welcome window should have now popped up. + + .. image:: https://raw.githubusercontent.com/HelmholtzAI-Consultants-Munich/data-centric-platform/main/src/client/readme_figs/client_welcome_window.png + :width: 400 + :height: 200 + :align: center + + + Here you will need to select the directories which we will be using throughout the data centric workflow. The following directories need to be defined: + + - **Uncurated Dataset Path:** + + This folder is intended to store all images of your dataset. These images may be accompanied by corresponding segmentations. If present, segmentation files should share the same filename as their associated image, appended with a suffix as specified in ``server/dcp_server/config.cfg`` file (default: '_seg'). + + - **Curation in Progress Path (Optional):** + + Images for which the segmentation is a work in progress should be moved here. Each image in this folder can have one or multiple segmentations corresponding to it (by changing the filename of the segmentation in the napari layer list after editing it, see **Viewer**). If you do not want to use an intermediate working dir, you can skip setting a path to this directory (it is not required). No future functions affect this directory, it is only used to move to and from the uncurated and curated directories. + + - **Curated Dataset Path:** + + This folder is intended to contain images along with their final segmentations. **Only** move images here when the segmentation is complete and finalised, you won't be able to change them after they have been moved here. These are then used for training your model. + +3. **Setting paths** +~~~~~~~~~~~~~~~~~~~~~ + + After setting the paths for these three folders, you can click the **Start** button. If you have set the server configuration to the cloud, you will receive a message notifying you that your data will be uploaded to the cloud. Click **Ok** to continue. + +4. **Data Overview** +~~~~~~~~~~~~~~~~~~~~ + + The main working window will appear next. This gives you an overview of the directories selected in the previous step along with three options: + + - **Generate Labels:** Click this button to generate labels for all images in the "Uncurated dataset" directory. This will call the ``segment_image`` service from the server + - **View image and fix label:** Click this button to launch your viewer. The napari software is used for visualising, and editing the images segmentations. See **Viewer** + - **Train Model:** Click this model to train your model on the images in the "Curated dataset" directory. This will call the ``train`` service from the server + + .. image:: https://raw.githubusercontent.com/HelmholtzAI-Consultants-Munich/data-centric-platform/main/src/client/readme_figs/client_data_overview_window.png + :width: 500 + :height: 200 + :align: center + +5. **The viewer** +~~~~~~~~~~~~~~~~~~~~ + + In DCP, we use [napari](https://napari.org/stable) for viewing our images and masks, adding, editing or removing labels. An example of the viewer can be seen below. After adding or removing any objects and editing existing objects wherever necessary, there are two options available: + + - Click the **Move to Curation in progress folder** if you are not 100% certain about the labels you have created. You can also click on the label in the labels layer and change the name. This will result in several label files being created in the *In progress folder*, which can be examined later on. **Note:** When changing the layer name in Napari, the user should rename it such that they add their initials or any other new info after _seg. E.g., if the labels of 1_seg.tiff have been changed in the Napari viewer, then the appropriate naming would for example be: 1_seg_CB.tiff and not 1_CB_seg.tiff. + - Click the **Move to Curated dataset folder** if you are certain that the labels you are now viewing are final and require no more curation. These images and labels will later be used for training the machine learning model, so make sure that you select this option only if you are certain about the labels. If several labels are displayed (opened from the 'Curation in progress' step), make sure to **click** on the single label in the labels layer list you wish to be moved to the *Curated data folder*. The other images will then be automatically deleted from this folder. + + .. image:: https://raw.githubusercontent.com/HelmholtzAI-Consultants-Munich/data-centric-platform/main/src/client/readme_figs/client_napari_viewer.png + :width: 900 + :height: 500 + :align: center + +Data centric workflow [intended usage summary] +---------------------------------------------- + +The intended usage of DCP would include the following: + +1. Setting up configuration, run client (with server already running) and select data directories +2. Generate labels for data in *Uncurated data folder* +3. Visualise the resulting labels with the viewer and correct labels wherever necessary - once done move the image *Curated data folder*. Repeat this step for a couple of images until a few are placed into the *Curated data folder*. Depending on the qualitative evaluation of the label generation you might want to include fewer or more images, i.e. if the resulting masks require few edits, then few images will most likely be sufficient, whereas if many edits to the mask are required it is likely that more images are needed in the *Curated data folder*. You can always start with a small number and adjust later +4. Train the model with the images in the *Curated data folder* +5. Repeat steps 2-4 until you are satisfied with the masks generated for the remaining images in the *Uncurated data folder*. Every time the model is trained in step 4, the masks generated in step 2 should be of higher quality, until the model need not be trained any more + + .. image:: https://raw.githubusercontent.com/HelmholtzAI-Consultants-Munich/data-centric-platform/main/src/client/readme_figs/dcp_pipeline.png + :width: 400 + :height: 400 + :align: center + +DCP Shortcuts +------------- + +- In the Data Overview window, clicking on an image and the hitting the **Enter** key, is equivalent to clicking the 'View Image and Fix Label' button +- The viewer accepts all Napari Shortcuts. The current list of the shortcuts for macOS can be see below: + +.. image:: https://raw.githubusercontent.com/HelmholtzAI-Consultants-Munich/data-centric-platform/main/src/client/readme_figs/napari_shortcuts.png + :width: 600 + :height: 500 + :align: center + + + + + diff --git a/docs/source/dcp_server.rst b/docs/source/dcp_server.rst new file mode 100644 index 00000000..78a1ef23 --- /dev/null +++ b/docs/source/dcp_server.rst @@ -0,0 +1,54 @@ +dcp\_server package +=================== + +The dcp_server package is structured to handle various server-side functionalities related model serving for segmentation and training. + +dcp_server.models + Defines various models for cell classification and segmentation, including CellClassifierFCNN, CellClassifierShallowModel, CellposePatchCNN, CustomCellposeModel, and UNet. + These models handle tasks such as evaluation, forward pass, training, and updating configurations. + +dcp_server.segmentationclasses + Defines segmentation classes for specific projects, such as GFPProjectSegmentation, GeneralSegmentation, and MitoProjectSegmentation. + These classes contain methods for segmenting images and training models on images and masks. + +dcp_server.serviceclasses + Defines service classes, such as CustomBentoService and CustomRunnable, for serving the models with BentoML and handling computation on remote Python workers. + +dcp_server.utils + Provides various utility functions for dealing with image storage, image processing, feature extraction, file handling, configuration reading, and path manipulation. + + +Submodules +---------- + +dcp\_server.models module +------------------------- + +.. automodule:: dcp_server.models + :members: + :undoc-members: + :show-inheritance: + +dcp\_server.segmentationclasses module +-------------------------------------- + +.. automodule:: dcp_server.segmentationclasses + :members: + :undoc-members: + :show-inheritance: + +dcp\_server.serviceclasses module +--------------------------------- + +.. automodule:: dcp_server.serviceclasses + :members: + :undoc-members: + :show-inheritance: + +dcp\_server.utils module +--------------------------------- + +.. toctree:: + :maxdepth: 4 + + dcp_server.utils \ No newline at end of file diff --git a/docs/source/dcp_server.utils.rst b/docs/source/dcp_server.utils.rst new file mode 100644 index 00000000..a6334330 --- /dev/null +++ b/docs/source/dcp_server.utils.rst @@ -0,0 +1,37 @@ +dcp\_server.utils package +========================= + +.. automodule:: dcp_server.utils + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +dcp\_server.utils.fsimagestorage module +--------------------------------------- + +.. automodule:: dcp_server.utils.fsimagestorage + :members: + :undoc-members: + :show-inheritance: + +dcp\_server.utils.helpers module +--------------------------------------- + +.. automodule:: dcp_server.utils.helpers + :members: + :undoc-members: + :show-inheritance: + +dcp\_server.utils.processing module +----------------------------------- + +.. automodule:: dcp_server.utils.processing + :members: + :undoc-members: + :show-inheritance: + + + diff --git a/docs/source/dcp_server_installation.rst b/docs/source/dcp_server_installation.rst new file mode 100644 index 00000000..823f97af --- /dev/null +++ b/docs/source/dcp_server_installation.rst @@ -0,0 +1,129 @@ +.. _DCP Server: + +DCP Server +=========== + + +The server of our data centric platform for microscopy imaging. + +.. image:: https://img.shields.io/badge/stability-work_in_progress-lightgrey.svg + :alt: stability-wip + +The client and server communicate via the `bentoml `_ library. The client interacts with the server every time we run model inference or training, so the server should be running before starting the client. + + +Installation +-------------- + +Before starting make sure you have navigated to ``data-centric-platform/src/server``. All future steps expect you are in the server directory. In your dedicated environment run: + +.. code-block:: bash + + pip install -e ".[dev]" + +Launch DCP Server +------------------ + +Simply run: + +.. code-block:: bash + + python dcp_server/main.py + +Once the server is running, you can verify it is working by visiting http://localhost:7010/ in your web browser. + +Customization (for developers) +-------------------------------- + +All service configurations are set in the ``config.cfg`` file. Please, obey the `formal JSON format `_. + +The config file has to have the six main parts. All the ``marked`` arguments are mandatory: + +- ``setup`` + + - ``segmentation`` - segmentation type from the ``segmentationclasses.py``. Currently, only ``GeneralSegmentation`` is available (MitoProjectSegmentation and GFPProjectSegmentation are stale). + - ``model_to_use`` - name of the model class from the ``models.py`` you want to use. Currently, available models are: ``CustomCellposeModel``, ``UNet`` and ``CellposePatchCNN``. See **Models** section for more information. + - ``accepted_types`` - types of images currently accepted for the analysis + - ``seg_name_string`` - end string for masks to run on (All the segmentations of the image should contain this string - used to save and search for segmentations of the images) +- ``service`` + + - ``runner_name`` - name of the runner for the bentoml service + - ``bento_model_path`` - name for the trained model which will be saved after calling the (re)train from service - is saved under ``bentoml/models`` + - ``service_name`` - name for the bentoml service + - ``port`` - on which port to start the service +- ``model`` - configuration for the model instantiation. Here, pass any arguments you need or want to change. Take care that the names of the arguments are the same as of the original model class ``__init__()`` function. + + - ``segmentor``: model configuration for the segmentor. ``CustomCellposeModel`` takes arguments used in the init of CellposeModel, see `here `__. + - ``classifier``: model configuration for classifier, see ``__init__()`` of ``UNet``, ``CellClassifierFCNN`` or ``CellClassifierShallowModel``. +- ``data`` - data configuration + + - ``data_root``: if you are running the server remotely, then you need to specify the project path here. Should match the ``server: data-path`` argument in the client config. +- ``train`` - configuration for the model training. Take care that the names of the arguments are the same as of the original model's ``train()`` function. + + - ``segmentor``: for ``CustomCellposeModel`` the ``train()`` function arguments can be found `here `__. Pass any arguments you need or want to change or leave empty {}, then default arguments will be used. + - ``classifier``: train configuration for classifier, see ``train()`` of ``UNet``, ``CellClassifierFCNN`` or ``CellClassifierShallowModel``. +- ``eval`` - configuration for the model evaluation. Take care that the names of the arguments are the same as of the original model's ``eval()`` function. + + - ``segmentor``: for ``CustomCellposeModel`` the ``eval()`` function arguments can be found `here `__. Pass any arguments you need or want to change or leave empty {}, then default arguments will be used. + - ``classifier``: train configuration for classifier, see ``eval()`` of ``UNet``, ``CellClassifierFCNN``or ``CellClassifierShallowModel`` + - ``mask_channel_axis``: If a multi-class instance segmentation model has been used, then the masks returned by the model should have two channels, one for the instance segmentation results and one indicating the objects class. This variable indicated at which dim the channel axis should be stored. Currently should be kept at 0, as this is the only way the masks can be visualized correctly by napari in the client. + +To make it easier to run the server we provide you with three config files: + + - ``config.cfg`` is set up to work for a panoptic segmentation task + - ``config_instance.cfg`` for instance segmentation + - ``config_semantic.cfg`` for semantic segmentation + +Make sure to rename the config you wish to use to ``config.cfg``. The default is panoptic segmentation. + +Models +------- + +The models are currently integrated into DCP: + +- **Instance** Segmentation: + + - ``CustomCellpose``: Inherits from cellpose.models.CellposeModel, see `here `__ for more information. +- **Semantic** Segmentation: + + - ``UNet``: A vanilla U-Net model, trained on the full images +- **Multi Class Instance** Segmentation: + + - ``Inst2MultiSeg``: Includes a segmentor for instance segmentation, sequentially followed by a classifier for semantic segmentation. The segmentor can only be ``CustomCellposeModel`` model, while the classifier can be one of: + + - ``PatchClassifier`` or "FCNN" (in config): A CNN model for obtaining class labels, trained on images patches of individual objects, extarcted using the instance mask from the previous step + - ``FeatureClassifier`` or "RandomForest" (in config): A Random Forest model for obtaining class labels, trained on shape and intensity features of the objects, extracted using the instance mask from the previous step. + - ``MultiCellpose``: Includes **n** CustomCellpose models, where n equals the number of classes, stacked such that each model predicts only the object corresponding to each class. + - ``UNet``: If the post-processing argument is set, then the instance mask is deduced from the labels mask. Will not be able to handle touching objects + + +Running with Docker +------------------------------------------------------- + +.. note:: + DO NOT USE UNTIL ISSUE IS SOLVED: Currently doesn't work for generate labels + +Docker-Compose +~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + docker compose up + +Docker Non-Interactively +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + docker build -t dcp-server . + docker run -p 7010:7010 -it dcp-server + +Docker Interactively +~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + docker build -t dcp-server . + docker run -it dcp-server bash + bentoml serve service:svc --reload --port=7010 + diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 00000000..c6532633 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,78 @@ +.. dcp documentation master file, created by + sphinx-quickstart on Sun Feb 11 18:53:28 2024. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Data Centric Platform +=============================== + +*A data centric platform for all-kinds segmentation in microscopy imaging* + +.. image:: https://img.shields.io/badge/stability-work_in_progress-lightgrey.svg + :alt: stability-wip + +.. image:: https://github.com/HelmholtzAI-Consultants-Munich/data-centric-platform/actions/workflows/test.yml/badge.svg?event=push + :alt: tests + +.. image:: https://codecov.io/gh/HelmholtzAI-Consultants-Munich/data-centric-platform/branch/main/graph/badge.svg + :target: https://codecov.io/gh/HelmholtzAI-Consultants-Munich/data-centric-platform + +How to use it? +---------------- + +The client and server communicate via the `bentoml `_ library. +The client interacts with the server every time we run model inference or training. +For full functionality of the software the server should be running, either locally or remotely. + +To install and start the server side follow the instructions described in :ref:`DCP Server`. +To run the client GUI follow the instructions described in :ref:`DCP Client` + +DCP handles all kinds of **segmentation tasks**! Try it out if you need to do: + +- **Instance** segmentation +- **Semantic** segmentation +- **Multi-class instance** segmentation + +Toy data +-------- + +Our github repo includes the ``data/`` directory with some toy data which you can use as the *Uncurated dataset* folder. You can create (empty) folders for the other two directories required in the welcome window and start playing around. + +Enabling data centric development +---------------------------------- + +Our platform encourages the use of data centric practices. With the user friendly client interface you can: + +- **Detect and remove outliers** from your training data: only confirmed samples are used to train our models +- **Detect and correct labeling errors**: editing labels with the integrated napari visualisation tool +- **Establish consensus**: allows for multiple annotators before curated label is passed to train model +- **Focus on data curation**: no interaction with model parameters during training and inference + +.. image:: https://raw.githubusercontent.com/HelmholtzAI-Consultants-Munich/data-centric-platform/main/src/client/readme_figs/dcp_pipeline.png + :width: 400 + :height: 400 + :align: center + +.. centered:: + *Get more with less!* + +DCP Imaging Conventions +----------------------- +DCP currently follows the imaging conventions described below: + +- Only 2D images are accepted +- The accepted imaging formats are: ``(".jpg", ".jpeg", ".png", ".tiff", ".tif")`` +- RGB and RGBA images are accepted, however they will be converted to grayscale after read into DCP. The dims can be [C, H, W] or [H, W, C] +- Existing segementations can be used, however they need to be TIFF files and have the same name as the corresponding image followed by '_seg', e.g. image1_seg.tiff\ + +.. toctree:: + :maxdepth: 3 + :caption: Contents: + + dcp_client_installation + dcp_server_installation + dcp_server + dcp_client + + + diff --git a/src/client/MANIFEST.in b/src/client/MANIFEST.in new file mode 100644 index 00000000..c6c02f12 --- /dev/null +++ b/src/client/MANIFEST.in @@ -0,0 +1 @@ +include dcp_client/*.yaml \ No newline at end of file diff --git a/src/client/README.md b/src/client/README.md index 8d0a16ac..ce0e4354 100644 --- a/src/client/README.md +++ b/src/client/README.md @@ -2,89 +2,27 @@ The client of our data centric platform for microscopy imaging. ![stability-wip](https://img.shields.io/badge/stability-work_in_progress-lightgrey.svg) +[![Documentation Status](https://readthedocs.org/projects/data-centric-platform/badge/?version=latest)](https://data-centric-platform.readthedocs.io/en/latest/?badge=latest) ## How to use? + ### Installation -Before starting make sure you have navigated to ```data-centric-platform/src/client```. All future steps expect you are in the client directory. This installation has been tested using a conda environment with python version 3.9 on a mac local machine. In your dedicated environment run: +This has been tested on Python versions 3.9, 3.10 and 3.11 on latest versions of Windows, Ubuntu and MacOS. In your dedicated environment run: ``` -pip install -e . +pip install dcp_client ``` -### Running the client: A step to step guide! -1. **Configurations** - -Before launching the GUI you will need to set up your client configuration file, _dcp_client/config.cfg_. Please, obey the [formal JSON format](https://www.json.org/json-en.html). Here, we will define how the client will interact with the server. There are currently two options available: running the server locally, or connecting to the running instance on the FZJ jusuf-cloud. To connect to a locally running server, set: - ``` - "server":{ - "user": "local", - "host": "local", - "data-path": "None", - "ip": "localhost", - "port": 7010 - } - ``` - To connect to the running service on jusuf-cloud, set: - ``` - "server":{ - "user": "xxxxx", - "host": "xxxxxx", - "data-path": "xxxxx", - "ip": "xxx.xx.xx.xx", - "port": xxxx - } - ``` - Before continuing, you need to make sure that DCP server is running, either locally or on the cloud. See [DCP Server Installation & Launch](https://github.com/HelmholtzAI-Consultants-Munich/data-centric-platform/blob/main/src/server/README.md#using-pypi) for instructions on how to launch the server. **Note:** In order for this connection to succeed, you will need to have contacted the team developing DCP, so they can add your IP to the list of accepted requests. - -To make it easier for you we provide you with two config files, one works when running a local server and one for remote - just make sure you rename the config file you wish to use to ```config.cfg```. The defualt is local configuration. - - -2. **Launching the client** - -After setting your config simply run: - ``` - python dcp_client/main.py - ``` - -3. **Welcome window** - -The welcome window should have now popped up. - - - - Here you will need to select the directories which we will be using throughout the data centric workflow. The following directories need to be defined: - - * **Uncurated dataset path:** This folder is intended to store all images of your dataset. These images may be accompanied by corresponding segmentations. If present, segmentation files should share the same filename as their associated image, appended with a suffix as specified in 'setup/seg_name_string', defined in ```server/dcp_server/config.cfg``` (default: '_seg'). - * **Curation in progress path:(Optional)** Images for which the segmentation is a work in progress should be moved here. Each image in this folder can have one or multiple segmentations corresponding to it (by changing the filename of the segmentation in the napari layer list after editing it, see **Viewer**). If you do not want to use an intermediate working dir, you can skip setting a path to this directory (it is not required). No future functions affect this directory, it is only used to move to and from the uncurated and curated directories. - * **Curated dataset path:** This folder is intended to contain images along with their final segmentations. **Only** move images here when the segmentation is complete and finalised, you won't be able to change them after they have been moved here. These are then used for training your model. - -4. **Setting paths** - -After setting the paths for these three folders, you can click the **Start** button. If you have set the server configuration to the cloud, you will receive a message notifying you that your data will be uploaded to the cloud. Clik **Ok** to continue. - -5. **Data Overview** - -The main working window will appear next. This gives you an overview of the directories selected in the previous step along with three options: - - * **Generate Labels:** Click this button to generate labels for all images in the "Uncurated dataset" directory. This will call the ```segment_image``` service from the server - * **View image and fix label:** Click this button to launch your viewer. The napari software is used for visualising, and editing the images segmentations. See **Viewer** - * **Train Model:** Click this model to train your model on the images in the "Curated dataset" directory. This will call the ```train``` service from the server - ![Alt Text](https://github.com/HelmholtzAI-Consultants-Munich/data-centric-platform/blob/main/src/client/readme_figs/client_data_overview_window.png) - -6. **The viewer** - -In DCP, we use [napari](https://napari.org/stable) for viewing our images and makss, adding, editing or removing labels. An example of the viewer can be seen below. After adding or removing any objects and editing existing objects wherever necessary, there are two options available: -- Click the **Move to Curation in progress folder** if you are not 100% certain about the labels you have created. You can also click on the label in the labels layer and change the name. This will result in several label files being created in the *In progress folder*, which can be examined later on. **Note:** When changing the layer name in Napari, the user should rename it such that they add their initials or any other new info after _seg. E.g., if the labels of 1_seg.tiff have been changed in the Napari viewer, then the appropriate naming would for example be: 1_seg_CB.tiff and not 1_CB_seg.tiff. -- Click the **Move to Curated dataset folder** if you are certain that the labels you are now viewing are final and require no more curation. These images and labels will later be used for training the machine learning model, so make sure that you select this option only if you are certain about the labels. If several labels are displayed (opened from the 'Curation in progress' step), make sure to **click** on the single label in the labels layer list you wish to be moved to the *Curated data folder*. The other images will then be automatically deleted from this folder. - -![Alt Text](https://github.com/HelmholtzAI-Consultants-Munich/data-centric-platform/blob/main/src/client/readme_figs/client_napari_viewer.png) +### Installation for developers +Before starting, make sure you have navigated to ```data-centric-platform/src/client```. All future steps expect you are in the client directory. This installation has been tested using a conda environment with python version 3.9 on a mac local machine. In your dedicated environment run: +``` +pip install -e . +``` -### Data centric workflow [intended usage summary] -The intended usage of DCP would include the following: -1. Setting up configuration, run client (with server already running) and select data directories -2. Generate labels for data in *Uncurated data folder* -3. Visualise the resulting labels with the viewer and correct labels wherever necessary - once done move the image *Curated data folder*. Repeat this step for a couple of images until a few are placed into the *Curated data folder*. Depending on the qualitative evaluation of the label generation you might want to include fewer or more images, i.e. if the resulting masks require few edits, then few images will most likely be sufficient, whereas if many edits to the mask are required it is likely that more images are needed in the *Curated data folder*. You can always start with a small number and adjust later -4. Train the model with the images in the *Curated data folder* -6. Repeat steps 2-4 until you are satisfied with the masks generated for the remaining images in the *Uncurated data folder*. Every time the model is trained in step 4, the masks generated in step 2 should be of higher quality, until the model need not be trained any more - +#### Launch DCP client +Make sure the server is already running, either locally or remotely. Then, depending on the configuration, simply run: +``` +dcp-client --mode local/remote +``` - +## Want to know more? +Visit our [documentation](https://data-centric-platform.readthedocs.io/en/latest/dcp_client_installation.html) for more information and a step by step guide on how to run the client. diff --git a/src/client/dcp_client/__init__.py b/src/client/dcp_client/__init__.py index e69de29b..f4ffe44b 100644 --- a/src/client/dcp_client/__init__.py +++ b/src/client/dcp_client/__init__.py @@ -0,0 +1,42 @@ +""" +Overview of dcp_client Package +============================== + +The `dcp_client` package contains modules and subpackages for interacting with a server for model inference and training. It provides functionalities for managing GUI windows, handling image storage, and connecting to the server for model operations. + +Subpackages +------------ + +- **dcp_client.gui package**: Contains modules for GUI components. + + - **Submodules**: + + - ``dcp_client.gui.main_window``: Defines the main application window and associated event functions. + - ``dcp_client.gui.napari_window``: Manages the Napari window and its functionalities. + - ``dcp_client.gui.welcome_window``: Implements the welcome window and its interactions. + +- **dcp_client.utils package**: Contains utility modules for various tasks. + + - **Submodules**: + + - ``dcp_client.utils.bentoml_model``: Handles interactions with BentoML for model inference and training. + - ``dcp_client.utils.fsimagestorage``: Provides functions for managing images stored in the filesystem. + - ``dcp_client.utils.settings``: Defines initialization functions and settings. + - ``dcp_client.utils.sync_src_dst``: Implements data synchronization between source and destination. + - ``dcp_client.utils.utils``: Offers various utility functions for common tasks. + +Submodules +------------ + +- **dcp_client.app module**: Defines the core application class and related functionalities. + + - **Classes**: + + - ``dcp_client.app.Application``: Represents the main application and provides methods for image management, model interaction, and server connectivity. + - ``dcp_client.app.DataSync``: Abstract base class for data synchronization operations. + - ``dcp_client.app.ImageStorage``: Abstract base class for image storage operations. + - ``dcp_client.app.Model``: Abstract base class for model operations. + +This package structure allows for easy management of GUI components, image storage, model interactions, and server connectivity within the dcp_client application. + +""" diff --git a/src/client/dcp_client/app.py b/src/client/dcp_client/app.py index d46a4244..8ae4e8a9 100644 --- a/src/client/dcp_client/app.py +++ b/src/client/dcp_client/app.py @@ -11,7 +11,7 @@ class Model(ABC): @abstractmethod def run_train(self, path: str) -> None: pass - + @abstractmethod def run_inference(self, path: str) -> None: pass @@ -21,7 +21,7 @@ class DataSync(ABC): @abstractmethod def sync(self, src: str, dst: str, path: str) -> None: pass - + class ImageStorage(ABC): @abstractmethod @@ -35,22 +35,29 @@ def save_image(self, to_directory, cur_selected_img, img) -> None: def search_segs(self, img_directory, cur_selected_img): """Returns a list of full paths of segmentations for an image""" # Take all segmentations of the image from the current directory: - search_string = utils.get_path_stem(cur_selected_img) + '_seg' - seg_files = [file_name for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] + search_string = utils.get_path_stem(cur_selected_img) + "_seg" + seg_files = [ + file_name + for file_name in os.listdir(img_directory) + if ( + search_string == utils.get_path_stem(file_name) + or str(file_name).startswith(search_string) + ) + ] return seg_files class Application: def __init__( - self, + self, ml_model: Model, syncer: DataSync, image_storage: ImageStorage, server_ip: str, server_port: int, - eval_data_path: str = '', - train_data_path: str = '', - inprogr_data_path: str = '', + eval_data_path: str = "", + train_data_path: str = "", + inprogr_data_path: str = "", ): self.ml_model = ml_model self.syncer = syncer @@ -60,73 +67,90 @@ def __init__( self.eval_data_path = eval_data_path self.train_data_path = train_data_path self.inprogr_data_path = inprogr_data_path - self.cur_selected_img = '' - self.cur_selected_path = '' - self.seg_filepaths = [] + self.cur_selected_img = "" + self.cur_selected_path = "" + self.seg_filepaths = [] def upload_data_to_server(self): """ Uploads the train and eval data to the server. """ - success_f1, message1 = self.syncer.first_sync(path=self.train_data_path) - success_f2, message2 = self.syncer.first_sync(path=self.eval_data_path) + success_f1, message1 = self.syncer.first_sync(path=self.train_data_path) + success_f2, message2 = self.syncer.first_sync(path=self.eval_data_path) return success_f1, success_f2, message1, message2 def try_server_connection(self): """ Checks if the ml model is connected to server and attempts to connect if not. """ - connection_success = self.ml_model.connect(ip=self.server_ip, port=self.server_port) + connection_success = self.ml_model.connect( + ip=self.server_ip, port=self.server_port + ) return connection_success - + def run_train(self): - """ Checks if the ml model is connected to the server, connects if not (and if possible), and trains the model with all data available in train_data_path """ - if not self.ml_model.is_connected and not self.try_server_connection(): + """Checks if the ml model is connected to the server, connects if not (and if possible), and trains the model with all data available in train_data_path""" + if not self.ml_model.is_connected and not self.try_server_connection(): message_title = "Warning" message_text = "Connection could not be established. Please check if the server is running and try again." return message_text, message_title # if syncer.host name is None then local machine is used to train message_title = "Information" - if self.syncer.host_name=="local": + if self.syncer.host_name == "local": message_text = self.ml_model.run_train(self.train_data_path) else: - success_sync, srv_relative_path = self.syncer.sync(src='client', dst='server', path=self.train_data_path) + success_sync, srv_relative_path = self.syncer.sync( + src="client", dst="server", path=self.train_data_path + ) # make sure syncing of folders was successful - if success_sync=="Success": message_text = self.ml_model.run_train(srv_relative_path) - else: message_text = None - if message_text is None: + if success_sync == "Success": + message_text = self.ml_model.run_train(srv_relative_path) + else: + message_text = None + if message_text is None: message_text = "An error has occured on the server. Please check your image data and configurations. If the problem persists contact your software provider." message_title = "Error" return message_text, message_title - + def run_inference(self): - """ Checks if the ml model is connected to the server, connects if not (and if possible), and runs inference on all images in eval_data_path """ - if not self.ml_model.is_connected and not self.try_server_connection(): + """Checks if the ml model is connected to the server, connects if not (and if possible), and runs inference on all images in eval_data_path""" + if not self.ml_model.is_connected and not self.try_server_connection(): message_title = "Warning" message_text = "Connection could not be established. Please check if the server is running and try again." return message_text, message_title - - if self.syncer.host_name=="local": + + if self.syncer.host_name == "local": # model serving directly from local - list_of_files_not_suported = self.ml_model.run_inference(self.eval_data_path) - success_sync = "Success" + list_of_files_not_suported = self.ml_model.run_inference( + self.eval_data_path + ) + success_sync = "Success" else: # sync data so that server gets updated files in client - e.g. if file was moved to curated srv_relative_path = utils.get_relative_path(self.eval_data_path) - success_sync, _ = self.syncer.sync(src='client', dst='server', path=self.eval_data_path) + success_sync, _ = self.syncer.sync( + src="client", dst="server", path=self.eval_data_path + ) # model serving from server list_of_files_not_suported = self.ml_model.run_inference(srv_relative_path) - # sync data so that client gets new masks - success_sync, _ = self.syncer.sync(src='server', dst='client', path=self.eval_data_path) + # sync data so that client gets new masks + success_sync, _ = self.syncer.sync( + src="server", dst="client", path=self.eval_data_path + ) # check if serving could not be performed for some files and prepare message - if list_of_files_not_suported is None or success_sync=="Error": + if list_of_files_not_suported is None or success_sync == "Error": message_text = "An error has occured on the server. Please check your image data and configurations. If the problem persists contact your software provider." message_title = "Error" else: list_of_files_not_suported = list(list_of_files_not_suported) if len(list_of_files_not_suported) > 0: - message_text = "Image types not supported. Only 2D and 3D image shapes currently supported. 3D stacks must be of type grayscale. \ - Currently supported image file formats are: " + ", ".join(settings.accepted_types)+ ". The files that were not supported are: " + ", ".join(list_of_files_not_suported) + message_text = ( + "Image types not supported. Only 2D and 3D image shapes currently supported. 3D stacks must be of type grayscale. \ + Currently supported image file formats are: " + + ", ".join(settings.accepted_types) + + ". The files that were not supported are: " + + ", ".join(list_of_files_not_suported) + ) message_title = "Warning" else: message_text = "Success! Masks generated for all images" @@ -134,31 +158,69 @@ def run_inference(self): return message_text, message_title def load_image(self, image_name=None): + """ + Loads an image from the file system storage. + + :param str image_name: The name of the image file to load. + If not provided, loads the currently selected image. + + :return: The loaded image. + :rtype: numpy.ndarray + + """ if image_name is None: - return self.fs_image_storage.load_image(self.cur_selected_path, self.cur_selected_img) - else: return self.fs_image_storage.load_image(self.cur_selected_path, image_name) - + return self.fs_image_storage.load_image( + self.cur_selected_path, self.cur_selected_img + ) + else: + return self.fs_image_storage.load_image(self.cur_selected_path, image_name) + def search_segs(self): - """ Searches in cur_selected_path for all possible segmentation files associated to cur_selected_img. - These files should have a _seg extension to the cur_selected_img filename. """ - self.seg_filepaths = self.fs_image_storage.search_segs(self.cur_selected_path, self.cur_selected_img) - + """Searches in cur_selected_path for all possible segmentation files associated to cur_selected_img. + These files should have a _seg extension to the cur_selected_img filename.""" + self.seg_filepaths = self.fs_image_storage.search_segs( + self.cur_selected_path, self.cur_selected_img + ) + def save_image(self, dst_directory, image_name, img): - """ Saves img array image in the dst_directory with filename cur_selected_img """ + """Saves img array image in the dst_directory with filename cur_selected_img + + :param dst_directory: The destination directory where the image will be saved. + :type dst_directory: str + :param image_name: The name of the image file. + :type image_name: str + :param img: The image that will be saved. + :type img: numpy.ndarray + """ self.fs_image_storage.save_image(dst_directory, image_name, img) def move_images(self, dst_directory, move_segs=False): - """ Moves cur_selected_img image from the current directory to the dst_directory """ - #if image_name is None: - self.fs_image_storage.move_image(self.cur_selected_path, dst_directory, self.cur_selected_img) + """ + Moves cur_selected_img image from the current directory to the dst_directory. + + :param dst_directory: The destination directory where the images will be moved. + :type dst_directory: str + + :param move_segs: If True, moves the corresponding segmentation along with the image. Default is False. + :type move_segs: bool + + """ + # if image_name is None: + self.fs_image_storage.move_image( + self.cur_selected_path, dst_directory, self.cur_selected_img + ) if move_segs: for seg_name in self.seg_filepaths: - self.fs_image_storage.move_image(self.cur_selected_path, dst_directory, seg_name) + self.fs_image_storage.move_image( + self.cur_selected_path, dst_directory, seg_name + ) def delete_images(self, image_names): - """ If image_name in the image_names list exists in the current directory it is deleted """ + """If image_name in the image_names list exists in the current directory it is deleted. + + :param image_names: A list of image names to be deleted. + :type image_names: list[str] + """ for image_name in image_names: - if os.path.exists(os.path.join(self.cur_selected_path, image_name)): + if os.path.exists(os.path.join(self.cur_selected_path, image_name)): self.fs_image_storage.delete_image(self.cur_selected_path, image_name) - - diff --git a/src/client/dcp_client/config.cfg b/src/client/dcp_client/config.yaml similarity index 100% rename from src/client/dcp_client/config.cfg rename to src/client/dcp_client/config.yaml diff --git a/src/client/dcp_client/config_remote.cfg b/src/client/dcp_client/config_remote.yaml similarity index 100% rename from src/client/dcp_client/config_remote.cfg rename to src/client/dcp_client/config_remote.yaml diff --git a/src/client/dcp_client/gui/__init__.py b/src/client/dcp_client/gui/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/client/dcp_client/gui/_my_widget.py b/src/client/dcp_client/gui/_my_widget.py index 38a37449..acf54b61 100644 --- a/src/client/dcp_client/gui/_my_widget.py +++ b/src/client/dcp_client/gui/_my_widget.py @@ -1,19 +1,43 @@ from PyQt5.QtWidgets import QWidget, QMessageBox from PyQt5.QtCore import QTimer + class MyWidget(QWidget): + """ + This class represents a custom widget. + """ msg = None - sim = False # will be used for testing to simulate user click + sim = False # will be used for testing to simulate user click + + def create_warning_box( + self, + message_text: str = " ", + message_title: str = "Information", + add_cancel_btn: bool = False, + custom_dialog=None, + ) -> None: + """Creates a warning box with the specified message and options. - def create_warning_box(self, message_text: str=" ", message_title: str="Information", add_cancel_btn: bool=False, custom_dialog=None) -> None: - #setup box - if custom_dialog is not None: self.msg = custom_dialog - else: self.msg = QMessageBox() + :param message_text: The text to be displayed in the message box. + :type message_text: str + :param message_title: The title of the message box. Default is "Information". + :type message_title: str + :param add_cancel_btn: Flag indicating whether to add a cancel button to the message box. Default is False. + :type add_cancel_btn: bool + :param custom_dialog: An optional custom dialog to use instead of creating a new QMessageBox instance. Default is None. + :type custom_dialog: Any + :return: None + """ + # setup box + if custom_dialog is not None: + self.msg = custom_dialog + else: + self.msg = QMessageBox() - if message_title=="Warning": + if message_title == "Warning": message_type = QMessageBox.Warning - elif message_title=="Error": + elif message_title == "Error": message_type = QMessageBox.Critical else: message_type = QMessageBox.Information @@ -24,12 +48,16 @@ def create_warning_box(self, message_text: str=" ", message_title: str="Informat if add_cancel_btn: self.msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) # simulate button click if specified - workaround used for testing - if self.sim: QTimer.singleShot(0, self.msg.button(QMessageBox.Cancel).clicked) + if self.sim: + QTimer.singleShot(0, self.msg.button(QMessageBox.Cancel).clicked) else: self.msg.setStandardButtons(QMessageBox.Ok) # simulate button click if specified - workaround used for testing - if self.sim: QTimer.singleShot(0, self.msg.button(QMessageBox.Ok).clicked) + if self.sim: + QTimer.singleShot(0, self.msg.button(QMessageBox.Ok).clicked) # return if user clicks Ok and False otherwise usr_response = self.msg.exec() - if usr_response == QMessageBox.Ok: return True - else: return False \ No newline at end of file + if usr_response == QMessageBox.Ok: + return True + else: + return False diff --git a/src/client/dcp_client/gui/main_window.py b/src/client/dcp_client/gui/main_window.py index 92ffa79b..cc4b4897 100644 --- a/src/client/dcp_client/gui/main_window.py +++ b/src/client/dcp_client/gui/main_window.py @@ -24,19 +24,41 @@ class WorkerThread(QThread): - ''' Worker thread for displaying Pulse ProgressBar during model serving ''' + """ + Worker thread for displaying Pulse ProgressBar during model serving. + + """ + task_finished = pyqtSignal(tuple) - def __init__(self, app: Application, task: str = None, parent = None,): + + def __init__( + self, + app: Application, + task: str = None, + parent=None, + ): + """ + Initialize the WorkerThread. + + :param app: The Application instance. See dcp_client.app for more information. + :type app: dcp_client.app.Application + :param task: The task performed by the worker thread. Can be 'inference' or 'train'. + :type task: str, optional + :param parent: The parent QObject (default is None). + """ super().__init__(parent) self.app = app self.task = task - def run(self): - ''' Once run_inference the tuple of (message_text, message_title) will be returned to on_finished''' + def run(self) -> None: + """ + Once run_inference or run_train is executed, the tuple of + (message_text, message_title) will be returned to on_finished. + """ try: - if self.task == 'inference': + if self.task == "inference": message_text, message_title = self.app.run_inference() - elif self.task == 'train': + elif self.task == "train": message_text, message_title = self.app.run_train() else: message_text, message_title = "Unknown task", "Error" @@ -195,35 +217,46 @@ def paint(self, painter, option, index): pixmap = index.data(Qt.DecorationRole) painter.drawPixmap(option.rect, pixmap) + class MainWindow(MyWidget): - ''' + """ Main Window Widget object. - Opens the main window of the app where selected images in both directories are listed. - User can view the images, train the mdoel to get the labels, and visualise the result. + Opens the main window of the app where selected images in both directories are listed. + User can view the images, train the model to get the labels, and visualise the result. + :param eval_data_path: Chosen path to images without labeles, selected by the user in the WelcomeWindow :type eval_data_path: string :param train_data_path: Chosen path to images with labeles, selected by the user in the WelcomeWindow :type train_data_path: string - ''' + """ + def __init__(self, app: Application) -> None: + """ + Initializes the MainWindow. + + :param app: The Application instance. See dcp_client.app for more information. + :type app: dcp_client.app.Application + :param app.eval_data_path: Chosen path to images without labels, selected by the user in the WelcomeWindow. + :type app.eval_data_path: str + :param app.train_data_path: Chosen path to images with labels, selected by the user in the WelcomeWindow. + :type app.train_data_path: str + """ - def __init__(self, app: Application): super().__init__() self.app = app self.title = "Data Overview" self.worker_thread = None self.main_window() - - def main_window(self): - ''' - Sets up the GUI - ''' + + def main_window(self) -> None: + """Sets up the GUI""" self.setWindowTitle(self.title) self.resize(1000, 700) self.setStyleSheet("background-color: #f3f3f3;") + main_layout = QVBoxLayout() - dir_layout = QHBoxLayout() - + dir_layout = QHBoxLayout() + self.uncurated_layout = QVBoxLayout() self.inprogress_layout = QVBoxLayout() self.curated_layout = QVBoxLayout() @@ -231,6 +264,7 @@ def main_window(self): self.eval_dir_layout = QVBoxLayout() self.eval_dir_layout.setContentsMargins(0,0,0,0) + self.label_eval = QLabel(self) self.label_eval.setText("Uncurated Dataset") self.label_eval.setMinimumHeight(50) @@ -267,10 +301,12 @@ def main_window(self): for i in range(1,4): self.list_view_eval.hideColumn(i) - #self.list_view_eval.setFixedSize(600, 600) - self.list_view_eval.setRootIndex(model_eval.setRootPath(self.app.eval_data_path)) + # self.list_view_eval.setFixedSize(600, 600) + self.list_view_eval.setRootIndex( + model_eval.setRootPath(self.app.eval_data_path) + ) self.list_view_eval.clicked.connect(self.on_item_eval_selected) - + self.eval_dir_layout.addWidget(self.list_view_eval) self.uncurated_layout.addLayout(self.eval_dir_layout) @@ -310,8 +346,9 @@ def main_window(self): self.inprogr_dir_layout.addWidget(self.label_inprogr) # add in progress dir list model_inprogr = MyQFileSystemModel(app=self.app) - + #self.list_view = QListView(self) + self.list_view_inprogr = QTreeView(self) self.list_view_inprogr.setToolTip("Select an image, click it, then press Enter") # self.list_view_inprogr.setIconSize(QSize(50,50)) @@ -325,8 +362,10 @@ def main_window(self): for i in range(1,4): self.list_view_inprogr.hideColumn(i) - #self.list_view_inprogr.setFixedSize(600, 600) - self.list_view_inprogr.setRootIndex(model_inprogr.setRootPath(self.app.inprogr_data_path)) + # self.list_view_inprogr.setFixedSize(600, 600) + self.list_view_inprogr.setRootIndex( + model_inprogr.setRootPath(self.app.inprogr_data_path) + ) self.list_view_inprogr.clicked.connect(self.on_item_inprogr_selected) self.inprogr_dir_layout.addWidget(self.list_view_inprogr) self.inprogress_layout.addLayout(self.inprogr_dir_layout) @@ -347,8 +386,8 @@ def main_window(self): dir_layout.addLayout(self.inprogress_layout) # Curated layout - self.train_dir_layout = QVBoxLayout() - self.train_dir_layout.setContentsMargins(0,0,0,0) + self.train_dir_layout = QVBoxLayout() + self.train_dir_layout.setContentsMargins(0, 0, 0, 0) self.label_train = QLabel(self) self.label_train.setText("Curated dataset") self.label_train.setMinimumHeight(50) @@ -362,6 +401,7 @@ def main_window(self): model_train = MyQFileSystemModel(app=self.app) # model_train.setNameFilters(["*_seg.tiff"]) #self.list_view = QListView(self) + self.list_view_train = QTreeView(self) self.list_view_train.setToolTip("Select an image, click it, then press Enter") # self.list_view_train.setIconSize(QSize(50,50)) @@ -375,12 +415,14 @@ def main_window(self): for i in range(1,4): self.list_view_train.hideColumn(i) - #self.list_view_train.setFixedSize(600, 600) - self.list_view_train.setRootIndex(model_train.setRootPath(self.app.train_data_path)) + # self.list_view_train.setFixedSize(600, 600) + self.list_view_train.setRootIndex( + model_train.setRootPath(self.app.train_data_path) + ) self.list_view_train.clicked.connect(self.on_item_train_selected) self.train_dir_layout.addWidget(self.list_view_train) self.curated_layout.addLayout(self.train_dir_layout) - + self.train_button = QPushButton("Train Model", self) self.train_button.setStyleSheet( """QPushButton @@ -395,6 +437,7 @@ def main_window(self): "QPushButton:pressed { background-color: #7bc432; }" ) self.train_button.clicked.connect(self.on_train_button_clicked) # add selected image + self.curated_layout.addWidget(self.train_button, alignment=Qt.AlignCenter) dir_layout.addLayout(self.curated_layout) @@ -402,58 +445,68 @@ def main_window(self): # add progress bar progress_layout = QHBoxLayout() - progress_layout.addStretch(1) + progress_layout.addStretch(1) self.progress_bar = QProgressBar(self) self.progress_bar.setMinimumWidth(1000) self.progress_bar.setAlignment(Qt.AlignCenter) self.progress_bar.setRange(0,1) + progress_layout.addWidget(self.progress_bar) main_layout.addLayout(progress_layout) self.setLayout(main_layout) self.show() - def on_item_train_selected(self, item): - ''' - Is called once an image is selected in the 'curated dataset' folder - ''' + def on_item_train_selected(self, item: QModelIndex) -> None: + """ + Is called once an image is selected in the 'curated dataset' folder. + + :param item: The selected item from the 'curated dataset' folder. + :type item: QModelIndex + """ self.app.cur_selected_img = item.data() self.app.cur_selected_path = self.app.train_data_path - def on_item_eval_selected(self, item): - ''' - Is called once an image is selected in the 'uncurated dataset' folder - ''' + def on_item_eval_selected(self, item: QModelIndex) -> None: + """ + Is called once an image is selected in the 'uncurated dataset' folder. + + :param item: The selected item from the 'uncurated dataset' folder. + :type item: QModelIndex + """ self.app.cur_selected_img = item.data() self.app.cur_selected_path = self.app.eval_data_path - def on_item_inprogr_selected(self, item): - ''' - Is called once an image is selected in the 'in progress' folder - ''' + def on_item_inprogr_selected(self, item: QModelIndex) -> None: + """ + Is called once an image is selected in the 'in progress' folder. + + :param item: The selected item from the 'in progress' folder. + :type item: QModelIndex + """ self.app.cur_selected_img = item.data() self.app.cur_selected_path = self.app.inprogr_data_path - def on_train_button_clicked(self): - ''' - Is called once user clicks the "Train Model" button - ''' + def on_train_button_clicked(self) -> None: + """ + Is called once user clicks the "Train Model" button. + """ self.train_button.setEnabled(False) - self.progress_bar.setRange(0,0) + self.progress_bar.setRange(0, 0) # initialise the worker thread - self.worker_thread = WorkerThread(app=self.app, task='train') + self.worker_thread = WorkerThread(app=self.app, task="train") self.worker_thread.task_finished.connect(self.on_finished) # start the worker thread to train self.worker_thread.start() - def on_run_inference_button_clicked(self): - ''' - Is called once user clicks the "Generate Labels" button - ''' + def on_run_inference_button_clicked(self) -> None: + """ + Is called once user clicks the "Generate Labels" button. + """ self.inference_button.setEnabled(False) - self.progress_bar.setRange(0,0) + self.progress_bar.setRange(0, 0) # initialise the worker thread - self.worker_thread = WorkerThread(app=self.app, task='inference') + self.worker_thread = WorkerThread(app=self.app, task="inference") self.worker_thread.task_finished.connect(self.on_finished) # start the worker thread to run inference self.worker_thread.start() @@ -473,12 +526,15 @@ def on_launch_napari_button_clicked(self): message_text = f"An error occurred while opening the Napari window: {str(e)}" _ = self.create_warning_box(message_text, message_title="Error") - def on_finished(self, result): - ''' - Is called once the worker thread emits the on finished signal - ''' + def on_finished(self, result: tuple) -> None: + """ + Is called once the worker thread emits the on finished signal. + + :param result: The result emitted by the worker thread. See return type of WorkerThread.run + :type result: tuple + """ # Stop the pulsation - self.progress_bar.setRange(0,1) + self.progress_bar.setRange(0, 1) # Display message of result message_text, message_title = result _ = self.create_warning_box(message_text, message_title) @@ -500,20 +556,21 @@ def on_finished(self, result): from dcp_client.utils.fsimagestorage import FilesystemImageStorage from dcp_client.utils import settings from dcp_client.utils.sync_src_dst import DataRSync + settings.init() image_storage = FilesystemImageStorage() ml_model = BentomlModel() - data_sync = DataRSync(user_name="local", - host_name="local", - server_repo_path=None) + data_sync = DataRSync(user_name="local", host_name="local", server_repo_path=None) app = QApplication(sys.argv) - app_ = Application(ml_model=ml_model, - syncer=data_sync, - image_storage=image_storage, - server_ip='0.0.0.0', - server_port=7010, - eval_data_path='data', - train_data_path='', # set path - inprogr_data_path='') # set path + app_ = Application( + ml_model=ml_model, + syncer=data_sync, + image_storage=image_storage, + server_ip="0.0.0.0", + server_port=7010, + eval_data_path="data", + train_data_path="", # set path + inprogr_data_path="", + ) # set path window = MainWindow(app=app_) - sys.exit(app.exec()) \ No newline at end of file + sys.exit(app.exec()) diff --git a/src/client/dcp_client/gui/napari_window.py b/src/client/dcp_client/gui/napari_window.py index c386eac2..2b3c065e 100644 --- a/src/client/dcp_client/gui/napari_window.py +++ b/src/client/dcp_client/gui/napari_window.py @@ -1,24 +1,32 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from copy import deepcopy from qtpy.QtWidgets import QPushButton, QComboBox, QLabel, QGridLayout from qtpy.QtCore import Qt import napari +import numpy as np if TYPE_CHECKING: from dcp_client.app import Application -from dcp_client.utils.utils import get_path_stem, check_equal_arrays, Compute4Mask +from dcp_client.utils.utils import get_path_stem, check_equal_arrays +from dcp_client.utils.compute4mask import Compute4Mask from dcp_client.gui._my_widget import MyWidget + class NapariWindow(MyWidget): - '''Napari Window Widget object. + """Napari Window Widget object. Opens the napari image viewer to view and fix the labeles. - :param app: - :type Application - ''' + :param app: The Application instance. + :type app: Application + """ + + def __init__(self, app: Application) -> None: + """Initializes the NapariWindow. - def __init__(self, app: Application): + :param app: The Application instance. + :type app: Application + """ super().__init__() self.app = app self.setWindowTitle("napari viewer") @@ -33,17 +41,24 @@ def __init__(self, app: Application): self.viewer = napari.Viewer(show=False) self.viewer.add_image(img, name=get_path_stem(self.app.cur_selected_img)) for seg_file in self.seg_files: - self.viewer.add_labels(self.app.load_image(seg_file), name=get_path_stem(seg_file)) + self.viewer.add_labels( + self.app.load_image(seg_file), name=get_path_stem(seg_file) + ) main_window = self.viewer.window._qt_window layout = QGridLayout() layout.addWidget(main_window, 0, 0, 1, 4) # select the first seg as the currently selected layer if there are any segs - if len(self.seg_files): + if ( + len(self.seg_files) + and len(self.viewer.layers[get_path_stem(self.seg_files[0])].data.shape) > 2 + ): self.cur_selected_seg = self.viewer.layers.selection.active.name self.layer = self.viewer.layers[self.cur_selected_seg] - self.viewer.layers.selection.events.changed.connect(self.on_seg_channel_changed) + self.viewer.layers.selection.events.changed.connect( + self.on_seg_channel_changed + ) # set first mask as active by default self.active_mask_index = 0 self.viewer.dims.events.current_step.connect(self.axis_changed) @@ -54,17 +69,27 @@ def __init__(self, app: Application): for seg_file in self.seg_files: layer_name = get_path_stem(seg_file) # get unique instance labels for each seg - self.original_instance_mask[layer_name] = deepcopy(self.viewer.layers[layer_name].data[0]) - self.original_class_mask[layer_name] = deepcopy(self.viewer.layers[layer_name].data[1]) + self.original_instance_mask[layer_name] = deepcopy( + self.viewer.layers[layer_name].data[0] + ) + self.original_class_mask[layer_name] = deepcopy( + self.viewer.layers[layer_name].data[1] + ) # compute unique instance ids - self.instances[layer_name] = Compute4Mask.get_unique_objects(self.original_instance_mask[layer_name]) + self.instances[layer_name] = Compute4Mask.get_unique_objects( + self.original_instance_mask[layer_name] + ) # remove border from class mask - self.contours_mask[layer_name] = Compute4Mask.get_contours(self.original_instance_mask[layer_name], contours_level=0.8) - self.viewer.layers[layer_name].data[1][self.contours_mask[layer_name]!=0] = 0 - + self.contours_mask[layer_name] = Compute4Mask.get_contours( + self.original_instance_mask[layer_name], contours_level=0.8 + ) + self.viewer.layers[layer_name].data[1][ + self.contours_mask[layer_name] != 0 + ] = 0 + self.qctrl = self.viewer.window.qt_viewer.controls.widgets[self.layer] - if self.layer.data.shape[0] >= 2: + if len(self.layer.data.shape) > 2: # User hint message_label = QLabel('Choose an active mask') message_label.setStyleSheet( @@ -77,16 +102,16 @@ def __init__(self, app: Application): padding: 8px 16px;""" ) - - message_label.setAlignment(Qt.AlignRight) layout.addWidget(message_label, 1, 0) - + # Drop list to choose which is an active mask self.mask_choice_dropdown = QComboBox() self.mask_choice_dropdown.setEnabled(False) - self.mask_choice_dropdown.addItem('Instance Segmentation Mask', userData=0) - self.mask_choice_dropdown.addItem('Labels Mask', userData=1) + self.mask_choice_dropdown.addItem( + "Instance Segmentation Mask", userData=0 + ) + self.mask_choice_dropdown.addItem("Labels Mask", userData=1) layout.addWidget(self.mask_choice_dropdown, 1, 1) # when user has chosen the mask, we don't want to change it anymore to avoid errors @@ -128,7 +153,9 @@ def __init__(self, app: Application): ) layout.addWidget(add_to_inprogress_button, 2, 0, 1, 2) - add_to_inprogress_button.clicked.connect(self.on_add_to_inprogress_button_clicked) + add_to_inprogress_button.clicked.connect( + self.on_add_to_inprogress_button_clicked + ) add_to_curated_button = QPushButton('Move to \'Curated dataset\' folder') add_to_curated_button.setStyleSheet( @@ -144,12 +171,13 @@ def __init__(self, app: Application): "QPushButton:pressed { background-color: #006FBA; }" ) + layout.addWidget(add_to_curated_button, 2, 2, 1, 2) add_to_curated_button.clicked.connect(self.on_add_to_curated_button_clicked) self.setLayout(layout) - def set_editable_mask(self): + def set_editable_mask(self) -> None: """ This function is not implemented. In theory the use can choose between which mask to edit. Currently painting and erasing is only possible on instance mask and in the class mask only @@ -157,108 +185,139 @@ def set_editable_mask(self): """ pass - def on_seg_channel_changed(self, event): + def on_seg_channel_changed(self, event) -> None: """ Is triggered each time the user selects a different layer in the viewer. """ if (act := self.viewer.layers.selection.active) is not None: # updater cur_selected_seg with the new selection from the user self.cur_selected_seg = act.name - if type(self.viewer.layers[self.cur_selected_seg]) == napari.layers.Image: pass + if type(self.viewer.layers[self.cur_selected_seg]) == napari.layers.Image: + pass # set self.layer to new selection from user - elif self.layer is not None: self.layer = self.viewer.layers[self.cur_selected_seg] - else: pass - - def axis_changed(self, event): + elif self.layer is not None: + self.layer = self.viewer.layers[self.cur_selected_seg] + else: + pass + + def axis_changed(self, event) -> None: """ - Is triggered each time the user switches the viewer between the mask channels. At this point the class mask + Is triggered each time the user switches the viewer between the mask channels. At this point the class mask needs to be updated according to the changes made tot the instance segmentation mask. """ self.active_mask_index = self.viewer.dims.current_step[0] masks = deepcopy(self.layer.data) # if user has switched to the instance mask - if self.active_mask_index==0: + if self.active_mask_index == 0: class_mask_with_contours = Compute4Mask.add_contour(masks[1], masks[0]) - if not check_equal_arrays(class_mask_with_contours.astype(bool), self.original_class_mask[self.cur_selected_seg].astype(bool)): + if not check_equal_arrays( + class_mask_with_contours.astype(bool), + self.original_class_mask[self.cur_selected_seg].astype(bool), + ): self.update_instance_mask(masks[0], masks[1]) self.switch_to_instance_mask() # else if user has switched to the class mask - elif self.active_mask_index==1: - if not check_equal_arrays(masks[0], self.original_instance_mask[self.cur_selected_seg]): + elif self.active_mask_index == 1: + if not check_equal_arrays( + masks[0], self.original_instance_mask[self.cur_selected_seg] + ): self.update_labels_mask(masks[0]) self.switch_to_labels_mask() - def switch_to_instance_mask(self): + def switch_to_instance_mask(self) -> None: """ - Switch the application to the active mask mode by enabling 'paint_button', 'erase_button' + Switch the application to the active mask mode by enabling 'paint_button', 'erase_button' and 'fill_button'. """ self.switch_controls("paint_button", True) self.switch_controls("erase_button", True) self.switch_controls("fill_button", True) - def switch_to_labels_mask(self): + def switch_to_labels_mask(self) -> None: """ Switch the application to non-active mask mode by enabling 'fill_button' and disabling 'paint_button' and 'erase_button'. """ if self.cur_selected_seg in [layer.name for layer in self.viewer.layers]: - self.viewer.layers[self.cur_selected_seg].mode = 'pan_zoom' - info_message_paint = "Painting objects is only possible in the instance layer for now." - info_message_erase = "Erasing objects is only possible in the instance layer for now." + self.viewer.layers[self.cur_selected_seg].mode = "pan_zoom" + info_message_paint = ( + "Painting objects is only possible in the instance layer for now." + ) + info_message_erase = ( + "Erasing objects is only possible in the instance layer for now." + ) self.switch_controls("paint_button", False, info_message_paint) self.switch_controls("erase_button", False, info_message_erase) - self.switch_controls("fill_button", True) + self.switch_controls("fill_button", True) - def update_labels_mask(self, instance_mask): - """ - If the instance mask has changed since the last switch between channels the class mask needs to be updated accordingly. - - Parameters: - - instance_mask (numpy.ndarray): The updated instance mask, changed by the user. - - labels_mask (numpy.ndarray): The existing labels mask, which needs to be updated. + def update_labels_mask(self, instance_mask: np.ndarray) -> None: + """Updates the class mask based on changes in the instance mask. + + If the instance mask has changed since the last switch between channels, the class mask needs to be updated accordingly. + + :param instance_mask: The updated instance mask, changed by the user. + :type instance_mask: numpy.ndarray + :return: None """ - self.original_class_mask[self.cur_selected_seg] = Compute4Mask.compute_new_labels_mask(self.original_class_mask[self.cur_selected_seg], - instance_mask, - self.original_instance_mask[self.cur_selected_seg], - self.instances[self.cur_selected_seg]) + self.original_class_mask[self.cur_selected_seg] = ( + Compute4Mask.compute_new_labels_mask( + self.original_class_mask[self.cur_selected_seg], + instance_mask, + self.original_instance_mask[self.cur_selected_seg], + self.instances[self.cur_selected_seg], + ) + ) # update original instance mask and instances self.original_instance_mask[self.cur_selected_seg] = instance_mask - self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects(self.original_instance_mask[self.cur_selected_seg]) + self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects( + self.original_instance_mask[self.cur_selected_seg] + ) # compute contours to remove from class mask visualisation - self.contours_mask[self.cur_selected_seg] = Compute4Mask.get_contours(instance_mask, contours_level=0.8) + self.contours_mask[self.cur_selected_seg] = Compute4Mask.get_contours( + instance_mask, contours_level=0.8 + ) vis_labels_mask = deepcopy(self.original_class_mask[self.cur_selected_seg]) - vis_labels_mask[self.contours_mask[self.cur_selected_seg]!=0] = 0 + vis_labels_mask[self.contours_mask[self.cur_selected_seg] != 0] = 0 # update the viewer self.layer.data[1] = vis_labels_mask self.layer.refresh() - def update_instance_mask(self, instance_mask, labels_mask): - """ - If the labels mask has changed **only if an object has been removed** the instance mask is updated. - - Parameters: - - instance_mask (numpy.ndarray): The existing instance mask, which needs to be updated. - - labels_mask (numpy.ndarray): The updated labels mask, changed by the user. + def update_instance_mask( + self, instance_mask: np.ndarray, labels_mask: np.ndarray + ) -> None: + """Updates the instance mask based on changes in the labels mask. + + If the labels mask has changed, but only if an object has been removed, the instance mask is updated accordingly. + + :param instance_mask: The existing instance mask, which needs to be updated. + :type instance_mask: numpy.ndarray + :param labels_mask: The updated labels mask, changed by the user. + :type labels_mask: numpy.ndarray """ # add contours back to labels mask labels_mask = Compute4Mask.add_contour(labels_mask, instance_mask) # and compute the updated instance mask - self.original_instance_mask[self.cur_selected_seg] = Compute4Mask.compute_new_instance_mask(labels_mask, - instance_mask) - self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects(self.original_instance_mask[self.cur_selected_seg]) + self.original_instance_mask[self.cur_selected_seg] = ( + Compute4Mask.compute_new_instance_mask(labels_mask, instance_mask) + ) + self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects( + self.original_instance_mask[self.cur_selected_seg] + ) self.original_class_mask[self.cur_selected_seg] = labels_mask # update the viewer self.layer.data[0] = self.original_instance_mask[self.cur_selected_seg] self.layer.refresh() - def switch_controls(self, target_widget, status: bool, info_message=None): - """ - Enable or disable a specific widget. - - Parameters: - - target_widget (str): The name of the widget to be controlled within the QCtrl object. - - status (bool): If True, the widget will be enabled; if False, it will be disabled. - - info_message (str or None): Optionally add an info message when hovering over some widget. + def switch_controls( + self, target_widget: str, status: bool, info_message: Optional[str] = None + ) -> None: + """Enables or disables a specific widget. + + :param target_widget: The name of the widget to be controlled within the QCtrl object. + :type target_widget: str + :param status: If True, the widget will be enabled; if False, it will be disabled. + :type status: bool + :param info_message: Optionally add an info message when hovering over some widget. Default is None. + :type info_message: str or None """ try: getattr(self.qctrl, target_widget).setEnabled(status) @@ -267,70 +326,76 @@ def switch_controls(self, target_widget, status: bool, info_message=None): except: pass - def on_add_to_curated_button_clicked(self): - ''' - Defines what happens when the "Move to curated dataset folder" button is clicked. - ''' - if self.app.cur_selected_path == str(self.app.train_data_path): - message_text = "Image is already in the \'Curated data\' folder and should not be changed again" + def on_add_to_curated_button_clicked(self) -> None: + """Defines what happens when the "Move to curated dataset folder" button is clicked.""" + if self.app.cur_selected_path == str(self.app.train_data_path): + message_text = "Image is already in the 'Curated data' folder and should not be changed again" _ = self.create_warning_box(message_text, message_title="Warning") return - + # take the name of the currently selected layer (by the user) seg_name_to_save = self.viewer.layers.selection.active.name # TODO if more than one item is selected this will break! - if '_seg' not in seg_name_to_save: + if "_seg" not in seg_name_to_save: message_text = ( - "Please select the segmenation you wish to save from the layer list." - "The labels layer should have the same name as the image to which it corresponds, followed by _seg." + "Please select the segmenation you wish to save from the layer list." + "The labels layer should have the same name as the image to which it corresponds, followed by _seg." ) _ = self.create_warning_box(message_text, message_title="Warning") return - + # Save the (changed) seg seg = self.viewer.layers[seg_name_to_save].data seg[1] = Compute4Mask.add_contour(seg[1], seg[0]) - annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = Compute4Mask.assert_consistent_labels(seg) + annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = ( + Compute4Mask.assert_consistent_labels(seg) + ) if annot_error: - message_text = ("There seems to be a problem with your mask. We expect each object to be a connected component. For object(s) with ID(s) \n" - +str(faulty_ids_annot)+"\n" - "more than one connected component was found. Please go back and fix this.") + message_text = ( + "There seems to be a problem with your mask. We expect each object to be a connected component. For object(s) with ID(s) \n" + + str(faulty_ids_annot) + + "\n" + "more than one connected component was found. Please go back and fix this." + ) self.create_warning_box(message_text, "Warning") elif mask_mismatch_error: - message_text = ("There seems to be a mismatch between your class and instance masks for object(s) with ID(s) \n" - +str(faulty_ids_missmatch)+"\n" - "This should not occur and will cause a problem later during model training. Please go back and check.") + message_text = ( + "There seems to be a mismatch between your class and instance masks for object(s) with ID(s) \n" + + str(faulty_ids_missmatch) + + "\n" + "This should not occur and will cause a problem later during model training. Please go back and check." + ) self.create_warning_box(message_text, "Warning") - else: + else: # Move original image self.app.move_images(self.app.train_data_path) - self.app.save_image(self.app.train_data_path, seg_name_to_save+'.tiff', seg) + self.app.save_image( + self.app.train_data_path, seg_name_to_save + ".tiff", seg + ) # We remove seg from the current directory if it exists (both eval and inprogr allowed) self.app.delete_images(self.seg_files) - # TODO Create the Archive folder for the rest? Or move them as well? + # TODO Create the Archive folder for the rest? Or move them as well? self.viewer.close() self.close() - def on_add_to_inprogress_button_clicked(self): - ''' - Defines what happens when the "Move to curation in progress folder" button is clicked. - ''' + def on_add_to_inprogress_button_clicked(self) -> None: + """Defines what happens when the "Move to curation in progress folder" button is clicked.""" # TODO: Do we allow this? What if they moved it by mistake? User can always manually move from their folders?) if self.app.cur_selected_path == str(self.app.train_data_path): - message_text = "Images from '\Curated data'\ folder can not be moved back to \'Curatation in progress\' folder." + message_text = "Images from '\Curated data'\ folder can not be moved back to 'Curatation in progress' folder." _ = self.create_warning_box(message_text, message_title="Warning") return - + # take the name of the currently selected layer (by the user) seg_name_to_save = self.viewer.layers.selection.active.name # TODO if more than one item is selected this will break! - if '_seg' not in seg_name_to_save: + if "_seg" not in seg_name_to_save: message_text = ( - "Please select the segmenation you wish to save from the layer list." - "The labels layer should have the same name as the image to which it corresponds, followed by _seg." + "Please select the segmenation you wish to save from the layer list." + "The labels layer should have the same name as the image to which it corresponds, followed by _seg." ) _ = self.create_warning_box(message_text, message_title="Warning") return @@ -340,8 +405,7 @@ def on_add_to_inprogress_button_clicked(self): # Save the (changed) seg - this will overwrite existing seg if seg name hasn't been changed in viewer seg = self.viewer.layers[seg_name_to_save].data seg[1] = Compute4Mask.add_contour(seg[1], seg[0]) - self.app.save_image(self.app.inprogr_data_path, seg_name_to_save+'.tiff', seg) - + self.app.save_image(self.app.inprogr_data_path, seg_name_to_save + ".tiff", seg) + self.viewer.close() self.close() - \ No newline at end of file diff --git a/src/client/dcp_client/gui/welcome_window.py b/src/client/dcp_client/gui/welcome_window.py index d9dc8211..fe112f37 100644 --- a/src/client/dcp_client/gui/welcome_window.py +++ b/src/client/dcp_client/gui/welcome_window.py @@ -1,23 +1,37 @@ from __future__ import annotations from typing import TYPE_CHECKING -from qtpy.QtWidgets import QPushButton, QVBoxLayout, QHBoxLayout, QLabel, QFileDialog, QLineEdit +from qtpy.QtWidgets import ( + QPushButton, + QVBoxLayout, + QHBoxLayout, + QLabel, + QFileDialog, + QLineEdit, +) from qtpy.QtCore import Qt, QEvent + from dcp_client.gui.main_window import MainWindow from dcp_client.gui._my_widget import MyWidget if TYPE_CHECKING: from dcp_client.app import Application + class WelcomeWindow(MyWidget): - '''Welcome Window Widget object. - The first window of the application providing a dialog that allows users to select directories. + """Welcome Window Widget object. + The first window of the application providing a dialog that allows users to select directories. Currently supported image file types that can be selected for segmentation are: .jpg, .jpeg, .png, .tiff, .tif. By clicking 'start' the MainWindow is called. - ''' + """ + + def __init__(self, app: Application) -> None: + """Initializes the WelcomeWindow. - def __init__(self, app: Application): + :param app: The Application instance. + :type app: Application + """ super().__init__() self.app = app self.setWindowTitle("Welcome to Helmholtz AI Data-Centric Tool") @@ -40,6 +54,7 @@ def __init__(self, app: Application): ) self.main_layout.addWidget(instructions_label) + input_layout = QHBoxLayout() self.text_layout = QVBoxLayout() @@ -49,11 +64,13 @@ def __init__(self, app: Application): val_label = QLabel(self) val_label.setText('Uncurated dataset path:') + inprogr_label = QLabel(self) - inprogr_label.setText('Curation in progress path:') + inprogr_label.setText("Curation in progress path:") train_label = QLabel(self) train_label.setText('Curated dataset path:') + self.text_layout.addWidget(val_label) self.text_layout.addWidget(inprogr_label) self.text_layout.addWidget(train_label) @@ -76,17 +93,45 @@ def __init__(self, app: Application): # self.train_textbox.setToolTip("Double-click to browse") self.train_textbox.textEdited.connect(lambda x: self.on_text_changed(self.train_textbox, "train", x)) self.train_textbox.installEventFilter(self) + ''' + self.val_textbox.textEdited.connect( + lambda x: self.on_text_changed(self.val_textbox, "eval", x) + ) + + self.inprogr_textbox = QLineEdit(self) + self.inprogr_textbox.textEdited.connect( + lambda x: self.on_text_changed(self.inprogr_textbox, "inprogress", x) + ) + + self.train_textbox = QLineEdit(self) + self.train_textbox.textEdited.connect( + lambda x: self.on_text_changed(self.train_textbox, "train", x) + ) + ''' self.path_layout.addWidget(self.val_textbox) self.path_layout.addWidget(self.inprogr_textbox) self.path_layout.addWidget(self.train_textbox) + self.file_open_button_val = QPushButton("Browse", self) + self.file_open_button_val.show() + self.file_open_button_val.clicked.connect(self.browse_eval_clicked) + self.file_open_button_prog = QPushButton("Browse", self) + self.file_open_button_prog.show() + self.file_open_button_prog.clicked.connect(self.browse_inprogr_clicked) + self.file_open_button_train = QPushButton("Browse", self) + self.file_open_button_train.show() + self.file_open_button_train.clicked.connect(self.browse_train_clicked) + self.button_layout.addWidget(self.file_open_button_val) + self.button_layout.addWidget(self.file_open_button_prog) + self.button_layout.addWidget(self.file_open_button_train) + input_layout.addLayout(self.text_layout) input_layout.addLayout(self.path_layout) input_layout.addLayout(self.button_layout) self.main_layout.addLayout(input_layout) - self.start_button = QPushButton('Start', self) + self.start_button = QPushButton("Start", self) self.start_button.setFixedSize(120, 30) self.start_button.setStyleSheet( """QPushButton @@ -102,7 +147,7 @@ def __init__(self, app: Application): ) self.start_button.show() # check if we need to upload data to server - self.done_upload = False # we only do once + self.done_upload = False # we only do once if self.app.syncer.host_name == "local": self.start_button.clicked.connect(self.start_main) else: @@ -112,11 +157,10 @@ def __init__(self, app: Application): self.show() - def browse_eval_clicked(self): - ''' - Activates when the user clicks the button to choose the evaluation directory (QFileDialog) and + def browse_eval_clicked(self) -> None: + """Activates when the user clicks the button to choose the evaluation directory (QFileDialog) and displays the name of the evaluation directory chosen in the validation textbox line (QLineEdit). - ''' + """ self.fd = QFileDialog() try: self.fd.setFileMode(QFileDialog.Directory) @@ -125,12 +169,11 @@ def browse_eval_clicked(self): self.val_textbox.setText(self.app.eval_data_path) finally: self.fd = None - - def browse_train_clicked(self): - ''' - Activates when the user clicks the button to choose the train directory (QFileDialog) and + + def browse_train_clicked(self) -> None: + """Activates when the user clicks the button to choose the train directory (QFileDialog) and displays the name of the train directory chosen in the train textbox line (QLineEdit). - ''' + """ fd = QFileDialog() fd.setFileMode(QFileDialog.Directory) @@ -138,11 +181,18 @@ def browse_train_clicked(self): self.app.train_data_path = fd.selectedFiles()[0] self.train_textbox.setText(self.app.train_data_path) - def on_text_changed(self, field_obj, field_name, text): - ''' - Update data paths based on text changes in input fields. + def on_text_changed(self, field_obj: QLineEdit, field_name: str, text: str) -> None: + """ + Update data paths based on text changes in input fields. Used for copying paths in the welcome window. - ''' + + :param field_obj: The QLineEdit object. + :type field_obj: QLineEdit + :param field_name: The name of the data field being updated. + :type field_name: str + :param text: The updated text. + :type text: str + """ if field_name == "train": self.app.train_data_path = text @@ -168,24 +218,32 @@ def browse_inprogr_clicked(self): ''' Activates when the user clicks the button to choose the curation in progress directory (QFileDialog) and displays the name of the evaluation directory chosen in the validation textbox line (QLineEdit). - ''' + """ fd = QFileDialog() fd.setFileMode(QFileDialog.Directory) - if fd.exec_(): # Browse clicked - self.app.inprogr_data_path = fd.selectedFiles()[0] #TODO: case when browse is clicked but nothing is specified - currently it is filled with os.getcwd() + if fd.exec_(): # Browse clicked + self.app.inprogr_data_path = fd.selectedFiles()[ + 0 + ] # TODO: case when browse is clicked but nothing is specified - currently it is filled with os.getcwd() self.inprogr_textbox.setText(self.app.inprogr_data_path) - - def start_main(self): - ''' - Starts the main window after the user clicks 'Start' and only if both evaluation and train directories are chosen and all unique. - ''' - - if len({self.app.inprogr_data_path, self.app.train_data_path, self.app.eval_data_path})<3: + + def start_main(self) -> None: + """Starts the main window after the user clicks 'Start' and only if both evaluation and train directories are chosen and all unique.""" + + if ( + len( + { + self.app.inprogr_data_path, + self.app.train_data_path, + self.app.eval_data_path, + } + ) + < 3 + ): self.message_text = "All directory names must be distinct." _ = self.create_warning_box(self.message_text, message_title="Warning") - elif self.app.train_data_path and self.app.eval_data_path: self.hide() self.mw = MainWindow(self.app) @@ -193,27 +251,35 @@ def start_main(self): self.message_text = "You need to specify a folder both for your uncurated and curated dataset (even if the curated folder is currently empty). Please go back and select folders for both." _ = self.create_warning_box(self.message_text, message_title="Warning") - def start_upload_and_main(self): - ''' + def start_upload_and_main(self) -> None: + """ If the configs are set to use remote not local server then the user is asked to confirm the upload of their data to the server and the upload starts before launching the main window. - ''' + """ if self.done_upload is False: - message_text = ("Your current configurations are set to run some operations on the cloud. \n" - "For this we need to upload your data to our server." - "We will now upload your data. Click ok to continue. \n" - "If you do not agree close the application and contact your software provider.") - usr_response = self.create_warning_box(message_text, message_title="Warning", add_cancel_btn=True) - if usr_response: + message_text = ( + "Your current configurations are set to run some operations on the cloud. \n" + "For this we need to upload your data to our server." + "We will now upload your data. Click ok to continue. \n" + "If you do not agree close the application and contact your software provider." + ) + usr_response = self.create_warning_box( + message_text, message_title="Warning", add_cancel_btn=True + ) + if usr_response: success_up1, success_up2, _, _ = self.app.upload_data_to_server() - if success_up1=="Error" or success_up2=="Error": - message_text = ("An error has occured during data upload to the server. \n" - "Please check your configuration file and ensure that the server connection settings are correct and you have been given access to the server. \n" - "If the problem persists contact your software provider. Exiting now.") - usr_response = self.create_warning_box(message_text, message_title="Error") - self.close() - else: + if success_up1 == "Error" or success_up2 == "Error": + message_text = ( + "An error has occured during data upload to the server. \n" + "Please check your configuration file and ensure that the server connection settings are correct and you have been given access to the server. \n" + "If the problem persists contact your software provider. Exiting now." + ) + usr_response = self.create_warning_box( + message_text, message_title="Error" + ) + self.close() + else: self.done_upload = True self.start_upload_and_main() - else: self.start_main() - \ No newline at end of file + else: + self.start_main() diff --git a/src/client/dcp_client/main.py b/src/client/dcp_client/main.py index 0f9da389..ef16a971 100644 --- a/src/client/dcp_client/main.py +++ b/src/client/dcp_client/main.py @@ -1,37 +1,63 @@ +import argparse import sys +import warnings from os import path -from PyQt5.QtWidgets import QApplication +from dcp_client.app import Application +from dcp_client.gui.welcome_window import WelcomeWindow from dcp_client.utils import settings -from dcp_client.utils.fsimagestorage import FilesystemImageStorage from dcp_client.utils.bentoml_model import BentomlModel +from dcp_client.utils.fsimagestorage import FilesystemImageStorage from dcp_client.utils.sync_src_dst import DataRSync from dcp_client.utils.utils import read_config -from dcp_client.app import Application -from dcp_client.gui.welcome_window import WelcomeWindow +from PyQt5.QtWidgets import QApplication -import warnings -warnings.simplefilter('ignore') +warnings.simplefilter("ignore") def main(): + settings.init() - dir_name = path.dirname(path.abspath(sys.argv[0])) - server_config = read_config('server', config_path = path.join(dir_name, 'config.cfg')) + + dir_name = path.dirname(path.abspath(__file__)) + + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--mode", + choices=["local", "remote"], + required=True, + help="Choose mode: local or remote", + ) + args = parser.parse_args() + + if args.mode == "local": + server_config = read_config( + "server", config_path=path.join(dir_name, "config.yaml") + ) + elif args.mode == "remote": + server_config = read_config( + "server", config_path=path.join(dir_name, "config_remote.yaml") + ) image_storage = FilesystemImageStorage() ml_model = BentomlModel() - data_sync = DataRSync(user_name=server_config["user"], - host_name=server_config["host"], - server_repo_path=server_config["data-path"]) - welcome_app = Application(ml_model=ml_model, - syncer=data_sync, - image_storage=image_storage, - server_ip=server_config["ip"], - server_port=server_config["port"]) + data_sync = DataRSync( + user_name=server_config["user"], + host_name=server_config["host"], + server_repo_path=server_config["data-path"], + ) + welcome_app = Application( + ml_model=ml_model, + syncer=data_sync, + image_storage=image_storage, + server_ip=server_config["ip"], + server_port=server_config["port"], + ) app = QApplication(sys.argv) window = WelcomeWindow(welcome_app) sys.exit(app.exec()) + if __name__ == "__main__": main() diff --git a/src/client/dcp_client/utils/__init__.py b/src/client/dcp_client/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/client/dcp_client/utils/bentoml_model.py b/src/client/dcp_client/utils/bentoml_model.py index a6e72643..5f57b421 100644 --- a/src/client/dcp_client/utils/bentoml_model.py +++ b/src/client/dcp_client/utils/bentoml_model.py @@ -1,44 +1,92 @@ import asyncio -from typing import Optional +from typing import Optional, List from bentoml.client import Client as BentoClient from bentoml.exceptions import BentoMLException +import numpy as np from dcp_client.app import Model + class BentomlModel(Model): + """BentomlModel class for connecting to a BentoML server and running training and inference tasks.""" + + def __init__(self, client: Optional[BentoClient] = None): + """Initializes the BentomlModel. - def __init__( - self, - client: Optional[BentoClient] = None - ): + :param client: Optional BentoClient instance. If None, it will be initialized during connection. + :type client: Optional[BentoClient] + """ self.client = client - def connect(self, ip: str = '0.0.0.0', port: int = 7010): - url = f"http://{ip}:{port}" #"http://0.0.0.0:7010" + def connect(self, ip: str = "0.0.0.0", port: int = 7010) -> bool: + """Connects to the BentoML server. + + :param ip: IP address of the BentoML server. Default is '0.0.0.0'. + :type ip: str + :param port: Port number of the BentoML server. Default is 7010. + :type port: int + :return: True if connection is successful, False otherwise. + :rtype: bool + """ + url = f"http://{ip}:{port}" # "http://0.0.0.0:7010" try: - self.client = BentoClient.from_url(url) + self.client = BentoClient.from_url(url) return True - except : return False # except ConnectionRefusedError - + except: + return False # except ConnectionRefusedError + @property - def is_connected(self): + def is_connected(self) -> bool: + """Checks if the BentomlModel is connected to the BentoML server. + + :return: True if connected, False otherwise. + :rtype: bool + """ return bool(self.client) - async def _run_train(self, data_path): + async def _run_train(self, data_path: str) -> Optional[str]: + """Runs the training task asynchronously. + + :param data_path: Path to the training data. + :type data_path: str + :return: Response from the server if successful, None otherwise. + :rtype: str, or None + """ try: response = await self.client.async_train(data_path) return response - except BentoMLException: return None + except BentoMLException: + return None + + def run_train(self, data_path: str): + """Runs the training. - def run_train(self, data_path): + :param data_path: Path to the training data. + :type data_path: str + :return: Response from the server if successful, None otherwise. + """ return asyncio.run(self._run_train(data_path)) - async def _run_inference(self, data_path): + async def _run_inference(self, data_path: str) -> Optional[np.ndarray]: + """Runs the inference task asynchronously. + + :param data_path: Path to the data for inference. + :type data_path: str + :return: List of files not supported by the server if unsuccessful, otherwise returns None. + :rtype: np.ndarray, or None + """ try: response = await self.client.async_segment_image(data_path) return response - except BentoMLException: return None - - def run_inference(self, data_path): + except BentoMLException: + return None + + def run_inference(self, data_path: str) -> List: + """Runs the inference. + + :param data_path: Path to the data for inference. + :type data_path: str + :return: List of files not supported by the server if unsuccessful, otherwise returns None. + """ list_of_files_not_suported = asyncio.run(self._run_inference(data_path)) - return list_of_files_not_suported \ No newline at end of file + return list_of_files_not_suported diff --git a/src/client/dcp_client/utils/compute4mask.py b/src/client/dcp_client/utils/compute4mask.py new file mode 100644 index 00000000..f14bff5d --- /dev/null +++ b/src/client/dcp_client/utils/compute4mask.py @@ -0,0 +1,210 @@ +from typing import List +import numpy as np +from skimage.measure import find_contours, label +from skimage.draw import polygon_perimeter + + +class Compute4Mask: + """ + Compute4Mask provides methods for manipulating masks to make visualisation in the viewer easier. + """ + + @staticmethod + def get_contours( + instance_mask: np.ndarray, contours_level: float = None + ) -> np.ndarray: + """Find contours of objects in the instance mask. This function is used to identify the contours of the objects to prevent the problem of the merged + objects in napari window (mask). + + :param instance_mask: The instance mask array. + :type instance_mask: numpy.ndarray + :param contours_level: Value along which to find contours in the array. See skimage.measure.find_contours for more. + :type: None or float + :return: A binary mask where the contours of all objects in the instance segmentation mask are one and the rest is background. + :rtype: numpy.ndarray + + """ + instance_ids = Compute4Mask.get_unique_objects( + instance_mask + ) # get object instance labels ignoring background + contour_mask = np.zeros_like(instance_mask) + for instance_id in instance_ids: + # get a binary mask only of object + single_obj_mask = np.zeros_like(instance_mask) + single_obj_mask[instance_mask == instance_id] = 1 + try: + # compute contours for mask + contours = find_contours(single_obj_mask, contours_level) + # sometimes little dots appeas as additional contours so remove these + if len(contours) > 1: + contour_sizes = [contour.shape[0] for contour in contours] + contour = contours[contour_sizes.index(max(contour_sizes))].astype( + int + ) + else: + contour = contours[0] + # and draw onto contours mask + rr, cc = polygon_perimeter( + contour[:, 0], contour[:, 1], contour_mask.shape + ) + contour_mask[rr, cc] = instance_id + except: + print("Could not create contour for instance id", instance_id) + return contour_mask + + @staticmethod + def add_contour(labels_mask: np.ndarray, instance_mask: np.ndarray) -> np.ndarray: + """Add contours of objects to the labels mask. + + :param labels_mask: The class mask array without the contour pixels annotated. + :type labels_mask: numpy.ndarray + :param instance_mask: The instance mask array. + :type instance_mask: numpy.ndarray + :return: The updated class mask including contours. + :rtype: numpy.ndarray + """ + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + where_instances = np.where(instance_mask == instance_id) + # get unique class ids where the object is present + class_vals, counts = np.unique( + labels_mask[where_instances], return_counts=True + ) + # and take the class id which is most heavily represented + class_id = class_vals[np.argmax(counts)] + # make sure instance mask and class mask match + labels_mask[np.where(instance_mask == instance_id)] = class_id + return labels_mask + + @staticmethod + def compute_new_instance_mask( + labels_mask: np.ndarray, instance_mask: np.ndarray + ) -> np.ndarray: + """Given an updated labels mask, update also the instance mask accordingly. + So far the user can only remove an entire object in the labels mask view by + setting the color of the object to the background. + Therefore the instance mask can only change by entirely removing an object. + + :param labels_mask: The labels mask array, with changes made by the user. + :type labels_mask: numpy.ndarray + :param instance_mask: The existing instance mask, which needs to be updated. + :type instance_mask: numpy.ndarray + :return: The updated instance mask. + :rtype: numpy.ndarray + """ + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + unique_items_in_class_mask = list( + np.unique(labels_mask[instance_mask == instance_id]) + ) + if ( + len(unique_items_in_class_mask) == 1 + and unique_items_in_class_mask[0] == 0 + ): + instance_mask[instance_mask == instance_id] = 0 + return instance_mask + + @staticmethod + def compute_new_labels_mask( + labels_mask: np.ndarray, + instance_mask: np.ndarray, + original_instance_mask: np.ndarray, + old_instances: np.ndarray, + ) -> np.ndarray: + """Given the existing labels mask, the updated instance mask is used to update the labels mask. + + :param labels_mask: The existing labels mask, which needs to be updated. + :type labels_mask: numpy.ndarray + :param instance_mask: The instance mask array, with changes made by the user. + :type instance_mask: numpy.ndarray + :param original_instance_mask: The instance mask array, before the changes made by the user. + :type original_instance_mask: numpy.ndarray + :param old_instances: A list of the instance label ids in original_instance_mask. + :type old_instances: list + :return: The new labels mask, with updated changes according to those the user has made in the instance mask. + :rtype: numpy.ndarray + """ + new_labels_mask = np.zeros_like(labels_mask) + for instance_id in np.unique(instance_mask): + where_instance = np.where(instance_mask == instance_id) + # if the label is background skip + if instance_id == 0: + continue + # if the label is a newly added object, add with the same id to the labels mask + # this is an indication to the user that this object needs to be assigned a class + elif instance_id not in old_instances: + new_labels_mask[where_instance] = instance_id + else: + where_instance_orig = np.where(original_instance_mask == instance_id) + # if the locations of the instance haven't changed, means object wasn't changed, do nothing + num_classes = np.unique(labels_mask[where_instance]) + # if area was erased and object retains same class + if len(num_classes) == 1: + new_labels_mask[where_instance] = num_classes[0] + # area was added where there is background or other class + else: + old_class_id, counts = np.unique( + labels_mask[where_instance_orig], return_counts=True + ) + # assert len(old_class_id)==1 + # old_class_id = old_class_id[0] + # and take the class id which is most heavily represented + old_class_id = old_class_id[np.argmax(counts)] + new_labels_mask[where_instance] = old_class_id + + return new_labels_mask + + @staticmethod + def get_unique_objects(active_mask: np.ndarray) -> List: + """Gets unique objects from the active mask. + + :param active_mask: The mask array. + :type active_mask: numpy.ndarray + :return: A list of unique object labels. + :rtype: list + """ + return list(np.unique(active_mask)[1:]) + + @staticmethod + def assert_consistent_labels(mask: np.ndarray) -> tuple: + """Before saving the final mask make sure the user has not mistakenly made an error during annotation, + such that one instance id does not correspond to exactly one class id. Also checks whether for one instance id + multiple classes exist. + :param mask: The mask which we want to test. + :type mask: numpy.ndarray + :return: + - A boolean which is True if there is more than one connected components corresponding to an instance id and Fale otherwise. + - A boolean which is True if there is a missmatch between the instance mask and class masks (not 1-1 correspondance) and Flase otherwise. + - A list with all the instance ids for which more than one connected component was found. + - A list with all the instance ids for which a missmatch between class and instance masks was found. + :rtype : + - bool + - bool + - list[int] + - list[int] + """ + user_annot_error = False + mask_mismatch_error = False + faulty_ids_annot = [] + faulty_ids_missmatch = [] + instance_mask, class_mask = mask[0], mask[1] + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + # check if there are more than one objects (connected components) with same instance_id + if np.unique(label(instance_mask == instance_id)).shape[0] > 2: + user_annot_error = True + faulty_ids_annot.append(instance_id) + # and check if there is a mismatch between class mask and instance mask - should never happen! + if ( + np.unique(class_mask[np.where(instance_mask == instance_id)]).shape[0] + > 1 + ): + mask_mismatch_error = True + faulty_ids_missmatch.append(instance_id) + + return ( + user_annot_error, + mask_mismatch_error, + faulty_ids_annot, + faulty_ids_missmatch, + ) diff --git a/src/client/dcp_client/utils/fsimagestorage.py b/src/client/dcp_client/utils/fsimagestorage.py index 98af9afa..3e8a5e3c 100644 --- a/src/client/dcp_client/utils/fsimagestorage.py +++ b/src/client/dcp_client/utils/fsimagestorage.py @@ -1,20 +1,61 @@ -from skimage.io import imread, imsave import os +import numpy as np +from skimage.io import imread, imsave from dcp_client.app import ImageStorage + class FilesystemImageStorage(ImageStorage): + """FilesystemImageStorage class for handling image storage operations on the local filesystem.""" - def load_image(self, from_directory, cur_selected_img): + def load_image(self, from_directory: str, cur_selected_img: str) -> np.ndarray: + """Loads an image from the specified directory. + + :param from_directory: Path to the directory containing the image. + :type from_directory: str + :param cur_selected_img: Name of the image file. + :type cur_selected_img: str + :return: Loaded image. + """ # Read the selected image and read the segmentation if any: return imread(os.path.join(from_directory, cur_selected_img)) - - def move_image(self, from_directory, to_directory, cur_selected_img): - print(f"from:{os.path.join(from_directory, cur_selected_img)}, to:{os.path.join(to_directory, cur_selected_img)}") - os.replace(os.path.join(from_directory, cur_selected_img), os.path.join(to_directory, cur_selected_img)) - def save_image(self, to_directory, cur_selected_img, img): + def move_image(self, from_directory: str, to_directory: str, cur_selected_img: str) -> None: + """Moves an image from one directory to another. + + :param from_directory: Path to the source directory. + :type from_directory: str + :param to_directory: Path to the destination directory. + :type to_directory: str + :param cur_selected_img: Name of the image file. + :type cur_selected_img: str + """ + print( + f"from:{os.path.join(from_directory, cur_selected_img)}, to:{os.path.join(to_directory, cur_selected_img)}" + ) + os.replace( + os.path.join(from_directory, cur_selected_img), + os.path.join(to_directory, cur_selected_img), + ) + + def save_image(self, to_directory: str, cur_selected_img: str, img: np.ndarray) -> None: + """Saves an image to the specified directory. + + :param to_directory: Path to the directory where the image will be saved. + :type to_directory: str + :param cur_selected_img: Name of the image file. + :type cur_selected_img: str + :param img: Image data to be saved. + """ + imsave(os.path.join(to_directory, cur_selected_img), img) - - def delete_image(self, from_directory, cur_selected_img): + + def delete_image(self, from_directory: str, cur_selected_img: str) -> None: + """Deletes an image from the specified directory. + + :param from_directory: Path to the directory containing the image. + :type from_directory: str + :param cur_selected_img: Name of the image file. + :type cur_selected_img: str + """ os.remove(os.path.join(from_directory, cur_selected_img)) diff --git a/src/client/dcp_client/utils/settings.py b/src/client/dcp_client/utils/settings.py index 2fd6bcb2..5107fb82 100644 --- a/src/client/dcp_client/utils/settings.py +++ b/src/client/dcp_client/utils/settings.py @@ -1,5 +1,6 @@ -def init(): +def init() -> None: + """ Initialise global variables.""" global accepted_types accepted_types = (".jpg", ".jpeg", ".png", ".tiff", ".tif") global seg_name_string - seg_name_string = '_seg' + seg_name_string = "_seg" diff --git a/src/client/dcp_client/utils/sync_src_dst.py b/src/client/dcp_client/utils/sync_src_dst.py index 091c475a..0698901d 100644 --- a/src/client/dcp_client/utils/sync_src_dst.py +++ b/src/client/dcp_client/utils/sync_src_dst.py @@ -6,14 +6,16 @@ class DataRSync(DataSync): - ''' + """ Class which uses rsync bash command to sync data between client and server - ''' - def __init__(self, - user_name: str, - host_name: str, - server_repo_path: str, - ): + """ + + def __init__( + self, + user_name: str, + host_name: str, + server_repo_path: str, + ) -> None: """Constructs all the necessary attributes for the CustomRunnable. :param user_name: the user name of the server - if "local", then it is assumed that local machine is used for the server @@ -22,39 +24,50 @@ def __init__(self, :type: host_name: str :param server_repo_path: the server path where we wish to sync data - if None, then it is assumed that local machine is used for the server :type server_repo_path: str - """ + """ self.user_name = user_name self.host_name = host_name self.server_repo_path = server_repo_path - def first_sync(self, path): + def first_sync(self, path: str) -> tuple: """ During the first sync the folder structure should be created on the server + + :param path: Path to the local directory to synchronize. + :type path: str + :return: result message of subprocess + :rtype: tuple """ - server = self.user_name + "@" + self.host_name + ":" + self.server_repo_path + server = self.user_name + "@" + self.host_name + ":" + self.server_repo_path try: # Run the subprocess command - result = subprocess.run(["rsync", - "-azP" , - path, - server], - check=True) + result = subprocess.run(["rsync", "-azP", path, server], check=True) return ("Success", result.stdout) except subprocess.CalledProcessError as e: return ("Error", e) + def sync(self, src: str, dst: str, path: str) -> tuple: + """Syncs the data between the src and the dst. Both src and dst can be one of either + 'client' or 'server', whereas path is the local path we wish to sync + + :param src: A string specifying the source, from where the data will be sent to dst. Can be 'client' or 'server'. + :type src: str + :param dst: A string specifying the destination, where the data from src will be sent to. Can be 'client' or 'server'. + :type dst: str + :param path: Path to the directory we want to synchronize. + :type path: str + :return: result message of subprocess + :rtype: tuple - def sync(self, src, dst, path): - """ Syncs the data between the src and the dst. Both src and dst can be one of either - 'client' or 'server', whereas path is the local path we wish to sync""" - path += '/' # otherwise it doesn't go in the directory - rel_path = get_relative_path(path) # get last folder, i.e. uncurated, curated + """ + path += "/" # otherwise it doesn't go in the directory + rel_path = get_relative_path(path) # get last folder, i.e. uncurated, curated server_full_path = os.path.join(self.server_repo_path, rel_path) - server_full_path += '/' - server = self.user_name + "@" + self.host_name + ":" + server_full_path - print('server is: ', server) - - if src=='server': + server_full_path += "/" + server = self.user_name + "@" + self.host_name + ":" + server_full_path + print("server is: ", server) + + if src == "server": src = server dst = path else: @@ -62,19 +75,14 @@ def sync(self, src, dst, path): dst = server try: # Run the subprocess command - _ = subprocess.run(["rsync", - "-r" , - "--delete", - src, - dst], - check=True) + _ = subprocess.run(["rsync", "-r", "--delete", src, dst], check=True) return ("Success", server_full_path) except subprocess.CalledProcessError as e: return ("Error", e) - -if __name__=="__main__": - ds = DataRSync() #vm2 + +if __name__ == "__main__": + ds = DataRSync() # vm2 # These combinations work for me: # ubuntu@jusuf-vm2:/path... # jusuf-vm2:/path... @@ -82,6 +90,8 @@ def sync(self, src, dst, path): src = "client" # dst = 'client' # src = 'server' - #path = "data/" - path = "/Users/christina.bukas/Documents/AI_projects/code/data-centric-platform/data" - ds.sync(src, dst, path) \ No newline at end of file + # path = "data/" + path = ( + "/Users/christina.bukas/Documents/AI_projects/code/data-centric-platform/data" + ) + ds.sync(src, dst, path) diff --git a/src/client/dcp_client/utils/utils.py b/src/client/dcp_client/utils/utils.py index 1423c3fd..8b9890b3 100644 --- a/src/client/dcp_client/utils/utils.py +++ b/src/client/dcp_client/utils/utils.py @@ -6,20 +6,32 @@ from skimage.measure import find_contours, label from skimage.draw import polygon_perimeter + from pathlib import Path, PurePath -import json +import yaml +import numpy as np from dcp_client.utils import settings + class IconProvider(QFileIconProvider): def __init__(self) -> None: + """Initializes the IconProvider with the default icon size.""" super().__init__() - self.ICON_SIZE = QSize(512,512) + self.ICON_SIZE = QSize(512, 512) + + def icon(self, type: QFileIconProvider.IconType) -> QIcon: + """Returns the icon for the specified file type. - def icon(self, type: 'QFileIconProvider.IconType'): + :param type: The type of the file for which the icon is requested. + :type type: QFileIconProvider.IconType + :return: The icon for the file type. + :rtype: QIcon + """ try: fn = type.filePath() - except AttributeError: return super().icon(type) # TODO handle exception differently? + except AttributeError: + return super().icon(type) # TODO handle exception differently? if fn.endswith(settings.accepted_types): a = QPixmap(self.ICON_SIZE) @@ -29,34 +41,46 @@ def icon(self, type: 'QFileIconProvider.IconType'): else: return super().icon(type) -def read_config(name, config_path = 'config.cfg') -> dict: + +def read_config(name: str, config_path: str = "config.yaml") -> dict: """Reads the configuration file :param name: name of the section you want to read (e.g. 'setup','train') :type name: string - :param config_path: path to the configuration file, defaults to 'config.cfg' + :param config_path: path to the configuration file, defaults to 'config.yaml' :type config_path: str, optional :return: dictionary from the config section given by name :rtype: dict - """ + """ with open(config_path) as config_file: - config_dict = json.load(config_file) + config_dict = yaml.safe_load( + config_file + ) # json.load(config_file) for .cfg file # Check if config file has main mandatory keys - assert all([i in config_dict.keys() for i in ['server']]) + assert all([i in config_dict.keys() for i in ["server"]]) return config_dict[name] -def get_relative_path(filepath): return PurePath(filepath).name -def get_path_stem(filepath): return str(Path(filepath).stem) +def get_relative_path(filepath: str) -> str: + """Returns the name of the file from the given filepath. -def get_path_name(filepath): return str(Path(filepath).name) + :param filepath: The path of the file. + :type filepath: str + :return: The name of the file. + :rtype: str + """ + return PurePath(filepath).name -def get_path_parent(filepath): return str(Path(filepath).parent) -def join_path(root_dir, filepath): return str(Path(root_dir, filepath)) +def get_path_stem(filepath: str) -> str: + """Returns the stem (filename without its extension) from the given filepath. -def check_equal_arrays(array1, array2): - return np.array_equal(array1, array2) + :param filepath: The path of the file. + :type filepath: str + :return: The stem of the file. + :rtype: str + """ + return str(Path(filepath).stem) class CustomItemDelegate(QStyledItemDelegate): """ @@ -88,164 +112,51 @@ def sizeHint(self, option, index): size = super().sizeHint(option, index) size.setHeight(100) return size - -class Compute4Mask: - - @staticmethod - def get_contours(instance_mask, contours_level=None): - ''' - Find contours of objects in the instance mask. - This function is used to identify the contours of the objects to prevent - the problem of the merged objects in napari window (mask). - - Parameters: - - instance_mask (numpy.ndarray): The instance mask array. - - Returns: - - contour_mask (numpy.ndarray): A binary mask where the contours of all objects in the instance segmentation mask are one and the rest is background. - ''' - instance_ids = Compute4Mask.get_unique_objects(instance_mask) # get object instance labels ignoring background - contour_mask= np.zeros_like(instance_mask) - for instance_id in instance_ids: - # get a binary mask only of object - single_obj_mask = np.zeros_like(instance_mask) - single_obj_mask[instance_mask==instance_id] = 1 - # compute contours for mask - contours = find_contours(single_obj_mask, contours_level) - # sometimes little dots appeas as additional contours so remove these - if len(contours)>1: - contour_sizes = [contour.shape[0] for contour in contours] - contour = contours[contour_sizes.index(max(contour_sizes))].astype(int) - else: contour = contours[0] - # and draw onto contours mask - rr, cc = polygon_perimeter(contour[:, 0], contour[:, 1], contour_mask.shape) - contour_mask[rr, cc] = instance_id - return contour_mask - - @staticmethod - def add_contour(labels_mask, instance_mask): - ''' - Add contours of objects to the labels mask. - - Parameters: - - labels_mask (numpy.ndarray): The class mask array without the contour pixels annotated. - - instance_mask (numpy.ndarray): The instance mask array. - - Returns: - - labels_mask (numpy.ndarray): The updated class mask including contours. - ''' - instance_ids = Compute4Mask.get_unique_objects(instance_mask) - for instance_id in instance_ids: - where_instances = np.where(instance_mask==instance_id) - # get unique class ids where the object is present - class_vals, counts = np.unique(labels_mask[where_instances], return_counts=True) - # and take the class id which is most heavily represented - class_id = class_vals[np.argmax(counts)] - # make sure instance mask and class mask match - labels_mask[np.where(instance_mask==instance_id)] = class_id - return labels_mask - - - @staticmethod - def compute_new_instance_mask(labels_mask, instance_mask): - ''' - Given an updated labels mask, update also the instance mask accordingly. So far the user can only remove an entire object in the labels mask view. - Therefore the instance mask can only change by entirely removing an object. - - Parameters: - - labels_mask (numpy.ndarray): The labels mask array, with changes made by the user. - - instance_mask (numpy.ndarray): The existing instance mask, which needs to be updated. - Returns: - - instance_mask (numpy.ndarray): The updated instance mask. - ''' - instance_ids = Compute4Mask.get_unique_objects(instance_mask) - for instance_id in instance_ids: - unique_items_in_class_mask = list(np.unique(labels_mask[instance_mask==instance_id])) - if len(unique_items_in_class_mask)==1 and unique_items_in_class_mask[0]==0: - instance_mask[instance_mask==instance_id] = 0 - return instance_mask - - - @staticmethod - def compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances): - ''' - Given the existing labels mask, the updated instance mask is used to update the labels mask. - - Parameters: - - labels_mask (numpy.ndarray): The existing labels mask, which needs to be updated. - - instance_mask (numpy.ndarray): The instance mask array, with changes made by the user. - - original_instance_mask (numpy.ndarray): The instance mask array, before the changes made by the user. - - old_instances (List): A list of the instance label ids in original_instance_mask. - Returns: - - new_labels_mask (numpy.ndarray): The new labels mask, with updated changes according to those the user has made in the instance mask. - ''' - new_labels_mask = np.zeros_like(labels_mask) - for instance_id in np.unique(instance_mask): - where_instance = np.where(instance_mask==instance_id) - # if the label is background skip - if instance_id==0: continue - # if the label is a newly added object, add with the same id to the labels mask - # this is an indication to the user that this object needs to be assigned a class - elif instance_id not in old_instances: - new_labels_mask[where_instance] = instance_id - else: - where_instance_orig = np.where(original_instance_mask==instance_id) - # if the locations of the instance haven't changed, means object wasn't changed, do nothing - num_classes = np.unique(labels_mask[where_instance]) - # if area was erased and object retains same class - if len(num_classes)==1: - new_labels_mask[where_instance] = num_classes[0] - # area was added where there is background or other class - else: - old_class_id, counts = np.unique(labels_mask[where_instance_orig], return_counts=True) - #assert len(old_class_id)==1 - #old_class_id = old_class_id[0] - # and take the class id which is most heavily represented - old_class_id = old_class_id[np.argmax(counts)] - new_labels_mask[where_instance] = old_class_id - - return new_labels_mask - - @staticmethod - def get_unique_objects(active_mask): - """ - Get unique objects from the active mask. - """ - return list(np.unique(active_mask)[1:]) - - @staticmethod - def assert_consistent_labels(mask): - """ - Before saving the final mask make sure the user has not mistakenly made an error during annotation, - such that one instance id does not correspond to exactly one class id. Also checks whether for one instance id - multiple classes exist. - :param mask: The mask which we want to test. - :type mask: numpy.ndarray - :return: - - A boolean which is True if there is more than one connected components corresponding to an instance id and Fale otherwise. - - A boolean which is True if there is a missmatch between the instance mask and class masks (not 1-1 correspondance) and Flase otherwise. - - A list with all the instance ids for which more than one connected component was found. - - A list with all the instance ids for which a missmatch between class and instance masks was found. - :rtype : - - bool - - bool - - list[int] - - list[int] - """ - user_annot_error = False - mask_mismatch_error = False - faulty_ids_annot = [] - faulty_ids_missmatch = [] - instance_mask, class_mask = mask[0], mask[1] - instance_ids = Compute4Mask.get_unique_objects(instance_mask) - for instance_id in instance_ids: - # check if there are more than one objects (connected components) with same instance_id - if np.unique(label(instance_mask==instance_id)).shape[0] > 2: - user_annot_error = True - faulty_ids_annot.append(instance_id) - # and check if there is a mismatch between class mask and instance mask - should never happen! - if np.unique(class_mask[np.where(instance_mask==instance_id)]).shape[0]>1: - mask_mismatch_error = True - faulty_ids_missmatch.append(instance_id) - - return user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch \ No newline at end of file + + +def get_path_name(filepath: str) -> str: + """Returns the name of the file from the given filepath. + + :param filepath: The path of the file. + :type filepath: str + :return: The name of the file. + :rtype: str + """ + return str(Path(filepath).name) + + +def get_path_parent(filepath: str) -> str: + """Returns the parent directory of the given filepath. + + :param filepath: The path of the file. + :type filepath: str + :return: The parent directory of the file. + :rtype: str + """ + return str(Path(filepath).parent) + + +def join_path(root_dir: str, filepath: str) -> str: + """Joins the root directory path with the given filepath. + + :param root_dir: The root directory. + :type root_dir: str + :param filepath: The path of the file. + :type filepath: str + :return: The joined path. + :rtype: str + """ + return str(Path(root_dir, filepath)) + + +def check_equal_arrays(array1: np.ndarray, array2: np.ndarray) -> bool: + """Checks if two arrays are equal. + + :param array1: The first array. + :type array1: numpy.ndarray + :param array2: The second array. + :type array2: numpy.ndarray + :return: True if the arrays are equal, False otherwise. + :rtype: bool + """ + return np.array_equal(array1, array2) diff --git a/src/client/pyproject.toml b/src/client/pyproject.toml index ed28d77c..93af7bd7 100644 --- a/src/client/pyproject.toml +++ b/src/client/pyproject.toml @@ -9,11 +9,10 @@ packages = ['dcp_client'] dependencies = {file = ["requirements.txt"]} [project] -name = "data-centric-tool-client" +name = "data-centric-platform-client" version = "0.1" -requires-python = ">=3.8" -description = "" -# license = {file = "LICENSE.txt"} +requires-python = ">=3.9" +description = "The client of the data centric platform for microscopy image segmentation" keywords = [] classifiers = [ "Programming Language :: Python :: 3", @@ -22,23 +21,27 @@ classifiers = [ readme = "README.md" dynamic = ["dependencies"] authors = [ - {name="Christina Bukas", email="christina.bukas@helmholtz-muenchen.de"}, - {name="Helena Pelin", email="helena.pelin@helmholtz-muenchen.de"} + {name="Christina Bukas", email="christina.bukas@helmholtz-munich.de"}, + {name="Helena Pelin", email="helena.pelin@helmholtz-munich.de"}, + {name="Mariia Koren", email="mariia.koren@helmholtz-munich.de"}, + {name="Marie Piraud", email="marie.piraud@helmholtz-munich.de"}, ] maintainers = [ - {name="Christina Bukas", email="christina.bukas@helmholtz-muenchen.de"}, - {name="Helena Pelin", email="helena.pelin@helmholtz-muenchen.de"} + {name="Christina Bukas", email="christina.bukas@helmholtz-munich.de"}, + {name="Helena Pelin", email="helena.pelin@helmholtz-munich.de"} ] [project.optional-dependencies] dev = [ - "pytest", + "pytest>=7.4.3", + "pytest-qt>=4.2.0", + "sphinx", + "sphinx-rtd-theme" ] [project.urls] repository = "https://github.com/HelmholtzAI-Consultants-Munich/data-centric-platform" -# homepage = "https://example.com" -# documentation = "https://readthedocs.org" +documentation = "https://readthedocs.org/projects/data-centric-platform" [project.scripts] dcp-client = "dcp_client.main:main" diff --git a/src/client/readme_figs/napari_shortcuts.png b/src/client/readme_figs/napari_shortcuts.png new file mode 100644 index 00000000..4781f119 Binary files /dev/null and b/src/client/readme_figs/napari_shortcuts.png differ diff --git a/src/client/requirements.txt b/src/client/requirements.txt index 798b769f..e47ad839 100644 --- a/src/client/requirements.txt +++ b/src/client/requirements.txt @@ -1,4 +1,2 @@ napari[pyqt5]>=0.4.17 -bentoml[grpc]==1.0.16 -pytest>=7.4.3 -pytest-qt>=4.2.0 \ No newline at end of file +bentoml[grpc]==1.0.16 \ No newline at end of file diff --git a/src/client/test/test_app.py b/src/client/test/test_app.py index e4e6d1f9..ad31285a 100644 --- a/src/client/test/test_app.py +++ b/src/client/test/test_app.py @@ -1,5 +1,6 @@ import os import sys + sys.path.append("../") import pytest import subprocess @@ -13,88 +14,101 @@ from dcp_client.utils.fsimagestorage import FilesystemImageStorage from dcp_client.utils.sync_src_dst import DataRSync + @pytest.fixture def app(): img1 = data.astronaut() img2 = data.coffee() img3 = data.cat() - if not os.path.exists('in_prog'): - os.mkdir('in_prog') - imsave('in_prog/coffee.png', img2) + if not os.path.exists("in_prog"): + os.mkdir("in_prog") + imsave("in_prog/coffee.png", img2) - if not os.path.exists('eval_data_path'): - os.mkdir('eval_data_path') - imsave('eval_data_path/cat.png', img3) + if not os.path.exists("eval_data_path"): + os.mkdir("eval_data_path") + imsave("eval_data_path/cat.png", img3) - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') - app = Application(BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", - 7010, - os.path.join(os.getcwd(), 'eval_data_path')) + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") + app = Application( + BentomlModel(), + rsyncer, + FilesystemImageStorage(), + "0.0.0.0", + 7010, + os.path.join(os.getcwd(), "eval_data_path"), + ) return app, img1, img2, img3 + def test_load_image(app): app, img, img2, _ = app # Unpack the app, img, and img2 from the fixture - - app.cur_selected_img = 'coffee.png' - app.cur_selected_path = 'in_prog' + + app.cur_selected_img = "coffee.png" + app.cur_selected_path = "in_prog" img_test = app.load_image() # if image_name is None assert img.all() == img_test.all() - app.cur_selected_path = 'eval_data_path' - img_test2 = app.load_image('cat.png') # if a filename is given + app.cur_selected_path = "eval_data_path" + img_test2 = app.load_image("cat.png") # if a filename is given assert img2.all() == img_test2.all() + def test_run_inference_no_connection(app): - app, _, _, _ = app + app, _, _, _ = app message_text, message_title = app.run_inference() - assert message_text=="Connection could not be established. Please check if the server is running and try again." - assert message_title=="Warning" + assert ( + message_text + == "Connection could not be established. Please check if the server is running and try again." + ) + assert message_title == "Warning" + def test_run_inference_run(app): - app, _, _, _ = app + app, _, _, _ = app # start the sevrer in the background locally command = [ "bentoml", - "serve", - '--working-dir', - '../server/dcp_server', + "serve", + "--working-dir", + "../server/dcp_server", "service:svc", "--reload", "--port=7010", ] process = subprocess.Popen(command, stdin=subprocess.PIPE, shell=False) # and wait until it is setup - if sys.platform == 'win32' or sys.platform == 'cygwin': time.sleep(240) - else: time.sleep(60) + if sys.platform == "win32" or sys.platform == "cygwin": + time.sleep(240) + else: + time.sleep(60) # then do model serving message_text, message_title = app.run_inference() # and assert returning message print(f"HERE: {message_text, message_title}") - assert message_text== "Success! Masks generated for all images" - assert message_title=="Information" + assert message_text == "Success! Masks generated for all images" + assert message_title == "Information" # finally clean up process process.terminate() process.wait() process.kill() + def test_search_segs(app): - app, _, _, _ = app - app.cur_selected_img = 'cat.png' - app.cur_selected_path = 'eval_data_path' + app, _, _, _ = app + app.cur_selected_img = "cat.png" + app.cur_selected_path = "eval_data_path" app.search_segs() - res = app.seg_filepaths - assert len(res)==1 - assert res[0]=='cat_seg.tiff' + res = app.seg_filepaths + assert len(res) == 1 + assert res[0] == "cat_seg.tiff" # also remove the seg as it is not needed for other scripts - os.remove('eval_data_path/cat_seg.tiff') + os.remove("eval_data_path/cat_seg.tiff") -''' + +""" def test_run_train(): pass @@ -107,7 +121,4 @@ def test_move_images(): def test_delete_images(): pass -''' - - - +""" diff --git a/src/client/test/test_compute4mask.py b/src/client/test/test_compute4mask.py index e76dfc1c..5304e2ee 100644 --- a/src/client/test/test_compute4mask.py +++ b/src/client/test/test_compute4mask.py @@ -1,83 +1,114 @@ import numpy as np import pytest -from dcp_client.utils.utils import Compute4Mask +from dcp_client.utils.compute4mask import Compute4Mask + @pytest.fixture def sample_data(): - instance_mask = np.array([[0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [2, 2, 0, 0, 0], - [0, 0, 3, 3, 0]]) - labels_mask = np.array([[0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [2, 2, 0, 0, 0], - [0, 0, 1, 1, 0]]) + instance_mask = np.array( + [ + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [2, 2, 0, 0, 0], + [0, 0, 3, 3, 0], + ] + ) + labels_mask = np.array( + [ + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [2, 2, 0, 0, 0], + [0, 0, 1, 1, 0], + ] + ) return instance_mask, labels_mask + def test_get_unique_objects(sample_data): instance_mask, _ = sample_data unique_objects = Compute4Mask.get_unique_objects(instance_mask) assert unique_objects == [1, 2, 3] + def test_get_contours(sample_data): instance_mask, _ = sample_data contour_mask = Compute4Mask.get_contours(instance_mask) assert contour_mask.shape == instance_mask.shape - assert contour_mask[0,1] == 1 # randomly check a contour location is present + assert contour_mask[0, 1] == 1 # randomly check a contour location is present + def test_add_contour(sample_data): instance_mask, labels_mask = sample_data contours_mask = Compute4Mask.get_contours(instance_mask, contours_level=0.1) labels_mask_wo_contour = np.copy(labels_mask) - labels_mask_wo_contour[contours_mask!=0] = 0 - updated_labels_mask = Compute4Mask.add_contour(labels_mask_wo_contour, instance_mask) + labels_mask_wo_contour[contours_mask != 0] = 0 + updated_labels_mask = Compute4Mask.add_contour( + labels_mask_wo_contour, instance_mask + ) assert np.array_equal(updated_labels_mask[:3], labels_mask[:3]) + def test_compute_new_instance_mask(sample_data): instance_mask, labels_mask = sample_data - labels_mask[labels_mask==1] = 0 - updated_instance_mask = Compute4Mask.compute_new_instance_mask(labels_mask, instance_mask) - assert list(np.unique(updated_instance_mask))==[0,2] + labels_mask[labels_mask == 1] = 0 + updated_instance_mask = Compute4Mask.compute_new_instance_mask( + labels_mask, instance_mask + ) + assert list(np.unique(updated_instance_mask)) == [0, 2] + def test_compute_new_labels_mask_obj_added(sample_data): instance_mask, labels_mask = sample_data original_instance_mask = np.copy(instance_mask) instance_mask[0, 0] = 4 old_instances = Compute4Mask.get_unique_objects(original_instance_mask) - new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) - assert new_labels_mask[0,0]==4 + new_labels_mask = Compute4Mask.compute_new_labels_mask( + labels_mask, instance_mask, original_instance_mask, old_instances + ) + assert new_labels_mask[0, 0] == 4 + def test_compute_new_labels_mask_obj_erased(sample_data): instance_mask, labels_mask = sample_data original_instance_mask = np.copy(instance_mask) instance_mask[0] = 0 old_instances = Compute4Mask.get_unique_objects(original_instance_mask) - new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) - assert np.all(new_labels_mask[0])==0 + new_labels_mask = Compute4Mask.compute_new_labels_mask( + labels_mask, instance_mask, original_instance_mask, old_instances + ) + assert np.all(new_labels_mask[0]) == 0 assert np.array_equal(new_labels_mask[1:], labels_mask[1:]) + def test_compute_new_labels_mask_obj_added(sample_data): instance_mask, labels_mask = sample_data original_instance_mask = np.copy(instance_mask) instance_mask[:, -1] = 1 old_instances = Compute4Mask.get_unique_objects(original_instance_mask) - new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) - assert np.all(new_labels_mask[:, -1])==1 + new_labels_mask = Compute4Mask.compute_new_labels_mask( + labels_mask, instance_mask, original_instance_mask, old_instances + ) + assert np.all(new_labels_mask[:, -1]) == 1 + def assert_consistent_labels(sample_data): instance_mask, labels_mask = sample_data - user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = Compute4Mask.assert_consistent_labels(sample_data) - assert user_annot_error==False - assert mask_mismatch_error==False - assert len(faulty_ids_annot)==len(faulty_ids_missmatch)==0 - instance_mask[instance_mask==3] = 1 - labels_mask[1,2] = 2 - user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = Compute4Mask.assert_consistent_labels(np.stack(instance_mask, labels_mask)) - assert user_annot_error==True - assert mask_mismatch_error==True - assert len(faulty_ids_annot)==1 - assert faulty_ids_annot[0]==1 - assert len(faulty_ids_missmatch)==1 - assert faulty_ids_missmatch[0]==1 + user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = ( + Compute4Mask.assert_consistent_labels(sample_data) + ) + assert user_annot_error == False + assert mask_mismatch_error == False + assert len(faulty_ids_annot) == len(faulty_ids_missmatch) == 0 + instance_mask[instance_mask == 3] = 1 + labels_mask[1, 2] = 2 + user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = ( + Compute4Mask.assert_consistent_labels(np.stack(instance_mask, labels_mask)) + ) + assert user_annot_error == True + assert mask_mismatch_error == True + assert len(faulty_ids_annot) == 1 + assert faulty_ids_annot[0] == 1 + assert len(faulty_ids_missmatch) == 1 + assert faulty_ids_missmatch[0] == 1 diff --git a/src/client/test/test_fsimagestorage.py b/src/client/test/test_fsimagestorage.py index 275e5f0b..f971fbfe 100644 --- a/src/client/test/test_fsimagestorage.py +++ b/src/client/test/test_fsimagestorage.py @@ -5,42 +5,48 @@ from dcp_client.utils.fsimagestorage import FilesystemImageStorage + @pytest.fixture def fis(): return FilesystemImageStorage() + @pytest.fixture def sample_image(): # Create a sample image img = data.astronaut() - fname = 'test_img.png' + fname = "test_img.png" imsave(fname, img) return fname - + + def test_load_image(fis, sample_image): - img_test = fis.load_image('.', sample_image) + img_test = fis.load_image(".", sample_image) assert img_test.all() == data.astronaut().all() os.remove(sample_image) + def test_move_image(fis, sample_image): - temp_dir = 'temp' + temp_dir = "temp" os.mkdir(temp_dir) - fis.move_image('.', temp_dir, sample_image) - assert os.path.exists(os.path.join(temp_dir, 'test_img.png')) - os.remove(os.path.join(temp_dir, 'test_img.png')) + fis.move_image(".", temp_dir, sample_image) + assert os.path.exists(os.path.join(temp_dir, "test_img.png")) + os.remove(os.path.join(temp_dir, "test_img.png")) os.rmdir(temp_dir) + def test_save_image(fis): img = data.astronaut() - fname = 'output.png' - fis.save_image('.', fname, img) + fname = "output.png" + fis.save_image(".", fname, img) assert os.path.exists(fname) os.remove(fname) + def test_delete_image(fis, sample_image): - temp_dir = 'temp' + temp_dir = "temp" os.mkdir(temp_dir) - fis.move_image('.', temp_dir, sample_image) - fis.delete_image(temp_dir, 'test_img.png') - assert not os.path.exists(os.path.join(temp_dir, 'test_img.png')) + fis.move_image(".", temp_dir, sample_image) + fis.delete_image(temp_dir, "test_img.png") + assert not os.path.exists(os.path.join(temp_dir, "test_img.png")) os.rmdir(temp_dir) diff --git a/src/client/test/test_main_window.py b/src/client/test/test_main_window.py index bdd72d01..5bb61dcc 100644 --- a/src/client/test/test_main_window.py +++ b/src/client/test/test_main_window.py @@ -1,6 +1,7 @@ import os import pytest import sys + sys.path.append("../") from skimage import data @@ -23,7 +24,8 @@ def setup_global_variable(): settings.accepted_types = (".jpg", ".jpeg", ".png", ".tiff", ".tif") yield settings.accepted_types -@pytest.fixture() + +@pytest.fixture def app(qtbot, setup_global_variable): settings.accepted_types = setup_global_variable @@ -32,101 +34,108 @@ def app(qtbot, setup_global_variable): img2 = data.coffee() img3 = data.cat() - if not os.path.exists('train_data_path'): - os.mkdir('train_data_path') - imsave('train_data_path/astronaut.png', img1) - - if not os.path.exists('in_prog'): - os.mkdir('in_prog') - imsave('in_prog/coffee.png', img2) - - if not os.path.exists('eval_data_path'): - os.mkdir('eval_data_path') - imsave('eval_data_path/cat.png', img3) - - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') - application = Application(BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", - 7010, - 'eval_data_path', - 'train_data_path', - 'in_prog') + if not os.path.exists("train_data_path"): + os.mkdir("train_data_path") + imsave("train_data_path/astronaut.png", img1) + + if not os.path.exists("in_prog"): + os.mkdir("in_prog") + imsave("in_prog/coffee.png", img2) + + if not os.path.exists("eval_data_path"): + os.mkdir("eval_data_path") + imsave("eval_data_path/cat.png", img3) + + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") + application = Application( + BentomlModel(), + rsyncer, + FilesystemImageStorage(), + "0.0.0.0", + 7010, + "eval_data_path", + "train_data_path", + "in_prog", + ) # Create an instance of MainWindow widget = MainWindow(application) qtbot.addWidget(widget) yield widget widget.close() - + + def test_main_window_setup(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable assert app.title == "Data Overview" + def test_item_train_selected(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable # Select the first item in the tree view - #index = app.list_view_train.model().index(0, 0) + # index = app.list_view_train.model().index(0, 0) index = app.list_view_train.indexAt(app.list_view_train.viewport().rect().topLeft()) pos = app.list_view_train.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_train.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_train.viewport(), Qt.LeftButton, pos=pos) app.on_item_train_selected(index) # Assert that the selected item matches the expected item assert app.list_view_train.selectionModel().currentIndex() == index - assert app.app.cur_selected_img=='astronaut.png' - assert app.app.cur_selected_path==app.app.train_data_path + assert app.app.cur_selected_img == "astronaut.png" + assert app.app.cur_selected_path == app.app.train_data_path + def test_item_inprog_selected(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable # Select the first item in the tree view - index = app.list_view_inprogr.indexAt(app.list_view_inprogr.viewport().rect().topLeft()) + index = app.list_view_inprogr.indexAt( + app.list_view_inprogr.viewport().rect().topLeft() + ) pos = app.list_view_inprogr.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_inprogr.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_inprogr.viewport(), Qt.LeftButton, pos=pos) app.on_item_inprogr_selected(index) # Assert that the selected item matches the expected item assert app.list_view_inprogr.selectionModel().currentIndex() == index assert app.app.cur_selected_img == "coffee.png" assert app.app.cur_selected_path == app.app.inprogr_data_path + def test_item_eval_selected(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable # Select the first item in the tree view index = app.list_view_eval.indexAt(app.list_view_eval.viewport().rect().topLeft()) pos = app.list_view_eval.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_eval.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_eval.viewport(), Qt.LeftButton, pos=pos) app.on_item_eval_selected(index) # Assert that the selected item matches the expected item assert app.list_view_eval.selectionModel().currentIndex() == index - assert app.app.cur_selected_img=='cat.png' - assert app.app.cur_selected_path==app.app.eval_data_path + assert app.app.cur_selected_img == "cat.png" + assert app.app.cur_selected_path == app.app.eval_data_path + def test_train_button_click(qtbot, app): # Click the "Train Model" button app.sim = True QTest.mouseClick(app.train_button, Qt.LeftButton) # Wait until the worker thread is done - while app.worker_thread.isRunning(): QTest.qSleep(1000) + while app.worker_thread.isRunning(): + QTest.qSleep(1000) # The train functionality of the thread is tested with app tests + def test_inference_button_click(qtbot, app): # Click the "Generate Labels" button app.sim = True QTest.mouseClick(app.inference_button, Qt.LeftButton) # Wait until the worker thread is done - while app.worker_thread.isRunning(): QTest.qSleep(1000) - #QTest.qWaitForWindowActive(app, timeout=5000) + while app.worker_thread.isRunning(): + QTest.qSleep(1000) + # QTest.qWaitForWindowActive(app, timeout=5000) # The inference functionality of the thread is tested with app tests + def test_on_finished(qtbot, app): # Assert that the on_finished function re-enabled the buttons and set the worker thread to None assert app.train_button.isEnabled() @@ -150,7 +159,8 @@ def test_launch_napari_button_click_without_selection(qtbot, app): # Try clicking the view button without having selected an image app.sim = True qtbot.mouseClick(app.launch_nap_button, Qt.LeftButton) - assert not hasattr(app, 'nap_win') + assert not hasattr(app, "nap_win") + def test_launch_napari_button_click(qtbot, app): settings.accepted_types = setup_global_variable @@ -158,29 +168,28 @@ def test_launch_napari_button_click(qtbot, app): index = app.list_view_eval.indexAt(app.list_view_eval.viewport().rect().topLeft()) pos = app.list_view_eval.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_eval.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_eval.viewport(), Qt.LeftButton, pos=pos) app.on_item_eval_selected(index) # Now click the view button qtbot.mouseClick(app.launch_nap_button, Qt.LeftButton) # Assert that the napari window has launched - assert hasattr(app, 'nap_win') + assert hasattr(app, "nap_win") assert app.nap_win.isVisible() -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def cleanup_files(request): # This code runs after all tests from all files have completed yield # Clean up - for fname in os.listdir('train_data_path'): - os.remove(os.path.join('train_data_path', fname)) - os.rmdir('train_data_path') - - for fname in os.listdir('in_prog'): - os.remove(os.path.join('in_prog', fname)) - os.rmdir('in_prog') - - for fname in os.listdir('eval_data_path'): - os.remove(os.path.join('eval_data_path', fname)) - os.rmdir('eval_data_path') + paths_to_clean = ["train_data_path", "in_prog", "eval_data_path"] + for path in paths_to_clean: + try: + for fname in os.listdir(path): + os.remove(os.path.join(path, fname)) + os.rmdir(path) + except FileNotFoundError: + pass + except Exception as e: + # Handle other exceptions + print(f"An error occurred while cleaning up {path}: {e}") diff --git a/src/client/test/test_mywidget.py b/src/client/test/test_mywidget.py index e75172c1..7e10f53f 100644 --- a/src/client/test/test_mywidget.py +++ b/src/client/test/test_mywidget.py @@ -1,35 +1,47 @@ import pytest import sys -sys.path.append('../') + +sys.path.append("../") from PyQt5.QtWidgets import QMessageBox from dcp_client.gui._my_widget import MyWidget + @pytest.fixture def app(qtbot): - #q_app = QApplication([]) + # q_app = QApplication([]) widget = MyWidget() qtbot.addWidget(widget) yield widget widget.close() + def test_create_warning_box_ok(qtbot, app): result = None app.sim = True + def execute_warning_box(): nonlocal result box = QMessageBox() result = app.create_warning_box("Test Message", custom_dialog=box) - qtbot.waitUntil(execute_warning_box, timeout=5000) - assert result is True + + qtbot.waitUntil(execute_warning_box, timeout=5000) + assert result is True + def test_create_warning_box_cancel(qtbot, app): result = None app.sim = True + def execute_warning_box(): nonlocal result box = QMessageBox() - result = app.create_warning_box("Test Message", add_cancel_btn=True, custom_dialog=box) - qtbot.waitUntil(execute_warning_box, timeout=5000) # Add a timeout for the function to execute - assert result is False + result = app.create_warning_box( + "Test Message", add_cancel_btn=True, custom_dialog=box + ) + + qtbot.waitUntil( + execute_warning_box, timeout=5000 + ) # Add a timeout for the function to execute + assert result is False diff --git a/src/client/test/test_napari_window.py b/src/client/test/test_napari_window.py index 06978ebf..8c31ebcf 100644 --- a/src/client/test/test_napari_window.py +++ b/src/client/test/test_napari_window.py @@ -21,11 +21,12 @@ # yield napari_app # napari_app.close() + @pytest.fixture def napari_window(qtbot): - #img1 = data.astronaut() - #img2 = data.coffee() + # img1 = data.astronaut() + # img2 = data.coffee() img = data.cat() img_mask = np.zeros((2, img.shape[0], img.shape[1]), dtype=np.uint8) img_mask[0, 50:50, 50:50] = 1 @@ -34,61 +35,63 @@ def napari_window(qtbot): img_mask[1, 100:200, 100:200] = 1 img_mask[0, 200:300, 200:300] = 3 img_mask[1, 200:300, 200:300] = 2 - #img3_mask = img2_mask.copy() + # img3_mask = img2_mask.copy() + + if not os.path.exists("train_data_path"): + os.mkdir("train_data_path") - if not os.path.exists('train_data_path'): - os.mkdir('train_data_path') + if not os.path.exists("in_prog"): + os.mkdir("in_prog") - if not os.path.exists('in_prog'): - os.mkdir('in_prog') + if not os.path.exists("eval_data_path"): + os.mkdir("eval_data_path") + imsave("eval_data_path/cat.png", img) - if not os.path.exists('eval_data_path'): - os.mkdir('eval_data_path') - imsave('eval_data_path/cat.png', img) - - imsave('eval_data_path/cat_seg.tiff', img_mask) + imsave("eval_data_path/cat_seg.tiff", img_mask) - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") application = Application( - BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", + BentomlModel(), + rsyncer, + FilesystemImageStorage(), + "0.0.0.0", 7010, - os.path.join(os.getcwd(), 'eval_data_path'), - os.path.join(os.getcwd(), 'train_data_path'), - os.path.join(os.getcwd(), 'in_prog') + os.path.join(os.getcwd(), "eval_data_path"), + os.path.join(os.getcwd(), "train_data_path"), + os.path.join(os.getcwd(), "in_prog"), ) - application.cur_selected_img = 'cat.png' + application.cur_selected_img = "cat.png" application.cur_selected_path = application.eval_data_path widget = NapariWindow(application) - qtbot.addWidget(widget) - yield widget + qtbot.addWidget(widget) + yield widget widget.close() + def test_napari_window_initialization(napari_window): assert napari_window.viewer is not None assert napari_window.qctrl is not None assert napari_window.mask_choice_dropdown is not None + def test_on_add_to_curated_button_clicked(napari_window, monkeypatch): # Mock the create_warning_box method def mock_create_warning_box(message_text, message_title): - return None - monkeypatch.setattr(napari_window, 'create_warning_box', mock_create_warning_box) + return None + + monkeypatch.setattr(napari_window, "create_warning_box", mock_create_warning_box) - napari_window.app.cur_selected_img = 'cat.png' + napari_window.app.cur_selected_img = "cat.png" napari_window.app.cur_selected_path = napari_window.app.eval_data_path - napari_window.viewer.layers.selection.active.name = 'cat_seg' + napari_window.viewer.layers.selection.active.name = "cat_seg" # Simulate the button click napari_window.on_add_to_curated_button_clicked() - assert not os.path.exists('eval_data_path/cat.tiff') - assert not os.path.exists('eval_data_path/cat_seg.tiff') - assert os.path.exists('train_data_path/cat.png') - assert os.path.exists('train_data_path/cat_seg.tiff') - + assert not os.path.exists("eval_data_path/cat.tiff") + assert not os.path.exists("eval_data_path/cat_seg.tiff") + assert os.path.exists("train_data_path/cat.png") + assert os.path.exists("train_data_path/cat_seg.tiff") diff --git a/src/client/test/test_sync_src_dst.py b/src/client/test/test_sync_src_dst.py index ca652644..15ed79d3 100644 --- a/src/client/test/test_sync_src_dst.py +++ b/src/client/test/test_sync_src_dst.py @@ -1,26 +1,25 @@ import pytest -from dcp_client.utils.sync_src_dst import DataRSync +from dcp_client.utils.sync_src_dst import DataRSync @pytest.fixture def rsyncer(): - syncer = DataRSync(user_name="local", - host_name="local", - server_repo_path='.') + syncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") return syncer + def test_init(rsyncer): - assert rsyncer.user_name=="local" - assert rsyncer.host_name=="local" - assert rsyncer.server_repo_path=='.' + assert rsyncer.user_name == "local" + assert rsyncer.host_name == "local" + assert rsyncer.server_repo_path == "." + def test_first_sync_e(rsyncer): msg, _ = rsyncer.first_sync("eval_data_path") - assert msg=="Error" + assert msg == "Error" + def test_sync(rsyncer): msg, _ = rsyncer.sync("server", "client", "eval_data_path") - assert msg=="Error" - - + assert msg == "Error" diff --git a/src/client/test/test_utils.py b/src/client/test/test_utils.py index d09c8df6..88d2ce5b 100644 --- a/src/client/test/test_utils.py +++ b/src/client/test/test_utils.py @@ -3,35 +3,38 @@ sys.path.append("../") from dcp_client.utils import utils + def test_get_relative_path(): - filepath = '/here/we/are/testing/something.txt' - assert utils.get_relative_path(filepath)== 'something.txt' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_relative_path(filepath) == "something.txt" + def test_get_path_stem(): - filepath = '/here/we/are/testing/something.txt' - assert utils.get_path_stem(filepath)== 'something' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_path_stem(filepath) == "something" + def test_get_path_name(): - filepath = '/here/we/are/testing/something.txt' - assert utils.get_path_name(filepath)== 'something.txt' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_path_name(filepath) == "something.txt" + def test_get_path_parent(): - if sys.platform == 'win32' or sys.platform == 'cygwin': - filepath = '\\here\\we\\are\\testing\\something.txt' - assert utils.get_path_parent(filepath)== '\\here\\we\\are\\testing' + if sys.platform == "win32" or sys.platform == "cygwin": + filepath = "\\here\\we\\are\\testing\\something.txt" + assert utils.get_path_parent(filepath) == "\\here\\we\\are\\testing" else: - filepath = '/here/we/are/testing/something.txt' - assert utils.get_path_parent(filepath)== '/here/we/are/testing' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_path_parent(filepath) == "/here/we/are/testing" + def test_join_path(): - if sys.platform == 'win32' or sys.platform == 'cygwin': - filepath = '\\here\\we\\are\\testing\\something.txt' - path1 = '\\here\\we\\are\\testing' - path2 = 'something.txt' + if sys.platform == "win32" or sys.platform == "cygwin": + filepath = "\\here\\we\\are\\testing\\something.txt" + path1 = "\\here\\we\\are\\testing" + path2 = "something.txt" else: - filepath = '/here/we/are/testing/something.txt' - path1 = '/here/we/are/testing' - path2 = 'something.txt' + filepath = "/here/we/are/testing/something.txt" + path1 = "/here/we/are/testing" + path2 = "something.txt" assert utils.join_path(path1, path2) == filepath - - diff --git a/src/client/test/test_welcome_window.py b/src/client/test/test_welcome_window.py index 0d842c7f..854b12ca 100644 --- a/src/client/test/test_welcome_window.py +++ b/src/client/test/test_welcome_window.py @@ -1,6 +1,7 @@ import pytest import sys -sys.path.append('../') + +sys.path.append("../") from PyQt5.QtCore import Qt from PyQt5.QtWidgets import QMessageBox @@ -12,39 +13,48 @@ from dcp_client.utils.sync_src_dst import DataRSync from dcp_client.utils import settings + @pytest.fixture def setup_global_variable(): settings.accepted_types = (".jpg", ".jpeg", ".png", ".tiff", ".tif") yield settings.accepted_types + @pytest.fixture def app(qtbot): - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') - application = Application(BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010) + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") + application = Application( + BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010 + ) # Create an instance of WelcomeWindow # q_app = QApplication([]) widget = WelcomeWindow(application) qtbot.addWidget(widget) - yield widget + yield widget widget.close() + @pytest.fixture def app_remote(qtbot): - rsyncer = DataRSync(user_name="remote", host_name="remote", server_repo_path='.') - application = Application(BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010) + rsyncer = DataRSync(user_name="remote", host_name="remote", server_repo_path=".") + application = Application( + BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010 + ) # Create an instance of WelcomeWindow # q_app = QApplication([]) widget = WelcomeWindow(application) qtbot.addWidget(widget) - yield widget + yield widget widget.close() + def test_welcome_window_initialization(app): assert app.title == "Welcome to Helmholtz AI Data-Centric Tool" assert app.val_textbox.text() == "" assert app.inprogr_textbox.text() == "" assert app.train_textbox.text() == "" + def test_warning_for_same_paths(qtbot, app, monkeypatch): app.app.eval_data_path = "/same/path" app.app.train_data_path = "/same/path" @@ -54,32 +64,43 @@ def test_warning_for_same_paths(qtbot, app, monkeypatch): def custom_exec(self): return QMessageBox.Ok - monkeypatch.setattr(QMessageBox, 'exec', custom_exec) - qtbot.mouseClick(app.start_button, Qt.LeftButton) + monkeypatch.setattr(QMessageBox, "exec", custom_exec) + qtbot.mouseClick(app.start_button, Qt.LeftButton) assert app.create_warning_box assert app.message_text == "All directory names must be distinct." + def test_on_text_changed(qtbot, app): app.app.train_data_path = "/initial/train/path" app.app.eval_data_path = "/initial/eval/path" app.app.inprogr_data_path = "/initial/inprogress/path" - app.on_text_changed(field_obj=app.train_textbox, field_name="train", text="/new/train/path") + app.on_text_changed( + field_obj=app.train_textbox, field_name="train", text="/new/train/path" + ) assert app.app.train_data_path == "/new/train/path" - app.on_text_changed(field_obj=app.val_textbox, field_name="eval", text="/new/eval/path") + app.on_text_changed( + field_obj=app.val_textbox, field_name="eval", text="/new/eval/path" + ) assert app.app.eval_data_path == "/new/eval/path" - app.on_text_changed(field_obj=app.inprogr_textbox, field_name="inprogress", text="/new/inprogress/path") + app.on_text_changed( + field_obj=app.inprogr_textbox, + field_name="inprogress", + text="/new/inprogress/path", + ) assert app.app.inprogr_data_path == "/new/inprogress/path" + def test_start_main_not_selected(qtbot, app): app.app.train_data_path = None app.app.eval_data_path = None app.sim = True qtbot.mouseClick(app.start_button, Qt.LeftButton) - assert not hasattr(app, 'mw') + assert not hasattr(app, "mw") + def test_start_main(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable @@ -93,11 +114,12 @@ def test_start_main(qtbot, app, setup_global_variable): # Simulate clicking the start button qtbot.mouseClick(app.start_button, Qt.LeftButton) # Check if the main window is created - #assert qtbot.waitUntil(lambda: hasattr(app, 'mw'), timeout=1000) - assert hasattr(app, 'mw') + # assert qtbot.waitUntil(lambda: hasattr(app, 'mw'), timeout=1000) + assert hasattr(app, "mw") # Check if the WelcomeWindow is hidden assert app.isHidden() + def test_start_upload_and_main(qtbot, app_remote, setup_global_variable, monkeypatch): settings.accepted_types = setup_global_variable app_remote.app.eval_data_path = "/path/to/eval" @@ -107,15 +129,15 @@ def test_start_upload_and_main(qtbot, app_remote, setup_global_variable, monkeyp def custom_exec(self): return QMessageBox.Ok - monkeypatch.setattr(QMessageBox, 'exec', custom_exec) - qtbot.mouseClick(app_remote.start_button, Qt.LeftButton) + monkeypatch.setattr(QMessageBox, "exec", custom_exec) + qtbot.mouseClick(app_remote.start_button, Qt.LeftButton) # should close because error on upload! - assert app_remote.done_upload==False + assert app_remote.done_upload == False assert not app_remote.isVisible() - assert not hasattr(app_remote, 'mw') - + assert not hasattr(app_remote, "mw") + -'''' +"""' # TODO wait for github respose def test_browse_eval_clicked(qtbot, app, monkeypatch): # Mock the QFileDialog so that it immediately returns a directory @@ -162,4 +184,4 @@ def test_browse_inprogr_clicked(qtbot, app): # Check if the textbox is updated with the selected path assert app.inprogr_textbox.text() == app.app.inprogr_data_path -''' \ No newline at end of file +""" diff --git a/src/server/MANIFEST.in b/src/server/MANIFEST.in new file mode 100644 index 00000000..ffd67494 --- /dev/null +++ b/src/server/MANIFEST.in @@ -0,0 +1 @@ +include dcp_server/*.yaml \ No newline at end of file diff --git a/src/server/README.md b/src/server/README.md index 0b9fb5e7..4c19dadf 100644 --- a/src/server/README.md +++ b/src/server/README.md @@ -3,12 +3,19 @@ The server of our data centric platform for microscopy imaging. ![stability-wip](https://img.shields.io/badge/stability-work_in_progress-lightgrey.svg) +[![Documentation Status](https://readthedocs.org/projects/data-centric-platform/badge/?version=latest)](https://data-centric-platform.readthedocs.io/en/latest/?badge=latest) The client and server communicate via the [bentoml](https://www.bentoml.com/?gclid=Cj0KCQiApKagBhC1ARIsAFc7Mc6iqOLi2OcLtqMbGx1KrFjtLUEZ-bhnqlT2zWREE0x7JImhtNmKlFEaAvSSEALw_wcB) library. The client interacts with the server every time we run model inference or training, so the server should be running before starting the client. ## How to use? ### Installation +This has been tested on Python versions 3.9, 3.10 and 3.11 on latest versions of Windows, Ubuntu and MacOS. In your dedicated environment run: +``` +pip install dcp_server +``` + +### Installation for developers Before starting make sure you have navigated to ```data-centric-platform/src/server```. All future steps expect you are in the server directory. In your dedicated environment run: ``` pip install -e . @@ -21,61 +28,5 @@ python dcp_server/main.py ``` Once the server is running, you can verify it is working by visiting http://localhost:7010/ in your web browser. -## Customization (for developers) - -All service configurations are set in the _config.cfg_ file. Please, obey the [formal JSON format](https://www.json.org/json-en.html). - -The config file has to have the five main parts. All the ```marked``` arguments are mandatory: - - - ``` setup ``` - - ```segmentation ``` - segmentation type from the segmentationclasses.py. Currently, only **GeneralSegmentation** is available (MitoProjectSegmentation and GFPProjectSegmentation are stale). - - ```accepted_types``` - types of images currently accepted for the analysis - - ```seg_name_string``` - end string for masks to run on (All the segmentations of the image should contain this string - used to save and search for segmentations of the images) -- ```service``` - - ```model_to_use``` - name of the model class from the models.py you want to use. Currently, available models are: - - **CustomCellposeModel**: Inherits [CellposeModel](https://cellpose.readthedocs.io/en/latest/api.html#cellposemodel) class - - **CellposePatchCNN**: Includes a segmentor and a clasifier. Currently segmentor can only be ```CustomCellposeModel```, and classifier is ```CellClassifierFCNN```. The model sequentially runs the segmentor and then classifier, on patches of the objects to classify them. - - ```save_model_path``` - name for the trained model which will be saved after calling the (re)train from service - is saved under ```bentoml/models``` - - ```runner_name``` - name of the runner for the bentoml service - - ```service_name``` - name for the bentoml service - - ```port``` - on which port to start the service -- ```model``` - configuration for the model instatiation. Here, pass any arguments you need or want to change. Take care that the names of the arguments are the same as of original model class' _init()_ function! - - ```segmentor```: model configuration for the segmentor. Currently takes argumnets used in the init of CellposeModel, see [here](https://cellpose.readthedocs.io/en/latest/api.html#cellposemodel). - - ```classifier```: model configuration for classifier, see _init()_ of ```CellClassifierFCNN``` -- ```train``` - configuration for the model training. Take care that the names of the arguments are the same as of original model's _train()_ function! - - ```segmentor```: If using cellpose - the _train()_ function arguments can be found [here](https://cellpose.readthedocs.io/en/latest/api.html#id7). Here, pass any arguments you need or want to change or leave empty {}, then default arguments will be used. - - ```classifier```: train configuration for classifier, see _train()_ of ```CellClassifierFCNN``` -- ```eval``` - configuration for the model evaluation.. Take care that the names of the arguments are the same as of original model's _eval()_ function! - - ```segmentor```: If using cellpose - the _eval()_ function arguments can be found [here](https://cellpose.readthedocs.io/en/latest/api.html#id3). Here, pass any arguments you need or want to change or leave empty {}, then default arguments will be used. - - ```classifier```: train configuration for classifier, see _eval()_ of ```CellClassifierFCNN```. - - ```mask_channel_axis```: If a multi-class instance segmentation model has been used, then the masks returned by the model should have two channels, one for the instance segmentation results and one indicating the obects class. This variable indicated at which dim the channel axis should be stored. Currently should be kept at 0, as this is the only way the masks can be visualised correcly by napari in the client. - -To make it easier for you we provide you with two config files: ```config.cfg``` is set up to work for a panoptic segmentation task, while ```config_instance.cfg``` for instance segmentation. Make sure to rename the config you wish to use to ```config.cfg```. The default is panoptic segmentation. - -## Models -The current models are currently integrated into DCP: -* CellPose --> for instance segmentation tasks -* CellposePatchCNN --> for panoptic segmentation tasks: includes the Cellpose model for instance segmentation followed by a patch wise CNN model on the predicted instances for obtaining class labels - -## Running with Docker [DO NOT USE UNTIL ISSUE IS SOLVED] - -### Docker --> Currently doesn't work for generate labels? - -#### Docker-Compose -``` -docker compose up -``` -#### Docker Non-Interactively -``` -docker build -t dcp-server . -docker run -p 7010:7010 -it dcp-server -``` - -#### Docker Interactively -``` -docker build -t dcp-server . -docker run -it dcp-server bash -bentoml serve service:svc --reload --port=7010 -``` - - +## Want to know more? +Visit our [documentation](https://data-centric-platform.readthedocs.io/en/latest/dcp_server_installation.html) for more information on server configurations and available models. diff --git a/src/server/dcp_server/__init__.py b/src/server/dcp_server/__init__.py index e69de29b..a125355e 100644 --- a/src/server/dcp_server/__init__.py +++ b/src/server/dcp_server/__init__.py @@ -0,0 +1,24 @@ +""" +Overview of dcp_server Package +============================== + +The dcp_server package is structured to handle various server-side functionalities related model serving for segmentation and training. + +Submodules: +------------ + +dcp_server.models + Defines various models for cell classification and segmentation, including CellClassifierFCNN, CellClassifierShallowModel, CellposePatchCNN, CustomCellposeModel, and UNet. + These models handle tasks such as evaluation, forward pass, training, and updating configurations. + +dcp_server.segmentationclasses + Defines segmentation classes for specific projects, such as GFPProjectSegmentation, GeneralSegmentation, and MitoProjectSegmentation. + These classes contain methods for segmenting images and training models on images and masks. + +dcp_server.serviceclasses + Defines service classes, such as CustomBentoService and CustomRunnable, for serving the models with BentoML and handling computation on remote Python workers. + +dcp_server.utils + Provides various utility functions for dealing with image storage, image processing, feature extraction, file handling, configuration reading, and path manipulation. + +""" diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.yaml similarity index 65% rename from src/server/dcp_server/config.cfg rename to src/server/dcp_server/config.yaml index 24f44ea0..5652469a 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.yaml @@ -1,35 +1,40 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "CellposeMultichannel", - "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], - "seg_name_string": "_seg" + "model_to_use": "Inst2MultiSeg" }, "service": { "runner_name": "bento_runner", - "bento_model_path": "cp-multi", + "bento_model_path": "cells", "service_name": "data-centric-platform", "port": 7010 }, "model": { + "segmentor_name": "Cellpose", "segmentor": { "model_type": "cyto" }, + "classifier_name": "PatchClassifier", "classifier":{ - "model_class": "RandomForest", "in_channels": 1, "num_classes": 2, "features":[64,128,256,512], - "black_bg": "False", - "include_mask": "False" + "black_bg": False, + "include_mask": True } }, "data": { - "data_root": "data" + "data_root": "data", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg", + "patch_size": 64, + "noise_intensity": 5, + "gray": True, + "rescale": True }, "train":{ @@ -39,12 +44,7 @@ "min_train_masks": 1 }, "classifier":{ - "train_data":{ - "patch_size": 64, - "noise_intensity": 5, - "num_classes": 3 - }, - "n_epochs": 2, + "n_epochs": 20, "lr": 0.001, "batch_size": 1, "optimizer": "Adam" @@ -59,10 +59,6 @@ "batch_size": 1 }, "classifier": { - "data":{ - "patch_size": 64, - "noise_intensity": 5 - } }, "mask_channel_axis": 0 } diff --git a/src/server/dcp_server/config_instance.cfg b/src/server/dcp_server/config_instance.yaml similarity index 72% rename from src/server/dcp_server/config_instance.cfg rename to src/server/dcp_server/config_instance.yaml index da9cfd84..db266da0 100644 --- a/src/server/dcp_server/config_instance.cfg +++ b/src/server/dcp_server/config_instance.yaml @@ -1,14 +1,12 @@ { "setup": { "segmentation": "GeneralSegmentation", - "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], - "seg_name_string": "_seg" + "model_to_use": "CustomCellpose" }, "service": { - "model_to_use": "CustomCellposeModel", - "save_model_path": "cells", - "runner_name": "cellpose_runner", + "runner_name": "bento_runner", + "bento_model_path": "cells", "service_name": "data-centric-platform", "port": 7010 }, @@ -20,12 +18,16 @@ }, "data": { - "data_root": "data" + "data_root": "data", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg", + "gray": True, + "rescale": True }, "train":{ "segmentor":{ - "n_epochs": 10, + "n_epochs": 5, "channels": [0,0], "min_train_masks": 1 } diff --git a/src/server/dcp_server/config_semantic.yaml b/src/server/dcp_server/config_semantic.yaml new file mode 100644 index 00000000..e72459ac --- /dev/null +++ b/src/server/dcp_server/config_semantic.yaml @@ -0,0 +1,46 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "UNet" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "semantic-Unet", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "classifier":{ + "in_channels": 1, + "num_classes": 2, + "features":[64,128,256,512] + } + }, + + "data": { + "data_root": "data", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg", + "gray": True, + "rescale": True + }, + + "train":{ + "classifier":{ + "n_epochs": 2, + "lr": 0.001, + "batch_size": 1, + "optimizer": "Adam" + } + }, + + "eval":{ + "classifier": { + + }, + "compute_instance": True, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py deleted file mode 100644 index f4fbe8ef..00000000 --- a/src/server/dcp_server/fsimagestorage.py +++ /dev/null @@ -1,220 +0,0 @@ -import os -import numpy as np -from skimage.io import imread, imsave -from skimage.transform import resize, rescale - -from dcp_server import utils - -# Import configuration -setup_config = utils.read_config('setup', config_path = 'config.cfg') - -class FilesystemImageStorage(): - """Class used to deal with everything related to image storing and processing - loading, saving, transforming... - """ - def __init__(self, data_root, model_used): - self.root_dir = data_root - self.model_used = model_used - - def load_image(self, cur_selected_img, is_gray=True): - """Load the image (using skiimage) - - :param cur_selected_img: full path of the image that needs to be loaded - :type cur_selected_img: str - :return: loaded image - :rtype: ndarray - """ - try: - return imread(os.path.join(self.root_dir , cur_selected_img), as_gray=is_gray) - except ValueError: return None - - def save_image(self, to_save_path, img): - """Save given image (using skiimage) - - :param to_save_path: full path to the directory that the image needs to be save into (use also image name in the path, eg. '/users/new_image.png') - :type to_save_path: str - :param img: image you wish to save - :type img: ndarray - """ - imsave(os.path.join(self.root_dir, to_save_path), img) - - def search_images(self, directory): - """Get a list of full paths of the images in the directory - - :param directory: path to the directory to search images in - :type directory: str - :return: list of image paths found in the directory (only image types that are supported - see config.cfg 'setup' section) - :rtype: list - """ - # Take all segmentations of the image from the current directory: - directory = os.path.join(self.root_dir, directory) - seg_files = [file_name for file_name in os.listdir(directory) if setup_config['seg_name_string'] in file_name] - # Take the image files - difference between the list of all the files in the directory and the list of seg files and only file extensions currently accepted - image_files = [os.path.join(directory, file_name) for file_name in os.listdir(directory) if (file_name not in seg_files) and (utils.get_file_extension(file_name) in setup_config['accepted_types'])] - return image_files - - def search_segs(self, cur_selected_img): - """Returns a list of full paths of segmentations for an image - - :param cur_selected_img: full path of the image which segmentations we need to find - :type cur_selected_img: str - :return: list segmentation paths for the given image - :rtype: list - """ - # Check the directory the image was selected from: - img_directory = utils.get_path_parent(os.path.join(self.root_dir, cur_selected_img)) - # Take all segmentations of the image from the current directory: - search_string = utils.get_path_stem(cur_selected_img) + setup_config['seg_name_string'] - #seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] - # TODO: check where this is used - copied the command from app's search_segs function (to fix the 1_seg and 11_seg bug) - - seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] - - return seg_files - - def get_image_seg_pairs(self, directory): - """Get pairs of (image, image_seg) - Used, e.g., in training to create training data-training labels pairs - - :param directory: path to the directory to search images and segmentations in - :type directory: str - :return: list of tuple pairs (image, image_seg) - :rtype: list - """ - image_files = self.search_images(os.path.join(self.root_dir, directory)) - seg_files = [] - for image in image_files: - seg = self.search_segs(image) - #TODO - the search seg returns all the segs, but here we need only one, hence the seg[0]. Check if it is from training path? - seg_files.append(seg[0]) - return list(zip(image_files, seg_files)) - - def get_unsupported_files(self, directory): - """Get unsupported files found in the given directory - - :param directory: direcory path to search for files in - :type directory: str - :return: list of unsupported files - :rtype: list - """ - - return [file_name for file_name in os.listdir(os.path.join(self.root_dir, directory)) - if not file_name.startswith('.') and utils.get_file_extension(file_name) not in setup_config['accepted_types']] - - def get_image_size_properties(self, img, file_extension): - """Get properties of the image size - - :param img: image (numpy array) - :type img: ndarray - :param file_extension: file extension of the image as saved in the directory - :type file_extension: str - :return: size properties: - - height - - width - - z_axis - - """ - - orig_size = img.shape - # png and jpeg will be RGB by default and 2D - # tif can be grayscale 2D or 3D [Z, H, W] - # image channels have already been removed in imread with is_gray=True - if file_extension in (".jpg", ".jpeg", ".png"): - height, width = orig_size[0], orig_size[1] - z_axis = None - elif file_extension in (".tiff", ".tif") and len(orig_size)==2: - height, width = orig_size[0], orig_size[1] - z_axis = None - # if we have 3 dimensions the [Z, H, W] - elif file_extension in (".tiff", ".tif") and len(orig_size)==3: - print('Warning: 3D image stack found. We are assuming your first dimension is your stack dimension. Please cross check this.') - height, width = orig_size[1], orig_size[2] - z_axis = 0 - else: - print('File not currently supported. See documentation for accepted types') - - return height, width, z_axis - - def rescale_image(self, img, height, width, channel_ax=None, order=2): - """rescale image - - :param img: image - :type img: ndarray - :param height: height of the image - :type height: int - :param width: width of the image - :type width: int - :param channel_ax: channel axis - :type channel_ax: int - :return: rescaled image - :rtype: ndarray - """ - if self.model_used == "UNet": - height_pad = (height//16 + 1)*16 - height - width_pad = (width//16 + 1)*16 - width - return np.pad(img, ((0, height_pad),(0, width_pad))) - else: - # Cellpose segmentation runs best with 512 size? TODO: check - max_dim = max(height, width) - rescale_factor = max_dim/512 - return rescale(img, 1/rescale_factor, order=order, channel_axis=channel_ax) - - def resize_mask(self, mask, height, width, channel_ax=None, order=2): - """resize the mask so it matches the original image size - - :param mask: image - :type mask: ndarray - :param height: height of the image - :type height: int - :param width: width of the image - :type width: int - :param order: from scikit-image - the order of the spline interpolation, default is 0 if image.dtype is bool and 1 otherwise. - :type order: int - :return: resized image - :rtype: ndarray - """ - - if self.model_used == "UNet": - # we assume an order C, H, W - if channel_ax is not None and channel_ax==0: - height_pad = mask.shape[1] - height - width_pad = mask.shape[2]- width - return mask[:, :-height_pad, :-width_pad] - elif channel_ax is not None and channel_ax==2: - height_pad = mask.shape[0] - height - width_pad = mask.shape[1]- width - return mask[:-height_pad, :-width_pad, :] - elif channel_ax is not None and channel_ax==1: - height_pad = mask.shape[2] - height - width_pad = mask.shape[0]- width - return mask[:-width_pad, :, :-height_pad] - - else: - if channel_ax is not None: - n_channel_dim = mask.shape[channel_ax] - output_size = [height, width] - output_size.insert(channel_ax, n_channel_dim) - else: output_size = [height, width] - return resize(mask, output_size, order=order) - - def prepare_images_and_masks_for_training(self, train_img_mask_pairs): - """Image and mask processing for training. - - :param train_img_mask_pairs: list pairs of (image, image_seg) (as returned by get_image_seg_pairs() function) - :type train_img_mask_pairs: list - :return: lists of processed images and masks - :rtype: list, list - """ - imgs=[] - masks=[] - for img_file, mask_file in train_img_mask_pairs: - img = self.load_image(img_file) - mask = imread(mask_file) - if self.model_used == "UNet": - # Unet only accepts image sizes divisable by 16 - height_pad = (img.shape[0]//16 + 1)*16 - img.shape[0] - width_pad = (img.shape[1]//16 + 1)*16 - img.shape[1] - img = np.pad(img, ((0, height_pad),(0, width_pad))) - mask = np.pad(mask, ((0, 0), (0, height_pad),(0, width_pad))) - imgs.append(img) - masks.append(mask) - return imgs, masks \ No newline at end of file diff --git a/src/server/dcp_server/main.py b/src/server/dcp_server/main.py index 9add94b8..9c149b5b 100644 --- a/src/server/dcp_server/main.py +++ b/src/server/dcp_server/main.py @@ -1,11 +1,14 @@ -import subprocess from os import path import sys -from utils import read_config +import subprocess + +from dcp_server.utils.helpers import read_config + -def main(): - '''entry point to bentoml - ''' +def main() -> None: + """ + Contains main functionality related to the server. + """ # global config_path # args = sys.argv # if len(args) > 1: @@ -14,21 +17,24 @@ def main(): # else: # config_path = 'config.cfg' - local_path = path.join(__file__, '..') + local_path = path.join(__file__, "..") dir_name = path.dirname(path.abspath(sys.argv[0])) - service_config = read_config('service', config_path = path.join(dir_name, 'config.cfg')) - port = str(service_config['port']) + service_config = read_config( + "service", config_path=path.join(dir_name, "config.yaml") + ) + port = str(service_config["port"]) - subprocess.run([ - "bentoml", - "serve", - '--working-dir', - local_path, - "service:svc", - "--reload", - "--port="+port, - ]) - + subprocess.run( + [ + "bentoml", + "serve", + "--working-dir", + local_path, + "service:svc", + "--reload", + "--port=" + port, + ] + ) if __name__ == "__main__": diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py deleted file mode 100644 index 187f6395..00000000 --- a/src/server/dcp_server/models.py +++ /dev/null @@ -1,816 +0,0 @@ -from cellpose import models, utils -import torch -from torch import nn -from torch.optim import Adam -from torch.utils.data import TensorDataset, DataLoader -from torchmetrics import F1Score -from copy import deepcopy -from tqdm import tqdm -import numpy as np -from scipy.ndimage import label -from skimage.measure import label as label_mask - - -from sklearn.ensemble import RandomForestClassifier -from sklearn.metrics import f1_score, log_loss -from sklearn.exceptions import NotFittedError - -from cellpose.metrics import aggregated_jaccard_index -from cellpose.dynamics import labels_to_flows -#from segment_anything import SamPredictor, sam_model_registry -#from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator - -from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset, create_dataset_for_rf - -class CustomCellposeModel(models.CellposeModel, nn.Module): - """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing - additional attributes and methods needed for this project. - """ - def __init__(self, model_config, train_config, eval_config, model_name): - """Constructs all the necessary attributes for the CustomCellposeModel. - The model inherits all attributes from the parent class, the init allows to pass any other argument that the parent class accepts. - Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted. - - :param model_config: dictionary passed from the config file with all the arguments for the __init__ function and model initialization - :type model_config: dict - :param train_config: dictionary passed from the config file with all the arguments for training function - :type train_config: dict - :param eval_config: dictionary passed from the config file with all the arguments for eval function - :type eval_config: dict - """ - - # Initialize the cellpose model - #super().__init__(**model_config["segmentor"]) - nn.Module.__init__(self) - models.CellposeModel.__init__(self, **model_config["segmentor"]) - self.mkldnn = False # otherwise we get error with saving model - self.train_config = train_config - self.eval_config = eval_config - self.loss = 1e6 - self.model_name = model_name - - def update_configs(self, train_config, eval_config): - """Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - - def eval_all_outputs(self, img): - """Get all outputs of the model when running eval. - - :param img: Input image for segmentation. - :type img: numpy.ndarray - :return: Probability mask for the input image. - :rtype: numpy.ndarray - """ - - return super().eval(x=img, **self.eval_config["segmentor"]) - - def eval(self, img): - """Evaluate the model - find mask of the given image - Calls the original eval function. - - :param img: image to evaluate on - :type img: np.ndarray - :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. - :rtype: np.ndarray - """ - return super().eval(x=img, **self.eval_config["segmentor"])[0] # 0 to take only mask - - def train(self, imgs, masks): - """Trains the given model - Calls the original train function. - - :param imgs: images to train on (training data) - :type imgs: List[np.ndarray] - :param masks: masks of the given images (training labels) - :type masks: List[np.ndarray] - """ - - if not isinstance(masks, np.ndarray): # TODO Remove: all these should be taken care of in fsimagestorage - masks = np.array(masks) - - if masks[0].shape[0] == 2: - masks = list(masks[:,0,...]) - super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) - - # compute loss and metric - true_bin_masks = [mask>0 for mask in masks] # get binary masks - true_flows = labels_to_flows(masks) # get cellpose flows - # get predicted flows and cell probability - pred_masks = [] - pred_flows = [] - true_lbl = [] - for idx, img in enumerate(imgs): - mask, flows, _ = super().eval(x=img, **self.eval_config["segmentor"]) - pred_masks.append(mask) - pred_flows.append(np.stack([flows[1][0], flows[1][1], flows[2]])) # stack cell probability map, horizontal and vertical flow - true_lbl.append(np.stack([true_bin_masks[idx], true_flows[idx][2], true_flows[idx][3]])) - - true_lbl = np.stack(true_lbl) - pred_flows=np.stack(pred_flows) - pred_flows = torch.from_numpy(pred_flows).float().to('cpu') - # compute loss, combination of mse for flows and bce for cell probability - self.loss = self.loss_fn(true_lbl, pred_flows) - self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) - - def masks_to_outlines(self, mask): - """ get outlines of masks as a 0-1 array - Calls the original cellpose.utils.masks_to_outlines function - - :param mask: int, 2D or 3D array, mask of an image - :type mask: ndarray - :return: outlines - :rtype: ndarray - """ - return utils.masks_to_outlines(mask) #[True, False] outputs - - -class CellClassifierFCNN(nn.Module): - - """Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - - """ - - def __init__(self, model_config, train_config, eval_config): - """Initialize the fully convolutional classifier. - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - """ - super().__init__() - - self.in_channels = model_config["classifier"].get("in_channels",1) - self.num_classes = model_config["classifier"].get("num_classes",3) - - self.train_config = train_config["classifier"] - self.eval_config = eval_config["classifier"] - - self.include_mask = model_config["classifier"]["include_mask"] - self.in_channels = self.in_channels + 1 if self.include_mask else self.in_channels - - self.layer1 = nn.Sequential( - nn.Conv2d(self.in_channels, 16, 3, 2, 5), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - - self.layer2 = nn.Sequential( - nn.Conv2d(16, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - - self.layer3 = nn.Sequential( - nn.Conv2d(64, 128, 3, 2, 4), - nn.BatchNorm2d(128), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - self.final_conv = nn.Conv2d(128, self.num_classes, 1) - self.pooling = nn.AdaptiveMaxPool2d(1) - - self.metric_fn = F1Score(num_classes=self.num_classes, task="multiclass") - - def update_configs(self, train_config, eval_config): - """ - Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - - def forward(self, x): - """ Performs forward pass of the CellClassifierFCNN. - - :param x: Input tensor. - :type x: torch.Tensor - :return: Output tensor after passing through the network. - :rtype: torch.Tensor - """ - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - - x = self.final_conv(x) - x = self.pooling(x) - x = x.view(x.size(0), -1) - return x - - def train (self, imgs, labels): - """Trains the given model - - :param imgs: List of input images with shape (3, dx, dy). - :type imgs: List[np.ndarray[np.uint8]] - :param labels: List of classification labels. - :type labels: List[int] - """ - - lr = self.train_config['lr'] - epochs = self.train_config['n_epochs'] - batch_size = self.train_config['batch_size'] - # optimizer_class = self.train_config['optimizer'] - - # Convert input images and labels to tensors - - # normalize images - imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] - # convert to tensor - imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) - imgs = torch.permute(imgs, (0, 3, 1, 2)) - # Your classification label mask - labels = torch.LongTensor([label for label in labels]) - - # Create a training dataset and dataloader - train_dataset = TensorDataset(imgs, labels) - train_dataloader = DataLoader(train_dataset, batch_size=batch_size) - - loss_fn = nn.CrossEntropyLoss() - optimizer = Adam(params=self.parameters(), lr=lr) #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') - # TODO check if we should replace self.parameters with super.parameters() - - for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"): - self.loss, self.metric = 0, 0 - for data in train_dataloader: - imgs, labels = data - - optimizer.zero_grad() - preds = self.forward(imgs) - - l = loss_fn(preds, labels) - l.backward() - optimizer.step() - self.loss += l.item() - - self.metric += self.metric_fn(preds, labels) - - self.loss /= len(train_dataloader) - self.metric /= len(train_dataloader) - - def eval(self, img): - """Evaluates the model on the provided image and return the predicted label. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: y_hat - predicted label. - :rtype: torch.Tensor - """ - # normalise - img = (img-np.min(img))/(np.max(img)-np.min(img)) - # convert to tensor - img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) - preds = self.forward(img) - y_hat = torch.argmax(preds, 1) - return y_hat - - -class CellposePatchCNN(nn.Module): - """ - Cellpose & patches of cells and then cnn to classify each patch - """ - - def __init__(self, model_config, train_config, eval_config, model_name): - """Constructs all the necessary attributes for the CellposePatchCNN - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str - """ - super().__init__() - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - self.include_mask = self.model_config["classifier"]["include_mask"] - self.model_name = model_name - self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN") - - # Initialize the cellpose model and the classifier - self.segmentor = CustomCellposeModel(self.model_config, - self.train_config, - self.eval_config, - "Cellpose") - - if self.classifier_class == "FCNN": - self.classifier = CellClassifierFCNN(self.model_config, - self.train_config, - self.eval_config) - - elif self.classifier_class == "RandomForest": - self.classifier = CellClassifierShallowModel(self.model_config, - self.train_config, - self.eval_config) - # make sure include mask is set to False if we are using the random forest model - self.include_mask = False - - def update_configs(self, train_config, eval_config): - """Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - - def train(self, imgs, masks): - """Trains the given model. First trains the segmentor and then the clasiffier. - - :param imgs: images to train on (training data) - :type imgs: List[np.ndarray] - :param masks: masks of the given images (training labels) - :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, - second channel classes, so [2, H, W] or [2, 3, H, W] for 3D - """ - # train cellpose - masks = np.array(masks) - masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks - self.segmentor.train(deepcopy(imgs), masks_instances) - # create patch dataset to train classifier - masks_classes = list(masks[:,1,...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] - patches, patch_masks, labels = create_patch_dataset(imgs, - masks_classes, - masks_instances, - noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], - max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"], - include_mask = self.include_mask) - x = patches - if self.classifier_class == "RandomForest": - x = create_dataset_for_rf(patches, patch_masks) - # train classifier - self.classifier.train(x, labels) - # and compute metric and loss - self.metric = (self.segmentor.metric + self.classifier.metric) / 2 - self.loss = (self.segmentor.loss + self.classifier.loss)/2 - - def eval(self, img): - """Evaluate the model on the provided image and return the final mask. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: Final mask containing instance mask and class masks. - :rtype: np.ndarray[np.uint16] - """ - # TBD we assume image is 2D [H, W] (see fsimage storage) - # The final mask which is returned should have - # first channel the output of cellpose and the rest are the class channels - with torch.no_grad(): - # get instance mask from segmentor - instance_mask = self.segmentor.eval(img) - # find coordinates of detected objects - class_mask = np.zeros(instance_mask.shape) - - max_patch_size = self.eval_config["classifier"]["data"]["patch_size"] - if max_patch_size is None: max_patch_size = find_max_patch_size(instance_mask) - noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] - - # get patches centered around detected objects - patches, patch_masks, instance_labels, _ = get_centered_patches(img, - instance_mask, - max_patch_size, - noise_intensity=noise_intensity, - include_mask=self.include_mask) - x = patches - if self.classifier_class == "RandomForest": - x = create_dataset_for_rf(patches, patch_masks) - # loop over patches and create classification mask - for idx in range(len(x)): - patch_class = self.classifier.eval(x[idx]) - # Assign predicted class to corresponding location in final_mask - patch_class = patch_class.item() if isinstance(patch_class, torch.Tensor) else patch_class - class_mask[instance_mask==instance_labels[idx]] = patch_class + 1 - # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) # size 2xHxW - - return final_mask - -class CellClassifierShallowModel: - """ - This class implements a shallow model for cell classification using scikit-learn. - """ - - def __init__(self, model_config, train_config, eval_config): - """Constructs all the necessary attributes for the CellClassifierShallowModel - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - """ - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - - self.model = RandomForestClassifier() # TODO chnage config so RandomForestClassifier accepts input params - - - def train(self, X_train, y_train): - """Trains the model using the provided training data. - - :param X_train: Features of the training data. - :type X_train: numpy.ndarray - :param y_train: Labels of the training data. - :type y_train: numpy.ndarray - """ - - self.model.fit(X_train,y_train) - - y_hat = self.model.predict(X_train) - y_hat_proba = self.model.predict_proba(X_train) - - self.metric = f1_score(y_train, y_hat, average='micro') - # Binary Cross Entrop Loss - self.loss = log_loss(y_train, y_hat_proba) - - - def eval(self, X_test): - """Evaluates the model on the provided test data. - - :param X_test: Features of the test data. - :type X_test: numpy.ndarray - :return: y_hat - predicted labels. - :rtype: numpy.ndarray - """ - - X_test = X_test.reshape(1,-1) - - try: - y_hat = self.model.predict(X_test) - except NotFittedError as e: - y_hat = np.zeros(X_test.shape[0]) - - return y_hat - -class UNet(nn.Module): - - """ - Unet is a convolutional neural network architecture for semantic segmentation. - - :param in_channels: Number of input channels (default: 3). - :type in_channels: int - :param out_channels: Number of output channels (default: 4). - :type out_channels: int - :param features: List of feature channels for each encoder level (default: [64,128,256,512]). - :type features: list - """ - - class DoubleConv(nn.Module): - """ - DoubleConv module consists of two consecutive convolutional layers with - batch normalization and ReLU activation functions. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize DoubleConv module. - - :param in_channels: Number of input channels. - :type in_channels: int - :param out_channels: Number of output channels. - :type out_channels: int - """ - - super().__init__() - - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(), - nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(), - ) - - def forward(self, x): - """Forward pass through the DoubleConv module. - - :param x: Input tensor. - :type x: torch.Tensor - """ - return self.conv(x) - - - def __init__(self, model_config, train_config, eval_config, model_name): - """Constructs all the necessary attributes for the UNet model. - - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str - """ - - super().__init__() - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - self.model_name = model_name - ''' - self.in_channels = self.model_config["unet"]["in_channels"] - self.out_channels = self.model_config["unet"]["out_channels"] - self.features = self.model_config["unet"]["features"] - ''' - self.in_channels = self.model_config["classifier"]["in_channels"] - self.out_channels = self.model_config["classifier"]["num_classes"] + 1 - self.features = self.model_config["classifier"]["features"] - - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - - # Encoder - for feature in self.features: - self.encoder.append( - UNet.DoubleConv(self.in_channels, feature) - ) - self.in_channels = feature - - # Decoder - for feature in self.features[::-1]: - self.decoder.append( - nn.ConvTranspose2d( - feature*2, feature, kernel_size=2, stride=2 - ) - ) - self.decoder.append( - UNet.DoubleConv(feature*2, feature) - ) - - self.bottle_neck = UNet.DoubleConv(self.features[-1], self.features[-1]*2) - self.output_conv = nn.Conv2d(self.features[0], self.out_channels, kernel_size=1) - - def forward(self, x): - """ - Forward pass of the UNet model. - - :param x: Input tensor. - :type x: torch.Tensor - :return: Output tensor. - :rtype: torch.Tensor - """ - skip_connections = [] - for encoder in self.encoder: - x = encoder(x) - skip_connections.append(x) - x = self.pool(x) - - x = self.bottle_neck(x) - skip_connections = skip_connections[::-1] - - for i in np.arange(len(self.decoder), step=2): - x = self.decoder[i](x) - skip_connection = skip_connections[i//2] - concatenate_skip = torch.cat((skip_connection, x), dim=1) - x = self.decoder[i+1](concatenate_skip) - - return self.output_conv(x) - - def train(self, imgs, masks): - """ - Trains the UNet model using the provided images and masks. - - :param imgs: Input images for training. - :type imgs: list[numpy.ndarray] - :param masks: Masks corresponding to the input images. - :type masks: list[numpy.ndarray] - """ - - lr = self.train_config["classifier"]['lr'] - epochs = self.train_config["classifier"]['n_epochs'] - batch_size = self.train_config["classifier"]['batch_size'] - - # Convert input images and labels to tensors - # normalize images - imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] - # convert to tensor - imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) - imgs = imgs.unsqueeze(1) if imgs.ndim == 3 else imgs - - # Classification label mask - masks = np.array(masks) - masks = torch.stack([torch.from_numpy(mask[1].astype(np.int16)) for mask in masks]) - - # Create a training dataset and dataloader - train_dataset = TensorDataset(imgs, masks) - train_dataloader = DataLoader(train_dataset, batch_size=batch_size) - - loss_fn = nn.CrossEntropyLoss() - optimizer = Adam(params=self.parameters(), lr=lr) - - for _ in tqdm(range(epochs), desc="Running UNet training"): - - self.loss = 0 - - for imgs, masks in train_dataloader: - imgs = imgs.float() - masks = masks.long() - - #forward path - preds = self.forward(imgs) - loss = loss_fn(preds, masks) - - #backward path - optimizer.zero_grad() - loss.backward() - optimizer.step() - - self.loss += loss.detach().mean().item() - - self.loss /= len(train_dataloader) - - def eval(self, img): - """ - Evaluate the model on the provided image and return the predicted label. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: predicted mask consists of instance and class masks - :rtype: numpy.ndarray - """ - with torch.no_grad(): - # normalise - img = (img-np.min(img))/(np.max(img)-np.min(img)) - img = torch.from_numpy(img).float().unsqueeze(0) - - img = img.unsqueeze(1) if img.ndim == 3 else img - - preds = self.forward(img) - class_mask = torch.argmax(preds, 1).numpy()[0] - - instance_mask = label((class_mask > 0).astype(int))[0] - - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) - - return final_mask - -class CellposeMultichannel(): - ''' - Multichannel image segmentation model. - Run the separate CustomCellposeModel models for each channel return the mask corresponding to each object type. - ''' - - def __init__(self, model_config, train_config, eval_config, model_name="Cellpose"): - """Constructs all the necessary attributes for the CellposeMultichannel model. - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str - """ - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - self.model_name = model_name - self.num_of_channels = self.model_config["classifier"]["num_classes"] - - self.cellpose_models = [ - CustomCellposeModel(self.model_config, - self.train_config, - self.eval_config, - self.model_name - ) for _ in range(self.num_of_channels) - ] - - def train(self, imgs, masks): - """ - Train the model on the provided images and masks. - - :param imgs: Input images for training. - :type imgs: list[numpy.ndarray] - :param masks: Masks corresponding to the input images. - :type masks: list[numpy.ndarray] - """ - - for i in range(self.num_of_channels): - - masks_class = [] - - for mask in masks: - mask_class = mask.copy() - # set all instances in the instance mask not corresponding to the class in question to zero - mask_class[0][mask_class[1]!=(i+1)] = 0 - masks_class.append(mask_class) - - self.cellpose_models[i].train(imgs, masks_class) - - self.metric = np.mean([self.cellpose_models[i].metric for i in range(self.num_of_channels)]) - self.loss = np.mean([self.cellpose_models[i].loss for i in range(self.num_of_channels)]) - - - def eval(self, img): - """Evaluate the model on the provided image. The instance mask are computed as the union of the predicted model outputs, while the class of - each object is assigned based on majority voting between the models. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: predicted mask consists of instance and class masks - :rtype: numpy.ndarray - """ - - instance_masks, class_masks, model_confidences = [], [], [] - - for i in range(self.num_of_channels): - # get the instance mask and pixel-wise cell probability mask - instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) - confidence = probs[2] - # assign the appropriate class to all objects detected by this model - class_mask = np.zeros_like(instance_mask) - class_mask[instance_mask>0]=(i + 1) - - instance_masks.append(instance_mask) - class_masks.append(class_mask) - model_confidences.append(confidence) - # merge the outputs of the different models using the pixel-wise cell probability mask - merged_mask_instances, class_mask = self.merge_masks(instance_masks, class_masks, model_confidences) - # set all connected components to the same label in the instance mask - instance_mask = label_mask(merged_mask_instances>0) - # and set the class with the most pixels to that object - for inst_id in np.unique(instance_mask)[1:]: - where_inst_id = np.where(instance_mask==inst_id) - vals, counts = np.unique(class_mask[where_inst_id], return_counts=True) - class_mask[where_inst_id] = vals[np.argmax(counts)] - # take the final mask by stancking instance and class mask - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) - - return final_mask - - def merge_masks(self, inst_masks, class_masks, probabilities): - """Merges the instance and class masks resulting from the different models using the pixel-wise cell probability. The output of the model - with the maximum probability is selected for each pixel. - - :param inst_masks: List of predicted instance masks from each model. - :type inst_masks: List[np.array] - :param class_masks: List of corresponding class masks from each model. - :type class_masks: List[np.array] - :param probabilities: List of corresponding pixel-wise cell probability masks - :type probabilities: List[np.array] - :return: A tuple containing the following elements: - - final_mask_inst (numpy.ndarray): A single instance mask where for each pixel the output of the model with the highest probability is selected - - final_mask_class (numpy.ndarray): A single class mask where for each pixel the output of the model with the highest probability is selected - :rtype: tuple - """ - # Convert lists to numpy arrays - inst_masks = np.array(inst_masks) - class_masks = np.array(class_masks) - probabilities = np.array(probabilities) - - # Find the index of the mask with the maximum probability for each pixel - max_prob_indices = np.argmax(probabilities, axis=0) - - # Use the index to select the corresponding mask for each pixel - final_mask_inst = inst_masks[max_prob_indices, np.arange(inst_masks.shape[1])[:, None], np.arange(inst_masks.shape[2])] - final_mask_class = class_masks[max_prob_indices, np.arange(class_masks.shape[1])[:, None], np.arange(class_masks.shape[2])] - - return final_mask_inst, final_mask_class - - - - - - -# class CustomSAMModel(): -# # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb -# def __init__(self): -# pass diff --git a/src/server/dcp_server/models/__init__.py b/src/server/dcp_server/models/__init__.py new file mode 100644 index 00000000..eba3d089 --- /dev/null +++ b/src/server/dcp_server/models/__init__.py @@ -0,0 +1,8 @@ +# dcp_server.models/__init__.py + +from .custom_cellpose import CustomCellpose +from .inst_to_multi_seg import Inst2MultiSeg +from .multicellpose import MultiCellpose +from .unet import UNet + +__all__ = ["CustomCellpose", "Inst2MultiSeg", "MultiCellpose", "UNet"] diff --git a/src/server/dcp_server/models/classifiers.py b/src/server/dcp_server/models/classifiers.py new file mode 100644 index 00000000..43fed489 --- /dev/null +++ b/src/server/dcp_server/models/classifiers.py @@ -0,0 +1,233 @@ +from tqdm import tqdm +from typing import List +import numpy as np + +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import TensorDataset, DataLoader +from torchmetrics import F1Score + +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import f1_score, log_loss +from sklearn.exceptions import NotFittedError + + +class PatchClassifier(nn.Module): + """Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP""" + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Initialize the fully convolutional classifier. + + :param model_name: Name of the model. + :type model_name: str + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configuration. + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + super().__init__() + + self.model_name = model_name + self.model_config = model_config["classifier"] + self.data_config = data_config + self.train_config = train_config["classifier"] + self.eval_config = eval_config["classifier"] + + self.build_model() + + def train(self, imgs: List[np.ndarray], labels: List[np.ndarray]) -> None: + """Trains the given model + + :param imgs: List of input images with shape (3, dx, dy). + :type imgs: List[np.ndarray[np.uint8]] + :param labels: List of classification labels. + :type labels: List[int] + """ + + # Convert input images and labels to tensors + imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) + imgs = torch.permute(imgs, (0, 3, 1, 2)) + # Your classification label mask + labels = torch.LongTensor([label for label in labels]) + + # Create a training dataset and dataloader + train_dataloader = DataLoader( + TensorDataset(imgs, labels), batch_size=self.train_config["batch_size"] + ) + + loss_fn = nn.CrossEntropyLoss() + optimizer = Adam(params=self.parameters(), lr=self.train_config["lr"]) + # optimizer_class = self.train_config["optimizer"] + # eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') + + # TODO check if we should replace self.parameters with super.parameters() + for _ in tqdm( + range(self.train_config["n_epochs"]), + desc="Running PatchClassifier training", + ): + + self.loss, self.metric = 0, 0 + for data in train_dataloader: + imgs, labels = data + + optimizer.zero_grad() + preds = self.forward(imgs) + + l = loss_fn(preds, labels) + l.backward() + optimizer.step() + self.loss += l.item() + + self.metric += self.metric_fn(preds, labels) + + self.loss /= len(train_dataloader) + self.metric /= len(train_dataloader) + + def eval(self, img: np.ndarray) -> torch.Tensor: + """Evaluates the model on the provided image and return the predicted label. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: y_hat - predicted label. + :rtype: torch.Tensor + """ + # convert to tensor + img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze( + 0 + ) + preds = self.forward(img) + y_hat = torch.argmax(preds, 1) + return y_hat + + def build_model(self) -> None: + """Builds the PatchClassifer.""" + in_channels = self.model_config["in_channels"] + in_channels = ( + in_channels + 1 if self.model_config["include_mask"] else in_channels + ) + + self.layer1 = nn.Sequential( + nn.Conv2d(in_channels, 16, 3, 2, 5), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer2 = nn.Sequential( + nn.Conv2d(16, 64, 3, 1, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer3 = nn.Sequential( + nn.Conv2d(64, 128, 3, 2, 4), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + self.final_conv = nn.Conv2d(128, self.model_config["num_classes"], 1) + self.pooling = nn.AdaptiveMaxPool2d(1) + + self.metric_fn = F1Score( + num_classes=self.model_config["num_classes"], task="multiclass" + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of the PatchClassifier. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Output tensor after passing through the network. + :rtype: torch.Tensor + """ + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.final_conv(x) + x = self.pooling(x) + x = x.view(x.size(0), -1) + return x + + +class FeatureClassifier: + """This class implements a shallow model for cell classification using scikit-learn.""" + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the FeatureClassifier + + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configuration. + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + + self.model_name = model_name + self.model_config = model_config["classifier"] # use for initialising model + # self.data_config = data_config + # self.train_config = train_config + # self.eval_config = eval_config + + self.model = RandomForestClassifier( + **self.model_config + ) # TODO chnage config so RandomForestClassifier accepts input params + + def train(self, X_train: List[np.ndarray], y_train: List[np.ndarray]) -> None: + """Trains the model using the provided training data. + + :param X_train: Features of the training data. + :type X_train: numpy.ndarray + :param y_train: Labels of the training data. + :type y_train: numpy.ndarray + """ + self.model.fit(X_train, y_train) + + y_hat = self.model.predict(X_train) + y_hat_proba = self.model.predict_proba(X_train) + + # Binary Cross Entrop Loss + self.loss = log_loss(y_train, y_hat_proba) + self.metric = f1_score(y_train, y_hat, average="micro") + + def eval(self, X_test: np.ndarray) -> np.ndarray: + """Evaluates the model on the provided test data. + + :param X_test: Features of the test data. + :type X_test: numpy.ndarray + :return: y_hat - predicted labels. + :rtype: numpy.ndarray + """ + + X_test = X_test.reshape(1, -1) + + try: + y_hat = self.model.predict(X_test) + except NotFittedError as e: + y_hat = np.zeros(X_test.shape[0]) + + return y_hat diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py new file mode 100644 index 00000000..b41d04bb --- /dev/null +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -0,0 +1,150 @@ +from copy import deepcopy +from typing import List +import numpy as np + +import torch +from torch import nn + +from cellpose import models, utils +from cellpose.metrics import aggregated_jaccard_index +from cellpose.dynamics import labels_to_flows + +from .model import Model + + +class CustomCellpose(models.CellposeModel, Model): + """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing + additional attributes and methods needed for this project. + """ + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the CustomCellpose. + The model inherits all attributes from the parent class, the init allows to pass any other argument that the parent class accepts. + Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted. + + :param model_name: The name of the current model + :type model_name: str + :param model_config: dictionary passed from the config file with all the arguments for the __init__ function and model initialization + :type model_config: dict + :param data_config: dictionary passed from the config file with all the data configurations + :type data_config: dict + :param train_config: dictionary passed from the config file with all the arguments for training function + :type train_config: dict + :param eval_config: dictionary passed from the config file with all the arguments for eval function + :type eval_config: dict + """ + + # Initialize the cellpose model + # super().__init__(**model_config["segmentor"]) + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) + models.CellposeModel.__init__(self, **model_config["segmentor"]) + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + self.model_name = model_name + self.mkldnn = False # otherwise we get error with saving model + self.loss = 1e6 + self.metric = 0 + + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: + """Trains the given model + Calls the original train function. + + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] + """ + if self.train_config["segmentor"]["n_epochs"] == 0: + return + super().train( + train_data=deepcopy(imgs), # Cellpose changes the images + train_labels=masks, + **self.train_config["segmentor"] + ) + pred_masks, pred_flows, true_flows = self.compute_masks_flows(imgs, masks) + # get loss, combination of mse for flows and bce for cell probability + self.loss = self.loss_fn(true_flows, pred_flows) + self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) + + def eval(self, img: np.ndarray) -> np.ndarray: + """Evaluate the model - find mask of the given image + Calls the original eval function. + + :param img: image to evaluate on + :type img: np.ndarray + :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. + :rtype: np.ndarray + """ + # 0 to take only mask - inline with other models eval should always return the final mask + return super().eval(x=img, **self.eval_config["segmentor"])[0] + + def eval_all_outputs(self, img: np.ndarray) -> tuple: + """Get all outputs of the model when running eval. + + :param img: Input image for segmentation. + :type img: numpy.ndarray + :return: mask, flows, styles etc. Returns the same as cellpose.models.CellposeModel.eval - see Cellpose API Guide for more details. + :rtype: tuple + """ + + return super().eval(x=img, **self.eval_config["segmentor"]) + + # I introduced typing here as suggest by the docstring + def compute_masks_flows( + self, imgs: List[np.ndarray], masks: List[np.ndarray] + ) -> tuple: + """Computes instance, binary mask and flows in x and y - needed for loss and metric computations + + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] + :return: A tuple containing the following elements: + - pred_masks List [np.ndarray]: A list of predicted instance masks + - pred_flows (torch.Tensor): A tensor holding the stacked predicted cell probability map, horizontal and vertical flows for all images + - true_lbl (np.ndarray): A numpy array holding the stacked true binary mask, horizontal and vertical flows for all images + :rtype: tuple + """ + # compute for loss and metric + true_bin_masks = [mask > 0 for mask in masks] # get binary masks + true_flows = labels_to_flows(masks) # get cellpose flows + # get predicted flows and cell probability + pred_masks = [] + pred_flows = [] + true_lbl = [] + for idx, img in enumerate(imgs): + mask, flows, _ = super().eval(x=img, **self.eval_config["segmentor"]) + pred_masks.append(mask) + pred_flows.append( + np.stack([flows[1][0], flows[1][1], flows[2]]) + ) # stack cell probability map, horizontal and vertical flow + true_lbl.append( + np.stack([true_bin_masks[idx], true_flows[idx][2], true_flows[idx][3]]) + ) + + true_lbl = np.stack(true_lbl) + pred_flows = np.stack(pred_flows) + pred_flows = torch.from_numpy(pred_flows).float().to("cpu") + return pred_masks, pred_flows, true_lbl + + def masks_to_outlines(self, mask: np.ndarray) -> np.ndarray: + """get outlines of masks as a 0-1 array + Calls the original cellpose.utils.masks_to_outlines function + + :param mask: int, 2D or 3D array, mask of an image + :type mask: ndarray + :return: outlines + :rtype: ndarray + """ + return utils.masks_to_outlines(mask) # [True, False] outputs diff --git a/src/server/dcp_server/models/inst_to_multi_seg.py b/src/server/dcp_server/models/inst_to_multi_seg.py new file mode 100644 index 00000000..43c3db01 --- /dev/null +++ b/src/server/dcp_server/models/inst_to_multi_seg.py @@ -0,0 +1,175 @@ +from typing import List + +import numpy as np +import torch + +from .model import Model +from .custom_cellpose import CustomCellpose +from dcp_server.models.classifiers import PatchClassifier, FeatureClassifier +from dcp_server.utils.processing import ( + get_centered_patches, + find_max_patch_size, + create_patch_dataset, + create_dataset_for_rf, +) + +# Dictionary mapping class names to their corresponding classes + +segmentor_mapping = {"Cellpose": CustomCellpose} +classifier_mapping = { + "PatchClassifier": PatchClassifier, + "RandomForest": FeatureClassifier, +} + + +class Inst2MultiSeg(Model): + """A two stage model for: 1. instance segmentation and 2. object wise classification""" + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the Inst2MultiSeg + + :param model_name: Name of the model. + :type model_name: str + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configurations + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + # super().__init__() + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) + + self.model_name = model_name + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + + self.segmentor_class = self.model_config.get("segmentor_name", "Cellpose") + self.classifier_class = self.model_config.get( + "classifier_name", "PatchClassifier" + ) + + # Initialize the cellpose model and the classifier + segmentor = segmentor_mapping.get(self.segmentor_class) + self.segmentor = segmentor( + self.segmentor_class, + self.model_config, + self.data_config, + self.train_config, + self.eval_config, + ) + classifier = classifier_mapping.get(self.classifier_class) + self.classifier = classifier( + self.classifier_class, + self.model_config, + self.data_config, + self.train_config, + self.eval_config, + ) + + # make sure include mask is set to False if we are using the random forest model + if self.classifier_class == "RandomForest": + if ( + "include_mask" not in self.model_config["classifier"].keys() + or self.model_config["classifier"]["include_mask"] is True + ): + # print("Include mask=True was found, but for Random Forest, this parameter must be set to False. Doing this now.") + self.model_config["classifier"]["include_mask"] = False + + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: + """Trains the given model. First trains the segmentor and then the clasiffier. + + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, + second channel classes, so [2, H, W] or [2, 3, H, W] for 3D. + """ + # train cellpose + masks_instances = [mask[0] for mask in masks] + # masks_instances = list(np.array(masks)[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks + self.segmentor.train(imgs, masks_instances) + masks_classes = [mask[1] for mask in masks] + # create patch dataset to train classifier + # masks_classes = list( + # masks[:,1,...] + # ) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + x, patch_masks, labels = create_patch_dataset( + imgs, + masks_classes, + masks_instances, + noise_intensity=self.data_config["noise_intensity"], + max_patch_size=self.data_config["patch_size"], + include_mask=self.model_config["classifier"]["include_mask"], + ) + # additionally extract features from the patches if you are in RF model + if self.classifier_class == "RandomForest": + x = create_dataset_for_rf(x, patch_masks) + # train classifier + self.classifier.train(x, labels) + # and compute metric and loss + self.metric = (self.segmentor.metric + self.classifier.metric) / 2 + self.loss = (self.segmentor.loss + self.classifier.loss) / 2 + + def eval(self, img: np.ndarray) -> np.ndarray: + """Evaluate the model on the provided image and return the final mask. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: Final mask containing instance mask and class masks. + :rtype: np.ndarray[np.uint16] + """ + # TBD we assume image is 2D [H, W] (see fsimage storage) + # The final mask which is returned should have + # first channel the output of cellpose and the rest are the class channels + with torch.no_grad(): + # get instance mask from segmentor + instance_mask = self.segmentor.eval(img) + # find coordinates of detected objects + class_mask = np.zeros(instance_mask.shape) + + max_patch_size = self.data_config["patch_size"] + if max_patch_size is None: + max_patch_size = find_max_patch_size(instance_mask) + + # get patches centered around detected objects + x, patch_masks, instance_labels, _ = get_centered_patches( + img, + instance_mask, + max_patch_size, + noise_intensity=self.data_config["noise_intensity"], + include_mask=self.model_config["classifier"]["include_mask"], + ) + if self.classifier_class == "RandomForest": + x = create_dataset_for_rf(x, patch_masks) + # loop over patches and create classification mask + for idx in range(len(x)): + patch_class = self.classifier.eval(x[idx]) + # Assign predicted class to corresponding location in final_mask + patch_class = ( + patch_class.item() + if isinstance(patch_class, torch.Tensor) + else patch_class + ) + class_mask[instance_mask == instance_labels[idx]] = patch_class + 1 + # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 + final_mask = np.stack( + (instance_mask, class_mask), axis=self.eval_config["mask_channel_axis"] + ).astype( + np.uint16 + ) # size 2xHxW + + return final_mask diff --git a/src/server/dcp_server/models/model.py b/src/server/dcp_server/models/model.py new file mode 100644 index 00000000..3cda12c1 --- /dev/null +++ b/src/server/dcp_server/models/model.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod +from typing import List +import numpy as np + + +class Model(ABC): + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + + self.model_name = model_name + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + + self.loss = 1e6 + self.metric = 0 + + @abstractmethod + def train(self, imgs: List[np.array], masks: List[np.array]) -> None: + pass + + @abstractmethod + def eval(self, img: np.array) -> np.array: + pass + + ''' + def update_configs(self, + config: dict, + ctype: str + ) -> None: + """ Update the training or evaluation configurations. + + :param config: Dictionary containing the updated configuration. + :type config: dict + :param ctype:type of config to update, will be train or eval + :type ctype: str + """ + if ctype=='train': self.train_config = config + else: self.eval_config = config + ''' + + +# from segment_anything import SamPredictor, sam_model_registry +# from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator +# class CustomSAMModel(): +# # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb +# def __init__(self): +# pass diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py new file mode 100644 index 00000000..5ece6b97 --- /dev/null +++ b/src/server/dcp_server/models/multicellpose.py @@ -0,0 +1,165 @@ +from typing import List +import numpy as np +from skimage.measure import label as label_mask + +from .model import Model +from .custom_cellpose import CustomCellpose + + +class MultiCellpose(Model): + """ + Multichannel image segmentation model. + Run the separate CustomCellpose models for each channel return the mask corresponding to each object type. + """ + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the MultiCellpose model. + + :param model_name: Name of the model. + :type model_name: str + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) + + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + self.model_name = model_name + self.num_of_channels = self.model_config["classifier"]["num_classes"] + + self.cellpose_models = [ + CustomCellpose( + "Cellpose", + self.model_config, + self.data_config, + self.train_config, + self.eval_config, + ) + for _ in range(self.num_of_channels) + ] + + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: + """ + Train the model on the provided images and masks. + + :param imgs: Input images for training. + :type imgs: list[numpy.ndarray] + :param masks: Masks corresponding to the input images. + :type masks: list[numpy.ndarray] + """ + + for i in range(self.num_of_channels): + + masks_class = [] + + for mask in masks: + mask_class = mask[0].copy() # TODO - Do we need copy?? + # set all instances in the instance mask not corresponding to the class in question to zero + mask_class[0][mask_class[1] != (i + 1)] = 0 + masks_class.append(mask_class) + self.cellpose_models[i].train(imgs, masks_class) + + self.metric = np.mean( + [self.cellpose_models[i].metric for i in range(self.num_of_channels)] + ) + self.loss = np.mean( + [self.cellpose_models[i].loss for i in range(self.num_of_channels)] + ) + + def eval(self, img: np.ndarray) -> np.ndarray: + """Evaluate the model on the provided image. The instance mask are computed as the union of the predicted model outputs, while the class of + each object is assigned based on majority voting between the models. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: predicted mask consists of instance and class masks + :rtype: numpy.ndarray + """ + + instance_masks, class_masks, model_confidences = [], [], [] + + for i in range(self.num_of_channels): + # get the instance mask and pixel-wise cell probability mask + instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) + confidence_map = probs[2] + # assign the appropriate class to all objects detected by this model + class_mask = np.zeros_like(instance_mask) + class_mask[instance_mask > 0] = i + 1 + + instance_masks.append(instance_mask) + class_masks.append(class_mask) + model_confidences.append(confidence_map) + # merge the outputs of the different models using the pixel-wise cell probability mask + merged_mask_instances, class_mask = self.merge_masks( + instance_masks, class_masks, model_confidences + ) + # set all connected components to the same label in the instance mask + instance_mask = label_mask(merged_mask_instances > 0) + # and set the class with the most pixels to that object + for inst_id in np.unique(instance_mask)[1:]: + where_inst_id = np.where(instance_mask == inst_id) + vals, counts = np.unique(class_mask[where_inst_id], return_counts=True) + class_mask[where_inst_id] = vals[np.argmax(counts)] + # take the final mask by stancking instance and class mask + final_mask = np.stack( + (instance_mask, class_mask), axis=self.eval_config["mask_channel_axis"] + ).astype(np.uint16) + + return final_mask + + def merge_masks( + self, + inst_masks: List[np.ndarray], + class_masks: List[np.ndarray], + probabilities: List[np.ndarray], + ) -> tuple: + """Merges the instance and class masks resulting from the different models using the pixel-wise cell probability. The output of the model + with the maximum probability is selected for each pixel. + + :param inst_masks: List of predicted instance masks from each model. + :type inst_masks: List[np.array] + :param class_masks: List of corresponding class masks from each model. + :type class_masks: List[np.array] + :param probabilities: List of corresponding pixel-wise cell probability masks + :type probabilities: List[np.array] + :return: A tuple containing the following elements: + - final_mask_inst (numpy.ndarray): A single instance mask where for each pixel the output of the model with the highest probability is selected + - final_mask_class (numpy.ndarray): A single class mask where for each pixel the output of the model with the highest probability is selected + :rtype: tuple + """ + # Convert lists to numpy arrays + inst_masks = np.array(inst_masks) + class_masks = np.array(class_masks) + probabilities = np.array(probabilities) + + # Find the index of the mask with the maximum probability for each pixel + max_prob_indices = np.argmax(probabilities, axis=0) + + # Use the index to select the corresponding mask for each pixel + final_mask_inst = inst_masks[ + max_prob_indices, + np.arange(inst_masks.shape[1])[:, None], + np.arange(inst_masks.shape[2]), + ] + final_mask_class = class_masks[ + max_prob_indices, + np.arange(class_masks.shape[1])[:, None], + np.arange(class_masks.shape[2]), + ] + + return final_mask_inst, final_mask_class diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py new file mode 100644 index 00000000..9d85a5f7 --- /dev/null +++ b/src/server/dcp_server/models/unet.py @@ -0,0 +1,235 @@ +from typing import List +from tqdm import tqdm +import numpy as np +from scipy.ndimage import label + +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import TensorDataset, DataLoader +from torchmetrics import JaccardIndex + +from .model import Model +from dcp_server.utils.processing import convert_to_tensor + + +class UNet(nn.Module, Model): + """ + Unet is a convolutional neural network architecture for semantic segmentation. + + :param in_channels: Number of input channels (default: 3). + :type in_channels: int + :param out_channels: Number of output channels (default: 4). + :type out_channels: int + :param features: List of feature channels for each encoder level (default: [64,128,256,512]). + :type features: list + """ + + class DoubleConv(nn.Module): + """ + DoubleConv module consists of two consecutive convolutional layers with + batch normalization and ReLU activation functions. + """ + + def __init__(self, in_channels: int, out_channels: int) -> None: + """ + Initialize DoubleConv module. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + """ + + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the DoubleConv module. + + :param x: Input tensor. + :type x: torch.Tensor + """ + return self.conv(x) + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the UNet model. + + :param model_name: Name of the model. + :type model_name: str + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configurations + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) + nn.Module.__init__(self) + # super().__init__() + + self.model_name = model_name + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + + self.loss = 1e6 + self.metric = 0 + self.num_classes = self.model_config["classifier"]["num_classes"] + 1 + self.metric_f = JaccardIndex( + task="multiclass", num_classes=self.num_classes, average="macro", ignore_index=0 + ) + + self.build_model() + + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: + """ + Trains the UNet model using the provided images and masks. + + :param imgs: Input images for training. + :type imgs: list[numpy.ndarray] + :param masks: Masks corresponding to the input images. + :type masks: list[numpy.ndarray] + """ + + imgs = convert_to_tensor(imgs, np.float32) + masks = convert_to_tensor( + [mask[1] for mask in masks], np.int16, unsqueeze=False + ) + + # Create a training dataset and dataloader + train_dataloader = DataLoader( + TensorDataset(imgs, masks), + batch_size=self.train_config["classifier"]["batch_size"], + ) + + loss_fn = nn.CrossEntropyLoss() + optimizer = Adam( + params=self.parameters(), lr=self.train_config["classifier"]["lr"] + ) + + for _ in tqdm( + range(self.train_config["classifier"]["n_epochs"]), + desc="Running UNet training", + ): + + self.loss = 0 + + for imgs, masks in train_dataloader: + # forward path + preds = self.forward(imgs.float()) + loss = loss_fn(preds, masks.long()) + + # backward path + optimizer.zero_grad() + loss.backward() + optimizer.step() + + self.loss += loss.detach().mean().item() + + self.loss /= len(train_dataloader) + + # compute metric on test set after train is complete + for imgs, masks in train_dataloader: + pred_masks = self.forward(imgs.float()) + self.metric += self.metric_f(pred_masks, masks) + self.metric /= len(train_dataloader) + + def eval(self, img: np.ndarray) -> np.ndarray: + """Evaluate the model on the provided image and return the predicted label. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: predicted mask consists of instance and class masks + :rtype: numpy.ndarray + """ + with torch.no_grad(): + + # img = torch.from_numpy(img).float().unsqueeze(0) + # img = img.unsqueeze(1) if img.ndim == 3 else img + img = convert_to_tensor([img], np.float32) + + preds = self.forward(img) + class_mask = torch.argmax(preds, 1).numpy()[0] + if self.eval_config["compute_instance"] is True: + instance_mask = label((class_mask > 0).astype(int))[0] + final_mask = np.stack( + [instance_mask, class_mask], + axis=self.eval_config["mask_channel_axis"], + ).astype(np.uint16) + else: + final_mask = class_mask.astype(np.uint16) + + return final_mask + + def build_model(self) -> None: + """Builds the UNet.""" + in_channels = self.model_config["classifier"]["in_channels"] + out_channels = self.num_classes + features = self.model_config["classifier"]["features"] + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + # Encoder + for feature in features: + self.encoder.append(UNet.DoubleConv(in_channels, feature)) + in_channels = feature + + # Decoder + for feature in features[::-1]: + self.decoder.append( + nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2) + ) + self.decoder.append(UNet.DoubleConv(feature * 2, feature)) + + self.bottle_neck = UNet.DoubleConv(features[-1], features[-1] * 2) + self.output_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the UNet model. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Output tensor. + :rtype: torch.Tensor + """ + skip_connections = [] + for encoder in self.encoder: + x = encoder(x) + skip_connections.append(x) + x = self.pool(x) + + x = self.bottle_neck(x) + skip_connections = skip_connections[::-1] + + for i in np.arange(len(self.decoder), step=2): + x = self.decoder[i](x) + skip_connection = skip_connections[i // 2] + concatenate_skip = torch.cat((skip_connection, x), dim=1) + x = self.decoder[i + 1](concatenate_skip) + + return self.output_conv(x) diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index ccb5fff8..b3897ff7 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -1,73 +1,85 @@ -from dcp_server import utils import os -# Import configuration -setup_config = utils.read_config('setup', config_path = 'config.cfg') +from dcp_server.utils import helpers +from dcp_server.utils.fsimagestorage import FilesystemImageStorage +from dcp_server import models as DCPModels -class GeneralSegmentation(): - """Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images. - """ - def __init__(self, imagestorage, runner, model): - """Constructs all the necessary attributes for the GeneralSegmentation. + +class GeneralSegmentation: + """Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images.""" + + def __init__( + self, imagestorage: FilesystemImageStorage, runner, model: DCPModels + ) -> None: + """Constructs all the necessary attributes for the GeneralSegmentation. :param imagestorage: imagestorage system used (see fsimagestorage.py) :type imagestorage: FilesystemImageStorage class object :param runner: runner used in the service :type runner: CustomRunnable class object - :param model: model used for segmentation + :param model: model used for segmentation :type model: class object from the models.py - """ + """ self.imagestorage = imagestorage - self.runner = runner + self.runner = runner self.model = model self.no_files_msg = "No image-label pairs found in curated directory" - - async def segment_image(self, input_path, list_of_images): + + async def segment_image(self, input_path: str, list_of_images: str) -> None: """Segments images from the given directory - :param input_path: directory where the images are saved + :param input_path: directory where the images are saved and where segmentation results will be saved :type input_path: str :param list_of_images: list of image objects from the directory that are currently supported :type list_of_images: list - """ + """ for img_filepath in list_of_images: - # Load the image - img = self.imagestorage.load_image(img_filepath) - # Get size properties - height, width, z_axis = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) - img = self.imagestorage.rescale_image(img, height, width) + img = self.imagestorage.prepare_img_for_eval(img_filepath) # Add channel ax into the model's evaluation parameters dictionary - self.model.eval_config['segmentor']['z_axis'] = z_axis + if self.imagestorage.model_used != "UNet": + self.model.eval_config["segmentor"][ + "channel_axis" + ] = self.imagestorage.channel_ax # Evaluate the model - mask = await self.runner.evaluate.async_run(img = img) - # Resize the mask - mask = self.imagestorage.resize_mask(mask, height, width, self.model.eval_config['mask_channel_axis'], order=0) + mask = await self.runner.evaluate.async_run(img=img) + # And prepare the mask for saving + mask = self.imagestorage.prepare_mask_for_save( + mask, self.model.eval_config["mask_channel_axis"] + ) # Save segmentation - seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' + seg_name = ( + helpers.get_path_stem(img_filepath) + + self.imagestorage.seg_name_string + + ".tiff" + ) self.imagestorage.save_image(os.path.join(input_path, seg_name), mask) - async def train(self, input_path): - """train model on images and masks in the given input directory. + async def train(self, input_path: str) -> str: + """Train model on images and masks in the given input directory. Calls the runner's train function. :param input_path: directory where the images are saved :type input_path: str :return: runner's train function output - path of the saved model :rtype: str - """ + """ train_img_mask_pairs = self.imagestorage.get_image_seg_pairs(input_path) if not train_img_mask_pairs: return self.no_files_msg - - imgs, masks = self.imagestorage.prepare_images_and_masks_for_training(train_img_mask_pairs) - model_save_path = await self.runner.train.async_run(imgs, masks) + + imgs, masks = self.imagestorage.prepare_images_and_masks_for_training( + train_img_mask_pairs + ) + model_save_path = await self.runner.train.async_run(imgs, masks) return model_save_path +''' + class GFPProjectSegmentation(GeneralSegmentation): def __init__(self, imagestorage, runner): super().__init__(imagestorage, runner) @@ -78,11 +90,11 @@ async def segment_image(self, input_path, list_of_images): class MitoProjectSegmentation(GeneralSegmentation): - """Segmentation class inheriting the attributes and functions from the original GeneralSegmentation and implementing + """ Segmentation class inheriting the attributes and functions from the original GeneralSegmentation and implementing additional attributes and methods needed for this project. """ def __init__(self, imagestorage, runner, model): - """Constructs all the necessary attributes for the MitoProjectSegmentation. Inherits all from the GeneralSegmentation + """ Constructs all the necessary attributes for the MitoProjectSegmentation. Inherits all from the GeneralSegmentation :param imagestorage: imagestorage system used (see fsimagestorage.py) :type imagestorage: FilesystemImageStorage class object @@ -95,7 +107,7 @@ def __init__(self, imagestorage, runner, model): # The only difference is in segment image async def segment_image(self, input_path, list_of_images): - """Segments images from the given directory. + """ Segments images from the given directory. The function differs from the parent class' function in obtaining the outlines of the masks. :param input_path: directory where the images are saved @@ -108,7 +120,7 @@ async def segment_image(self, input_path, list_of_images): # Load the image img = self.imagestorage.load_image(img_filepath) # Get size properties - height, width, channel_ax = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) + height, width, channel_ax = self.imagestorage.get_image_size_properties(img, helpers.get_file_extension(img_filepath)) img = self.imagestorage.rescale_image(img, height, width, channel_ax) # Add channel ax into the model's evaluation parameters dictionary @@ -128,5 +140,6 @@ async def segment_image(self, input_path, list_of_images): new_mask[outlines==True] = 1 # Save segmentation - seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' + seg_name = helpers.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), new_mask) +''' diff --git a/src/server/dcp_server/service.py b/src/server/dcp_server/service.py index 0be4fb0d..d464545b 100644 --- a/src/server/dcp_server/service.py +++ b/src/server/dcp_server/service.py @@ -1,9 +1,10 @@ from __future__ import annotations import bentoml import typing as t -from dcp_server.fsimagestorage import FilesystemImageStorage from dcp_server.serviceclasses import CustomBentoService, CustomRunnable -from dcp_server.utils import read_config + +from dcp_server.utils.fsimagestorage import FilesystemImageStorage +from dcp_server.utils.helpers import read_config import sys, inspect @@ -11,29 +12,46 @@ segmentation_module = __import__("segmentationclasses") # Import configuration -service_config = read_config('service', config_path = 'config.cfg') -model_config = read_config('model', config_path = 'config.cfg') -data_config = read_config('data', config_path = 'config.cfg') -train_config = read_config('train', config_path = 'config.cfg') -eval_config = read_config('eval', config_path = 'config.cfg') -setup_config = read_config('setup', config_path = 'config.cfg') +service_config = read_config("service", config_path="config.yaml") +model_config = read_config("model", config_path="config.yaml") +data_config = read_config("data", config_path="config.yaml") +train_config = read_config("train", config_path="config.yaml") +eval_config = read_config("eval", config_path="config.yaml") +setup_config = read_config("setup", config_path="config.yaml") # instantiate the model -model_class = getattr(models_module, setup_config['model_to_use']) -model = model_class(model_config = model_config, train_config = train_config, eval_config = eval_config, model_name=setup_config['model_to_use']) +model_class = getattr(models_module, setup_config["model_to_use"]) +model = model_class( + model_name=setup_config["model_to_use"], + model_config=model_config, + data_config=data_config, + train_config=train_config, + eval_config=eval_config, +) custom_model_runner = t.cast( - "CustomRunner", bentoml.Runner(CustomRunnable, name=service_config['runner_name'], - runnable_init_params={"model": model, "save_model_path": service_config['bento_model_path']}) + "CustomRunner", + bentoml.Runner( + CustomRunnable, + name=service_config["runner_name"], + runnable_init_params={ + "model": model, + "save_model_path": service_config["bento_model_path"], + }, + ), ) # instantiate the segmentation type -segm_class = getattr(segmentation_module, setup_config['segmentation']) -fsimagestorage = FilesystemImageStorage(data_config['data_root'], setup_config['model_to_use']) -segmentation = segm_class(imagestorage=fsimagestorage, - runner = custom_model_runner, - model = model) +segm_class = getattr(segmentation_module, setup_config["segmentation"]) +fsimagestorage = FilesystemImageStorage(data_config, setup_config["model_to_use"]) +segmentation = segm_class( + imagestorage=fsimagestorage, runner=custom_model_runner, model=model +) # Call the service -service = CustomBentoService(runner=segmentation.runner, segmentation=segmentation, service_name=service_config['service_name']) -svc = service.start_service() \ No newline at end of file +service = CustomBentoService( + runner=segmentation.runner, + segmentation=segmentation, + service_name=service_config["service_name"], +) +svc = service.start_service() diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index 72a1e351..bb1b8e30 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -5,24 +5,26 @@ from typing import List from dcp_server import models as DCPModels +import dcp_server.segmentationclasses as DCPSegClasses class CustomRunnable(bentoml.Runnable): - ''' + """ BentoML, Runner represents a unit of computation that can be executed on a remote Python worker and scales independently. CustomRunnable is a custom runner defined to meet all the requirements needed for this project. - ''' - SUPPORTED_RESOURCES = ("cpu",) #TODO add here? + """ + + SUPPORTED_RESOURCES = ("cpu",) # TODO add here? SUPPORTS_CPU_MULTI_THREADING = False - def __init__(self, model, save_model_path): + def __init__(self, model: DCPModels, save_model_path: str) -> None: """Constructs all the necessary attributes for the CustomRunnable. :param model: model to be trained or evaluated - will be one of classes in models.py :param save_model_path: full path of the model object that it will be saved into :type save_model_path: str - """ - + """ + self.model = model self.save_model_path = save_model_path # update with the latest model if it already exists to continue training from there? @@ -44,12 +46,20 @@ def evaluate(self, img: np.ndarray) -> np.ndarray: mask = self.model.eval(img=img) return mask - - def check_and_load_model(self): + + def check_and_load_model(self) -> None: + """Checks if the specified model exists in BentoML's model repository. + If the model exists, it loads the latest version of the model into + memory. + """ bento_model_list = [model.tag.name for model in bentoml.models.list()] if self.save_model_path in bento_model_list: - loaded_model = bentoml.picklable_model.load_model(self.save_model_path+":latest") - assert loaded_model.__class__.__name__ == self.model.__class__.__name__, 'Check your config, loaded model and model to use not the same!' + loaded_model = bentoml.picklable_model.load_model( + self.save_model_path + ":latest" + ) + assert ( + loaded_model.__class__.__name__ == self.model.__class__.__name__ + ), "Check your config, loaded model and model to use not the same!" self.model = loaded_model @bentoml.Runnable.method(batchable=False) @@ -62,14 +72,14 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: :type masks: List[np.ndarray] :return: path of the saved model :rtype: str - """ + """ self.model.train(imgs, masks) # Save the bentoml model bentoml.picklable_model.save_model( - self.save_model_path, + self.save_model_path, self.model, external_modules=[DCPModels], - ) + ) # bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store # self.model, # Model instance being saved # external_modules=[DCPModels] @@ -77,42 +87,49 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: return self.save_model_path -class CustomBentoService(): - """BentoML Service class. Contains all the functions necessary to serve the service with BentoML - """ - def __init__(self, runner, segmentation, service_name): + +class CustomBentoService: + """BentoML Service class. Contains all the functions necessary to serve the service with BentoML""" + + def __init__( + self, runner: CustomRunnable, segmentation: DCPSegClasses, service_name: str + ) -> None: """Constructs all the necessary attributes for the class CustomBentoService(): :param runner: runner used in the service :type runner: CustomRunnable class object :param segmentation: segmentation type used in the service :type segmentation: segmentation class object from the segmentationclasses.py - :param service_name: name of the service + :param service_name: name of the service :type service_name: str - """ + """ self.runner = runner self.segmentation = segmentation self.service_name = service_name - def start_service(self): + def start_service(self) -> None: """Starts the service :return: service object needed in service.py and for the bentoml serve call. - """ + """ svc = bentoml.Service(self.service_name, runners=[self.runner]) - @svc.api(input=Text(), output=NumpyNdarray()) #input path to the image output message with success and the save path - async def segment_image(input_path: str): + @svc.api( + input=Text(), output=NumpyNdarray() + ) # input path to the image output message with success and the save path + async def segment_image(input_path: str) -> np.ndarray: """function served within the service, used to segment images :param input_path: directory where the images for segmentation are saved :type input_path: str :return: list of files not supported :rtype: ndarray - """ + """ list_of_images = self.segmentation.imagestorage.search_images(input_path) - list_of_files_not_suported = self.segmentation.imagestorage.get_unsupported_files(input_path) - + list_of_files_not_suported = ( + self.segmentation.imagestorage.get_unsupported_files(input_path) + ) + if not list_of_images: return np.array(list_of_images) else: @@ -121,20 +138,19 @@ async def segment_image(input_path: str): return np.array(list_of_files_not_suported) @svc.api(input=Text(), output=Text()) - async def train(input_path): + async def train(input_path: str) -> str: """function served within the service, used to retrain the model :param input_path: directory where the images for training are saved :type input_path: str :return: message of success if training went well :rtype: str - """ + """ print("Calling retrain from server.") # Train the model msg = await self.segmentation.train(input_path) - if msg!=self.segmentation.no_files_msg: + if msg != self.segmentation.no_files_msg: msg = "Success! Trained model saved in: " + msg return msg - + return svc - diff --git a/src/server/dcp_server/utils.py b/src/server/dcp_server/utils.py deleted file mode 100644 index 7e6818e8..00000000 --- a/src/server/dcp_server/utils.py +++ /dev/null @@ -1,344 +0,0 @@ -from pathlib import Path -import json -from copy import deepcopy -import numpy as np -from scipy.ndimage import find_objects -from skimage import measure -from copy import deepcopy -import SimpleITK as sitk -from radiomics import shape2D - -def read_config(name, config_path = 'config.cfg') -> dict: - """Reads the configuration file - - :param name: name of the section you want to read (e.g. 'setup','train') - :type name: string - :param config_path: path to the configuration file, defaults to 'config.cfg' - :type config_path: str, optional - :return: dictionary from the config section given by name - :rtype: dict - """ - with open(config_path) as config_file: - config_dict = json.load(config_file) - # Check if config file has main mandatory keys - assert all([i in config_dict.keys() for i in ['setup', 'service', 'model', 'train', 'eval']]) - return config_dict[name] - -def get_path_stem(filepath): return str(Path(filepath).stem) - - -def get_path_name(filepath): return str(Path(filepath).name) - - -def get_path_parent(filepath): return str(Path(filepath).parent) - - -def join_path(root_dir, filepath): return str(Path(root_dir, filepath)) - - -def get_file_extension(file): return str(Path(file).suffix) - - -def crop_centered_padded_patch(img: np.ndarray, - patch_center_xy, - patch_size, - obj_label, - mask: np.ndarray=None, - noise_intensity=None) -> np.ndarray: - """ - Crop a patch from an array `x` centered at coordinates `c` with size `p`, and apply padding if necessary. - - Args: - img (np.ndarray): The input array from which the patch will be cropped. - patch_center_xy (tuple): The coordinates (row, column) at the center of the patch. - patch_size (tuple): The size of the patch to be cropped (height, width). - obj_label (int): The instance label of the mask at the patch - mask (np.ndarray, optional): The mask array that asociated with the array x; - mask is used during training to mask out non-central elements; - for RandomForest, it is used to calculate pyradiomics features. - noise_intensity (float, optional): Intensity of noise to be added to the background. - - Returns: - np.ndarray: The cropped patch with applied padding. - """ - - height, width = patch_size # Size of the patch - img_height, img_width = img.shape[0], img.shape[1] # Size of the input image - - # Calculate the boundaries of the patch - top = patch_center_xy[0] - height // 2 - bottom = top + height - left = patch_center_xy[1] - width // 2 - right = left + width - - # Crop the patch from the input array - if mask is not None: - mask_ = mask.max(-1) if len(mask.shape) >= 3 else mask - # Zero out values in the patch where the mask is not equal to the central label - mask_other_objs = (mask_ != obj_label) & (mask_ > 0) - img[mask_other_objs] = 0 - # Add random noise at locations where other objects are present if noise_intensity is given - if noise_intensity is not None: img[mask_other_objs] = np.random.normal(scale=noise_intensity, size=img[mask_other_objs].shape) - mask[mask_other_objs] = 0 - # crop the mask - mask = mask[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] - - patch = img[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] - # Calculate the required padding amounts and apply padding if necessary - if left < 0: - patch = np.hstack(( - np.random.normal(scale=noise_intensity, size=(patch.shape[0], abs(left), patch.shape[2])).astype(np.uint8), - patch)) - if mask is not None: - mask = np.hstack(( - np.zeros((mask.shape[0], abs(left), mask.shape[2])).astype(np.uint8), - mask)) - # Apply padding on the right side if necessary - if right > img_width: - patch = np.hstack(( - patch, - np.random.normal(scale=noise_intensity, size=(patch.shape[0], (right - img_width), patch.shape[2])).astype(np.uint8))) - if mask is not None: - mask = np.hstack(( - mask, - np.zeros((mask.shape[0], (right - img_width), mask.shape[2])).astype(np.uint8))) - # Apply padding on the top side if necessary - if top < 0: - patch = np.vstack(( - np.random.normal(scale=noise_intensity, size=(abs(top), patch.shape[1], patch.shape[2])).astype(np.uint8), - patch)) - if mask is not None: - mask = np.vstack(( - np.zeros((abs(top), mask.shape[1], mask.shape[2])).astype(np.uint8), - mask)) - # Apply padding on the bottom side if necessary - if bottom > img_height: - patch = np.vstack(( - patch, - np.random.normal(scale=noise_intensity, size=(bottom - img_height, patch.shape[1], patch.shape[2])).astype(np.uint8))) - if mask is not None: - mask = np.vstack(( - mask, - np.zeros((bottom - img_height, mask.shape[1], mask.shape[2])).astype(np.uint8))) - - return patch, mask - - -def get_center_of_mass_and_label(mask: np.ndarray) -> np.ndarray: - """ - Compute the centers of mass for each object in a mask. - - Args: - mask (np.ndarray): The input mask containing labeled objects. - - Returns: - list of tuples: A list of coordinates (row, column) representing the centers of mass for each object. - list of ints: Holds the label for each object in the mask - """ - - # Compute the centers of mass for each labeled object in the mask - ''' - return [(int(x[0]), int(x[1])) - for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1))] - ''' - centers = [] - labels = [] - for region in measure.regionprops(mask): - center = region.centroid - centers.append((int(center[0]), int(center[1]))) - labels.append(region.label) - return centers, labels - - - -def get_centered_patches(img, - mask, - p_size: int, - noise_intensity=5, - mask_class=None, - include_mask=False): - - ''' - Extracts centered patches from the input image based on the centers of objects identified in the mask. - - Args: - img (np.array): The input image. - mask (np.array): The mask representing the objects in the image. - p_size (int): The size of the patches to extract. - noise_intensity (float): The intensity of noise to add to the patches. - mask_class (np.array): The mask representing the classes of the objects in the image. - include_mask (bool): Whether or not to include mask as input argument to model. - - ''' - - patches, patch_masks, instance_labels, class_labels = [], [], [], [] - # if image is 2D add an additional dim for channels - if img.ndim<3: img = img[:, :, np.newaxis] - if mask.ndim<3: mask = mask[:, :, np.newaxis] - # compute center of mass of objects - centers_of_mass, instance_labels = get_center_of_mass_and_label(mask) - # Crop patches around each center of mass - for c, obj_label in zip(centers_of_mass, instance_labels): - c_x, c_y = c - patch, patch_mask = crop_centered_padded_patch(img.copy(), - (c_x, c_y), - (p_size, p_size), - obj_label, - mask=deepcopy(mask), - noise_intensity=noise_intensity) - if include_mask: - patch_mask = 255 * (patch_mask > 0).astype(np.uint8) - patch = np.concatenate((patch, patch_mask), axis=-1) - - patches.append(patch) - patch_masks.append(patch_mask) - if mask_class is not None: - # get the class instance for the specific object - instance_labels.append(obj_label) - class_l = np.unique(mask_class[mask[:,:,0]==obj_label]) - assert class_l.shape[0] == 1, "ERROR"+str(class_l) - class_l = int(class_l[0]) - #-1 because labels from mask start from 1, we want classes to start from 0 - class_labels.append(class_l-1) - - return patches, patch_masks, instance_labels, class_labels - -def get_objects(mask): - return find_objects(mask) - -def find_max_patch_size(mask): - - # Find objects in the mask - objects = get_objects(mask) - - # Initialize variables to store the maximum patch size - max_patch_size = 0 - - # Iterate over the found objects - for obj in objects: - # Extract start and stop values from the slice object - slices = [s for s in obj] - start = [s.start for s in slices] - stop = [s.stop for s in slices] - - # Calculate the size of the patch along each axis - patch_size = tuple(stop[i] - start[i] for i in range(len(start))) - - # Calculate the total size (area) of the patch - total_size = 1 - for size in patch_size: - total_size *= size - - # Check if the current patch size is larger than the maximum - if total_size > max_patch_size: - max_patch_size = total_size - - max_patch_size_edge = np.ceil(np.sqrt(max_patch_size)) - - return max_patch_size_edge - -def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size, include_mask): - ''' - Splits img and masks into patches of equal size which are centered around the cells. - If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use - the max cell size to define the patch size. All patches and masks should then be returned - in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same - convention of dims, e.g. CxHxW) - include_mask(bool) : Flag indicating whether to include the mask along with patches. - ''' - - if max_patch_size is None: - max_patch_size = np.max([find_max_patch_size(mask) for mask in masks_instances]) - - patches, patch_masks, labels = [], [], [] - for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): - # mask_instance has dimension WxH - # mask_class has dimension WxH - patch, patch_mask, _, label = get_centered_patches(img, - mask_instance, - max_patch_size, - noise_intensity=noise_intensity, - mask_class=mask_class, - include_mask = include_mask) - patches.extend(patch) - patch_masks.extend(patch_mask) - labels.extend(label) - return patches, patch_masks, labels - - -def get_shape_features(img, mask): - """ - Calculate shape-based radiomic features from an image within the region defined by the mask. - - Args: - - img (np.ndarray): The input image. - - mask (np.ndarray): The mask corresponding to the image. - - Returns: - - np.ndarray: An array containing the calculated shape-based radiomic features, such as: - Elongation, Sphericity, Perimeter surface. - """ - - mask = 255 * ((mask) > 0).astype(np.uint8) - - image = sitk.GetImageFromArray(img.squeeze()) - roi_mask = sitk.GetImageFromArray(mask.squeeze()) - - shape_calculator = shape2D.RadiomicsShape2D(inputImage=image, inputMask=roi_mask, label=255) - # Calculate the shape-based radiomic features - shape_features = shape_calculator.execute() - - return np.array(list(shape_features.values())) - -def extract_intensity_features(image, mask): - """ - Extract intensity-based features from an image within the region defined by the mask. - - Args: - - image (np.ndarray): The input image. - - mask (np.ndarray): The mask defining the region of interest. - - Returns: - - np.ndarray: An array containing the extracted intensity-based features: - median intensity, mean intensity, 25th/75th percentile intensity within the masked region. - - """ - - features = {} - - # Ensure the image and mask have the same dimensions - - if image.shape != mask.shape: - raise ValueError("Image and mask must have the same dimensions") - - masked_image = image[(mask>0)] - # features["min_intensity"] = np.min(masked_image) - # features["max_intensity"] = np.max(masked_image) - features["median_intensity"] = np.median(masked_image) - features["mean_intensity"] = np.mean(masked_image) - features["25th_percentile_intensity"] = np.percentile(masked_image, 25) - features["75th_percentile_intensity"] = np.percentile(masked_image, 75) - - return np.array(list(features.values())) - -def create_dataset_for_rf(imgs, masks): - """ - Extract intensity-based features from an image within the region defined by the mask. - - Args: - - imgs (List): A list of all input images. - - mask (List): A list of all corresponding masks defining the region of interest. - - Returns: - - List: A list of arrays containing shape and intensity-based features - - """ - X = [] - for img, mask in zip(imgs, masks): - - shape_features = get_shape_features(img, mask) - intensity_features = extract_intensity_features(img, mask) - features_list = np.concatenate((shape_features, intensity_features), axis=0) - X.append(features_list) - - return X \ No newline at end of file diff --git a/src/server/dcp_server/utils/__init__.py b/src/server/dcp_server/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/server/dcp_server/utils/fsimagestorage.py b/src/server/dcp_server/utils/fsimagestorage.py new file mode 100644 index 00000000..d89025b3 --- /dev/null +++ b/src/server/dcp_server/utils/fsimagestorage.py @@ -0,0 +1,315 @@ +import os +from typing import Optional, List +import numpy as np +from skimage.io import imread, imsave +from skimage.transform import resize, rescale + +from dcp_server.utils import helpers +from dcp_server.utils.processing import pad_image, normalise + + +class FilesystemImageStorage: + """ + Class used to deal with everything related to image storing and processing - loading, saving, transforming. + """ + + def __init__(self, data_config: dict, model_used: str) -> None: + self.root_dir = data_config["data_root"] + self.seg_name_string = data_config["seg_name_string"] + self.accepted_types = data_config["accepted_types"] + self.gray = bool(data_config["gray"]) + self.rescale = bool(data_config["rescale"]) + self.model_used = model_used + self.channel_ax = None + self.img_height = None + self.img_width = None + + def load_image( + self, cur_selected_img: str, gray: Optional[bool] = None + ) -> Optional[np.ndarray]: + """Load the image (using skiimage) + + :param cur_selected_img: full path of the image that needs to be loaded + :type cur_selected_img: str + :param gray: whether to load the image as a grayscale or not + :type gray: bool or None, default=Nonee + :return: loaded image + :rtype: ndarray + """ + if gray is None: + gray = self.gray + try: + return imread(os.path.join(self.root_dir, cur_selected_img), as_gray=gray) + except ValueError: + return None + + def save_image(self, to_save_path: str, img: np.ndarray) -> None: + """Save given image using skimage. + + :param to_save_path: full path to the directory that the image needs to be save into (use also image name in the path, eg. '/users/new_image.png') + :type to_save_path: str + :param img: image you wish to save + :type img: ndarray + """ + imsave(os.path.join(self.root_dir, to_save_path), img) + + def search_images(self, directory: str) -> List[str]: + """Get a list of full paths of the images in the directory. + + :param directory: Path to the directory to search for images. + :type directory: str + :return: List of image paths found in the directory (only image types that are supported - see config.cfg 'setup' section). + :rtype: list + """ + # Take all segmentations of the image from the current directory: + directory = os.path.join(self.root_dir, directory) + seg_files = [ + file_name + for file_name in os.listdir(directory) + if self.seg_name_string in file_name + ] + # Take the image files - difference between the list of all the files in the directory and the list of seg files and only file extensions currently accepted + image_files = [ + os.path.join(directory, file_name) + for file_name in os.listdir(directory) + if (file_name not in seg_files) + and (helpers.get_file_extension(file_name) in self.accepted_types) + ] + return image_files + + def search_segs(self, cur_selected_img: str) -> List[str]: + """Returns a list of full paths of segmentations for an image. + + :param cur_selected_img: Full path of the image for which segmentations are needed. + :type cur_selected_img: str + :return: List of segmentation paths for the given image. + :rtype: list + """ + + # Check the directory the image was selected from: + img_directory = helpers.get_path_parent( + os.path.join(self.root_dir, cur_selected_img) + ) + # Take all segmentations of the image from the current directory: + search_string = helpers.get_path_stem(cur_selected_img) + self.seg_name_string + # seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] + # TODO: check where this is used - copied the command from app's search_segs function (to fix the 1_seg and 11_seg bug) + + seg_files = [ + os.path.join(img_directory, file_name) + for file_name in os.listdir(img_directory) + if ( + search_string == helpers.get_path_stem(file_name) + or str(file_name).startswith(search_string) + ) + ] + + return seg_files + + def get_image_seg_pairs(self, directory: str) -> List[tuple]: + """Get pairs of (image, image_seg). + + Used, e.g., in training to create training data-training labels pairs. + + :param directory: Path to the directory to search images and segmentations in. + :type directory: str + :return: List of tuple pairs (image, image_seg). + :rtype: list + """ + + image_files = self.search_images(os.path.join(self.root_dir, directory)) + seg_files = [] + for image in image_files: + seg = self.search_segs(image) + # TODO - the search seg returns all the segs, but here we need only one, hence the seg[0]. Check if it is from training path? + seg_files.append(seg[0]) + return list(zip(image_files, seg_files)) + + def get_unsupported_files(self, directory: str) -> List[str]: + """Get unsupported files found in the given directory. + + :param directory: Directory path to search for files in. + :type directory: str + :return: List of unsupported files. + :rtype: list + """ + return [ + file_name + for file_name in os.listdir(os.path.join(self.root_dir, directory)) + if not file_name.startswith(".") + and helpers.get_file_extension(file_name) not in self.accepted_types + ] + + def get_image_size_properties(self, img: np.ndarray, file_extension: str) -> None: + """Set properties of the image size + + :param img: Image (numpy array). + :type img: ndarray + :param file_extension: File extension of the image as saved in the directory. + :type file_extension: str + """ + # TODO simplify! + + orig_size = img.shape + # png and jpeg will be RGB by default and 2D + # tif can be grayscale 2D or 3D [Z, H, W] + # image channels have already been removed in imread if self.gray=True + # skimage.imread reads RGB or RGBA images in always with channel axis in dim=2 + if file_extension in (".jpg", ".jpeg", ".png") and self.gray == False: + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = 2 + elif file_extension in (".jpg", ".jpeg", ".png") and self.gray == True: + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = None + elif file_extension in (".tiff", ".tif") and len(orig_size) == 2: + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = None + # if we have 3 dimensions the [Z, H, W] + elif file_extension in (".tiff", ".tif") and len(orig_size) == 3: + print( + "Warning: 3D image stack found. We are assuming your last dimension is your channel dimension. Please cross check this." + ) + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = 2 + else: + print("File not currently supported. See documentation for accepted types") + + def rescale_image(self, img: np.ndarray, order: int = 2) -> np.ndarray: + """rescale image + + :param img: Image. + :type img: ndarray + :param order: Order of interpolation. + :type order: int + :return: Rescaled image. + :rtype: ndarray + """ + + if self.model_used == "UNet": + return pad_image( + img, self.img_height, self.img_width, self.channel_ax, dividable=16 + ) + else: + # Cellpose segmentation runs best with 512 size? TODO: check + max_dim = max(self.img_height, self.img_width) + rescale_factor = max_dim / 512 + return rescale( + img, 1 / rescale_factor, order=order, channel_axis=self.channel_ax + ) + + def resize_mask( + self, mask: np.ndarray, channel_ax: Optional[int] = None, order: int = 0 + ) -> np.ndarray: + """resize the mask so it matches the original image size + + :param mask: Image. + :type mask: ndarray + :param height: Height of the image. + :type height: int + :param width: Width of the image. + :type width: int + :param order: From scikit-image - the order of the spline interpolation. Default is 0 if image.dtype is bool and 1 otherwise. + :type order: int + :return: Resized image. + :rtype: ndarray + """ + + if self.model_used == "UNet": + # we assume an order C, H, W + if channel_ax is not None and channel_ax == 0: + height_pad = mask.shape[1] - self.img_height + width_pad = mask.shape[2] - self.img_width + return mask[:, :-height_pad, :-width_pad] + elif channel_ax is not None and channel_ax == 2: + height_pad = mask.shape[0] - self.img_height + width_pad = mask.shape[1] - self.img_width + return mask[:-height_pad, :-width_pad, :] + elif channel_ax is not None and channel_ax == 1: + height_pad = mask.shape[2] - self.img_height + width_pad = mask.shape[0] - self.img_width + return mask[:-width_pad, :, :-height_pad] + else: + height_pad = mask.shape[0] - self.img_height + width_pad = mask.shape[1] - self.img_width + return mask[:-height_pad, :-width_pad] + + else: + if channel_ax is not None: + n_channel_dim = mask.shape[channel_ax] + output_size = [self.img_height, self.img_width] + output_size.insert(channel_ax, n_channel_dim) + else: + output_size = [self.img_height, self.img_width] + return resize(mask, output_size, order=order) + + def prepare_images_and_masks_for_training( + self, train_img_mask_pairs: List[tuple] + ) -> tuple: + """Image and mask processing for training. + + :param train_img_mask_pairs: List pairs of (image, image_seg) (as returned by get_image_seg_pairs() function). + :type train_img_mask_pairs: list + :return: Lists of processed images and masks. + :rtype: tuple + """ + + imgs = [] + masks = [] + for img_file, mask_file in train_img_mask_pairs: + img = self.load_image(img_file) + img = normalise(img) + mask = self.load_image(mask_file, gray=False) + self.get_image_size_properties(img, helpers.get_file_extension(img_file)) + # Unet only accepts image sizes divisable by 16 + if self.model_used == "UNet": + img = pad_image( + img, + self.img_height, + self.img_width, + channel_ax=self.channel_ax, + dividable=16, + ) + mask = pad_image( + mask, self.img_height, self.img_width, channel_ax=0, dividable=16 + ) + if self.model_used == "CustomCellpose" and len(mask.shape) == 3: + # if we also have class mask drop it + mask = masks[0] # assuming mask_channel_axis=0 + imgs.append(img) + masks.append(mask) + return imgs, masks + + def prepare_img_for_eval(self, img_file: str) -> np.ndarray: + """Image processing for model inference. + + :param img_file: the path to the image + :type img_file: str + :return: the loaded and processed image + :rtype: np.ndarray + """ + # Load and normalise the image + img = self.load_image(img_file) + img = normalise(img) + # Get size properties + self.get_image_size_properties(img, helpers.get_file_extension(img_file)) + if self.rescale: + img = self.rescale_image(img) + return img + + def prepare_mask_for_save(self, mask: np.ndarray, channel_ax: int) -> np.ndarray: + """Prepares the mask output of the model to be saved. + + :param mask: the mask + :type mask: np.ndarray + :param channel_ax: the channel dimension of the mask + :rype channel_ax: int + :return: the ready to save mask + :rtype: np.ndarray + """ + # Resize the mask if rescaling took place before + if self.rescale is True: + if len(mask.shape) < 3: + channel_ax = None + return self.resize_mask(mask, channel_ax) + else: + return mask diff --git a/src/server/dcp_server/utils/helpers.py b/src/server/dcp_server/utils/helpers.py new file mode 100644 index 00000000..b4cb15c6 --- /dev/null +++ b/src/server/dcp_server/utils/helpers.py @@ -0,0 +1,46 @@ +from pathlib import Path +import yaml + + +def read_config(name: str, config_path: str) -> dict: + """Reads the configuration file + + :param name: name of the section you want to read (e.g. 'setup','train') + :type name: string + :param config_path: path to the configuration file + :type config_path: str + :return: dictionary from the config section given by name + :rtype: dict + """ + with open(config_path) as config_file: + config_dict = yaml.safe_load( + config_file + ) # json.load(config_file) for .cfg file + # Check if config file has main mandatory keys + assert all( + [ + i in config_dict.keys() + for i in ["setup", "service", "model", "train", "eval"] + ] + ) + return config_dict[name] + + +def get_path_stem(filepath: str) -> str: + return str(Path(filepath).stem) + + +def get_path_name(filepath: str) -> str: + return str(Path(filepath).name) + + +def get_path_parent(filepath: str) -> str: + return str(Path(filepath).parent) + + +def join_path(root_dir: str, filepath: str) -> str: + return str(Path(root_dir, filepath)) + + +def get_file_extension(file: str) -> str: + return str(Path(file).suffix) diff --git a/src/server/dcp_server/utils/processing.py b/src/server/dcp_server/utils/processing.py new file mode 100644 index 00000000..9c7f4b03 --- /dev/null +++ b/src/server/dcp_server/utils/processing.py @@ -0,0 +1,490 @@ +from copy import deepcopy +from typing import List, Optional, Union +import numpy as np + +from scipy.ndimage import find_objects +from skimage import measure +import SimpleITK as sitk +from radiomics import shape2D +import torch + + +def normalise(img: np.ndarray, norm: str = "min-max") -> np.ndarray: + """Normalises the image based on the chosen method. Currently available methods are: + - min max normalisation. + + :param img: image to be normalised + :type img: np.ndarray + :param norm: the normalisation method to apply + :type norm: str + :return: the normalised image + :rtype: np.ndarray + """ + if norm == "min-max": + return (img - np.min(img)) / (np.max(img) - np.min(img)) + + +def pad_image( + img: np.ndarray, + height: int, + width: int, + channel_ax: Optional[int] = None, + dividable: int = 16, +) -> np.ndarray: + """Pads the image such that it is dividable by a given number. + + :param img: image to be padded + :type img: np.ndarray + :param height: image height + :type height: int + :param width: image width + :type width: int + :param channel_ax: + :type channel_ax: int or None + :param dividable: the number with which the new image size should be perfectly dividable by + :type dividable: int + :return: the padded image + :rtype: np.ndarray + """ + height_pad = (height // dividable + 1) * dividable - height + width_pad = (width // dividable + 1) * dividable - width + if channel_ax == 0: + img = np.pad(img, ((0, 0), (0, height_pad), (0, width_pad))) + elif channel_ax == 2: + img = np.pad(img, ((0, height_pad), (0, width_pad), (0, 0))) + else: + img = np.pad(img, ((0, height_pad), (0, width_pad))) + return img + + +def convert_to_tensor( + imgs: List[np.ndarray], dtype: type, unsqueeze: bool = True +) -> torch.Tensor: + """Convert the imgs to tensors of type dtype and add extra dimension if input bool is true. + + :param imgs: the list of images to convert + :type img: List[np.ndarray] + :param dtype: the data type to convert the image tensor + :type dtype: type + :param unsqueeze: If True an extra dim will be added at location zero + :type unsqueeze: bool + :return: the converted image + :rtype: torch.Tensor + """ + # Convert images tensors + imgs = torch.stack([torch.from_numpy(img.astype(dtype)) for img in imgs]) + imgs = imgs.unsqueeze(1) if imgs.ndim == 3 and unsqueeze is True else imgs + return imgs + + +def crop_centered_padded_patch( + img: np.ndarray, + patch_center_xy: tuple, + patch_size: tuple, + obj_label: int, + mask: np.ndarray = None, + noise_intensity: int = None, +) -> np.ndarray: + """Crop a patch from an array centered at coordinates patch_center_xy with size patch_size, + and apply padding if necessary. + + :param img: the input array from which the patch will be cropped + :type img: np.ndarray + :param patch_center_xy: the coordinates (row, column) at the center of the patch + :type patch_center_xy: tuple + :param patch_size: the size of the patch to be cropped (height, width) + :type patch_size: tuple + :param obj_label: the instance label of the mask at the patch + :type obj_label: int + :param mask: The mask array associated with the array x. + Mask is used during training to mask out non-central elements. + For RandomForest, it is used to calculate pyradiomics features. + :type mask: np.ndarray, optional + :param noise_intensity: intensity of noise to be added to the background + :type noise_intensity: float, optional + :return: the cropped patch with applied padding + :rtype: np.ndarray + """ + + height, width = patch_size # Size of the patch + img_height, img_width = img.shape[0], img.shape[1] # Size of the input image + + # Calculate the boundaries of the patch + top = patch_center_xy[0] - height // 2 + bottom = top + height + left = patch_center_xy[1] - width // 2 + right = left + width + + # Crop the patch from the input array + if mask is not None: + mask_ = mask.max(-1) if len(mask.shape) >= 3 else mask + # Zero out values in the patch where the mask is not equal to the central label + mask_other_objs = (mask_ != obj_label) & (mask_ > 0) + img[mask_other_objs] = 0 + # Add random noise at locations where other objects are present if noise_intensity is given + if noise_intensity is not None: + img[mask_other_objs] = np.random.normal( + scale=noise_intensity, size=img[mask_other_objs].shape + ) + mask[mask_other_objs] = 0 + # crop the mask + mask = mask[ + max(top, 0) : min(bottom, img_height), + max(left, 0) : min(right, img_width), + :, + ] + + patch = img[ + max(top, 0) : min(bottom, img_height), max(left, 0) : min(right, img_width), : + ] + + # Calculate the required padding amounts and apply padding if necessary + if left < 0: + patch = np.hstack( + ( + np.random.normal( + scale=noise_intensity, + size=(patch.shape[0], abs(left), patch.shape[2]), + ).astype(np.uint8), + patch, + ) + ) + if mask is not None: + mask = np.hstack( + ( + np.zeros((mask.shape[0], abs(left), mask.shape[2])).astype( + np.uint8 + ), + mask, + ) + ) + # Apply padding on the right side if necessary + if right > img_width: + patch = np.hstack( + ( + patch, + np.random.normal( + scale=noise_intensity, + size=(patch.shape[0], (right - img_width), patch.shape[2]), + ).astype(np.uint8), + ) + ) + if mask is not None: + mask = np.hstack( + ( + mask, + np.zeros( + (mask.shape[0], (right - img_width), mask.shape[2]) + ).astype(np.uint8), + ) + ) + # Apply padding on the top side if necessary + if top < 0: + patch = np.vstack( + ( + np.random.normal( + scale=noise_intensity, + size=(abs(top), patch.shape[1], patch.shape[2]), + ).astype(np.uint8), + patch, + ) + ) + if mask is not None: + mask = np.vstack( + ( + np.zeros((abs(top), mask.shape[1], mask.shape[2])).astype(np.uint8), + mask, + ) + ) + # Apply padding on the bottom side if necessary + if bottom > img_height: + patch = np.vstack( + ( + patch, + np.random.normal( + scale=noise_intensity, + size=(bottom - img_height, patch.shape[1], patch.shape[2]), + ).astype(np.uint8), + ) + ) + if mask is not None: + mask = np.vstack( + ( + mask, + np.zeros( + (bottom - img_height, mask.shape[1], mask.shape[2]) + ).astype(np.uint8), + ) + ) + return patch, mask + + +def get_center_of_mass_and_label(mask: np.ndarray) -> tuple: + """Computes the centers of mass for each object in a mask. + + :param mask: the input mask containing labeled objects + :type mask: np.ndarray + :return: + - A list of tuples representing the coordinates (row, column) of the centers of mass for each object. + - A list of ints representing the labels for each object in the mask. + :rtype: + - List [tuple] + - List [int] + """ + + # Compute the centers of mass for each labeled object in the mask + + # return [(int(x[0]), int(x[1])) + # for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1))] + + centers = [] + labels = [] + for region in measure.regionprops(mask): + center = region.centroid + centers.append((int(center[0]), int(center[1]))) + labels.append(region.label) + return centers, labels + + +def get_centered_patches( + img: np.ndarray, + mask: np.ndarray, + p_size: int, + noise_intensity: int = 5, + mask_class: Optional[int] = None, + include_mask: bool = False, +) -> tuple: + """Extracts centered patches from the input image based on the centers of objects identified in the mask. + + :param img: The input image. + :type img: numpy.ndarray + :param mask: The mask representing the objects in the image. + :type mask: numpy.ndarray + :param p_size: The size of the patches to extract. + :type p_size: int + :param noise_intensity: The intensity of noise to add to the patches. + :type noise_intensity: float + :param mask_class: The class represented in the patch. + :type mask_class: int + :param include_mask: Whether or not to include the mask as an input argument to the model. + :type include_mask: bool + :return: A tuple containing the following elements: + - patches (numpy.ndarray): Extracted patches. + - patch_masks (numpy.ndarray): Masks corresponding to the extracted patches. + - instance_labels (list): Labels identifying each object instance in the extracted patches. + - class_labels (list): Labels identifying the class of each object instance in the extracted patches. + :rtype: tuple + """ + + patches, patch_masks, instance_labels, class_labels = [], [], [], [] + # if image is 2D add an additional dim for channels + if img.ndim < 3: + img = img[:, :, np.newaxis] + if mask.ndim < 3: + mask = mask[:, :, np.newaxis] + # compute center of mass of objects + centers_of_mass, instance_labels = get_center_of_mass_and_label(mask) + # Crop patches around each center of mass + for c, obj_label in zip(centers_of_mass, instance_labels): + c_x, c_y = c + patch, patch_mask = crop_centered_padded_patch( + img.copy(), + (c_x, c_y), + (p_size, p_size), + obj_label, + mask=deepcopy(mask), + noise_intensity=noise_intensity, + ) + if include_mask is True: + patch_mask = 255 * (patch_mask > 0).astype(np.uint8) + patch = np.concatenate((patch, patch_mask), axis=-1) + + patches.append(patch) + patch_masks.append(patch_mask) + if mask_class is not None: + # get the class instance for the specific object + instance_labels.append(obj_label) + class_l = np.unique(mask_class[mask[:, :, 0] == obj_label]) + assert class_l.shape[0] == 1, "ERROR" + str(class_l) + class_l = int(class_l[0]) + # -1 because labels from mask start from 1, we want classes to start from 0 + class_labels.append(class_l - 1) + + return patches, patch_masks, instance_labels, class_labels + + +def get_objects(mask: np.ndarray) -> List: + """Finds labeled connected components in a binary mask. + + :param mask: The binary mask representing objects. + :type mask: numpy.ndarray + :return: A list of slices indicating the bounding boxes of the found objects. + :rtype: list + """ + return find_objects(mask) + + +def find_max_patch_size(mask: np.ndarray) -> float: + """Finds the maximum patch size in a mask. + + :param mask: The binary mask representing objects. + :type mask: numpy.ndarray + :return: The maximum size of the bounding box edge for objects in the mask. + :rtype: float + """ + + # Find objects in the mask + objects = get_objects(mask) + + # Initialize variables to store the maximum patch size + max_patch_size = 0 + + # Iterate over the found objects + for obj in objects: + # Extract start and stop values from the slice object + slices = [s for s in obj] + start = [s.start for s in slices] + stop = [s.stop for s in slices] + + # Calculate the size of the patch along each axis + patch_size = tuple(stop[i] - start[i] for i in range(len(start))) + + # Calculate the total size (area) of the patch + total_size = 1 + for size in patch_size: + total_size *= size + + # Check if the current patch size is larger than the maximum + if total_size > max_patch_size: + max_patch_size = total_size + + max_patch_size_edge = np.ceil(np.sqrt(max_patch_size)) + + return max_patch_size_edge + + +def create_patch_dataset( + imgs: List[np.ndarray], + masks_classes: Optional[Union[List[np.ndarray], torch.Tensor]], + masks_instances: Optional[Union[List[np.ndarray], torch.Tensor]], + noise_intensity: int, + max_patch_size: int, + include_mask: bool, +) -> tuple: + """Splits images and masks into patches of equal size centered around the cells. + + :param imgs: A list of input images. + :type imgs: list of numpy.ndarray or torch.Tensor + :param masks_classes: A list of binary masks representing classes. + :type masks_classes: list of numpy.ndarray or torch.Tensor + :param masks_instances: A list of binary masks representing instances. + :type masks_instances: list of numpy.ndarray or torch.Tensor + :param noise_intensity: The intensity of noise to add to the patches. + :type noise_intensity: int + :param max_patch_size: The maximum size of the bounding box edge for objects in the mask. + :type max_patch_size: int + :param include_mask: A flag indicating whether to include the mask along with patches. + :type include_mask: bool + :return: A tuple containing the patches, patch masks, and labels. + :rtype: tuple + + .. note:: + If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use + the max cell size to define the patch size. All patches and masks should then be returned + in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same + convention of dims, e.g. CxHxW) + """ + if max_patch_size is None: + max_patch_size = np.max([find_max_patch_size(mask) for mask in masks_instances]) + + patches, patch_masks, labels = [], [], [] + for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): + # mask_instance has dimension WxH + # mask_class has dimension WxH + patch, patch_mask, _, label = get_centered_patches( + img=img, + mask=mask_instance, + p_size=max_patch_size, + noise_intensity=noise_intensity, + mask_class=mask_class, + include_mask=include_mask, + ) + patches.extend(patch) + patch_masks.extend(patch_mask) + labels.extend(label) + return patches, patch_masks, labels + + +def get_shape_features(img: np.ndarray, mask: np.ndarray) -> np.ndarray: + """Calculate shape-based radiomic features from an image within the region defined by the mask. + + :param img: The input image. + :type img: numpy.ndarray + :param mask: The mask corresponding to the image. + :type mask: numpy.ndarray + :return: An array containing the calculated shape-based radiomic features, such as elongation, sphericity, and perimeter surface. + :rtype: numpy.ndarray + """ + + mask = 255 * ((mask) > 0).astype(np.uint8) + image = sitk.GetImageFromArray(img.squeeze()) + roi_mask = sitk.GetImageFromArray(mask.squeeze()) + + shape_calculator = shape2D.RadiomicsShape2D( + inputImage=image, inputMask=roi_mask, label=255 + ) + # Calculate the shape-based radiomic features + shape_features = shape_calculator.execute() + + return np.array(list(shape_features.values())) + + +def extract_intensity_features(image: np.ndarray, mask: np.ndarray) -> np.ndarray: + """Extracts intensity-based features from an image within the region defined by the mask. + + :param image: The input image. + :type image: numpy.ndarray + :param mask: The mask defining the region of interest. + :type mask: numpy.ndarray + :return: An array containing the extracted intensity-based features, including median intensity, mean intensity, and 25th/75th percentile intensity within the masked region. + :rtype: numpy.ndarray + """ + + features = {} + + # Ensure the image and mask have the same dimensions + + if image.shape != mask.shape: + raise ValueError("Image and mask must have the same dimensions") + + masked_image = image[(mask > 0)] + # features["min_intensity"] = np.min(masked_image) + # features["max_intensity"] = np.max(masked_image) + features["median_intensity"] = np.median(masked_image) + features["mean_intensity"] = np.mean(masked_image) + features["25th_percentile_intensity"] = np.percentile(masked_image, 25) + features["75th_percentile_intensity"] = np.percentile(masked_image, 75) + + return np.array(list(features.values())) + + +def create_dataset_for_rf( + imgs: List[np.ndarray], masks: List[np.ndarray] +) -> List[np.ndarray]: + """Extracts shape and intensity-based features from images within regions defined by masks. + + :param imgs: A list of input images. + :type imgs: list + :param masks: A list of corresponding masks defining regions of interest. + :type masks: list + :return: A list of arrays containing shape and intensity-based features. + :rtype: list + """ + X = [] + for img, mask in zip(imgs, masks): + shape_features = get_shape_features(img, mask) + intensity_features = extract_intensity_features(img, mask) + features_list = np.concatenate((shape_features, intensity_features), axis=0) + X.append(features_list) + + return X diff --git a/src/server/pyproject.toml b/src/server/pyproject.toml index 5833f351..4acd006c 100644 --- a/src/server/pyproject.toml +++ b/src/server/pyproject.toml @@ -9,11 +9,10 @@ packages = ['dcp_server'] dependencies = {file = ["requirements.txt"]} [project] -name = "data-centric-tool-server" +name = "data-centric-platform-server" version = "0.1" -requires-python = ">=3.8" -description = "" -# license = {file = "LICENSE.txt"} +requires-python = ">=3.9" +description = "The server of the data centric platform for microscopy image segmentation" keywords = [] classifiers = [ "Programming Language :: Python :: 3", @@ -22,23 +21,27 @@ classifiers = [ readme = "README.md" dynamic = ["dependencies"] authors = [ - {name="Christina Bukas", email="christina.bukas@helmholtz-muenchen.de"}, - {name="Helena Pelin", email="helena.pelin@helmholtz-muenchen.de"} + {name="Christina Bukas", email="christina.bukas@helmholtz-munich.de"}, + {name="Helena Pelin", email="helena.pelin@helmholtz-munich.de"}, + {name="Mariia Koren", email="mariia.koren@helmholtz-munich.de"}, + {name="Marie Piraud", email="marie.piraud@helmholtz-munich.de"}, ] maintainers = [ - {name="Christina Bukas", email="christina.bukas@helmholtz-muenchen.de"}, - {name="Helena Pelin", email="helena.pelin@helmholtz-muenchen.de"} + {name="Christina Bukas", email="christina.bukas@helmholtz-munich.de"}, + {name="Helena Pelin", email="helena.pelin@helmholtz-munich.de"} ] [project.optional-dependencies] dev = [ - "pytest", + "pytest>=7.4.3", + "sphinx", + "sphinx-rtd-theme" ] [project.urls] repository = "https://github.com/HelmholtzAI-Consultants-Munich/data-centric-platform" # homepage = "https://example.com" -# documentation = "https://readthedocs.org" +documentation = "https://readthedocs.org/projects/data-centric-platform" [project.scripts] dcp-server = "dcp_server.main:main" diff --git a/src/server/requirements.txt b/src/server/requirements.txt index d84b7f43..a42ba5eb 100644 --- a/src/server/requirements.txt +++ b/src/server/requirements.txt @@ -4,7 +4,6 @@ bentoml==1.0.16 scikit-image>=0.19.3 torchmetrics>=0.11.4 torch>=2.1.0 -pytest>=7.4.3 numpy scikit-learn>=1.2.2 SimpleITK>=2.2.1 diff --git a/src/server/test/configs/test_config_CustomCellpose.yaml b/src/server/test/configs/test_config_CustomCellpose.yaml new file mode 100644 index 00000000..5e3c0436 --- /dev/null +++ b/src/server/test/configs/test_config_CustomCellpose.yaml @@ -0,0 +1,47 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "CustomCellpose", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "cells", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "segmentor_name": Cellpose, + "segmentor": { + "model_type": "cyto" + } + }, + + "data": { + "data_root": "data", + "gray": True, + "rescale": True + }, + + "train":{ + "segmentor":{ + "n_epochs": 20, + "channels": [0,0], + "min_train_masks": 1, + "learning_rate":0.01 + } + }, + + "eval":{ + "segmentor": { + "z_axis": null, + "channel_axis": null, + "rescale": 1, + "batch_size": 1 + }, + "mask_channel_axis": null + } +} \ No newline at end of file diff --git a/src/server/test/configs/test_config_fcnn.cfg b/src/server/test/configs/test_config_Inst2MultiSeg_PatchClassifier.yaml similarity index 64% rename from src/server/test/configs/test_config_fcnn.cfg rename to src/server/test/configs/test_config_Inst2MultiSeg_PatchClassifier.yaml index 02039f68..20e5c96a 100644 --- a/src/server/test/configs/test_config_fcnn.cfg +++ b/src/server/test/configs/test_config_Inst2MultiSeg_PatchClassifier.yaml @@ -1,49 +1,49 @@ { "setup": { "segmentation": "GeneralSegmentation", + "model_to_use": "Inst2MultiSeg", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { - "model_to_use": "CustomCellposeModel", - "save_model_path": "mito", - "runner_name": "cellpose_runner", + "runner_name": "bento_runner", + "bento_model_path": "cells", "service_name": "data-centric-platform", "port": 7010 }, "model": { + "segmentor_name": Cellpose, "segmentor": { "model_type": "cyto" }, - "classifier":{ - "model_class": "FCNN", + "classifier_name": "PatchClassifier", + "classifier":{ "in_channels": 1, "num_classes": 3, "features":[64,128,256,512], - "black_bg": "False", - "include_mask": "False" + "black_bg": False, + "include_mask": False } }, "data": { - "data_root": "data" + "data_root": "data", + "patch_size": 64, + "noise_intensity": 5, + "gray": True, + "rescale": True }, "train":{ "segmentor":{ - "n_epochs": 20, + "n_epochs": 10, "channels": [0,0], "min_train_masks": 1, "learning_rate":0.01 }, "classifier":{ - "train_data":{ - "patch_size": 64, - "noise_intensity": 5, - "num_classes": 3 - }, "n_epochs": 20, "lr": 0.005, "batch_size": 5, @@ -59,10 +59,6 @@ "batch_size": 1 }, "classifier": { - "data":{ - "patch_size": 64, - "noise_intensity": 5 - } }, "mask_channel_axis": 0 } diff --git a/src/server/test/configs/test_config_RF.cfg b/src/server/test/configs/test_config_Inst2MultiSeg_RF.yaml similarity index 54% rename from src/server/test/configs/test_config_RF.cfg rename to src/server/test/configs/test_config_Inst2MultiSeg_RF.yaml index c09c6af5..0734bcf7 100644 --- a/src/server/test/configs/test_config_RF.cfg +++ b/src/server/test/configs/test_config_Inst2MultiSeg_RF.yaml @@ -1,53 +1,44 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "CustomCellposeModel", + "model_to_use": "Inst2MultiSeg", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { - "runner_name": "cellpose_runner", - "bento_model_path": "mito", + "runner_name": "bento_runner", + "bento_model_path": "test", "service_name": "data-centric-platform", "port": 7010 }, "model": { + "segmentor_name": Cellpose, "segmentor": { "model_type": "cyto" }, + "classifier_name": "RandomForest", "classifier":{ - "model_class": "RandomForest", - "in_channels": 1, - "num_classes": 3, - "features":[64,128,256,512], - "black_bg": "False", - "include_mask": "False" } }, "data": { - "data_root": "data" + "data_root": "data", + "patch_size": 64, + "noise_intensity": 5, + "gray": True, + "rescale": True }, "train":{ "segmentor":{ - "n_epochs": 20, + "n_epochs": 10, "channels": [0,0], "min_train_masks": 1, "learning_rate":0.01 }, "classifier":{ - "train_data":{ - "patch_size": 64, - "noise_intensity": 5, - "num_classes": 3 - }, - "n_epochs": 10, - "lr": 0.001, - "batch_size": 1, - "optimizer": "Adam" } }, @@ -59,10 +50,6 @@ "batch_size": 1 }, "classifier": { - "data":{ - "patch_size": 64, - "noise_intensity": 5 - } }, "mask_channel_axis": 0 } diff --git a/src/server/test/configs/test_config_MultiCellpose.yaml b/src/server/test/configs/test_config_MultiCellpose.yaml new file mode 100644 index 00000000..46b913d7 --- /dev/null +++ b/src/server/test/configs/test_config_MultiCellpose.yaml @@ -0,0 +1,50 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "MultiCellpose", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "cells", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "segmentor_name": Cellpose, + "segmentor": { + "model_type": "cyto" + }, + "classifier":{ + "num_classes": 3 + } + }, + + "data": { + "data_root": "data", + "gray": True, + "rescale": True + }, + + "train":{ + "segmentor":{ + "n_epochs": 30, + "channels": [0,0], + "min_train_masks": 1, + "learning_rate":0.01 + } + }, + + "eval":{ + "segmentor": { + "z_axis": null, + "channel_axis": null, + "rescale": 1, + "batch_size": 1 + }, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/test/configs/test_config_UNet.yaml b/src/server/test/configs/test_config_UNet.yaml new file mode 100644 index 00000000..f4eba079 --- /dev/null +++ b/src/server/test/configs/test_config_UNet.yaml @@ -0,0 +1,46 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "UNet", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "cells", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "classifier":{ + "in_channels": 1, + "num_classes": 3, + "features":[64,128,256,512] + } + }, + + "data": { + "data_root": "data", + "gray": True, + "rescale": True + }, + + "train":{ + "classifier":{ + "n_epochs": 30, + "lr": 0.005, + "batch_size": 5, + "optimizer": "Adam" + } + }, + + "eval":{ + "classifier": { + + }, + compute_instance: True, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index c3adcdec..5c9e0fcb 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -14,9 +14,15 @@ def assign_unique_colors(labels, colors): - ''' - Assigns unique colors to each label in the given label array. - ''' + """Assigns unique colors to each label in the given label array. + + :param labels: The array containing labels. + :type labels: numpy.ndarray + :param colors: The list of colors to assign to each label. + :type colors: list + :return: A dictionary containing the color assignment for each unique label. + :rtype: dict + """ unique_labels = np.unique(labels) # Create a dictionary to store the color assignment for each label label_colors = {} @@ -35,10 +41,22 @@ def assign_unique_colors(labels, colors): return label_colors -def custom_label2rgb(labels, colors=['red', 'green', 'blue'], bg_label=0, alpha=0.5): - ''' + +def custom_label2rgb(labels, colors=["red", "green", "blue"], bg_label=0, alpha=0.5): + """ Converts a label array to an RGB image using assigned colors for each label. - ''' + + :param labels: The array containing labels. + :type labels: numpy.ndarray + :param colors: The list of colors to assign to each label. Defaults to ['red', 'green', 'blue']. + :type colors: list, optional + :param bg_label: The label representing the background. Defaults to 0. + :type bg_label: int, optional + :param alpha: The transparency level of the colors. Defaults to 0.5. + :type alpha: float, optional + :return: The RGB image representing the labels with assigned colors. + :rtype: numpy.ndarray + """ label_colors = assign_unique_colors(labels, colors) @@ -47,20 +65,27 @@ def custom_label2rgb(labels, colors=['red', 'green', 'blue'], bg_label=0, alpha= for label in np.unique(labels): mask = labels == label if label in label_colors: - rgb = color.label2rgb(mask, colors=[label_colors[label]], bg_label=bg_label, alpha=alpha) + rgb = color.label2rgb( + mask, colors=[label_colors[label]], bg_label=bg_label, alpha=alpha + ) rgb_image += rgb return rgb_image + def add_padding_for_rotation(image, angle): - ''' - Apply padding and rotation to an image. + """ + Apply padding and rotation to an image. + The purpose of this function is to ensure that the rotated image fits within its original dimensions by adding padding, preventing any parts of the image from being cropped. - Args: - image (numpy.ndarray): The input image. - angle (float): The rotation angle in degrees. - ''' + :param image: The input image. + :type image: numpy.ndarray + :param angle: The rotation angle in degrees. + :type angle: float + :return: The rotated and padded image. + :rtype: numpy.ndarray + """ # Calculate rotated bounding box h, w = image.shape[:2] @@ -76,48 +101,73 @@ def add_padding_for_rotation(image, angle): pad_h = (new_h - h) // 2 # Add padding to the image - padded_image = cv2.copyMakeBorder(image, pad_h, pad_h, pad_w, pad_w, cv2.BORDER_CONSTANT) + padded_image = cv2.copyMakeBorder( + image, pad_h, pad_h, pad_w, pad_w, cv2.BORDER_CONSTANT + ) # Rotate the padded image center = (padded_image.shape[1] // 2, padded_image.shape[0] // 2) rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) - rotated_image = cv2.warpAffine(padded_image, rotation_matrix, (padded_image.shape[1], padded_image.shape[0])) + rotated_image = cv2.warpAffine( + padded_image, rotation_matrix, (padded_image.shape[1], padded_image.shape[0]) + ) return rotated_image + def get_object_images(objects): - ''' + """ Load object images from file paths. - ''' + + :param objects: A list of dictionaries containing information about the objects such as name, path, intensity + :type objects: list[dict] + :return: A list of object images loaded from the specified file paths. + :rtype: list[numpy.ndarray] + """ object_images = [] for obj in objects: - img = cv2.imread(obj['path']) + img = cv2.imread(obj["path"]) # img = cv2.resize(img, obj['size']) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) object_images.append(img) return object_images -def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, noise_intensity=None, max_rotation_angle=None): - ''' - Generate a synthetic dataset with images and masks. - Args: - num_samples (int): The number of samples to generate. - objects (list): List of object descriptions. - canvas_size (tuple): Size of the canvas to place objects on. - max_object_counts (list, optional): Maximum object counts for each class. Default is None. - noise_intensity (float, optional): intensity of the additional noise to the image +def generate_dataset( + num_samples, + objects, + canvas_size, + max_object_counts=None, + noise_intensity=None, + max_rotation_angle=None, +): + """ + Generate a synthetic dataset with images and masks. - ''' + :param num_samples: The number of samples to generate. + :type num_samples: int + :param objects: List of object descriptions. + :type objects: list + :param canvas_size: Size of the canvas to place objects on. + :type canvas_size: tuple + :param max_object_counts: Maximum object counts for each class. Default is None. + :type max_object_counts: list, optional + :param noise_intensity: Intensity of the additional noise to the image. Default is None. + :type noise_intensity: float, optional + :param max_rotation_angle: Maximum rotation angle in degrees. Default is None. + :type max_rotation_angle: float, optional + :return: A tuple containing the generated images and masks. + :rtype: tuple + """ dataset_images = [] dataset_masks = [] object_images = get_object_images(objects) - class_intensities = [ (obj['intensity'][0], obj['intensity'][1]) for obj in objects] + class_intensities = [(obj["intensity"][0], obj["intensity"][1]) for obj in objects] if len(object_images[0].shape) == 3: num_of_img_channels = object_images[0].shape[-1] @@ -128,8 +178,12 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, max_object_counts = [10] * len(object_images) for _ in range(num_samples): - canvas = np.zeros((canvas_size[0], canvas_size[1], num_of_img_channels), dtype=np.uint8) - mask = np.zeros((canvas_size[0], canvas_size[1], len(object_images)), dtype=np.uint8) + canvas = np.zeros( + (canvas_size[0], canvas_size[1], num_of_img_channels), dtype=np.uint8 + ) + mask = np.zeros( + (canvas_size[0], canvas_size[1], len(object_images)), dtype=np.uint8 + ) for object_index, object_img in enumerate(object_images): @@ -137,70 +191,104 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, object_count = random.randint(1, max_count) for _ in range(object_count): - + canvas_range = max(canvas_size) - object_size = random.randint(canvas_range//20, canvas_range//5) + object_size = random.randint(canvas_range // 20, canvas_range // 5) object_img_resized = cv2.resize(object_img, (object_size, object_size)) # object_img_resized = (object_img_resized>0).astype(np.uint8)*(255 - object_size) - intensity_mean = (class_intensities[object_index][1] - class_intensities[object_index][0])/2 - intensity_scale = (class_intensities[object_index][1] - intensity_mean)/3 - class_intensity = np.random.normal(loc=intensity_mean, scale=intensity_scale) - class_intensity = np.clip(class_intensity, class_intensities[object_index][0], class_intensities[object_index][1]) + intensity_mean = ( + class_intensities[object_index][1] + - class_intensities[object_index][0] + ) / 2 + intensity_scale = ( + class_intensities[object_index][1] - intensity_mean + ) / 3 + class_intensity = np.random.normal( + loc=intensity_mean, scale=intensity_scale + ) + class_intensity = np.clip( + class_intensity, + class_intensities[object_index][0], + class_intensities[object_index][1], + ) # class_intensity = random.randint(int(class_intensities[object_index][0]), int(class_intensities[object_index][1])) - object_img_resized = (object_img_resized>0).astype(np.uint8)*(class_intensity)*255 + object_img_resized = ( + (object_img_resized > 0).astype(np.uint8) * (class_intensity) * 255 + ) if num_of_img_channels == 1: - + if max_rotation_angle is not None: # Randomly rotate the object image - rotation_angle = random.uniform(-max_rotation_angle, max_rotation_angle) - object_img_transformed = add_padding_for_rotation(object_img_resized, rotation_angle) + rotation_angle = random.uniform( + -max_rotation_angle, max_rotation_angle + ) + object_img_transformed = add_padding_for_rotation( + object_img_resized, rotation_angle + ) else: object_img_transformed = object_img_resized - - object_size_x, object_size_y = object_img_transformed.shape - + object_size_x, object_size_y = object_img_transformed.shape object_mask = np.zeros((object_size_x, object_size_y), dtype=np.uint8) if num_of_img_channels == 1: # Grayscale image object_mask[object_img_transformed > 0] = object_index + 1 # object_img_resized = np.expand_dims(object_img_resized, axis=-1) - object_img_transformed = np.expand_dims(object_img_transformed, axis=-1) + object_img_transformed = np.expand_dims( + object_img_transformed, axis=-1 + ) else: # Color image with alpha channel object_mask[object_img_resized[:, :, -1] > 0] = object_index + 1 - x = random.randint(0, canvas_size[1] - object_size_x) y = random.randint(0, canvas_size[0] - object_size_y) - intersecting_mask = mask[y:y + object_size_y, x:x + object_size_x].max(axis=-1) + intersecting_mask = mask[ + y : y + object_size_y, x : x + object_size_x + ].max(axis=-1) if (intersecting_mask > 0).any(): continue # Skip if there is an intersection with objects from other classes - - assert mask[y:y + object_size_y, x:x + object_size_x, object_index].shape == object_mask.shape - canvas[y:y + object_size_y, x:x + object_size_x] = object_img_transformed - mask[y:y + object_size_y, x:x + object_size_x, object_index] = np.maximum( - mask[y:y + object_size_y, x:x + object_size_x, object_index], object_mask + assert ( + mask[ + y : y + object_size_y, x : x + object_size_x, object_index + ].shape + == object_mask.shape + ) + + canvas[y : y + object_size_y, x : x + object_size_x] = ( + object_img_transformed + ) + mask[y : y + object_size_y, x : x + object_size_x, object_index] = ( + np.maximum( + mask[ + y : y + object_size_y, x : x + object_size_x, object_index + ], + object_mask, + ) ) - # Add noise to the canvas if noise_intensity is not None: if num_of_img_channels == 1: - noise = np.random.normal(scale=noise_intensity, size=(canvas_size[0], canvas_size[1], 1)) + noise = np.random.normal( + scale=noise_intensity, size=(canvas_size[0], canvas_size[1], 1) + ) # noise = random_noise(canvas, mode='speckle', mean=noise_intensity) - + else: - noise = np.random.normal(scale=noise_intensity, size=(canvas_size[0], canvas_size[1], num_of_img_channels)) + noise = np.random.normal( + scale=noise_intensity, + size=(canvas_size[0], canvas_size[1], num_of_img_channels), + ) noisy_canvas = canvas + noise.astype(np.uint8) - dataset_images.append(noisy_canvas.squeeze(2)) - + dataset_images.append(noisy_canvas.squeeze(2)) + else: dataset_images.append(canvas.squeeze(2)) @@ -218,26 +306,37 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, return dataset_images, dataset_masks -def get_synthetic_dataset(num_samples, canvas_size=(512,512), max_object_counts=[15, 15, 15]): - + +def get_synthetic_dataset( + num_samples, canvas_size=(512, 512), max_object_counts=[15, 15, 15] +): + """Generates a synthetic dataset with images and masks. + + :param num_samples: The number of samples to generate. + :type num_samples: int + :param canvas_size: Size of the canvas to place objects on. Default is (512, 512). + :type canvas_size: tuple, optional + :param max_object_counts: Maximum object counts for each class. Default is [15, 15, 15]. + :type max_object_counts: list, optional + :return: A tuple containing the generated images and masks. + :rtype: tuple + """ objects = [ - { - - 'name': 'triangle', - 'path': 'test/shapes/triangle.png', - 'intensity' : [0, 0.33] - }, - { - 'name': 'circle', - 'path': 'test/shapes/circle.png', - 'intensity' : [0.34, 0.66] - }, - { - 'name': 'square', - 'path': 'test/shapes/square.png', - 'intensity' : [0.67, 1.0] - }, + { + "name": "triangle", + "path": "test/shapes/triangle.png", + "intensity": [0, 0.33], + }, + {"name": "circle", "path": "test/shapes/circle.png", "intensity": [0.34, 0.66]}, + {"name": "square", "path": "test/shapes/square.png", "intensity": [0.67, 1.0]}, ] - - images, masks = generate_dataset(num_samples, objects, canvas_size=canvas_size, max_object_counts=max_object_counts, noise_intensity=5, max_rotation_angle=30) + + images, masks = generate_dataset( + num_samples, + objects, + canvas_size=canvas_size, + max_object_counts=max_object_counts, + noise_intensity=5, + max_rotation_angle=30, + ) return images, masks diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index ced69cdc..6e37ea22 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -1,73 +1,148 @@ import sys + +sys.path.append(".") + from glob import glob -import inspect +import pytest + +# import inspect import random import numpy as np - -import torch +import torch from torchmetrics import JaccardIndex -# from importlib.machinery import SourceFileLoader - -sys.path.append(".") - -import dcp_server.models as models -from dcp_server.utils import read_config +from dcp_server.models import * +from dcp_server.utils.helpers import read_config from synthetic_dataset import get_synthetic_dataset -import pytest - seed_value = 2023 random.seed(seed_value) torch.manual_seed(seed_value) np.random.seed(seed_value) -# retrieve models names -model_classes = [ - cls_obj for cls_name, cls_obj in inspect.getmembers(models) \ - if inspect.isclass(cls_obj) \ - and cls_obj.__module__ == models.__name__ \ - and not cls_name.startswith("CellClassifier") - ] +model_mapping = { + "CustomCellpose": CustomCellpose, + "Inst2MultiSeg": Inst2MultiSeg, + "MultiCellpose": MultiCellpose, + "UNet": UNet, +} -config_paths = glob("test/configs/*.cfg") +config_paths = glob("test/configs/*.yaml") -@pytest.fixture(params=model_classes) -def model_class(request): - return request.param @pytest.fixture(params=config_paths) def config_path(request): return request.param -@pytest.fixture() -def model(model_class, config_path): - - model_config = read_config('model', config_path=config_path) - train_config = read_config('train', config_path=config_path) - eval_config = read_config('eval', config_path=config_path) - - model = model_class(model_config, train_config, eval_config, str(model_class)) +@pytest.fixture() +# def model(model_class, config_path): +def model(config_path): + + setup_config = read_config("setup", config_path=config_path) + model_config = read_config("model", config_path=config_path) + data_config = read_config("data", config_path=config_path) + train_config = read_config("train", config_path=config_path) + eval_config = read_config("eval", config_path=config_path) + + model_name = setup_config["model_to_use"] + model_class = model_mapping.get(model_name) + model = model_class( + model_name, model_config, data_config, train_config, eval_config + ) + # str(model_class) return model + @pytest.fixture def data_train(): - images, masks = get_synthetic_dataset(num_samples=4, canvas_size=(512,768)) + images, masks = get_synthetic_dataset(num_samples=4, canvas_size=(512, 768)) masks = [np.array(mask) for mask in masks] masks_instances = [mask.sum(-1) for mask in masks] masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] - masks_ = [np.stack((instances, classes)) for instances, classes in zip(masks_instances, masks_classes)] + masks_ = [ + np.stack((instances, classes)) + for instances, classes in zip(masks_instances, masks_classes) + ] return images, masks_ + @pytest.fixture -def data_eval(): +def data_eval(): img, msk = get_synthetic_dataset(num_samples=1) msk = np.array(msk) - msk_ = np.stack((msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0).transpose(1,0,2,3) + msk_ = np.stack( + (msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0 + ).transpose(1, 0, 2, 3) return img, msk_ + +def test_train_eval_run(data_train, data_eval, model): + """ + Performs testing, training, and evaluation with the provided data and model. + """ + # train + images, masks = data_train + if model.model_name == "CustomCellpose": + masks = [mask[0] for mask in masks] + model.train(images, masks) + + # retrieve the attribute names of the class of the current model + attrs = model.__dict__.keys() + + if "metric" in attrs: + assert model.metric > 0.1 + if "loss" in attrs: + assert model.loss < 0.83 + + # validate + imgs_test, masks_test = data_eval + if model.model_name == "CustomCellpose": + masks = [mask[0] for mask in masks_test] + + jaccard_index_instances = 0 + jaccard_index_classes = 0 + + jaccard_metric_binary = JaccardIndex( + task="multiclass", num_classes=2, average="macro", ignore_index=0 + ) + jaccard_metric_multi = JaccardIndex( + task="multiclass", num_classes=4, average="macro", ignore_index=0 + ) + + for img, mask in zip(imgs_test, masks_test): + + # mask - instance segmentation mask + classes (2, 512, 512) + # pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) + + pred_mask = model.eval(img) + + if pred_mask.ndim > 2: + pred_mask_bin = torch.tensor((pred_mask[0] > 0).astype(bool).astype(int)) + else: + pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) + + bin_mask = torch.tensor((mask[0] > 0).astype(bool).astype(int)) + + jaccard_index_instances += jaccard_metric_binary(pred_mask_bin, bin_mask) + + if pred_mask.ndim > 2: + + jaccard_index_classes += jaccard_metric_multi( + torch.tensor(pred_mask[1].astype(int)), + torch.tensor(mask[1].astype(int)), + ) + + jaccard_index_instances /= len(imgs_test) + assert jaccard_index_instances > 0.2 + + if pred_mask.ndim > 2: + + jaccard_index_classes /= len(imgs_test) + assert jaccard_index_classes > 0.1 + + # def test_train_run(data_train, model): # images, masks = data_train @@ -83,12 +158,12 @@ def data_eval(): # assert(model.metric>0.1) # if "loss" in attrs: # assert(model.loss<0.3) - + # def test_eval_run(data_train, data_eval, model): # images, masks = data_train # model.train(images, masks) - + # imgs_test, masks_test = data_eval # jaccard_index_instances = 0 @@ -103,7 +178,7 @@ def data_eval(): # #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) # pred_mask = model.eval(img) #, channels=[0,0]) - + # if pred_mask.ndim > 2: # pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) # else: @@ -112,78 +187,22 @@ def data_eval(): # bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) # jaccard_index_instances += jaccard_metric_binary( -# pred_mask_bin, +# pred_mask_bin, # bin_mask # ) # if pred_mask.ndim > 2: # jaccard_index_classes += jaccard_metric_multi( -# torch.tensor(pred_mask[1].astype(int)), +# torch.tensor(pred_mask[1].astype(int)), # torch.tensor(mask[1].astype(int)) # ) - + # jaccard_index_instances /= len(imgs_test) # assert(jaccard_index_instances>0.2) -# # for PatchCNN model +# # for PatchCNN model # if pred_mask.ndim > 2: # jaccard_index_classes /= len(imgs_test) # assert(jaccard_index_classes>0.1) - -def test_train_eval_run(data_train, data_eval, model): - - images, masks = data_train - model.train(images, masks) - - imgs_test, masks_test = data_eval - - jaccard_index_instances = 0 - jaccard_index_classes = 0 - - jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) - jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) - - for img, mask in zip(imgs_test, masks_test): - - #mask - instance segmentation mask + classes (2, 512, 512) - #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) - - pred_mask = model.eval(img) - - if pred_mask.ndim > 2: - pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) - else: - pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) - - bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) - - jaccard_index_instances += jaccard_metric_binary( - pred_mask_bin, - bin_mask - ) - - if pred_mask.ndim > 2: - - jaccard_index_classes += jaccard_metric_multi( - torch.tensor(pred_mask[1].astype(int)), - torch.tensor(mask[1].astype(int)) - ) - - jaccard_index_instances /= len(imgs_test) - assert(jaccard_index_instances>0.2) - - # retrieve the attribute names of the class of the current model - attrs = model.__dict__.keys() - - if "metric" in attrs: - assert(model.metric>0.1) - if "loss" in attrs: - assert(model.loss<0.75) - - # for PatchCNN model - if pred_mask.ndim > 2: - - jaccard_index_classes /= len(imgs_test) - assert(jaccard_index_classes>0.1) \ No newline at end of file diff --git a/src/server/test/test_models.py b/src/server/test/test_models.py index 84b203c3..eddf8f94 100644 --- a/src/server/test/test_models.py +++ b/src/server/test/test_models.py @@ -2,35 +2,33 @@ import numpy as np import dcp_server.models as models -from dcp_server.utils import read_config +from dcp_server.models.classifiers import FeatureClassifier +from dcp_server.utils.helpers import read_config -def test_eval_rf_not_fitted(): - - model_config = read_config('model', config_path='test/configs/test_config_RF.cfg') - train_config = read_config('train', config_path='test/configs/test_config_RF.cfg') - eval_config = read_config('eval', config_path='test/configs/test_config_RF.cfg') - - model_rf = models.CellClassifierShallowModel(model_config,train_config,eval_config) - X_test = np.array([[1, 2, 3]]) +def test_eval_rf_not_fitted(): + """ + Tests the evaluation of a random forest model that has not been fitted. + """ + + model_config = read_config( + "model", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + data_config = read_config( + "data", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + train_config = read_config( + "train", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + eval_config = read_config( + "eval", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + + model_rf = FeatureClassifier( + "Random Forest", model_config, data_config, train_config, eval_config + ) + + X_test = np.array([[1, 2, 3]]) # if we don't fit the model then the model returns zeros - assert np.all(model_rf.eval(X_test)== np.zeros(X_test.shape)) - -def test_update_configs(): - - model_config = read_config('model', config_path='test/configs/test_config_RF.cfg') - train_config = read_config('train', config_path='test/configs/test_config_RF.cfg') - eval_config = read_config('eval', config_path='test/configs/test_config_RF.cfg') - - model = models.CustomCellposeModel(model_config,train_config,eval_config, "Cellpose") - - new_train_config = {"param1": "value1"} - new_eval_config = {"param2": "value2"} - - model.update_configs(new_train_config, new_eval_config) - - assert model.train_config == new_train_config - assert model.eval_config == new_eval_config - - + assert np.all(model_rf.eval(X_test) == np.zeros(X_test.shape)) diff --git a/src/server/test/test_utils.py b/src/server/test/test_utils.py index 35678a22..b0c4f71f 100644 --- a/src/server/test/test_utils.py +++ b/src/server/test/test_utils.py @@ -1,6 +1,7 @@ import numpy as np import pytest -from dcp_server.utils import find_max_patch_size +from dcp_server.utils.processing import find_max_patch_size + @pytest.fixture def sample_mask(): @@ -9,12 +10,9 @@ def sample_mask(): mask[7:9, 2:5] = 1 return mask + def test_find_max_patch_size(sample_mask): # Test when the function is called with a sample mask result = find_max_patch_size(sample_mask) assert isinstance(result, float) assert result > 0 - - - -