diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 255b0ac..3c45b11 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -49,6 +49,7 @@ jobs: pip install pytest-xvfb pip install coverage pip install -e ".[testing]" + pip install matplotlib working-directory: src/client - name: Install server dependencies (for communication tests) diff --git a/src/client/dcp_client/app.py b/src/client/dcp_client/app.py index 912bd01..c95c35e 100644 --- a/src/client/dcp_client/app.py +++ b/src/client/dcp_client/app.py @@ -62,7 +62,7 @@ def __init__( self.inprogr_data_path = inprogr_data_path self.cur_selected_img = '' self.cur_selected_path = '' - self.seg_filepaths = [] + self.seg_filepaths = [] def upload_data_to_server(self): """ @@ -155,4 +155,4 @@ def delete_images(self, image_names): if os.path.exists(os.path.join(self.cur_selected_path, image_name)): self.fs_image_storage.delete_image(self.cur_selected_path, image_name) - + diff --git a/src/client/dcp_client/gui/main_window.py b/src/client/dcp_client/gui/main_window.py index 2dc4cb6..d6d03a7 100644 --- a/src/client/dcp_client/gui/main_window.py +++ b/src/client/dcp_client/gui/main_window.py @@ -212,7 +212,7 @@ def on_launch_napari_button_clicked(self): _ = self.create_warning_box(message_text, message_title="Warning") else: self.nap_win = NapariWindow(self.app) - self.nap_win.show() + self.nap_win.show() def on_finished(self, result): ''' diff --git a/src/client/dcp_client/gui/napari_window.py b/src/client/dcp_client/gui/napari_window.py index 74fbdcc..f1f9846 100644 --- a/src/client/dcp_client/gui/napari_window.py +++ b/src/client/dcp_client/gui/napari_window.py @@ -1,13 +1,14 @@ from __future__ import annotations -from typing import List, TYPE_CHECKING +from typing import TYPE_CHECKING +from copy import deepcopy -from PyQt5.QtWidgets import QWidget, QPushButton, QVBoxLayout, QHBoxLayout +from qtpy.QtWidgets import QPushButton, QComboBox, QLabel, QGridLayout +from qtpy.QtCore import Qt import napari if TYPE_CHECKING: from dcp_client.app import Application - -from dcp_client.utils.utils import get_path_stem +from dcp_client.utils.utils import get_path_stem, check_equal_arrays, Compute4Mask from dcp_client.gui._my_widget import MyWidget class NapariWindow(MyWidget): @@ -25,35 +26,197 @@ def __init__(self, app: Application): # Load image and get corresponding segmentation filenames img = self.app.load_image() self.app.search_segs() + self.seg_files = self.app.seg_filepaths.copy() # Set the viewer self.viewer = napari.Viewer(show=False) self.viewer.add_image(img, name=get_path_stem(self.app.cur_selected_img)) - for seg_file in self.app.seg_filepaths: + for seg_file in self.seg_files: self.viewer.add_labels(self.app.load_image(seg_file), name=get_path_stem(seg_file)) main_window = self.viewer.window._qt_window - layout = QVBoxLayout() - layout.addWidget(main_window) + layout = QGridLayout() + layout.addWidget(main_window, 0, 0, 1, 4) + + # select the first seg as the currently selected layer if there are any segs + if len(self.seg_files): + self.cur_selected_seg = self.viewer.layers.selection.active.name + self.layer = self.viewer.layers[self.cur_selected_seg] + self.viewer.layers.selection.events.changed.connect(self.on_seg_channel_changed) + # set first mask as active by default + self.active_mask_index = 0 + self.viewer.dims.events.current_step.connect(self.axis_changed) + self.original_instance_mask = {} + self.original_class_mask = {} + self.instances = {} + self.contours_mask = {} + for seg_file in self.seg_files: + layer_name = get_path_stem(seg_file) + # get unique instance labels for each seg + self.original_instance_mask[layer_name] = deepcopy(self.viewer.layers[layer_name].data[0]) + self.original_class_mask[layer_name] = deepcopy(self.viewer.layers[layer_name].data[1]) + # compute unique instance ids + self.instances[layer_name] = Compute4Mask.get_unique_objects(self.original_instance_mask[layer_name]) + # remove border from class mask + self.contours_mask[layer_name] = Compute4Mask.get_contours(self.original_instance_mask[layer_name]) + self.viewer.layers[layer_name].data[1][self.contours_mask[layer_name]!=0] = 0 + + self.qctrl = self.viewer.window.qt_viewer.controls.widgets[self.layer] + + if self.layer.data.shape[0] >= 2: + # User hint + message_label = QLabel('Choose an active mask') + message_label.setAlignment(Qt.AlignRight) + layout.addWidget(message_label, 1, 0) + + # Drop list to choose which is an active mask + self.mask_choice_dropdown = QComboBox() + self.mask_choice_dropdown.setEnabled(False) + self.mask_choice_dropdown.addItem('Instance Segmentation Mask', userData=0) + self.mask_choice_dropdown.addItem('Labels Mask', userData=1) + layout.addWidget(self.mask_choice_dropdown, 1, 1) - buttons_layout = QHBoxLayout() + # when user has chosen the mask, we don't want to change it anymore to avoid errors + lock_button = QPushButton("Confirm Final Choice") + lock_button.setEnabled(False) + lock_button.clicked.connect(self.set_editable_mask) + layout.addWidget(lock_button, 1, 2) + else: + self.layer = None + + # add buttons for moving images to other dirs add_to_inprogress_button = QPushButton('Move to \'Curatation in progress\' folder') - buttons_layout.addWidget(add_to_inprogress_button) + layout.addWidget(add_to_inprogress_button, 2, 0, 1, 2) add_to_inprogress_button.clicked.connect(self.on_add_to_inprogress_button_clicked) - + add_to_curated_button = QPushButton('Move to \'Curated dataset\' folder') - buttons_layout.addWidget(add_to_curated_button) + layout.addWidget(add_to_curated_button, 2, 2, 1, 2) add_to_curated_button.clicked.connect(self.on_add_to_curated_button_clicked) - layout.addLayout(buttons_layout) - self.setLayout(layout) - self.show() + + def set_editable_mask(self): + """ + This function is not implemented. In theory the use can choose between which mask to edit. + Currently painting and erasing is only possible on instance mask and in the class mask only + the class labels can be changed. + """ + pass + + def on_seg_channel_changed(self, event): + """ + Is triggered each time the user selects a different layer in the viewer. + """ + if (act := self.viewer.layers.selection.active) is not None: + # updater cur_selected_seg with the new selection from the user + self.cur_selected_seg = act.name + if type(self.viewer.layers[self.cur_selected_seg]) == napari.layers.Image: pass + # set self.layer to new selection from user + elif self.layer is not None: self.layer = self.viewer.layers[self.cur_selected_seg] + else: pass + + def axis_changed(self, event): + """ + Is triggered each time the user switches the viewer between the mask channels. At this point the class mask + needs to be updated according to the changes made tot the instance segmentation mask. + """ + self.active_mask_index = self.viewer.dims.current_step[0] + masks = deepcopy(self.layer.data) + # if user has switched to the instance mask + if self.active_mask_index==0: + class_mask_with_contours = Compute4Mask.add_contour(masks[1], masks[0], self.contours_mask[self.cur_selected_seg]) + if not check_equal_arrays(class_mask_with_contours.astype(bool), self.original_class_mask[self.cur_selected_seg].astype(bool)): + self.update_instance_mask(masks[0], masks[1]) + self.switch_to_instance_mask() + # else if user has switched to the class mask + elif self.active_mask_index==1: + if not check_equal_arrays(masks[0], self.original_instance_mask[self.cur_selected_seg]): + self.update_labels_mask(masks[0]) + self.switch_to_labels_mask() + + def switch_to_instance_mask(self): + """ + Switch the application to the active mask mode by enabling 'paint_button', 'erase_button' + and 'fill_button'. + """ + self.switch_controls("paint_button", True) + self.switch_controls("erase_button", True) + self.switch_controls("fill_button", True) + + def switch_to_labels_mask(self): + """ + Switch the application to non-active mask mode by enabling 'fill_button' and disabling 'paint_button' and 'erase_button'. + """ + if self.cur_selected_seg in [layer.name for layer in self.viewer.layers]: + self.viewer.layers[self.cur_selected_seg].mode = 'pan_zoom' + info_message_paint = "Painting objects is only possible in the instance layer for now." + info_message_erase = "Erasing objects is only possible in the instance layer for now." + self.switch_controls("paint_button", False, info_message_paint) + self.switch_controls("erase_button", False, info_message_erase) + self.switch_controls("fill_button", True) + + def update_labels_mask(self, instance_mask): + """ + If the instance mask has changed since the last switch between channels the class mask needs to be updated accordingly. + + Parameters: + - instance_mask (numpy.ndarray): The updated instance mask, changed by the user. + - labels_mask (numpy.ndarray): The existing labels mask, which needs to be updated. + """ + self.original_class_mask[self.cur_selected_seg] = Compute4Mask.compute_new_labels_mask(self.original_class_mask[self.cur_selected_seg], + instance_mask, + self.original_instance_mask[self.cur_selected_seg], + self.instances[self.cur_selected_seg]) + # update original instance mask and instances + self.original_instance_mask[self.cur_selected_seg] = instance_mask + self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects(self.original_instance_mask[self.cur_selected_seg]) + # compute contours to remove from class mask visualisation + self.contours_mask[self.cur_selected_seg] = Compute4Mask.get_contours(instance_mask) + vis_labels_mask = deepcopy(self.original_class_mask[self.cur_selected_seg]) + vis_labels_mask[self.contours_mask[self.cur_selected_seg]!=0] = 0 + # update the viewer + self.layer.data[1] = vis_labels_mask + self.layer.refresh() + + def update_instance_mask(self, instance_mask, labels_mask): + """ + If the labels mask has changed **only if an object has been removed** the instance mask is updated. + + Parameters: + - instance_mask (numpy.ndarray): The existing instance mask, which needs to be updated. + - labels_mask (numpy.ndarray): The updated labels mask, changed by the user. + """ + # add contours back to labels mask + labels_mask = Compute4Mask.add_contour(labels_mask, instance_mask, self.contours_mask[self.cur_selected_seg]) + # and compute the updated instance mask + self.original_instance_mask[self.cur_selected_seg] = Compute4Mask.compute_new_instance_mask(labels_mask, + instance_mask) + self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects(self.original_instance_mask[self.cur_selected_seg]) + self.original_class_mask[self.cur_selected_seg] = labels_mask + # update the viewer + self.layer.data[0] = self.original_instance_mask[self.cur_selected_seg] + self.layer.refresh() + + def switch_controls(self, target_widget, status: bool, info_message=None): + """ + Enable or disable a specific widget. + + Parameters: + - target_widget (str): The name of the widget to be controlled within the QCtrl object. + - status (bool): If True, the widget will be enabled; if False, it will be disabled. + - info_message (str or None): Optionally add an info message when hovering over some widget. + """ + try: + getattr(self.qctrl, target_widget).setEnabled(status) + if info_message is not None: + getattr(self.qctrl, target_widget).setToolTip(info_message) + except: + pass def on_add_to_curated_button_clicked(self): ''' - Defines what happens when the button is clicked. + Defines what happens when the "Move to curated dataset folder" button is clicked. ''' if self.app.cur_selected_path == str(self.app.train_data_path): message_text = "Image is already in the \'Curated data\' folder and should not be changed again" @@ -61,29 +224,35 @@ def on_add_to_curated_button_clicked(self): return # take the name of the currently selected layer (by the user) - cur_seg_selected = self.viewer.layers.selection.active.name + seg_name_to_save = self.viewer.layers.selection.active.name # TODO if more than one item is selected this will break! - if '_seg' not in cur_seg_selected: - message_text = "Please select the segmenation you wish to save from the layer list" + if '_seg' not in seg_name_to_save: + message_text = ( + "Please select the segmenation you wish to save from the layer list." + "The labels layer should have the same name as the image to which it corresponds, followed by _seg." + ) _ = self.create_warning_box(message_text, message_title="Warning") return - seg = self.viewer.layers[cur_seg_selected].data # Move original image self.app.move_images(self.app.train_data_path) - + # Save the (changed) seg - self.app.save_image(self.app.train_data_path, cur_seg_selected+'.tiff', seg) + seg = self.viewer.layers[seg_name_to_save].data + contours = Compute4Mask.get_contours(seg[0]) + seg[1] = Compute4Mask.add_contour(seg[1], seg[0], contours) + self.app.save_image(self.app.train_data_path, seg_name_to_save+'.tiff', seg) # We remove seg from the current directory if it exists (both eval and inprogr allowed) - self.app.delete_images(self.app.seg_filepaths) + self.app.delete_images(self.seg_files) # TODO Create the Archive folder for the rest? Or move them as well? + self.viewer.close() self.close() def on_add_to_inprogress_button_clicked(self): ''' - Defines what happens when the button is clicked. + Defines what happens when the "Move to curation in progress folder" button is clicked. ''' # TODO: Do we allow this? What if they moved it by mistake? User can always manually move from their folders?) if self.app.cur_selected_path == str(self.app.train_data_path): @@ -92,18 +261,24 @@ def on_add_to_inprogress_button_clicked(self): return # take the name of the currently selected layer (by the user) - cur_seg_selected = self.viewer.layers.selection.active.name + seg_name_to_save = self.viewer.layers.selection.active.name # TODO if more than one item is selected this will break! - if '_seg' not in cur_seg_selected: - message_text = "Please select the segmenation you wish to save from the layer list" + if '_seg' not in seg_name_to_save: + message_text = ( + "Please select the segmenation you wish to save from the layer list." + "The labels layer should have the same name as the image to which it corresponds, followed by _seg." + ) _ = self.create_warning_box(message_text, message_title="Warning") return # Move original image self.app.move_images(self.app.inprogr_data_path, move_segs=True) - # Save the (changed) seg - this will overwrite existing seg if seg name hasn't been changed in viewer - seg = self.viewer.layers[cur_seg_selected].data - self.app.save_image(self.app.inprogr_data_path, cur_seg_selected+'.tiff', seg) + seg = self.viewer.layers[seg_name_to_save].data + contours = Compute4Mask.get_contours(seg[0]) + seg[1] = Compute4Mask.add_contour(seg[1], seg[0], contours) + self.app.save_image(self.app.inprogr_data_path, seg_name_to_save+'.tiff', seg) - self.close() \ No newline at end of file + self.viewer.close() + self.close() + \ No newline at end of file diff --git a/src/client/dcp_client/gui/welcome_window.py b/src/client/dcp_client/gui/welcome_window.py index f32d5a7..9d4b536 100644 --- a/src/client/dcp_client/gui/welcome_window.py +++ b/src/client/dcp_client/gui/welcome_window.py @@ -1,8 +1,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from PyQt5.QtWidgets import QPushButton, QVBoxLayout, QHBoxLayout, QLabel, QFileDialog, QLineEdit -from PyQt5.QtCore import Qt +from qtpy.QtWidgets import QPushButton, QVBoxLayout, QHBoxLayout, QLabel, QFileDialog, QLineEdit +from qtpy.QtCore import Qt from dcp_client.gui.main_window import MainWindow from dcp_client.gui._my_widget import MyWidget @@ -43,8 +43,14 @@ def __init__(self, app: Application): self.text_layout.addWidget(train_label) self.val_textbox = QLineEdit(self) + self.val_textbox.textEdited.connect(lambda x: self.on_text_changed(self.val_textbox, "eval", x)) + self.inprogr_textbox = QLineEdit(self) + self.inprogr_textbox.textEdited.connect(lambda x: self.on_text_changed(self.inprogr_textbox, "inprogress", x)) + self.train_textbox = QLineEdit(self) + self.train_textbox.textEdited.connect(lambda x: self.on_text_changed(self.train_textbox, "train", x)) + self.path_layout.addWidget(self.val_textbox) self.path_layout.addWidget(self.inprogr_textbox) self.path_layout.addWidget(self.train_textbox) @@ -107,6 +113,21 @@ def browse_train_clicked(self): self.app.train_data_path = fd.selectedFiles()[0] self.train_textbox.setText(self.app.train_data_path) + def on_text_changed(self, field_obj, field_name, text): + ''' + Update data paths based on text changes in input fields. + Used for copying paths in the welcome window. + ''' + + if field_name == "train": + self.app.train_data_path = text + elif field_name == "eval": + self.app.eval_data_path = text + elif field_name == "inprogress": + self.app.inprogr_data_path = text + field_obj.setText(text) + + def browse_inprogr_clicked(self): ''' @@ -122,15 +143,20 @@ def browse_inprogr_clicked(self): def start_main(self): ''' - Starts the main window after the user clicks 'Start' and only if both evaluation and train directories are chosen. + Starts the main window after the user clicks 'Start' and only if both evaluation and train directories are chosen and all unique. ''' - if self.app.train_data_path and self.app.eval_data_path: + if len({self.app.inprogr_data_path, self.app.train_data_path, self.app.eval_data_path})<3: + self.message_text = "All directory names must be distinct." + _ = self.create_warning_box(self.message_text, message_title="Warning") + + + elif self.app.train_data_path and self.app.eval_data_path: self.hide() self.mw = MainWindow(self.app) else: - message_text = "You need to specify a folder both for your uncurated and curated dataset (even if the curated folder is currently empty). Please go back and select folders for both." - _ = self.create_warning_box(message_text, message_title="Warning") + self.message_text = "You need to specify a folder both for your uncurated and curated dataset (even if the curated folder is currently empty). Please go back and select folders for both." + _ = self.create_warning_box(self.message_text, message_title="Warning") def start_upload_and_main(self): ''' diff --git a/src/client/dcp_client/utils/fsimagestorage.py b/src/client/dcp_client/utils/fsimagestorage.py index d4b10e2..98af9af 100644 --- a/src/client/dcp_client/utils/fsimagestorage.py +++ b/src/client/dcp_client/utils/fsimagestorage.py @@ -10,6 +10,7 @@ def load_image(self, from_directory, cur_selected_img): return imread(os.path.join(from_directory, cur_selected_img)) def move_image(self, from_directory, to_directory, cur_selected_img): + print(f"from:{os.path.join(from_directory, cur_selected_img)}, to:{os.path.join(to_directory, cur_selected_img)}") os.replace(os.path.join(from_directory, cur_selected_img), os.path.join(to_directory, cur_selected_img)) def save_image(self, to_directory, cur_selected_img, img): diff --git a/src/client/dcp_client/utils/utils.py b/src/client/dcp_client/utils/utils.py index 0560d8c..51dc71a 100644 --- a/src/client/dcp_client/utils/utils.py +++ b/src/client/dcp_client/utils/utils.py @@ -1,6 +1,11 @@ from PyQt5.QtWidgets import QFileIconProvider from PyQt5.QtCore import QSize from PyQt5.QtGui import QPixmap, QIcon +import numpy as np +from skimage.feature import canny +from skimage.morphology import closing, square +from skimage.measure import find_contours +from skimage.draw import polygon_perimeter from pathlib import Path, PurePath import json @@ -50,3 +55,127 @@ def get_path_parent(filepath): return str(Path(filepath).parent) def join_path(root_dir, filepath): return str(Path(root_dir, filepath)) +def check_equal_arrays(array1, array2): + return np.array_equal(array1, array2) + +class Compute4Mask: + + @staticmethod + def get_contours(instance_mask): + ''' + Find contours of objects in the instance mask. + This function is used to identify the contours of the objects to prevent + the problem of the merged objects in napari window (mask). + + Parameters: + - instance_mask (numpy.ndarray): The instance mask array. + + Returns: + - contour_mask (numpy.ndarray): A binary mask where the contours of all objects in the instance segmentation mask are one and the rest is background. + ''' + instance_ids = Compute4Mask.get_unique_objects(instance_mask) # get object instance labels ignoring background + contour_mask= np.zeros_like(instance_mask) + for instance_id in instance_ids: + # get a binary mask only of object + single_obj_mask = np.zeros_like(instance_mask) + single_obj_mask[instance_mask==instance_id] = 1 + # compute contours for mask + contours = find_contours(single_obj_mask, 0.8) + # sometimes little dots appeas as additional contours so remove these + if len(contours)>1: + contour_sizes = [contour.shape[0] for contour in contours] + contour = contours[contour_sizes.index(max(contour_sizes))].astype(int) + else: contour = contours[0] + # and draw onto contours mask + rr, cc = polygon_perimeter(contour[:, 0], contour[:, 1], contour_mask.shape) + contour_mask[rr, cc] = instance_id + return contour_mask + + @staticmethod + def add_contour(labels_mask, instance_mask, contours_mask): + ''' + Add contours of objects to the labels mask. + + Parameters: + - labels_mask (numpy.ndarray): The class mask array without the contour pixels annotated. + - instance_mask (numpy.ndarray): The instance mask array. + - contours_mask (numpy.ndarray): The contours mask array, where each contour holds the instance_id. + + Returns: + - labels_mask (numpy.ndarray): The updated class mask including contours. + ''' + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + where_instances = np.where(instance_mask==instance_id) + class_vals = Compute4Mask.get_unique_objects(labels_mask[where_instances]) + if len(class_vals)==0: continue + else: + labels_mask[np.where(contours_mask==instance_id)] = class_vals[-1] + return labels_mask + + + @staticmethod + def compute_new_instance_mask(labels_mask, instance_mask): + ''' + Given an updated labels mask, update also the instance mask accordingly. So far the user can only remove an entire object in the labels mask view. + Therefore the instance mask can only change by entirely removing an object. + + Parameters: + - labels_mask (numpy.ndarray): The labels mask array, with changes made by the user. + - instance_mask (numpy.ndarray): The existing instance mask, which needs to be updated. + Returns: + - instance_mask (numpy.ndarray): The updated instance mask. + ''' + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + unique_items_in_class_mask = list(np.unique(labels_mask[instance_mask==instance_id])) + if len(unique_items_in_class_mask)==1 and unique_items_in_class_mask[0]==0: + instance_mask[instance_mask==instance_id] = 0 + return instance_mask + + + @staticmethod + def compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances): + ''' + Given the existing labels mask, the updated instance mask is used to update the labels mask. + + Parameters: + - labels_mask (numpy.ndarray): The existing labels mask, which needs to be updated. + - instance_mask (numpy.ndarray): The instance mask array, with changes made by the user. + - original_instance_mask (numpy.ndarray): The instance mask array, before the changes made by the user. + - old_instances (List): A list of the instance label ids in original_instance_mask. + Returns: + - new_labels_mask (numpy.ndarray): The new labels mask, with updated changes according to those the user has made in the instance mask. + ''' + new_labels_mask = np.zeros_like(labels_mask) + for instance_id in np.unique(instance_mask): + where_instance = np.where(instance_mask==instance_id) + # if the label is background skip + if instance_id==0: continue + # if the label is a newly added object, add with the same id to the labels mask + # this is an indication to the user that this object needs to be assigned a class + elif instance_id not in old_instances: + new_labels_mask[where_instance] = instance_id + else: + where_instance_orig = np.where(original_instance_mask==instance_id) + # if the locations of the instance haven't changed, means object wasn't changed, do nothing + num_classes = np.unique(labels_mask[where_instance]) + # if area was erased and object retains same class + if len(num_classes)==1: + new_labels_mask[where_instance] = num_classes[0] + # area was added where there is background or other class + else: + old_class_id = np.unique(labels_mask[where_instance_orig]) + #assert len(old_class_id)==1 + old_class_id = old_class_id[0] + new_labels_mask[where_instance] = old_class_id + + return new_labels_mask + + @staticmethod + def get_unique_objects(active_mask): + """ + Get unique objects from the active mask. + """ + return list(np.unique(active_mask)[1:]) + \ No newline at end of file diff --git a/src/client/test/test_compute4mask.py b/src/client/test/test_compute4mask.py new file mode 100644 index 0000000..b678ebc --- /dev/null +++ b/src/client/test/test_compute4mask.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest +from dcp_client.utils.utils import Compute4Mask + +@pytest.fixture +def sample_data(): + instance_mask = np.array([[0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + [2, 2, 0, 0], + [0, 0, 3, 3]]) + labels_mask = np.array([[0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + [2, 2, 0, 0], + [0, 0, 1, 1]]) + return instance_mask, labels_mask + +def test_get_unique_objects(sample_data): + instance_mask, _ = sample_data + unique_objects = Compute4Mask.get_unique_objects(instance_mask) + assert unique_objects == [1, 2, 3] + +def test_get_contours(sample_data): + instance_mask, _ = sample_data + contour_mask = Compute4Mask.get_contours(instance_mask) + assert contour_mask.shape == instance_mask.shape + assert contour_mask[0,1] == 1 # randomly check a contour location is present + +def test_add_contour(sample_data): + instance_mask, labels_mask = sample_data + contours_mask = Compute4Mask.get_contours(instance_mask) + labels_mask_wo_contour = np.copy(labels_mask) + labels_mask_wo_contour[contours_mask!=0] = 0 + updated_labels_mask = Compute4Mask.add_contour(labels_mask_wo_contour, instance_mask, contours_mask) + assert np.array_equal(updated_labels_mask[:3], labels_mask[:3]) + +def test_compute_new_instance_mask(sample_data): + instance_mask, labels_mask = sample_data + labels_mask[labels_mask==1] = 0 + updated_instance_mask = Compute4Mask.compute_new_instance_mask(labels_mask, instance_mask) + assert list(np.unique(updated_instance_mask))==[0,2] + +def test_compute_new_labels_mask_obj_added(sample_data): + instance_mask, labels_mask = sample_data + original_instance_mask = np.copy(instance_mask) + instance_mask[0, 0] = 4 + old_instances = Compute4Mask.get_unique_objects(original_instance_mask) + new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) + assert new_labels_mask[0,0]==4 + +def test_compute_new_labels_mask_obj_updated(sample_data): + instance_mask, labels_mask = sample_data + original_instance_mask = np.copy(instance_mask) + instance_mask[0] = 0 + old_instances = Compute4Mask.get_unique_objects(original_instance_mask) + new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) + assert np.all(new_labels_mask[0])==0 + assert np.array_equal(new_labels_mask[1:], labels_mask[1:]) diff --git a/src/client/test/test_main_window.py b/src/client/test/test_main_window.py index ca80c6d..ec56459 100644 --- a/src/client/test/test_main_window.py +++ b/src/client/test/test_main_window.py @@ -139,7 +139,7 @@ def test_launch_napari_button_click_without_selection(qtbot, app): def test_launch_napari_button_click(qtbot, app): settings.accepted_types = setup_global_variable - # Simulate selection of an image to view before clivking on view button + # Simulate selection of an image to view before clicking on view button index = app.list_view_eval.indexAt(app.list_view_eval.viewport().rect().topLeft()) pos = app.list_view_eval.visualRect(index).center() # Simulate file click @@ -154,6 +154,9 @@ def test_launch_napari_button_click(qtbot, app): assert app.nap_win.isVisible() + + + @pytest.fixture(scope='session', autouse=True) def cleanup_files(request): # This code runs after all tests from all files have completed diff --git a/src/client/test/test_napari_window.py b/src/client/test/test_napari_window.py new file mode 100644 index 0000000..61a7839 --- /dev/null +++ b/src/client/test/test_napari_window.py @@ -0,0 +1,108 @@ +import os + +from skimage import data +from skimage.io import imsave +import numpy as np + +import pytest +from dcp_client.app import Application +from dcp_client.gui.napari_window import NapariWindow + +from dcp_client.app import Application +from dcp_client.utils.bentoml_model import BentomlModel +from dcp_client.utils.fsimagestorage import FilesystemImageStorage +from dcp_client.utils.sync_src_dst import DataRSync +from dcp_client.utils import settings + +# @pytest.fixture +# def napari_app(): +# app = Application([]) +# napari_app = QtViewer() +# yield napari_app +# napari_app.close() + +@pytest.fixture +def napari_window(qtbot): + + img1 = data.astronaut() + img2 = data.coffee() + img3 = data.cat() + + if not os.path.exists('train_data_path'): + os.mkdir('train_data_path') + imsave('train_data_path/astronaut.png', img1) + + if not os.path.exists('in_prog'): + os.mkdir('in_prog') + imsave('in_prog/coffee.png', img2) + + if not os.path.exists('eval_data_path'): + os.mkdir('eval_data_path') + imsave('eval_data_path/cat.png', img3) + imsave('eval_data_path/cat_seg.png', img3) + + imsave('eval_data_path/cat_test.png', img3) + imsave('eval_data_path/cat_test_seg.png', img3) + + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') + application = Application( + BentomlModel(), + rsyncer, + FilesystemImageStorage(), + "0.0.0.0", + 7010, + os.path.join(os.getcwd(), 'eval_data_path'), + os.path.join(os.getcwd(), 'train_data_path'), + os.path.join(os.getcwd(), 'in_prog') + ) + + application.cur_selected_img = 'cat_test.png' + application.cur_selected_path = application.eval_data_path + + widget = NapariWindow(application) + qtbot.addWidget(widget) + yield widget + widget.close() + +def test_napari_window_initialization(napari_window): + assert napari_window.viewer is not None + assert napari_window.qctrl is not None + assert napari_window.mask_choice_dropdown is not None + +def test_on_add_to_curated_button_clicked(napari_window, monkeypatch): + # Mock the create_warning_box method + def mock_create_warning_box(message_text, message_title): + return None + + monkeypatch.setattr(napari_window, 'create_warning_box', mock_create_warning_box) + + # assert napari_window.app.cur_selected_path == 'eval_data_path' + + napari_window.app.cur_selected_img = 'cat_test.png' + napari_window.app.cur_selected_path = napari_window.app.eval_data_path + + napari_window.viewer.layers.selection.active.name = 'cat_test_seg' + + # Simulate the button click + napari_window.on_add_to_curated_button_clicked() + + assert os.path.exists('train_data_path/cat_test_seg.tiff') + assert os.path.exists('train_data_path/cat_test.png') + assert not os.path.exists('eval_data_path/cat_test.png') + +# @pytest.fixture(scope='session', autouse=True) +# def cleanup_files(request): +# # This code runs after all tests from all files have completed +# yield +# # Clean up +# for fname in os.listdir('train_data_path'): +# os.remove(os.path.join('train_data_path', fname)) +# os.rmdir('train_data_path') + +# for fname in os.listdir('in_prog'): +# os.remove(os.path.join('in_prog', fname)) +# os.rmdir('in_prog') + +# for fname in os.listdir('eval_data_path'): +# os.remove(os.path.join('eval_data_path', fname)) +# os.rmdir('eval_data_path') \ No newline at end of file diff --git a/src/client/test/test_welcome_window.py b/src/client/test/test_welcome_window.py index 204a1cc..8b453ae 100644 --- a/src/client/test/test_welcome_window.py +++ b/src/client/test/test_welcome_window.py @@ -3,6 +3,7 @@ sys.path.append('../') from PyQt5.QtCore import Qt +from PyQt5.QtWidgets import QMessageBox from dcp_client.gui.welcome_window import WelcomeWindow from dcp_client.app import Application @@ -21,10 +22,10 @@ def app(qtbot): rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') application = Application(BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010) # Create an instance of WelcomeWindow - #q_app = QApplication([]) + # q_app = QApplication([]) widget = WelcomeWindow(application) qtbot.addWidget(widget) - yield widget + yield widget widget.close() def test_welcome_window_initialization(app): @@ -32,7 +33,35 @@ def test_welcome_window_initialization(app): assert app.val_textbox.text() == "" assert app.inprogr_textbox.text() == "" assert app.train_textbox.text() == "" - + +def test_warning_for_same_paths(qtbot, app, monkeypatch): + app.app.eval_data_path = "/same/path" + app.app.train_data_path = "/same/path" + app.app.inprogr_data_path = "/same/path" + + # Define a custom exec method that always returns QMessageBox.Ok + def custom_exec(self): + return QMessageBox.Ok + + monkeypatch.setattr(QMessageBox, 'exec', custom_exec) + qtbot.mouseClick(app.start_button, Qt.LeftButton) + + assert app.create_warning_box + assert app.message_text == "All directory names must be distinct." + +def test_on_text_changed(qtbot, app): + app.app.train_data_path = "/initial/train/path" + app.app.eval_data_path = "/initial/eval/path" + app.app.inprogr_data_path = "/initial/inprogress/path" + + app.on_text_changed(field_obj=app.train_textbox, field_name="train", text="/new/train/path") + assert app.app.train_data_path == "/new/train/path" + + app.on_text_changed(field_obj=app.val_textbox, field_name="eval", text="/new/eval/path") + assert app.app.eval_data_path == "/new/eval/path" + + app.on_text_changed(field_obj=app.inprogr_textbox, field_name="inprogress", text="/new/inprogress/path") + assert app.app.inprogr_data_path == "/new/inprogress/path" '''' # TODO wait for github respose def test_browse_eval_clicked(qtbot, app, monkeypatch): @@ -90,6 +119,10 @@ def test_start_main_not_selected(qtbot, app): def test_start_main(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable + + # app.app.cur_selected_path = app.app.eval_data_path + # app.app.cur_selected_img = 'cat.png' + # Set some paths for testing app.app.eval_data_path = "/path/to/eval" app.app.train_data_path = "/path/to/train"