Skip to content

Commit

Permalink
Merge pull request #62 from HelmholtzAI-Consultants-Munich/client-add…
Browse files Browse the repository at this point in the history
…-in-prog

Client add in prog
  • Loading branch information
christinab12 authored Jan 10, 2024
2 parents 8df82bb + abeb4db commit 32fc3cc
Show file tree
Hide file tree
Showing 14 changed files with 641 additions and 167 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ jobs:
pip install coverage
pip install -e ".[testing]"
working-directory: src/client

- name: Install server dependencies (for communication tests)
run: |
pip install -e ".[testing]"
working-directory: src/server

- name: Test with pytest
run: |
Expand Down
57 changes: 37 additions & 20 deletions src/client/dcp_client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,30 +65,44 @@ def __init__(
self.seg_filepaths = []

def upload_data_to_server(self):
"""
Uploads the train and eval data to the server.
"""
self.syncer.first_sync(path=self.train_data_path)
self.syncer.first_sync(path=self.eval_data_path)

def try_server_connection(self):
"""
Checks if the ml model is connected to server and attempts to connect if not.
"""
connection_success = self.ml_model.connect(ip=self.server_ip, port=self.server_port)
return connection_success

def run_train(self):
""" Checks if the ml model is connected to the server, connects if not (and if possible), and trains the model with all data available in train_data_path """
if not self.ml_model.is_connected:
connection_success = self.ml_model.connect(ip=self.server_ip, port=self.server_port)
if not connection_success: return "Connection could not be established. Please check if the server is running and try again."
if not self.ml_model.is_connected and not self.try_server_connection():
message_title = "Warning"
message_text = "Connection could not be established. Please check if the server is running and try again."
return message_text, message_title
# if syncer.host name is None then local machine is used to train
message_title = "Information"
if self.syncer.host_name=="local":
return self.ml_model.run_train(self.train_data_path)
message_text = self.ml_model.run_train(self.train_data_path)
else:
srv_relative_path = self.syncer.sync(src='client', dst='server', path=self.train_data_path)
return self.ml_model.run_train(srv_relative_path)

message_text = self.ml_model.run_train(srv_relative_path)
if message_text is None:
message_text = "An error has occured on the server. Please check your image data and configurations. If the problem persists contact your software provider."
message_title = "Error"
return message_text, message_title

def run_inference(self):
""" Checks if the ml model is connected to the server, connects if not (and if possible), and runs inference on all images in eval_data_path """
if not self.ml_model.is_connected:
connection_success = self.ml_model.connect(ip=self.server_ip, port=self.server_port)
if not connection_success:
message_text = "Connection could not be established. Please check if the server is running and try again."
return message_text, "Warning"

if not self.ml_model.is_connected and not self.try_server_connection():
message_title = "Warning"
message_text = "Connection could not be established. Please check if the server is running and try again."
return message_text, message_title

if self.syncer.host_name=="local":
# model serving directly from local
list_of_files_not_suported = self.ml_model.run_inference(self.eval_data_path)
Expand All @@ -98,16 +112,19 @@ def run_inference(self):
list_of_files_not_suported = self.ml_model.run_inference(srv_relative_path)
# sync data so that client gets new masks
_ = self.syncer.sync(src='server', dst='client', path=self.eval_data_path)

# check if serving could not be performed for some files and prepare message
list_of_files_not_suported = list(list_of_files_not_suported)
if len(list_of_files_not_suported) > 0:
message_text = "Image types not supported. Only 2D and 3D image shapes currently supported. 3D stacks must be of type grayscale. \
Currently supported image file formats are: " + ", ".join(settings.accepted_types)+ ". The files that were not supported are: " + ", ".join(list_of_files_not_suported)
message_title = "Warning"
if list_of_files_not_suported is None:
message_text = "An error has occured on the server. Please check your image data and configurations. If the problem persists contact your software provider."
message_title = "Error"
else:
message_text = "Success! Masks generated for all images"
message_title="Success"
list_of_files_not_suported = list(list_of_files_not_suported)
if len(list_of_files_not_suported) > 0:
message_text = "Image types not supported. Only 2D and 3D image shapes currently supported. 3D stacks must be of type grayscale. \
Currently supported image file formats are: " + ", ".join(settings.accepted_types)+ ". The files that were not supported are: " + ", ".join(list_of_files_not_suported)
message_title = "Warning"
else:
message_text = "Success! Masks generated for all images"
message_title = "Information"
return message_text, message_title

def load_image(self, image_name=None):
Expand Down
35 changes: 35 additions & 0 deletions src/client/dcp_client/gui/_my_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from PyQt5.QtWidgets import QWidget, QMessageBox
from PyQt5.QtCore import QTimer

class MyWidget(QWidget):

msg = None
sim = False # will be used for testing to simulate user click

def create_warning_box(self, message_text: str=" ", message_title: str="Information", add_cancel_btn: bool=False, custom_dialog=None) -> None:
#setup box
if custom_dialog is not None: self.msg = custom_dialog
else: self.msg = QMessageBox()

if message_title=="Warning":
message_type = QMessageBox.Warning
elif message_title=="Error":
message_type = QMessageBox.Critical
else:
message_type = QMessageBox.Information
self.msg.setIcon(message_type)
self.msg.setText(message_text)
self.msg.setWindowTitle(message_title)
# if specified add a cancel button else only an ok
if add_cancel_btn:
self.msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel)
# simulate button click if specified - workaround used for testing
if self.sim: QTimer.singleShot(0, self.msg.button(QMessageBox.Cancel).clicked)
else:
self.msg.setStandardButtons(QMessageBox.Ok)
# simulate button click if specified - workaround used for testing
if self.sim: QTimer.singleShot(0, self.msg.button(QMessageBox.Ok).clicked)
# return if user clicks Ok and False otherwise
usr_response = self.msg.exec()
if usr_response == QMessageBox.Ok: return True
else: return False
131 changes: 110 additions & 21 deletions src/client/dcp_client/gui/main_window.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,68 @@
from __future__ import annotations
from typing import TYPE_CHECKING

