diff --git a/ASO_CBCT/LinearTransform_t.tfm b/ASO_CBCT/LinearTransform_t.tfm new file mode 100644 index 0000000..de1d214 --- /dev/null +++ b/ASO_CBCT/LinearTransform_t.tfm @@ -0,0 +1,5 @@ +#Insight Transform File V1.0 +#Transform 0 +Transform: AffineTransform_double_3_3 +Parameters: 0.9763290732702393 0.15616745925611628 0.14964379491567228 -0.1421477793365056 0.9847545451025459 -0.10026213008697676 -0.16302008930489087 0.07661729941316797 0.9836433499565058 3.8520097928318386 10.915991569159537 -4.582919224464414 +FixedParameters: 0 0 0 diff --git a/AutoMatrix/AutoMatrix.py b/AutoMatrix/AutoMatrix.py index 0842385..207feac 100644 --- a/AutoMatrix/AutoMatrix.py +++ b/AutoMatrix/AutoMatrix.py @@ -468,7 +468,7 @@ def ProcessVolume(self)-> None: print("not a .nii.gz") self.UpdateTime() - if extension_scan!=".nii.gz": + if extension_scan!=".nii.gz" and extension_scan!=".nrrd": model = slicer.util.loadModel(scan) else : model = slicer.util.loadVolume(scan) @@ -616,7 +616,7 @@ def onProcessStarted(self)->None: Initialize the variables and progress bar. """ if os.path.isdir(self.ui.LineEditPatient.text): - self.nbFiles = len(self.dico_patient[".vtk"]) + len(self.dico_patient['.vtp']) + len(self.dico_patient['.stl']) + len(self.dico_patient['.off']) + len(self.dico_patient['.obj']) + len(self.dico_patient['.nii.gz']) + self.nbFiles = len(self.dico_patient[".vtk"]) + len(self.dico_patient['.vtp']) + len(self.dico_patient['.stl']) + len(self.dico_patient['.off']) + len(self.dico_patient['.obj']) + len(self.dico_patient['.nii.gz']) + len(self.dico_patient['nrrd']) else: self.nbFiles = 1 self.ui.progressBar.setValue(0) @@ -645,10 +645,10 @@ def CheckGoodEntre(self)->bool: warning_text = warning_text + "Enter file patient" + "\n" else : if self.ui.ComboBoxPatient.currentIndex==1 : #folder option - self.dico_patient=search(self.ui.LineEditPatient.text,'.vtk','.vtp','.stl','.off','.obj','.nii.gz') - if len(self.dico_patient['.vtk'])==0 and len(self.dico_patient['.vtp']) and len(self.dico_patient['.stl']) and len(self.dico_patient['.off']) and len(self.dico_patient['.obj']) and len(self.dico_patient['.nii.gz']) : + self.dico_patient=search(self.ui.LineEditPatient.text,'.vtk','.vtp','.stl','.off','.obj','.nii.gz','nrrd') + if len(self.dico_patient['.vtk'])==0 and len(self.dico_patient['.vtp']) and len(self.dico_patient['.stl']) and len(self.dico_patient['.off']) and len(self.dico_patient['.obj']) and len(self.dico_patient['.nii.gz']) and len(self.dico_patient['.nrrd']) : warning_text = warning_text + "Folder empty or wrong type of file patient" + "\n" - warning_text = warning_text + "File authorized : .vtk / .vtp / .stl / .off / .obj / .nii.gz" + "\n" + warning_text = warning_text + "File authorized : .vtk / .vtp / .stl / .off / .obj / .nii.gz / .nrrd" + "\n" elif self.ui.ComboBoxPatient.currentIndex==0 : # file option fname, extension = os.path.splitext(os.path.basename(self.ui.LineEditPatient.text)) try : @@ -656,9 +656,9 @@ def CheckGoodEntre(self)->bool: extension = extension2+extension except : print("not a .nii.gz") - if extension != ".vtk" and extension != ".vtp" and extension != ".stl" and extension != ".off" and extension != ".obj" and extension != ".nii.gz" : + if extension != ".vtk" and extension != ".vtp" and extension != ".stl" and extension != ".off" and extension != ".obj" and extension != ".nii.gz" and extension != ".nrrd": warning_text = warning_text + "Wrong type of file patient detected" + "\n" - warning_text = warning_text + "File authorized : .vtk / .vtp / .stl / .off / .obj / .nii.gz" + "\n" + warning_text = warning_text + "File authorized : .vtk / .vtp / .stl / .off / .obj / .nii.gz / .nrrd" + "\n" if self.ui.LineEditMatrix.text=="": diff --git a/AutoMatrix/Method/General_tools.py b/AutoMatrix/Method/General_tools.py index efd0c37..4c419ab 100644 --- a/AutoMatrix/Method/General_tools.py +++ b/AutoMatrix/Method/General_tools.py @@ -45,7 +45,7 @@ def GetPatients(file_path:str,matrix_path:str): files = [] if Path(file_path).is_dir(): - files_original = search(file_path,'.vtk','.vtp','.stl','.off','.obj','.nii.gz') + files_original = search(file_path,'.vtk','.vtp','.stl','.off','.obj','.nii.gz','.nrrd') files = [] for i in range(len(files_original['.vtk'])): files.append(files_original['.vtk'][i]) @@ -64,6 +64,9 @@ def GetPatients(file_path:str,matrix_path:str): for i in range(len(files_original['.nii.gz'])): files.append(files_original['.nii.gz'][i]) + + for i in range(len(files_original['.nrrd'])): + files.append(files_original['.nrrd'][i]) for i in range(len(files)): file = files[i] @@ -87,7 +90,7 @@ def GetPatients(file_path:str,matrix_path:str): except : print("not a .nii.gz") - if extension ==".vtk" or extension ==".vtp" or extension ==".stl" or extension ==".off" or extension ==".obj" or extension==".nii.gz" : + if extension ==".vtk" or extension ==".vtp" or extension ==".stl" or extension ==".off" or extension ==".obj" or extension==".nii.gz" or extension==".nrrd" : files = [file_path] file_pat = os.path.basename(file_path).split('_Seg')[0].split('_seg')[0].split('_Scan')[0].split('_scan')[0].split('_Or')[0].split('_OR')[0].split('_MAND')[0].split('_MD')[0].split('_MAX')[0].split('_MX')[0].split('_CB')[0].split('_lm')[0].split('_T2')[0].split('_T1')[0].split('_Cl')[0].split('.')[0].replace('.','') for i in range(50): diff --git a/CMakeLists.txt b/CMakeLists.txt index 39bea8c..c3d08db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,6 +35,8 @@ add_subdirectory(AREG) add_subdirectory(AutoMatrix) add_subdirectory(AutoCrop3D) +add_subdirectory(MRI2CBCT) +add_subdirectory(MRI2CBCT_CLI) ## NEXT_MODULE #----------------------------------------------------------------------------- diff --git a/MRI2CBCT/CMakeLists.txt b/MRI2CBCT/CMakeLists.txt new file mode 100644 index 0000000..068a372 --- /dev/null +++ b/MRI2CBCT/CMakeLists.txt @@ -0,0 +1,37 @@ +#----------------------------------------------------------------------------- +set(MODULE_NAME MRI2CBCT) + +#----------------------------------------------------------------------------- +set(MODULE_PYTHON_SCRIPTS + ${MODULE_NAME}.py + utils/Method.py + utils/Preprocess_CBCT_MRI.py + utils/Preprocess_CBCT.py + utils/Preprocess_MRI.py + utils/Reg_MRI2CBCT.py + utils/utils_CBCT.py + ) + +set(MODULE_PYTHON_RESOURCES + Resources/Icons/${MODULE_NAME}.png + Resources/UI/${MODULE_NAME}.ui + ) + +#----------------------------------------------------------------------------- +slicerMacroBuildScriptedModule( + NAME ${MODULE_NAME} + SCRIPTS ${MODULE_PYTHON_SCRIPTS} + RESOURCES ${MODULE_PYTHON_RESOURCES} + WITH_GENERIC_TESTS + ) + +#----------------------------------------------------------------------------- +if(BUILD_TESTING) + + # Register the unittest subclass in the main script as a ctest. + # Note that the test will also be available at runtime. + slicer_add_python_unittest(SCRIPT ${MODULE_NAME}.py) + + # Additional build-time testing + add_subdirectory(Testing) +endif() diff --git a/MRI2CBCT/MRI2CBCT.py b/MRI2CBCT/MRI2CBCT.py new file mode 100644 index 0000000..2c23d99 --- /dev/null +++ b/MRI2CBCT/MRI2CBCT.py @@ -0,0 +1,1392 @@ +import logging +import os +from typing import Annotated, Optional +from qt import QApplication, QWidget, QTableWidget, QDoubleSpinBox, QTableWidgetItem, QHeaderView,QSpinBox, QVBoxLayout, QLabel, QSizePolicy, QCheckBox, QFileDialog,QMessageBox, QApplication, QProgressDialog +import qt +from utils.Preprocess_CBCT import Process_CBCT +from utils.Preprocess_MRI import Process_MRI +from utils.Preprocess_CBCT_MRI import Preprocess_CBCT_MRI +from utils.Reg_MRI2CBCT import Registration_MRI2CBCT +import time + +import vtk +import shutil +import urllib +import zipfile + +import slicer +from functools import partial +from slicer.i18n import tr as _ +from slicer.i18n import translate +from slicer.ScriptedLoadableModule import * +from slicer.util import VTKObservationMixin +from slicer.parameterNodeWrapper import ( + parameterNodeWrapper, + WithinRange, +) + +from slicer import vtkMRMLScalarVolumeNode + + +# +# MRI2CBCT +# + + +class MRI2CBCT(ScriptedLoadableModule): + """Uses ScriptedLoadableModule base class, available at: + https://github.com/Slicer/Slicer/blob/main/Base/Python/slicer/ScriptedLoadableModule.py + """ + + def __init__(self, parent): + ScriptedLoadableModule.__init__(self, parent) + self.parent.title = _("MRI2CBCT") # TODO: make this more human readable by adding spaces + # TODO: set categories (folders where the module shows up in the module selector) + self.parent.categories = ["Automated Dental Tools"] + self.parent.dependencies = [] # TODO: add here list of module names that this module requires + self.parent.contributors = ["John Doe (AnyWare Corp.)"] # TODO: replace with "Firstname Lastname (Organization)" + # TODO: update with short description of the module and a link to online module documentation + # _() function marks text as translatable to other languages + self.parent.helpText = _(""" +This is an example of scripted loadable module bundled in an extension. +See more information in module documentation. +""") + # TODO: replace with organization, grant and thanks + self.parent.acknowledgementText = _(""" +This file was originally developed by Jean-Christophe Fillion-Robin, Kitware Inc., Andras Lasso, PerkLab, +and Steve Pieper, Isomics, Inc. and was partially funded by NIH grant 3P41RR013218-12S1. +""") + + # Additional initialization step after application startup is complete + slicer.app.connect("startupCompleted()", registerSampleData) + + +# +# Register sample data sets in Sample Data module +# + + +def registerSampleData(): + """Add data sets to Sample Data module.""" + # It is always recommended to provide sample data for users to make it easy to try the module, + # but if no sample data is available then this method (and associated startupCompeted signal connection) can be removed. + + import SampleData + + iconsPath = os.path.join(os.path.dirname(__file__), "Resources/Icons") + + # To ensure that the source code repository remains small (can be downloaded and installed quickly) + # it is recommended to store data sets that are larger than a few MB in a Github release. + + # MRI2CBCT1 + SampleData.SampleDataLogic.registerCustomSampleDataSource( + # Category and sample name displayed in Sample Data module + category="MRI2CBCT", + sampleName="MRI2CBCT1", + # Thumbnail should have size of approximately 260x280 pixels and stored in Resources/Icons folder. + # It can be created by Screen Capture module, "Capture all views" option enabled, "Number of images" set to "Single". + thumbnailFileName=os.path.join(iconsPath, "MRI2CBCT1.png"), + # Download URL and target file name + uris="https://github.com/Slicer/SlicerTestingData/releases/download/SHA256/998cb522173839c78657f4bc0ea907cea09fd04e44601f17c82ea27927937b95", + fileNames="MRI2CBCT1.nrrd", + # Checksum to ensure file integrity. Can be computed by this command: + # import hashlib; print(hashlib.sha256(open(filename, "rb").read()).hexdigest()) + checksums="SHA256:998cb522173839c78657f4bc0ea907cea09fd04e44601f17c82ea27927937b95", + # This node name will be used when the data set is loaded + nodeNames="MRI2CBCT1", + ) + + # MRI2CBCT2 + SampleData.SampleDataLogic.registerCustomSampleDataSource( + # Category and sample name displayed in Sample Data module + category="MRI2CBCT", + sampleName="MRI2CBCT2", + thumbnailFileName=os.path.join(iconsPath, "MRI2CBCT2.png"), + # Download URL and target file name + uris="https://github.com/Slicer/SlicerTestingData/releases/download/SHA256/1a64f3f422eb3d1c9b093d1a18da354b13bcf307907c66317e2463ee530b7a97", + fileNames="MRI2CBCT2.nrrd", + checksums="SHA256:1a64f3f422eb3d1c9b093d1a18da354b13bcf307907c66317e2463ee530b7a97", + # This node name will be used when the data set is loaded + nodeNames="MRI2CBCT2", + ) + + +# +# MRI2CBCTParameterNode +# + + +@parameterNodeWrapper +class MRI2CBCTParameterNode: + """ + The parameters needed by module. + + inputVolume - The volume to threshold. + imageThreshold - The value at which to threshold the input volume. + invertThreshold - If true, will invert the threshold. + thresholdedVolume - The output volume that will contain the thresholded volume. + invertedVolume - The output volume that will contain the inverted thresholded volume. + """ + + inputVolume: vtkMRMLScalarVolumeNode + imageThreshold: Annotated[float, WithinRange(-100, 500)] = 100 + invertThreshold: bool = False + thresholdedVolume: vtkMRMLScalarVolumeNode + invertedVolume: vtkMRMLScalarVolumeNode + + +# +# MRI2CBCTWidget +# + + +class MRI2CBCTWidget(ScriptedLoadableModuleWidget, VTKObservationMixin): + """Uses ScriptedLoadableModuleWidget base class, available at: + https://github.com/Slicer/Slicer/blob/main/Base/Python/slicer/ScriptedLoadableModule.py + """ + + def __init__(self, parent=None) -> None: + """Called when the user opens the module the first time and the widget is initialized.""" + ScriptedLoadableModuleWidget.__init__(self, parent) + VTKObservationMixin.__init__(self) # needed for parameter node observation + self.logic = None + self.checked_cells = set() + self.minus_checked_rows = set() + self._parameterNode = None + self._parameterNodeGuiTag = None + + + + def setup(self) -> None: + """Called when the user opens the module the first time and the widget is initialized.""" + ScriptedLoadableModuleWidget.setup(self) + + # Load widget from .ui file (created by Qt Designer). + # Additional widgets can be instantiated manually and added to self.layout. + uiWidget = slicer.util.loadUI(self.resourcePath("UI/MRI2CBCT.ui")) + self.layout.addWidget(uiWidget) + self.ui = slicer.util.childWidgetVariables(uiWidget) + + # Set scene in MRML widgets. Make sure that in Qt designer the top-level qMRMLWidget's + # "mrmlSceneChanged(vtkMRMLScene*)" signal in is connected to each MRML widget's. + # "setMRMLScene(vtkMRMLScene*)" slot. + uiWidget.setMRMLScene(slicer.mrmlScene) + + # Create logic class. Logic implements all computations that should be possible to run + # in batch mode, without a graphical user interface. + self.logic = MRI2CBCTLogic() + + documentsLocation = qt.QStandardPaths.DocumentsLocation + self.documents = qt.QStandardPaths.writableLocation(documentsLocation) + self.SlicerDownloadPath = os.path.join( + self.documents, + slicer.app.applicationName + "Downloads", + "MRI2CBCT", + "MRI2CBCT_" + "CBCT", + ) + self.preprocess_cbct = Process_CBCT(self) + self.preprocess_mri = Process_MRI(self) + self.preprocess_mri_cbct = Preprocess_CBCT_MRI(self) + self.registration_mri2cbct = Registration_MRI2CBCT(self) + + # Connections + # LineEditOutputReg + # These connections ensure that we update parameter node when scene is closed + self.addObserver(slicer.mrmlScene, slicer.mrmlScene.StartCloseEvent, self.onSceneStartClose) + self.addObserver(slicer.mrmlScene, slicer.mrmlScene.EndCloseEvent, self.onSceneEndClose) + + # Buttons + self.ui.registrationButton.connect("clicked(bool)", self.registration_MR2CBCT) + self.ui.SearchButtonCBCT.connect("clicked(bool)",partial(self.openFinder,"InputCBCT")) + self.ui.SearchButtonMRI.connect("clicked(bool)",partial(self.openFinder,"InputMRI")) + self.ui.SearchButtonRegMRI.connect("clicked(bool)",partial(self.openFinder,"InputRegMRI")) + self.ui.SearchButtonRegCBCT.connect("clicked(bool)",partial(self.openFinder,"InputRegCBCT")) + self.ui.SearchButtonRegLabel.connect("clicked(bool)",partial(self.openFinder,"InputRegLabel")) + self.ui.SearchOutputFolderOrientCBCT.connect("clicked(bool)",partial(self.openFinder,"OutputOrientCBCT")) + self.ui.SearchOutputFolderOrientMRI.connect("clicked(bool)",partial(self.openFinder,"OutputOrientMRI")) + self.ui.SearchOutputFolderResample.connect("clicked(bool)",partial(self.openFinder,"OutputOrientResample")) + self.ui.SearchButtonOutput.connect("clicked(bool)",partial(self.openFinder,"OutputReg")) + self.ui.pushButtonOrientCBCT.connect("clicked(bool)",self.orientCBCT) + self.ui.pushButtonResample.connect("clicked(bool)",self.resampleMRICBCT) + self.ui.pushButtonOrientMRI.connect("clicked(bool)",self.orientCenterMRI) + self.ui.pushButtonDownloadOrientCBCT.connect("clicked(bool)",partial(self.downloadModel,self.ui.lineEditOrientCBCT, "Orientation", True)) + self.ui.pushButtonDownloadSegCBCT.connect("clicked(bool)",partial(self.downloadModel,self.ui.lineEditSegCBCT, "Segmentation", True)) + + + # Make sure parameter node is initialized (needed for module reload) + self.initializeParameterNode() + self.ui.ComboBoxCBCT.setCurrentIndex(1) + self.ui.ComboBoxCBCT.setEnabled(False) + self.ui.ComboBoxMRI.setCurrentIndex(1) + self.ui.ComboBoxMRI.setEnabled(False) + self.ui.comboBoxRegMRI.setCurrentIndex(1) + self.ui.comboBoxRegMRI.setEnabled(False) + self.ui.comboBoxRegCBCT.setCurrentIndex(1) + self.ui.comboBoxRegCBCT.setEnabled(False) + self.ui.comboBoxRegLabel.setCurrentIndex(1) + self.ui.comboBoxRegLabel.setEnabled(False) + + self.ui.label_time.setHidden(True) + self.ui.label_info.setHidden(True) + self.ui.progressBar.setHidden(True) + + self.ui.outputCollapsibleButton.setText("Registration") + self.ui.inputsCollapsibleButton.setText("Preprocess") + + self.ui.outputCollapsibleButton.setChecked(True) # True to expand, False to collapse + self.ui.inputsCollapsibleButton.setChecked(False) + ################################################################################################## + ### Orientation Table + self.tableWidgetOrient = self.ui.tableWidgetOrient + self.tableWidgetOrient.setRowCount(3) # Rows for New Direction X, Y, Z + self.tableWidgetOrient.setColumnCount(4) # Columns for X, Y, Z, and Minus + + # Set the headers + self.tableWidgetOrient.setHorizontalHeaderLabels(["X", "Y", "Z", "Negative"]) + self.tableWidgetOrient.setVerticalHeaderLabels(["New Direction X", "New Direction Y", "New Direction Z"]) + + # Set the horizontal header to stretch and fill the available space + header = self.tableWidgetOrient.horizontalHeader() + header.setSectionResizeMode(QHeaderView.Stretch) + + # Set a fixed height for the table to avoid stretching + self.tableWidgetOrient.setFixedHeight(self.tableWidgetOrient.horizontalHeader().height + + self.tableWidgetOrient.verticalHeader().sectionSize(0) * self.tableWidgetOrient.rowCount) + + # Add widgets for each cell + for row in range(3): + for col in range(4): # Columns X, Y, Z, and Minus + if col!=3 : + checkBox = QCheckBox('0') + checkBox.stateChanged.connect(lambda state, r=row, c=col: self.onCheckboxOrientClicked(r, c, state)) + self.tableWidgetOrient.setCellWidget(row, col, checkBox) + else : + checkBox = QCheckBox('No') + checkBox.stateChanged.connect(lambda state, r=row, c=col: self.onCheckboxOrientClicked(r, c, state)) + self.tableWidgetOrient.setCellWidget(row, col, checkBox) + + self.ui.ButtonDefaultOrientMRI.connect("clicked(bool)",self.defaultOrientMRI) + self.defaultOrientMRI() + + ################################################################################################## + ### Normalization Table + self.tableWidgetNorm = self.ui.tableWidgetNorm + + self.tableWidgetNorm.setRowCount(2) # MRI and CBCT rows + header row + self.tableWidgetNorm.setColumnCount(4) # Min, Max for Normalization and Percentile + + # Set the horizontal header to stretch and fill the available space + header = self.tableWidgetNorm.horizontalHeader() + header.setSectionResizeMode(QHeaderView.Stretch) + + # Set a fixed height for the table to avoid stretching + self.tableWidgetNorm.setFixedHeight(self.tableWidgetNorm.horizontalHeader().height + + self.tableWidgetNorm.verticalHeader().sectionSize(0) * self.tableWidgetNorm.rowCount) + + # Set the headers + self.tableWidgetNorm.setHorizontalHeaderLabels(["Normalization Min", "Normalization Max", "Percentile Min", "Percentile Max"]) + self.tableWidgetNorm.setVerticalHeaderLabels([ "MRI", "CBCT"]) + + + for row in range(2): + for col in range(4): + spinBox = QSpinBox() + if col in [2, 3]: # Columns for Percentile Min and Percentile Max + spinBox.setMaximum(100) + else: + spinBox.setMaximum(10000) + self.tableWidgetNorm.setCellWidget(row, col, spinBox) + + self.ui.ButtonCheckBoxDefaultNorm1.connect("clicked(bool)",partial(self.DefaultNorm,"1")) + self.ui.ButtonCheckBoxDefaultNorm2.connect("clicked(bool)",partial(self.DefaultNorm,"2")) + + self.DefaultNorm("1",_) + + ################################################################################################## + # RESAMPLE TABLE + self.tableWidgetResample = self.ui.tableWidgetResample + + # Increase the row and column count + self.tableWidgetResample.setRowCount(2) # Adding a second row + self.tableWidgetResample.setColumnCount(4) # Adding a new column + + # Set the horizontal header to stretch and fill the available space + header = self.tableWidgetResample.horizontalHeader() + header.setSectionResizeMode(QHeaderView.Stretch) + + # Set a fixed height for the table to avoid stretching + self.tableWidgetResample.setFixedHeight( + self.tableWidgetResample.horizontalHeader().height + + self.tableWidgetResample.verticalHeader().sectionSize(0) * self.tableWidgetResample.rowCount + ) + + # Set the headers + self.tableWidgetResample.setHorizontalHeaderLabels(["X", "Y", "Z", "Keep File "]) + self.tableWidgetResample.setVerticalHeaderLabels(["Number of slices", "Spacing"]) + + # Add QSpinBoxes for the first row + spinBox1 = QSpinBox() + spinBox1.setMaximum(10000) + spinBox1.setValue(119) + self.tableWidgetResample.setCellWidget(0, 0, spinBox1) + + spinBox2 = QSpinBox() + spinBox2.setMaximum(10000) + spinBox2.setValue(443) + self.tableWidgetResample.setCellWidget(0, 1, spinBox2) + + spinBox3 = QSpinBox() + spinBox3.setMaximum(10000) + spinBox3.setValue(443) + self.tableWidgetResample.setCellWidget(0, 2, spinBox3) + + # Add QSpinBoxes for the new row + spinBox4 = QDoubleSpinBox() + spinBox4.setMaximum(10000) + spinBox4.setSingleStep(0.1) + spinBox4.setValue(0.3) + self.tableWidgetResample.setCellWidget(1, 0, spinBox4) + + spinBox5 = QDoubleSpinBox() + spinBox5.setMaximum(10000) + spinBox5.setSingleStep(0.1) + spinBox5.setValue(0.3) + self.tableWidgetResample.setCellWidget(1, 1, spinBox5) + + spinBox6 = QDoubleSpinBox() + spinBox6.setMaximum(10000) + spinBox6.setSingleStep(0.1) + spinBox6.setValue(0.3) + self.tableWidgetResample.setCellWidget(1, 2, spinBox6) + # Add QCheckBox for the "Keep File" column + checkBox1 = QCheckBox("Keep the same size as the input scan") + checkBox1.stateChanged.connect(lambda state: self.toggleSpinBoxes(state, [spinBox1, spinBox2, spinBox3])) + self.tableWidgetResample.setCellWidget(0, 3, checkBox1) + + checkBox2 = QCheckBox("Keep the same spacing as the input scan") + checkBox2.stateChanged.connect(lambda state: self.toggleSpinBoxes(state, [spinBox4, spinBox5, spinBox6])) + self.tableWidgetResample.setCellWidget(1, 3, checkBox2) + + def toggleSpinBoxes(self, state, spinBoxes): + """ + Enable or disable a list of QSpinBox widgets based on the provided state. + + Parameters: + - state: An integer representing the state (2 for disabled, any other value for enabled). + - spinBoxes: A list of QSpinBox widgets to be toggled. + + The function iterates through each QSpinBox in the provided list. If the state is 2, + the QSpinBox is disabled and its text color is set to gray. Otherwise, the QSpinBox + is enabled and its default stylesheet is restored. + + This function is connected to the "keep file" checkbox. When the checkbox is checked + (state == 2), the spin boxes are disabled and shown in gray. If the checkbox is unchecked, + the spin boxes are enabled and restored to their default style. + """ + for spinBox in spinBoxes: + if state == 2: + spinBox.setEnabled(False) + spinBox.setStyleSheet("color: gray;") + else: + spinBox.setEnabled(True) + spinBox.setStyleSheet("") + + + def get_resample_values(self): + """ + Retrieves the resample values (X, Y, Z) from the QTableWidget. + + :return: A tuple of two lists representing the resample values for the two rows. + Each list contains three values (X, Y, Z) or None if the "Keep File" checkbox is checked. + First output : number of slices. + Second output : spacing + """ + resample_values_row1 = [] + resample_values_row2 = [] + + # Check the "Keep File" checkbox for the first row + if self.tableWidgetResample.cellWidget(0, 3).isChecked(): + resample_values_row1 = "None" + else: + resample_values_row1 = [ + self.tableWidgetResample.cellWidget(0, 0).value, + self.tableWidgetResample.cellWidget(0, 1).value, + self.tableWidgetResample.cellWidget(0, 2).value + ] + + # Check the "Keep File" checkbox for the second row + if self.tableWidgetResample.cellWidget(1, 3).isChecked(): + resample_values_row2 = "None" + else: + resample_values_row2 = [ + self.tableWidgetResample.cellWidget(1, 0).value, + self.tableWidgetResample.cellWidget(1, 1).value, + self.tableWidgetResample.cellWidget(1, 2).value + ] + + return resample_values_row1, resample_values_row2 + + + + def onCheckboxOrientClicked(self, row, col, state): + """ + Handle the click event of the orientation checkboxes in the table. + + Parameters: + - row: The row index of the clicked checkbox. + - col: The column index of the clicked checkbox. + - state: The state of the clicked checkbox (2 for checked, 0 for unchecked). + + This function updates the orientation checkboxes in the table based on the user's selection. + It ensures that only one checkbox per row can be set to '1' (or '-1' if the "Minus" column is checked) + and that the rest are set to '0'. Additionally, if the "Minus" column checkbox is checked, it sets + the text to 'Yes' and updates related checkboxes in the same row accordingly. The function also handles + unchecking a checkbox and updating the styles and texts of other checkboxes in the same row and column. + + This function is connected to the checkboxes for the orientation of the MRI. When a checkbox is clicked, + it ensures the correct orientation is set, following the specified rules. + """ + if col == 3: # If the "Minus" column checkbox is clicked + if state == 2: # Checkbox is checked + self.minus_checked_rows.add(row) + checkBox = self.tableWidgetOrient.cellWidget(row, col) + checkBox.setText('Yes') + for c in range(3): + checkBox = self.tableWidgetOrient.cellWidget(row, c) + if checkBox.text=="1": + checkBox.setText('-1') + else: # Checkbox is unchecked + self.minus_checked_rows.discard(row) + checkBox = self.tableWidgetOrient.cellWidget(row, col) + checkBox.setText('No') + for c in range(3): + checkBox = self.tableWidgetOrient.cellWidget(row, c) + if checkBox.text=="-1": + checkBox.setText('1') + else : + if state == 2: # Checkbox is checked + # Set the clicked checkbox to '1' and uncheck all others in the same row + for c in range(3): + checkBox = self.tableWidgetOrient.cellWidget(row, c) + if checkBox: + if c == col: + if row in self.minus_checked_rows: + checkBox.setText('-1') + else : + checkBox.setText('1') + checkBox.setStyleSheet("color: black;") + checkBox.setStyleSheet("font-weight: bold;") + self.checked_cells.add((row, col)) + else: + checkBox.setText('0') + checkBox.setChecked(False) + self.checked_cells.discard((row, c)) + + # Check for other '1' in the same column and set them to '0' + for r in range(3): + if r != row: + checkBox = self.tableWidgetOrient.cellWidget(r, col) + if checkBox and (checkBox.text == '1' or checkBox.text == '-1'): + checkBox.setText('0') + checkBox.setChecked(False) + checkBox.setStyleSheet("color: gray;") + checkBox.setStyleSheet("font-weight: normal;") + self.checked_cells.discard((r, col)) + + # Check if two checkboxes are checked in different rows, then check the third one + if len(self.checked_cells) == 2: + all_rows = {0, 1, 2} + all_cols = {0, 1, 2} + checked_rows = {r for r, c in self.checked_cells} + unchecked_row = list(all_rows - checked_rows)[0] + + # Find the unchecked column + unchecked_cols = list(all_cols - {c for r, c in self.checked_cells}) + # print("unchecked_cols : ",unchecked_cols) + for c in range(3): + checkBox = self.tableWidgetOrient.cellWidget(unchecked_row, c) + if c in unchecked_cols: + checkBox.setStyleSheet("color: black;") + checkBox.setStyleSheet("font-weight: bold;") + checkBox.setChecked(True) + if unchecked_row in self.minus_checked_rows: + checkBox.setText('-1') + else : + checkBox.setText('1') + self.checked_cells.add((unchecked_row, c)) + else : + checkBox.setText('0') + checkBox.setChecked(False) + self.checked_cells.discard((row, c)) + + else: # Checkbox is unchecked + checkBox = self.tableWidgetOrient.cellWidget(row, col) + if checkBox: + checkBox.setText('0') + checkBox.setStyleSheet("color: black;") + checkBox.setStyleSheet("font-weight: normal;") + self.checked_cells.discard((row, col)) + + # Reset the style of all checkboxes in the same row + for c in range(3): + checkBox = self.tableWidgetOrient.cellWidget(row, c) + if checkBox: + checkBox.setStyleSheet("color: black;") + checkBox.setStyleSheet("font-weight: normal;") + + def getCheckboxValuesOrient(self): + """ + Retrieve the values of the orientation checkboxes in the table. + + This function iterates through each checkbox in a 3x3 grid within the tableWidgetOrient. + It collects the integer value (text) of each checkbox and stores them in a list, which is + then converted to a tuple and returned. + + Returns: + - A tuple containing the integer values of the checkboxes, representing the orientation of the MRI. + """ + values = [] + for row in range(3): + for col in range(3): + checkBox = self.tableWidgetOrient.cellWidget(row, col) + if checkBox: + values.append(int(checkBox.text)) + return tuple(values) + + def defaultOrientMRI(self): + """ + Set the default orientation values for the MRI checkboxes in the table. + + This function initializes the orientation of the MRI by setting specific checkboxes + to predefined values. It iterates through a list of initial states, where each state + is a tuple containing the row, column, and value to set. The value can be 1, -1, or 0. + The corresponding checkbox is checked and its text is set accordingly. Additionally, + the checkbox style is updated to make the checked state bold, and the respective sets + (checked_cells and minus_checked_rows) are updated. + + The initial states are: + - Row 0, Column 2: Set to -1 + - Row 1, Column 0: Set to 1 + - Row 2, Column 1: Set to -1 + """ + initial_states = [ + (0, 2, -1), + (1, 0, 1), + (2, 1, -1) + ] + for row, col, value in initial_states: + checkBox = self.tableWidgetOrient.cellWidget(row, col) + if checkBox: + if value == 1: + checkBox.setChecked(True) + checkBox.setText('1') + checkBox.setStyleSheet("font-weight: bold;") + self.checked_cells.add((row, col)) + elif value == -1: + checkBox.setChecked(True) + checkBox.setText('-1') + checkBox.setStyleSheet("font-weight: bold;") + minus_checkBox = self.tableWidgetOrient.cellWidget(row, 3) + if minus_checkBox: + minus_checkBox.setChecked(True) + minus_checkBox.setText("Yes") + self.minus_checked_rows.add(row) + + def cleanup(self) -> None: + """Called when the application closes and the module widget is destroyed.""" + self.removeObservers() + + def enter(self) -> None: + """Called each time the user opens this module.""" + # Make sure parameter node exists and observed + self.initializeParameterNode() + + def exit(self) -> None: + """Called each time the user opens a different module.""" + # Do not react to parameter node changes (GUI will be updated when the user enters into the module) + if self._parameterNode: + self._parameterNode.disconnectGui(self._parameterNodeGuiTag) + self._parameterNodeGuiTag = None + self.removeObserver(self._parameterNode, vtk.vtkCommand.ModifiedEvent, self._checkCanApply) + + def onSceneStartClose(self, caller, event) -> None: + """Called just before the scene is closed.""" + pass + + def onSceneEndClose(self, caller, event) -> None: + """Called just after the scene is closed.""" + if self.parent.isEntered: + self.initializeParameterNode() + + def initializeParameterNode(self) -> None: + """Ensure parameter node exists and observed.""" + pass + + + def _checkCanApply(self, caller=None, event=None) -> None: + pass + + def getNormalization(self): + """ + Retrieve the normalization values from the table. + + This function iterates through each cell in the tableWidgetNorm, collecting the values + of QSpinBox widgets. It stores these values in a nested list, where each sublist represents + a row of values. The collected values are then returned as a list of lists. + + Returns: + - A list of lists containing the values of the QSpinBox widgets in the tableWidgetNorm. + """ + values = [] + for row in range(self.tableWidgetNorm.rowCount): + rowData = [] + for col in range(self.tableWidgetNorm.columnCount): + widget = self.tableWidgetNorm.cellWidget(row, col) + if isinstance(widget, QSpinBox): + rowData.append(widget.value) + values.append(rowData) + return(values) + + def DefaultNorm(self,num : str,_)->None: + """ + Set default normalization values in the tableWidgetNorm based on the identifier 'num'. + + If 'num' is "1", set specific default values; otherwise, use another set of values. + + Parameters: + - num: Identifier to select the set of default values. + - _: Unused parameter. + """ + # Define the default values for each cell + if num=="1": + default_values = [ + [0, 100, 0, 100], + [0, 75, 10, 95] + ] + else : + default_values = [ + [0, 100, 10, 95], + [0, 100, 10, 95] + ] + + for row in range(self.tableWidgetNorm.rowCount): + for col in range(self.tableWidgetNorm.columnCount): + spinBox = QSpinBox() + spinBox.setMaximum(10000) + spinBox.setValue(default_values[row][col]) + self.tableWidgetNorm.setCellWidget(row, col, spinBox) + + def openFinder(self,nom : str,_) -> None : + """ + Open finder to let the user choose is files or folder + """ + if nom=="InputMRI": + if self.ui.ComboBoxMRI.currentIndex==1: + print("oui") + surface_folder = QFileDialog.getExistingDirectory(self.parent, "Select a scan folder") + else : + surface_folder = QFileDialog.getOpenFileName(self.parent,'Open a file',) + + self.ui.LineEditMRI.setText(surface_folder) + + elif nom=="InputCBCT": + if self.ui.ComboBoxCBCT.currentIndex==1: + surface_folder = QFileDialog.getExistingDirectory(self.parent, "Select a scan folder") + else : + surface_folder = QFileDialog.getOpenFileName(self.parent,'Open a file',) + self.ui.LineEditCBCT.setText(surface_folder) + + elif nom=="InputRegCBCT": + if self.ui.comboBoxRegCBCT.currentIndex==1: + surface_folder = QFileDialog.getExistingDirectory(self.parent, "Select a scan folder") + else : + surface_folder = QFileDialog.getOpenFileName(self.parent,'Open a file',) + self.ui.lineEditRegCBCT.setText(surface_folder) + + elif nom=="InputRegMRI": + if self.ui.comboBoxRegMRI.currentIndex==1: + surface_folder = QFileDialog.getExistingDirectory(self.parent, "Select a scan folder") + else : + surface_folder = QFileDialog.getOpenFileName(self.parent,'Open a file',) + self.ui.lineEditRegMRI.setText(surface_folder) + + elif nom=="InputRegLabel": + if self.ui.comboBoxRegLabel.currentIndex==1: + surface_folder = QFileDialog.getExistingDirectory(self.parent, "Select a scan folder") + else : + surface_folder = QFileDialog.getOpenFileName(self.parent,'Open a file',) + self.ui.lineEditRegLabel.setText(surface_folder) + + + elif nom=="OutputOrientCBCT": + surface_folder = QFileDialog.getExistingDirectory(self.parent, "Select a scan folder") + self.ui.lineEditOutputOrientCBCT.setText(surface_folder) + + elif nom=="OutputOrientMRI": + surface_folder = QFileDialog.getExistingDirectory(self.parent, "Select a scan folder") + self.ui.lineEditOutputOrientMRI.setText(surface_folder) + + elif nom=="OutputOrientResample": + surface_folder = QFileDialog.getExistingDirectory(self.parent, "Select a scan folder") + self.ui.lineEditOuputResample.setText(surface_folder) + + elif nom=="OutputReg": + surface_folder = QFileDialog.getExistingDirectory(self.parent, "Select a scan folder") + self.ui.LineEditOutput.setText(surface_folder) + + + + def downloadModel(self, lineEdit, name, test,_): + """ + Download model files from the URL(s) provided by the getModelUrl function. + + Parameters: + - lineEdit: The QLineEdit widget to update with the model folder path. + - name: The name of the model to download. + - test: A flag for testing purposes (unused in this function). + - _: Unused parameter for compatibility. + + This function fetches the model URL(s) using getModelUrl, downloads the files, + unzips them to the appropriate directory, and updates the lineEdit with the model + folder path. It also runs a test on the downloaded model and shows a warning message + if any errors occur. + """ + + # To select the reference files (CBCT Orientation and Registration mode only) + listmodel = self.preprocess_cbct.getModelUrl() + print("listmodel : ",listmodel) + + urls = listmodel[name] + if isinstance(urls, str): + url = urls + _ = self.DownloadUnzip( + url=url, + directory=os.path.join(self.SlicerDownloadPath), + folder_name=os.path.join("Models", name), + num_downl=1, + total_downloads=1, + ) + model_folder = os.path.join(self.SlicerDownloadPath, "Models", name) + + elif isinstance(urls, dict): + for i, (name_bis, url) in enumerate(urls.items()): + _ = self.DownloadUnzip( + url=url, + directory=os.path.join(self.SlicerDownloadPath), + folder_name=os.path.join("Models", name, name_bis), + num_downl=i + 1, + total_downloads=len(urls), + ) + model_folder = os.path.join(self.SlicerDownloadPath, "Models", name) + + if not model_folder == "": + error = self.preprocess_cbct.TestModel(model_folder, lineEdit.name) + + if isinstance(error, str): + QMessageBox.warning(self.parent, "Warning", error) + + else: + lineEdit.setText(model_folder) + + def DownloadUnzip( + self, url, directory, folder_name=None, num_downl=1, total_downloads=1 + ): + """ + Download and unzip a file from a given URL to a specified directory. + + Parameters: + - url: The URL of the zip file to download. + - directory: The directory where the file should be downloaded and unzipped. + - folder_name: The name of the folder to create and unzip the contents into. + - num_downl: The current download number (for progress display). + - total_downloads: The total number of downloads (for progress display). + + Returns: + - out_path: The path to the unzipped folder. + """ + + out_path = os.path.join(directory, folder_name) + + if not os.path.exists(out_path): + # print("Downloading {}...".format(folder_name.split(os.sep)[0])) + os.makedirs(out_path) + + temp_path = os.path.join(directory, "temp.zip") + + # Download the zip file from the url + with urllib.request.urlopen(url) as response, open( + temp_path, "wb" + ) as out_file: + # Pop up a progress bar with a QProgressDialog + progress = QProgressDialog( + "Downloading {} (File {}/{})".format( + folder_name.split(os.sep)[0], num_downl, total_downloads + ), + "Cancel", + 0, + 100, + self.parent, + ) + progress.setCancelButton(None) + progress.setWindowModality(qt.Qt.WindowModal) + progress.setWindowTitle( + "Downloading {}...".format(folder_name.split(os.sep)[0]) + ) + # progress.setWindowFlags(qt.Qt.WindowStaysOnTopHint) + progress.show() + length = response.info().get("Content-Length") + if length: + length = int(length) + blocksize = max(4096, length // 100) + read = 0 + while True: + buffer = response.read(blocksize) + if not buffer: + break + read += len(buffer) + out_file.write(buffer) + progress.setValue(read * 100.0 / length) + QApplication.processEvents() + shutil.copyfileobj(response, out_file) + + # Unzip the file + with zipfile.ZipFile(temp_path, "r") as zip: + zip.extractall(out_path) + + # Delete the zip file + os.remove(temp_path) + + return out_path + + def orientCBCT(self)->None: + """ + This function is called when the button "pushButtonOrientCBCT" is click. + Orient CBCT images using specified parameters and initiate the processing pipeline. + + This function sets up the parameters for CBCT image orientation, tests the process and scan, + and starts the processing pipeline if all checks pass. It handles the initial setup, + parameter passing, and process initiation, including setting up observers for process updates. + """ + + param = {"input_t1_folder":self.ui.LineEditCBCT.text, + "folder_output":self.ui.lineEditOutputOrientCBCT.text, + "model_folder_1":self.ui.lineEditSegCBCT.text, + "merge_seg":False, + "isDCMInput":False, + "slicerDownload":self.SlicerDownloadPath} + + ok,mess = self.preprocess_cbct.TestProcess(**param) + if not ok : + self.showMessage(mess) + return + ok,mess = self.preprocess_cbct.TestScan(param["input_t1_folder"]) + if not ok : + self.showMessage(mess) + return + + self.list_Processes_Parameters = self.preprocess_cbct.Process(**param) + + self.onProcessStarted() + + # /!\ Launch of the first process /!\ + print("module name : ",self.list_Processes_Parameters[0]["Module"]) + print("Parameters : ",self.list_Processes_Parameters[0]["Parameter"]) + + self.process = slicer.cli.run( + self.list_Processes_Parameters[0]["Process"], + None, + self.list_Processes_Parameters[0]["Parameter"], + ) + + self.module_name = self.list_Processes_Parameters[0]["Module"] + self.processObserver = self.process.AddObserver( + "ModifiedEvent", self.onProcessUpdate + ) + + del self.list_Processes_Parameters[0] + + def orientCenterMRI(self): + """ + This function is called when the button "pushButtonOrientMRI" is click. + Orient and center MRI images using specified parameters and initiate the processing pipeline. + + This function sets up the parameters for MRI image orientation and centering, tests the process and scan, + and starts the processing pipeline if all checks pass. It handles the initial setup, parameter passing, + and process initiation, including setting up observers for process updates. + """ + + param = {"input_folder":self.ui.LineEditMRI.text, + "direction":self.getCheckboxValuesOrient(), + "output_folder":self.ui.lineEditOutputOrientMRI.text} + + ok,mess = self.preprocess_mri.TestProcess(**param) + if not ok : + self.showMessage(mess) + return + ok,mess = self.preprocess_mri.TestScan(param["input_folder"]) + if not ok : + self.showMessage(mess) + return + + self.list_Processes_Parameters = self.preprocess_mri.Process(**param) + + self.onProcessStarted() + + # /!\ Launch of the first process /!\ + print("module name : ",self.list_Processes_Parameters[0]["Module"]) + print("Parameters : ",self.list_Processes_Parameters[0]["Parameter"]) + + self.process = slicer.cli.run( + self.list_Processes_Parameters[0]["Process"], + None, + self.list_Processes_Parameters[0]["Parameter"], + ) + + self.module_name = self.list_Processes_Parameters[0]["Module"] + self.processObserver = self.process.AddObserver( + "ModifiedEvent", self.onProcessUpdate + ) + + del self.list_Processes_Parameters[0] + + def resampleMRICBCT(self): + """ + Resample MRI and/or CBCT images based on the selected options and initiate the processing pipeline. + + This function determines which input folders (MRI, CBCT, or both) to use based on the user's selection + in the comboBoxResample widget. It sets up the resampling parameters, tests the process and scans, + and starts the processing pipeline if all checks pass. The function handles the initial setup, parameter + passing, and process initiation, including setting up observers for process updates. + """ + + if self.ui.comboBoxResample.currentText=="CBCT": + LineEditMRI = "None" + LineEditCBCT = self.ui.LineEditCBCT.text + elif self.ui.comboBoxResample.currentText=="MRI": + LineEditMRI = self.ui.LineEditMRI.text + LineEditCBCT = "None" + else : + LineEditMRI = self.ui.LineEditMRI.text + LineEditCBCT = self.ui.LineEditCBCT.text + + param = {"input_folder_MRI": LineEditMRI, + "input_folder_CBCT": LineEditCBCT, + "output_folder": self.ui.lineEditOuputResample.text, + "resample_size": self.get_resample_values()[0], + "spacing" : self.get_resample_values()[1] + } + + ok,mess = self.preprocess_mri_cbct.TestProcess(**param) + if not ok : + self.showMessage(mess) + return + + ok,mess = self.preprocess_mri_cbct.TestScan(param["input_folder_MRI"]) + + if not ok : + mess = mess + "MRI folder" + self.showMessage(mess) + return + + ok,mess = self.preprocess_mri_cbct.TestScan(param["input_folder_CBCT"]) + if not ok : + mess = mess + "CBCT folder" + self.showMessage(mess) + return + + + self.list_Processes_Parameters = self.preprocess_mri_cbct.Process(**param) + + self.onProcessStarted() + + # /!\ Launch of the first process /!\ + print("module name : ",self.list_Processes_Parameters[0]["Module"]) + print("Parameters : ",self.list_Processes_Parameters[0]["Parameter"]) + + self.process = slicer.cli.run( + self.list_Processes_Parameters[0]["Process"], + None, + self.list_Processes_Parameters[0]["Parameter"], + ) + + self.module_name = self.list_Processes_Parameters[0]["Module"] + self.processObserver = self.process.AddObserver( + "ModifiedEvent", self.onProcessUpdate + ) + + del self.list_Processes_Parameters[0] + + + def registration_MR2CBCT(self) -> None: + """ + Register MRI images to CBCT images using specified parameters and initiate the processing pipeline. + + This function sets up the parameters for MRI to CBCT registration, tests the process and scans, + and starts the processing pipeline if all checks pass. It handles the initial setup, parameter passing, + and process initiation, including setting up observers for process updates. The function also checks + for normalization parameters and validates input folders for the presence of necessary files. + """ + + param = {"folder_general": self.ui.LineEditOutput.text, + "mri_folder": self.ui.lineEditRegMRI.text, + "cbct_folder": self.ui.lineEditRegCBCT.text, + "cbct_label2": self.ui.lineEditRegLabel.text, + "normalization" : [self.getNormalization()], + "tempo_fold" : self.ui.checkBoxTompraryFold.isChecked()} + + ok,mess = self.registration_mri2cbct.TestProcess(**param) + if not ok : + self.showMessage(mess) + return + + ok1,mess = self.registration_mri2cbct.TestScan(param["mri_folder"]) + ok2,mess2 = self.registration_mri2cbct.TestScan(param["cbct_folder"]) + ok3,mess3 = self.registration_mri2cbct.TestScan(param["cbct_label2"]) + + error_messages = [] + + if not ok1: + error_messages.append("MRI folder") + if not ok2: + error_messages.append("CBCT folder") + if not ok3: + error_messages.append("CBCT label2 folder") + + if error_messages: + error_message = "No files to run has been found in the following folders: " + ", ".join(error_messages) + self.showMessage(error_message) + return + + ok,mess = self.registration_mri2cbct.CheckNormalization(param["normalization"]) + if not ok : + self.showMessage(mess) + return + + self.list_Processes_Parameters = self.registration_mri2cbct.Process(**param) + + self.onProcessStarted() + + # /!\ Launch of the first process /!\ + print("module name : ",self.list_Processes_Parameters[0]["Module"]) + print("Parameters : ",self.list_Processes_Parameters[0]["Parameter"]) + + self.process = slicer.cli.run( + self.list_Processes_Parameters[0]["Process"], + None, + self.list_Processes_Parameters[0]["Parameter"], + ) + + self.module_name = self.list_Processes_Parameters[0]["Module"] + self.processObserver = self.process.AddObserver( + "ModifiedEvent", self.onProcessUpdate + ) + + del self.list_Processes_Parameters[0] + + + + def onProcessStarted(self): + """ + Initialize and update the UI components when a process starts. + + This function sets the start time, initializes the progress bar and related UI elements, + and updates the process-related attributes such as the number of extensions and modules. + It also enables the running state UI to reflect that a process is in progress. + """ + self.startTime = time.time() + + # self.ui.progressBar.setMaximum(self.nb_patient) + self.ui.progressBar.setValue(0) + self.ui.progressBar.setTextVisible(True) + self.ui.progressBar.setFormat("0%") + + self.ui.label_info.setText(f"Starting process") + + self.nb_extnesion_did = 0 + self.nb_extension_launch = len(self.list_Processes_Parameters) + + self.module_name_before = 0 + self.nb_change_bystep = 0 + + self.RunningUI(True) + + def onProcessUpdate(self, caller, event): + """ + Update the UI components during the process execution and handle process completion. + + This function updates the progress bar, time label, and information label during the process execution. + It handles the completion of each process step, manages errors, and initiates the next process if available. + + Parameters: + - caller: The process that triggered the update. + - event: The event that triggered the update. + """ + + self.ui.progressBar.setVisible(False) + # timer = f"Time : {time.time()-self.startTime:.2f}s" + currentTime = time.time() - self.startTime + if currentTime < 60: + timer = f"Time : {int(currentTime)}s" + elif currentTime < 3600: + timer = f"Time : {int(currentTime/60)}min and {int(currentTime%60)}s" + else: + timer = f"Time : {int(currentTime/3600)}h, {int(currentTime%3600/60)}min and {int(currentTime%60)}s" + + self.ui.label_time.setText(timer) + # self.module_name = caller.GetModuleTitle() if self.module_name_bis is None else self.module_name_bis + self.ui.label_info.setText(f"Extension {self.module_name} is running. \nNumber of extension runned : {self.nb_extnesion_did} / {self.nb_extension_launch}") + # self.displayModule = self.displayModule_bis if self.displayModule_bis is not None else self.display[self.module_name.split(' ')[0]] + + if self.module_name_before != self.module_name: + print("Valeur progress barre : ",100*self.nb_extnesion_did/self.nb_extension_launch) + self.ui.progressBar.setValue(self.nb_extnesion_did/self.nb_extension_launch) + self.ui.progressBar.setFormat(f"{100*self.nb_extnesion_did/self.nb_extension_launch}%") + self.nb_extnesion_did += 1 + self.ui.label_info.setText( + f"Extension {self.module_name} is running. \nNumber of extension runned : {self.nb_extnesion_did} / {self.nb_extension_launch}" + ) + + + self.module_name_before = self.module_name + self.nb_change_bystep = 0 + + + if caller.GetStatus() & caller.Completed: + if caller.GetStatus() & caller.ErrorsMask: + # error + print("\n\n ========= PROCESSED ========= \n") + + print(self.process.GetOutputText()) + print("\n\n ========= ERROR ========= \n") + errorText = self.process.GetErrorText() + print("CLI execution failed: \n \n" + errorText) + # error + # errorText = caller.GetErrorText() + # print("\n"+ 70*"=" + "\n\n" + errorText) + # print(70*"=") + self.onCancel() + + else: + print("\n\n ========= PROCESSED ========= \n") + # print("PROGRESS :",self.displayModule.progress) + + print(self.process.GetOutputText()) + try: + print("name process : ",self.list_Processes_Parameters[0]["Process"]) + self.process = slicer.cli.run( + self.list_Processes_Parameters[0]["Process"], + None, + self.list_Processes_Parameters[0]["Parameter"], + ) + self.module_name = self.list_Processes_Parameters[0]["Module"] + self.processObserver = self.process.AddObserver( + "ModifiedEvent", self.onProcessUpdate + ) + del self.list_Processes_Parameters[0] + # self.displayModule.progress = 0 + except IndexError: + self.OnEndProcess() + + def OnEndProcess(self): + """ + Finalize the process execution and update the UI components accordingly. + + This function increments the number of completed extensions, updates the information label, + resets the progress bar, calculates the total time taken, and displays a message box indicating + the completion of the process. It also disables the running state UI. + """ + + self.nb_extnesion_did += 1 + self.ui.label_info.setText( + f"Process end" + ) + self.ui.progressBar.setValue(0) + + self.module_name_before = self.module_name + self.nb_change_bystep = 0 + total_time = time.time() - self.startTime + + + print("PROCESS DONE.") + print( + "Done in {} min and {} sec".format( + int(total_time / 60), int(total_time % 60) + ) + ) + + self.RunningUI(False) + + stopTime = time.time() + + msg = QMessageBox() + msg.setIcon(QMessageBox.Information) + + # setting message for Message Box + msg.setText(f"Processing completed in {int(total_time / 60)} min and {int(total_time % 60)} sec") + + # setting Message box window title + msg.setWindowTitle("Information") + + # declaring buttons on Message Box + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + + + def onCancel(self): + self.process.Cancel() + + self.RunningUI(False) + + def RunningUI(self, run=False): + + self.ui.progressBar.setVisible(run) + self.ui.label_time.setVisible(run) + self.ui.label_info.setVisible(run) + + def showMessage(self,mess): + msg = QMessageBox() + msg.setIcon(QMessageBox.Information) + + # setting message for Message Box + msg.setText(mess) + + # setting Message box window title + msg.setWindowTitle("Information") + + # declaring buttons on Message Box + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + + + + + + +# +# MRI2CBCTLogic +# + + +class MRI2CBCTLogic(ScriptedLoadableModuleLogic): + """This class should implement all the actual + computation done by your module. The interface + should be such that other python code can import + this class and make use of the functionality without + requiring an instance of the Widget. + Uses ScriptedLoadableModuleLogic base class, available at: + https://github.com/Slicer/Slicer/blob/main/Base/Python/slicer/ScriptedLoadableModule.py + """ + + def __init__(self) -> None: + """Called when the logic class is instantiated. Can be used for initializing member variables.""" + ScriptedLoadableModuleLogic.__init__(self) + + def getParameterNode(self): + return MRI2CBCTParameterNode(super().getParameterNode()) + + def process(self, + inputVolume: vtkMRMLScalarVolumeNode, + outputVolume: vtkMRMLScalarVolumeNode, + imageThreshold: float, + invert: bool = False, + showResult: bool = True) -> None: + """ + Run the processing algorithm. + Can be used without GUI widget. + :param inputVolume: volume to be thresholded + :param outputVolume: thresholding result + :param imageThreshold: values above/below this threshold will be set to 0 + :param invert: if True then values above the threshold will be set to 0, otherwise values below are set to 0 + :param showResult: show output volume in slice viewers + """ + + if not inputVolume or not outputVolume: + raise ValueError("Input or output volume is invalid") + + import time + + startTime = time.time() + logging.info("Processing started") + + # Compute the thresholded output volume using the "Threshold Scalar Volume" CLI module + cliParams = { + "InputVolume": inputVolume.GetID(), + "OutputVolume": outputVolume.GetID(), + "ThresholdValue": imageThreshold, + "ThresholdType": "Above" if invert else "Below", + } + cliNode = slicer.cli.run(slicer.modules.thresholdscalarvolume, None, cliParams, wait_for_completion=True, update_display=showResult) + # We don't need the CLI module node anymore, remove it to not clutter the scene with it + slicer.mrmlScene.RemoveNode(cliNode) + + stopTime = time.time() + logging.info(f"Processing completed in {stopTime-startTime:.2f} seconds") + + +# +# MRI2CBCTTest +# + + +class MRI2CBCTTest(ScriptedLoadableModuleTest): + """ + This is the test case for your scripted module. + Uses ScriptedLoadableModuleTest base class, available at: + https://github.com/Slicer/Slicer/blob/main/Base/Python/slicer/ScriptedLoadableModule.py + """ + + def setUp(self): + """Do whatever is needed to reset the state - typically a scene clear will be enough.""" + slicer.mrmlScene.Clear() + + def runTest(self): + """Run as few or as many tests as needed here.""" + self.setUp() + self.test_MRI2CBCT1() + + def test_MRI2CBCT1(self): + """Ideally you should have several levels of tests. At the lowest level + tests should exercise the functionality of the logic with different inputs + (both valid and invalid). At higher levels your tests should emulate the + way the user would interact with your code and confirm that it still works + the way you intended. + One of the most important features of the tests is that it should alert other + developers when their changes will have an impact on the behavior of your + module. For example, if a developer removes a feature that you depend on, + your test should break so they know that the feature is needed. + """ + + self.delayDisplay("Starting the test") + + # Get/create input data + + import SampleData + + registerSampleData() + inputVolume = SampleData.downloadSample("MRI2CBCT1") + self.delayDisplay("Loaded test data set") + + inputScalarRange = inputVolume.GetImageData().GetScalarRange() + self.assertEqual(inputScalarRange[0], 0) + self.assertEqual(inputScalarRange[1], 695) + + outputVolume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLScalarVolumeNode") + threshold = 100 + + # Test the module logic + + logic = MRI2CBCTLogic() + + # Test algorithm with non-inverted threshold + logic.process(inputVolume, outputVolume, threshold, True) + outputScalarRange = outputVolume.GetImageData().GetScalarRange() + self.assertEqual(outputScalarRange[0], inputScalarRange[0]) + self.assertEqual(outputScalarRange[1], threshold) + + # Test algorithm with inverted threshold + logic.process(inputVolume, outputVolume, threshold, False) + outputScalarRange = outputVolume.GetImageData().GetScalarRange() + self.assertEqual(outputScalarRange[0], inputScalarRange[0]) + self.assertEqual(outputScalarRange[1], inputScalarRange[1]) + + self.delayDisplay("Test passed") diff --git a/MRI2CBCT/Resources/Icons/MRI2CBCT.png b/MRI2CBCT/Resources/Icons/MRI2CBCT.png new file mode 100644 index 0000000..5d83ab4 Binary files /dev/null and b/MRI2CBCT/Resources/Icons/MRI2CBCT.png differ diff --git a/MRI2CBCT/Resources/UI/MRI2CBCT.ui b/MRI2CBCT/Resources/UI/MRI2CBCT.ui new file mode 100644 index 0000000..b7d7e11 --- /dev/null +++ b/MRI2CBCT/Resources/UI/MRI2CBCT.ui @@ -0,0 +1,581 @@ + + + Matrix_bis + + + + 0 + 0 + 735 + 1210 + + + + + + + Inputs + + + + + + + + + + Search + + + + + + + Search + + + + + + + + File + + + + + Folder + + + + + + + + Input MRI files(s): + + + + + + + + + + Input CBCT file(s) : + + + + + + + + File + + + + + Folder + + + + + + + + + + + + + + + __________________________________________________________________________________________________________________________________ + + + + + + + Orientation + Segmentation of the CBCT + + + + + + + + + Download + + + + + + + Output folder : + + + + + + + + + + + + + Search + + + + + + + + + + Orientation Model : + + + + + + + Segmentation Model : + + + + + + + Download + + + + + + + + + + + + + + + Orient and Segment CBCT + + + + + + + __________________________________________________________________________________________________________________________________ + + + + + + + Orientation + centering of MRI + + + + + + + + + + + + + + + + Default + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + + + + + Output folder : + + + + + + + + + + Search + + + + + + + + + Orient and centering MRI + + + + + + + __________________________________________________________________________________________________________________________________ + + + + + + + Resample + + + + + + + + + + + + Output folder : + + + + + + + + + + Search + + + + + + + + + + + + MRI & CBCT + + + + + MRI + + + + + CBCT + + + + + + + + Run resample + + + + + + + + + + + + Output + + + + + + + + + + + + + + + + Output folder : + + + + + + + Suffix : + + + + + + + Search + + + + + + + _reg + + + + + + + Keep the temporary folder + + + + + + + + + + + + + Default 1 + + + + + + + Default 2 + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + + + + + Input CBCT : + + + + + + + + + + + + + Input MRI : + + + + + + + Search + + + + + + + + + + Search + + + + + + + Input Seg CBCT : + + + + + + + Search + + + + + + + + File + + + + + Folder + + + + + + + + + File + + + + + Folder + + + + + + + + + File + + + + + Folder + + + + + + + + + + true + + + Run the algorithm. + + + Registration + + + + + + + + + + + + + + true + + + 0 + + + + + + + + + + Number of processed files : + + + + + + + time : + + + + + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + + ctkCollapsibleButton + QWidget +
ctkCollapsibleButton.h
+ 1 +
+ + qMRMLWidget + QWidget +
qMRMLWidget.h
+ 1 +
+
+ + +
diff --git a/MRI2CBCT/Testing/CMakeLists.txt b/MRI2CBCT/Testing/CMakeLists.txt new file mode 100644 index 0000000..655007a --- /dev/null +++ b/MRI2CBCT/Testing/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Python) diff --git a/MRI2CBCT/Testing/Python/CMakeLists.txt b/MRI2CBCT/Testing/Python/CMakeLists.txt new file mode 100644 index 0000000..5658d8b --- /dev/null +++ b/MRI2CBCT/Testing/Python/CMakeLists.txt @@ -0,0 +1,2 @@ + +#slicer_add_python_unittest(SCRIPT ${MODULE_NAME}ModuleTest.py) diff --git a/MRI2CBCT/utils/Method.py b/MRI2CBCT/utils/Method.py new file mode 100644 index 0000000..d81a426 --- /dev/null +++ b/MRI2CBCT/utils/Method.py @@ -0,0 +1,122 @@ +from abc import ABC, abstractmethod +import os +import glob +import json + + +class Method(ABC): + def __init__(self, widget): + self.widget = widget + self.diccheckbox = {} + self.diccheckbox2 = {} + + @abstractmethod + def NumberScan(self, scan_folder_t1: str, scan_folder_t2: str): + """ + Count the number of patient in folder + Args: + scan_folder_t1 (str): folder path with Scan for T1 + scan_folder_t2 (str): folder path with Scan for T2 + + Return: + int : return the number of patient. + """ + pass + + @abstractmethod + def TestScan(self, scan_folder_t1: str, scan_folder_t2): + """Verify if the input folder seems good (have everything required to run the mode selected), if something is wrong the function return string with error message + + This function is called when the user want to import scan + + Args: + scan_folder (str): path of folder with scan + + Returns: + str and bool: Return str with error message if something is wrong and a boolean to indicate if there is a message + pass + """ + + + + @abstractmethod + def TestProcess(self, **kwargs) -> str: + """Check if everything is OK before launching the process, if something is wrong return string with all error + + + + Returns: + str or None: return None if there no problem with input of the process, else return str with all error + """ + pass + + @abstractmethod + def Process(self, **kwargs): + """Launch extension""" + + pass + + def search(self, path, *args): + """ + Return a dictionary with args element as key and a list of file in path directory finishing by args extension for each key + + Example: + args = ('json',['.nii.gz','.nrrd']) + return: + { + 'json' : ['path/a.json', 'path/b.json','path/c.json'], + '.nii.gz' : ['path/a.nii.gz', 'path/b.nii.gz'] + '.nrrd.gz' : ['path/c.nrrd'] + } + """ + arguments = [] + for arg in args: + if type(arg) == list: + arguments.extend(arg) + else: + arguments.append(arg) + return { + key: [ + i + for i in glob.iglob( + os.path.normpath("/".join([path, "**", "*"])), recursive=True + ) + if i.endswith(key) + ] + for key in arguments + } + + + + def getTestFileListDCM(self): + """Return a tuple with both the name and the Download link of the test files but only for DCM files (AREG CBCT) + tuple = ('name','link') + """ + pass + + def TestScanDCM(self, scan_folder_t1: str, scan_folder_t2) -> str: + """Verify if the input folder seems good (have everything required to run the mode selected), if something is wrong the function return string with error message for DCM as input + + This function is called when the user want to import scan + + Args: + scan_folder (str): path of folder with scan + + Returns: + str or None: Return str with error message if something is wrong, else return None + """ + pass + + def NumberScanDCM(self, scan_folder_t1: str, scan_folder_t2: str): + """ + Count the number of patient in folder for DCM as input + Args: + scan_folder_t1 (str): folder path with Scan for T1 + scan_folder_t2 (str): folder path with Scan for T2 + + Return: + int : return the number of patient. + """ + pass + + \ No newline at end of file diff --git a/MRI2CBCT/utils/Preprocess_CBCT.py b/MRI2CBCT/utils/Preprocess_CBCT.py new file mode 100644 index 0000000..a93b6f2 --- /dev/null +++ b/MRI2CBCT/utils/Preprocess_CBCT.py @@ -0,0 +1,143 @@ +from utils.Method import Method +from utils.utils_CBCT import GetDictPatients, GetPatients +import os, sys + +import SimpleITK as sitk +import numpy as np + +from glob import iglob +import slicer +import time +import qt +import platform + + +class Process_CBCT(Method): + def __init__(self, widget): + super().__init__(widget) + documentsLocation = qt.QStandardPaths.DocumentsLocation + documents = qt.QStandardPaths.writableLocation(documentsLocation) + self.tempAMASSS_folder = os.path.join( + documents, slicer.app.applicationName + "_temp_AMASSS" + ) + + def getGPUUsage(self): + if platform.system() == "Darwin": + return 1 + else: + return 5 + + def NumberScan(self, scan_folder_t1: str, scan_folder_t2: str): + return len(GetDictPatients(scan_folder_t1, scan_folder_t2)) + + + + def TestModel(self, model_folder: str,lineEdit:str) -> str: + if lineEdit == "lineEditSegCBCT": + if len(super().search(model_folder, "pth")["pth"]) == 0: + return "Folder must have models for mask segmentation" + else: + return None + + + def TestScan(self, scan_folder: str): + extensions = ['.nii', '.nii.gz', '.nrrd'] + found_files = self.search(scan_folder, extensions) + if any(found_files[ext] for ext in extensions): + return True, "" + else: + return False, "No files to run has been found in the input folder" + + + def TestProcess(self, **kwargs) -> str: + out = "" + ok = True + + if kwargs["input_t1_folder"] == "": + out += "Please select an input folder for CBCT scans\n" + ok = False + + if kwargs["folder_output"] == "": + out += "Please select an output folder\n" + ok = False + + if kwargs["model_folder_1"] == "": + out += "Please select a folder for segmentation models\n" + ok = False + + if out == "": + out = None + + return ok,out + + def getModelUrl(self): + return { + "Segmentation": { + "Full Face Models": "https://github.com/lucanchling/AMASSS_CBCT/releases/download/v1.0.2/AMASSS_Models.zip", + "Mask Models": "https://github.com/lucanchling/AMASSS_CBCT/releases/download/v1.0.2/Masks_Models.zip", + }, + "Orientation": { + "PreASO": "https://github.com/lucanchling/ASO_CBCT/releases/download/v01_preASOmodels/PreASOModels.zip", + "Occlusal and Midsagittal Plane": "https://github.com/lucanchling/ASO_CBCT/releases/download/v01_goldmodels/Occlusal_Midsagittal_Plane.zip", + "Frankfurt Horizontal and Midsagittal Plane": "https://github.com/lucanchling/ASO_CBCT/releases/download/v01_goldmodels/Frankfurt_Horizontal_Midsagittal_Plane.zip", + }, + } + + + def Process(self, **kwargs): + centered_T1 = kwargs["folder_output"] + "CBCT_Center" + centered_T1 = os.path.join(kwargs["folder_output"], "_CBCT_Center") + parameter_pre_aso = { + "input": kwargs["input_t1_folder"], + "output_folder": centered_T1, + "model_folder": os.path.join( + kwargs["slicerDownload"], "Models", "Orientation", "PreASO" + ), + "SmallFOV": False, + "temp_folder": "../", + "DCMInput": kwargs["isDCMInput"], + } + + PreOrientProcess = slicer.modules.pre_aso_cbct + list_process = [ + { + "Process": PreOrientProcess, + "Parameter": parameter_pre_aso, + "Module": "Centering CBCT", + # "Display": DisplayASOCBCT(nb_scan), + } + ] + + + + # AMASSS PROCESS - SEGMENTATION + AMASSSProcess = slicer.modules.amasss_cli + parameter_amasss_seg_t1 = { + "inputVolume": centered_T1, + "modelDirectory": kwargs["model_folder_1"], + "highDefinition": False, + "skullStructure": "CB", + "merge": "MERGE" if kwargs["merge_seg"] else "SEPARATE", + "genVtk": True, + "save_in_folder": True, + "output_folder": os.path.join(kwargs["folder_output"],'CBCT_Segmentation'), + "precision": 50, + "vtk_smooth": 5, + "prediction_ID": "Pred", + "gpu_usage": self.getGPUUsage(), + "cpu_usage": 1, + "temp_fold": self.tempAMASSS_folder, + "SegmentInput": False, + "DCMInput": False, + } + + list_process.append( + { + "Process": AMASSSProcess, + "Parameter": parameter_amasss_seg_t1, + "Module": "AMASSS_CBCT Segmentation of CBCT", + # "Display": DisplayAMASSS(nb_scan, len(full_seg_struct)), + } + ) + + return list_process diff --git a/MRI2CBCT/utils/Preprocess_CBCT_MRI.py b/MRI2CBCT/utils/Preprocess_CBCT_MRI.py new file mode 100644 index 0000000..6285c35 --- /dev/null +++ b/MRI2CBCT/utils/Preprocess_CBCT_MRI.py @@ -0,0 +1,94 @@ +from utils.Method import Method +from utils.utils_CBCT import GetDictPatients, GetPatients +import os, sys + +import SimpleITK as sitk +import numpy as np + +from glob import iglob +import slicer +import time +import qt +import platform + + +class Preprocess_CBCT_MRI(Method): + def __init__(self, widget): + super().__init__(widget) + documentsLocation = qt.QStandardPaths.DocumentsLocation + documents = qt.QStandardPaths.writableLocation(documentsLocation) + + def getGPUUsage(self): + if platform.system() == "Darwin": + return 1 + else: + return 5 + + def NumberScan(self, scan_folder_t1: str, scan_folder_t2: str): + return len(GetDictPatients(scan_folder_t1, scan_folder_t2)) + + + + def TestScan(self, scan_folder: str): + extensions = ['.nii', '.nii.gz', '.nrrd'] + if scan_folder!="None" : + found_files = self.search(scan_folder, extensions) + if any(found_files[ext] for ext in extensions): + return True, "" + else: + return False, "No files to run has been found in the " + return True,"" + + + def TestProcess(self, **kwargs) -> str: + out = "" + ok = True + + if kwargs["input_folder_MRI"] == "": + out += "Please select an input folder for MRI scans\n" + ok = False + + if kwargs["input_folder_CBCT"] == "": + out += "Please select an input folder for CBCT scans\n" + ok = False + + if kwargs["output_folder"] == "": + out += "Please select an output folder\n" + ok = False + + if kwargs["resample_size"] == "": + out += "Please select a new resample size\n" + ok = False + + if kwargs["spacing"] == "": + out += "Please select a new spacing\n" + ok = False + + if out == "": + out = None + + return ok,out + + def Process(self, **kwargs): + list_process=[] + # MRI2CBCT_ORIENT_CENTER_MRI + MRI2CBCT_RESAMPLE_CBCT_MRI = slicer.modules.mri2cbct_resample_cbct_mri + parameter_mri2cbct_resample_cbct_mri = { + "input_folder_MRI": kwargs["input_folder_MRI"], + "input_folder_CBCT": kwargs["input_folder_CBCT"], + "output_folder": kwargs["output_folder"], + "resample_size": kwargs["resample_size"], + "spacing" : kwargs["spacing"] + } + + list_process.append( + { + "Process": MRI2CBCT_RESAMPLE_CBCT_MRI, + "Parameter": parameter_mri2cbct_resample_cbct_mri, + "Module": "Resample files", + } + ) + + return list_process + + diff --git a/MRI2CBCT/utils/Preprocess_MRI.py b/MRI2CBCT/utils/Preprocess_MRI.py new file mode 100644 index 0000000..372a47b --- /dev/null +++ b/MRI2CBCT/utils/Preprocess_MRI.py @@ -0,0 +1,82 @@ +from utils.Method import Method +from utils.utils_CBCT import GetDictPatients, GetPatients +import os, sys + +import SimpleITK as sitk +import numpy as np + +from glob import iglob +import slicer +import time +import qt +import platform + + +class Process_MRI(Method): + def __init__(self, widget): + super().__init__(widget) + documentsLocation = qt.QStandardPaths.DocumentsLocation + documents = qt.QStandardPaths.writableLocation(documentsLocation) + + def getGPUUsage(self): + if platform.system() == "Darwin": + return 1 + else: + return 5 + + def NumberScan(self, scan_folder_t1: str, scan_folder_t2: str): + return len(GetDictPatients(scan_folder_t1, scan_folder_t2)) + + + + def TestScan(self, scan_folder: str): + extensions = ['.nii', '.nii.gz', '.nrrd'] + found_files = self.search(scan_folder, extensions) + if any(found_files[ext] for ext in extensions): + return True, "" + else: + return False, "No files to run has been found in the input folder" + + + def TestProcess(self, **kwargs) -> str: + out = "" + ok = True + + if kwargs["input_folder"] == "": + out += "Please select an input folder for CBCT scans\n" + ok = False + + if kwargs["output_folder"] == "": + out += "Please select an output folder\n" + ok = False + + if kwargs["direction"] == "": + out += "Please select a new direction for X,Y and Z\n" + ok = False + + if out == "": + out = None + + return ok,out + + def Process(self, **kwargs): + list_process=[] + # MRI2CBCT_ORIENT_CENTER_MRI + MRI2CBCT_ORIENT_CENTER_MRI = slicer.modules.mri2cbct_orient_center_mri + parameter_mri2cbct_orient_center_mri = { + "input_folder": kwargs["input_folder"], + "direction": kwargs["direction"], + "output_folder": kwargs["output_folder"] + } + + list_process.append( + { + "Process": MRI2CBCT_ORIENT_CENTER_MRI, + "Parameter": parameter_mri2cbct_orient_center_mri, + "Module": "Orientation and Centering of the MRI", + } + ) + + return list_process + + diff --git a/MRI2CBCT/utils/Reg_MRI2CBCT.py b/MRI2CBCT/utils/Reg_MRI2CBCT.py new file mode 100644 index 0000000..5ca93c0 --- /dev/null +++ b/MRI2CBCT/utils/Reg_MRI2CBCT.py @@ -0,0 +1,117 @@ +from utils.Method import Method +from utils.utils_CBCT import GetDictPatients, GetPatients +import os, sys + +import SimpleITK as sitk +import numpy as np + +from glob import iglob +import slicer +import time +import qt +import platform +import re + + +class Registration_MRI2CBCT(Method): + def __init__(self, widget): + super().__init__(widget) + documentsLocation = qt.QStandardPaths.DocumentsLocation + documents = qt.QStandardPaths.writableLocation(documentsLocation) + + def getGPUUsage(self): + if platform.system() == "Darwin": + return 1 + else: + return 5 + + def NumberScan(self, scan_folder_t1: str, scan_folder_t2: str): + return len(GetDictPatients(scan_folder_t1, scan_folder_t2)) + + + def TestScan(self, scan_folder: str): + extensions = ['.nii', '.nii.gz', '.nrrd'] + found_files = self.search(scan_folder, extensions) + if any(found_files[ext] for ext in extensions): + return True, "" + else: + return False, "No files to run has been found in the input folder" + + def CheckNormalization(self, norm: str): + mri_min_norm, mri_max_norm, mri_lower_p, mri_upper_p = norm[0][0] + cbct_min_norm, cbct_max_norm, cbct_lower_p, cbct_upper_p = norm[0][1] + + ok = True + messages = [] + + if mri_max_norm <= mri_min_norm: + ok = False + messages.append("MRI normalization max must be greater than min") + if mri_upper_p <= mri_lower_p: + ok = False + messages.append("MRI percentile max must be greater than min") + + if cbct_max_norm <= cbct_min_norm: + ok = False + messages.append("CBCT normalization max must be greater than min") + if cbct_upper_p <= cbct_lower_p: + ok = False + messages.append("CBCT percentile max must be greater than min") + + message = "\n".join(messages) + + return ok, message + + def TestProcess(self, **kwargs) -> str: + out = "" + ok = True + + if kwargs["folder_general"] == "": + out += "Please select an input folder for CBCT scans\n" + ok = False + + if kwargs["mri_folder"] == "": + out += "Please select an input folder for MRI scans\n" + ok = False + + if kwargs["cbct_folder"] == "": + out += "Please select an input folder for CBCT scans\n" + ok = False + + if kwargs["cbct_label2"] == "": + out += "Please select an input folder for CBCT segmentation\n" + ok = False + + if kwargs["normalization"] == "": + out += "Please select some values for the normalization\n" + ok = False + + if out == "": + out = None + + return ok,out + + def Process(self, **kwargs): + list_process=[] + + MRI2CBCT_RESAMPLE_REG = slicer.modules.mri2cbct_reg + parameter_mri2cbct_reg = { + "folder_general": kwargs["folder_general"], + "mri_folder": kwargs["mri_folder"], + "cbct_folder": kwargs["cbct_folder"], + "cbct_label2": kwargs["cbct_label2"], + "normalization" : kwargs["normalization"], + "tempo_fold" : kwargs["tempo_fold"] + } + + list_process.append( + { + "Process": MRI2CBCT_RESAMPLE_REG, + "Parameter": parameter_mri2cbct_reg, + "Module": "Resample files", + } + ) + + return list_process + + diff --git a/MRI2CBCT/utils/utils_CBCT.py b/MRI2CBCT/utils/utils_CBCT.py new file mode 100644 index 0000000..84c7291 --- /dev/null +++ b/MRI2CBCT/utils/utils_CBCT.py @@ -0,0 +1,157 @@ +import os +from glob import iglob + +def GetListFiles(folder_path, file_extension): + """Return a list of files in folder_path finishing by file_extension""" + file_list = [] + for extension_type in file_extension: + file_list += search(folder_path, file_extension)[extension_type] + return file_list + + +def GetPatients(folder_path, time_point="T1", segmentationType=None): + """Return a dictionary with patient id as key""" + file_extension = [".nii.gz", ".nii", ".nrrd", ".nrrd.gz", ".gipl", ".gipl.gz"] + json_extension = [".json"] + file_list = GetListFiles(folder_path, file_extension + json_extension) + + patients = {} + + for file in file_list: + basename = os.path.basename(file) + patient = ( + basename.split("_Scan")[0] + .split("_scan")[0] + .split("_Or")[0] + .split("_OR")[0] + .split("_MAND")[0] + .split("_MD")[0] + .split("_MAX")[0] + .split("_MX")[0] + .split("_CB")[0] + .split("_lm")[0] + .split("_T2")[0] + .split("_T1")[0] + .split("_Cl")[0] + .split(".")[0] + ) + + if patient not in patients: + patients[patient] = {} + + if True in [i in basename for i in file_extension]: + # if segmentationType+'MASK' in basename: + if True in [i in basename.lower() for i in ["mask", "seg", "pred"]]: + if segmentationType is None: + patients[patient]["seg" + time_point] = file + else: + if True in [ + i in basename.lower() + for i in GetListNamesSegType(segmentationType) + ]: + patients[patient]["seg" + time_point] = file + + else: + patients[patient]["scan" + time_point] = file + + if True in [i in basename for i in json_extension]: + if time_point == "T2": + patients[patient]["lm" + time_point] = file + + return patients + + +def GetMatrixPatients(folder_path): + """Return a dictionary with patient id as key and matrix path as data""" + file_extension = [".tfm"] + file_list = GetListFiles(folder_path, file_extension) + + patients = {} + for file in file_list: + basename = os.path.basename(file) + patient = basename.split("reg_")[1].split("_Cl")[0] + if patient not in patients and True in [i in basename for i in file_extension]: + patients[patient] = {} + patients[patient]["mat"] = file + + return patients + + +def GetDictPatients( + folder_t1_path, + folder_t2_path, + segmentationType=None, + todo_str="", + matrix_folder=None, +): + """Return a dictionary with patients for both time points""" + patients_t1 = GetPatients( + folder_t1_path, time_point="T1", segmentationType=segmentationType + ) + patients_t2 = GetPatients(folder_t2_path, time_point="T2", segmentationType=None) + patients = MergeDicts(patients_t1, patients_t2) + + if matrix_folder is not None: + patient_matrix = GetMatrixPatients(matrix_folder) + patients = MergeDicts(patients, patient_matrix) + patients = ModifiedDictPatients(patients, todo_str) + return patients + + +def MergeDicts(dict1, dict2): + """Merge t1 and t2 dictionaries for each patient""" + patients = {} + for patient in dict1: + patients[patient] = dict1[patient] + try: + patients[patient].update(dict2[patient]) + except KeyError: + continue + return patients + + +def ModifiedDictPatients(patients, todo_str): + """Modify the dictionary of patients to only keep the ones in the todo_str""" + + if todo_str != "": + liste_todo = todo_str.split(",") + todo_patients = {} + for i in liste_todo: + patient = list(patients.keys())[int(i) - 1] + todo_patients[patient] = patients[patient] + patients = todo_patients + + return patients + + +def search(path, *args): + """ + Return a dictionary with args element as key and a list of file in path directory finishing by args extension for each key + + Example: + args = ('json',['.nii.gz','.nrrd']) + return: + { + 'json' : ['path/a.json', 'path/b.json','path/c.json'], + '.nii.gz' : ['path/a.nii.gz', 'path/b.nii.gz'] + '.nrrd.gz' : ['path/c.nrrd'] + } + """ + arguments = [] + for arg in args: + if type(arg) == list: + arguments.extend(arg) + else: + arguments.append(arg) + return { + key: sorted( + [ + i + for i in iglob( + os.path.normpath("/".join([path, "**", "*"])), recursive=True + ) + if i.endswith(key) + ] + ) + for key in arguments + } \ No newline at end of file diff --git a/MRI2CBCT_CLI/CMakeLists.txt b/MRI2CBCT_CLI/CMakeLists.txt new file mode 100644 index 0000000..9a74ba1 --- /dev/null +++ b/MRI2CBCT_CLI/CMakeLists.txt @@ -0,0 +1,11 @@ +#----------------------------------------------------------------------------- +add_subdirectory(MRI2CBCT_ORIENT_CENTER_MRI) +add_subdirectory(MRI2CBCT_RESAMPLE_CBCT_MRI) +add_subdirectory(MRI2CBCT_REG) +# add_subdirectory(MRI2CBCT_CLI_utils) +# add_subdirectory(MRI2CBCT_RESAMPLE_CBCT_MRI) + +include(ImportLibrary.cmake) + + + diff --git a/MRI2CBCT_CLI/ImportLibrary.cmake b/MRI2CBCT_CLI/ImportLibrary.cmake new file mode 100644 index 0000000..9bdc896 --- /dev/null +++ b/MRI2CBCT_CLI/ImportLibrary.cmake @@ -0,0 +1,15 @@ +#----------------------------------------------------------------------------- +set(MODULE_NAME MRI2CBCT_CLI_utils) + +#----------------------------------------------------------------------------- +set(MODULE_PYTHON_SCRIPTS + ${MODULE_NAME}/__init__.py + ${MODULE_NAME}/resample_create_csv.py + ${MODULE_NAME}/resample.py +) + +#----------------------------------------------------------------------------- +slicerMacroBuildScriptedModule( + NAME ${MODULE_NAME} + SCRIPTS ${MODULE_PYTHON_SCRIPTS} +) diff --git a/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/AREG_MRI.py b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/AREG_MRI.py new file mode 100644 index 0000000..29be884 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/AREG_MRI.py @@ -0,0 +1,194 @@ +import argparse +import os +import itk +import SimpleITK as sitk +import numpy as np + +def ComputeFinalMatrix(Transforms): + """Compute the final matrix from the list of matrices and translations""" + Rotation, Translation = [], [] + for i in range(len(Transforms)): + Rotation.append(Transforms[i].GetMatrix()) + Translation.append(Transforms[i].GetTranslation()) + + # Compute the final rotation matrix + final_rotation = np.reshape(np.asarray(Rotation[0]), (3, 3)) + for i in range(1, len(Rotation)): + final_rotation = final_rotation @ np.reshape(np.asarray(Rotation[i]), (3, 3)) + + # Compute the final translation matrix + final_translation = np.reshape(np.asarray(Translation[0]), (1, 3)) + for i in range(1, len(Translation)): + final_translation = final_translation + np.reshape( + np.asarray(Translation[i]), (1, 3) + ) + + # Create the final transform + final_transform = sitk.Euler3DTransform() + final_transform.SetMatrix(final_rotation.flatten().tolist()) + final_transform.SetTranslation(final_translation[0].tolist()) + + return final_transform + + +def ElastixReg(fixed_image, moving_image, initial_transform=None): + """Perform a registration using elastix with a rigid transform and possibly an initial transform""" + + elastix_object = itk.ElastixRegistrationMethod.New(fixed_image, moving_image) + + # ParameterMap + parameter_object = itk.ParameterObject.New() + default_rigid_parameter_map = parameter_object.GetDefaultParameterMap("rigid") + parameter_object.AddParameterMap(default_rigid_parameter_map) + parameter_object.SetParameter("ErodeMask", "true") + parameter_object.SetParameter("WriteResultImage", "false") + parameter_object.SetParameter("MaximumNumberOfIterations", "10000") + parameter_object.SetParameter("NumberOfResolutions", "1") + parameter_object.SetParameter("NumberOfSpatialSamples", "10000") + + elastix_object.SetParameterObject(parameter_object) + if initial_transform is not None: + elastix_object.SetInitialTransformParameterObject(initial_transform) + + # Additional parameters + elastix_object.SetLogToConsole(False) + + # Execute registration + elastix_object.UpdateLargestPossibleRegion() + + TransParamObj = elastix_object.GetTransformParameterObject() + + return TransParamObj + +def MatrixRetrieval(TransformParameterMapObject): + """Retrieve the matrix from the transform parameter map""" + ParameterMap = TransformParameterMapObject.GetParameterMap(0) + + if ParameterMap["Transform"][0] == "AffineTransform": + matrix = [float(i) for i in ParameterMap["TransformParameters"]] + # Convert to a sitk transform + transform = sitk.AffineTransform(3) + transform.SetParameters(matrix) + + elif ParameterMap["Transform"][0] == "EulerTransform": + A = [float(i) for i in ParameterMap["TransformParameters"][0:3]] + B = [float(i) for i in ParameterMap["TransformParameters"][3:6]] + # Convert to a sitk transform + transform = sitk.Euler3DTransform() + transform.SetRotation(angleX=A[0], angleY=A[1], angleZ=A[2]) + transform.SetTranslation(B) + + return transform + +def get_corresponding_file(folder, patient_id, modality): + """Get the corresponding file for a given patient ID and modality.""" + for root, _, files in os.walk(folder): + for file in files: + if file.startswith(patient_id) and modality in file and file.endswith(".nii.gz"): + return os.path.join(root, file) + return None + +def registration(cbct_folder,mri_folder,cbct_mask_folder,output_folder,mri_original_folder): + """ + Registers CBCT and MRI images using CBCT masks, saving the results in the specified output folder. + + Arguments: + cbct_folder (str): Folder containing CBCT files (.nii.gz). + mri_folder (str): Folder containing corresponding MRI files (.nii.gz). + cbct_mask_folder (str): Folder containing CBCT masks (.nii.gz). + output_folder (str): Folder to save the registration results. + mri_original_folder (str): Folder containing original MRI files (.nii.gz), if available. + + For each CBCT file in cbct_folder: + - Extract patient ID from the filename. + - Find corresponding MRI and CBCT mask files. + - Optionally, find the original MRI file. + - Call process_images to perform registration and save the results. + """ + + for cbct_file in os.listdir(cbct_folder): + if cbct_file.endswith(".nii.gz") and "_CBCT_" in cbct_file: + patient_id = cbct_file.split("_CBCT_")[0] + + mri_path = get_corresponding_file(mri_folder, patient_id, "_MR_") + if mri_original_folder!="None": + mri_path_original = get_corresponding_file(mri_original_folder, patient_id, "_MR_") + + + cbct_mask_path = get_corresponding_file(cbct_mask_folder, patient_id, "_CBCT_") + + process_images(mri_path, cbct_mask_path, output_folder,patient_id,mri_path_original,) + +def process_images(mri_path, cbct_mask_path, output_folder, patient_id,mri_path_original): + """ + Processes MRI and CBCT mask images, performs registration, and saves the results. + + Arguments: + mri_path (str): Path to the MRI file. + cbct_mask_path (str): Path to the CBCT mask file. + output_folder (str): Folder to save the registration results. + patient_id (str): Identifier for the patient. + mri_path_original (str): Path to the original MRI file. + + Steps: + - Reads the MRI and CBCT mask images. + - Performs registration using Elastix. + - Retrieves the transformation matrix and computes the final transform. + - Saves the transformed image and transformation matrix in the output folder. + """ + + try : + mri_path = itk.imread(mri_path, itk.F) + cbct_mask_path = itk.imread(cbct_mask_path, itk.F) + except KeyError as e: + print("An error occurred while reading the images of the patient : {patient_id}") + print(e) + print(f"{patient_id} failed") + return + + Transforms = [] + + try : + TransformObj_Fine = ElastixReg(cbct_mask_path, mri_path, initial_transform=None) + except Exception as e: + print("An error occurred during the registration process on the patient {patient_id} :") + print(e) + return + + transforms_Fine = MatrixRetrieval(TransformObj_Fine) + Transforms.append(transforms_Fine) + transform = ComputeFinalMatrix(Transforms) + + os.makedirs(output_folder, exist_ok=True) + + output_image_path = os.path.join(output_folder,os.path.basename(mri_path_original).replace('.nii.gz', f'_reg.nii.gz')) + output_image_path_transform = os.path.join(output_folder,os.path.basename(mri_path_original).replace('.nii.gz', f'_reg_transform.tfm')) + + sitk.WriteTransform(transform, output_image_path_transform) + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='AREG MRI folder') + + parser.add_argument("--cbct_folder", type=str, help="Folder containing CBCT images.", default=".") + parser.add_argument("--cbct_mask_folder", type=str, help="Folder containing CBCT masks.", default=".") + + parser.add_argument("--mri_folder", type=str, help="Folder containing MRI images.", default=".") + parser.add_argument("--mri_original_folder", type=str, help="Folder containing original MRI.", default=".") + + parser.add_argument("--output_folder", type=str, help="Folder to save the output files.",default=".") + + args = parser.parse_args() + + if not os.path.exists(args.output_folder): + os.makedirs(args.output_folder) + + cbct_folder = args.cbct_folder + mri_folder = args.mri_folder + cbct_mask_folder = args.cbct_mask_folder + output_folder = args.output_folder + mri_original_folder = args.mri_original_folder + + registration(cbct_folder,mri_folder,cbct_mask_folder,output_folder,mri_original_folder) diff --git a/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/__init__.py b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/__init__.py new file mode 100644 index 0000000..683fe84 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/__init__.py @@ -0,0 +1,11 @@ +from .resample_create_csv import ( + create_csv, +) + +# from .Net import DenseNet +from .resample import resample_images, run_resample + +from .mri_inverse import invert_mri_intensity +from .normalize_percentile import normalize +from .apply_mask import apply_mask_f +from .AREG_MRI import registration diff --git a/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/apply_mask.py b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/apply_mask.py new file mode 100644 index 0000000..614b425 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/apply_mask.py @@ -0,0 +1,124 @@ +import SimpleITK as sitk +import os +import argparse +import numpy as np + + +def MaskedImage(fixed_image_path, fixed_seg_path, folder_output, suffix, SegLabel=None): + """ + Mask the fixed image with the fixed segmentation and write it to a file + + Arguments: + fixed_image_path (str): Path to the fixed image file. + fixed_seg_path (str): Path to the fixed segmentation file. + folder_output (str): Folder to save the masked image. + suffix (str): Suffix to add to the output file name. + SegLabel (int, optional): Segmentation label to use for masking. + """ + fixed_image_sitk = sitk.ReadImage(fixed_image_path) + fixed_seg_sitk = sitk.ReadImage(fixed_seg_path) + + fixed_image_masked = applyMask(fixed_image_sitk, fixed_seg_sitk, label=SegLabel) + if fixed_image_masked=="failed": + print("failed process on : ",fixed_image_sitk) + return + + base_name, ext = os.path.splitext(fixed_image_path) + if base_name.endswith('.nii'): # Case for .nii.gz + ext = '.nii.gz' + + file_name = os.path.basename(fixed_image_path) + file_name_without_ext = os.path.splitext(os.path.splitext(file_name)[0])[0] + + output_path = os.path.join(folder_output, f"{file_name_without_ext}_{suffix}{ext}") + + sitk.WriteImage(sitk.Cast(fixed_image_masked, sitk.sitkInt16), output_path) + + return output_path + + +def applyMask(image, mask, label): + """ + Apply a mask to an image. + + Arguments: + image (SimpleITK.Image): The image to be masked. + mask (SimpleITK.Image): The mask image. + label (int): The label value to use for masking. + """ + try : + array = sitk.GetArrayFromImage(mask) + if label is not None and label in np.unique(array): + array = np.where(array == label, 1, 0) + mask = sitk.GetImageFromArray(array) + mask.CopyInformation(image) + except KeyError as e : + print(e) + return "failed" + + return sitk.Mask(image, mask) + + +def find_segmentation_file(image_file, seg_folder): + """ + Find the corresponding segmentation file for a given image file. + + Arguments: + image_file (str): Path to the image file. + seg_folder (str): Folder containing segmentation files. + """ + base_name = os.path.basename(image_file) + patient_id = base_name.split('_CBCT')[0].split('_MR')[0] + + for seg_file in os.listdir(seg_folder): + if seg_file.startswith(patient_id) and "_CBCT" in seg_file: + return os.path.join(seg_folder, seg_file) + + return None + + +def apply_mask_f(folder_path, seg_folder, folder_output, suffix, seg_label): + """ + Processes all image files in the specified folder by applying the corresponding segmentation masks. + + Arguments: + folder_path (str): Path to the folder containing image files. + seg_folder (str): Folder containing segmentation files. + folder_output (str): Folder to save the masked images. + suffix (str): Suffix to add to the output file names. + seg_label (int): Segmentation label to use for masking. + """ + + for root, _, files in os.walk(folder_path): + for file in files: + if file.endswith(('.nii', '.nii.gz')) and ('_CBCT' in file or '_MR' in file): + fixed_image_path = os.path.join(root, file) + fixed_seg_path = find_segmentation_file(file, seg_folder) + + if fixed_seg_path: + try : + MaskedImage(fixed_image_path, fixed_seg_path, folder_output, suffix, seg_label) + print(f"Mask apply for the file {fixed_image_path} succedeed.") + except KeyError as e: + print(f"Mask apply for the file {fixed_image_path}failed.") + print(e) + continue + else: + print(f"Segmentation file for {fixed_image_path} not found.") + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Apply segmentation mask to all MRI files in a folder.") + parser.add_argument("--folder_path", type=str, default="/home/lucia/Documents/Gaelle/Data/MultimodelReg/Segmentation/a3_Registration_closer_all/b2_CBCT_norm/test_percentile=[10,95]_norm=[0,75]", help="The path to the folder containing the MRI files.") + parser.add_argument("--seg_folder", type=str, default="/home/lucia/Documents/Gaelle/Data/MultimodelReg/Segmentation/a3_Registration_closer_all/d0_CBCT_seg_sep/label_2", help="The path to the segmentation file.") + parser.add_argument("--folder_output", type=str, default="/home/lucia/Documents/Gaelle/Data/MultimodelReg/Segmentation/a3_Registration_closer_all/b3_CBCT_inv_norm_mask:l2/a03_test_percentile=[10,95]_norm=[0,75]", help="The path to the output folder for the masked files.") + parser.add_argument("--suffix", type=str, default="mask", help="The suffix to add to the output filenames.") + parser.add_argument("--seg_label", type=int, default=1, help="Label of the segmentation.") + + args = parser.parse_args() + + if not os.path.exists(args.folder_output): + os.makedirs(args.folder_output) + + apply_mask_f(args.folder_path, args.seg_folder, args.folder_output, args.suffix, args.seg_label) diff --git a/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/mri_inverse.py b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/mri_inverse.py new file mode 100644 index 0000000..c61b19c --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/mri_inverse.py @@ -0,0 +1,61 @@ +import SimpleITK as sitk +import os +import argparse + +def invert_mri_intensity(path_folder, folder_output, suffix): + """ + Inverts the intensity values of MRI images in the specified folder and saves the results. + + Arguments: + path_folder (str): Path to the folder containing MRI files. + folder_output (str): Folder to save the inverted images. + suffix (str): Suffix to add to the output file names. + """ + + # Check if the output folder exists, if not create it + if not os.path.exists(folder_output): + os.makedirs(folder_output) + + # Iterate through the files in the input folder + for filename in os.listdir(path_folder): + if filename.endswith('.nii') or filename.endswith('.nii.gz'): + filepath = os.path.join(path_folder, filename) + image = sitk.ReadImage(filepath) + + # Convert the image to a numpy array to manipulate intensities + image_array = sitk.GetArrayFromImage(image) + + # Find the maximum intensity value in the image + max_intensity = image_array.max() + + # Invert the intensities while keeping the background (where intensity is 0) unchanged + inverted_image_array = max_intensity - image_array + inverted_image_array[image_array == 0] = 0 + + # Convert the inverted numpy array back to a SimpleITK image + inverted_image = sitk.GetImageFromArray(inverted_image_array) + + # Copy the original image information (such as spacing, origin, etc.) to the inverted image + inverted_image.CopyInformation(image) + + # Generate the new filename with the suffix + base_name, ext = os.path.splitext(filename) + if base_name.endswith('.nii'): # Case for .nii.gz + base_name, ext2 = os.path.splitext(base_name) + ext = ext2 + ext + + output_filename = os.path.join(folder_output, f"{base_name}_{suffix}{ext}") + + # Save the inverted image + sitk.WriteImage(inverted_image, output_filename) + + print(f"Inversion completed for {filename}, saved as {output_filename}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Invert the intensity of MRI images while keeping the background at 0.") + parser.add_argument("--path_folder", type=str, help="The path to the folder containing the MRI files", default="/home/lucia/Documents/Gaelle/Data/MultimodelReg/Segmentation/a3_Registration_closer_all/a0_MRI") + parser.add_argument("--folder_output", type=str, help="The path to the output folder for the inverted files",default="/home/lucia/Documents/Gaelle/Data/MultimodelReg/Segmentation/a3_Registration_closer_all/a1_MRI_inv") + parser.add_argument("--suffix", type=str, help="The suffix to add to the output filenames",default="inv") + + args = parser.parse_args() + invert_mri_intensity(args.path_folder, args.folder_output, args.suffix) diff --git a/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/normalize_percentile.py b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/normalize_percentile.py new file mode 100644 index 0000000..2248c1c --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/normalize_percentile.py @@ -0,0 +1,92 @@ +import argparse +import os +import SimpleITK as sitk +import numpy as np + +def compute_thresholds(image, lower_percentile=10, upper_percentile=90): + """ + Computes intensity thresholds for an image based on specified percentiles. + + Arguments: + image (SimpleITK.Image): The input image. + lower_percentile (float): The lower percentile for threshold computation (default is 10). + upper_percentile (float): The upper percentile for threshold computation (default is 90). + """ + array = sitk.GetArrayFromImage(image) + lower_threshold = np.percentile(array, lower_percentile) + upper_threshold = np.percentile(array, upper_percentile) + return lower_threshold, upper_threshold + +def enhance_contrast(image,upper_percentile,lower_percentile, min_norm, max_norm): + """ + Enhances the contrast of the image while normalizing its intensity values. + + Arguments: + image (SimpleITK.Image): The input image. + upper_percentile (float): The upper percentile for threshold computation. + lower_percentile (float): The lower percentile for threshold computation. + min_norm (float): The minimum normalization value. + max_norm (float): The maximum normalization value. + """ + # Compute thresholds + lower_threshold, upper_threshold = compute_thresholds(image,lower_percentile,upper_percentile) + + + # Normalize the image using the computed thresholds + array = sitk.GetArrayFromImage(image) + normalized_array = np.clip((array - lower_threshold) / (upper_threshold - lower_threshold), 0, 1) + scaled_array = normalized_array * max_norm - min_norm + + return sitk.GetImageFromArray(scaled_array) + +def normalize(input_folder, output_folder,upper_percentile,lower_percentile,min_norm, max_norm): + """ + Processes and normalizes all .nii.gz images in the input folder, enhancing their contrast. + + Arguments: + input_folder (str): Path to the folder containing the input images. + output_folder (str): Path to the folder to save the normalized images. + upper_percentile (float): Upper percentile for threshold computation. + lower_percentile (float): Lower percentile for threshold computation. + min_norm (float): Minimum normalization value. + max_norm (float): Maximum normalization value. + """ + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + for filename in os.listdir(input_folder): + if filename.endswith('.nii.gz'): + input_path = os.path.join(input_folder, filename) + img = sitk.ReadImage(input_path) + + # Enhance the contrast of the image + enhanced_img = enhance_contrast(img,upper_percentile,lower_percentile,min_norm, max_norm) + + # Copy original metadata to the enhanced image + enhanced_img.CopyInformation(img) + + # Save the enhanced image with the new suffix + output_filename = filename.replace('.nii.gz', f'_percentile=[{lower_percentile},{upper_percentile}]_norm=[{min_norm},{max_norm}].nii.gz') + output_path = os.path.join(output_folder, output_filename) + sitk.WriteImage(enhanced_img, output_path) + print(f'Saved enhanced image to {output_path}') + +def main(): + parser = argparse.ArgumentParser(description='Enhance contrast of NIfTI images and save with a new suffix.') + parser.add_argument('--input_folder', type=str, help='Path to the input folder containing .nii.gz images.', default="/home/lucia/Documents/Gaelle/Data/MultimodelReg/Segmentation/a3_Registration_closer_all/b0_CBCT") + parser.add_argument('--output_folder', type=str, help='Path to the output folder to save normalized images.', default="/home/lucia/Documents/Gaelle/Data/MultimodelReg/Segmentation/a3_Registration_closer_all/b2_CBCT_norm") + parser.add_argument('--upper_percentile', type=int, help='upper percentile to apply, choose between 0 and 100',default=95) + parser.add_argument('--lower_percentile', type=int, help='lower percentile to apply, choose between 0 and 100',default=10) + parser.add_argument('--max_norm', type=int, help='max value after normalization',default=75) + parser.add_argument('--min_norm', type=int, help='min value after normalization',default=0) + + args = parser.parse_args() + + output_path = os.path.join(args.output_folder,f"test_percentile=[{args.lower_percentile},{args.upper_percentile}]_norm=[{args.min_norm},{args.max_norm}]") + if not os.path.exists(output_path): + os.makedirs(output_path) + + normalize(args.input_folder, output_path, args.upper_percentile,args.lower_percentile,args.min_norm, args.max_norm) + +if __name__ == '__main__': + main() diff --git a/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/resample.py b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/resample.py new file mode 100644 index 0000000..1a5c6ac --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/resample.py @@ -0,0 +1,271 @@ +import SimpleITK as sitk +import numpy as np +import argparse +import os +import glob +import sys +import csv + +def resample_fn(img, args): + ''' + Resamples the given image based on the specified arguments. + + Arguments: + img (SimpleITK.Image): The image to be resampled. + args (dict): Dictionary containing the following keys: + - size (tuple): Desired size of the output image. + - fit_spacing (bool): Flag to fit spacing. + - iso_spacing (bool): Flag for isotropic spacing. + - pixel_dimension (int): Pixel dimension of the image. + - center (int): Flag to center the image. + - linear (bool): Flag to use linear interpolation. + - spacing (tuple): Desired spacing of the output image (optional). + - origin (tuple): Desired origin of the output image (optional). + ''' + output_size = args['size'] + fit_spacing = args['fit_spacing'] + iso_spacing = args['iso_spacing'] + pixel_dimension = args['pixel_dimension'] + center = args['center'] + + if args['linear']: + InterpolatorType = sitk.sitkLinear + else: + InterpolatorType = sitk.sitkNearestNeighbor + + + + spacing = img.GetSpacing() + size = img.GetSize() + + output_origin = img.GetOrigin() + output_size = [si if o_si == -1 else o_si for si, o_si in zip(size, output_size)] + + if(fit_spacing): + output_spacing = [sp*si/o_si for sp, si, o_si in zip(spacing, size, output_size)] + else: + output_spacing = spacing + + + if(iso_spacing=="True"): + output_spacing_filtered = [sp for si, sp in zip(args['size'], output_spacing) if si != -1] + max_spacing = np.max(output_spacing_filtered) + output_spacing = [sp if si == -1 else max_spacing for si, sp in zip(args['size'], output_spacing)] + + + if(args['spacing'] is not None): + output_spacing = args['spacing'] + + if(args['origin'] is not None): + output_origin = args['origin'] + + if(center): + output_physical_size = np.array(output_size)*np.array(output_spacing) + input_physical_size = np.array(size)*np.array(spacing) + output_origin = np.array(output_origin) - (output_physical_size - input_physical_size)/2.0 + + + resampleImageFilter = sitk.ResampleImageFilter() + resampleImageFilter.SetInterpolator(InterpolatorType) + resampleImageFilter.SetOutputSpacing(output_spacing) + resampleImageFilter.SetSize(output_size) + resampleImageFilter.SetOutputDirection(img.GetDirection()) + resampleImageFilter.SetOutputOrigin(output_origin) + # resampleImageFilter.SetDefaultPixelValue(zeroPixel) + + + return resampleImageFilter.Execute(img) + + +def Resample(img_filename, args): + """ + Resamples an image based on the provided arguments. + + Arguments: + img_filename (str): Path to the image file to resample. + args (dict): Dictionary containing the following keys: + - size (tuple): Desired size of the output image. + - fit_spacing (bool): Flag to fit spacing. + - iso_spacing (bool): Flag for isotropic spacing. + - image_dimension (int): Dimension of the image. + - pixel_dimension (int): Pixel dimension of the image. + - img_spacing (tuple): Spacing of the input image. + + Steps: + - Reads the image from the specified file. + - Sets the image spacing if provided in the arguments. + - Calls the resample function with the image and arguments. + - Returns the resampled image. + """ + + img = sitk.ReadImage(img_filename) + + if(args['img_spacing']): + img.SetSpacing(args['img_spacing']) + + return resample_fn(img, args) + + +def resample_images(args): + """ + Resamples images based on the provided arguments and saves the output. + """ + + filenames = [] + if args['img']: + fobj = {"img": args['img'], "out": args['out']} + filenames.append(fobj) + elif args['dir']: + out_dir = args['out'] + normpath = os.path.normpath("/".join([args['dir'], '**', '*'])) + for img in glob.iglob(normpath, recursive=True): + if os.path.isfile(img) and any(ext in img for ext in [".nrrd", ".nii", ".nii.gz", ".mhd", ".dcm", ".DCM", ".jpg", ".png"]): + fobj = {"img": img, "out": os.path.normpath(out_dir + "/" + img.replace(args['dir'], ''))} + if args['out_ext'] is not None: + out_ext = args['out_ext'] if args['out_ext'].startswith(".") else "." + args['out_ext'] + fobj["out"] = os.path.splitext(fobj["out"])[0] + out_ext + if not os.path.exists(os.path.dirname(fobj["out"])): + os.makedirs(os.path.dirname(fobj["out"])) + if not os.path.exists(fobj["out"]) or args.ow: + filenames.append(fobj) + elif args['csv']: + replace_dir_name = args['csv_root_path'] + with open(args['csv']) as csvfile: + csv_reader = csv.DictReader(csvfile) + for row in csv_reader: + fobj = {"img": row[args['csv_column']], "out": row[args['csv_column']]} + if replace_dir_name: + fobj["out"] = fobj["out"].replace(replace_dir_name, args['out']) + if args['csv_use_spc']: + img_spacing = [ + row[args['csv_column_spcx']] if args['csv_column_spcx'] else None, + row[args['csv_column_spcy']] if args['csv_column_spcy'] else None, + row[args['csv_column_spcz']] if args['csv_column_spcz'] else None, + ] + fobj["img_spacing"] = [spc for spc in img_spacing if spc] + + if "ref" in row: + fobj["ref"] = row["ref"] + + if args['out_ext'] is not None: + out_ext = args['out_ext'] if args['out_ext'].startswith(".") else "." + args['out_ext'] + fobj["out"] = os.path.splitext(fobj["out"])[0] + out_ext + if not os.path.exists(os.path.dirname(fobj["out"])): + os.makedirs(os.path.dirname(fobj["out"])) + if not os.path.exists(fobj["out"]) or args.ow: + filenames.append(fobj) + else: + raise ValueError("Set img or dir to resample!") + + if args['rgb']: + if args['pixel_dimension'] == 3: + print("Using: RGB type pixel with unsigned char") + elif args['pixel_dimension'] == 4: + print("Using: RGBA type pixel with unsigned char") + else: + print("WARNING: Pixel size not supported!") + + if args['ref'] is not None: + ref = sitk.ReadImage(args['ref']) + args['size'] = ref.GetSize() + args['spacing'] = ref.GetSpacing() + args['origin'] = ref.GetOrigin() + + for fobj in filenames: + try: + if "ref" in fobj and fobj["ref"] is not None: + ref = sitk.ReadImage(fobj["ref"]) + args['size'] = ref.GetSize() + args['spacing'] = ref.GetSpacing() + args['origin'] = ref.GetOrigin() + + if args['size'] is not None: + img = Resample(fobj["img"], args) + else: + img = sitk.ReadImage(fobj["img"]) + + print("Writing:", fobj["out"]) + writer = sitk.ImageFileWriter() + writer.SetFileName(fobj["out"]) + writer.UseCompressionOn() + writer.Execute(img) + + except Exception as e: + print(e, file=sys.stderr) + + +def run_resample(img=None, dir=None, csv=None, csv_column='image', csv_root_path=None, csv_use_spc=0, + csv_column_spcx=None, csv_column_spcy=None, csv_column_spcz=None, ref=None, size=None, + img_spacing=None, spacing=None, origin=None, linear=False, center=0, fit_spacing=False, + iso_spacing=False, image_dimension=2, pixel_dimension=1, rgb=False, ow=1, out="./out.nrrd", + out_ext=None): + ''' + Sets up and runs the resampling of images based on the provided parameters. + ''' + args = { + 'img': img, # Path to a single image file to resample + 'dir': dir, # Directory containing image files to resample + 'csv': csv, # Path to a CSV file listing images to resample + 'csv_column': csv_column, # CSV column name that contains image paths + 'csv_root_path': csv_root_path, # Root path to prepend to CSV image paths + 'csv_use_spc': csv_use_spc, # Flag to use spacing from CSV + 'csv_column_spcx': csv_column_spcx, # CSV column name for X spacing + 'csv_column_spcy': csv_column_spcy, # CSV column name for Y spacing + 'csv_column_spcz': csv_column_spcz, # CSV column name for Z spacing + 'ref': ref, # Reference image path for resampling + 'size': size, # Desired size of the output image + 'img_spacing': img_spacing, # Spacing of the input image + 'spacing': spacing, # Desired spacing of the output image + 'origin': origin, # Origin of the output image + 'linear': linear, # Flag to use linear interpolation + 'center': center, # Flag to center the image + 'fit_spacing': fit_spacing, # Flag to fit spacing + 'iso_spacing': iso_spacing, # Flag for isotropic spacing + 'image_dimension': image_dimension, # Dimension of the image + 'pixel_dimension': pixel_dimension, # Pixel dimension of the image + 'rgb': rgb, # Flag for RGB images + 'ow': ow, # Overwrite flag + 'out': out, # Output file path + 'out_ext': out_ext, # Output file extension + } + resample_images(args) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Resample an image', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + in_group = parser.add_mutually_exclusive_group(required=True) + in_group.add_argument('--img', type=str, help='image to resample') + in_group.add_argument('--dir', type=str, help='Directory with image to resample') + in_group.add_argument('--csv', type=str, help='CSV file with column img with paths to images to resample') + + csv_group = parser.add_argument_group('CSV extra parameters') + csv_group.add_argument('--csv_column', type=str, default='image', help='CSV column name (Only used if flag csv is used)') + csv_group.add_argument('--csv_root_path', type=str, default=None, help='Replaces a root path directory to empty, this is use to recreate a directory structure in the output directory, otherwise, the output name will be the name in the csv (only if csv flag is used)') + csv_group.add_argument('--csv_use_spc', type=int, default=0, help='Use the spacing information in the csv instead of the image') + csv_group.add_argument('--csv_column_spcx', type=str, default=None, help='Column name in csv') + csv_group.add_argument('--csv_column_spcy', type=str, default=None, help='Column name in csv') + csv_group.add_argument('--csv_column_spcz', type=str, default=None, help='Column name in csv') + + transform_group = parser.add_argument_group('Transform parameters') + transform_group.add_argument('--ref', type=str, help='Reference image. Use an image as reference for the resampling', default=None) + transform_group.add_argument('--size', nargs="+", type=int, help='Output size, -1 to leave unchanged', default=None) + transform_group.add_argument('--img_spacing', nargs="+", type=float, default=None, help='Use this spacing information instead of the one in the image') + transform_group.add_argument('--spacing', nargs="+", type=float, default=None, help='Output spacing') + transform_group.add_argument('--origin', nargs="+", type=float, default=None, help='Output origin') + transform_group.add_argument('--linear', type=bool, help='Use linear interpolation.', default=False) + transform_group.add_argument('--center', type=int, help='Center the image in the space', default=0) + transform_group.add_argument('--fit_spacing', type=bool, help='Fit spacing to output', default=False) + transform_group.add_argument('--iso_spacing', type=bool, help='Same spacing for resampled output', default=False) + + img_group = parser.add_argument_group('Image parameters') + img_group.add_argument('--image_dimension', type=int, help='Image dimension', default=2) + img_group.add_argument('--pixel_dimension', type=int, help='Pixel dimension', default=1) + img_group.add_argument('--rgb', type=bool, help='Use RGB type pixel', default=False) + + out_group = parser.add_argument_group('Output parameters') + out_group.add_argument('--ow', type=int, help='Overwrite', default=1) + out_group.add_argument('--out', type=str, help='Output image/directory', default="./out.nrrd") + out_group.add_argument('--out_ext', type=str, help='Output extension type', default=None) + + args = parser.parse_args() + resample_images(args) diff --git a/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/resample_create_csv.py b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/resample_create_csv.py new file mode 100644 index 0000000..79af099 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_CLI_utils/resample_create_csv.py @@ -0,0 +1,66 @@ +import SimpleITK as sitk +import os +import pandas as pd +import argparse + +def get_nifti_info(file_path,output_resample): + """ + Retrieves information about a nifti file and prepares the output path for resampling. + + Arguments: + file_path (str): Path to the input nifti file. + output_resample (str): Path to the folder to save resampled nifti files. + """ + + # Read the NIfTI file + image = sitk.ReadImage(file_path) + + # Get information + info = { + "in": file_path, + "out" : file_path.replace(os.path.dirname(file_path),output_resample), + "size": image.GetSize(), + "Spacing": image.GetSpacing(), + } + + return info + +def create_csv(input:str,output_resample:str,output_csv:str,name_csv:str): + """ + Creates a CSV file with information about nifti files in the input folder, resampling them if needed. + + Arguments: + input (str): Path to the input folder containing nifti files. + output_resample (str): Path to the folder to save resampled nifti files. + output_csv (str): Path to the folder to save the output CSV file. + name_csv (str): Name of the output CSV file. + """ + + if not os.path.exists(output_resample): + os.makedirs(output_resample) + + if not os.path.exists(output_csv): + os.makedirs(output_csv) + + input_folder = input + # Get all nifti files in the folder + nifti_files = [] + for root, dirs, files in os.walk(input_folder): + for file in files: + if file.endswith(".nii") or file.endswith(".nii.gz"): + nifti_files.append(os.path.join(root, file)) + + # Get nifti info for every nifti file + nifti_info = [] + for file in nifti_files: + info = get_nifti_info(file,output_resample) + nifti_info.append(info) + + # Créez un seul DataFrame avec toutes les informations + df = pd.DataFrame(nifti_info) + outpath = os.path.join(output_csv,name_csv) + df.to_csv(outpath, index=False) + + return outpath + + diff --git a/MRI2CBCT_CLI/MRI2CBCT_ORIENT_CENTER_MRI/CMakeLists.txt b/MRI2CBCT_CLI/MRI2CBCT_ORIENT_CENTER_MRI/CMakeLists.txt new file mode 100644 index 0000000..d8762f7 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_ORIENT_CENTER_MRI/CMakeLists.txt @@ -0,0 +1,7 @@ + +set(MODULE_NAME MRI2CBCT_ORIENT_CENTER_MRI) + +SlicerMacroBuildScriptedCLI( + NAME ${MODULE_NAME} +) + diff --git a/MRI2CBCT_CLI/MRI2CBCT_ORIENT_CENTER_MRI/MRI2CBCT_ORIENT_CENTER_MRI.py b/MRI2CBCT_CLI/MRI2CBCT_ORIENT_CENTER_MRI/MRI2CBCT_ORIENT_CENTER_MRI.py new file mode 100644 index 0000000..8d1020f --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_ORIENT_CENTER_MRI/MRI2CBCT_ORIENT_CENTER_MRI.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python-real + +import os +import SimpleITK as sitk +import argparse + +def extract_id(filename): + """ + Extracts and returns the ID from a filename, removing common NIfTI extensions. + + Parameters: + filename (str): The filename from which to extract the ID. + + Returns: + str: The extracted ID without the extension. + """ + # Remove the extension using os.path.splitext + type_file = 0 + base = os.path.splitext(filename)[0] + # If the file has a double extension (commonly .nii.gz), remove the second extension + if base.endswith('.nii'): + base = os.path.splitext(base)[0] + type_file=1 + + + return base,type_file + +def calculate_new_origin(image): + """ + Calculate the new origin to center the image in the Slicer viewport across all axes. + """ + size = image.GetSize() + spacing = image.GetSpacing() + # Calculate the center offset for each axis + new_origin = [(size[i] * spacing[i]) / 2 for i in range(len(size))] + new_origin = [new_origin[2],-new_origin[0],new_origin[1]] # FOR MRI + # new_origin = [-new_origin[0]*1.5,new_origin[1],-new_origin[2]*0.5] # FOR CBCT + # new_origin = [-new_origin[0]*1,new_origin[1],-new_origin[2]*1] # SAVE INSIDE BUT NOT CENTER + return tuple(new_origin) + +def modify_image_properties(nifti_file_path, new_direction, output_file_path=None): + """ + Read a NIfTI file, change its Direction and optionally center and save the modified image. + """ + image = sitk.ReadImage(nifti_file_path) + # Set the new direction + image.SetDirection(new_direction) + + # Calculate and set the new origin + new_origin = calculate_new_origin(image) + image.SetOrigin(new_origin) + + if output_file_path: + sitk.WriteImage(image, output_file_path) + print(f"Modified image saved to {output_file_path}") + + return image + +def main(args): + new_direction = tuple(map(float, args.direction.split(','))) # Assumes direction as comma-separated values + input_folder = args.input_folder + output_folder = args.output_folder if args.output_folder else input_folder # Default to input folder if no output folder is provided. + + # Ensure the output folder exists + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + # Get all nifti files in the folder + nifti_files = [] + for root, dirs, files in os.walk(input_folder): + for file in files: + if file.endswith(".nii") or file.endswith(".nii.gz"): + nifti_files.append(os.path.join(root, file)) + + # Process each file + for file_path in nifti_files: + filename = os.path.basename(file_path) + file_id,type_file = extract_id(filename) + if type_file==0: + output_file_path = os.path.join(output_folder, f"{file_id}_OR.nii") + else : + output_file_path = os.path.join(output_folder, f"{file_id}_OR.nii.gz") + modify_image_properties(file_path, new_direction, output_file_path) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Modify NIfTI file directions and center them.") + parser.add_argument('input_folder', default = '.', help='Path to the input folder containing NIfTI files.') + parser.add_argument('direction', default = "-1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0", help='New direction for the NIfTI files, specified as a comma-separated string of floats. ') + parser.add_argument('output_folder', default = '.', help='Path to the output folder where modified NIfTI files will be saved.') + args = parser.parse_args() + main(args) + +# USE THIS DIRECTION FOR MRI : "0.0, 0.0, -1.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0" +# FOR CBCT : "1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0" + + + +# "-1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0" \ No newline at end of file diff --git a/MRI2CBCT_CLI/MRI2CBCT_ORIENT_CENTER_MRI/MRI2CBCT_ORIENT_CENTER_MRI.xml b/MRI2CBCT_CLI/MRI2CBCT_ORIENT_CENTER_MRI/MRI2CBCT_ORIENT_CENTER_MRI.xml new file mode 100644 index 0000000..04322f0 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_ORIENT_CENTER_MRI/MRI2CBCT_ORIENT_CENTER_MRI.xml @@ -0,0 +1,40 @@ + + + Automated Dental Tools.Advanced + 2 + MRI2CBCT_ORIENT_CENTER_MRI + + 0.0.1 + https://github.com/username/project + Slicer + FirstName LastName (Institution), FirstName LastName (Institution) + This work was partially funded by NIH grant NXNNXXNNNNNN-NNXN + + + + + + + + input_folder + + 0 + Path for the input folder. + + + + direction + + 1 + Direction for the new orientation of the scan + + + + output_folder + + 2 + Output Path + + + + diff --git a/MRI2CBCT_CLI/MRI2CBCT_REG/CMakeLists.txt b/MRI2CBCT_CLI/MRI2CBCT_REG/CMakeLists.txt new file mode 100644 index 0000000..aabe381 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_REG/CMakeLists.txt @@ -0,0 +1,7 @@ +#----------------------------------------------------------------------------- +set(MODULE_NAME MRI2CBCT_REG) + + +SlicerMacroBuildScriptedCLI( + NAME ${MODULE_NAME} +) diff --git a/MRI2CBCT_CLI/MRI2CBCT_REG/MRI2CBCT_REG.py b/MRI2CBCT_CLI/MRI2CBCT_REG/MRI2CBCT_REG.py new file mode 100644 index 0000000..7569335 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_REG/MRI2CBCT_REG.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python-real + +import argparse +import os +import re +import shutil + +import sys +fpath = os.path.join(os.path.dirname(__file__), "..") +sys.path.append(fpath) +from MRI2CBCT_CLI_utils import invert_mri_intensity, normalize, apply_mask_f, registration + +def create_folder(folder): + """ + Creates a folder if it does not already exist. + + Arguments: + folder (str): Path of the folder to create. + """ + if not os.path.exists(folder): + os.makedirs(folder) + + +def run_script_inverse_mri(mri_folder, folder_general): + """ + Inverts the intensity of MRI images and saves the results. + + Arguments: + mri_folder (str): Folder containing MRI files. + folder_general (str): General folder for output. + """ + + folder_mri_inverse = os.path.join(folder_general,"a01_MRI_inv") + create_folder(folder_mri_inverse) + invert_mri_intensity(mri_folder, folder_mri_inverse, "inv") + return folder_mri_inverse + +def run_script_normalize_percentile(file_type,input_folder, folder_general, upper_percentile, lower_percentile, max_norm, min_norm): + """ + Normalizes images based on specified percentiles and saves the results. + + Arguments: + file_type (str): Type of files to normalize ('MRI' or other). + input_folder (str): Folder containing the input files. + folder_general (str): General folder for output. + upper_percentile (float): Upper percentile for normalization. + lower_percentile (float): Lower percentile for normalization. + max_norm (float): Maximum value for normalization. + min_norm (float): Minimum value for normalization. + """ + + if file_type=="MRI": + output_folder_norm_general = os.path.join(folder_general,"a2_MRI_inv_norm") + else : + output_folder_norm_general = os.path.join(folder_general,"b2_CBCT_norm") + create_folder(output_folder_norm_general) + + output_folder_norm = os.path.join(output_folder_norm_general,f"percentile=[{lower_percentile},{upper_percentile}]_norm=[{min_norm},{max_norm}]") + create_folder(output_folder_norm) + + normalize(input_folder, output_folder_norm,upper_percentile,lower_percentile,min_norm, max_norm) + return output_folder_norm + + +def run_script_apply_mask(cbct_folder, cbct_label2,folder_general, suffix,upper_percentile, lower_percentile, max_norm, min_norm): + """ + Applies a mask to CBCT images and saves the normalized results. + + Arguments: + cbct_folder (str): Folder containing CBCT files. + cbct_label2 (str): Folder containing the segmentation labels. + folder_general (str): General folder for output. + suffix (str): Suffix for the output files. + upper_percentile (float): Upper percentile for normalization. + lower_percentile (float): Lower percentile for normalization. + max_norm (float): Maximum value for normalization. + min_norm (float): Minimum value for normalization. + """ + cbct_mask_folder = os.path.join(folder_general,"b3_CBCT_norm_mask:l2",f"percentile=[{lower_percentile},{upper_percentile}]_norm=[{min_norm},{max_norm}]") + create_folder(cbct_mask_folder) + apply_mask_f(folder_path=cbct_folder, seg_folder=cbct_label2, folder_output=cbct_mask_folder, suffix=suffix, seg_label=1) + return cbct_mask_folder + +def run_script_AREG_MRI_folder(cbct_folder, cbct_mask_folder,mri_folder,mri_original_folder,folder_general,mri_lower_p,mri_upper_p,mri_min_norm,mri_max_norm,cbct_lower_p,cbct_upper_p,cbct_min_norm,cbct_max_norm): + """ + Runs the registration script for MRI and CBCT folders, applying normalization and percentile adjustments. + + Arguments: + cbct_folder (str): Folder containing CBCT files. + cbct_mask_folder (str): Folder containing CBCT mask files. + mri_folder (str): Folder containing MRI files. + mri_original_folder (str): Folder containing original MRI files. + folder_general (str): General folder for output. + mri_lower_p (float): Lower percentile for MRI normalization. + mri_upper_p (float): Upper percentile for MRI normalization. + mri_min_norm (float): Minimum value for MRI normalization. + mri_max_norm (float): Maximum value for MRI normalization. + cbct_lower_p (float): Lower percentile for CBCT normalization. + cbct_upper_p (float): Upper percentile for CBCT normalization. + cbct_min_norm (float): Minimum value for CBCT normalization. + cbct_max_norm (float): Maximum value for CBCT normalization. + """ + + output_folder = os.path.join(folder_general,f"mri:inv+norm[{mri_min_norm},{mri_max_norm}]+p[{mri_lower_p},{mri_upper_p}]_cbct:norm[{cbct_min_norm},{cbct_max_norm}]+p[{cbct_lower_p},{cbct_upper_p}]+mask") + create_folder(output_folder) + registration(cbct_folder,mri_folder,cbct_mask_folder,output_folder,mri_original_folder) + return cbct_mask_folder + +def extract_values(input_string): + """ + Extracts 8 integers from the input string and returns them as a tuple. + + Arguments: + input_string (str): String containing the integers. + """ + + numbers = re.findall(r'\d+', input_string) + + numbers = list(map(int, numbers)) + + if len(numbers) != 8: + raise ValueError("The input need to contains 8 numbers") + + a, b, c, d, e, f, g, h = numbers + + return a, b, c, d, e, f, g, h + +def delete_folder(folder_path): + if os.path.exists(folder_path): + shutil.rmtree(folder_path) + print(f"The folder '{folder_path}' has been deleted successfully.") + else: + print(f"The folder '{folder_path}' does not exist.") + +def main(): + parser = argparse.ArgumentParser(description="Run multiple Python scripts with arguments") + parser.add_argument('folder_general', type=str, help="Folder general where to make all the output") + parser.add_argument('mri_folder', type=str, help="Folder containing original MRI images.") + parser.add_argument('cbct_folder', type=str, help="Folder containing original CBCT images.") + parser.add_argument('cbct_label2', type=str, help="Folder containing CBCT masks.") + parser.add_argument('normalization', type=str, help="Folder containing CBCT masks.") + parser.add_argument('tempo_fold', type=str, help="Indicate to keep the temporary fold or not") + args = parser.parse_args() + + mri_min_norm, mri_max_norm, mri_lower_p, mri_upper_p, cbct_min_norm, cbct_max_norm, cbct_lower_p, cbct_upper_p = extract_values(args.normalization) + + # MRI + folder_mri_inverse = run_script_inverse_mri(args.mri_folder, args.folder_general) + input_path_norm_mri = run_script_normalize_percentile("MRI",folder_mri_inverse, args.folder_general, upper_percentile=mri_upper_p, lower_percentile=mri_lower_p, max_norm=mri_max_norm, min_norm=mri_min_norm) + + # CBCT + output_path_norm_cbct = run_script_normalize_percentile("CBCT",args.cbct_folder, args.folder_general, upper_percentile=cbct_upper_p, lower_percentile=cbct_lower_p, max_norm=cbct_max_norm, min_norm=cbct_min_norm) + input_path_cbct_norm_mask = run_script_apply_mask(output_path_norm_cbct,args.cbct_label2,args.folder_general,"mask",upper_percentile=cbct_upper_p, lower_percentile=cbct_lower_p, max_norm=cbct_max_norm, min_norm=cbct_min_norm) + + # REG + run_script_AREG_MRI_folder(cbct_folder=args.cbct_folder,cbct_mask_folder=input_path_cbct_norm_mask,mri_folder=input_path_norm_mri,mri_original_folder=args.mri_folder,folder_general=args.folder_general,mri_lower_p=mri_lower_p,mri_upper_p=mri_upper_p,mri_min_norm=mri_min_norm,mri_max_norm=mri_max_norm,cbct_lower_p=cbct_lower_p,cbct_upper_p=cbct_upper_p,cbct_min_norm=cbct_min_norm,cbct_max_norm=cbct_max_norm) + + + if args.tempo_fold=="false": + delete_folder(folder_mri_inverse) + delete_folder(input_path_norm_mri) + delete_folder(os.path.dirname(input_path_norm_mri)) + delete_folder(output_path_norm_cbct) + delete_folder(os.path.dirname(output_path_norm_cbct)) + delete_folder(input_path_cbct_norm_mask) + delete_folder(os.path.dirname(input_path_cbct_norm_mask)) + + + +if __name__ == "__main__": + main() diff --git a/MRI2CBCT_CLI/MRI2CBCT_REG/MRI2CBCT_REG.xml b/MRI2CBCT_CLI/MRI2CBCT_REG/MRI2CBCT_REG.xml new file mode 100644 index 0000000..bfedc32 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_REG/MRI2CBCT_REG.xml @@ -0,0 +1,61 @@ + + + Automated Dental Tools.Advanced + 1 + MRI2CBCT_REG + + 0.0.1 + https://github.com/username/project + Slicer + FirstName LastName (Institution), FirstName LastName (Institution) + This work was partially funded by NIH grant NXNNXXNNNNNN-NNXN + + + + + + + + folder_general + + 0 + output_folder + + + + mri_folder + + 1 + Input Path MRI folder + + + + cbct_folder + + 2 + Input Path CBCT folder + + + + cbct_label2 + + 3 + Input path CB CBCT label + + + + normalization + + 4 + Normalization to use for MRI and CBCT + + + + tempo_fold + + 5 + If keeping temporary fold or not + + + + diff --git a/MRI2CBCT_CLI/MRI2CBCT_RESAMPLE_CBCT_MRI/CMakeLists.txt b/MRI2CBCT_CLI/MRI2CBCT_RESAMPLE_CBCT_MRI/CMakeLists.txt new file mode 100644 index 0000000..7495865 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_RESAMPLE_CBCT_MRI/CMakeLists.txt @@ -0,0 +1,7 @@ +set(MODULE_NAME MRI2CBCT_RESAMPLE_CBCT_MRI) + + +SlicerMacroBuildScriptedCLI( + NAME ${MODULE_NAME} +) + diff --git a/MRI2CBCT_CLI/MRI2CBCT_RESAMPLE_CBCT_MRI/MRI2CBCT_RESAMPLE_CBCT_MRI.py b/MRI2CBCT_CLI/MRI2CBCT_RESAMPLE_CBCT_MRI/MRI2CBCT_RESAMPLE_CBCT_MRI.py new file mode 100644 index 0000000..0d84e04 --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_RESAMPLE_CBCT_MRI/MRI2CBCT_RESAMPLE_CBCT_MRI.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python-real + +import csv +import argparse +import os + +import sys +fpath = os.path.join(os.path.dirname(__file__), "..") +sys.path.append(fpath) + +from MRI2CBCT_CLI_utils import create_csv, resample_images +import csv + + +def run_resample(img=None, dir=None, csv=None, csv_column='image', csv_root_path=None, csv_use_spc=0, + csv_column_spcx=None, csv_column_spcy=None, csv_column_spcz=None, ref=None, size=None, + img_spacing=None, spacing=None, origin=None, linear=False, center=0, fit_spacing=False, + iso_spacing=False, image_dimension=2, pixel_dimension=1, rgb=False, ow=1, out="./out.nrrd", + out_ext=None): + args = { + 'img': img, + 'dir': dir, + 'csv': csv, + 'csv_column': csv_column, + 'csv_root_path': csv_root_path, + 'csv_use_spc': csv_use_spc, + 'csv_column_spcx': csv_column_spcx, + 'csv_column_spcy': csv_column_spcy, + 'csv_column_spcz': csv_column_spcz, + 'ref': ref, + 'size': size, + 'img_spacing': img_spacing, + 'spacing': spacing, + 'origin': origin, + 'linear': linear, + 'center': center, + 'fit_spacing': fit_spacing, + 'iso_spacing': iso_spacing, + 'image_dimension': image_dimension, + 'pixel_dimension': pixel_dimension, + 'rgb': rgb, + 'ow': ow, + 'out': out, + 'out_ext': out_ext, + } + resample_images(args) + +def transform_size(size_str): + """ + Transforms a string '[x,y,z]' into 'x y z' with x, y, z as integers. + + :param size_str: String in the format '[x,y,z]' + :return: String in the format 'x y z' + """ + # Remove the brackets and split by comma + size_list = size_str.strip('[]').split(',') + + # Convert each element to int and join with space + size_transformed = ' '.join(map(str, map(int, size_list))) + + return size_transformed + +def main(input_folder,output_folder,resample_size,spacing,iso_spacing): + csv_path = create_csv(input_folder,output_folder,output_csv=output_folder,name_csv="resample_csv.csv") + + with open(csv_path, mode='r') as csv_file: + csv_reader = csv.DictReader(csv_file) + for row in csv_reader: + size_file = tuple(map(int, row["size"].strip("()").split(","))) + spacing_file = tuple(map(float, row["Spacing"].strip("()").split(","))) + input_path = row["in"] + out_path = row["out"] + if resample_size != "None" and spacing=="None" : + run_resample(img=input_path,out=out_path,size=list(map(int, resample_size.split(','))),fit_spacing=True,center=0,iso_spacing=False,linear=False,image_dimension=3,pixel_dimension=1,rgb=False,ow=0) + elif resample_size == "None" and spacing!="None" : + run_resample(img=input_path,out=out_path,spacing=list(map(float, spacing.split(','))),size=[size_file[0],size_file[1],size_file[2]],fit_spacing=False,center=0,iso_spacing=False,linear=False,image_dimension=3,pixel_dimension=1,rgb=False,ow=0) + elif resample_size != "None" and spacing!="None" : + run_resample(img=input_path,out=out_path,spacing=list(map(float, spacing.split(','))),size=list(map(int, resample_size.split(','))),fit_spacing=True,center=0,iso_spacing=False,linear=False,image_dimension=3,pixel_dimension=1,rgb=False,ow=0) + + delete_csv(csv_path) + +def delete_csv(file_path): + """Delete a CSV file if it exists.""" + try: + if os.path.exists(file_path): + os.remove(file_path) + print(f"File {file_path} has been deleted successfully.") + else: + print(f"File {file_path} does not exist.") + except Exception as e: + print(f"An error occurred while trying to delete the file {file_path}: {e}") + + + + +if __name__=="__main__": + parser = argparse.ArgumentParser(description='Get nifti info') + parser.add_argument('input_folder_MRI', type=str, help='Input path') + parser.add_argument('input_folder_CBCT', type=str, help='Input path') + parser.add_argument('output_folder', type=str, help='Output path') + parser.add_argument('resample_size', type=str, help='size_resample') + parser.add_argument('spacing', type=str, help='size_resample') + args = parser.parse_args() + + + if os.path.isdir(args.input_folder_MRI): + main(args.input_folder_MRI,args.output_folder,args.resample_size,args.spacing,iso_spacing=True) + if os.path.isdir(args.input_folder_CBCT): + main(args.input_folder_CBCT,args.output_folder,args.resample_size,args.spacing,iso_spacing=False) \ No newline at end of file diff --git a/MRI2CBCT_CLI/MRI2CBCT_RESAMPLE_CBCT_MRI/MRI2CBCT_RESAMPLE_CBCT_MRI.xml b/MRI2CBCT_CLI/MRI2CBCT_RESAMPLE_CBCT_MRI/MRI2CBCT_RESAMPLE_CBCT_MRI.xml new file mode 100644 index 0000000..e4a083b --- /dev/null +++ b/MRI2CBCT_CLI/MRI2CBCT_RESAMPLE_CBCT_MRI/MRI2CBCT_RESAMPLE_CBCT_MRI.xml @@ -0,0 +1,55 @@ + + + Automated Dental Tools.Advanced + 2 + MRI2CBCT_RESAMPLE_CBCT_MRI + + 0.0.1 + https://github.com/username/project + Slicer + FirstName LastName (Institution), FirstName LastName (Institution) + This work was partially funded by NIH grant NXNNXXNNNNNN-NNXN + + + + + + + + input_folder_MRI + + 0 + Path for the input MRI. + + + + input_folder_CBCT + + 1 + Path for the input CBCT. + + + + output_folder + + 2 + output folder + + + + resample_size + + 3 + size + + + + spacing + + 4 + spacing + + + + +