diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7f1627f8..90a92f78 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,17 +44,15 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools - pip install pytest - pip install pytest-qt pip install pytest-xvfb pip install coverage - pip install -e ".[testing]" + pip install -e ".[dev]" pip install matplotlib working-directory: src/client - name: Install server dependencies (for communication tests) run: | - pip install -e ".[testing]" + pip install -e ".[dev]" working-directory: src/server - name: Test with pytest @@ -94,10 +92,9 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade setuptools pip install numpy - pip install pytest pip install wheel pip install coverage - pip install -e ".[testing]" + pip install -e ".[dev]" working-directory: src/server - name: Test with pytest @@ -115,7 +112,6 @@ jobs: env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - deploy: # this will run when you have tagged a commit, starting with "v*" # and requires that you have put your twine API key in your diff --git a/.gitignore b/.gitignore index 5aab7354..0f64af41 100644 --- a/.gitignore +++ b/.gitignore @@ -21,7 +21,7 @@ __pycache__/ # Distribution / packaging .Python -# build/ +build/ develop-eggs/ dist/ downloads/ @@ -153,4 +153,6 @@ dmypy.json .idea/ .DS_Store +data/ +BentoML/ diff --git a/README.md b/README.md index 97e28581..bf0e3499 100644 --- a/README.md +++ b/README.md @@ -29,5 +29,4 @@ Our platform encourages the use of data centric practices. With the user friendl - Focus on data curation: no interaction with model parameters during training and inference #### *Get more with less!* - diff --git a/docs/source/conf.py b/docs/source/conf.py index 1145052d..c0df37ab 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -25,7 +25,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = 'alabaster' -html_static_path = ['_static'] +#html_static_path = ['_static'] import os import sys @@ -35,7 +35,7 @@ # Add parent dir to known paths p = Path(__file__).parents[2] sys.path.insert(0, os.path.abspath(p)) - +sys.path.insert(0, os.path.join(p, 'src/server/dcp_server')) # Add the following extensions extensions = [ 'sphinx.ext.autodoc', diff --git a/docs/source/dcp_client.rst b/docs/source/dcp_client.rst index ab9b9ecd..f921a211 100644 --- a/docs/source/dcp_client.rst +++ b/docs/source/dcp_client.rst @@ -1,13 +1,32 @@ dcp\_client package =================== +The dcp_client package contains modules and subpackages for interacting with a server for model inference and training. It provides functionalities for managing GUI windows, handling image storage, and connecting to the server for model operations. -.. toctree:: - :maxdepth: 4 +dcp_client.app + Defines the core application class and related functionalities. + - ``dcp_client.app.Application``: Represents the main application and provides methods for image management, model interaction, and server connectivity. + - ``dcp_client.app.DataSync``: Abstract base class for data synchronization operations. + - ``dcp_client.app.ImageStorage``: Abstract base class for image storage operations. + - ``dcp_client.app.Model``: Abstract base class for model operations. - dcp_client.gui - dcp_client.utils +dcp_client.gui + Contains modules for GUI components. + - ``dcp_client.gui.main_window``: Defines the main application window and associated event functions. + - ``dcp_client.gui.napari_window``: Manages the Napari window and its functionalities. + - ``dcp_client.gui.welcome_window``: Implements the welcome window and its interactions. +dcp_client.utils + Contains utility modules for various tasks. + - ``dcp_client.utils.bentoml_model``: Handles interactions with BentoML for model inference and training. + - ``dcp_client.utils.fsimagestorage``: Provides functions for managing images stored in the filesystem. + - ``dcp_client.utils.settings``: Defines initialization functions and settings. + - ``dcp_client.utils.sync_src_dst``: Implements data synchronization between source and destination. + - ``dcp_client.utils.utils``: Offers various utility functions for common tasks. + + +Submodules +---------- dcp\_client.app module ---------------------- @@ -17,3 +36,17 @@ dcp\_client.app module :undoc-members: :show-inheritance: +dcp\_client.gui module +---------------------- +.. toctree:: + :maxdepth: 4 + + dcp_client.gui + +dcp\_client.utils module +------------------------ +.. toctree:: + :maxdepth: 4 + + dcp_client.utils + \ No newline at end of file diff --git a/docs/source/dcp_client_installation.rst b/docs/source/dcp_client_installation.rst index 2b010e42..b4a883b2 100644 --- a/docs/source/dcp_client_installation.rst +++ b/docs/source/dcp_client_installation.rst @@ -114,7 +114,7 @@ DCP Shortcuts - In the Data Overview window, clicking on an image and the hitting the **Enter** key, is equivalent to clicking the 'View Image and Fix Label' button - The viewer accepts all Napari Shortcuts. The current list of the shortcuts for macOS can be see below: -.. image:: https://raw.githubusercontent.com/HelmholtzAI-Consultants-Munich/data-centric-platform/add-documentation/src/client/readme_figs/napari_shortcuts.png +.. image:: https://raw.githubusercontent.com/HelmholtzAI-Consultants-Munich/data-centric-platform/main/src/client/readme_figs/napari_shortcuts.png :width: 600 :height: 500 :align: center diff --git a/docs/source/dcp_server.rst b/docs/source/dcp_server.rst index dfb482d9..78a1ef23 100644 --- a/docs/source/dcp_server.rst +++ b/docs/source/dcp_server.rst @@ -1,23 +1,25 @@ dcp\_server package =================== -.. automodule:: dcp_server - :members: - :undoc-members: - :show-inheritance: - :exclude-members: dcp\_server.main module +The dcp_server package is structured to handle various server-side functionalities related model serving for segmentation and training. -Submodules ----------- +dcp_server.models + Defines various models for cell classification and segmentation, including CellClassifierFCNN, CellClassifierShallowModel, CellposePatchCNN, CustomCellposeModel, and UNet. + These models handle tasks such as evaluation, forward pass, training, and updating configurations. -dcp\_server.fsimagestorage module ---------------------------------- +dcp_server.segmentationclasses + Defines segmentation classes for specific projects, such as GFPProjectSegmentation, GeneralSegmentation, and MitoProjectSegmentation. + These classes contain methods for segmenting images and training models on images and masks. -.. automodule:: dcp_server.fsimagestorage - :members: - :undoc-members: - :show-inheritance: +dcp_server.serviceclasses + Defines service classes, such as CustomBentoService and CustomRunnable, for serving the models with BentoML and handling computation on remote Python workers. +dcp_server.utils + Provides various utility functions for dealing with image storage, image processing, feature extraction, file handling, configuration reading, and path manipulation. + + +Submodules +---------- dcp\_server.models module ------------------------- @@ -44,9 +46,9 @@ dcp\_server.serviceclasses module :show-inheritance: dcp\_server.utils module ------------------------- +--------------------------------- -.. automodule:: dcp_server.utils - :members: - :undoc-members: - :show-inheritance: +.. toctree:: + :maxdepth: 4 + + dcp_server.utils \ No newline at end of file diff --git a/docs/source/dcp_server.utils.rst b/docs/source/dcp_server.utils.rst new file mode 100644 index 00000000..a6334330 --- /dev/null +++ b/docs/source/dcp_server.utils.rst @@ -0,0 +1,37 @@ +dcp\_server.utils package +========================= + +.. automodule:: dcp_server.utils + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +dcp\_server.utils.fsimagestorage module +--------------------------------------- + +.. automodule:: dcp_server.utils.fsimagestorage + :members: + :undoc-members: + :show-inheritance: + +dcp\_server.utils.helpers module +--------------------------------------- + +.. automodule:: dcp_server.utils.helpers + :members: + :undoc-members: + :show-inheritance: + +dcp\_server.utils.processing module +----------------------------------- + +.. automodule:: dcp_server.utils.processing + :members: + :undoc-members: + :show-inheritance: + + + diff --git a/docs/source/dcp_server_installation.rst b/docs/source/dcp_server_installation.rst index a34bd6a8..823f97af 100644 --- a/docs/source/dcp_server_installation.rst +++ b/docs/source/dcp_server_installation.rst @@ -19,7 +19,7 @@ Before starting make sure you have navigated to ``data-centric-platform/src/serv .. code-block:: bash - pip install -e . + pip install -e ".[dev]" Launch DCP Server ------------------ @@ -83,17 +83,18 @@ The models are currently integrated into DCP: - **Instance** Segmentation: - - ``CustomCellposeModel``: Inherits from cellpose.models.CellposeModel, see `here `__ for more information. + - ``CustomCellpose``: Inherits from cellpose.models.CellposeModel, see `here `__ for more information. - **Semantic** Segmentation: - ``UNet``: A vanilla U-Net model, trained on the full images -- **Panoptic** Segmentation: +- **Multi Class Instance** Segmentation: - - ``CellposePatchCNN``: Includes a segmentor for instance segmentation, sequentially followed by a classifier for semantic segmentation. The segmentor can only be ``CustomCellposeModel`` model, while the classifier can be one of: + - ``Inst2MultiSeg``: Includes a segmentor for instance segmentation, sequentially followed by a classifier for semantic segmentation. The segmentor can only be ``CustomCellposeModel`` model, while the classifier can be one of: - - ``CellClassifierFCNN`` or "FCNN" (in config): A CNN model for obtaining class labels, trained on images patches of individual objects, extarcted using the instance mask from the previous step - - ``CellClassifierShallowModel`` or "RandomForest" (in config): A Random Forest model for obtaining class labels, trained on shape and intensity features of the objects, extracted using the instance mask from the previous step. - - UNet: If the post-processing argument is set, then the instance mask is deduced from the labels mask. Will not be able to handle touching objects + - ``PatchClassifier`` or "FCNN" (in config): A CNN model for obtaining class labels, trained on images patches of individual objects, extarcted using the instance mask from the previous step + - ``FeatureClassifier`` or "RandomForest" (in config): A Random Forest model for obtaining class labels, trained on shape and intensity features of the objects, extracted using the instance mask from the previous step. + - ``MultiCellpose``: Includes **n** CustomCellpose models, where n equals the number of classes, stacked such that each model predicts only the object corresponding to each class. + - ``UNet``: If the post-processing argument is set, then the instance mask is deduced from the labels mask. Will not be able to handle touching objects Running with Docker diff --git a/docs/source/index.rst b/docs/source/index.rst index f220d007..c6532633 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -31,7 +31,7 @@ DCP handles all kinds of **segmentation tasks**! Try it out if you need to do: - **Instance** segmentation - **Semantic** segmentation -- **Panoptic** segmentation +- **Multi-class instance** segmentation Toy data -------- diff --git a/src/client/MANIFEST.in b/src/client/MANIFEST.in index 8809a659..c6c02f12 100644 --- a/src/client/MANIFEST.in +++ b/src/client/MANIFEST.in @@ -1 +1 @@ -include dcp_client/*.cfg \ No newline at end of file +include dcp_client/*.yaml \ No newline at end of file diff --git a/src/client/dcp_client/__init__.py b/src/client/dcp_client/__init__.py index 65273344..f4ffe44b 100644 --- a/src/client/dcp_client/__init__.py +++ b/src/client/dcp_client/__init__.py @@ -39,4 +39,4 @@ This package structure allows for easy management of GUI components, image storage, model interactions, and server connectivity within the dcp_client application. -""" \ No newline at end of file +""" diff --git a/src/client/dcp_client/app.py b/src/client/dcp_client/app.py index 9adbb587..8ae4e8a9 100644 --- a/src/client/dcp_client/app.py +++ b/src/client/dcp_client/app.py @@ -11,7 +11,7 @@ class Model(ABC): @abstractmethod def run_train(self, path: str) -> None: pass - + @abstractmethod def run_inference(self, path: str) -> None: pass @@ -21,7 +21,7 @@ class DataSync(ABC): @abstractmethod def sync(self, src: str, dst: str, path: str) -> None: pass - + class ImageStorage(ABC): @abstractmethod @@ -35,22 +35,29 @@ def save_image(self, to_directory, cur_selected_img, img) -> None: def search_segs(self, img_directory, cur_selected_img): """Returns a list of full paths of segmentations for an image""" # Take all segmentations of the image from the current directory: - search_string = utils.get_path_stem(cur_selected_img) + '_seg' - seg_files = [file_name for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] + search_string = utils.get_path_stem(cur_selected_img) + "_seg" + seg_files = [ + file_name + for file_name in os.listdir(img_directory) + if ( + search_string == utils.get_path_stem(file_name) + or str(file_name).startswith(search_string) + ) + ] return seg_files class Application: def __init__( - self, + self, ml_model: Model, syncer: DataSync, image_storage: ImageStorage, server_ip: str, server_port: int, - eval_data_path: str = '', - train_data_path: str = '', - inprogr_data_path: str = '', + eval_data_path: str = "", + train_data_path: str = "", + inprogr_data_path: str = "", ): self.ml_model = ml_model self.syncer = syncer @@ -60,73 +67,90 @@ def __init__( self.eval_data_path = eval_data_path self.train_data_path = train_data_path self.inprogr_data_path = inprogr_data_path - self.cur_selected_img = '' - self.cur_selected_path = '' - self.seg_filepaths = [] + self.cur_selected_img = "" + self.cur_selected_path = "" + self.seg_filepaths = [] def upload_data_to_server(self): """ Uploads the train and eval data to the server. """ - success_f1, message1 = self.syncer.first_sync(path=self.train_data_path) - success_f2, message2 = self.syncer.first_sync(path=self.eval_data_path) + success_f1, message1 = self.syncer.first_sync(path=self.train_data_path) + success_f2, message2 = self.syncer.first_sync(path=self.eval_data_path) return success_f1, success_f2, message1, message2 def try_server_connection(self): """ Checks if the ml model is connected to server and attempts to connect if not. """ - connection_success = self.ml_model.connect(ip=self.server_ip, port=self.server_port) + connection_success = self.ml_model.connect( + ip=self.server_ip, port=self.server_port + ) return connection_success - + def run_train(self): - """ Checks if the ml model is connected to the server, connects if not (and if possible), and trains the model with all data available in train_data_path """ - if not self.ml_model.is_connected and not self.try_server_connection(): + """Checks if the ml model is connected to the server, connects if not (and if possible), and trains the model with all data available in train_data_path""" + if not self.ml_model.is_connected and not self.try_server_connection(): message_title = "Warning" message_text = "Connection could not be established. Please check if the server is running and try again." return message_text, message_title # if syncer.host name is None then local machine is used to train message_title = "Information" - if self.syncer.host_name=="local": + if self.syncer.host_name == "local": message_text = self.ml_model.run_train(self.train_data_path) else: - success_sync, srv_relative_path = self.syncer.sync(src='client', dst='server', path=self.train_data_path) + success_sync, srv_relative_path = self.syncer.sync( + src="client", dst="server", path=self.train_data_path + ) # make sure syncing of folders was successful - if success_sync=="Success": message_text = self.ml_model.run_train(srv_relative_path) - else: message_text = None - if message_text is None: + if success_sync == "Success": + message_text = self.ml_model.run_train(srv_relative_path) + else: + message_text = None + if message_text is None: message_text = "An error has occured on the server. Please check your image data and configurations. If the problem persists contact your software provider." message_title = "Error" return message_text, message_title - + def run_inference(self): - """ Checks if the ml model is connected to the server, connects if not (and if possible), and runs inference on all images in eval_data_path """ - if not self.ml_model.is_connected and not self.try_server_connection(): + """Checks if the ml model is connected to the server, connects if not (and if possible), and runs inference on all images in eval_data_path""" + if not self.ml_model.is_connected and not self.try_server_connection(): message_title = "Warning" message_text = "Connection could not be established. Please check if the server is running and try again." return message_text, message_title - - if self.syncer.host_name=="local": + + if self.syncer.host_name == "local": # model serving directly from local - list_of_files_not_suported = self.ml_model.run_inference(self.eval_data_path) - success_sync = "Success" + list_of_files_not_suported = self.ml_model.run_inference( + self.eval_data_path + ) + success_sync = "Success" else: # sync data so that server gets updated files in client - e.g. if file was moved to curated srv_relative_path = utils.get_relative_path(self.eval_data_path) - success_sync, _ = self.syncer.sync(src='client', dst='server', path=self.eval_data_path) + success_sync, _ = self.syncer.sync( + src="client", dst="server", path=self.eval_data_path + ) # model serving from server list_of_files_not_suported = self.ml_model.run_inference(srv_relative_path) - # sync data so that client gets new masks - success_sync, _ = self.syncer.sync(src='server', dst='client', path=self.eval_data_path) + # sync data so that client gets new masks + success_sync, _ = self.syncer.sync( + src="server", dst="client", path=self.eval_data_path + ) # check if serving could not be performed for some files and prepare message - if list_of_files_not_suported is None or success_sync=="Error": + if list_of_files_not_suported is None or success_sync == "Error": message_text = "An error has occured on the server. Please check your image data and configurations. If the problem persists contact your software provider." message_title = "Error" else: list_of_files_not_suported = list(list_of_files_not_suported) if len(list_of_files_not_suported) > 0: - message_text = "Image types not supported. Only 2D and 3D image shapes currently supported. 3D stacks must be of type grayscale. \ - Currently supported image file formats are: " + ", ".join(settings.accepted_types)+ ". The files that were not supported are: " + ", ".join(list_of_files_not_suported) + message_text = ( + "Image types not supported. Only 2D and 3D image shapes currently supported. 3D stacks must be of type grayscale. \ + Currently supported image file formats are: " + + ", ".join(settings.accepted_types) + + ". The files that were not supported are: " + + ", ".join(list_of_files_not_suported) + ) message_title = "Warning" else: message_text = "Success! Masks generated for all images" @@ -145,51 +169,58 @@ def load_image(self, image_name=None): """ if image_name is None: - return self.fs_image_storage.load_image(self.cur_selected_path, self.cur_selected_img) - else: return self.fs_image_storage.load_image(self.cur_selected_path, image_name) - + return self.fs_image_storage.load_image( + self.cur_selected_path, self.cur_selected_img + ) + else: + return self.fs_image_storage.load_image(self.cur_selected_path, image_name) + def search_segs(self): - """ Searches in cur_selected_path for all possible segmentation files associated to cur_selected_img. - These files should have a _seg extension to the cur_selected_img filename. """ - self.seg_filepaths = self.fs_image_storage.search_segs(self.cur_selected_path, self.cur_selected_img) - + """Searches in cur_selected_path for all possible segmentation files associated to cur_selected_img. + These files should have a _seg extension to the cur_selected_img filename.""" + self.seg_filepaths = self.fs_image_storage.search_segs( + self.cur_selected_path, self.cur_selected_img + ) + def save_image(self, dst_directory, image_name, img): - """ Saves img array image in the dst_directory with filename cur_selected_img - + """Saves img array image in the dst_directory with filename cur_selected_img + :param dst_directory: The destination directory where the image will be saved. :type dst_directory: str :param image_name: The name of the image file. :type image_name: str - :param img: The image that will be saved. + :param img: The image that will be saved. :type img: numpy.ndarray """ self.fs_image_storage.save_image(dst_directory, image_name, img) def move_images(self, dst_directory, move_segs=False): """ - Moves cur_selected_img image from the current directory to the dst_directory. - + Moves cur_selected_img image from the current directory to the dst_directory. + :param dst_directory: The destination directory where the images will be moved. :type dst_directory: str :param move_segs: If True, moves the corresponding segmentation along with the image. Default is False. :type move_segs: bool - + """ - #if image_name is None: - self.fs_image_storage.move_image(self.cur_selected_path, dst_directory, self.cur_selected_img) + # if image_name is None: + self.fs_image_storage.move_image( + self.cur_selected_path, dst_directory, self.cur_selected_img + ) if move_segs: for seg_name in self.seg_filepaths: - self.fs_image_storage.move_image(self.cur_selected_path, dst_directory, seg_name) + self.fs_image_storage.move_image( + self.cur_selected_path, dst_directory, seg_name + ) def delete_images(self, image_names): - """ If image_name in the image_names list exists in the current directory it is deleted. - + """If image_name in the image_names list exists in the current directory it is deleted. + :param image_names: A list of image names to be deleted. :type image_names: list[str] """ for image_name in image_names: - if os.path.exists(os.path.join(self.cur_selected_path, image_name)): + if os.path.exists(os.path.join(self.cur_selected_path, image_name)): self.fs_image_storage.delete_image(self.cur_selected_path, image_name) - - diff --git a/src/client/dcp_client/config.cfg b/src/client/dcp_client/config.yaml similarity index 100% rename from src/client/dcp_client/config.cfg rename to src/client/dcp_client/config.yaml diff --git a/src/client/dcp_client/config_remote.cfg b/src/client/dcp_client/config_remote.yaml similarity index 100% rename from src/client/dcp_client/config_remote.cfg rename to src/client/dcp_client/config_remote.yaml diff --git a/src/client/dcp_client/gui/_my_widget.py b/src/client/dcp_client/gui/_my_widget.py index 8298360e..acf54b61 100644 --- a/src/client/dcp_client/gui/_my_widget.py +++ b/src/client/dcp_client/gui/_my_widget.py @@ -1,18 +1,25 @@ from PyQt5.QtWidgets import QWidget, QMessageBox from PyQt5.QtCore import QTimer + class MyWidget(QWidget): """ This class represents a custom widget. """ msg = None - sim = False # will be used for testing to simulate user click + sim = False # will be used for testing to simulate user click - def create_warning_box(self, message_text: str=" ", message_title: str="Information", add_cancel_btn: bool=False, custom_dialog=None) -> None: + def create_warning_box( + self, + message_text: str = " ", + message_title: str = "Information", + add_cancel_btn: bool = False, + custom_dialog=None, + ) -> None: """Creates a warning box with the specified message and options. - :param message_text: The text to be displayed in the message box. + :param message_text: The text to be displayed in the message box. :type message_text: str :param message_title: The title of the message box. Default is "Information". :type message_title: str @@ -21,14 +28,16 @@ def create_warning_box(self, message_text: str=" ", message_title: str="Informat :param custom_dialog: An optional custom dialog to use instead of creating a new QMessageBox instance. Default is None. :type custom_dialog: Any :return: None - """ - #setup box - if custom_dialog is not None: self.msg = custom_dialog - else: self.msg = QMessageBox() + """ + # setup box + if custom_dialog is not None: + self.msg = custom_dialog + else: + self.msg = QMessageBox() - if message_title=="Warning": + if message_title == "Warning": message_type = QMessageBox.Warning - elif message_title=="Error": + elif message_title == "Error": message_type = QMessageBox.Critical else: message_type = QMessageBox.Information @@ -39,12 +48,16 @@ def create_warning_box(self, message_text: str=" ", message_title: str="Informat if add_cancel_btn: self.msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) # simulate button click if specified - workaround used for testing - if self.sim: QTimer.singleShot(0, self.msg.button(QMessageBox.Cancel).clicked) + if self.sim: + QTimer.singleShot(0, self.msg.button(QMessageBox.Cancel).clicked) else: self.msg.setStandardButtons(QMessageBox.Ok) # simulate button click if specified - workaround used for testing - if self.sim: QTimer.singleShot(0, self.msg.button(QMessageBox.Ok).clicked) + if self.sim: + QTimer.singleShot(0, self.msg.button(QMessageBox.Ok).clicked) # return if user clicks Ok and False otherwise usr_response = self.msg.exec() - if usr_response == QMessageBox.Ok: return True - else: return False \ No newline at end of file + if usr_response == QMessageBox.Ok: + return True + else: + return False diff --git a/src/client/dcp_client/gui/main_window.py b/src/client/dcp_client/gui/main_window.py index 8407c1dc..c1eec891 100644 --- a/src/client/dcp_client/gui/main_window.py +++ b/src/client/dcp_client/gui/main_window.py @@ -1,8 +1,17 @@ from __future__ import annotations from typing import TYPE_CHECKING -from PyQt5.QtWidgets import QPushButton, QVBoxLayout, QFileSystemModel, QHBoxLayout, QLabel, QTreeView, QProgressBar, QShortcut -from PyQt5.QtCore import Qt, QThread, pyqtSignal +from PyQt5.QtWidgets import ( + QPushButton, + QVBoxLayout, + QFileSystemModel, + QHBoxLayout, + QLabel, + QTreeView, + QProgressBar, + QShortcut, +) +from PyQt5.QtCore import Qt, QThread, QModelIndex, pyqtSignal from PyQt5.QtGui import QKeySequence from dcp_client.utils import settings @@ -14,13 +23,21 @@ if TYPE_CHECKING: from dcp_client.app import Application + class WorkerThread(QThread): """ - Worker thread for displaying Pulse ProgressBar during model serving. - + Worker thread for displaying Pulse ProgressBar during model serving. + """ + task_finished = pyqtSignal(tuple) - def __init__(self, app: Application, task: str = None, parent = None,): + + def __init__( + self, + app: Application, + task: str = None, + parent=None, + ): """ Initialize the WorkerThread. @@ -34,15 +51,15 @@ def __init__(self, app: Application, task: str = None, parent = None,): self.app = app self.task = task - def run(self): + def run(self) -> None: """ - Once run_inference or run_train is executed, the tuple of + Once run_inference or run_train is executed, the tuple of (message_text, message_title) will be returned to on_finished. """ try: - if self.task == 'inference': + if self.task == "inference": message_text, message_title = self.app.run_inference() - elif self.task == 'train': + elif self.task == "train": message_text, message_title = self.app.run_train() else: message_text, message_title = "Unknown task", "Error" @@ -53,19 +70,20 @@ def run(self): self.task_finished.emit((message_text, message_title)) + class MainWindow(MyWidget): """ Main Window Widget object. - Opens the main window of the app where selected images in both directories are listed. + Opens the main window of the app where selected images in both directories are listed. User can view the images, train the model to get the labels, and visualise the result. - + :param eval_data_path: Chosen path to images without labeles, selected by the user in the WelcomeWindow :type eval_data_path: string :param train_data_path: Chosen path to images with labeles, selected by the user in the WelcomeWindow :type train_data_path: string - """ + """ - def __init__(self, app: Application): + def __init__(self, app: Application) -> None: """ Initializes the MainWindow. @@ -81,21 +99,20 @@ def __init__(self, app: Application): self.title = "Data Overview" self.worker_thread = None self.main_window() - - def main_window(self): - """Sets up the GUI - """ + + def main_window(self) -> None: + """Sets up the GUI""" self.setWindowTitle(self.title) - #self.resize(1000, 1500) + # self.resize(1000, 1500) main_layout = QVBoxLayout() - dir_layout = QHBoxLayout() - + dir_layout = QHBoxLayout() + self.uncurated_layout = QVBoxLayout() self.inprogress_layout = QVBoxLayout() self.curated_layout = QVBoxLayout() - self.eval_dir_layout = QVBoxLayout() - self.eval_dir_layout.setContentsMargins(0,0,0,0) + self.eval_dir_layout = QVBoxLayout() + self.eval_dir_layout.setContentsMargins(0, 0, 0, 0) self.label_eval = QLabel(self) self.label_eval.setText("Uncurated dataset") self.eval_dir_layout.addWidget(self.label_eval) @@ -104,45 +121,55 @@ def main_window(self): model_eval.setIconProvider(IconProvider()) self.list_view_eval = QTreeView(self) self.list_view_eval.setModel(model_eval) - for i in range(1,4): + for i in range(1, 4): self.list_view_eval.hideColumn(i) - #self.list_view_eval.setFixedSize(600, 600) - self.list_view_eval.setRootIndex(model_eval.setRootPath(self.app.eval_data_path)) + # self.list_view_eval.setFixedSize(600, 600) + self.list_view_eval.setRootIndex( + model_eval.setRootPath(self.app.eval_data_path) + ) self.list_view_eval.clicked.connect(self.on_item_eval_selected) - + self.eval_dir_layout.addWidget(self.list_view_eval) self.uncurated_layout.addLayout(self.eval_dir_layout) # add buttons self.inference_button = QPushButton("Generate Labels", self) - self.inference_button.clicked.connect(self.on_run_inference_button_clicked) # add selected image + self.inference_button.clicked.connect( + self.on_run_inference_button_clicked + ) # add selected image self.uncurated_layout.addWidget(self.inference_button, alignment=Qt.AlignCenter) dir_layout.addLayout(self.uncurated_layout) # In progress layout - self.inprogr_dir_layout = QVBoxLayout() - self.inprogr_dir_layout.setContentsMargins(0,0,0,0) + self.inprogr_dir_layout = QVBoxLayout() + self.inprogr_dir_layout.setContentsMargins(0, 0, 0, 0) self.label_inprogr = QLabel(self) self.label_inprogr.setText("Curation in progress") self.inprogr_dir_layout.addWidget(self.label_inprogr) # add in progress dir list model_inprogr = QFileSystemModel() - #self.list_view = QListView(self) + # self.list_view = QListView(self) self.list_view_inprogr = QTreeView(self) model_inprogr.setIconProvider(IconProvider()) self.list_view_inprogr.setModel(model_inprogr) - for i in range(1,4): + for i in range(1, 4): self.list_view_inprogr.hideColumn(i) - #self.list_view_inprogr.setFixedSize(600, 600) - self.list_view_inprogr.setRootIndex(model_inprogr.setRootPath(self.app.inprogr_data_path)) + # self.list_view_inprogr.setFixedSize(600, 600) + self.list_view_inprogr.setRootIndex( + model_inprogr.setRootPath(self.app.inprogr_data_path) + ) self.list_view_inprogr.clicked.connect(self.on_item_inprogr_selected) self.inprogr_dir_layout.addWidget(self.list_view_inprogr) self.inprogress_layout.addLayout(self.inprogr_dir_layout) self.launch_nap_button = QPushButton("View image and fix label", self) - self.launch_nap_button.clicked.connect(self.on_launch_napari_button_clicked) # add selected image - self.inprogress_layout.addWidget(self.launch_nap_button, alignment=Qt.AlignCenter) + self.launch_nap_button.clicked.connect( + self.on_launch_napari_button_clicked + ) # add selected image + self.inprogress_layout.addWidget( + self.launch_nap_button, alignment=Qt.AlignCenter + ) # Create a shortcut for the Enter key to click the button enter_shortcut = QShortcut(QKeySequence(Qt.Key_Return), self) enter_shortcut.activated.connect(self.on_launch_napari_button_clicked) @@ -150,27 +177,31 @@ def main_window(self): dir_layout.addLayout(self.inprogress_layout) # Curated layout - self.train_dir_layout = QVBoxLayout() - self.train_dir_layout.setContentsMargins(0,0,0,0) + self.train_dir_layout = QVBoxLayout() + self.train_dir_layout.setContentsMargins(0, 0, 0, 0) self.label_train = QLabel(self) self.label_train.setText("Curated dataset") self.train_dir_layout.addWidget(self.label_train) # add train dir list model_train = QFileSystemModel() - #self.list_view = QListView(self) + # self.list_view = QListView(self) self.list_view_train = QTreeView(self) model_train.setIconProvider(IconProvider()) self.list_view_train.setModel(model_train) - for i in range(1,4): + for i in range(1, 4): self.list_view_train.hideColumn(i) - #self.list_view_train.setFixedSize(600, 600) - self.list_view_train.setRootIndex(model_train.setRootPath(self.app.train_data_path)) + # self.list_view_train.setFixedSize(600, 600) + self.list_view_train.setRootIndex( + model_train.setRootPath(self.app.train_data_path) + ) self.list_view_train.clicked.connect(self.on_item_train_selected) self.train_dir_layout.addWidget(self.list_view_train) self.curated_layout.addLayout(self.train_dir_layout) - + self.train_button = QPushButton("Train Model", self) - self.train_button.clicked.connect(self.on_train_button_clicked) # add selected image + self.train_button.clicked.connect( + self.on_train_button_clicked + ) # add selected image self.curated_layout.addWidget(self.train_button, alignment=Qt.AlignCenter) dir_layout.addLayout(self.curated_layout) @@ -178,16 +209,16 @@ def main_window(self): # add progress bar progress_layout = QHBoxLayout() - progress_layout.addStretch(1) + progress_layout.addStretch(1) self.progress_bar = QProgressBar(self) - self.progress_bar.setRange(0,1) + self.progress_bar.setRange(0, 1) progress_layout.addWidget(self.progress_bar) main_layout.addLayout(progress_layout) self.setLayout(main_layout) self.show() - def on_item_train_selected(self, item): + def on_item_train_selected(self, item: QModelIndex) -> None: """ Is called once an image is selected in the 'curated dataset' folder. @@ -197,7 +228,7 @@ def on_item_train_selected(self, item): self.app.cur_selected_img = item.data() self.app.cur_selected_path = self.app.train_data_path - def on_item_eval_selected(self, item): + def on_item_eval_selected(self, item: QModelIndex) -> None: """ Is called once an image is selected in the 'uncurated dataset' folder. @@ -207,7 +238,7 @@ def on_item_eval_selected(self, item): self.app.cur_selected_img = item.data() self.app.cur_selected_path = self.app.eval_data_path - def on_item_inprogr_selected(self, item): + def on_item_inprogr_selected(self, item: QModelIndex) -> None: """ Is called once an image is selected in the 'in progress' folder. @@ -217,50 +248,50 @@ def on_item_inprogr_selected(self, item): self.app.cur_selected_img = item.data() self.app.cur_selected_path = self.app.inprogr_data_path - def on_train_button_clicked(self): + def on_train_button_clicked(self) -> None: """ Is called once user clicks the "Train Model" button. """ self.train_button.setEnabled(False) - self.progress_bar.setRange(0,0) + self.progress_bar.setRange(0, 0) # initialise the worker thread - self.worker_thread = WorkerThread(app=self.app, task='train') + self.worker_thread = WorkerThread(app=self.app, task="train") self.worker_thread.task_finished.connect(self.on_finished) # start the worker thread to train self.worker_thread.start() - def on_run_inference_button_clicked(self): + def on_run_inference_button_clicked(self) -> None: """ Is called once user clicks the "Generate Labels" button. """ self.inference_button.setEnabled(False) - self.progress_bar.setRange(0,0) + self.progress_bar.setRange(0, 0) # initialise the worker thread - self.worker_thread = WorkerThread(app=self.app, task='inference') + self.worker_thread = WorkerThread(app=self.app, task="inference") self.worker_thread.task_finished.connect(self.on_finished) # start the worker thread to run inference self.worker_thread.start() - def on_launch_napari_button_clicked(self): - """ + def on_launch_napari_button_clicked(self) -> None: + """ Launches the napari window after the image is selected. """ - if not self.app.cur_selected_img or '_seg.tiff' in self.app.cur_selected_img: + if not self.app.cur_selected_img or "_seg.tiff" in self.app.cur_selected_img: message_text = "Please first select an image you wish to visualise. The selected image must be an original image, not a mask." _ = self.create_warning_box(message_text, message_title="Warning") else: self.nap_win = NapariWindow(self.app) - self.nap_win.show() + self.nap_win.show() - def on_finished(self, result): + def on_finished(self, result: tuple) -> None: """ Is called once the worker thread emits the on finished signal. :param result: The result emitted by the worker thread. See return type of WorkerThread.run :type result: tuple - """ + """ # Stop the pulsation - self.progress_bar.setRange(0,1) + self.progress_bar.setRange(0, 1) # Display message of result message_text, message_title = result _ = self.create_warning_box(message_text, message_title) @@ -282,20 +313,21 @@ def on_finished(self, result): from dcp_client.utils.fsimagestorage import FilesystemImageStorage from dcp_client.utils import settings from dcp_client.utils.sync_src_dst import DataRSync + settings.init() image_storage = FilesystemImageStorage() ml_model = BentomlModel() - data_sync = DataRSync(user_name="local", - host_name="local", - server_repo_path=None) + data_sync = DataRSync(user_name="local", host_name="local", server_repo_path=None) app = QApplication(sys.argv) - app_ = Application(ml_model=ml_model, - syncer=data_sync, - image_storage=image_storage, - server_ip='0.0.0.0', - server_port=7010, - eval_data_path='data', - train_data_path='', # set path - inprogr_data_path='') # set path + app_ = Application( + ml_model=ml_model, + syncer=data_sync, + image_storage=image_storage, + server_ip="0.0.0.0", + server_port=7010, + eval_data_path="data", + train_data_path="", # set path + inprogr_data_path="", + ) # set path window = MainWindow(app=app_) - sys.exit(app.exec()) \ No newline at end of file + sys.exit(app.exec()) diff --git a/src/client/dcp_client/gui/napari_window.py b/src/client/dcp_client/gui/napari_window.py index e888ac67..001720cf 100644 --- a/src/client/dcp_client/gui/napari_window.py +++ b/src/client/dcp_client/gui/napari_window.py @@ -1,16 +1,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from copy import deepcopy from qtpy.QtWidgets import QPushButton, QComboBox, QLabel, QGridLayout from qtpy.QtCore import Qt import napari +import numpy as np if TYPE_CHECKING: from dcp_client.app import Application -from dcp_client.utils.utils import get_path_stem, check_equal_arrays, Compute4Mask +from dcp_client.utils.utils import get_path_stem, check_equal_arrays +from dcp_client.utils.compute4mask import Compute4Mask from dcp_client.gui._my_widget import MyWidget + class NapariWindow(MyWidget): """Napari Window Widget object. Opens the napari image viewer to view and fix the labeles. @@ -18,7 +21,7 @@ class NapariWindow(MyWidget): :type app: Application """ - def __init__(self, app: Application): + def __init__(self, app: Application) -> None: """Initializes the NapariWindow. :param app: The Application instance. @@ -37,17 +40,24 @@ def __init__(self, app: Application): self.viewer = napari.Viewer(show=False) self.viewer.add_image(img, name=get_path_stem(self.app.cur_selected_img)) for seg_file in self.seg_files: - self.viewer.add_labels(self.app.load_image(seg_file), name=get_path_stem(seg_file)) + self.viewer.add_labels( + self.app.load_image(seg_file), name=get_path_stem(seg_file) + ) main_window = self.viewer.window._qt_window layout = QGridLayout() layout.addWidget(main_window, 0, 0, 1, 4) # select the first seg as the currently selected layer if there are any segs - if len(self.seg_files): + if ( + len(self.seg_files) + and len(self.viewer.layers[get_path_stem(self.seg_files[0])].data.shape) > 2 + ): self.cur_selected_seg = self.viewer.layers.selection.active.name self.layer = self.viewer.layers[self.cur_selected_seg] - self.viewer.layers.selection.events.changed.connect(self.on_seg_channel_changed) + self.viewer.layers.selection.events.changed.connect( + self.on_seg_channel_changed + ) # set first mask as active by default self.active_mask_index = 0 self.viewer.dims.events.current_step.connect(self.axis_changed) @@ -58,27 +68,39 @@ def __init__(self, app: Application): for seg_file in self.seg_files: layer_name = get_path_stem(seg_file) # get unique instance labels for each seg - self.original_instance_mask[layer_name] = deepcopy(self.viewer.layers[layer_name].data[0]) - self.original_class_mask[layer_name] = deepcopy(self.viewer.layers[layer_name].data[1]) + self.original_instance_mask[layer_name] = deepcopy( + self.viewer.layers[layer_name].data[0] + ) + self.original_class_mask[layer_name] = deepcopy( + self.viewer.layers[layer_name].data[1] + ) # compute unique instance ids - self.instances[layer_name] = Compute4Mask.get_unique_objects(self.original_instance_mask[layer_name]) + self.instances[layer_name] = Compute4Mask.get_unique_objects( + self.original_instance_mask[layer_name] + ) # remove border from class mask - self.contours_mask[layer_name] = Compute4Mask.get_contours(self.original_instance_mask[layer_name], contours_level=0.8) - self.viewer.layers[layer_name].data[1][self.contours_mask[layer_name]!=0] = 0 - + self.contours_mask[layer_name] = Compute4Mask.get_contours( + self.original_instance_mask[layer_name], contours_level=0.8 + ) + self.viewer.layers[layer_name].data[1][ + self.contours_mask[layer_name] != 0 + ] = 0 + self.qctrl = self.viewer.window.qt_viewer.controls.widgets[self.layer] - if self.layer.data.shape[0] >= 2: + if len(self.layer.data.shape) > 2: # User hint - message_label = QLabel('Choose an active mask') + message_label = QLabel("Choose an active mask") message_label.setAlignment(Qt.AlignRight) layout.addWidget(message_label, 1, 0) - + # Drop list to choose which is an active mask self.mask_choice_dropdown = QComboBox() self.mask_choice_dropdown.setEnabled(False) - self.mask_choice_dropdown.addItem('Instance Segmentation Mask', userData=0) - self.mask_choice_dropdown.addItem('Labels Mask', userData=1) + self.mask_choice_dropdown.addItem( + "Instance Segmentation Mask", userData=0 + ) + self.mask_choice_dropdown.addItem("Labels Mask", userData=1) layout.addWidget(self.mask_choice_dropdown, 1, 1) # when user has chosen the mask, we don't want to change it anymore to avoid errors @@ -91,17 +113,21 @@ def __init__(self, app: Application): self.layer = None # add buttons for moving images to other dirs - add_to_inprogress_button = QPushButton('Move to \'Curatation in progress\' folder') + add_to_inprogress_button = QPushButton( + "Move to 'Curatation in progress' folder" + ) layout.addWidget(add_to_inprogress_button, 2, 0, 1, 2) - add_to_inprogress_button.clicked.connect(self.on_add_to_inprogress_button_clicked) + add_to_inprogress_button.clicked.connect( + self.on_add_to_inprogress_button_clicked + ) - add_to_curated_button = QPushButton('Move to \'Curated dataset\' folder') + add_to_curated_button = QPushButton("Move to 'Curated dataset' folder") layout.addWidget(add_to_curated_button, 2, 2, 1, 2) add_to_curated_button.clicked.connect(self.on_add_to_curated_button_clicked) self.setLayout(layout) - def set_editable_mask(self): + def set_editable_mask(self) -> None: """ This function is not implemented. In theory the use can choose between which mask to edit. Currently painting and erasing is only possible on instance mask and in the class mask only @@ -109,59 +135,71 @@ def set_editable_mask(self): """ pass - def on_seg_channel_changed(self, event): + def on_seg_channel_changed(self, event) -> None: """ Is triggered each time the user selects a different layer in the viewer. """ if (act := self.viewer.layers.selection.active) is not None: # updater cur_selected_seg with the new selection from the user self.cur_selected_seg = act.name - if type(self.viewer.layers[self.cur_selected_seg]) == napari.layers.Image: pass + if type(self.viewer.layers[self.cur_selected_seg]) == napari.layers.Image: + pass # set self.layer to new selection from user - elif self.layer is not None: self.layer = self.viewer.layers[self.cur_selected_seg] - else: pass - - def axis_changed(self, event): + elif self.layer is not None: + self.layer = self.viewer.layers[self.cur_selected_seg] + else: + pass + + def axis_changed(self, event) -> None: """ - Is triggered each time the user switches the viewer between the mask channels. At this point the class mask + Is triggered each time the user switches the viewer between the mask channels. At this point the class mask needs to be updated according to the changes made tot the instance segmentation mask. """ self.active_mask_index = self.viewer.dims.current_step[0] masks = deepcopy(self.layer.data) # if user has switched to the instance mask - if self.active_mask_index==0: + if self.active_mask_index == 0: class_mask_with_contours = Compute4Mask.add_contour(masks[1], masks[0]) - if not check_equal_arrays(class_mask_with_contours.astype(bool), self.original_class_mask[self.cur_selected_seg].astype(bool)): + if not check_equal_arrays( + class_mask_with_contours.astype(bool), + self.original_class_mask[self.cur_selected_seg].astype(bool), + ): self.update_instance_mask(masks[0], masks[1]) self.switch_to_instance_mask() # else if user has switched to the class mask - elif self.active_mask_index==1: - if not check_equal_arrays(masks[0], self.original_instance_mask[self.cur_selected_seg]): + elif self.active_mask_index == 1: + if not check_equal_arrays( + masks[0], self.original_instance_mask[self.cur_selected_seg] + ): self.update_labels_mask(masks[0]) self.switch_to_labels_mask() - def switch_to_instance_mask(self): + def switch_to_instance_mask(self) -> None: """ - Switch the application to the active mask mode by enabling 'paint_button', 'erase_button' + Switch the application to the active mask mode by enabling 'paint_button', 'erase_button' and 'fill_button'. """ self.switch_controls("paint_button", True) self.switch_controls("erase_button", True) self.switch_controls("fill_button", True) - def switch_to_labels_mask(self): + def switch_to_labels_mask(self) -> None: """ Switch the application to non-active mask mode by enabling 'fill_button' and disabling 'paint_button' and 'erase_button'. """ if self.cur_selected_seg in [layer.name for layer in self.viewer.layers]: - self.viewer.layers[self.cur_selected_seg].mode = 'pan_zoom' - info_message_paint = "Painting objects is only possible in the instance layer for now." - info_message_erase = "Erasing objects is only possible in the instance layer for now." + self.viewer.layers[self.cur_selected_seg].mode = "pan_zoom" + info_message_paint = ( + "Painting objects is only possible in the instance layer for now." + ) + info_message_erase = ( + "Erasing objects is only possible in the instance layer for now." + ) self.switch_controls("paint_button", False, info_message_paint) self.switch_controls("erase_button", False, info_message_erase) - self.switch_controls("fill_button", True) + self.switch_controls("fill_button", True) - def update_labels_mask(self, instance_mask): + def update_labels_mask(self, instance_mask: np.ndarray) -> None: """Updates the class mask based on changes in the instance mask. If the instance mask has changed since the last switch between channels, the class mask needs to be updated accordingly. @@ -170,22 +208,32 @@ def update_labels_mask(self, instance_mask): :type instance_mask: numpy.ndarray :return: None """ - self.original_class_mask[self.cur_selected_seg] = Compute4Mask.compute_new_labels_mask(self.original_class_mask[self.cur_selected_seg], - instance_mask, - self.original_instance_mask[self.cur_selected_seg], - self.instances[self.cur_selected_seg]) + self.original_class_mask[self.cur_selected_seg] = ( + Compute4Mask.compute_new_labels_mask( + self.original_class_mask[self.cur_selected_seg], + instance_mask, + self.original_instance_mask[self.cur_selected_seg], + self.instances[self.cur_selected_seg], + ) + ) # update original instance mask and instances self.original_instance_mask[self.cur_selected_seg] = instance_mask - self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects(self.original_instance_mask[self.cur_selected_seg]) + self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects( + self.original_instance_mask[self.cur_selected_seg] + ) # compute contours to remove from class mask visualisation - self.contours_mask[self.cur_selected_seg] = Compute4Mask.get_contours(instance_mask, contours_level=0.8) + self.contours_mask[self.cur_selected_seg] = Compute4Mask.get_contours( + instance_mask, contours_level=0.8 + ) vis_labels_mask = deepcopy(self.original_class_mask[self.cur_selected_seg]) - vis_labels_mask[self.contours_mask[self.cur_selected_seg]!=0] = 0 + vis_labels_mask[self.contours_mask[self.cur_selected_seg] != 0] = 0 # update the viewer self.layer.data[1] = vis_labels_mask self.layer.refresh() - def update_instance_mask(self, instance_mask, labels_mask): + def update_instance_mask( + self, instance_mask: np.ndarray, labels_mask: np.ndarray + ) -> None: """Updates the instance mask based on changes in the labels mask. If the labels mask has changed, but only if an object has been removed, the instance mask is updated accordingly. @@ -198,15 +246,20 @@ def update_instance_mask(self, instance_mask, labels_mask): # add contours back to labels mask labels_mask = Compute4Mask.add_contour(labels_mask, instance_mask) # and compute the updated instance mask - self.original_instance_mask[self.cur_selected_seg] = Compute4Mask.compute_new_instance_mask(labels_mask, - instance_mask) - self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects(self.original_instance_mask[self.cur_selected_seg]) + self.original_instance_mask[self.cur_selected_seg] = ( + Compute4Mask.compute_new_instance_mask(labels_mask, instance_mask) + ) + self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects( + self.original_instance_mask[self.cur_selected_seg] + ) self.original_class_mask[self.cur_selected_seg] = labels_mask # update the viewer self.layer.data[0] = self.original_instance_mask[self.cur_selected_seg] self.layer.refresh() - def switch_controls(self, target_widget, status: bool, info_message=None): + def switch_controls( + self, target_widget: str, status: bool, info_message: Optional[str] = None + ) -> None: """Enables or disables a specific widget. :param target_widget: The name of the widget to be controlled within the QCtrl object. @@ -223,68 +276,76 @@ def switch_controls(self, target_widget, status: bool, info_message=None): except: pass - def on_add_to_curated_button_clicked(self): - """Defines what happens when the "Move to curated dataset folder" button is clicked. - """ - if self.app.cur_selected_path == str(self.app.train_data_path): - message_text = "Image is already in the \'Curated data\' folder and should not be changed again" + def on_add_to_curated_button_clicked(self) -> None: + """Defines what happens when the "Move to curated dataset folder" button is clicked.""" + if self.app.cur_selected_path == str(self.app.train_data_path): + message_text = "Image is already in the 'Curated data' folder and should not be changed again" _ = self.create_warning_box(message_text, message_title="Warning") return - + # take the name of the currently selected layer (by the user) seg_name_to_save = self.viewer.layers.selection.active.name # TODO if more than one item is selected this will break! - if '_seg' not in seg_name_to_save: + if "_seg" not in seg_name_to_save: message_text = ( - "Please select the segmenation you wish to save from the layer list." - "The labels layer should have the same name as the image to which it corresponds, followed by _seg." + "Please select the segmenation you wish to save from the layer list." + "The labels layer should have the same name as the image to which it corresponds, followed by _seg." ) _ = self.create_warning_box(message_text, message_title="Warning") return - + # Save the (changed) seg seg = self.viewer.layers[seg_name_to_save].data seg[1] = Compute4Mask.add_contour(seg[1], seg[0]) - annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = Compute4Mask.assert_consistent_labels(seg) + annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = ( + Compute4Mask.assert_consistent_labels(seg) + ) if annot_error: - message_text = ("There seems to be a problem with your mask. We expect each object to be a connected component. For object(s) with ID(s) \n" - +str(faulty_ids_annot)+"\n" - "more than one connected component was found. Please go back and fix this.") + message_text = ( + "There seems to be a problem with your mask. We expect each object to be a connected component. For object(s) with ID(s) \n" + + str(faulty_ids_annot) + + "\n" + "more than one connected component was found. Please go back and fix this." + ) self.create_warning_box(message_text, "Warning") elif mask_mismatch_error: - message_text = ("There seems to be a mismatch between your class and instance masks for object(s) with ID(s) \n" - +str(faulty_ids_missmatch)+"\n" - "This should not occur and will cause a problem later during model training. Please go back and check.") + message_text = ( + "There seems to be a mismatch between your class and instance masks for object(s) with ID(s) \n" + + str(faulty_ids_missmatch) + + "\n" + "This should not occur and will cause a problem later during model training. Please go back and check." + ) self.create_warning_box(message_text, "Warning") - else: + else: # Move original image self.app.move_images(self.app.train_data_path) - self.app.save_image(self.app.train_data_path, seg_name_to_save+'.tiff', seg) + self.app.save_image( + self.app.train_data_path, seg_name_to_save + ".tiff", seg + ) # We remove seg from the current directory if it exists (both eval and inprogr allowed) self.app.delete_images(self.seg_files) - # TODO Create the Archive folder for the rest? Or move them as well? + # TODO Create the Archive folder for the rest? Or move them as well? self.viewer.close() self.close() - def on_add_to_inprogress_button_clicked(self): - """Defines what happens when the "Move to curation in progress folder" button is clicked. - """ + def on_add_to_inprogress_button_clicked(self) -> None: + """Defines what happens when the "Move to curation in progress folder" button is clicked.""" # TODO: Do we allow this? What if they moved it by mistake? User can always manually move from their folders?) if self.app.cur_selected_path == str(self.app.train_data_path): - message_text = "Images from '\Curated data'\ folder can not be moved back to \'Curatation in progress\' folder." + message_text = "Images from '\Curated data'\ folder can not be moved back to 'Curatation in progress' folder." _ = self.create_warning_box(message_text, message_title="Warning") return - + # take the name of the currently selected layer (by the user) seg_name_to_save = self.viewer.layers.selection.active.name # TODO if more than one item is selected this will break! - if '_seg' not in seg_name_to_save: + if "_seg" not in seg_name_to_save: message_text = ( - "Please select the segmenation you wish to save from the layer list." - "The labels layer should have the same name as the image to which it corresponds, followed by _seg." + "Please select the segmenation you wish to save from the layer list." + "The labels layer should have the same name as the image to which it corresponds, followed by _seg." ) _ = self.create_warning_box(message_text, message_title="Warning") return @@ -294,8 +355,7 @@ def on_add_to_inprogress_button_clicked(self): # Save the (changed) seg - this will overwrite existing seg if seg name hasn't been changed in viewer seg = self.viewer.layers[seg_name_to_save].data seg[1] = Compute4Mask.add_contour(seg[1], seg[0]) - self.app.save_image(self.app.inprogr_data_path, seg_name_to_save+'.tiff', seg) - + self.app.save_image(self.app.inprogr_data_path, seg_name_to_save + ".tiff", seg) + self.viewer.close() self.close() - \ No newline at end of file diff --git a/src/client/dcp_client/gui/welcome_window.py b/src/client/dcp_client/gui/welcome_window.py index 74b3c55d..f4bd73da 100644 --- a/src/client/dcp_client/gui/welcome_window.py +++ b/src/client/dcp_client/gui/welcome_window.py @@ -1,7 +1,14 @@ from __future__ import annotations from typing import TYPE_CHECKING -from qtpy.QtWidgets import QPushButton, QVBoxLayout, QHBoxLayout, QLabel, QFileDialog, QLineEdit +from qtpy.QtWidgets import ( + QPushButton, + QVBoxLayout, + QHBoxLayout, + QLabel, + QFileDialog, + QLineEdit, +) from qtpy.QtCore import Qt from dcp_client.gui.main_window import MainWindow @@ -10,14 +17,15 @@ if TYPE_CHECKING: from dcp_client.app import Application + class WelcomeWindow(MyWidget): """Welcome Window Widget object. - The first window of the application providing a dialog that allows users to select directories. + The first window of the application providing a dialog that allows users to select directories. Currently supported image file types that can be selected for segmentation are: .jpg, .jpeg, .png, .tiff, .tif. By clicking 'start' the MainWindow is called. """ - def __init__(self, app: Application): + def __init__(self, app: Application) -> None: """Initializes the WelcomeWindow. :param app: The Application instance. @@ -30,7 +38,9 @@ def __init__(self, app: Application): self.main_layout = QVBoxLayout() input_layout = QHBoxLayout() label = QLabel(self) - label.setText('Welcome to Helmholtz AI data centric tool! Please select your dataset folder') + label.setText( + "Welcome to Helmholtz AI data centric tool! Please select your dataset folder" + ) self.main_layout.addWidget(label) self.text_layout = QVBoxLayout() @@ -38,35 +48,41 @@ def __init__(self, app: Application): self.button_layout = QVBoxLayout() val_label = QLabel(self) - val_label.setText('Uncurated dataset path:') + val_label.setText("Uncurated dataset path:") inprogr_label = QLabel(self) - inprogr_label.setText('Curation in progress path:') + inprogr_label.setText("Curation in progress path:") train_label = QLabel(self) - train_label.setText('Curated dataset path:') + train_label.setText("Curated dataset path:") self.text_layout.addWidget(val_label) self.text_layout.addWidget(inprogr_label) self.text_layout.addWidget(train_label) self.val_textbox = QLineEdit(self) - self.val_textbox.textEdited.connect(lambda x: self.on_text_changed(self.val_textbox, "eval", x)) + self.val_textbox.textEdited.connect( + lambda x: self.on_text_changed(self.val_textbox, "eval", x) + ) self.inprogr_textbox = QLineEdit(self) - self.inprogr_textbox.textEdited.connect(lambda x: self.on_text_changed(self.inprogr_textbox, "inprogress", x)) + self.inprogr_textbox.textEdited.connect( + lambda x: self.on_text_changed(self.inprogr_textbox, "inprogress", x) + ) self.train_textbox = QLineEdit(self) - self.train_textbox.textEdited.connect(lambda x: self.on_text_changed(self.train_textbox, "train", x)) + self.train_textbox.textEdited.connect( + lambda x: self.on_text_changed(self.train_textbox, "train", x) + ) self.path_layout.addWidget(self.val_textbox) self.path_layout.addWidget(self.inprogr_textbox) self.path_layout.addWidget(self.train_textbox) - - self.file_open_button_val = QPushButton('Browse',self) + + self.file_open_button_val = QPushButton("Browse", self) self.file_open_button_val.show() self.file_open_button_val.clicked.connect(self.browse_eval_clicked) - self.file_open_button_prog = QPushButton('Browse',self) + self.file_open_button_prog = QPushButton("Browse", self) self.file_open_button_prog.show() self.file_open_button_prog.clicked.connect(self.browse_inprogr_clicked) - self.file_open_button_train = QPushButton('Browse',self) + self.file_open_button_train = QPushButton("Browse", self) self.file_open_button_train.show() self.file_open_button_train.clicked.connect(self.browse_train_clicked) self.button_layout.addWidget(self.file_open_button_val) @@ -78,11 +94,11 @@ def __init__(self, app: Application): input_layout.addLayout(self.button_layout) self.main_layout.addLayout(input_layout) - self.start_button = QPushButton('Start', self) + self.start_button = QPushButton("Start", self) self.start_button.setFixedSize(120, 30) self.start_button.show() # check if we need to upload data to server - self.done_upload = False # we only do once + self.done_upload = False # we only do once if self.app.syncer.host_name == "local": self.start_button.clicked.connect(self.start_main) else: @@ -92,8 +108,8 @@ def __init__(self, app: Application): self.show() - def browse_eval_clicked(self): - """Activates when the user clicks the button to choose the evaluation directory (QFileDialog) and + def browse_eval_clicked(self) -> None: + """Activates when the user clicks the button to choose the evaluation directory (QFileDialog) and displays the name of the evaluation directory chosen in the validation textbox line (QLineEdit). """ self.fd = QFileDialog() @@ -104,9 +120,9 @@ def browse_eval_clicked(self): self.val_textbox.setText(self.app.eval_data_path) finally: self.fd = None - - def browse_train_clicked(self): - """Activates when the user clicks the button to choose the train directory (QFileDialog) and + + def browse_train_clicked(self) -> None: + """Activates when the user clicks the button to choose the train directory (QFileDialog) and displays the name of the train directory chosen in the train textbox line (QLineEdit). """ @@ -116,9 +132,9 @@ def browse_train_clicked(self): self.app.train_data_path = fd.selectedFiles()[0] self.train_textbox.setText(self.app.train_data_path) - def on_text_changed(self, field_obj, field_name, text): + def on_text_changed(self, field_obj: QLineEdit, field_name: str, text: str) -> None: """ - Update data paths based on text changes in input fields. + Update data paths based on text changes in input fields. Used for copying paths in the welcome window. :param field_obj: The QLineEdit object. @@ -136,30 +152,37 @@ def on_text_changed(self, field_obj, field_name, text): elif field_name == "inprogress": self.app.inprogr_data_path = text field_obj.setText(text) - - - def browse_inprogr_clicked(self): + def browse_inprogr_clicked(self) -> None: """ - Activates when the user clicks the button to choose the curation in progress directory (QFileDialog) and + Activates when the user clicks the button to choose the curation in progress directory (QFileDialog) and displays the name of the evaluation directory chosen in the validation textbox line (QLineEdit). """ fd = QFileDialog() fd.setFileMode(QFileDialog.Directory) - if fd.exec_(): # Browse clicked - self.app.inprogr_data_path = fd.selectedFiles()[0] #TODO: case when browse is clicked but nothing is specified - currently it is filled with os.getcwd() + if fd.exec_(): # Browse clicked + self.app.inprogr_data_path = fd.selectedFiles()[ + 0 + ] # TODO: case when browse is clicked but nothing is specified - currently it is filled with os.getcwd() self.inprogr_textbox.setText(self.app.inprogr_data_path) - - def start_main(self): - """Starts the main window after the user clicks 'Start' and only if both evaluation and train directories are chosen and all unique. - """ - - if len({self.app.inprogr_data_path, self.app.train_data_path, self.app.eval_data_path})<3: + + def start_main(self) -> None: + """Starts the main window after the user clicks 'Start' and only if both evaluation and train directories are chosen and all unique.""" + + if ( + len( + { + self.app.inprogr_data_path, + self.app.train_data_path, + self.app.eval_data_path, + } + ) + < 3 + ): self.message_text = "All directory names must be distinct." _ = self.create_warning_box(self.message_text, message_title="Warning") - elif self.app.train_data_path and self.app.eval_data_path: self.hide() self.mw = MainWindow(self.app) @@ -167,27 +190,35 @@ def start_main(self): self.message_text = "You need to specify a folder both for your uncurated and curated dataset (even if the curated folder is currently empty). Please go back and select folders for both." _ = self.create_warning_box(self.message_text, message_title="Warning") - def start_upload_and_main(self): + def start_upload_and_main(self) -> None: """ If the configs are set to use remote not local server then the user is asked to confirm the upload of their data to the server and the upload starts before launching the main window. """ if self.done_upload is False: - message_text = ("Your current configurations are set to run some operations on the cloud. \n" - "For this we need to upload your data to our server." - "We will now upload your data. Click ok to continue. \n" - "If you do not agree close the application and contact your software provider.") - usr_response = self.create_warning_box(message_text, message_title="Warning", add_cancel_btn=True) - if usr_response: + message_text = ( + "Your current configurations are set to run some operations on the cloud. \n" + "For this we need to upload your data to our server." + "We will now upload your data. Click ok to continue. \n" + "If you do not agree close the application and contact your software provider." + ) + usr_response = self.create_warning_box( + message_text, message_title="Warning", add_cancel_btn=True + ) + if usr_response: success_up1, success_up2, _, _ = self.app.upload_data_to_server() - if success_up1=="Error" or success_up2=="Error": - message_text = ("An error has occured during data upload to the server. \n" - "Please check your configuration file and ensure that the server connection settings are correct and you have been given access to the server. \n" - "If the problem persists contact your software provider. Exiting now.") - usr_response = self.create_warning_box(message_text, message_title="Error") - self.close() - else: + if success_up1 == "Error" or success_up2 == "Error": + message_text = ( + "An error has occured during data upload to the server. \n" + "Please check your configuration file and ensure that the server connection settings are correct and you have been given access to the server. \n" + "If the problem persists contact your software provider. Exiting now." + ) + usr_response = self.create_warning_box( + message_text, message_title="Error" + ) + self.close() + else: self.done_upload = True self.start_upload_and_main() - else: self.start_main() - \ No newline at end of file + else: + self.start_main() diff --git a/src/client/dcp_client/main.py b/src/client/dcp_client/main.py index 57e917bc..ef16a971 100644 --- a/src/client/dcp_client/main.py +++ b/src/client/dcp_client/main.py @@ -18,6 +18,7 @@ def main(): settings.init() + dir_name = path.dirname(path.abspath(__file__)) parser = argparse.ArgumentParser() @@ -32,11 +33,11 @@ def main(): if args.mode == "local": server_config = read_config( - "server", config_path=path.join(dir_name, "config.cfg") + "server", config_path=path.join(dir_name, "config.yaml") ) elif args.mode == "remote": server_config = read_config( - "server", config_path=path.join(dir_name, "config_remote.cfg") + "server", config_path=path.join(dir_name, "config_remote.yaml") ) image_storage = FilesystemImageStorage() diff --git a/src/client/dcp_client/utils/bentoml_model.py b/src/client/dcp_client/utils/bentoml_model.py index 25204ac5..5f57b421 100644 --- a/src/client/dcp_client/utils/bentoml_model.py +++ b/src/client/dcp_client/utils/bentoml_model.py @@ -1,18 +1,16 @@ import asyncio -from typing import Optional +from typing import Optional, List from bentoml.client import Client as BentoClient from bentoml.exceptions import BentoMLException +import numpy as np from dcp_client.app import Model + class BentomlModel(Model): - """BentomlModel class for connecting to a BentoML server and running training and inference tasks. - """ + """BentomlModel class for connecting to a BentoML server and running training and inference tasks.""" - def __init__( - self, - client: Optional[BentoClient] = None - ): + def __init__(self, client: Optional[BentoClient] = None): """Initializes the BentomlModel. :param client: Optional BentoClient instance. If None, it will be initialized during connection. @@ -20,7 +18,7 @@ def __init__( """ self.client = client - def connect(self, ip: str = '0.0.0.0', port: int = 7010): + def connect(self, ip: str = "0.0.0.0", port: int = 7010) -> bool: """Connects to the BentoML server. :param ip: IP address of the BentoML server. Default is '0.0.0.0'. @@ -30,14 +28,15 @@ def connect(self, ip: str = '0.0.0.0', port: int = 7010): :return: True if connection is successful, False otherwise. :rtype: bool """ - url = f"http://{ip}:{port}" #"http://0.0.0.0:7010" + url = f"http://{ip}:{port}" # "http://0.0.0.0:7010" try: - self.client = BentoClient.from_url(url) + self.client = BentoClient.from_url(url) return True - except : return False # except ConnectionRefusedError - + except: + return False # except ConnectionRefusedError + @property - def is_connected(self): + def is_connected(self) -> bool: """Checks if the BentomlModel is connected to the BentoML server. :return: True if connected, False otherwise. @@ -45,19 +44,21 @@ def is_connected(self): """ return bool(self.client) - async def _run_train(self, data_path): + async def _run_train(self, data_path: str) -> Optional[str]: """Runs the training task asynchronously. :param data_path: Path to the training data. :type data_path: str :return: Response from the server if successful, None otherwise. + :rtype: str, or None """ try: response = await self.client.async_train(data_path) return response - except BentoMLException: return None + except BentoMLException: + return None - def run_train(self, data_path): + def run_train(self, data_path: str): """Runs the training. :param data_path: Path to the training data. @@ -66,19 +67,21 @@ def run_train(self, data_path): """ return asyncio.run(self._run_train(data_path)) - async def _run_inference(self, data_path): + async def _run_inference(self, data_path: str) -> Optional[np.ndarray]: """Runs the inference task asynchronously. :param data_path: Path to the data for inference. :type data_path: str :return: List of files not supported by the server if unsuccessful, otherwise returns None. + :rtype: np.ndarray, or None """ try: response = await self.client.async_segment_image(data_path) return response - except BentoMLException: return None - - def run_inference(self, data_path): + except BentoMLException: + return None + + def run_inference(self, data_path: str) -> List: """Runs the inference. :param data_path: Path to the data for inference. @@ -86,4 +89,4 @@ def run_inference(self, data_path): :return: List of files not supported by the server if unsuccessful, otherwise returns None. """ list_of_files_not_suported = asyncio.run(self._run_inference(data_path)) - return list_of_files_not_suported \ No newline at end of file + return list_of_files_not_suported diff --git a/src/client/dcp_client/utils/compute4mask.py b/src/client/dcp_client/utils/compute4mask.py new file mode 100644 index 00000000..f14bff5d --- /dev/null +++ b/src/client/dcp_client/utils/compute4mask.py @@ -0,0 +1,210 @@ +from typing import List +import numpy as np +from skimage.measure import find_contours, label +from skimage.draw import polygon_perimeter + + +class Compute4Mask: + """ + Compute4Mask provides methods for manipulating masks to make visualisation in the viewer easier. + """ + + @staticmethod + def get_contours( + instance_mask: np.ndarray, contours_level: float = None + ) -> np.ndarray: + """Find contours of objects in the instance mask. This function is used to identify the contours of the objects to prevent the problem of the merged + objects in napari window (mask). + + :param instance_mask: The instance mask array. + :type instance_mask: numpy.ndarray + :param contours_level: Value along which to find contours in the array. See skimage.measure.find_contours for more. + :type: None or float + :return: A binary mask where the contours of all objects in the instance segmentation mask are one and the rest is background. + :rtype: numpy.ndarray + + """ + instance_ids = Compute4Mask.get_unique_objects( + instance_mask + ) # get object instance labels ignoring background + contour_mask = np.zeros_like(instance_mask) + for instance_id in instance_ids: + # get a binary mask only of object + single_obj_mask = np.zeros_like(instance_mask) + single_obj_mask[instance_mask == instance_id] = 1 + try: + # compute contours for mask + contours = find_contours(single_obj_mask, contours_level) + # sometimes little dots appeas as additional contours so remove these + if len(contours) > 1: + contour_sizes = [contour.shape[0] for contour in contours] + contour = contours[contour_sizes.index(max(contour_sizes))].astype( + int + ) + else: + contour = contours[0] + # and draw onto contours mask + rr, cc = polygon_perimeter( + contour[:, 0], contour[:, 1], contour_mask.shape + ) + contour_mask[rr, cc] = instance_id + except: + print("Could not create contour for instance id", instance_id) + return contour_mask + + @staticmethod + def add_contour(labels_mask: np.ndarray, instance_mask: np.ndarray) -> np.ndarray: + """Add contours of objects to the labels mask. + + :param labels_mask: The class mask array without the contour pixels annotated. + :type labels_mask: numpy.ndarray + :param instance_mask: The instance mask array. + :type instance_mask: numpy.ndarray + :return: The updated class mask including contours. + :rtype: numpy.ndarray + """ + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + where_instances = np.where(instance_mask == instance_id) + # get unique class ids where the object is present + class_vals, counts = np.unique( + labels_mask[where_instances], return_counts=True + ) + # and take the class id which is most heavily represented + class_id = class_vals[np.argmax(counts)] + # make sure instance mask and class mask match + labels_mask[np.where(instance_mask == instance_id)] = class_id + return labels_mask + + @staticmethod + def compute_new_instance_mask( + labels_mask: np.ndarray, instance_mask: np.ndarray + ) -> np.ndarray: + """Given an updated labels mask, update also the instance mask accordingly. + So far the user can only remove an entire object in the labels mask view by + setting the color of the object to the background. + Therefore the instance mask can only change by entirely removing an object. + + :param labels_mask: The labels mask array, with changes made by the user. + :type labels_mask: numpy.ndarray + :param instance_mask: The existing instance mask, which needs to be updated. + :type instance_mask: numpy.ndarray + :return: The updated instance mask. + :rtype: numpy.ndarray + """ + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + unique_items_in_class_mask = list( + np.unique(labels_mask[instance_mask == instance_id]) + ) + if ( + len(unique_items_in_class_mask) == 1 + and unique_items_in_class_mask[0] == 0 + ): + instance_mask[instance_mask == instance_id] = 0 + return instance_mask + + @staticmethod + def compute_new_labels_mask( + labels_mask: np.ndarray, + instance_mask: np.ndarray, + original_instance_mask: np.ndarray, + old_instances: np.ndarray, + ) -> np.ndarray: + """Given the existing labels mask, the updated instance mask is used to update the labels mask. + + :param labels_mask: The existing labels mask, which needs to be updated. + :type labels_mask: numpy.ndarray + :param instance_mask: The instance mask array, with changes made by the user. + :type instance_mask: numpy.ndarray + :param original_instance_mask: The instance mask array, before the changes made by the user. + :type original_instance_mask: numpy.ndarray + :param old_instances: A list of the instance label ids in original_instance_mask. + :type old_instances: list + :return: The new labels mask, with updated changes according to those the user has made in the instance mask. + :rtype: numpy.ndarray + """ + new_labels_mask = np.zeros_like(labels_mask) + for instance_id in np.unique(instance_mask): + where_instance = np.where(instance_mask == instance_id) + # if the label is background skip + if instance_id == 0: + continue + # if the label is a newly added object, add with the same id to the labels mask + # this is an indication to the user that this object needs to be assigned a class + elif instance_id not in old_instances: + new_labels_mask[where_instance] = instance_id + else: + where_instance_orig = np.where(original_instance_mask == instance_id) + # if the locations of the instance haven't changed, means object wasn't changed, do nothing + num_classes = np.unique(labels_mask[where_instance]) + # if area was erased and object retains same class + if len(num_classes) == 1: + new_labels_mask[where_instance] = num_classes[0] + # area was added where there is background or other class + else: + old_class_id, counts = np.unique( + labels_mask[where_instance_orig], return_counts=True + ) + # assert len(old_class_id)==1 + # old_class_id = old_class_id[0] + # and take the class id which is most heavily represented + old_class_id = old_class_id[np.argmax(counts)] + new_labels_mask[where_instance] = old_class_id + + return new_labels_mask + + @staticmethod + def get_unique_objects(active_mask: np.ndarray) -> List: + """Gets unique objects from the active mask. + + :param active_mask: The mask array. + :type active_mask: numpy.ndarray + :return: A list of unique object labels. + :rtype: list + """ + return list(np.unique(active_mask)[1:]) + + @staticmethod + def assert_consistent_labels(mask: np.ndarray) -> tuple: + """Before saving the final mask make sure the user has not mistakenly made an error during annotation, + such that one instance id does not correspond to exactly one class id. Also checks whether for one instance id + multiple classes exist. + :param mask: The mask which we want to test. + :type mask: numpy.ndarray + :return: + - A boolean which is True if there is more than one connected components corresponding to an instance id and Fale otherwise. + - A boolean which is True if there is a missmatch between the instance mask and class masks (not 1-1 correspondance) and Flase otherwise. + - A list with all the instance ids for which more than one connected component was found. + - A list with all the instance ids for which a missmatch between class and instance masks was found. + :rtype : + - bool + - bool + - list[int] + - list[int] + """ + user_annot_error = False + mask_mismatch_error = False + faulty_ids_annot = [] + faulty_ids_missmatch = [] + instance_mask, class_mask = mask[0], mask[1] + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + # check if there are more than one objects (connected components) with same instance_id + if np.unique(label(instance_mask == instance_id)).shape[0] > 2: + user_annot_error = True + faulty_ids_annot.append(instance_id) + # and check if there is a mismatch between class mask and instance mask - should never happen! + if ( + np.unique(class_mask[np.where(instance_mask == instance_id)]).shape[0] + > 1 + ): + mask_mismatch_error = True + faulty_ids_missmatch.append(instance_id) + + return ( + user_annot_error, + mask_mismatch_error, + faulty_ids_annot, + faulty_ids_missmatch, + ) diff --git a/src/client/dcp_client/utils/fsimagestorage.py b/src/client/dcp_client/utils/fsimagestorage.py index d33371ff..3e8a5e3c 100644 --- a/src/client/dcp_client/utils/fsimagestorage.py +++ b/src/client/dcp_client/utils/fsimagestorage.py @@ -1,13 +1,14 @@ -from skimage.io import imread, imsave import os +import numpy as np +from skimage.io import imread, imsave from dcp_client.app import ImageStorage + class FilesystemImageStorage(ImageStorage): - """FilesystemImageStorage class for handling image storage operations on the local filesystem. - """ + """FilesystemImageStorage class for handling image storage operations on the local filesystem.""" - def load_image(self, from_directory, cur_selected_img): + def load_image(self, from_directory: str, cur_selected_img: str) -> np.ndarray: """Loads an image from the specified directory. :param from_directory: Path to the directory containing the image. @@ -18,8 +19,8 @@ def load_image(self, from_directory, cur_selected_img): """ # Read the selected image and read the segmentation if any: return imread(os.path.join(from_directory, cur_selected_img)) - - def move_image(self, from_directory, to_directory, cur_selected_img): + + def move_image(self, from_directory: str, to_directory: str, cur_selected_img: str) -> None: """Moves an image from one directory to another. :param from_directory: Path to the source directory. @@ -29,10 +30,15 @@ def move_image(self, from_directory, to_directory, cur_selected_img): :param cur_selected_img: Name of the image file. :type cur_selected_img: str """ - print(f"from:{os.path.join(from_directory, cur_selected_img)}, to:{os.path.join(to_directory, cur_selected_img)}") - os.replace(os.path.join(from_directory, cur_selected_img), os.path.join(to_directory, cur_selected_img)) - - def save_image(self, to_directory, cur_selected_img, img): + print( + f"from:{os.path.join(from_directory, cur_selected_img)}, to:{os.path.join(to_directory, cur_selected_img)}" + ) + os.replace( + os.path.join(from_directory, cur_selected_img), + os.path.join(to_directory, cur_selected_img), + ) + + def save_image(self, to_directory: str, cur_selected_img: str, img: np.ndarray) -> None: """Saves an image to the specified directory. :param to_directory: Path to the directory where the image will be saved. @@ -41,10 +47,10 @@ def save_image(self, to_directory, cur_selected_img, img): :type cur_selected_img: str :param img: Image data to be saved. """ - + imsave(os.path.join(to_directory, cur_selected_img), img) - - def delete_image(self, from_directory, cur_selected_img): + + def delete_image(self, from_directory: str, cur_selected_img: str) -> None: """Deletes an image from the specified directory. :param from_directory: Path to the directory containing the image. diff --git a/src/client/dcp_client/utils/settings.py b/src/client/dcp_client/utils/settings.py index 2fd6bcb2..5107fb82 100644 --- a/src/client/dcp_client/utils/settings.py +++ b/src/client/dcp_client/utils/settings.py @@ -1,5 +1,6 @@ -def init(): +def init() -> None: + """ Initialise global variables.""" global accepted_types accepted_types = (".jpg", ".jpeg", ".png", ".tiff", ".tif") global seg_name_string - seg_name_string = '_seg' + seg_name_string = "_seg" diff --git a/src/client/dcp_client/utils/sync_src_dst.py b/src/client/dcp_client/utils/sync_src_dst.py index 66d0a4b7..0698901d 100644 --- a/src/client/dcp_client/utils/sync_src_dst.py +++ b/src/client/dcp_client/utils/sync_src_dst.py @@ -6,14 +6,16 @@ class DataRSync(DataSync): - ''' + """ Class which uses rsync bash command to sync data between client and server - ''' - def __init__(self, - user_name: str, - host_name: str, - server_repo_path: str, - ): + """ + + def __init__( + self, + user_name: str, + host_name: str, + server_repo_path: str, + ) -> None: """Constructs all the necessary attributes for the CustomRunnable. :param user_name: the user name of the server - if "local", then it is assumed that local machine is used for the server @@ -22,33 +24,30 @@ def __init__(self, :type: host_name: str :param server_repo_path: the server path where we wish to sync data - if None, then it is assumed that local machine is used for the server :type server_repo_path: str - """ + """ self.user_name = user_name self.host_name = host_name self.server_repo_path = server_repo_path - def first_sync(self, path): + def first_sync(self, path: str) -> tuple: """ During the first sync the folder structure should be created on the server - + :param path: Path to the local directory to synchronize. :type path: str + :return: result message of subprocess + :rtype: tuple """ - server = self.user_name + "@" + self.host_name + ":" + self.server_repo_path + server = self.user_name + "@" + self.host_name + ":" + self.server_repo_path try: # Run the subprocess command - result = subprocess.run(["rsync", - "-azP" , - path, - server], - check=True) + result = subprocess.run(["rsync", "-azP", path, server], check=True) return ("Success", result.stdout) except subprocess.CalledProcessError as e: return ("Error", e) - - def sync(self, src, dst, path): - """ Syncs the data between the src and the dst. Both src and dst can be one of either + def sync(self, src: str, dst: str, path: str) -> tuple: + """Syncs the data between the src and the dst. Both src and dst can be one of either 'client' or 'server', whereas path is the local path we wish to sync :param src: A string specifying the source, from where the data will be sent to dst. Can be 'client' or 'server'. @@ -57,16 +56,18 @@ def sync(self, src, dst, path): :type dst: str :param path: Path to the directory we want to synchronize. :type path: str - + :return: result message of subprocess + :rtype: tuple + """ - path += '/' # otherwise it doesn't go in the directory - rel_path = get_relative_path(path) # get last folder, i.e. uncurated, curated + path += "/" # otherwise it doesn't go in the directory + rel_path = get_relative_path(path) # get last folder, i.e. uncurated, curated server_full_path = os.path.join(self.server_repo_path, rel_path) - server_full_path += '/' - server = self.user_name + "@" + self.host_name + ":" + server_full_path - print('server is: ', server) - - if src=='server': + server_full_path += "/" + server = self.user_name + "@" + self.host_name + ":" + server_full_path + print("server is: ", server) + + if src == "server": src = server dst = path else: @@ -74,19 +75,14 @@ def sync(self, src, dst, path): dst = server try: # Run the subprocess command - _ = subprocess.run(["rsync", - "-r" , - "--delete", - src, - dst], - check=True) + _ = subprocess.run(["rsync", "-r", "--delete", src, dst], check=True) return ("Success", server_full_path) except subprocess.CalledProcessError as e: return ("Error", e) - -if __name__=="__main__": - ds = DataRSync() #vm2 + +if __name__ == "__main__": + ds = DataRSync() # vm2 # These combinations work for me: # ubuntu@jusuf-vm2:/path... # jusuf-vm2:/path... @@ -94,6 +90,8 @@ def sync(self, src, dst, path): src = "client" # dst = 'client' # src = 'server' - #path = "data/" - path = "/Users/christina.bukas/Documents/AI_projects/code/data-centric-platform/data" - ds.sync(src, dst, path) \ No newline at end of file + # path = "data/" + path = ( + "/Users/christina.bukas/Documents/AI_projects/code/data-centric-platform/data" + ) + ds.sync(src, dst, path) diff --git a/src/client/dcp_client/utils/utils.py b/src/client/dcp_client/utils/utils.py index c9060a3b..eb08f881 100644 --- a/src/client/dcp_client/utils/utils.py +++ b/src/client/dcp_client/utils/utils.py @@ -1,24 +1,22 @@ -from PyQt5.QtWidgets import QFileIconProvider -from PyQt5.QtCore import QSize -from PyQt5.QtGui import QPixmap, QIcon -import numpy as np -from skimage.measure import find_contours, label -from skimage.draw import polygon_perimeter +from qtpy.QtWidgets import QFileIconProvider +from qtpy.QtCore import QSize +from qtpy.QtGui import QPixmap, QIcon from pathlib import Path, PurePath -import json +import yaml +import numpy as np from dcp_client.utils import settings + class IconProvider(QFileIconProvider): def __init__(self) -> None: - """ Initializes the IconProvider with the default icon size. - """ + """Initializes the IconProvider with the default icon size.""" super().__init__() - self.ICON_SIZE = QSize(512,512) + self.ICON_SIZE = QSize(512, 512) - def icon(self, type: 'QFileIconProvider.IconType'): - """ Returns the icon for the specified file type. + def icon(self, type: QFileIconProvider.IconType) -> QIcon: + """Returns the icon for the specified file type. :param type: The type of the file for which the icon is requested. :type type: QFileIconProvider.IconType @@ -27,7 +25,8 @@ def icon(self, type: 'QFileIconProvider.IconType'): """ try: fn = type.filePath() - except AttributeError: return super().icon(type) # TODO handle exception differently? + except AttributeError: + return super().icon(type) # TODO handle exception differently? if fn.endswith(settings.accepted_types): a = QPixmap(self.ICON_SIZE) @@ -36,24 +35,28 @@ def icon(self, type: 'QFileIconProvider.IconType'): else: return super().icon(type) -def read_config(name, config_path = 'config.cfg') -> dict: - """ Reads the configuration file + +def read_config(name: str, config_path: str = "config.yaml") -> dict: + """Reads the configuration file :param name: name of the section you want to read (e.g. 'setup','train') :type name: string - :param config_path: path to the configuration file, defaults to 'config.cfg' + :param config_path: path to the configuration file, defaults to 'config.yaml' :type config_path: str, optional :return: dictionary from the config section given by name :rtype: dict - """ + """ with open(config_path) as config_file: - config_dict = json.load(config_file) + config_dict = yaml.safe_load( + config_file + ) # json.load(config_file) for .cfg file # Check if config file has main mandatory keys - assert all([i in config_dict.keys() for i in ['server']]) + assert all([i in config_dict.keys() for i in ["server"]]) return config_dict[name] -def get_relative_path(filepath): - """ Returns the name of the file from the given filepath. + +def get_relative_path(filepath: str) -> str: + """Returns the name of the file from the given filepath. :param filepath: The path of the file. :type filepath: str @@ -62,8 +65,9 @@ def get_relative_path(filepath): """ return PurePath(filepath).name -def get_path_stem(filepath): - """ Returns the stem (filename without its extension) from the given filepath. + +def get_path_stem(filepath: str) -> str: + """Returns the stem (filename without its extension) from the given filepath. :param filepath: The path of the file. :type filepath: str @@ -72,8 +76,9 @@ def get_path_stem(filepath): """ return str(Path(filepath).stem) -def get_path_name(filepath): - """ Returns the name of the file from the given filepath. + +def get_path_name(filepath: str) -> str: + """Returns the name of the file from the given filepath. :param filepath: The path of the file. :type filepath: str @@ -82,8 +87,9 @@ def get_path_name(filepath): """ return str(Path(filepath).name) -def get_path_parent(filepath): - """ Returns the parent directory of the given filepath. + +def get_path_parent(filepath: str) -> str: + """Returns the parent directory of the given filepath. :param filepath: The path of the file. :type filepath: str @@ -92,8 +98,9 @@ def get_path_parent(filepath): """ return str(Path(filepath).parent) -def join_path(root_dir, filepath): - """ Joins the root directory path with the given filepath. + +def join_path(root_dir: str, filepath: str) -> str: + """Joins the root directory path with the given filepath. :param root_dir: The root directory. :type root_dir: str @@ -104,8 +111,9 @@ def join_path(root_dir, filepath): """ return str(Path(root_dir, filepath)) -def check_equal_arrays(array1, array2): - """ Checks if two arrays are equal. + +def check_equal_arrays(array1: np.ndarray, array2: np.ndarray) -> bool: + """Checks if two arrays are equal. :param array1: The first array. :type array1: numpy.ndarray @@ -115,173 +123,3 @@ def check_equal_arrays(array1, array2): :rtype: bool """ return np.array_equal(array1, array2) - -class Compute4Mask: - """ - Compute4Mask provides methods for manipulating masks. - """ - - @staticmethod - def get_contours(instance_mask, contours_level=None): - """ Find contours of objects in the instance mask. This function is used to identify the contours of the objects to prevent the problem of the merged - objects in napari window (mask). - - :param instance_mask: The instance mask array. - :type instance_mask: numpy.ndarray - :param contours_level: Value along which to find contours in the array. See skimage.measure.find_contours for more. - :type: None or float - :return: A binary mask where the contours of all objects in the instance segmentation mask are one and the rest is background. - :rtype: numpy.ndarray - - """ - instance_ids = Compute4Mask.get_unique_objects(instance_mask) # get object instance labels ignoring background - contour_mask= np.zeros_like(instance_mask) - for instance_id in instance_ids: - # get a binary mask only of object - single_obj_mask = np.zeros_like(instance_mask) - single_obj_mask[instance_mask==instance_id] = 1 - # compute contours for mask - contours = find_contours(single_obj_mask, contours_level) - # sometimes little dots appeas as additional contours so remove these - if len(contours)>1: - contour_sizes = [contour.shape[0] for contour in contours] - contour = contours[contour_sizes.index(max(contour_sizes))].astype(int) - else: contour = contours[0] - # and draw onto contours mask - rr, cc = polygon_perimeter(contour[:, 0], contour[:, 1], contour_mask.shape) - contour_mask[rr, cc] = instance_id - return contour_mask - - @staticmethod - def add_contour(labels_mask, instance_mask): - """ Add contours of objects to the labels mask. - - :param labels_mask: The class mask array without the contour pixels annotated. - :type labels_mask: numpy.ndarray - :param instance_mask: The instance mask array. - :type instance_mask: numpy.ndarray - :return: The updated class mask including contours. - :rtype: numpy.ndarray - """ - instance_ids = Compute4Mask.get_unique_objects(instance_mask) - for instance_id in instance_ids: - where_instances = np.where(instance_mask==instance_id) - # get unique class ids where the object is present - class_vals, counts = np.unique(labels_mask[where_instances], return_counts=True) - # and take the class id which is most heavily represented - class_id = class_vals[np.argmax(counts)] - # make sure instance mask and class mask match - labels_mask[np.where(instance_mask==instance_id)] = class_id - return labels_mask - - - @staticmethod - def compute_new_instance_mask(labels_mask, instance_mask): - """ Given an updated labels mask, update also the instance mask accordingly. - So far the user can only remove an entire object in the labels mask view by - setting the color of the object to the background. - Therefore the instance mask can only change by entirely removing an object. - - :param labels_mask: The labels mask array, with changes made by the user. - :type labels_mask: numpy.ndarray - :param instance_mask: The existing instance mask, which needs to be updated. - :type instance_mask: numpy.ndarray - :return: The updated instance mask. - :rtype: numpy.ndarray - """ - instance_ids = Compute4Mask.get_unique_objects(instance_mask) - for instance_id in instance_ids: - unique_items_in_class_mask = list(np.unique(labels_mask[instance_mask==instance_id])) - if len(unique_items_in_class_mask)==1 and unique_items_in_class_mask[0]==0: - instance_mask[instance_mask==instance_id] = 0 - return instance_mask - - - @staticmethod - def compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances): - """ Given the existing labels mask, the updated instance mask is used to update the labels mask. - - :param labels_mask: The existing labels mask, which needs to be updated. - :type labels_mask: numpy.ndarray - :param instance_mask: The instance mask array, with changes made by the user. - :type instance_mask: numpy.ndarray - :param original_instance_mask: The instance mask array, before the changes made by the user. - :type original_instance_mask: numpy.ndarray - :param old_instances: A list of the instance label ids in original_instance_mask. - :type old_instances: list - :return: The new labels mask, with updated changes according to those the user has made in the instance mask. - :rtype: numpy.ndarray - """ - new_labels_mask = np.zeros_like(labels_mask) - for instance_id in np.unique(instance_mask): - where_instance = np.where(instance_mask==instance_id) - # if the label is background skip - if instance_id==0: continue - # if the label is a newly added object, add with the same id to the labels mask - # this is an indication to the user that this object needs to be assigned a class - elif instance_id not in old_instances: - new_labels_mask[where_instance] = instance_id - else: - where_instance_orig = np.where(original_instance_mask==instance_id) - # if the locations of the instance haven't changed, means object wasn't changed, do nothing - num_classes = np.unique(labels_mask[where_instance]) - # if area was erased and object retains same class - if len(num_classes)==1: - new_labels_mask[where_instance] = num_classes[0] - # area was added where there is background or other class - else: - old_class_id, counts = np.unique(labels_mask[where_instance_orig], return_counts=True) - #assert len(old_class_id)==1 - #old_class_id = old_class_id[0] - # and take the class id which is most heavily represented - old_class_id = old_class_id[np.argmax(counts)] - new_labels_mask[where_instance] = old_class_id - - return new_labels_mask - - @staticmethod - def get_unique_objects(active_mask): - """ Gets unique objects from the active mask. - - :param active_mask: The mask array. - :type active_mask: numpy.ndarray - :return: A list of unique object labels. - :rtype: list - """ - return list(np.unique(active_mask)[1:]) - - @staticmethod - def assert_consistent_labels(mask): - """ Before saving the final mask make sure the user has not mistakenly made an error during annotation, - such that one instance id does not correspond to exactly one class id. Also checks whether for one instance id - multiple classes exist. - :param mask: The mask which we want to test. - :type mask: numpy.ndarray - :return: - - A boolean which is True if there is more than one connected components corresponding to an instance id and Fale otherwise. - - A boolean which is True if there is a missmatch between the instance mask and class masks (not 1-1 correspondance) and Flase otherwise. - - A list with all the instance ids for which more than one connected component was found. - - A list with all the instance ids for which a missmatch between class and instance masks was found. - :rtype : - - bool - - bool - - list[int] - - list[int] - """ - user_annot_error = False - mask_mismatch_error = False - faulty_ids_annot = [] - faulty_ids_missmatch = [] - instance_mask, class_mask = mask[0], mask[1] - instance_ids = Compute4Mask.get_unique_objects(instance_mask) - for instance_id in instance_ids: - # check if there are more than one objects (connected components) with same instance_id - if np.unique(label(instance_mask==instance_id)).shape[0] > 2: - user_annot_error = True - faulty_ids_annot.append(instance_id) - # and check if there is a mismatch between class mask and instance mask - should never happen! - if np.unique(class_mask[np.where(instance_mask==instance_id)]).shape[0]>1: - mask_mismatch_error = True - faulty_ids_missmatch.append(instance_id) - - return user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch \ No newline at end of file diff --git a/src/client/pyproject.toml b/src/client/pyproject.toml index 2e521a11..93af7bd7 100644 --- a/src/client/pyproject.toml +++ b/src/client/pyproject.toml @@ -33,7 +33,10 @@ maintainers = [ [project.optional-dependencies] dev = [ - "pytest", + "pytest>=7.4.3", + "pytest-qt>=4.2.0", + "sphinx", + "sphinx-rtd-theme" ] [project.urls] diff --git a/src/client/requirements.txt b/src/client/requirements.txt index 98109d47..e47ad839 100644 --- a/src/client/requirements.txt +++ b/src/client/requirements.txt @@ -1,6 +1,2 @@ napari[pyqt5]>=0.4.17 -bentoml[grpc]==1.0.16 -pytest>=7.4.3 -pytest-qt>=4.2.0 -sphinx -sphinx-rtd-theme \ No newline at end of file +bentoml[grpc]==1.0.16 \ No newline at end of file diff --git a/src/client/test/test_app.py b/src/client/test/test_app.py index e4e6d1f9..ad31285a 100644 --- a/src/client/test/test_app.py +++ b/src/client/test/test_app.py @@ -1,5 +1,6 @@ import os import sys + sys.path.append("../") import pytest import subprocess @@ -13,88 +14,101 @@ from dcp_client.utils.fsimagestorage import FilesystemImageStorage from dcp_client.utils.sync_src_dst import DataRSync + @pytest.fixture def app(): img1 = data.astronaut() img2 = data.coffee() img3 = data.cat() - if not os.path.exists('in_prog'): - os.mkdir('in_prog') - imsave('in_prog/coffee.png', img2) + if not os.path.exists("in_prog"): + os.mkdir("in_prog") + imsave("in_prog/coffee.png", img2) - if not os.path.exists('eval_data_path'): - os.mkdir('eval_data_path') - imsave('eval_data_path/cat.png', img3) + if not os.path.exists("eval_data_path"): + os.mkdir("eval_data_path") + imsave("eval_data_path/cat.png", img3) - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') - app = Application(BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", - 7010, - os.path.join(os.getcwd(), 'eval_data_path')) + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") + app = Application( + BentomlModel(), + rsyncer, + FilesystemImageStorage(), + "0.0.0.0", + 7010, + os.path.join(os.getcwd(), "eval_data_path"), + ) return app, img1, img2, img3 + def test_load_image(app): app, img, img2, _ = app # Unpack the app, img, and img2 from the fixture - - app.cur_selected_img = 'coffee.png' - app.cur_selected_path = 'in_prog' + + app.cur_selected_img = "coffee.png" + app.cur_selected_path = "in_prog" img_test = app.load_image() # if image_name is None assert img.all() == img_test.all() - app.cur_selected_path = 'eval_data_path' - img_test2 = app.load_image('cat.png') # if a filename is given + app.cur_selected_path = "eval_data_path" + img_test2 = app.load_image("cat.png") # if a filename is given assert img2.all() == img_test2.all() + def test_run_inference_no_connection(app): - app, _, _, _ = app + app, _, _, _ = app message_text, message_title = app.run_inference() - assert message_text=="Connection could not be established. Please check if the server is running and try again." - assert message_title=="Warning" + assert ( + message_text + == "Connection could not be established. Please check if the server is running and try again." + ) + assert message_title == "Warning" + def test_run_inference_run(app): - app, _, _, _ = app + app, _, _, _ = app # start the sevrer in the background locally command = [ "bentoml", - "serve", - '--working-dir', - '../server/dcp_server', + "serve", + "--working-dir", + "../server/dcp_server", "service:svc", "--reload", "--port=7010", ] process = subprocess.Popen(command, stdin=subprocess.PIPE, shell=False) # and wait until it is setup - if sys.platform == 'win32' or sys.platform == 'cygwin': time.sleep(240) - else: time.sleep(60) + if sys.platform == "win32" or sys.platform == "cygwin": + time.sleep(240) + else: + time.sleep(60) # then do model serving message_text, message_title = app.run_inference() # and assert returning message print(f"HERE: {message_text, message_title}") - assert message_text== "Success! Masks generated for all images" - assert message_title=="Information" + assert message_text == "Success! Masks generated for all images" + assert message_title == "Information" # finally clean up process process.terminate() process.wait() process.kill() + def test_search_segs(app): - app, _, _, _ = app - app.cur_selected_img = 'cat.png' - app.cur_selected_path = 'eval_data_path' + app, _, _, _ = app + app.cur_selected_img = "cat.png" + app.cur_selected_path = "eval_data_path" app.search_segs() - res = app.seg_filepaths - assert len(res)==1 - assert res[0]=='cat_seg.tiff' + res = app.seg_filepaths + assert len(res) == 1 + assert res[0] == "cat_seg.tiff" # also remove the seg as it is not needed for other scripts - os.remove('eval_data_path/cat_seg.tiff') + os.remove("eval_data_path/cat_seg.tiff") -''' + +""" def test_run_train(): pass @@ -107,7 +121,4 @@ def test_move_images(): def test_delete_images(): pass -''' - - - +""" diff --git a/src/client/test/test_compute4mask.py b/src/client/test/test_compute4mask.py index e76dfc1c..5304e2ee 100644 --- a/src/client/test/test_compute4mask.py +++ b/src/client/test/test_compute4mask.py @@ -1,83 +1,114 @@ import numpy as np import pytest -from dcp_client.utils.utils import Compute4Mask +from dcp_client.utils.compute4mask import Compute4Mask + @pytest.fixture def sample_data(): - instance_mask = np.array([[0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [2, 2, 0, 0, 0], - [0, 0, 3, 3, 0]]) - labels_mask = np.array([[0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [2, 2, 0, 0, 0], - [0, 0, 1, 1, 0]]) + instance_mask = np.array( + [ + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [2, 2, 0, 0, 0], + [0, 0, 3, 3, 0], + ] + ) + labels_mask = np.array( + [ + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [2, 2, 0, 0, 0], + [0, 0, 1, 1, 0], + ] + ) return instance_mask, labels_mask + def test_get_unique_objects(sample_data): instance_mask, _ = sample_data unique_objects = Compute4Mask.get_unique_objects(instance_mask) assert unique_objects == [1, 2, 3] + def test_get_contours(sample_data): instance_mask, _ = sample_data contour_mask = Compute4Mask.get_contours(instance_mask) assert contour_mask.shape == instance_mask.shape - assert contour_mask[0,1] == 1 # randomly check a contour location is present + assert contour_mask[0, 1] == 1 # randomly check a contour location is present + def test_add_contour(sample_data): instance_mask, labels_mask = sample_data contours_mask = Compute4Mask.get_contours(instance_mask, contours_level=0.1) labels_mask_wo_contour = np.copy(labels_mask) - labels_mask_wo_contour[contours_mask!=0] = 0 - updated_labels_mask = Compute4Mask.add_contour(labels_mask_wo_contour, instance_mask) + labels_mask_wo_contour[contours_mask != 0] = 0 + updated_labels_mask = Compute4Mask.add_contour( + labels_mask_wo_contour, instance_mask + ) assert np.array_equal(updated_labels_mask[:3], labels_mask[:3]) + def test_compute_new_instance_mask(sample_data): instance_mask, labels_mask = sample_data - labels_mask[labels_mask==1] = 0 - updated_instance_mask = Compute4Mask.compute_new_instance_mask(labels_mask, instance_mask) - assert list(np.unique(updated_instance_mask))==[0,2] + labels_mask[labels_mask == 1] = 0 + updated_instance_mask = Compute4Mask.compute_new_instance_mask( + labels_mask, instance_mask + ) + assert list(np.unique(updated_instance_mask)) == [0, 2] + def test_compute_new_labels_mask_obj_added(sample_data): instance_mask, labels_mask = sample_data original_instance_mask = np.copy(instance_mask) instance_mask[0, 0] = 4 old_instances = Compute4Mask.get_unique_objects(original_instance_mask) - new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) - assert new_labels_mask[0,0]==4 + new_labels_mask = Compute4Mask.compute_new_labels_mask( + labels_mask, instance_mask, original_instance_mask, old_instances + ) + assert new_labels_mask[0, 0] == 4 + def test_compute_new_labels_mask_obj_erased(sample_data): instance_mask, labels_mask = sample_data original_instance_mask = np.copy(instance_mask) instance_mask[0] = 0 old_instances = Compute4Mask.get_unique_objects(original_instance_mask) - new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) - assert np.all(new_labels_mask[0])==0 + new_labels_mask = Compute4Mask.compute_new_labels_mask( + labels_mask, instance_mask, original_instance_mask, old_instances + ) + assert np.all(new_labels_mask[0]) == 0 assert np.array_equal(new_labels_mask[1:], labels_mask[1:]) + def test_compute_new_labels_mask_obj_added(sample_data): instance_mask, labels_mask = sample_data original_instance_mask = np.copy(instance_mask) instance_mask[:, -1] = 1 old_instances = Compute4Mask.get_unique_objects(original_instance_mask) - new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) - assert np.all(new_labels_mask[:, -1])==1 + new_labels_mask = Compute4Mask.compute_new_labels_mask( + labels_mask, instance_mask, original_instance_mask, old_instances + ) + assert np.all(new_labels_mask[:, -1]) == 1 + def assert_consistent_labels(sample_data): instance_mask, labels_mask = sample_data - user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = Compute4Mask.assert_consistent_labels(sample_data) - assert user_annot_error==False - assert mask_mismatch_error==False - assert len(faulty_ids_annot)==len(faulty_ids_missmatch)==0 - instance_mask[instance_mask==3] = 1 - labels_mask[1,2] = 2 - user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = Compute4Mask.assert_consistent_labels(np.stack(instance_mask, labels_mask)) - assert user_annot_error==True - assert mask_mismatch_error==True - assert len(faulty_ids_annot)==1 - assert faulty_ids_annot[0]==1 - assert len(faulty_ids_missmatch)==1 - assert faulty_ids_missmatch[0]==1 + user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = ( + Compute4Mask.assert_consistent_labels(sample_data) + ) + assert user_annot_error == False + assert mask_mismatch_error == False + assert len(faulty_ids_annot) == len(faulty_ids_missmatch) == 0 + instance_mask[instance_mask == 3] = 1 + labels_mask[1, 2] = 2 + user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = ( + Compute4Mask.assert_consistent_labels(np.stack(instance_mask, labels_mask)) + ) + assert user_annot_error == True + assert mask_mismatch_error == True + assert len(faulty_ids_annot) == 1 + assert faulty_ids_annot[0] == 1 + assert len(faulty_ids_missmatch) == 1 + assert faulty_ids_missmatch[0] == 1 diff --git a/src/client/test/test_fsimagestorage.py b/src/client/test/test_fsimagestorage.py index 275e5f0b..f971fbfe 100644 --- a/src/client/test/test_fsimagestorage.py +++ b/src/client/test/test_fsimagestorage.py @@ -5,42 +5,48 @@ from dcp_client.utils.fsimagestorage import FilesystemImageStorage + @pytest.fixture def fis(): return FilesystemImageStorage() + @pytest.fixture def sample_image(): # Create a sample image img = data.astronaut() - fname = 'test_img.png' + fname = "test_img.png" imsave(fname, img) return fname - + + def test_load_image(fis, sample_image): - img_test = fis.load_image('.', sample_image) + img_test = fis.load_image(".", sample_image) assert img_test.all() == data.astronaut().all() os.remove(sample_image) + def test_move_image(fis, sample_image): - temp_dir = 'temp' + temp_dir = "temp" os.mkdir(temp_dir) - fis.move_image('.', temp_dir, sample_image) - assert os.path.exists(os.path.join(temp_dir, 'test_img.png')) - os.remove(os.path.join(temp_dir, 'test_img.png')) + fis.move_image(".", temp_dir, sample_image) + assert os.path.exists(os.path.join(temp_dir, "test_img.png")) + os.remove(os.path.join(temp_dir, "test_img.png")) os.rmdir(temp_dir) + def test_save_image(fis): img = data.astronaut() - fname = 'output.png' - fis.save_image('.', fname, img) + fname = "output.png" + fis.save_image(".", fname, img) assert os.path.exists(fname) os.remove(fname) + def test_delete_image(fis, sample_image): - temp_dir = 'temp' + temp_dir = "temp" os.mkdir(temp_dir) - fis.move_image('.', temp_dir, sample_image) - fis.delete_image(temp_dir, 'test_img.png') - assert not os.path.exists(os.path.join(temp_dir, 'test_img.png')) + fis.move_image(".", temp_dir, sample_image) + fis.delete_image(temp_dir, "test_img.png") + assert not os.path.exists(os.path.join(temp_dir, "test_img.png")) os.rmdir(temp_dir) diff --git a/src/client/test/test_main_window.py b/src/client/test/test_main_window.py index d5fae533..788dea3c 100644 --- a/src/client/test/test_main_window.py +++ b/src/client/test/test_main_window.py @@ -1,7 +1,8 @@ import os import pytest import sys -sys.path.append('../') + +sys.path.append("../") from skimage import data from skimage.io import imsave @@ -16,11 +17,13 @@ from dcp_client.utils.sync_src_dst import DataRSync from dcp_client.utils import settings + @pytest.fixture() def setup_global_variable(): settings.accepted_types = (".jpg", ".jpeg", ".png", ".tiff", ".tif") yield settings.accepted_types + @pytest.fixture def app(qtbot, setup_global_variable): @@ -30,112 +33,121 @@ def app(qtbot, setup_global_variable): img2 = data.coffee() img3 = data.cat() - if not os.path.exists('train_data_path'): - os.mkdir('train_data_path') - imsave('train_data_path/astronaut.png', img1) - - if not os.path.exists('in_prog'): - os.mkdir('in_prog') - imsave('in_prog/coffee.png', img2) - - if not os.path.exists('eval_data_path'): - os.mkdir('eval_data_path') - imsave('eval_data_path/cat.png', img3) - - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') - application = Application(BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", - 7010, - 'eval_data_path', - 'train_data_path', - 'in_prog') + if not os.path.exists("train_data_path"): + os.mkdir("train_data_path") + imsave("train_data_path/astronaut.png", img1) + + if not os.path.exists("in_prog"): + os.mkdir("in_prog") + imsave("in_prog/coffee.png", img2) + + if not os.path.exists("eval_data_path"): + os.mkdir("eval_data_path") + imsave("eval_data_path/cat.png", img3) + + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") + application = Application( + BentomlModel(), + rsyncer, + FilesystemImageStorage(), + "0.0.0.0", + 7010, + "eval_data_path", + "train_data_path", + "in_prog", + ) # Create an instance of MainWindow widget = MainWindow(application) qtbot.addWidget(widget) yield widget widget.close() - + + def test_main_window_setup(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable assert app.title == "Data Overview" + def test_item_train_selected(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable # Select the first item in the tree view - #index = app.list_view_train.model().index(0, 0) + # index = app.list_view_train.model().index(0, 0) index = app.list_view_train.indexAt(app.list_view_train.viewport().rect().topLeft()) pos = app.list_view_train.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_train.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_train.viewport(), Qt.LeftButton, pos=pos) app.on_item_train_selected(index) # Assert that the selected item matches the expected item assert app.list_view_train.selectionModel().currentIndex() == index - assert app.app.cur_selected_img=='astronaut.png' - assert app.app.cur_selected_path==app.app.train_data_path + assert app.app.cur_selected_img == "astronaut.png" + assert app.app.cur_selected_path == app.app.train_data_path + def test_item_inprog_selected(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable # Select the first item in the tree view - index = app.list_view_inprogr.indexAt(app.list_view_inprogr.viewport().rect().topLeft()) + index = app.list_view_inprogr.indexAt( + app.list_view_inprogr.viewport().rect().topLeft() + ) pos = app.list_view_inprogr.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_inprogr.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_inprogr.viewport(), Qt.LeftButton, pos=pos) app.on_item_inprogr_selected(index) # Assert that the selected item matches the expected item assert app.list_view_inprogr.selectionModel().currentIndex() == index assert app.app.cur_selected_img == "coffee.png" assert app.app.cur_selected_path == app.app.inprogr_data_path + def test_item_eval_selected(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable # Select the first item in the tree view index = app.list_view_eval.indexAt(app.list_view_eval.viewport().rect().topLeft()) pos = app.list_view_eval.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_eval.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_eval.viewport(), Qt.LeftButton, pos=pos) app.on_item_eval_selected(index) # Assert that the selected item matches the expected item assert app.list_view_eval.selectionModel().currentIndex() == index - assert app.app.cur_selected_img=='cat.png' - assert app.app.cur_selected_path==app.app.eval_data_path + assert app.app.cur_selected_img == "cat.png" + assert app.app.cur_selected_path == app.app.eval_data_path + def test_train_button_click(qtbot, app): # Click the "Train Model" button app.sim = True QTest.mouseClick(app.train_button, Qt.LeftButton) # Wait until the worker thread is done - while app.worker_thread.isRunning(): QTest.qSleep(1000) + while app.worker_thread.isRunning(): + QTest.qSleep(1000) # The train functionality of the thread is tested with app tests + def test_inference_button_click(qtbot, app): # Click the "Generate Labels" button app.sim = True QTest.mouseClick(app.inference_button, Qt.LeftButton) # Wait until the worker thread is done - while app.worker_thread.isRunning(): QTest.qSleep(1000) - #QTest.qWaitForWindowActive(app, timeout=5000) + while app.worker_thread.isRunning(): + QTest.qSleep(1000) + # QTest.qWaitForWindowActive(app, timeout=5000) # The inference functionality of the thread is tested with app tests + def test_on_finished(qtbot, app): # Assert that the on_finished function re-enabled the buttons and set the worker thread to None assert app.train_button.isEnabled() assert app.inference_button.isEnabled() assert app.worker_thread is None + def test_launch_napari_button_click_without_selection(qtbot, app): # Try clicking the view button without having selected an image app.sim = True qtbot.mouseClick(app.launch_nap_button, Qt.LeftButton) - assert not hasattr(app, 'nap_win') + assert not hasattr(app, "nap_win") + def test_launch_napari_button_click(qtbot, app): settings.accepted_types = setup_global_variable @@ -143,28 +155,28 @@ def test_launch_napari_button_click(qtbot, app): index = app.list_view_eval.indexAt(app.list_view_eval.viewport().rect().topLeft()) pos = app.list_view_eval.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_eval.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_eval.viewport(), Qt.LeftButton, pos=pos) app.on_item_eval_selected(index) # Now click the view button qtbot.mouseClick(app.launch_nap_button, Qt.LeftButton) # Assert that the napari window has launched - assert hasattr(app, 'nap_win') + assert hasattr(app, "nap_win") assert app.nap_win.isVisible() -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def cleanup_files(request): # This code runs after all tests from all files have completed yield # Clean up - paths_to_clean = ['train_data_path', 'in_prog', 'eval_data_path'] + paths_to_clean = ["train_data_path", "in_prog", "eval_data_path"] for path in paths_to_clean: try: for fname in os.listdir(path): os.remove(os.path.join(path, fname)) os.rmdir(path) - except FileNotFoundError: pass + except FileNotFoundError: + pass except Exception as e: # Handle other exceptions - print(f"An error occurred while cleaning up {path}: {e}") \ No newline at end of file + print(f"An error occurred while cleaning up {path}: {e}") diff --git a/src/client/test/test_mywidget.py b/src/client/test/test_mywidget.py index e75172c1..7e10f53f 100644 --- a/src/client/test/test_mywidget.py +++ b/src/client/test/test_mywidget.py @@ -1,35 +1,47 @@ import pytest import sys -sys.path.append('../') + +sys.path.append("../") from PyQt5.QtWidgets import QMessageBox from dcp_client.gui._my_widget import MyWidget + @pytest.fixture def app(qtbot): - #q_app = QApplication([]) + # q_app = QApplication([]) widget = MyWidget() qtbot.addWidget(widget) yield widget widget.close() + def test_create_warning_box_ok(qtbot, app): result = None app.sim = True + def execute_warning_box(): nonlocal result box = QMessageBox() result = app.create_warning_box("Test Message", custom_dialog=box) - qtbot.waitUntil(execute_warning_box, timeout=5000) - assert result is True + + qtbot.waitUntil(execute_warning_box, timeout=5000) + assert result is True + def test_create_warning_box_cancel(qtbot, app): result = None app.sim = True + def execute_warning_box(): nonlocal result box = QMessageBox() - result = app.create_warning_box("Test Message", add_cancel_btn=True, custom_dialog=box) - qtbot.waitUntil(execute_warning_box, timeout=5000) # Add a timeout for the function to execute - assert result is False + result = app.create_warning_box( + "Test Message", add_cancel_btn=True, custom_dialog=box + ) + + qtbot.waitUntil( + execute_warning_box, timeout=5000 + ) # Add a timeout for the function to execute + assert result is False diff --git a/src/client/test/test_napari_window.py b/src/client/test/test_napari_window.py index 06978ebf..8c31ebcf 100644 --- a/src/client/test/test_napari_window.py +++ b/src/client/test/test_napari_window.py @@ -21,11 +21,12 @@ # yield napari_app # napari_app.close() + @pytest.fixture def napari_window(qtbot): - #img1 = data.astronaut() - #img2 = data.coffee() + # img1 = data.astronaut() + # img2 = data.coffee() img = data.cat() img_mask = np.zeros((2, img.shape[0], img.shape[1]), dtype=np.uint8) img_mask[0, 50:50, 50:50] = 1 @@ -34,61 +35,63 @@ def napari_window(qtbot): img_mask[1, 100:200, 100:200] = 1 img_mask[0, 200:300, 200:300] = 3 img_mask[1, 200:300, 200:300] = 2 - #img3_mask = img2_mask.copy() + # img3_mask = img2_mask.copy() + + if not os.path.exists("train_data_path"): + os.mkdir("train_data_path") - if not os.path.exists('train_data_path'): - os.mkdir('train_data_path') + if not os.path.exists("in_prog"): + os.mkdir("in_prog") - if not os.path.exists('in_prog'): - os.mkdir('in_prog') + if not os.path.exists("eval_data_path"): + os.mkdir("eval_data_path") + imsave("eval_data_path/cat.png", img) - if not os.path.exists('eval_data_path'): - os.mkdir('eval_data_path') - imsave('eval_data_path/cat.png', img) - - imsave('eval_data_path/cat_seg.tiff', img_mask) + imsave("eval_data_path/cat_seg.tiff", img_mask) - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") application = Application( - BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", + BentomlModel(), + rsyncer, + FilesystemImageStorage(), + "0.0.0.0", 7010, - os.path.join(os.getcwd(), 'eval_data_path'), - os.path.join(os.getcwd(), 'train_data_path'), - os.path.join(os.getcwd(), 'in_prog') + os.path.join(os.getcwd(), "eval_data_path"), + os.path.join(os.getcwd(), "train_data_path"), + os.path.join(os.getcwd(), "in_prog"), ) - application.cur_selected_img = 'cat.png' + application.cur_selected_img = "cat.png" application.cur_selected_path = application.eval_data_path widget = NapariWindow(application) - qtbot.addWidget(widget) - yield widget + qtbot.addWidget(widget) + yield widget widget.close() + def test_napari_window_initialization(napari_window): assert napari_window.viewer is not None assert napari_window.qctrl is not None assert napari_window.mask_choice_dropdown is not None + def test_on_add_to_curated_button_clicked(napari_window, monkeypatch): # Mock the create_warning_box method def mock_create_warning_box(message_text, message_title): - return None - monkeypatch.setattr(napari_window, 'create_warning_box', mock_create_warning_box) + return None + + monkeypatch.setattr(napari_window, "create_warning_box", mock_create_warning_box) - napari_window.app.cur_selected_img = 'cat.png' + napari_window.app.cur_selected_img = "cat.png" napari_window.app.cur_selected_path = napari_window.app.eval_data_path - napari_window.viewer.layers.selection.active.name = 'cat_seg' + napari_window.viewer.layers.selection.active.name = "cat_seg" # Simulate the button click napari_window.on_add_to_curated_button_clicked() - assert not os.path.exists('eval_data_path/cat.tiff') - assert not os.path.exists('eval_data_path/cat_seg.tiff') - assert os.path.exists('train_data_path/cat.png') - assert os.path.exists('train_data_path/cat_seg.tiff') - + assert not os.path.exists("eval_data_path/cat.tiff") + assert not os.path.exists("eval_data_path/cat_seg.tiff") + assert os.path.exists("train_data_path/cat.png") + assert os.path.exists("train_data_path/cat_seg.tiff") diff --git a/src/client/test/test_sync_src_dst.py b/src/client/test/test_sync_src_dst.py index ca652644..15ed79d3 100644 --- a/src/client/test/test_sync_src_dst.py +++ b/src/client/test/test_sync_src_dst.py @@ -1,26 +1,25 @@ import pytest -from dcp_client.utils.sync_src_dst import DataRSync +from dcp_client.utils.sync_src_dst import DataRSync @pytest.fixture def rsyncer(): - syncer = DataRSync(user_name="local", - host_name="local", - server_repo_path='.') + syncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") return syncer + def test_init(rsyncer): - assert rsyncer.user_name=="local" - assert rsyncer.host_name=="local" - assert rsyncer.server_repo_path=='.' + assert rsyncer.user_name == "local" + assert rsyncer.host_name == "local" + assert rsyncer.server_repo_path == "." + def test_first_sync_e(rsyncer): msg, _ = rsyncer.first_sync("eval_data_path") - assert msg=="Error" + assert msg == "Error" + def test_sync(rsyncer): msg, _ = rsyncer.sync("server", "client", "eval_data_path") - assert msg=="Error" - - + assert msg == "Error" diff --git a/src/client/test/test_utils.py b/src/client/test/test_utils.py index d09c8df6..88d2ce5b 100644 --- a/src/client/test/test_utils.py +++ b/src/client/test/test_utils.py @@ -3,35 +3,38 @@ sys.path.append("../") from dcp_client.utils import utils + def test_get_relative_path(): - filepath = '/here/we/are/testing/something.txt' - assert utils.get_relative_path(filepath)== 'something.txt' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_relative_path(filepath) == "something.txt" + def test_get_path_stem(): - filepath = '/here/we/are/testing/something.txt' - assert utils.get_path_stem(filepath)== 'something' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_path_stem(filepath) == "something" + def test_get_path_name(): - filepath = '/here/we/are/testing/something.txt' - assert utils.get_path_name(filepath)== 'something.txt' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_path_name(filepath) == "something.txt" + def test_get_path_parent(): - if sys.platform == 'win32' or sys.platform == 'cygwin': - filepath = '\\here\\we\\are\\testing\\something.txt' - assert utils.get_path_parent(filepath)== '\\here\\we\\are\\testing' + if sys.platform == "win32" or sys.platform == "cygwin": + filepath = "\\here\\we\\are\\testing\\something.txt" + assert utils.get_path_parent(filepath) == "\\here\\we\\are\\testing" else: - filepath = '/here/we/are/testing/something.txt' - assert utils.get_path_parent(filepath)== '/here/we/are/testing' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_path_parent(filepath) == "/here/we/are/testing" + def test_join_path(): - if sys.platform == 'win32' or sys.platform == 'cygwin': - filepath = '\\here\\we\\are\\testing\\something.txt' - path1 = '\\here\\we\\are\\testing' - path2 = 'something.txt' + if sys.platform == "win32" or sys.platform == "cygwin": + filepath = "\\here\\we\\are\\testing\\something.txt" + path1 = "\\here\\we\\are\\testing" + path2 = "something.txt" else: - filepath = '/here/we/are/testing/something.txt' - path1 = '/here/we/are/testing' - path2 = 'something.txt' + filepath = "/here/we/are/testing/something.txt" + path1 = "/here/we/are/testing" + path2 = "something.txt" assert utils.join_path(path1, path2) == filepath - - diff --git a/src/client/test/test_welcome_window.py b/src/client/test/test_welcome_window.py index 4b15803d..9fdaa49c 100644 --- a/src/client/test/test_welcome_window.py +++ b/src/client/test/test_welcome_window.py @@ -1,6 +1,7 @@ import pytest import sys -sys.path.append('../') + +sys.path.append("../") from PyQt5.QtCore import Qt from PyQt5.QtWidgets import QMessageBox @@ -12,39 +13,48 @@ from dcp_client.utils.sync_src_dst import DataRSync from dcp_client.utils import settings + @pytest.fixture def setup_global_variable(): settings.accepted_types = (".jpg", ".jpeg", ".png", ".tiff", ".tif") yield settings.accepted_types + @pytest.fixture def app(qtbot): - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') - application = Application(BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010) + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") + application = Application( + BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010 + ) # Create an instance of WelcomeWindow # q_app = QApplication([]) widget = WelcomeWindow(application) qtbot.addWidget(widget) - yield widget + yield widget widget.close() + @pytest.fixture def app_remote(qtbot): - rsyncer = DataRSync(user_name="remote", host_name="remote", server_repo_path='.') - application = Application(BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010) + rsyncer = DataRSync(user_name="remote", host_name="remote", server_repo_path=".") + application = Application( + BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010 + ) # Create an instance of WelcomeWindow # q_app = QApplication([]) widget = WelcomeWindow(application) qtbot.addWidget(widget) - yield widget + yield widget widget.close() + def test_welcome_window_initialization(app): assert app.title == "Select Dataset" assert app.val_textbox.text() == "" assert app.inprogr_textbox.text() == "" assert app.train_textbox.text() == "" + def test_warning_for_same_paths(qtbot, app, monkeypatch): app.app.eval_data_path = "/same/path" app.app.train_data_path = "/same/path" @@ -54,32 +64,43 @@ def test_warning_for_same_paths(qtbot, app, monkeypatch): def custom_exec(self): return QMessageBox.Ok - monkeypatch.setattr(QMessageBox, 'exec', custom_exec) - qtbot.mouseClick(app.start_button, Qt.LeftButton) + monkeypatch.setattr(QMessageBox, "exec", custom_exec) + qtbot.mouseClick(app.start_button, Qt.LeftButton) assert app.create_warning_box assert app.message_text == "All directory names must be distinct." + def test_on_text_changed(qtbot, app): app.app.train_data_path = "/initial/train/path" app.app.eval_data_path = "/initial/eval/path" app.app.inprogr_data_path = "/initial/inprogress/path" - app.on_text_changed(field_obj=app.train_textbox, field_name="train", text="/new/train/path") + app.on_text_changed( + field_obj=app.train_textbox, field_name="train", text="/new/train/path" + ) assert app.app.train_data_path == "/new/train/path" - app.on_text_changed(field_obj=app.val_textbox, field_name="eval", text="/new/eval/path") + app.on_text_changed( + field_obj=app.val_textbox, field_name="eval", text="/new/eval/path" + ) assert app.app.eval_data_path == "/new/eval/path" - app.on_text_changed(field_obj=app.inprogr_textbox, field_name="inprogress", text="/new/inprogress/path") + app.on_text_changed( + field_obj=app.inprogr_textbox, + field_name="inprogress", + text="/new/inprogress/path", + ) assert app.app.inprogr_data_path == "/new/inprogress/path" + def test_start_main_not_selected(qtbot, app): app.app.train_data_path = None app.app.eval_data_path = None app.sim = True qtbot.mouseClick(app.start_button, Qt.LeftButton) - assert not hasattr(app, 'mw') + assert not hasattr(app, "mw") + def test_start_main(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable @@ -93,11 +114,12 @@ def test_start_main(qtbot, app, setup_global_variable): # Simulate clicking the start button qtbot.mouseClick(app.start_button, Qt.LeftButton) # Check if the main window is created - #assert qtbot.waitUntil(lambda: hasattr(app, 'mw'), timeout=1000) - assert hasattr(app, 'mw') + # assert qtbot.waitUntil(lambda: hasattr(app, 'mw'), timeout=1000) + assert hasattr(app, "mw") # Check if the WelcomeWindow is hidden assert app.isHidden() + def test_start_upload_and_main(qtbot, app_remote, setup_global_variable, monkeypatch): settings.accepted_types = setup_global_variable app_remote.app.eval_data_path = "/path/to/eval" @@ -107,15 +129,15 @@ def test_start_upload_and_main(qtbot, app_remote, setup_global_variable, monkeyp def custom_exec(self): return QMessageBox.Ok - monkeypatch.setattr(QMessageBox, 'exec', custom_exec) - qtbot.mouseClick(app_remote.start_button, Qt.LeftButton) + monkeypatch.setattr(QMessageBox, "exec", custom_exec) + qtbot.mouseClick(app_remote.start_button, Qt.LeftButton) # should close because error on upload! - assert app_remote.done_upload==False + assert app_remote.done_upload == False assert not app_remote.isVisible() - assert not hasattr(app_remote, 'mw') - + assert not hasattr(app_remote, "mw") + -'''' +"""' # TODO wait for github respose def test_browse_eval_clicked(qtbot, app, monkeypatch): # Mock the QFileDialog so that it immediately returns a directory @@ -162,4 +184,4 @@ def test_browse_inprogr_clicked(qtbot, app): # Check if the textbox is updated with the selected path assert app.inprogr_textbox.text() == app.app.inprogr_data_path -''' \ No newline at end of file +""" diff --git a/src/server/MANIFEST.in b/src/server/MANIFEST.in new file mode 100644 index 00000000..ffd67494 --- /dev/null +++ b/src/server/MANIFEST.in @@ -0,0 +1 @@ +include dcp_server/*.yaml \ No newline at end of file diff --git a/src/server/dcp_server/__init__.py b/src/server/dcp_server/__init__.py index ffbb8826..a125355e 100644 --- a/src/server/dcp_server/__init__.py +++ b/src/server/dcp_server/__init__.py @@ -2,15 +2,11 @@ Overview of dcp_server Package ============================== -The dcp_server package is structured to handle various server-side functionalities related to image processing, segmentation, and model serving. +The dcp_server package is structured to handle various server-side functionalities related model serving for segmentation and training. Submodules: ------------ -dcp_server.fsimagestorage - Provides a class FilesystemImageStorage for dealing with image storage, loading, saving, and processing. - Contains methods for retrieving image-segmentation pairs, getting image size properties, loading images, preparing images and masks for training, rescaling images, resizing masks, saving images, and searching for images and segmentations in directories. - dcp_server.models Defines various models for cell classification and segmentation, including CellClassifierFCNN, CellClassifierShallowModel, CellposePatchCNN, CustomCellposeModel, and UNet. These models handle tasks such as evaluation, forward pass, training, and updating configurations. @@ -23,6 +19,6 @@ Defines service classes, such as CustomBentoService and CustomRunnable, for serving the models with BentoML and handling computation on remote Python workers. dcp_server.utils - Provides various utility functions for image processing, feature extraction, file handling, configuration reading, and path manipulation. + Provides various utility functions for dealing with image storage, image processing, feature extraction, file handling, configuration reading, and path manipulation. """ diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.yaml similarity index 65% rename from src/server/dcp_server/config.cfg rename to src/server/dcp_server/config.yaml index bb92bda5..5652469a 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.yaml @@ -1,34 +1,40 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "CellposePatchCNN", - "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], - "seg_name_string": "_seg" + "model_to_use": "Inst2MultiSeg" }, "service": { "runner_name": "bento_runner", - "bento_model_path": "cppcnn", + "bento_model_path": "cells", "service_name": "data-centric-platform", "port": 7010 }, "model": { + "segmentor_name": "Cellpose", "segmentor": { "model_type": "cyto" }, + "classifier_name": "PatchClassifier", "classifier":{ - "model_class": "FCNN", "in_channels": 1, "num_classes": 2, - "black_bg": "False", - "include_mask": "False" + "features":[64,128,256,512], + "black_bg": False, + "include_mask": True } }, "data": { - "data_root": "data" + "data_root": "data", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg", + "patch_size": 64, + "noise_intensity": 5, + "gray": True, + "rescale": True }, "train":{ @@ -38,12 +44,7 @@ "min_train_masks": 1 }, "classifier":{ - "train_data":{ - "patch_size": 64, - "noise_intensity": 5, - "num_classes": 2 - }, - "n_epochs": 10, + "n_epochs": 20, "lr": 0.001, "batch_size": 1, "optimizer": "Adam" @@ -58,10 +59,6 @@ "batch_size": 1 }, "classifier": { - "data":{ - "patch_size": 64, - "noise_intensity": 5 - } }, "mask_channel_axis": 0 } diff --git a/src/server/dcp_server/config_instance.cfg b/src/server/dcp_server/config_instance.yaml similarity index 72% rename from src/server/dcp_server/config_instance.cfg rename to src/server/dcp_server/config_instance.yaml index dc841189..db266da0 100644 --- a/src/server/dcp_server/config_instance.cfg +++ b/src/server/dcp_server/config_instance.yaml @@ -1,14 +1,12 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "CustomCellposeModel", - "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], - "seg_name_string": "_seg" + "model_to_use": "CustomCellpose" }, "service": { - "runner_name": "cellpose_runner", - "bento_model_path": "cells", + "runner_name": "bento_runner", + "bento_model_path": "cells", "service_name": "data-centric-platform", "port": 7010 }, @@ -20,12 +18,16 @@ }, "data": { - "data_root": "data" + "data_root": "data", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg", + "gray": True, + "rescale": True }, "train":{ "segmentor":{ - "n_epochs": 10, + "n_epochs": 5, "channels": [0,0], "min_train_masks": 1 } diff --git a/src/server/dcp_server/config_semantic.cfg b/src/server/dcp_server/config_semantic.yaml similarity index 73% rename from src/server/dcp_server/config_semantic.cfg rename to src/server/dcp_server/config_semantic.yaml index 03f5e86f..e72459ac 100644 --- a/src/server/dcp_server/config_semantic.cfg +++ b/src/server/dcp_server/config_semantic.yaml @@ -1,9 +1,7 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "UNet", - "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], - "seg_name_string": "_seg" + "model_to_use": "UNet" }, "service": { @@ -15,15 +13,18 @@ "model": { "classifier":{ - "model_class": "UNet", "in_channels": 1, - "num_classes": 3, + "num_classes": 2, "features":[64,128,256,512] } }, "data": { - "data_root": "data" + "data_root": "data", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg", + "gray": True, + "rescale": True }, "train":{ @@ -37,7 +38,9 @@ "eval":{ "classifier": { + }, - "mask_channel_axis": null + "compute_instance": True, + "mask_channel_axis": 0 } } \ No newline at end of file diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py deleted file mode 100644 index c77d329b..00000000 --- a/src/server/dcp_server/fsimagestorage.py +++ /dev/null @@ -1,225 +0,0 @@ -import os -import numpy as np -from skimage.io import imread, imsave -from skimage.transform import resize, rescale - -from dcp_server import utils - -# Import configuration -dirname = os.path.dirname(__file__) -setup_config = utils.read_config('setup', config_path = os.path.join(dirname, 'config.cfg')) - -class FilesystemImageStorage(): - """ - Class used to deal with everything related to image storing and processing - loading, saving, transforming. - """ - def __init__(self, data_root, model_used): - self.root_dir = data_root - self.model_used = model_used - - def load_image(self, cur_selected_img, is_gray=True): - """ Load the image using skimage. - - :param cur_selected_img: full path of the image that needs to be loaded - :type cur_selected_img: str - :return: loaded image - :rtype: ndarray - """ - try: - return imread(os.path.join(self.root_dir , cur_selected_img), as_gray=is_gray) - except ValueError: return None - - def save_image(self, to_save_path, img): - """ Save given image using skimage. - - :param to_save_path: full path to the directory that the image needs to be save into (use also image name in the path, eg. '/users/new_image.png') - :type to_save_path: str - :param img: image you wish to save - :type img: ndarray - """ - imsave(os.path.join(self.root_dir, to_save_path), img) - - def search_images(self, directory): - """ Get a list of full paths of the images in the directory. - - :param directory: Path to the directory to search for images. - :type directory: str - :return: List of image paths found in the directory (only image types that are supported - see config.cfg 'setup' section). - :rtype: list - """ - # Take all segmentations of the image from the current directory: - directory = os.path.join(self.root_dir, directory) - seg_files = [file_name for file_name in os.listdir(directory) if setup_config['seg_name_string'] in file_name] - # Take the image files - difference between the list of all the files in the directory and the list of seg files and only file extensions currently accepted - image_files = [os.path.join(directory, file_name) for file_name in os.listdir(directory) if (file_name not in seg_files) and (utils.get_file_extension(file_name) in setup_config['accepted_types'])] - return image_files - - def search_segs(self, cur_selected_img): - """ Returns a list of full paths of segmentations for an image. - - :param cur_selected_img: Full path of the image for which segmentations are needed. - :type cur_selected_img: str - :return: List of segmentation paths for the given image. - :rtype: list - """ - - # Check the directory the image was selected from: - img_directory = utils.get_path_parent(os.path.join(self.root_dir, cur_selected_img)) - # Take all segmentations of the image from the current directory: - search_string = utils.get_path_stem(cur_selected_img) + setup_config['seg_name_string'] - #seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] - # TODO: check where this is used - copied the command from app's search_segs function (to fix the 1_seg and 11_seg bug) - - seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] - - return seg_files - - def get_image_seg_pairs(self, directory): - """ Get pairs of (image, image_seg). - - Used, e.g., in training to create training data-training labels pairs. - - :param directory: Path to the directory to search images and segmentations in. - :type directory: str - :return: List of tuple pairs (image, image_seg). - :rtype: list - """ - - image_files = self.search_images(os.path.join(self.root_dir, directory)) - seg_files = [] - for image in image_files: - seg = self.search_segs(image) - #TODO - the search seg returns all the segs, but here we need only one, hence the seg[0]. Check if it is from training path? - seg_files.append(seg[0]) - return list(zip(image_files, seg_files)) - - def get_unsupported_files(self, directory): - """ Get unsupported files found in the given directory. - - :param directory: Directory path to search for files in. - :type directory: str - :return: List of unsupported files. - :rtype: list - """ - return [file_name for file_name in os.listdir(os.path.join(self.root_dir, directory)) - if not file_name.startswith('.') and utils.get_file_extension(file_name) not in setup_config['accepted_types']] - - def get_image_size_properties(self, img, file_extension): - """ Get properties of the image size. - - :param img: Image (numpy array). - :type img: ndarray - :param file_extension: File extension of the image as saved in the directory. - :type file_extension: str - :return: Size properties: - - height - - width - - z_axis - :rtype: dict - """ - orig_size = img.shape - # png and jpeg will be RGB by default and 2D - # tif can be grayscale 2D or 3D [Z, H, W] - # image channels have already been removed in imread with is_gray=True - if file_extension in (".jpg", ".jpeg", ".png"): - height, width = orig_size[0], orig_size[1] - z_axis = None - elif file_extension in (".tiff", ".tif") and len(orig_size)==2: - height, width = orig_size[0], orig_size[1] - z_axis = None - # if we have 3 dimensions the [Z, H, W] - elif file_extension in (".tiff", ".tif") and len(orig_size)==3: - print('Warning: 3D image stack found. We are assuming your first dimension is your stack dimension. Please cross check this.') - height, width = orig_size[1], orig_size[2] - z_axis = 0 - else: - print('File not currently supported. See documentation for accepted types') - - return height, width, z_axis - - def rescale_image(self, img, height, width, channel_ax=None, order=2): - """ Rescale image. - - :param img: Image. - :type img: ndarray - :param height: Height of the image. - :type height: int - :param width: Width of the image. - :type width: int - :param channel_ax: Channel axis. - :type channel_ax: int - :return: Rescaled image. - :rtype: ndarray - """ - - if self.model_used == "UNet": - height_pad = (height//16 + 1)*16 - height - width_pad = (width//16 + 1)*16 - width - return np.pad(img, ((0, height_pad),(0, width_pad))) - else: - # Cellpose segmentation runs best with 512 size? TODO: check - max_dim = max(height, width) - rescale_factor = max_dim/512 - return rescale(img, 1/rescale_factor, order=order, channel_axis=channel_ax) - - def resize_mask(self, mask, height, width, channel_ax=None, order=2): - """ Resize the mask so it matches the original image size. - - :param mask: Image. - :type mask: ndarray - :param height: Height of the image. - :type height: int - :param width: Width of the image. - :type width: int - :param order: From scikit-image - the order of the spline interpolation. Default is 0 if image.dtype is bool and 1 otherwise. - :type order: int - :return: Resized image. - :rtype: ndarray - """ - - if self.model_used == "UNet": - # we assume an order C, H, W - if channel_ax is not None and channel_ax==0: - height_pad = mask.shape[1] - height - width_pad = mask.shape[2]- width - return mask[:, :-height_pad, :-width_pad] - elif channel_ax is not None and channel_ax==2: - height_pad = mask.shape[0] - height - width_pad = mask.shape[1]- width - return mask[:-height_pad, :-width_pad, :] - elif channel_ax is not None and channel_ax==1: - height_pad = mask.shape[2] - height - width_pad = mask.shape[0]- width - return mask[:-width_pad, :, :-height_pad] - - else: - if channel_ax is not None: - n_channel_dim = mask.shape[channel_ax] - output_size = [height, width] - output_size.insert(channel_ax, n_channel_dim) - else: output_size = [height, width] - return resize(mask, output_size, order=order) - - def prepare_images_and_masks_for_training(self, train_img_mask_pairs): - """ Image and mask processing for training. - - :param train_img_mask_pairs: List pairs of (image, image_seg) (as returned by get_image_seg_pairs() function). - :type train_img_mask_pairs: list - :return: Lists of processed images and masks. - :rtype: tuple - """ - - imgs=[] - masks=[] - for img_file, mask_file in train_img_mask_pairs: - img = self.load_image(img_file) - mask = imread(mask_file) - if self.model_used == "UNet": - # Unet only accepts image sizes divisable by 16 - height_pad = (img.shape[0]//16 + 1)*16 - img.shape[0] - width_pad = (img.shape[1]//16 + 1)*16 - img.shape[1] - img = np.pad(img, ((0, height_pad),(0, width_pad))) - mask = np.pad(mask, ((0, 0), (0, height_pad),(0, width_pad))) - imgs.append(img) - masks.append(mask) - return imgs, masks \ No newline at end of file diff --git a/src/server/dcp_server/main.py b/src/server/dcp_server/main.py index d670047d..9c149b5b 100644 --- a/src/server/dcp_server/main.py +++ b/src/server/dcp_server/main.py @@ -1,9 +1,11 @@ -import subprocess from os import path import sys -from utils import read_config +import subprocess + +from dcp_server.utils.helpers import read_config + -def main(): +def main() -> None: """ Contains main functionality related to the server. """ @@ -15,21 +17,24 @@ def main(): # else: # config_path = 'config.cfg' - local_path = path.join(__file__, '..') + local_path = path.join(__file__, "..") dir_name = path.dirname(path.abspath(sys.argv[0])) - service_config = read_config('service', config_path = path.join(dir_name, 'config.cfg')) - port = str(service_config['port']) + service_config = read_config( + "service", config_path=path.join(dir_name, "config.yaml") + ) + port = str(service_config["port"]) - subprocess.run([ - "bentoml", - "serve", - '--working-dir', - local_path, - "service:svc", - "--reload", - "--port="+port, - ]) - + subprocess.run( + [ + "bentoml", + "serve", + "--working-dir", + local_path, + "service:svc", + "--reload", + "--port=" + port, + ] + ) if __name__ == "__main__": diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py deleted file mode 100644 index 509fc1d3..00000000 --- a/src/server/dcp_server/models.py +++ /dev/null @@ -1,810 +0,0 @@ -from cellpose import models, utils -import torch -from torch import nn -from torch.optim import Adam -from torch.utils.data import TensorDataset, DataLoader -from torchmetrics import F1Score -from copy import deepcopy -from tqdm import tqdm -import numpy as np -from scipy.ndimage import label -from skimage.measure import label as label_mask - - -from sklearn.ensemble import RandomForestClassifier -from sklearn.metrics import f1_score, log_loss -from sklearn.exceptions import NotFittedError - -from cellpose.metrics import aggregated_jaccard_index -from cellpose.dynamics import labels_to_flows -#from segment_anything import SamPredictor, sam_model_registry -#from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator - -from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset, create_dataset_for_rf - -class CustomCellposeModel(models.CellposeModel, nn.Module): - """ - Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing - additional attributes and methods needed for this project. - """ - def __init__(self, model_config, train_config, eval_config, model_name): - """ Construct all the necessary attributes for the CustomCellposeModel. - The model inherits all attributes from the parent class, the init allows to pass any other argument that the parent class accepts. - Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted. - - :param model_config: dictionary passed from the config file with all the arguments for the __init__ function and model initialization - :type model_config: dict - :param train_config: dictionary passed from the config file with all the arguments for training function - :type train_config: dict - :param eval_config: dictionary passed from the config file with all the arguments for eval function - :type eval_config: dict - """ - - # Initialize the cellpose model - #super().__init__(**model_config["segmentor"]) - nn.Module.__init__(self) - models.CellposeModel.__init__(self, **model_config["segmentor"]) - self.mkldnn = False # otherwise we get error with saving model - self.train_config = train_config - self.eval_config = eval_config - self.loss = 1e6 - self.model_name = model_name - - def update_configs(self, train_config, eval_config): - """ Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - - def eval_all_outputs(self, img): - """ Get all outputs of the model when running eval. - - :param img: Input image for segmentation. - :type img: numpy.ndarray - :return: Probability mask for the input image. - :rtype: numpy.ndarray - """ - - return super().eval(x=img, **self.eval_config["segmentor"]) - - def eval(self, img): - """ Evaluate the model - find mask of the given image - Calls the original eval function. - - :param img: image to evaluate on - :type img: np.ndarray - :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. - :rtype: np.ndarray - """ - return super().eval(x=img, **self.eval_config["segmentor"])[0] # 0 to take only mask - - def train(self, imgs, masks): - """ Train the given model - Calls the original train function. - - :param imgs: images to train on (training data) - :type imgs: List[np.ndarray] - :param masks: masks of the given images (training labels) - :type masks: List[np.ndarray] - """ - - if not isinstance(masks, np.ndarray): # TODO Remove: all these should be taken care of in fsimagestorage - masks = np.array(masks) - - if masks[0].shape[0] == 2: - masks = list(masks[:,0,...]) - super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) - - # compute loss and metric - true_bin_masks = [mask>0 for mask in masks] # get binary masks - true_flows = labels_to_flows(masks) # get cellpose flows - # get predicted flows and cell probability - pred_masks = [] - pred_flows = [] - true_lbl = [] - for idx, img in enumerate(imgs): - mask, flows, _ = super().eval(x=img, **self.eval_config["segmentor"]) - pred_masks.append(mask) - pred_flows.append(np.stack([flows[1][0], flows[1][1], flows[2]])) # stack cell probability map, horizontal and vertical flow - true_lbl.append(np.stack([true_bin_masks[idx], true_flows[idx][2], true_flows[idx][3]])) - - true_lbl = np.stack(true_lbl) - pred_flows=np.stack(pred_flows) - pred_flows = torch.from_numpy(pred_flows).float().to('cpu') - # compute loss, combination of mse for flows and bce for cell probability - self.loss = self.loss_fn(true_lbl, pred_flows) - self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) - - def masks_to_outlines(self, mask): - """ Get outlines of masks as a 0-1 array - Calls the original cellpose.utils.masks_to_outlines function. - - :param mask: int, 2D or 3D array, mask of an image - :type mask: ndarray - :return: outlines - :rtype: ndarray - """ - return utils.masks_to_outlines(mask) #[True, False] outputs - - -class CellClassifierFCNN(nn.Module): - - """ - Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP - """ - - def __init__(self, model_config, train_config, eval_config): - """ Initialize the fully convolutional classifier. - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - """ - super().__init__() - - self.in_channels = model_config["classifier"].get("in_channels",1) - self.num_classes = model_config["classifier"].get("num_classes",3) - - self.train_config = train_config["classifier"] - self.eval_config = eval_config["classifier"] - - self.include_mask = model_config["classifier"]["include_mask"] - self.in_channels = self.in_channels + 1 if self.include_mask else self.in_channels - - self.layer1 = nn.Sequential( - nn.Conv2d(self.in_channels, 16, 3, 2, 5), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - - self.layer2 = nn.Sequential( - nn.Conv2d(16, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - - self.layer3 = nn.Sequential( - nn.Conv2d(64, 128, 3, 2, 4), - nn.BatchNorm2d(128), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - self.final_conv = nn.Conv2d(128, self.num_classes, 1) - self.pooling = nn.AdaptiveMaxPool2d(1) - - self.metric_fn = F1Score(num_classes=self.num_classes, task="multiclass") - - def update_configs(self, train_config, eval_config): - """ Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - - def forward(self, x): - """ Performs forward pass of the CellClassifierFCNN. - - :param x: Input tensor. - :type x: torch.Tensor - :return: Output tensor after passing through the network. - :rtype: torch.Tensor - """ - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - - x = self.final_conv(x) - x = self.pooling(x) - x = x.view(x.size(0), -1) - return x - - def train (self, imgs, labels): - """ Trains the given model. - - :param imgs: List of input images with shape (3, dx, dy). - :type imgs: List[np.ndarray[np.uint8]] - :param labels: List of classification labels. - :type labels: List[int] - """ - - lr = self.train_config['lr'] - epochs = self.train_config['n_epochs'] - batch_size = self.train_config['batch_size'] - # optimizer_class = self.train_config['optimizer'] - - # Convert input images and labels to tensors - - # normalize images - imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] - # convert to tensor - imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) - imgs = torch.permute(imgs, (0, 3, 1, 2)) - # Your classification label mask - labels = torch.LongTensor([label for label in labels]) - - # Create a training dataset and dataloader - train_dataset = TensorDataset(imgs, labels) - train_dataloader = DataLoader(train_dataset, batch_size=batch_size) - - loss_fn = nn.CrossEntropyLoss() - optimizer = Adam(params=self.parameters(), lr=lr) #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') - # TODO check if we should replace self.parameters with super.parameters() - - for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"): - self.loss, self.metric = 0, 0 - for data in train_dataloader: - imgs, labels = data - - optimizer.zero_grad() - preds = self.forward(imgs) - - l = loss_fn(preds, labels) - l.backward() - optimizer.step() - self.loss += l.item() - - self.metric += self.metric_fn(preds, labels) - - self.loss /= len(train_dataloader) - self.metric /= len(train_dataloader) - - def eval(self, img): - """ Evaluates the model on the provided image and return the predicted label. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: y_hat - predicted label. - :rtype: torch.Tensor - """ - # normalise - img = (img-np.min(img))/(np.max(img)-np.min(img)) - # convert to tensor - img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) - preds = self.forward(img) - y_hat = torch.argmax(preds, 1) - return y_hat - - -class CellposePatchCNN(nn.Module): - """ - Cellpose & patches of cells and then cnn to classify each patch. - """ - - def __init__(self, model_config, train_config, eval_config, model_name): - """ Construct all the necessary attributes for the CellposePatchCNN. - - :param model_config: Model configuration. - :type model_config: dict - - :param train_config: Training configuration. - :type train_config: dict - - :param eval_config: Evaluation configuration. - :type eval_config: dict - - :param model_name: Name of the model. - :type model_name: str - """ - super().__init__() - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - self.include_mask = self.model_config["classifier"]["include_mask"] - self.model_name = model_name - self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN") - - # Initialize the cellpose model and the classifier - self.segmentor = CustomCellposeModel(self.model_config, - self.train_config, - self.eval_config, - "Cellpose") - - if self.classifier_class == "FCNN": - self.classifier = CellClassifierFCNN(self.model_config, - self.train_config, - self.eval_config) - - elif self.classifier_class == "RandomForest": - self.classifier = CellClassifierShallowModel(self.model_config, - self.train_config, - self.eval_config) - # make sure include mask is set to False if we are using the random forest model - self.include_mask = False - - def update_configs(self, train_config, eval_config): - """ Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - - def train(self, imgs, masks): - """ Trains the given model. First trains the segmentor and then the clasiffier. - - :param imgs: images to train on (training data) - :type imgs: List[np.ndarray] - - :param masks: masks of the given images (training labels) - :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, - second channel classes, so [2, H, W] or [2, 3, H, W] for 3D - - """ - # train cellpose - masks = np.array(masks) - masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks - self.segmentor.train(deepcopy(imgs), masks_instances) - # create patch dataset to train classifier - masks_classes = list(masks[:,1,...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] - patches, patch_masks, labels = create_patch_dataset(imgs, - masks_classes, - masks_instances, - noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], - max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"], - include_mask = self.include_mask) - x = patches - if self.classifier_class == "RandomForest": - x = create_dataset_for_rf(patches, patch_masks) - # train classifier - self.classifier.train(x, labels) - # and compute metric and loss - self.metric = (self.segmentor.metric + self.classifier.metric) / 2 - self.loss = (self.segmentor.loss + self.classifier.loss)/2 - - def eval(self, img): - """ Evaluate the model on the provided image and return the final mask. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: Final mask containing instance mask and class masks. - :rtype: np.ndarray[np.uint16] - """ - - # TBD we assume image is 2D [H, W] (see fsimage storage) - # The final mask which is returned should have - # first channel the output of cellpose and the rest are the class channels - with torch.no_grad(): - # get instance mask from segmentor - instance_mask = self.segmentor.eval(img) - # find coordinates of detected objects - class_mask = np.zeros(instance_mask.shape) - - max_patch_size = self.eval_config["classifier"]["data"]["patch_size"] - if max_patch_size is None: max_patch_size = find_max_patch_size(instance_mask) - noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] - - # get patches centered around detected objects - patches, patch_masks, instance_labels, _ = get_centered_patches(img, - instance_mask, - max_patch_size, - noise_intensity=noise_intensity, - include_mask=self.include_mask) - x = patches - if self.classifier_class == "RandomForest": - x = create_dataset_for_rf(patches, patch_masks) - # loop over patches and create classification mask - for idx in range(len(x)): - patch_class = self.classifier.eval(x[idx]) - # Assign predicted class to corresponding location in final_mask - patch_class = patch_class.item() if isinstance(patch_class, torch.Tensor) else patch_class - class_mask[instance_mask==instance_labels[idx]] = patch_class + 1 - # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) # size 2xHxW - - return final_mask - -class CellClassifierShallowModel: - """ - This class implements a shallow model for cell classification using scikit-learn. - """ - - def __init__(self, model_config, train_config, eval_config): - """ Construct all the necessary attributes for the CellClassifierShallowModel. - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - """ - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - - self.model = RandomForestClassifier() # TODO chnage config so RandomForestClassifier accepts input params - - - def train(self, X_train, y_train): - """ Trains the model using the provided training data. - - :param X_train: Features of the training data. - :type X_train: numpy.ndarray - :param y_train: Labels of the training data. - :type y_train: numpy.ndarray - """ - - self.model.fit(X_train,y_train) - - y_hat = self.model.predict(X_train) - y_hat_proba = self.model.predict_proba(X_train) - - self.metric = f1_score(y_train, y_hat, average='micro') - # Binary Cross Entrop Loss - self.loss = log_loss(y_train, y_hat_proba) - - - def eval(self, X_test): - """ Evaluates the model on the provided test data. - - :param X_test: Features of the test data. - :type X_test: numpy.ndarray - :return: y_hat - predicted labels. - :rtype: numpy.ndarray - """ - - X_test = X_test.reshape(1,-1) - - try: - y_hat = self.model.predict(X_test) - except NotFittedError as e: - y_hat = np.zeros(X_test.shape[0]) - - return y_hat - -class UNet(nn.Module): - - """ - Unet is a convolutional neural network architecture for semantic segmentation. - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str - """ - - class DoubleConv(nn.Module): - """ - DoubleConv module consists of two consecutive convolutional layers with - batch normalization and ReLU activation functions. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize DoubleConv module. - - :param in_channels: Number of input channels. - :type in_channels: int - :param out_channels: Number of output channels. - :type out_channels: int - """ - super().__init__() - - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(), - nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(), - ) - - def forward(self, x): - """Forward pass through the DoubleConv module. - - :param x: Input tensor. - :type x: torch.Tensor - """ - return self.conv(x) - - - def __init__(self, model_config, train_config, eval_config, model_name): - """ Construct all the necessary attributes for the UNet model. - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str - """ - - super().__init__() - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - self.model_name = model_name - - # self.in_channels = self.model_config["unet"]["in_channels"] - # self.out_channels = self.model_config["unet"]["out_channels"] - # self.features = self.model_config["unet"]["features"] - - self.in_channels = self.model_config["classifier"]["in_channels"] - self.out_channels = self.model_config["classifier"]["num_classes"] + 1 - self.features = self.model_config["classifier"]["features"] - - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - - # Encoder - for feature in self.features: - self.encoder.append( - UNet.DoubleConv(self.in_channels, feature) - ) - self.in_channels = feature - - # Decoder - for feature in self.features[::-1]: - self.decoder.append( - nn.ConvTranspose2d( - feature*2, feature, kernel_size=2, stride=2 - ) - ) - self.decoder.append( - UNet.DoubleConv(feature*2, feature) - ) - - self.bottle_neck = UNet.DoubleConv(self.features[-1], self.features[-1]*2) - self.output_conv = nn.Conv2d(self.features[0], self.out_channels, kernel_size=1) - - def forward(self, x): - """ Forward pass of the UNet model. - - :param x: Input tensor. - :type x: torch.Tensor - :return: Output tensor. - :rtype: torch.Tensor - """ - skip_connections = [] - for encoder in self.encoder: - x = encoder(x) - skip_connections.append(x) - x = self.pool(x) - - x = self.bottle_neck(x) - skip_connections = skip_connections[::-1] - - for i in np.arange(len(self.decoder), step=2): - x = self.decoder[i](x) - skip_connection = skip_connections[i//2] - concatenate_skip = torch.cat((skip_connection, x), dim=1) - x = self.decoder[i+1](concatenate_skip) - - return self.output_conv(x) - - def train(self, imgs, masks): - """ Trains the UNet model using the provided images and masks. - - :param imgs: Input images for training. - :type imgs: list[numpy.ndarray] - :param masks: Masks corresponding to the input images. - :type masks: list[numpy.ndarray] - """ - - lr = self.train_config["classifier"]['lr'] - epochs = self.train_config["classifier"]['n_epochs'] - batch_size = self.train_config["classifier"]['batch_size'] - - # Convert input images and labels to tensors - # normalize images - imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] - # convert to tensor - imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) - imgs = imgs.unsqueeze(1) if imgs.ndim == 3 else imgs - - # Classification label mask - masks = np.array(masks) - masks = torch.stack([torch.from_numpy(mask[1].astype(np.int16)) for mask in masks]) - - # Create a training dataset and dataloader - train_dataset = TensorDataset(imgs, masks) - train_dataloader = DataLoader(train_dataset, batch_size=batch_size) - - loss_fn = nn.CrossEntropyLoss() - optimizer = Adam(params=self.parameters(), lr=lr) - - for _ in tqdm(range(epochs), desc="Running UNet training"): - - self.loss = 0 - - for imgs, masks in train_dataloader: - imgs = imgs.float() - masks = masks.long() - - #forward path - preds = self.forward(imgs) - loss = loss_fn(preds, masks) - - #backward path - optimizer.zero_grad() - loss.backward() - optimizer.step() - - self.loss += loss.detach().mean().item() - - self.loss /= len(train_dataloader) - - def eval(self, img): - """ Evaluate the model on the provided image and return the predicted label. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: predicted mask consists of instance and class masks - :rtype: numpy.ndarray - """ - with torch.no_grad(): - # normalise - img = (img-np.min(img))/(np.max(img)-np.min(img)) - img = torch.from_numpy(img).float().unsqueeze(0) - - img = img.unsqueeze(1) if img.ndim == 3 else img - - preds = self.forward(img) - class_mask = torch.argmax(preds, 1).numpy()[0] - - instance_mask = label((class_mask > 0).astype(int))[0] - - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) - - return final_mask - -class CellposeMultichannel(): - """ - Multichannel image segmentation model. - Run the separate CustomCellposeModel models for each channel return the mask corresponding to each object type. - """ - - def __init__(self, model_config, train_config, eval_config, model_name="Cellpose"): - """ Constructs all the necessary attributes for the CellposeMultichannel model. - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str - """ - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - self.model_name = model_name - self.num_of_channels = self.model_config["classifier"]["num_classes"] - - self.cellpose_models = [ - CustomCellposeModel(self.model_config, - self.train_config, - self.eval_config, - self.model_name - ) for _ in range(self.num_of_channels) - ] - - def train(self, imgs, masks): - """ Train the model on the provided images and masks. - - :param imgs: Input images for training. - :type imgs: list[numpy.ndarray] - :param masks: Masks corresponding to the input images. - :type masks: list[numpy.ndarray] - """ - - for i in range(self.num_of_channels): - - masks_class = [] - - for mask in masks: - mask_class = mask.copy() - # set all instances in the instance mask not corresponding to the class in question to zero - mask_class[0][mask_class[1]!=(i+1)] = 0 - masks_class.append(mask_class) - - self.cellpose_models[i].train(imgs, masks_class) - - self.metric = np.mean([self.cellpose_models[i].metric for i in range(self.num_of_channels)]) - self.loss = np.mean([self.cellpose_models[i].loss for i in range(self.num_of_channels)]) - - - def eval(self, img): - """ Evaluate the model on the provided image. The instance mask are computed as the union of the predicted model outputs, while the class of - each object is assigned based on majority voting between the models. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: predicted mask consists of instance and class masks - :rtype: numpy.ndarray - """ - - instance_masks, class_masks, model_confidences = [], [], [] - - for i in range(self.num_of_channels): - # get the instance mask and pixel-wise cell probability mask - instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) - confidence = probs[2] - # assign the appropriate class to all objects detected by this model - class_mask = np.zeros_like(instance_mask) - class_mask[instance_mask>0]=(i + 1) - - instance_masks.append(instance_mask) - class_masks.append(class_mask) - model_confidences.append(confidence) - # merge the outputs of the different models using the pixel-wise cell probability mask - merged_mask_instances, class_mask = self.merge_masks(instance_masks, class_masks, model_confidences) - # set all connected components to the same label in the instance mask - instance_mask = label_mask(merged_mask_instances>0) - # and set the class with the most pixels to that object - for inst_id in np.unique(instance_mask)[1:]: - where_inst_id = np.where(instance_mask==inst_id) - vals, counts = np.unique(class_mask[where_inst_id], return_counts=True) - class_mask[where_inst_id] = vals[np.argmax(counts)] - # take the final mask by stancking instance and class mask - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) - - return final_mask - - def merge_masks(self, inst_masks, class_masks, probabilities): - """ Merges the instance and class masks resulting from the different models using the pixel-wise cell probability. The output of the model - with the maximum probability is selected for each pixel. - - :param inst_masks: List of predicted instance masks from each model. - :type inst_masks: List[np.array] - :param class_masks: List of corresponding class masks from each model. - :type class_masks: List[np.array] - :param probabilities: List of corresponding pixel-wise cell probability masks - :type probabilities: List[np.array] - :return: A tuple containing the following elements: - - final_mask_inst (numpy.ndarray): A single instance mask where for each pixel the output of the model with the highest probability is selected - - final_mask_class (numpy.ndarray): A single class mask where for each pixel the output of the model with the highest probability is selected - :rtype: tuple - """ - # Convert lists to numpy arrays - inst_masks = np.array(inst_masks) - class_masks = np.array(class_masks) - probabilities = np.array(probabilities) - - # Find the index of the mask with the maximum probability for each pixel - max_prob_indices = np.argmax(probabilities, axis=0) - - # Use the index to select the corresponding mask for each pixel - final_mask_inst = inst_masks[max_prob_indices, np.arange(inst_masks.shape[1])[:, None], np.arange(inst_masks.shape[2])] - final_mask_class = class_masks[max_prob_indices, np.arange(class_masks.shape[1])[:, None], np.arange(class_masks.shape[2])] - - return final_mask_inst, final_mask_class - - - - - - -# class CustomSAMModel(): -# # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb -# def __init__(self): -# pass diff --git a/src/server/dcp_server/models/__init__.py b/src/server/dcp_server/models/__init__.py new file mode 100644 index 00000000..eba3d089 --- /dev/null +++ b/src/server/dcp_server/models/__init__.py @@ -0,0 +1,8 @@ +# dcp_server.models/__init__.py + +from .custom_cellpose import CustomCellpose +from .inst_to_multi_seg import Inst2MultiSeg +from .multicellpose import MultiCellpose +from .unet import UNet + +__all__ = ["CustomCellpose", "Inst2MultiSeg", "MultiCellpose", "UNet"] diff --git a/src/server/dcp_server/models/classifiers.py b/src/server/dcp_server/models/classifiers.py new file mode 100644 index 00000000..43fed489 --- /dev/null +++ b/src/server/dcp_server/models/classifiers.py @@ -0,0 +1,233 @@ +from tqdm import tqdm +from typing import List +import numpy as np + +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import TensorDataset, DataLoader +from torchmetrics import F1Score + +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import f1_score, log_loss +from sklearn.exceptions import NotFittedError + + +class PatchClassifier(nn.Module): + """Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP""" + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Initialize the fully convolutional classifier. + + :param model_name: Name of the model. + :type model_name: str + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configuration. + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + super().__init__() + + self.model_name = model_name + self.model_config = model_config["classifier"] + self.data_config = data_config + self.train_config = train_config["classifier"] + self.eval_config = eval_config["classifier"] + + self.build_model() + + def train(self, imgs: List[np.ndarray], labels: List[np.ndarray]) -> None: + """Trains the given model + + :param imgs: List of input images with shape (3, dx, dy). + :type imgs: List[np.ndarray[np.uint8]] + :param labels: List of classification labels. + :type labels: List[int] + """ + + # Convert input images and labels to tensors + imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) + imgs = torch.permute(imgs, (0, 3, 1, 2)) + # Your classification label mask + labels = torch.LongTensor([label for label in labels]) + + # Create a training dataset and dataloader + train_dataloader = DataLoader( + TensorDataset(imgs, labels), batch_size=self.train_config["batch_size"] + ) + + loss_fn = nn.CrossEntropyLoss() + optimizer = Adam(params=self.parameters(), lr=self.train_config["lr"]) + # optimizer_class = self.train_config["optimizer"] + # eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') + + # TODO check if we should replace self.parameters with super.parameters() + for _ in tqdm( + range(self.train_config["n_epochs"]), + desc="Running PatchClassifier training", + ): + + self.loss, self.metric = 0, 0 + for data in train_dataloader: + imgs, labels = data + + optimizer.zero_grad() + preds = self.forward(imgs) + + l = loss_fn(preds, labels) + l.backward() + optimizer.step() + self.loss += l.item() + + self.metric += self.metric_fn(preds, labels) + + self.loss /= len(train_dataloader) + self.metric /= len(train_dataloader) + + def eval(self, img: np.ndarray) -> torch.Tensor: + """Evaluates the model on the provided image and return the predicted label. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: y_hat - predicted label. + :rtype: torch.Tensor + """ + # convert to tensor + img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze( + 0 + ) + preds = self.forward(img) + y_hat = torch.argmax(preds, 1) + return y_hat + + def build_model(self) -> None: + """Builds the PatchClassifer.""" + in_channels = self.model_config["in_channels"] + in_channels = ( + in_channels + 1 if self.model_config["include_mask"] else in_channels + ) + + self.layer1 = nn.Sequential( + nn.Conv2d(in_channels, 16, 3, 2, 5), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer2 = nn.Sequential( + nn.Conv2d(16, 64, 3, 1, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer3 = nn.Sequential( + nn.Conv2d(64, 128, 3, 2, 4), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + self.final_conv = nn.Conv2d(128, self.model_config["num_classes"], 1) + self.pooling = nn.AdaptiveMaxPool2d(1) + + self.metric_fn = F1Score( + num_classes=self.model_config["num_classes"], task="multiclass" + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of the PatchClassifier. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Output tensor after passing through the network. + :rtype: torch.Tensor + """ + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.final_conv(x) + x = self.pooling(x) + x = x.view(x.size(0), -1) + return x + + +class FeatureClassifier: + """This class implements a shallow model for cell classification using scikit-learn.""" + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the FeatureClassifier + + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configuration. + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + + self.model_name = model_name + self.model_config = model_config["classifier"] # use for initialising model + # self.data_config = data_config + # self.train_config = train_config + # self.eval_config = eval_config + + self.model = RandomForestClassifier( + **self.model_config + ) # TODO chnage config so RandomForestClassifier accepts input params + + def train(self, X_train: List[np.ndarray], y_train: List[np.ndarray]) -> None: + """Trains the model using the provided training data. + + :param X_train: Features of the training data. + :type X_train: numpy.ndarray + :param y_train: Labels of the training data. + :type y_train: numpy.ndarray + """ + self.model.fit(X_train, y_train) + + y_hat = self.model.predict(X_train) + y_hat_proba = self.model.predict_proba(X_train) + + # Binary Cross Entrop Loss + self.loss = log_loss(y_train, y_hat_proba) + self.metric = f1_score(y_train, y_hat, average="micro") + + def eval(self, X_test: np.ndarray) -> np.ndarray: + """Evaluates the model on the provided test data. + + :param X_test: Features of the test data. + :type X_test: numpy.ndarray + :return: y_hat - predicted labels. + :rtype: numpy.ndarray + """ + + X_test = X_test.reshape(1, -1) + + try: + y_hat = self.model.predict(X_test) + except NotFittedError as e: + y_hat = np.zeros(X_test.shape[0]) + + return y_hat diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py new file mode 100644 index 00000000..b41d04bb --- /dev/null +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -0,0 +1,150 @@ +from copy import deepcopy +from typing import List +import numpy as np + +import torch +from torch import nn + +from cellpose import models, utils +from cellpose.metrics import aggregated_jaccard_index +from cellpose.dynamics import labels_to_flows + +from .model import Model + + +class CustomCellpose(models.CellposeModel, Model): + """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing + additional attributes and methods needed for this project. + """ + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the CustomCellpose. + The model inherits all attributes from the parent class, the init allows to pass any other argument that the parent class accepts. + Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted. + + :param model_name: The name of the current model + :type model_name: str + :param model_config: dictionary passed from the config file with all the arguments for the __init__ function and model initialization + :type model_config: dict + :param data_config: dictionary passed from the config file with all the data configurations + :type data_config: dict + :param train_config: dictionary passed from the config file with all the arguments for training function + :type train_config: dict + :param eval_config: dictionary passed from the config file with all the arguments for eval function + :type eval_config: dict + """ + + # Initialize the cellpose model + # super().__init__(**model_config["segmentor"]) + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) + models.CellposeModel.__init__(self, **model_config["segmentor"]) + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + self.model_name = model_name + self.mkldnn = False # otherwise we get error with saving model + self.loss = 1e6 + self.metric = 0 + + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: + """Trains the given model + Calls the original train function. + + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] + """ + if self.train_config["segmentor"]["n_epochs"] == 0: + return + super().train( + train_data=deepcopy(imgs), # Cellpose changes the images + train_labels=masks, + **self.train_config["segmentor"] + ) + pred_masks, pred_flows, true_flows = self.compute_masks_flows(imgs, masks) + # get loss, combination of mse for flows and bce for cell probability + self.loss = self.loss_fn(true_flows, pred_flows) + self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) + + def eval(self, img: np.ndarray) -> np.ndarray: + """Evaluate the model - find mask of the given image + Calls the original eval function. + + :param img: image to evaluate on + :type img: np.ndarray + :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. + :rtype: np.ndarray + """ + # 0 to take only mask - inline with other models eval should always return the final mask + return super().eval(x=img, **self.eval_config["segmentor"])[0] + + def eval_all_outputs(self, img: np.ndarray) -> tuple: + """Get all outputs of the model when running eval. + + :param img: Input image for segmentation. + :type img: numpy.ndarray + :return: mask, flows, styles etc. Returns the same as cellpose.models.CellposeModel.eval - see Cellpose API Guide for more details. + :rtype: tuple + """ + + return super().eval(x=img, **self.eval_config["segmentor"]) + + # I introduced typing here as suggest by the docstring + def compute_masks_flows( + self, imgs: List[np.ndarray], masks: List[np.ndarray] + ) -> tuple: + """Computes instance, binary mask and flows in x and y - needed for loss and metric computations + + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] + :return: A tuple containing the following elements: + - pred_masks List [np.ndarray]: A list of predicted instance masks + - pred_flows (torch.Tensor): A tensor holding the stacked predicted cell probability map, horizontal and vertical flows for all images + - true_lbl (np.ndarray): A numpy array holding the stacked true binary mask, horizontal and vertical flows for all images + :rtype: tuple + """ + # compute for loss and metric + true_bin_masks = [mask > 0 for mask in masks] # get binary masks + true_flows = labels_to_flows(masks) # get cellpose flows + # get predicted flows and cell probability + pred_masks = [] + pred_flows = [] + true_lbl = [] + for idx, img in enumerate(imgs): + mask, flows, _ = super().eval(x=img, **self.eval_config["segmentor"]) + pred_masks.append(mask) + pred_flows.append( + np.stack([flows[1][0], flows[1][1], flows[2]]) + ) # stack cell probability map, horizontal and vertical flow + true_lbl.append( + np.stack([true_bin_masks[idx], true_flows[idx][2], true_flows[idx][3]]) + ) + + true_lbl = np.stack(true_lbl) + pred_flows = np.stack(pred_flows) + pred_flows = torch.from_numpy(pred_flows).float().to("cpu") + return pred_masks, pred_flows, true_lbl + + def masks_to_outlines(self, mask: np.ndarray) -> np.ndarray: + """get outlines of masks as a 0-1 array + Calls the original cellpose.utils.masks_to_outlines function + + :param mask: int, 2D or 3D array, mask of an image + :type mask: ndarray + :return: outlines + :rtype: ndarray + """ + return utils.masks_to_outlines(mask) # [True, False] outputs diff --git a/src/server/dcp_server/models/inst_to_multi_seg.py b/src/server/dcp_server/models/inst_to_multi_seg.py new file mode 100644 index 00000000..43c3db01 --- /dev/null +++ b/src/server/dcp_server/models/inst_to_multi_seg.py @@ -0,0 +1,175 @@ +from typing import List + +import numpy as np +import torch + +from .model import Model +from .custom_cellpose import CustomCellpose +from dcp_server.models.classifiers import PatchClassifier, FeatureClassifier +from dcp_server.utils.processing import ( + get_centered_patches, + find_max_patch_size, + create_patch_dataset, + create_dataset_for_rf, +) + +# Dictionary mapping class names to their corresponding classes + +segmentor_mapping = {"Cellpose": CustomCellpose} +classifier_mapping = { + "PatchClassifier": PatchClassifier, + "RandomForest": FeatureClassifier, +} + + +class Inst2MultiSeg(Model): + """A two stage model for: 1. instance segmentation and 2. object wise classification""" + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the Inst2MultiSeg + + :param model_name: Name of the model. + :type model_name: str + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configurations + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + # super().__init__() + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) + + self.model_name = model_name + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + + self.segmentor_class = self.model_config.get("segmentor_name", "Cellpose") + self.classifier_class = self.model_config.get( + "classifier_name", "PatchClassifier" + ) + + # Initialize the cellpose model and the classifier + segmentor = segmentor_mapping.get(self.segmentor_class) + self.segmentor = segmentor( + self.segmentor_class, + self.model_config, + self.data_config, + self.train_config, + self.eval_config, + ) + classifier = classifier_mapping.get(self.classifier_class) + self.classifier = classifier( + self.classifier_class, + self.model_config, + self.data_config, + self.train_config, + self.eval_config, + ) + + # make sure include mask is set to False if we are using the random forest model + if self.classifier_class == "RandomForest": + if ( + "include_mask" not in self.model_config["classifier"].keys() + or self.model_config["classifier"]["include_mask"] is True + ): + # print("Include mask=True was found, but for Random Forest, this parameter must be set to False. Doing this now.") + self.model_config["classifier"]["include_mask"] = False + + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: + """Trains the given model. First trains the segmentor and then the clasiffier. + + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, + second channel classes, so [2, H, W] or [2, 3, H, W] for 3D. + """ + # train cellpose + masks_instances = [mask[0] for mask in masks] + # masks_instances = list(np.array(masks)[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks + self.segmentor.train(imgs, masks_instances) + masks_classes = [mask[1] for mask in masks] + # create patch dataset to train classifier + # masks_classes = list( + # masks[:,1,...] + # ) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + x, patch_masks, labels = create_patch_dataset( + imgs, + masks_classes, + masks_instances, + noise_intensity=self.data_config["noise_intensity"], + max_patch_size=self.data_config["patch_size"], + include_mask=self.model_config["classifier"]["include_mask"], + ) + # additionally extract features from the patches if you are in RF model + if self.classifier_class == "RandomForest": + x = create_dataset_for_rf(x, patch_masks) + # train classifier + self.classifier.train(x, labels) + # and compute metric and loss + self.metric = (self.segmentor.metric + self.classifier.metric) / 2 + self.loss = (self.segmentor.loss + self.classifier.loss) / 2 + + def eval(self, img: np.ndarray) -> np.ndarray: + """Evaluate the model on the provided image and return the final mask. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: Final mask containing instance mask and class masks. + :rtype: np.ndarray[np.uint16] + """ + # TBD we assume image is 2D [H, W] (see fsimage storage) + # The final mask which is returned should have + # first channel the output of cellpose and the rest are the class channels + with torch.no_grad(): + # get instance mask from segmentor + instance_mask = self.segmentor.eval(img) + # find coordinates of detected objects + class_mask = np.zeros(instance_mask.shape) + + max_patch_size = self.data_config["patch_size"] + if max_patch_size is None: + max_patch_size = find_max_patch_size(instance_mask) + + # get patches centered around detected objects + x, patch_masks, instance_labels, _ = get_centered_patches( + img, + instance_mask, + max_patch_size, + noise_intensity=self.data_config["noise_intensity"], + include_mask=self.model_config["classifier"]["include_mask"], + ) + if self.classifier_class == "RandomForest": + x = create_dataset_for_rf(x, patch_masks) + # loop over patches and create classification mask + for idx in range(len(x)): + patch_class = self.classifier.eval(x[idx]) + # Assign predicted class to corresponding location in final_mask + patch_class = ( + patch_class.item() + if isinstance(patch_class, torch.Tensor) + else patch_class + ) + class_mask[instance_mask == instance_labels[idx]] = patch_class + 1 + # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 + final_mask = np.stack( + (instance_mask, class_mask), axis=self.eval_config["mask_channel_axis"] + ).astype( + np.uint16 + ) # size 2xHxW + + return final_mask diff --git a/src/server/dcp_server/models/model.py b/src/server/dcp_server/models/model.py new file mode 100644 index 00000000..3cda12c1 --- /dev/null +++ b/src/server/dcp_server/models/model.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod +from typing import List +import numpy as np + + +class Model(ABC): + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + + self.model_name = model_name + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + + self.loss = 1e6 + self.metric = 0 + + @abstractmethod + def train(self, imgs: List[np.array], masks: List[np.array]) -> None: + pass + + @abstractmethod + def eval(self, img: np.array) -> np.array: + pass + + ''' + def update_configs(self, + config: dict, + ctype: str + ) -> None: + """ Update the training or evaluation configurations. + + :param config: Dictionary containing the updated configuration. + :type config: dict + :param ctype:type of config to update, will be train or eval + :type ctype: str + """ + if ctype=='train': self.train_config = config + else: self.eval_config = config + ''' + + +# from segment_anything import SamPredictor, sam_model_registry +# from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator +# class CustomSAMModel(): +# # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb +# def __init__(self): +# pass diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py new file mode 100644 index 00000000..5ece6b97 --- /dev/null +++ b/src/server/dcp_server/models/multicellpose.py @@ -0,0 +1,165 @@ +from typing import List +import numpy as np +from skimage.measure import label as label_mask + +from .model import Model +from .custom_cellpose import CustomCellpose + + +class MultiCellpose(Model): + """ + Multichannel image segmentation model. + Run the separate CustomCellpose models for each channel return the mask corresponding to each object type. + """ + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the MultiCellpose model. + + :param model_name: Name of the model. + :type model_name: str + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) + + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + self.model_name = model_name + self.num_of_channels = self.model_config["classifier"]["num_classes"] + + self.cellpose_models = [ + CustomCellpose( + "Cellpose", + self.model_config, + self.data_config, + self.train_config, + self.eval_config, + ) + for _ in range(self.num_of_channels) + ] + + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: + """ + Train the model on the provided images and masks. + + :param imgs: Input images for training. + :type imgs: list[numpy.ndarray] + :param masks: Masks corresponding to the input images. + :type masks: list[numpy.ndarray] + """ + + for i in range(self.num_of_channels): + + masks_class = [] + + for mask in masks: + mask_class = mask[0].copy() # TODO - Do we need copy?? + # set all instances in the instance mask not corresponding to the class in question to zero + mask_class[0][mask_class[1] != (i + 1)] = 0 + masks_class.append(mask_class) + self.cellpose_models[i].train(imgs, masks_class) + + self.metric = np.mean( + [self.cellpose_models[i].metric for i in range(self.num_of_channels)] + ) + self.loss = np.mean( + [self.cellpose_models[i].loss for i in range(self.num_of_channels)] + ) + + def eval(self, img: np.ndarray) -> np.ndarray: + """Evaluate the model on the provided image. The instance mask are computed as the union of the predicted model outputs, while the class of + each object is assigned based on majority voting between the models. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: predicted mask consists of instance and class masks + :rtype: numpy.ndarray + """ + + instance_masks, class_masks, model_confidences = [], [], [] + + for i in range(self.num_of_channels): + # get the instance mask and pixel-wise cell probability mask + instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) + confidence_map = probs[2] + # assign the appropriate class to all objects detected by this model + class_mask = np.zeros_like(instance_mask) + class_mask[instance_mask > 0] = i + 1 + + instance_masks.append(instance_mask) + class_masks.append(class_mask) + model_confidences.append(confidence_map) + # merge the outputs of the different models using the pixel-wise cell probability mask + merged_mask_instances, class_mask = self.merge_masks( + instance_masks, class_masks, model_confidences + ) + # set all connected components to the same label in the instance mask + instance_mask = label_mask(merged_mask_instances > 0) + # and set the class with the most pixels to that object + for inst_id in np.unique(instance_mask)[1:]: + where_inst_id = np.where(instance_mask == inst_id) + vals, counts = np.unique(class_mask[where_inst_id], return_counts=True) + class_mask[where_inst_id] = vals[np.argmax(counts)] + # take the final mask by stancking instance and class mask + final_mask = np.stack( + (instance_mask, class_mask), axis=self.eval_config["mask_channel_axis"] + ).astype(np.uint16) + + return final_mask + + def merge_masks( + self, + inst_masks: List[np.ndarray], + class_masks: List[np.ndarray], + probabilities: List[np.ndarray], + ) -> tuple: + """Merges the instance and class masks resulting from the different models using the pixel-wise cell probability. The output of the model + with the maximum probability is selected for each pixel. + + :param inst_masks: List of predicted instance masks from each model. + :type inst_masks: List[np.array] + :param class_masks: List of corresponding class masks from each model. + :type class_masks: List[np.array] + :param probabilities: List of corresponding pixel-wise cell probability masks + :type probabilities: List[np.array] + :return: A tuple containing the following elements: + - final_mask_inst (numpy.ndarray): A single instance mask where for each pixel the output of the model with the highest probability is selected + - final_mask_class (numpy.ndarray): A single class mask where for each pixel the output of the model with the highest probability is selected + :rtype: tuple + """ + # Convert lists to numpy arrays + inst_masks = np.array(inst_masks) + class_masks = np.array(class_masks) + probabilities = np.array(probabilities) + + # Find the index of the mask with the maximum probability for each pixel + max_prob_indices = np.argmax(probabilities, axis=0) + + # Use the index to select the corresponding mask for each pixel + final_mask_inst = inst_masks[ + max_prob_indices, + np.arange(inst_masks.shape[1])[:, None], + np.arange(inst_masks.shape[2]), + ] + final_mask_class = class_masks[ + max_prob_indices, + np.arange(class_masks.shape[1])[:, None], + np.arange(class_masks.shape[2]), + ] + + return final_mask_inst, final_mask_class diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py new file mode 100644 index 00000000..9d85a5f7 --- /dev/null +++ b/src/server/dcp_server/models/unet.py @@ -0,0 +1,235 @@ +from typing import List +from tqdm import tqdm +import numpy as np +from scipy.ndimage import label + +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import TensorDataset, DataLoader +from torchmetrics import JaccardIndex + +from .model import Model +from dcp_server.utils.processing import convert_to_tensor + + +class UNet(nn.Module, Model): + """ + Unet is a convolutional neural network architecture for semantic segmentation. + + :param in_channels: Number of input channels (default: 3). + :type in_channels: int + :param out_channels: Number of output channels (default: 4). + :type out_channels: int + :param features: List of feature channels for each encoder level (default: [64,128,256,512]). + :type features: list + """ + + class DoubleConv(nn.Module): + """ + DoubleConv module consists of two consecutive convolutional layers with + batch normalization and ReLU activation functions. + """ + + def __init__(self, in_channels: int, out_channels: int) -> None: + """ + Initialize DoubleConv module. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + """ + + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the DoubleConv module. + + :param x: Input tensor. + :type x: torch.Tensor + """ + return self.conv(x) + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the UNet model. + + :param model_name: Name of the model. + :type model_name: str + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configurations + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) + nn.Module.__init__(self) + # super().__init__() + + self.model_name = model_name + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + + self.loss = 1e6 + self.metric = 0 + self.num_classes = self.model_config["classifier"]["num_classes"] + 1 + self.metric_f = JaccardIndex( + task="multiclass", num_classes=self.num_classes, average="macro", ignore_index=0 + ) + + self.build_model() + + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: + """ + Trains the UNet model using the provided images and masks. + + :param imgs: Input images for training. + :type imgs: list[numpy.ndarray] + :param masks: Masks corresponding to the input images. + :type masks: list[numpy.ndarray] + """ + + imgs = convert_to_tensor(imgs, np.float32) + masks = convert_to_tensor( + [mask[1] for mask in masks], np.int16, unsqueeze=False + ) + + # Create a training dataset and dataloader + train_dataloader = DataLoader( + TensorDataset(imgs, masks), + batch_size=self.train_config["classifier"]["batch_size"], + ) + + loss_fn = nn.CrossEntropyLoss() + optimizer = Adam( + params=self.parameters(), lr=self.train_config["classifier"]["lr"] + ) + + for _ in tqdm( + range(self.train_config["classifier"]["n_epochs"]), + desc="Running UNet training", + ): + + self.loss = 0 + + for imgs, masks in train_dataloader: + # forward path + preds = self.forward(imgs.float()) + loss = loss_fn(preds, masks.long()) + + # backward path + optimizer.zero_grad() + loss.backward() + optimizer.step() + + self.loss += loss.detach().mean().item() + + self.loss /= len(train_dataloader) + + # compute metric on test set after train is complete + for imgs, masks in train_dataloader: + pred_masks = self.forward(imgs.float()) + self.metric += self.metric_f(pred_masks, masks) + self.metric /= len(train_dataloader) + + def eval(self, img: np.ndarray) -> np.ndarray: + """Evaluate the model on the provided image and return the predicted label. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: predicted mask consists of instance and class masks + :rtype: numpy.ndarray + """ + with torch.no_grad(): + + # img = torch.from_numpy(img).float().unsqueeze(0) + # img = img.unsqueeze(1) if img.ndim == 3 else img + img = convert_to_tensor([img], np.float32) + + preds = self.forward(img) + class_mask = torch.argmax(preds, 1).numpy()[0] + if self.eval_config["compute_instance"] is True: + instance_mask = label((class_mask > 0).astype(int))[0] + final_mask = np.stack( + [instance_mask, class_mask], + axis=self.eval_config["mask_channel_axis"], + ).astype(np.uint16) + else: + final_mask = class_mask.astype(np.uint16) + + return final_mask + + def build_model(self) -> None: + """Builds the UNet.""" + in_channels = self.model_config["classifier"]["in_channels"] + out_channels = self.num_classes + features = self.model_config["classifier"]["features"] + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + # Encoder + for feature in features: + self.encoder.append(UNet.DoubleConv(in_channels, feature)) + in_channels = feature + + # Decoder + for feature in features[::-1]: + self.decoder.append( + nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2) + ) + self.decoder.append(UNet.DoubleConv(feature * 2, feature)) + + self.bottle_neck = UNet.DoubleConv(features[-1], features[-1] * 2) + self.output_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the UNet model. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Output tensor. + :rtype: torch.Tensor + """ + skip_connections = [] + for encoder in self.encoder: + x = encoder(x) + skip_connections.append(x) + x = self.pool(x) + + x = self.bottle_neck(x) + skip_connections = skip_connections[::-1] + + for i in np.arange(len(self.decoder), step=2): + x = self.decoder[i](x) + skip_connection = skip_connections[i // 2] + concatenate_skip = torch.cat((skip_connection, x), dim=1) + x = self.decoder[i + 1](concatenate_skip) + + return self.output_conv(x) diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index f7b66245..b3897ff7 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -1,74 +1,85 @@ -from dcp_server import utils import os -# Import configuration -dirname = os.path.dirname(__file__) -setup_config = utils.read_config('setup', config_path = os.path.join(dirname, 'config.cfg')) +from dcp_server.utils import helpers +from dcp_server.utils.fsimagestorage import FilesystemImageStorage +from dcp_server import models as DCPModels -class GeneralSegmentation(): - """Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images. - """ - def __init__(self, imagestorage, runner, model): - """Constructs all the necessary attributes for the GeneralSegmentation. + +class GeneralSegmentation: + """Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images.""" + + def __init__( + self, imagestorage: FilesystemImageStorage, runner, model: DCPModels + ) -> None: + """Constructs all the necessary attributes for the GeneralSegmentation. :param imagestorage: imagestorage system used (see fsimagestorage.py) :type imagestorage: FilesystemImageStorage class object :param runner: runner used in the service :type runner: CustomRunnable class object - :param model: model used for segmentation + :param model: model used for segmentation :type model: class object from the models.py - """ + """ self.imagestorage = imagestorage - self.runner = runner + self.runner = runner self.model = model self.no_files_msg = "No image-label pairs found in curated directory" - - async def segment_image(self, input_path, list_of_images): + + async def segment_image(self, input_path: str, list_of_images: str) -> None: """Segments images from the given directory - :param input_path: directory where the images are saved + :param input_path: directory where the images are saved and where segmentation results will be saved :type input_path: str :param list_of_images: list of image objects from the directory that are currently supported :type list_of_images: list - """ + """ for img_filepath in list_of_images: - # Load the image - img = self.imagestorage.load_image(img_filepath) - # Get size properties - height, width, z_axis = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) - img = self.imagestorage.rescale_image(img, height, width) + img = self.imagestorage.prepare_img_for_eval(img_filepath) # Add channel ax into the model's evaluation parameters dictionary - self.model.eval_config['segmentor']['z_axis'] = z_axis + if self.imagestorage.model_used != "UNet": + self.model.eval_config["segmentor"][ + "channel_axis" + ] = self.imagestorage.channel_ax # Evaluate the model - mask = await self.runner.evaluate.async_run(img = img) - # Resize the mask - mask = self.imagestorage.resize_mask(mask, height, width, self.model.eval_config['mask_channel_axis'], order=0) + mask = await self.runner.evaluate.async_run(img=img) + # And prepare the mask for saving + mask = self.imagestorage.prepare_mask_for_save( + mask, self.model.eval_config["mask_channel_axis"] + ) # Save segmentation - seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' + seg_name = ( + helpers.get_path_stem(img_filepath) + + self.imagestorage.seg_name_string + + ".tiff" + ) self.imagestorage.save_image(os.path.join(input_path, seg_name), mask) - async def train(self, input_path): - """train model on images and masks in the given input directory. + async def train(self, input_path: str) -> str: + """Train model on images and masks in the given input directory. Calls the runner's train function. :param input_path: directory where the images are saved :type input_path: str :return: runner's train function output - path of the saved model :rtype: str - """ + """ train_img_mask_pairs = self.imagestorage.get_image_seg_pairs(input_path) if not train_img_mask_pairs: return self.no_files_msg - - imgs, masks = self.imagestorage.prepare_images_and_masks_for_training(train_img_mask_pairs) - model_save_path = await self.runner.train.async_run(imgs, masks) + + imgs, masks = self.imagestorage.prepare_images_and_masks_for_training( + train_img_mask_pairs + ) + model_save_path = await self.runner.train.async_run(imgs, masks) return model_save_path +''' + class GFPProjectSegmentation(GeneralSegmentation): def __init__(self, imagestorage, runner): super().__init__(imagestorage, runner) @@ -79,11 +90,11 @@ async def segment_image(self, input_path, list_of_images): class MitoProjectSegmentation(GeneralSegmentation): - """Segmentation class inheriting the attributes and functions from the original GeneralSegmentation and implementing + """ Segmentation class inheriting the attributes and functions from the original GeneralSegmentation and implementing additional attributes and methods needed for this project. """ def __init__(self, imagestorage, runner, model): - """Constructs all the necessary attributes for the MitoProjectSegmentation. Inherits all from the GeneralSegmentation + """ Constructs all the necessary attributes for the MitoProjectSegmentation. Inherits all from the GeneralSegmentation :param imagestorage: imagestorage system used (see fsimagestorage.py) :type imagestorage: FilesystemImageStorage class object @@ -96,7 +107,7 @@ def __init__(self, imagestorage, runner, model): # The only difference is in segment image async def segment_image(self, input_path, list_of_images): - """Segments images from the given directory. + """ Segments images from the given directory. The function differs from the parent class' function in obtaining the outlines of the masks. :param input_path: directory where the images are saved @@ -109,7 +120,7 @@ async def segment_image(self, input_path, list_of_images): # Load the image img = self.imagestorage.load_image(img_filepath) # Get size properties - height, width, channel_ax = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) + height, width, channel_ax = self.imagestorage.get_image_size_properties(img, helpers.get_file_extension(img_filepath)) img = self.imagestorage.rescale_image(img, height, width, channel_ax) # Add channel ax into the model's evaluation parameters dictionary @@ -129,5 +140,6 @@ async def segment_image(self, input_path, list_of_images): new_mask[outlines==True] = 1 # Save segmentation - seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' + seg_name = helpers.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), new_mask) +''' diff --git a/src/server/dcp_server/service.py b/src/server/dcp_server/service.py index 0be4fb0d..d464545b 100644 --- a/src/server/dcp_server/service.py +++ b/src/server/dcp_server/service.py @@ -1,9 +1,10 @@ from __future__ import annotations import bentoml import typing as t -from dcp_server.fsimagestorage import FilesystemImageStorage from dcp_server.serviceclasses import CustomBentoService, CustomRunnable -from dcp_server.utils import read_config + +from dcp_server.utils.fsimagestorage import FilesystemImageStorage +from dcp_server.utils.helpers import read_config import sys, inspect @@ -11,29 +12,46 @@ segmentation_module = __import__("segmentationclasses") # Import configuration -service_config = read_config('service', config_path = 'config.cfg') -model_config = read_config('model', config_path = 'config.cfg') -data_config = read_config('data', config_path = 'config.cfg') -train_config = read_config('train', config_path = 'config.cfg') -eval_config = read_config('eval', config_path = 'config.cfg') -setup_config = read_config('setup', config_path = 'config.cfg') +service_config = read_config("service", config_path="config.yaml") +model_config = read_config("model", config_path="config.yaml") +data_config = read_config("data", config_path="config.yaml") +train_config = read_config("train", config_path="config.yaml") +eval_config = read_config("eval", config_path="config.yaml") +setup_config = read_config("setup", config_path="config.yaml") # instantiate the model -model_class = getattr(models_module, setup_config['model_to_use']) -model = model_class(model_config = model_config, train_config = train_config, eval_config = eval_config, model_name=setup_config['model_to_use']) +model_class = getattr(models_module, setup_config["model_to_use"]) +model = model_class( + model_name=setup_config["model_to_use"], + model_config=model_config, + data_config=data_config, + train_config=train_config, + eval_config=eval_config, +) custom_model_runner = t.cast( - "CustomRunner", bentoml.Runner(CustomRunnable, name=service_config['runner_name'], - runnable_init_params={"model": model, "save_model_path": service_config['bento_model_path']}) + "CustomRunner", + bentoml.Runner( + CustomRunnable, + name=service_config["runner_name"], + runnable_init_params={ + "model": model, + "save_model_path": service_config["bento_model_path"], + }, + ), ) # instantiate the segmentation type -segm_class = getattr(segmentation_module, setup_config['segmentation']) -fsimagestorage = FilesystemImageStorage(data_config['data_root'], setup_config['model_to_use']) -segmentation = segm_class(imagestorage=fsimagestorage, - runner = custom_model_runner, - model = model) +segm_class = getattr(segmentation_module, setup_config["segmentation"]) +fsimagestorage = FilesystemImageStorage(data_config, setup_config["model_to_use"]) +segmentation = segm_class( + imagestorage=fsimagestorage, runner=custom_model_runner, model=model +) # Call the service -service = CustomBentoService(runner=segmentation.runner, segmentation=segmentation, service_name=service_config['service_name']) -svc = service.start_service() \ No newline at end of file +service = CustomBentoService( + runner=segmentation.runner, + segmentation=segmentation, + service_name=service_config["service_name"], +) +svc = service.start_service() diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index c66b81c8..bb1b8e30 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -5,24 +5,26 @@ from typing import List from dcp_server import models as DCPModels +import dcp_server.segmentationclasses as DCPSegClasses class CustomRunnable(bentoml.Runnable): - ''' + """ BentoML, Runner represents a unit of computation that can be executed on a remote Python worker and scales independently. CustomRunnable is a custom runner defined to meet all the requirements needed for this project. - ''' - SUPPORTED_RESOURCES = ("cpu",) #TODO add here? + """ + + SUPPORTED_RESOURCES = ("cpu",) # TODO add here? SUPPORTS_CPU_MULTI_THREADING = False - def __init__(self, model, save_model_path): + def __init__(self, model: DCPModels, save_model_path: str) -> None: """Constructs all the necessary attributes for the CustomRunnable. :param model: model to be trained or evaluated - will be one of classes in models.py :param save_model_path: full path of the model object that it will be saved into :type save_model_path: str - """ - + """ + self.model = model self.save_model_path = save_model_path # update with the latest model if it already exists to continue training from there? @@ -44,16 +46,20 @@ def evaluate(self, img: np.ndarray) -> np.ndarray: mask = self.model.eval(img=img) return mask - - def check_and_load_model(self): + + def check_and_load_model(self) -> None: """Checks if the specified model exists in BentoML's model repository. - If the model exists, it loads the latest version of the model into - memory. + If the model exists, it loads the latest version of the model into + memory. """ bento_model_list = [model.tag.name for model in bentoml.models.list()] if self.save_model_path in bento_model_list: - loaded_model = bentoml.picklable_model.load_model(self.save_model_path+":latest") - assert loaded_model.__class__.__name__ == self.model.__class__.__name__, 'Check your config, loaded model and model to use not the same!' + loaded_model = bentoml.picklable_model.load_model( + self.save_model_path + ":latest" + ) + assert ( + loaded_model.__class__.__name__ == self.model.__class__.__name__ + ), "Check your config, loaded model and model to use not the same!" self.model = loaded_model @bentoml.Runnable.method(batchable=False) @@ -66,14 +72,14 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: :type masks: List[np.ndarray] :return: path of the saved model :rtype: str - """ + """ self.model.train(imgs, masks) # Save the bentoml model bentoml.picklable_model.save_model( - self.save_model_path, + self.save_model_path, self.model, external_modules=[DCPModels], - ) + ) # bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store # self.model, # Model instance being saved # external_modules=[DCPModels] @@ -81,42 +87,49 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: return self.save_model_path -class CustomBentoService(): - """BentoML Service class. Contains all the functions necessary to serve the service with BentoML - """ - def __init__(self, runner, segmentation, service_name): + +class CustomBentoService: + """BentoML Service class. Contains all the functions necessary to serve the service with BentoML""" + + def __init__( + self, runner: CustomRunnable, segmentation: DCPSegClasses, service_name: str + ) -> None: """Constructs all the necessary attributes for the class CustomBentoService(): :param runner: runner used in the service :type runner: CustomRunnable class object :param segmentation: segmentation type used in the service :type segmentation: segmentation class object from the segmentationclasses.py - :param service_name: name of the service + :param service_name: name of the service :type service_name: str - """ + """ self.runner = runner self.segmentation = segmentation self.service_name = service_name - def start_service(self): + def start_service(self) -> None: """Starts the service :return: service object needed in service.py and for the bentoml serve call. - """ + """ svc = bentoml.Service(self.service_name, runners=[self.runner]) - @svc.api(input=Text(), output=NumpyNdarray()) #input path to the image output message with success and the save path - async def segment_image(input_path: str): + @svc.api( + input=Text(), output=NumpyNdarray() + ) # input path to the image output message with success and the save path + async def segment_image(input_path: str) -> np.ndarray: """function served within the service, used to segment images :param input_path: directory where the images for segmentation are saved :type input_path: str :return: list of files not supported :rtype: ndarray - """ + """ list_of_images = self.segmentation.imagestorage.search_images(input_path) - list_of_files_not_suported = self.segmentation.imagestorage.get_unsupported_files(input_path) - + list_of_files_not_suported = ( + self.segmentation.imagestorage.get_unsupported_files(input_path) + ) + if not list_of_images: return np.array(list_of_images) else: @@ -125,20 +138,19 @@ async def segment_image(input_path: str): return np.array(list_of_files_not_suported) @svc.api(input=Text(), output=Text()) - async def train(input_path): + async def train(input_path: str) -> str: """function served within the service, used to retrain the model :param input_path: directory where the images for training are saved :type input_path: str :return: message of success if training went well :rtype: str - """ + """ print("Calling retrain from server.") # Train the model msg = await self.segmentation.train(input_path) - if msg!=self.segmentation.no_files_msg: + if msg != self.segmentation.no_files_msg: msg = "Success! Trained model saved in: " + msg return msg - + return svc - diff --git a/src/server/dcp_server/utils/__init__.py b/src/server/dcp_server/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/server/dcp_server/utils/fsimagestorage.py b/src/server/dcp_server/utils/fsimagestorage.py new file mode 100644 index 00000000..d89025b3 --- /dev/null +++ b/src/server/dcp_server/utils/fsimagestorage.py @@ -0,0 +1,315 @@ +import os +from typing import Optional, List +import numpy as np +from skimage.io import imread, imsave +from skimage.transform import resize, rescale + +from dcp_server.utils import helpers +from dcp_server.utils.processing import pad_image, normalise + + +class FilesystemImageStorage: + """ + Class used to deal with everything related to image storing and processing - loading, saving, transforming. + """ + + def __init__(self, data_config: dict, model_used: str) -> None: + self.root_dir = data_config["data_root"] + self.seg_name_string = data_config["seg_name_string"] + self.accepted_types = data_config["accepted_types"] + self.gray = bool(data_config["gray"]) + self.rescale = bool(data_config["rescale"]) + self.model_used = model_used + self.channel_ax = None + self.img_height = None + self.img_width = None + + def load_image( + self, cur_selected_img: str, gray: Optional[bool] = None + ) -> Optional[np.ndarray]: + """Load the image (using skiimage) + + :param cur_selected_img: full path of the image that needs to be loaded + :type cur_selected_img: str + :param gray: whether to load the image as a grayscale or not + :type gray: bool or None, default=Nonee + :return: loaded image + :rtype: ndarray + """ + if gray is None: + gray = self.gray + try: + return imread(os.path.join(self.root_dir, cur_selected_img), as_gray=gray) + except ValueError: + return None + + def save_image(self, to_save_path: str, img: np.ndarray) -> None: + """Save given image using skimage. + + :param to_save_path: full path to the directory that the image needs to be save into (use also image name in the path, eg. '/users/new_image.png') + :type to_save_path: str + :param img: image you wish to save + :type img: ndarray + """ + imsave(os.path.join(self.root_dir, to_save_path), img) + + def search_images(self, directory: str) -> List[str]: + """Get a list of full paths of the images in the directory. + + :param directory: Path to the directory to search for images. + :type directory: str + :return: List of image paths found in the directory (only image types that are supported - see config.cfg 'setup' section). + :rtype: list + """ + # Take all segmentations of the image from the current directory: + directory = os.path.join(self.root_dir, directory) + seg_files = [ + file_name + for file_name in os.listdir(directory) + if self.seg_name_string in file_name + ] + # Take the image files - difference between the list of all the files in the directory and the list of seg files and only file extensions currently accepted + image_files = [ + os.path.join(directory, file_name) + for file_name in os.listdir(directory) + if (file_name not in seg_files) + and (helpers.get_file_extension(file_name) in self.accepted_types) + ] + return image_files + + def search_segs(self, cur_selected_img: str) -> List[str]: + """Returns a list of full paths of segmentations for an image. + + :param cur_selected_img: Full path of the image for which segmentations are needed. + :type cur_selected_img: str + :return: List of segmentation paths for the given image. + :rtype: list + """ + + # Check the directory the image was selected from: + img_directory = helpers.get_path_parent( + os.path.join(self.root_dir, cur_selected_img) + ) + # Take all segmentations of the image from the current directory: + search_string = helpers.get_path_stem(cur_selected_img) + self.seg_name_string + # seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] + # TODO: check where this is used - copied the command from app's search_segs function (to fix the 1_seg and 11_seg bug) + + seg_files = [ + os.path.join(img_directory, file_name) + for file_name in os.listdir(img_directory) + if ( + search_string == helpers.get_path_stem(file_name) + or str(file_name).startswith(search_string) + ) + ] + + return seg_files + + def get_image_seg_pairs(self, directory: str) -> List[tuple]: + """Get pairs of (image, image_seg). + + Used, e.g., in training to create training data-training labels pairs. + + :param directory: Path to the directory to search images and segmentations in. + :type directory: str + :return: List of tuple pairs (image, image_seg). + :rtype: list + """ + + image_files = self.search_images(os.path.join(self.root_dir, directory)) + seg_files = [] + for image in image_files: + seg = self.search_segs(image) + # TODO - the search seg returns all the segs, but here we need only one, hence the seg[0]. Check if it is from training path? + seg_files.append(seg[0]) + return list(zip(image_files, seg_files)) + + def get_unsupported_files(self, directory: str) -> List[str]: + """Get unsupported files found in the given directory. + + :param directory: Directory path to search for files in. + :type directory: str + :return: List of unsupported files. + :rtype: list + """ + return [ + file_name + for file_name in os.listdir(os.path.join(self.root_dir, directory)) + if not file_name.startswith(".") + and helpers.get_file_extension(file_name) not in self.accepted_types + ] + + def get_image_size_properties(self, img: np.ndarray, file_extension: str) -> None: + """Set properties of the image size + + :param img: Image (numpy array). + :type img: ndarray + :param file_extension: File extension of the image as saved in the directory. + :type file_extension: str + """ + # TODO simplify! + + orig_size = img.shape + # png and jpeg will be RGB by default and 2D + # tif can be grayscale 2D or 3D [Z, H, W] + # image channels have already been removed in imread if self.gray=True + # skimage.imread reads RGB or RGBA images in always with channel axis in dim=2 + if file_extension in (".jpg", ".jpeg", ".png") and self.gray == False: + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = 2 + elif file_extension in (".jpg", ".jpeg", ".png") and self.gray == True: + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = None + elif file_extension in (".tiff", ".tif") and len(orig_size) == 2: + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = None + # if we have 3 dimensions the [Z, H, W] + elif file_extension in (".tiff", ".tif") and len(orig_size) == 3: + print( + "Warning: 3D image stack found. We are assuming your last dimension is your channel dimension. Please cross check this." + ) + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = 2 + else: + print("File not currently supported. See documentation for accepted types") + + def rescale_image(self, img: np.ndarray, order: int = 2) -> np.ndarray: + """rescale image + + :param img: Image. + :type img: ndarray + :param order: Order of interpolation. + :type order: int + :return: Rescaled image. + :rtype: ndarray + """ + + if self.model_used == "UNet": + return pad_image( + img, self.img_height, self.img_width, self.channel_ax, dividable=16 + ) + else: + # Cellpose segmentation runs best with 512 size? TODO: check + max_dim = max(self.img_height, self.img_width) + rescale_factor = max_dim / 512 + return rescale( + img, 1 / rescale_factor, order=order, channel_axis=self.channel_ax + ) + + def resize_mask( + self, mask: np.ndarray, channel_ax: Optional[int] = None, order: int = 0 + ) -> np.ndarray: + """resize the mask so it matches the original image size + + :param mask: Image. + :type mask: ndarray + :param height: Height of the image. + :type height: int + :param width: Width of the image. + :type width: int + :param order: From scikit-image - the order of the spline interpolation. Default is 0 if image.dtype is bool and 1 otherwise. + :type order: int + :return: Resized image. + :rtype: ndarray + """ + + if self.model_used == "UNet": + # we assume an order C, H, W + if channel_ax is not None and channel_ax == 0: + height_pad = mask.shape[1] - self.img_height + width_pad = mask.shape[2] - self.img_width + return mask[:, :-height_pad, :-width_pad] + elif channel_ax is not None and channel_ax == 2: + height_pad = mask.shape[0] - self.img_height + width_pad = mask.shape[1] - self.img_width + return mask[:-height_pad, :-width_pad, :] + elif channel_ax is not None and channel_ax == 1: + height_pad = mask.shape[2] - self.img_height + width_pad = mask.shape[0] - self.img_width + return mask[:-width_pad, :, :-height_pad] + else: + height_pad = mask.shape[0] - self.img_height + width_pad = mask.shape[1] - self.img_width + return mask[:-height_pad, :-width_pad] + + else: + if channel_ax is not None: + n_channel_dim = mask.shape[channel_ax] + output_size = [self.img_height, self.img_width] + output_size.insert(channel_ax, n_channel_dim) + else: + output_size = [self.img_height, self.img_width] + return resize(mask, output_size, order=order) + + def prepare_images_and_masks_for_training( + self, train_img_mask_pairs: List[tuple] + ) -> tuple: + """Image and mask processing for training. + + :param train_img_mask_pairs: List pairs of (image, image_seg) (as returned by get_image_seg_pairs() function). + :type train_img_mask_pairs: list + :return: Lists of processed images and masks. + :rtype: tuple + """ + + imgs = [] + masks = [] + for img_file, mask_file in train_img_mask_pairs: + img = self.load_image(img_file) + img = normalise(img) + mask = self.load_image(mask_file, gray=False) + self.get_image_size_properties(img, helpers.get_file_extension(img_file)) + # Unet only accepts image sizes divisable by 16 + if self.model_used == "UNet": + img = pad_image( + img, + self.img_height, + self.img_width, + channel_ax=self.channel_ax, + dividable=16, + ) + mask = pad_image( + mask, self.img_height, self.img_width, channel_ax=0, dividable=16 + ) + if self.model_used == "CustomCellpose" and len(mask.shape) == 3: + # if we also have class mask drop it + mask = masks[0] # assuming mask_channel_axis=0 + imgs.append(img) + masks.append(mask) + return imgs, masks + + def prepare_img_for_eval(self, img_file: str) -> np.ndarray: + """Image processing for model inference. + + :param img_file: the path to the image + :type img_file: str + :return: the loaded and processed image + :rtype: np.ndarray + """ + # Load and normalise the image + img = self.load_image(img_file) + img = normalise(img) + # Get size properties + self.get_image_size_properties(img, helpers.get_file_extension(img_file)) + if self.rescale: + img = self.rescale_image(img) + return img + + def prepare_mask_for_save(self, mask: np.ndarray, channel_ax: int) -> np.ndarray: + """Prepares the mask output of the model to be saved. + + :param mask: the mask + :type mask: np.ndarray + :param channel_ax: the channel dimension of the mask + :rype channel_ax: int + :return: the ready to save mask + :rtype: np.ndarray + """ + # Resize the mask if rescaling took place before + if self.rescale is True: + if len(mask.shape) < 3: + channel_ax = None + return self.resize_mask(mask, channel_ax) + else: + return mask diff --git a/src/server/dcp_server/utils/helpers.py b/src/server/dcp_server/utils/helpers.py new file mode 100644 index 00000000..b4cb15c6 --- /dev/null +++ b/src/server/dcp_server/utils/helpers.py @@ -0,0 +1,46 @@ +from pathlib import Path +import yaml + + +def read_config(name: str, config_path: str) -> dict: + """Reads the configuration file + + :param name: name of the section you want to read (e.g. 'setup','train') + :type name: string + :param config_path: path to the configuration file + :type config_path: str + :return: dictionary from the config section given by name + :rtype: dict + """ + with open(config_path) as config_file: + config_dict = yaml.safe_load( + config_file + ) # json.load(config_file) for .cfg file + # Check if config file has main mandatory keys + assert all( + [ + i in config_dict.keys() + for i in ["setup", "service", "model", "train", "eval"] + ] + ) + return config_dict[name] + + +def get_path_stem(filepath: str) -> str: + return str(Path(filepath).stem) + + +def get_path_name(filepath: str) -> str: + return str(Path(filepath).name) + + +def get_path_parent(filepath: str) -> str: + return str(Path(filepath).parent) + + +def join_path(root_dir: str, filepath: str) -> str: + return str(Path(root_dir, filepath)) + + +def get_file_extension(file: str) -> str: + return str(Path(file).suffix) diff --git a/src/server/dcp_server/utils.py b/src/server/dcp_server/utils/processing.py similarity index 50% rename from src/server/dcp_server/utils.py rename to src/server/dcp_server/utils/processing.py index a952e73f..9c7f4b03 100644 --- a/src/server/dcp_server/utils.py +++ b/src/server/dcp_server/utils/processing.py @@ -1,93 +1,91 @@ -from pathlib import Path -import json from copy import deepcopy +from typing import List, Optional, Union import numpy as np + from scipy.ndimage import find_objects from skimage import measure -from copy import deepcopy import SimpleITK as sitk -from radiomics import shape2D - -def read_config(name, config_path = 'config.cfg') -> dict: - """ Reads the configuration file - - :param name: name of the section you want to read (e.g. 'setup','train') - :type name: string - :param config_path: path to the configuration file, defaults to 'config.cfg' - :type config_path: str, optional - :return: dictionary from the config section given by name - :rtype: dict - """ - with open(config_path) as config_file: - config_dict = json.load(config_file) - # Check if config file has main mandatory keys - assert all([i in config_dict.keys() for i in ['setup', 'service', 'model', 'train', 'eval']]) - return config_dict[name] - -def get_path_stem(filepath): - """ Returns the stem of a file path. - - :param filepath: The file path. - :type filepath: str - :return: The stem of the file path. - :rtype: str - """ - return str(Path(filepath).stem) +from radiomics import shape2D +import torch -def get_path_name(filepath): - """ Returns the name of a file or directory from a given path. +def normalise(img: np.ndarray, norm: str = "min-max") -> np.ndarray: + """Normalises the image based on the chosen method. Currently available methods are: + - min max normalisation. - :param filepath: The file path. - :type filepath: str - :return: The name of the file or directory. - :rtype: str + :param img: image to be normalised + :type img: np.ndarray + :param norm: the normalisation method to apply + :type norm: str + :return: the normalised image + :rtype: np.ndarray """ - return str(Path(filepath).name) + if norm == "min-max": + return (img - np.min(img)) / (np.max(img) - np.min(img)) -def get_path_parent(filepath): - """ Returns the parent directory of a file or directory from a given path. +def pad_image( + img: np.ndarray, + height: int, + width: int, + channel_ax: Optional[int] = None, + dividable: int = 16, +) -> np.ndarray: + """Pads the image such that it is dividable by a given number. - :param filepath: The file path. - :type filepath: str - :return: The parent directory of the file or directory. - :rtype: str + :param img: image to be padded + :type img: np.ndarray + :param height: image height + :type height: int + :param width: image width + :type width: int + :param channel_ax: + :type channel_ax: int or None + :param dividable: the number with which the new image size should be perfectly dividable by + :type dividable: int + :return: the padded image + :rtype: np.ndarray """ - return str(Path(filepath).parent) - - -def join_path(root_dir, filepath): - """ Joins a root directory with a file or directory path. - - :param root_dir: The root directory. - :type root_dir: str - :param filepath: The file or directory path. - :type filepath: str - :return: The joined path. - :rtype: str + height_pad = (height // dividable + 1) * dividable - height + width_pad = (width // dividable + 1) * dividable - width + if channel_ax == 0: + img = np.pad(img, ((0, 0), (0, height_pad), (0, width_pad))) + elif channel_ax == 2: + img = np.pad(img, ((0, height_pad), (0, width_pad), (0, 0))) + else: + img = np.pad(img, ((0, height_pad), (0, width_pad))) + return img + + +def convert_to_tensor( + imgs: List[np.ndarray], dtype: type, unsqueeze: bool = True +) -> torch.Tensor: + """Convert the imgs to tensors of type dtype and add extra dimension if input bool is true. + + :param imgs: the list of images to convert + :type img: List[np.ndarray] + :param dtype: the data type to convert the image tensor + :type dtype: type + :param unsqueeze: If True an extra dim will be added at location zero + :type unsqueeze: bool + :return: the converted image + :rtype: torch.Tensor """ - return str(Path(root_dir, filepath)) - - -def get_file_extension(file): - """ Returns the extension of a file. - - :param file: The file path. - :type file: str - :return: The extension of the file. - :rtype: str - """ - return str(Path(file).suffix) - - -def crop_centered_padded_patch(img: np.ndarray, - patch_center_xy, - patch_size, - obj_label, - mask: np.ndarray=None, - noise_intensity=None) -> np.ndarray: - """ Crops a patch from an array `x` centered at coordinates `c` with size `p`, + # Convert images tensors + imgs = torch.stack([torch.from_numpy(img.astype(dtype)) for img in imgs]) + imgs = imgs.unsqueeze(1) if imgs.ndim == 3 and unsqueeze is True else imgs + return imgs + + +def crop_centered_padded_patch( + img: np.ndarray, + patch_center_xy: tuple, + patch_size: tuple, + obj_label: int, + mask: np.ndarray = None, + noise_intensity: int = None, +) -> np.ndarray: + """Crop a patch from an array centered at coordinates patch_center_xy with size patch_size, and apply padding if necessary. :param img: the input array from which the patch will be cropped @@ -98,20 +96,19 @@ def crop_centered_padded_patch(img: np.ndarray, :type patch_size: tuple :param obj_label: the instance label of the mask at the patch :type obj_label: int - :param mask: The mask array associated with the array x. - Mask is used during training to mask out non-central elements. - For RandomForest, it is used to calculate pyradiomics features. + :param mask: The mask array associated with the array x. + Mask is used during training to mask out non-central elements. + For RandomForest, it is used to calculate pyradiomics features. :type mask: np.ndarray, optional :param noise_intensity: intensity of noise to be added to the background :type noise_intensity: float, optional - :return: the cropped patch with applied padding :rtype: np.ndarray - """ + """ height, width = patch_size # Size of the patch - img_height, img_width = img.shape[0], img.shape[1] # Size of the input image - + img_height, img_width = img.shape[0], img.shape[1] # Size of the input image + # Calculate the boundaries of the patch top = patch_center_xy[0] - height // 2 bottom = top + height @@ -125,72 +122,121 @@ def crop_centered_padded_patch(img: np.ndarray, mask_other_objs = (mask_ != obj_label) & (mask_ > 0) img[mask_other_objs] = 0 # Add random noise at locations where other objects are present if noise_intensity is given - if noise_intensity is not None: img[mask_other_objs] = np.random.normal(scale=noise_intensity, size=img[mask_other_objs].shape) + if noise_intensity is not None: + img[mask_other_objs] = np.random.normal( + scale=noise_intensity, size=img[mask_other_objs].shape + ) mask[mask_other_objs] = 0 # crop the mask - mask = mask[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] + mask = mask[ + max(top, 0) : min(bottom, img_height), + max(left, 0) : min(right, img_width), + :, + ] + + patch = img[ + max(top, 0) : min(bottom, img_height), max(left, 0) : min(right, img_width), : + ] - patch = img[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] # Calculate the required padding amounts and apply padding if necessary - if left < 0: - patch = np.hstack(( - np.random.normal(scale=noise_intensity, size=(patch.shape[0], abs(left), patch.shape[2])).astype(np.uint8), - patch)) - if mask is not None: - mask = np.hstack(( - np.zeros((mask.shape[0], abs(left), mask.shape[2])).astype(np.uint8), - mask)) + if left < 0: + patch = np.hstack( + ( + np.random.normal( + scale=noise_intensity, + size=(patch.shape[0], abs(left), patch.shape[2]), + ).astype(np.uint8), + patch, + ) + ) + if mask is not None: + mask = np.hstack( + ( + np.zeros((mask.shape[0], abs(left), mask.shape[2])).astype( + np.uint8 + ), + mask, + ) + ) # Apply padding on the right side if necessary - if right > img_width: - patch = np.hstack(( - patch, - np.random.normal(scale=noise_intensity, size=(patch.shape[0], (right - img_width), patch.shape[2])).astype(np.uint8))) - if mask is not None: - mask = np.hstack(( - mask, - np.zeros((mask.shape[0], (right - img_width), mask.shape[2])).astype(np.uint8))) + if right > img_width: + patch = np.hstack( + ( + patch, + np.random.normal( + scale=noise_intensity, + size=(patch.shape[0], (right - img_width), patch.shape[2]), + ).astype(np.uint8), + ) + ) + if mask is not None: + mask = np.hstack( + ( + mask, + np.zeros( + (mask.shape[0], (right - img_width), mask.shape[2]) + ).astype(np.uint8), + ) + ) # Apply padding on the top side if necessary - if top < 0: - patch = np.vstack(( - np.random.normal(scale=noise_intensity, size=(abs(top), patch.shape[1], patch.shape[2])).astype(np.uint8), - patch)) - if mask is not None: - mask = np.vstack(( - np.zeros((abs(top), mask.shape[1], mask.shape[2])).astype(np.uint8), - mask)) + if top < 0: + patch = np.vstack( + ( + np.random.normal( + scale=noise_intensity, + size=(abs(top), patch.shape[1], patch.shape[2]), + ).astype(np.uint8), + patch, + ) + ) + if mask is not None: + mask = np.vstack( + ( + np.zeros((abs(top), mask.shape[1], mask.shape[2])).astype(np.uint8), + mask, + ) + ) # Apply padding on the bottom side if necessary - if bottom > img_height: - patch = np.vstack(( - patch, - np.random.normal(scale=noise_intensity, size=(bottom - img_height, patch.shape[1], patch.shape[2])).astype(np.uint8))) - if mask is not None: - mask = np.vstack(( - mask, - np.zeros((bottom - img_height, mask.shape[1], mask.shape[2])).astype(np.uint8))) - - return patch, mask - - -def get_center_of_mass_and_label(mask: np.ndarray) -> np.ndarray: - """ Computes the centers of mass for each object in a mask. + if bottom > img_height: + patch = np.vstack( + ( + patch, + np.random.normal( + scale=noise_intensity, + size=(bottom - img_height, patch.shape[1], patch.shape[2]), + ).astype(np.uint8), + ) + ) + if mask is not None: + mask = np.vstack( + ( + mask, + np.zeros( + (bottom - img_height, mask.shape[1], mask.shape[2]) + ).astype(np.uint8), + ) + ) + return patch, mask + + +def get_center_of_mass_and_label(mask: np.ndarray) -> tuple: + """Computes the centers of mass for each object in a mask. :param mask: the input mask containing labeled objects :type mask: np.ndarray - - :return: + :return: - A list of tuples representing the coordinates (row, column) of the centers of mass for each object. - A list of ints representing the labels for each object in the mask. - - :rtype: + :rtype: - List [tuple] - List [int] """ # Compute the centers of mass for each labeled object in the mask - - #return [(int(x[0]), int(x[1])) - # for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1))] - + + # return [(int(x[0]), int(x[1])) + # for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1))] + centers = [] labels = [] for region in measure.regionprops(mask): @@ -198,17 +244,17 @@ def get_center_of_mass_and_label(mask: np.ndarray) -> np.ndarray: centers.append((int(center[0]), int(center[1]))) labels.append(region.label) return centers, labels - - -def get_centered_patches(img, - mask, - p_size: int, - noise_intensity=5, - mask_class=None, - include_mask=False): - """ Extracts centered patches from the input image based on the centers of objects identified in the mask. +def get_centered_patches( + img: np.ndarray, + mask: np.ndarray, + p_size: int, + noise_intensity: int = 5, + mask_class: Optional[int] = None, + include_mask: bool = False, +) -> tuple: + """Extracts centered patches from the input image based on the centers of objects identified in the mask. :param img: The input image. :type img: numpy.ndarray @@ -221,49 +267,54 @@ def get_centered_patches(img, :param mask_class: The class represented in the patch. :type mask_class: int :param include_mask: Whether or not to include the mask as an input argument to the model. - :type include_mask: bool + :type include_mask: bool :return: A tuple containing the following elements: - patches (numpy.ndarray): Extracted patches. - patch_masks (numpy.ndarray): Masks corresponding to the extracted patches. - instance_labels (list): Labels identifying each object instance in the extracted patches. - class_labels (list): Labels identifying the class of each object instance in the extracted patches. - :rtype: tuple - """ + :rtype: tuple + """ - patches, patch_masks, instance_labels, class_labels = [], [], [], [] + patches, patch_masks, instance_labels, class_labels = [], [], [], [] # if image is 2D add an additional dim for channels - if img.ndim<3: img = img[:, :, np.newaxis] - if mask.ndim<3: mask = mask[:, :, np.newaxis] + if img.ndim < 3: + img = img[:, :, np.newaxis] + if mask.ndim < 3: + mask = mask[:, :, np.newaxis] # compute center of mass of objects centers_of_mass, instance_labels = get_center_of_mass_and_label(mask) # Crop patches around each center of mass for c, obj_label in zip(centers_of_mass, instance_labels): c_x, c_y = c - patch, patch_mask = crop_centered_padded_patch(img.copy(), - (c_x, c_y), - (p_size, p_size), - obj_label, - mask=deepcopy(mask), - noise_intensity=noise_intensity) - if include_mask: + patch, patch_mask = crop_centered_padded_patch( + img.copy(), + (c_x, c_y), + (p_size, p_size), + obj_label, + mask=deepcopy(mask), + noise_intensity=noise_intensity, + ) + if include_mask is True: patch_mask = 255 * (patch_mask > 0).astype(np.uint8) patch = np.concatenate((patch, patch_mask), axis=-1) - + patches.append(patch) patch_masks.append(patch_mask) if mask_class is not None: # get the class instance for the specific object instance_labels.append(obj_label) - class_l = np.unique(mask_class[mask[:,:,0]==obj_label]) - assert class_l.shape[0] == 1, "ERROR"+str(class_l) + class_l = np.unique(mask_class[mask[:, :, 0] == obj_label]) + assert class_l.shape[0] == 1, "ERROR" + str(class_l) class_l = int(class_l[0]) - #-1 because labels from mask start from 1, we want classes to start from 0 - class_labels.append(class_l-1) - + # -1 because labels from mask start from 1, we want classes to start from 0 + class_labels.append(class_l - 1) + return patches, patch_masks, instance_labels, class_labels -def get_objects(mask): - """ Finds labeled connected components in a binary mask. + +def get_objects(mask: np.ndarray) -> List: + """Finds labeled connected components in a binary mask. :param mask: The binary mask representing objects. :type mask: numpy.ndarray @@ -272,8 +323,9 @@ def get_objects(mask): """ return find_objects(mask) -def find_max_patch_size(mask): - """ Finds the maximum patch size in a mask. + +def find_max_patch_size(mask: np.ndarray) -> float: + """Finds the maximum patch size in a mask. :param mask: The binary mask representing objects. :type mask: numpy.ndarray @@ -305,13 +357,21 @@ def find_max_patch_size(mask): # Check if the current patch size is larger than the maximum if total_size > max_patch_size: max_patch_size = total_size - + max_patch_size_edge = np.ceil(np.sqrt(max_patch_size)) return max_patch_size_edge - -def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size, include_mask): - """ Splits images and masks into patches of equal size centered around the cells. + + +def create_patch_dataset( + imgs: List[np.ndarray], + masks_classes: Optional[Union[List[np.ndarray], torch.Tensor]], + masks_instances: Optional[Union[List[np.ndarray], torch.Tensor]], + noise_intensity: int, + max_patch_size: int, + include_mask: bool, +) -> tuple: + """Splits images and masks into patches of equal size centered around the cells. :param imgs: A list of input images. :type imgs: list of numpy.ndarray or torch.Tensor @@ -320,9 +380,9 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, :param masks_instances: A list of binary masks representing instances. :type masks_instances: list of numpy.ndarray or torch.Tensor :param noise_intensity: The intensity of noise to add to the patches. - :type noise_intensity: float + :type noise_intensity: int :param max_patch_size: The maximum size of the bounding box edge for objects in the mask. - :type max_patch_size: float + :type max_patch_size: int :param include_mask: A flag indicating whether to include the mask along with patches. :type include_mask: bool :return: A tuple containing the patches, patch masks, and labels. @@ -331,30 +391,32 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, .. note:: If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use the max cell size to define the patch size. All patches and masks should then be returned - in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same + in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same convention of dims, e.g. CxHxW) """ if max_patch_size is None: max_patch_size = np.max([find_max_patch_size(mask) for mask in masks_instances]) - + patches, patch_masks, labels = [], [], [] - for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): + for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): # mask_instance has dimension WxH # mask_class has dimension WxH - patch, patch_mask, _, label = get_centered_patches(img, - mask_instance, - max_patch_size, - noise_intensity=noise_intensity, - mask_class=mask_class, - include_mask = include_mask) + patch, patch_mask, _, label = get_centered_patches( + img=img, + mask=mask_instance, + p_size=max_patch_size, + noise_intensity=noise_intensity, + mask_class=mask_class, + include_mask=include_mask, + ) patches.extend(patch) patch_masks.extend(patch_mask) - labels.extend(label) + labels.extend(label) return patches, patch_masks, labels -def get_shape_features(img, mask): - """ Calculate shape-based radiomic features from an image within the region defined by the mask. +def get_shape_features(img: np.ndarray, mask: np.ndarray) -> np.ndarray: + """Calculate shape-based radiomic features from an image within the region defined by the mask. :param img: The input image. :type img: numpy.ndarray @@ -365,18 +427,20 @@ def get_shape_features(img, mask): """ mask = 255 * ((mask) > 0).astype(np.uint8) - image = sitk.GetImageFromArray(img.squeeze()) roi_mask = sitk.GetImageFromArray(mask.squeeze()) - shape_calculator = shape2D.RadiomicsShape2D(inputImage=image, inputMask=roi_mask, label=255) + shape_calculator = shape2D.RadiomicsShape2D( + inputImage=image, inputMask=roi_mask, label=255 + ) # Calculate the shape-based radiomic features shape_features = shape_calculator.execute() return np.array(list(shape_features.values())) -def extract_intensity_features(image, mask): - """ Extracts intensity-based features from an image within the region defined by the mask. + +def extract_intensity_features(image: np.ndarray, mask: np.ndarray) -> np.ndarray: + """Extracts intensity-based features from an image within the region defined by the mask. :param image: The input image. :type image: numpy.ndarray @@ -385,40 +449,42 @@ def extract_intensity_features(image, mask): :return: An array containing the extracted intensity-based features, including median intensity, mean intensity, and 25th/75th percentile intensity within the masked region. :rtype: numpy.ndarray """ - + features = {} - + # Ensure the image and mask have the same dimensions if image.shape != mask.shape: raise ValueError("Image and mask must have the same dimensions") - masked_image = image[(mask>0)] + masked_image = image[(mask > 0)] # features["min_intensity"] = np.min(masked_image) # features["max_intensity"] = np.max(masked_image) features["median_intensity"] = np.median(masked_image) features["mean_intensity"] = np.mean(masked_image) features["25th_percentile_intensity"] = np.percentile(masked_image, 25) features["75th_percentile_intensity"] = np.percentile(masked_image, 75) - + return np.array(list(features.values())) -def create_dataset_for_rf(imgs, masks): - """ Extracts shape and intensity-based features from images within regions defined by masks. + +def create_dataset_for_rf( + imgs: List[np.ndarray], masks: List[np.ndarray] +) -> List[np.ndarray]: + """Extracts shape and intensity-based features from images within regions defined by masks. :param imgs: A list of input images. :type imgs: list :param masks: A list of corresponding masks defining regions of interest. :type masks: list :return: A list of arrays containing shape and intensity-based features. - :rtype: list + :rtype: list """ X = [] for img, mask in zip(imgs, masks): - shape_features = get_shape_features(img, mask) intensity_features = extract_intensity_features(img, mask) features_list = np.concatenate((shape_features, intensity_features), axis=0) X.append(features_list) - - return X \ No newline at end of file + + return X diff --git a/src/server/pyproject.toml b/src/server/pyproject.toml index 783e0dfb..4acd006c 100644 --- a/src/server/pyproject.toml +++ b/src/server/pyproject.toml @@ -33,7 +33,9 @@ maintainers = [ [project.optional-dependencies] dev = [ - "pytest", + "pytest>=7.4.3", + "sphinx", + "sphinx-rtd-theme" ] [project.urls] diff --git a/src/server/requirements.txt b/src/server/requirements.txt index b6e7f266..a42ba5eb 100644 --- a/src/server/requirements.txt +++ b/src/server/requirements.txt @@ -4,10 +4,7 @@ bentoml==1.0.16 scikit-image>=0.19.3 torchmetrics>=0.11.4 torch>=2.1.0 -pytest>=7.4.3 numpy scikit-learn>=1.2.2 SimpleITK>=2.2.1 -pyradiomics==3.0.1 -sphinx -sphinx-rtd-theme \ No newline at end of file +pyradiomics==3.0.1 \ No newline at end of file diff --git a/src/server/test/configs/test_config_CustomCellpose.yaml b/src/server/test/configs/test_config_CustomCellpose.yaml new file mode 100644 index 00000000..5e3c0436 --- /dev/null +++ b/src/server/test/configs/test_config_CustomCellpose.yaml @@ -0,0 +1,47 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "CustomCellpose", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "cells", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "segmentor_name": Cellpose, + "segmentor": { + "model_type": "cyto" + } + }, + + "data": { + "data_root": "data", + "gray": True, + "rescale": True + }, + + "train":{ + "segmentor":{ + "n_epochs": 20, + "channels": [0,0], + "min_train_masks": 1, + "learning_rate":0.01 + } + }, + + "eval":{ + "segmentor": { + "z_axis": null, + "channel_axis": null, + "rescale": 1, + "batch_size": 1 + }, + "mask_channel_axis": null + } +} \ No newline at end of file diff --git a/src/server/test/configs/test_config_fcnn.cfg b/src/server/test/configs/test_config_Inst2MultiSeg_PatchClassifier.yaml similarity index 64% rename from src/server/test/configs/test_config_fcnn.cfg rename to src/server/test/configs/test_config_Inst2MultiSeg_PatchClassifier.yaml index 02039f68..20e5c96a 100644 --- a/src/server/test/configs/test_config_fcnn.cfg +++ b/src/server/test/configs/test_config_Inst2MultiSeg_PatchClassifier.yaml @@ -1,49 +1,49 @@ { "setup": { "segmentation": "GeneralSegmentation", + "model_to_use": "Inst2MultiSeg", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { - "model_to_use": "CustomCellposeModel", - "save_model_path": "mito", - "runner_name": "cellpose_runner", + "runner_name": "bento_runner", + "bento_model_path": "cells", "service_name": "data-centric-platform", "port": 7010 }, "model": { + "segmentor_name": Cellpose, "segmentor": { "model_type": "cyto" }, - "classifier":{ - "model_class": "FCNN", + "classifier_name": "PatchClassifier", + "classifier":{ "in_channels": 1, "num_classes": 3, "features":[64,128,256,512], - "black_bg": "False", - "include_mask": "False" + "black_bg": False, + "include_mask": False } }, "data": { - "data_root": "data" + "data_root": "data", + "patch_size": 64, + "noise_intensity": 5, + "gray": True, + "rescale": True }, "train":{ "segmentor":{ - "n_epochs": 20, + "n_epochs": 10, "channels": [0,0], "min_train_masks": 1, "learning_rate":0.01 }, "classifier":{ - "train_data":{ - "patch_size": 64, - "noise_intensity": 5, - "num_classes": 3 - }, "n_epochs": 20, "lr": 0.005, "batch_size": 5, @@ -59,10 +59,6 @@ "batch_size": 1 }, "classifier": { - "data":{ - "patch_size": 64, - "noise_intensity": 5 - } }, "mask_channel_axis": 0 } diff --git a/src/server/test/configs/test_config_RF.cfg b/src/server/test/configs/test_config_Inst2MultiSeg_RF.yaml similarity index 54% rename from src/server/test/configs/test_config_RF.cfg rename to src/server/test/configs/test_config_Inst2MultiSeg_RF.yaml index c09c6af5..0734bcf7 100644 --- a/src/server/test/configs/test_config_RF.cfg +++ b/src/server/test/configs/test_config_Inst2MultiSeg_RF.yaml @@ -1,53 +1,44 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "CustomCellposeModel", + "model_to_use": "Inst2MultiSeg", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { - "runner_name": "cellpose_runner", - "bento_model_path": "mito", + "runner_name": "bento_runner", + "bento_model_path": "test", "service_name": "data-centric-platform", "port": 7010 }, "model": { + "segmentor_name": Cellpose, "segmentor": { "model_type": "cyto" }, + "classifier_name": "RandomForest", "classifier":{ - "model_class": "RandomForest", - "in_channels": 1, - "num_classes": 3, - "features":[64,128,256,512], - "black_bg": "False", - "include_mask": "False" } }, "data": { - "data_root": "data" + "data_root": "data", + "patch_size": 64, + "noise_intensity": 5, + "gray": True, + "rescale": True }, "train":{ "segmentor":{ - "n_epochs": 20, + "n_epochs": 10, "channels": [0,0], "min_train_masks": 1, "learning_rate":0.01 }, "classifier":{ - "train_data":{ - "patch_size": 64, - "noise_intensity": 5, - "num_classes": 3 - }, - "n_epochs": 10, - "lr": 0.001, - "batch_size": 1, - "optimizer": "Adam" } }, @@ -59,10 +50,6 @@ "batch_size": 1 }, "classifier": { - "data":{ - "patch_size": 64, - "noise_intensity": 5 - } }, "mask_channel_axis": 0 } diff --git a/src/server/test/configs/test_config_MultiCellpose.yaml b/src/server/test/configs/test_config_MultiCellpose.yaml new file mode 100644 index 00000000..46b913d7 --- /dev/null +++ b/src/server/test/configs/test_config_MultiCellpose.yaml @@ -0,0 +1,50 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "MultiCellpose", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "cells", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "segmentor_name": Cellpose, + "segmentor": { + "model_type": "cyto" + }, + "classifier":{ + "num_classes": 3 + } + }, + + "data": { + "data_root": "data", + "gray": True, + "rescale": True + }, + + "train":{ + "segmentor":{ + "n_epochs": 30, + "channels": [0,0], + "min_train_masks": 1, + "learning_rate":0.01 + } + }, + + "eval":{ + "segmentor": { + "z_axis": null, + "channel_axis": null, + "rescale": 1, + "batch_size": 1 + }, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/test/configs/test_config_UNet.yaml b/src/server/test/configs/test_config_UNet.yaml new file mode 100644 index 00000000..f4eba079 --- /dev/null +++ b/src/server/test/configs/test_config_UNet.yaml @@ -0,0 +1,46 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "UNet", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "cells", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "classifier":{ + "in_channels": 1, + "num_classes": 3, + "features":[64,128,256,512] + } + }, + + "data": { + "data_root": "data", + "gray": True, + "rescale": True + }, + + "train":{ + "classifier":{ + "n_epochs": 30, + "lr": 0.005, + "batch_size": 5, + "optimizer": "Adam" + } + }, + + "eval":{ + "classifier": { + + }, + compute_instance: True, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index 58c87ba6..5c9e0fcb 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -41,7 +41,8 @@ def assign_unique_colors(labels, colors): return label_colors -def custom_label2rgb(labels, colors=['red', 'green', 'blue'], bg_label=0, alpha=0.5): + +def custom_label2rgb(labels, colors=["red", "green", "blue"], bg_label=0, alpha=0.5): """ Converts a label array to an RGB image using assigned colors for each label. @@ -64,14 +65,17 @@ def custom_label2rgb(labels, colors=['red', 'green', 'blue'], bg_label=0, alpha= for label in np.unique(labels): mask = labels == label if label in label_colors: - rgb = color.label2rgb(mask, colors=[label_colors[label]], bg_label=bg_label, alpha=alpha) + rgb = color.label2rgb( + mask, colors=[label_colors[label]], bg_label=bg_label, alpha=alpha + ) rgb_image += rgb return rgb_image + def add_padding_for_rotation(image, angle): """ - Apply padding and rotation to an image. + Apply padding and rotation to an image. The purpose of this function is to ensure that the rotated image fits within its original dimensions by adding padding, preventing any parts of the image from being cropped. @@ -97,20 +101,25 @@ def add_padding_for_rotation(image, angle): pad_h = (new_h - h) // 2 # Add padding to the image - padded_image = cv2.copyMakeBorder(image, pad_h, pad_h, pad_w, pad_w, cv2.BORDER_CONSTANT) + padded_image = cv2.copyMakeBorder( + image, pad_h, pad_h, pad_w, pad_w, cv2.BORDER_CONSTANT + ) # Rotate the padded image center = (padded_image.shape[1] // 2, padded_image.shape[0] // 2) rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) - rotated_image = cv2.warpAffine(padded_image, rotation_matrix, (padded_image.shape[1], padded_image.shape[0])) + rotated_image = cv2.warpAffine( + padded_image, rotation_matrix, (padded_image.shape[1], padded_image.shape[0]) + ) return rotated_image + def get_object_images(objects): """ Load object images from file paths. - :param objects: A list of dictionaries containing information about the objects such as name, path, intensity + :param objects: A list of dictionaries containing information about the objects such as name, path, intensity :type objects: list[dict] :return: A list of object images loaded from the specified file paths. :rtype: list[numpy.ndarray] @@ -119,14 +128,22 @@ def get_object_images(objects): object_images = [] for obj in objects: - img = cv2.imread(obj['path']) + img = cv2.imread(obj["path"]) # img = cv2.resize(img, obj['size']) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) object_images.append(img) return object_images -def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, noise_intensity=None, max_rotation_angle=None): + +def generate_dataset( + num_samples, + objects, + canvas_size, + max_object_counts=None, + noise_intensity=None, + max_rotation_angle=None, +): """ Generate a synthetic dataset with images and masks. @@ -150,7 +167,7 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, dataset_masks = [] object_images = get_object_images(objects) - class_intensities = [ (obj['intensity'][0], obj['intensity'][1]) for obj in objects] + class_intensities = [(obj["intensity"][0], obj["intensity"][1]) for obj in objects] if len(object_images[0].shape) == 3: num_of_img_channels = object_images[0].shape[-1] @@ -161,8 +178,12 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, max_object_counts = [10] * len(object_images) for _ in range(num_samples): - canvas = np.zeros((canvas_size[0], canvas_size[1], num_of_img_channels), dtype=np.uint8) - mask = np.zeros((canvas_size[0], canvas_size[1], len(object_images)), dtype=np.uint8) + canvas = np.zeros( + (canvas_size[0], canvas_size[1], num_of_img_channels), dtype=np.uint8 + ) + mask = np.zeros( + (canvas_size[0], canvas_size[1], len(object_images)), dtype=np.uint8 + ) for object_index, object_img in enumerate(object_images): @@ -170,70 +191,104 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, object_count = random.randint(1, max_count) for _ in range(object_count): - + canvas_range = max(canvas_size) - object_size = random.randint(canvas_range//20, canvas_range//5) + object_size = random.randint(canvas_range // 20, canvas_range // 5) object_img_resized = cv2.resize(object_img, (object_size, object_size)) # object_img_resized = (object_img_resized>0).astype(np.uint8)*(255 - object_size) - intensity_mean = (class_intensities[object_index][1] - class_intensities[object_index][0])/2 - intensity_scale = (class_intensities[object_index][1] - intensity_mean)/3 - class_intensity = np.random.normal(loc=intensity_mean, scale=intensity_scale) - class_intensity = np.clip(class_intensity, class_intensities[object_index][0], class_intensities[object_index][1]) + intensity_mean = ( + class_intensities[object_index][1] + - class_intensities[object_index][0] + ) / 2 + intensity_scale = ( + class_intensities[object_index][1] - intensity_mean + ) / 3 + class_intensity = np.random.normal( + loc=intensity_mean, scale=intensity_scale + ) + class_intensity = np.clip( + class_intensity, + class_intensities[object_index][0], + class_intensities[object_index][1], + ) # class_intensity = random.randint(int(class_intensities[object_index][0]), int(class_intensities[object_index][1])) - object_img_resized = (object_img_resized>0).astype(np.uint8)*(class_intensity)*255 + object_img_resized = ( + (object_img_resized > 0).astype(np.uint8) * (class_intensity) * 255 + ) if num_of_img_channels == 1: - + if max_rotation_angle is not None: # Randomly rotate the object image - rotation_angle = random.uniform(-max_rotation_angle, max_rotation_angle) - object_img_transformed = add_padding_for_rotation(object_img_resized, rotation_angle) + rotation_angle = random.uniform( + -max_rotation_angle, max_rotation_angle + ) + object_img_transformed = add_padding_for_rotation( + object_img_resized, rotation_angle + ) else: object_img_transformed = object_img_resized - - object_size_x, object_size_y = object_img_transformed.shape - + object_size_x, object_size_y = object_img_transformed.shape object_mask = np.zeros((object_size_x, object_size_y), dtype=np.uint8) if num_of_img_channels == 1: # Grayscale image object_mask[object_img_transformed > 0] = object_index + 1 # object_img_resized = np.expand_dims(object_img_resized, axis=-1) - object_img_transformed = np.expand_dims(object_img_transformed, axis=-1) + object_img_transformed = np.expand_dims( + object_img_transformed, axis=-1 + ) else: # Color image with alpha channel object_mask[object_img_resized[:, :, -1] > 0] = object_index + 1 - x = random.randint(0, canvas_size[1] - object_size_x) y = random.randint(0, canvas_size[0] - object_size_y) - intersecting_mask = mask[y:y + object_size_y, x:x + object_size_x].max(axis=-1) + intersecting_mask = mask[ + y : y + object_size_y, x : x + object_size_x + ].max(axis=-1) if (intersecting_mask > 0).any(): continue # Skip if there is an intersection with objects from other classes - - assert mask[y:y + object_size_y, x:x + object_size_x, object_index].shape == object_mask.shape - canvas[y:y + object_size_y, x:x + object_size_x] = object_img_transformed - mask[y:y + object_size_y, x:x + object_size_x, object_index] = np.maximum( - mask[y:y + object_size_y, x:x + object_size_x, object_index], object_mask + assert ( + mask[ + y : y + object_size_y, x : x + object_size_x, object_index + ].shape + == object_mask.shape + ) + + canvas[y : y + object_size_y, x : x + object_size_x] = ( + object_img_transformed + ) + mask[y : y + object_size_y, x : x + object_size_x, object_index] = ( + np.maximum( + mask[ + y : y + object_size_y, x : x + object_size_x, object_index + ], + object_mask, + ) ) - # Add noise to the canvas if noise_intensity is not None: if num_of_img_channels == 1: - noise = np.random.normal(scale=noise_intensity, size=(canvas_size[0], canvas_size[1], 1)) + noise = np.random.normal( + scale=noise_intensity, size=(canvas_size[0], canvas_size[1], 1) + ) # noise = random_noise(canvas, mode='speckle', mean=noise_intensity) - + else: - noise = np.random.normal(scale=noise_intensity, size=(canvas_size[0], canvas_size[1], num_of_img_channels)) + noise = np.random.normal( + scale=noise_intensity, + size=(canvas_size[0], canvas_size[1], num_of_img_channels), + ) noisy_canvas = canvas + noise.astype(np.uint8) - dataset_images.append(noisy_canvas.squeeze(2)) - + dataset_images.append(noisy_canvas.squeeze(2)) + else: dataset_images.append(canvas.squeeze(2)) @@ -251,7 +306,10 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, return dataset_images, dataset_masks -def get_synthetic_dataset(num_samples, canvas_size=(512,512), max_object_counts=[15, 15, 15]): + +def get_synthetic_dataset( + num_samples, canvas_size=(512, 512), max_object_counts=[15, 15, 15] +): """Generates a synthetic dataset with images and masks. :param num_samples: The number of samples to generate. @@ -264,23 +322,21 @@ def get_synthetic_dataset(num_samples, canvas_size=(512,512), max_object_counts= :rtype: tuple """ objects = [ - { - - 'name': 'triangle', - 'path': 'test/shapes/triangle.png', - 'intensity' : [0, 0.33] - }, - { - 'name': 'circle', - 'path': 'test/shapes/circle.png', - 'intensity' : [0.34, 0.66] - }, - { - 'name': 'square', - 'path': 'test/shapes/square.png', - 'intensity' : [0.67, 1.0] - }, + { + "name": "triangle", + "path": "test/shapes/triangle.png", + "intensity": [0, 0.33], + }, + {"name": "circle", "path": "test/shapes/circle.png", "intensity": [0.34, 0.66]}, + {"name": "square", "path": "test/shapes/square.png", "intensity": [0.67, 1.0]}, ] - - images, masks = generate_dataset(num_samples, objects, canvas_size=canvas_size, max_object_counts=max_object_counts, noise_intensity=5, max_rotation_angle=30) + + images, masks = generate_dataset( + num_samples, + objects, + canvas_size=canvas_size, + max_object_counts=max_object_counts, + noise_intensity=5, + max_rotation_angle=30, + ) return images, masks diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 3a28a3e2..6e37ea22 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -1,73 +1,148 @@ import sys + +sys.path.append(".") + from glob import glob -import inspect +import pytest + +# import inspect import random import numpy as np - -import torch +import torch from torchmetrics import JaccardIndex -# from importlib.machinery import SourceFileLoader - -sys.path.append(".") - -import dcp_server.models as models -from dcp_server.utils import read_config +from dcp_server.models import * +from dcp_server.utils.helpers import read_config from synthetic_dataset import get_synthetic_dataset -import pytest - seed_value = 2023 random.seed(seed_value) torch.manual_seed(seed_value) np.random.seed(seed_value) -# retrieve models names -model_classes = [ - cls_obj for cls_name, cls_obj in inspect.getmembers(models) \ - if inspect.isclass(cls_obj) \ - and cls_obj.__module__ == models.__name__ \ - and not cls_name.startswith("CellClassifier") - ] +model_mapping = { + "CustomCellpose": CustomCellpose, + "Inst2MultiSeg": Inst2MultiSeg, + "MultiCellpose": MultiCellpose, + "UNet": UNet, +} -config_paths = glob("test/configs/*.cfg") +config_paths = glob("test/configs/*.yaml") -@pytest.fixture(params=model_classes) -def model_class(request): - return request.param @pytest.fixture(params=config_paths) def config_path(request): return request.param -@pytest.fixture() -def model(model_class, config_path): - - model_config = read_config('model', config_path=config_path) - train_config = read_config('train', config_path=config_path) - eval_config = read_config('eval', config_path=config_path) - - model = model_class(model_config, train_config, eval_config, str(model_class)) +@pytest.fixture() +# def model(model_class, config_path): +def model(config_path): + + setup_config = read_config("setup", config_path=config_path) + model_config = read_config("model", config_path=config_path) + data_config = read_config("data", config_path=config_path) + train_config = read_config("train", config_path=config_path) + eval_config = read_config("eval", config_path=config_path) + + model_name = setup_config["model_to_use"] + model_class = model_mapping.get(model_name) + model = model_class( + model_name, model_config, data_config, train_config, eval_config + ) + # str(model_class) return model + @pytest.fixture def data_train(): - images, masks = get_synthetic_dataset(num_samples=4, canvas_size=(512,768)) + images, masks = get_synthetic_dataset(num_samples=4, canvas_size=(512, 768)) masks = [np.array(mask) for mask in masks] masks_instances = [mask.sum(-1) for mask in masks] masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] - masks_ = [np.stack((instances, classes)) for instances, classes in zip(masks_instances, masks_classes)] + masks_ = [ + np.stack((instances, classes)) + for instances, classes in zip(masks_instances, masks_classes) + ] return images, masks_ + @pytest.fixture -def data_eval(): +def data_eval(): img, msk = get_synthetic_dataset(num_samples=1) msk = np.array(msk) - msk_ = np.stack((msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0).transpose(1,0,2,3) + msk_ = np.stack( + (msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0 + ).transpose(1, 0, 2, 3) return img, msk_ + +def test_train_eval_run(data_train, data_eval, model): + """ + Performs testing, training, and evaluation with the provided data and model. + """ + # train + images, masks = data_train + if model.model_name == "CustomCellpose": + masks = [mask[0] for mask in masks] + model.train(images, masks) + + # retrieve the attribute names of the class of the current model + attrs = model.__dict__.keys() + + if "metric" in attrs: + assert model.metric > 0.1 + if "loss" in attrs: + assert model.loss < 0.83 + + # validate + imgs_test, masks_test = data_eval + if model.model_name == "CustomCellpose": + masks = [mask[0] for mask in masks_test] + + jaccard_index_instances = 0 + jaccard_index_classes = 0 + + jaccard_metric_binary = JaccardIndex( + task="multiclass", num_classes=2, average="macro", ignore_index=0 + ) + jaccard_metric_multi = JaccardIndex( + task="multiclass", num_classes=4, average="macro", ignore_index=0 + ) + + for img, mask in zip(imgs_test, masks_test): + + # mask - instance segmentation mask + classes (2, 512, 512) + # pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) + + pred_mask = model.eval(img) + + if pred_mask.ndim > 2: + pred_mask_bin = torch.tensor((pred_mask[0] > 0).astype(bool).astype(int)) + else: + pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) + + bin_mask = torch.tensor((mask[0] > 0).astype(bool).astype(int)) + + jaccard_index_instances += jaccard_metric_binary(pred_mask_bin, bin_mask) + + if pred_mask.ndim > 2: + + jaccard_index_classes += jaccard_metric_multi( + torch.tensor(pred_mask[1].astype(int)), + torch.tensor(mask[1].astype(int)), + ) + + jaccard_index_instances /= len(imgs_test) + assert jaccard_index_instances > 0.2 + + if pred_mask.ndim > 2: + + jaccard_index_classes /= len(imgs_test) + assert jaccard_index_classes > 0.1 + + # def test_train_run(data_train, model): # images, masks = data_train @@ -83,12 +158,12 @@ def data_eval(): # assert(model.metric>0.1) # if "loss" in attrs: # assert(model.loss<0.3) - + # def test_eval_run(data_train, data_eval, model): # images, masks = data_train # model.train(images, masks) - + # imgs_test, masks_test = data_eval # jaccard_index_instances = 0 @@ -103,7 +178,7 @@ def data_eval(): # #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) # pred_mask = model.eval(img) #, channels=[0,0]) - + # if pred_mask.ndim > 2: # pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) # else: @@ -112,81 +187,22 @@ def data_eval(): # bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) # jaccard_index_instances += jaccard_metric_binary( -# pred_mask_bin, +# pred_mask_bin, # bin_mask # ) # if pred_mask.ndim > 2: # jaccard_index_classes += jaccard_metric_multi( -# torch.tensor(pred_mask[1].astype(int)), +# torch.tensor(pred_mask[1].astype(int)), # torch.tensor(mask[1].astype(int)) # ) - + # jaccard_index_instances /= len(imgs_test) # assert(jaccard_index_instances>0.2) -# # for PatchCNN model +# # for PatchCNN model # if pred_mask.ndim > 2: # jaccard_index_classes /= len(imgs_test) # assert(jaccard_index_classes>0.1) - -def test_train_eval_run(data_train, data_eval, model): - """ - Performs testing, training, and evaluation with the provided data and model. - """ - - images, masks = data_train - model.train(images, masks) - - imgs_test, masks_test = data_eval - - jaccard_index_instances = 0 - jaccard_index_classes = 0 - - jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) - jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) - - for img, mask in zip(imgs_test, masks_test): - - #mask - instance segmentation mask + classes (2, 512, 512) - #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) - - pred_mask = model.eval(img) - - if pred_mask.ndim > 2: - pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) - else: - pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) - - bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) - - jaccard_index_instances += jaccard_metric_binary( - pred_mask_bin, - bin_mask - ) - - if pred_mask.ndim > 2: - - jaccard_index_classes += jaccard_metric_multi( - torch.tensor(pred_mask[1].astype(int)), - torch.tensor(mask[1].astype(int)) - ) - - jaccard_index_instances /= len(imgs_test) - assert(jaccard_index_instances>0.2) - - # retrieve the attribute names of the class of the current model - attrs = model.__dict__.keys() - - if "metric" in attrs: - assert(model.metric>0.1) - if "loss" in attrs: - assert(model.loss<0.75) - - # for PatchCNN model - if pred_mask.ndim > 2: - - jaccard_index_classes /= len(imgs_test) - assert(jaccard_index_classes>0.1) \ No newline at end of file diff --git a/src/server/test/test_models.py b/src/server/test/test_models.py index 7a91fa9a..eddf8f94 100644 --- a/src/server/test/test_models.py +++ b/src/server/test/test_models.py @@ -2,41 +2,33 @@ import numpy as np import dcp_server.models as models -from dcp_server.utils import read_config +from dcp_server.models.classifiers import FeatureClassifier +from dcp_server.utils.helpers import read_config + def test_eval_rf_not_fitted(): """ Tests the evaluation of a random forest model that has not been fitted. """ - model_config = read_config('model', config_path='test/configs/test_config_RF.cfg') - train_config = read_config('train', config_path='test/configs/test_config_RF.cfg') - eval_config = read_config('eval', config_path='test/configs/test_config_RF.cfg') - - model_rf = models.CellClassifierShallowModel(model_config,train_config,eval_config) - - X_test = np.array([[1, 2, 3]]) + model_config = read_config( + "model", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + data_config = read_config( + "data", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + train_config = read_config( + "train", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + eval_config = read_config( + "eval", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + + model_rf = FeatureClassifier( + "Random Forest", model_config, data_config, train_config, eval_config + ) + + X_test = np.array([[1, 2, 3]]) # if we don't fit the model then the model returns zeros - assert np.all(model_rf.eval(X_test)== np.zeros(X_test.shape)) - -def test_update_configs(): - """ - Tests the update of model training and evaluation configurations. - """ - - model_config = read_config('model', config_path='test/configs/test_config_RF.cfg') - train_config = read_config('train', config_path='test/configs/test_config_RF.cfg') - eval_config = read_config('eval', config_path='test/configs/test_config_RF.cfg') - - model = models.CustomCellposeModel(model_config,train_config,eval_config, "Cellpose") - - new_train_config = {"param1": "value1"} - new_eval_config = {"param2": "value2"} - - model.update_configs(new_train_config, new_eval_config) - - assert model.train_config == new_train_config - assert model.eval_config == new_eval_config - - + assert np.all(model_rf.eval(X_test) == np.zeros(X_test.shape)) diff --git a/src/server/test/test_utils.py b/src/server/test/test_utils.py index 35678a22..b0c4f71f 100644 --- a/src/server/test/test_utils.py +++ b/src/server/test/test_utils.py @@ -1,6 +1,7 @@ import numpy as np import pytest -from dcp_server.utils import find_max_patch_size +from dcp_server.utils.processing import find_max_patch_size + @pytest.fixture def sample_mask(): @@ -9,12 +10,9 @@ def sample_mask(): mask[7:9, 2:5] = 1 return mask + def test_find_max_patch_size(sample_mask): # Test when the function is called with a sample mask result = find_max_patch_size(sample_mask) assert isinstance(result, float) assert result > 0 - - - -