from PyQt5.QtWidgets import QWidget, QPushButton, QVBoxLayout, QFileSystemModel, QHBoxLayout, QLabel, QTreeView
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QPushButton, QVBoxLayout, QFileSystemModel, QHBoxLayout, QLabel, QTreeView, QProgressBar
from PyQt5.QtCore import Qt, QThread, pyqtSignal

from dcp_client.utils import settings
from dcp_client.utils.utils import IconProvider, create_warning_box
from dcp_client.utils.utils import IconProvider

from dcp_client.gui.napari_window import NapariWindow
from dcp_client.gui._my_widget import MyWidget

if TYPE_CHECKING:
from dcp_client.app import Application

class WorkerThread(QThread):
''' Worker thread for displaying Pulse ProgressBar during model serving '''
task_finished = pyqtSignal(tuple)
def __init__(self, app: Application, task: str = None, parent = None,):
super().__init__(parent)
self.app = app
self.task = task

def run(self):
''' Once run_inference the tuple of (message_text, message_title) will be returned to on_finished'''
try:
if self.task == 'inference':
message_text, message_title = self.app.run_inference()
elif self.task == 'train':
message_text, message_title = self.app.run_train()
else:
message_text, message_title = "Unknown task", "Error"

except Exception as e:
# Log any exceptions that might occur in the thread
message_text, message_title = f"Exception in WorkerThread: {e}", "Error"

class MainWindow(QWidget):
'''Main Window Widget object.
self.task_finished.emit((message_text, message_title))

class MainWindow(MyWidget):
'''
Main Window Widget object.
Opens the main window of the app where selected images in both directories are listed.
User can view the images, train the mdoel to get the labels, and visualise the result.
:param eval_data_path: Chosen path to images without labeles, selected by the user in the WelcomeWindow
:type eval_data_path: string
:param train_data_path: Chosen path to images with labeles, selected by the user in the WelcomeWindow
:type train_data_path: string
'''
'''

def __init__(self, app: Application):
super().__init__()
self.app = app
self.title = "Data Overview"
self.worker_thread = None
self.main_window()

def main_window(self):
'''
Sets up the GUI
'''
self.setWindowTitle(self.title)
#self.resize(1000, 1500)
self.main_layout = QHBoxLayout()
main_layout = QVBoxLayout()
dir_layout = QHBoxLayout()

self.uncurated_layout = QVBoxLayout()
self.inprogress_layout = QVBoxLayout()
Expand Down Expand Up @@ -61,7 +92,7 @@ def main_window(self):
self.inference_button.clicked.connect(self.on_run_inference_button_clicked) # add selected image
self.uncurated_layout.addWidget(self.inference_button, alignment=Qt.AlignCenter)

self.main_layout.addLayout(self.uncurated_layout)
dir_layout.addLayout(self.uncurated_layout)

# In progress layout
self.inprogr_dir_layout = QVBoxLayout()
Expand All @@ -87,7 +118,7 @@ def main_window(self):
self.launch_nap_button.clicked.connect(self.on_launch_napari_button_clicked) # add selected image
self.inprogress_layout.addWidget(self.launch_nap_button, alignment=Qt.AlignCenter)

self.main_layout.addLayout(self.inprogress_layout)
dir_layout.addLayout(self.inprogress_layout)

# Curated layout
self.train_dir_layout = QVBoxLayout()
Expand All @@ -112,55 +143,113 @@ def main_window(self):
self.train_button = QPushButton("Train Model", self)
self.train_button.clicked.connect(self.on_train_button_clicked) # add selected image
self.curated_layout.addWidget(self.train_button, alignment=Qt.AlignCenter)
dir_layout.addLayout(self.curated_layout)

main_layout.addLayout(dir_layout)

self.main_layout.addLayout(self.curated_layout)
# add progress bar
progress_layout = QHBoxLayout()
progress_layout.addStretch(1)
self.progress_bar = QProgressBar(self)
self.progress_bar.setRange(0,1)
progress_layout.addWidget(self.progress_bar)
main_layout.addLayout(progress_layout)

self.setLayout(self.main_layout)
self.setLayout(main_layout)
self.show()

def on_item_train_selected(self, item):
'''
Is called once an image is selected in the 'curated dataset' folder
'''
self.app.cur_selected_img = item.data()
self.app.cur_selected_path = self.app.train_data_path

def on_item_eval_selected(self, item):
'''
Is called once an image is selected in the 'uncurated dataset' folder
'''
self.app.cur_selected_img = item.data()
self.app.cur_selected_path = self.app.eval_data_path

def on_item_inprogr_selected(self, item):
'''
Is called once an image is selected in the 'in progress' folder
'''
self.app.cur_selected_img = item.data()
self.app.cur_selected_path = self.app.inprogr_data_path

def on_train_button_clicked(self):
message_text = self.app.run_train()
_ = create_warning_box(message_text)
'''
Is called once user clicks the "Train Model" button
'''
self.train_button.setEnabled(False)
self.progress_bar.setRange(0,0)
# initialise the worker thread
self.worker_thread = WorkerThread(app=self.app, task='train')
self.worker_thread.task_finished.connect(self.on_finished)
# start the worker thread to train
self.worker_thread.start()

def on_run_inference_button_clicked(self):
message_text, message_title = self.app.run_inference()
_ = create_warning_box(message_text, message_title)
'''
Is called once user clicks the "Generate Labels" button
'''
self.inference_button.setEnabled(False)
self.progress_bar.setRange(0,0)
# initialise the worker thread
self.worker_thread = WorkerThread(app=self.app, task='inference')
self.worker_thread.task_finished.connect(self.on_finished)
# start the worker thread to run inference
self.worker_thread.start()

def on_launch_napari_button_clicked(self):
'''
Launches the napari window after the image is selected.
'''
if not self.app.cur_selected_img or '_seg.tiff' in self.app.cur_selected_img:
message_text = "Please first select an image you wish to visualise. The selected image must be an original images, not a mask."
_ = create_warning_box(message_text, message_title="Warning")
message_text = "Please first select an image you wish to visualise. The selected image must be an original image, not a mask."
_ = self.create_warning_box(message_text, message_title="Warning")
else:
self.nap_win = NapariWindow(self.app)
self.nap_win.show()

def on_finished(self, result):
'''
Is called once the worker thread emits the on finished signal
'''
# Stop the pulsation
self.progress_bar.setRange(0,1)
# Display message of result
message_text, message_title = result
_ = self.create_warning_box(message_text, message_title)
# Re-enable buttons
self.inference_button.setEnabled(True)
self.train_button.setEnabled(True)
# Delete the worker thread when it's done
self.worker_thread.quit()
self.worker_thread.wait()
self.worker_thread.deleteLater()
self.worker_thread = None # Set to None to indicate it's no longer in use


if __name__ == "__main__":
import sys
from PyQt5.QtWidgets import QApplication
from app import Application
from bentoml_model import BentomlModel
from fsimagestorage import FilesystemImageStorage
import settings
from dcp_client.app import Application
from dcp_client.utils.bentoml_model import BentomlModel
from dcp_client.utils.fsimagestorage import FilesystemImageStorage
from dcp_client.utils import settings
from dcp_client.utils.sync_src_dst import DataRSync
settings.init()
image_storage = FilesystemImageStorage()
ml_model = BentomlModel()
data_sync = DataRSync(user_name="local",
host_name="local",
server_repo_path=None)
app = QApplication(sys.argv)
app_ = Application(ml_model=ml_model,
syncer=data_sync,
image_storage=image_storage,
server_ip='0.0.0.0',
server_port=7010,
Expand Down
Loading

0 comments on commit 32fc3cc

Please sign in to comment.