Skip to content

Commit

Permalink
Merge pull request #99 from GaelleLeroux/AREG_windows
Browse files Browse the repository at this point in the history
ENH : upadate AREG_IOS on windows
  • Loading branch information
allemangD authored May 30, 2024
2 parents 53084b1 + d442d5f commit 6b331d1
Show file tree
Hide file tree
Showing 6 changed files with 498 additions and 134 deletions.
365 changes: 363 additions & 2 deletions AREG/AREG.py

Large diffs are not rendered by default.

119 changes: 98 additions & 21 deletions AREG/AREG_Method/IOS.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import os
import vtk
import shutil
import platform
import csv


class Auto_IOS(Method):
Expand Down Expand Up @@ -63,6 +65,47 @@ def TestModel(self, model_folder: str, lineEditName) -> str:
out = "Please give folder with only one .ckpt file \n"

return out

def create_csv(self,input_dir,name_csv):
'''
create a csv with the complete path of the files in the folder (used for segmentation only)
'''
file_path = os.path.abspath(__file__)
folder_path = os.path.dirname(file_path)
csv_file = os.path.join(folder_path,f"{name_csv}.csv")
with open(csv_file, 'w', newline='') as fichier:
writer = csv.writer(fichier)
# Écrire l'en-tête du CSV
writer.writerow(["surf"])

# Parcourir le dossier et ses sous-dossiers
for root, dirs, files in os.walk(input_dir):
for file in files:
if file.endswith(".vtk") or file.endswith(".stl"):
# Écrire le chemin complet du fichier dans le CSV
if platform.system() != "Windows" :
writer.writerow([os.path.join(root, file)])
else :
file_path = os.path.join(root, file)
norm_file_path = os.path.normpath(file_path)
writer.writerow([self.windows_to_linux_path(norm_file_path)])


return csv_file

def windows_to_linux_path(self,windows_path):
'''
Convert a windows path to a wsl path
'''
windows_path = windows_path.strip()

path = windows_path.replace('\\', '/')

if ':' in path:
drive, path_without_drive = path.split(':', 1)
path = "/mnt/" + drive.lower() + path_without_drive

return path

def TestReference(self, ref_folder: str):

Expand Down Expand Up @@ -162,12 +205,15 @@ def __BypassCrownseg__(self, folder, folder_toseg, folder_bypass):
files = self.search(folder, ".vtk")[".vtk"]
toseg = 0
for file in files:
name = os.path.basename(file)
base_name = os.path.basename(file)
if self.__isSegmented__(file):
shutil.copy(file, os.path.join(folder_bypass, name))
name, ext = os.path.splitext(base_name)
new_name = f"{name}_Seg{ext}"
print("new_name : ",new_name)
shutil.copy(file, os.path.join(folder_bypass, new_name))

else:
shutil.copy(file, os.path.join(folder_toseg, name))
shutil.copy(file, os.path.join(folder_toseg, base_name))
toseg += 1

return toseg
Expand Down Expand Up @@ -239,30 +285,61 @@ def Process(self, **kwargs):
number_scan_toseg_T2 = self.__BypassCrownseg__(
kwargs["input_t2_folder"], path_input_T2, path_seg_T2
)
slicer_path = slicer.app.applicationDirPath()
dentalmodelseg_path = os.path.join(slicer_path,"..","lib","Python","bin","dentalmodelseg")

surf_T1 = "None"
input_csv_T1 = "None"
vtk_folder_T1 = "None"
if os.path.isfile(path_input_T1):
extension = os.path.splitext(self.input)[1]
if extension == ".vtk" or extension == ".stl":
surf_T1 = path_input_T1

elif os.path.isdir(path_input_T1):
input_csv_T1 = self.create_csv(path_input_T1,"liste_csv_file_T1")
vtk_folder_T1 = path_input_T1

parameter_segteeth_T1 = {
"input": path_input_T1,
"output": path_seg_T1,
"subdivision_level": 2,
"resolution": 320,
"model": self.getModel(kwargs["model_folder_1"], extension="pth"),
"predictedId": "Universal_ID",
"sepOutputs": 0,
"chooseFDI": 0,
"logPath": kwargs["logPath"],
"surf": surf_T1,
"input_csv": input_csv_T1,
"out": path_seg_T1,
"overwrite": "0",
"model": "latest",
"crown_segmentation": "0",
"array_name": "Universal_ID",
"fdi": 0,
"suffix": "Seg",
"vtk_folder": vtk_folder_T1,
"dentalmodelseg_path":dentalmodelseg_path
}

surf_T2 = "None"
input_csv_T2 = "None"
vtk_folder_T2 = "None"
if os.path.isfile(path_input_T2):
extension = os.path.splitext(self.input)[1]
if extension == ".vtk" or extension == ".stl":
surf_T2 = path_input_T2

elif os.path.isdir(path_input_T2):
input_csv_T2 = self.create_csv(path_input_T2,"liste_csv_file_T2")
vtk_folder_T2 = path_input_T2

parameter_segteeth_T2 = {
"input": path_input_T2,
"output": path_seg_T2,
"subdivision_level": 2,
"resolution": 320,
"model": self.getModel(kwargs["model_folder_1"], extension="pth"),
"predictedId": "Universal_ID",
"sepOutputs": 0,
"chooseFDI": 0,
"logPath": kwargs["logPath"],
"surf": surf_T2,
"input_csv": input_csv_T2,
"out": path_seg_T2,
"overwrite": "0",
"model": "latest",
"crown_segmentation": "0",
"array_name": "Universal_ID",
"fdi": 0,
"suffix": "Seg",
"vtk_folder": vtk_folder_T2,
"dentalmodelseg_path":dentalmodelseg_path
}


parameter_pre_aso_T1 = {
"input": path_seg_T1,
Expand Down
147 changes: 36 additions & 111 deletions AREG_IOS/AREG_IOS.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,10 @@
#!/usr/bin/env python-real


# def installPackages():
# from slicer.util import pip_install, pip_uninstall
# import sys
# import os

# try:
# import pandas
# except ImportError:
# pip_install("pandas")

# try:
# import torch

# pyt_version_str = torch.__version__.split("+")[0].replace(".", "")
# version_str = "".join(
# [
# f"py3{sys.version_info.minor}_cu",
# torch.version.cuda.replace(".", ""),
# f"_pyt{pyt_version_str}",
# ]
# )
# if version_str != "py39_cu113_pyt1120":
# raise ImportError
# except ImportError:
# # pip_install('--no-cache-dir torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113')
# pip_install(
# "--force-reinstall torch==1.12.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113"
# )

# try:
# import monai
# except ImportError:
# pip_install("monai")

# from platform import system # to know which OS is used

# if system() == "Darwin": # MACOS
# try:
# import pytorch3d
# except ImportError:
# pip_install("pytorch3d")
# import pytorch3d

# else: # Linux or Windows
# try:
# import pytorch3d

# if pytorch3d.__version__ != "0.7.0":
# raise ImportError
# except ImportError:
# # try:
# # # import torch
# # pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
# # version_str="".join([f"py3{sys.version_info.minor}_cu",torch.version.cuda.replace(".",""),f"_pyt{pyt_version_str}"])
# # pip_install('--upgrade pip')
# # pip_install('fvcore==0.1.5.post20220305')
# # pip_install('--no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html')
# # except: # install correct torch version
# # pip_install('--no-cache-dir torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113')
# # pip_install('--no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py39_cu113_pyt1110/download.html')

# try:
# code_path = os.sep.join(
# os.path.dirname(os.path.abspath(__file__)).split(os.sep)
# )
# # print(code_path)
# pip_install(
# os.path.join(
# code_path,
# "AREG_IOS_utils",
# "pytorch3d-0.7.0-cp39-cp39-linux_x86_64.whl",
# )
# ) # py39_cu113_pyt1120
# except:
# pip_install(
# "--force-reinstall --no-deps --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py39_cu113_pyt1120/download.html"
# )

# try:
# import pytorch_lightning
# except ImportError:
# pip_install("pytorch_lightning==1.7.7")

# import numpy

# if float(".".join(numpy.__version__.split(".")[:2])) >= 1.23:
# pip_install("numpy==1.21.1")


# installPackages()
import os
import sys
import argparse

import pandas
import torch
import monai
import pytorch3d
import pytorch_lightning
import numpy
import platform


# from tqdm import tqdm
Expand All @@ -109,25 +13,45 @@
fpath = os.path.join(os.path.dirname(__file__), "..")
sys.path.append(fpath)

from AREG_IOS_utils import (
DatasetPatch,
PredPatch,
vtkMeshTeeth,
vtkICP,
ICP,
WriteSurf,
TransformSurf,
)
def check_platform():
if platform.system() == 'Windows':
return "Windows"
elif platform.system() == 'Linux':
if 'Microsoft' in platform.release():
return "WSL"
else:
return "Linux"
else:
return "Unknown"

if check_platform()=="WSL":
from AREG_IOS_utils.dataset import DatasetPatch
from AREG_IOS_utils.PredPatch import PredPatch
from AREG_IOS_utils.vtkSegTeeth import vtkMeshTeeth
from AREG_IOS_utils.ICP import vtkICP
from AREG_IOS_utils.ICP import ICP
from AREG_IOS_utils.utils import WriteSurf
from AREG_IOS_utils.transformation import TransformSurf

else :
from AREG_IOS_utils import (
DatasetPatch,
PredPatch,
vtkMeshTeeth,
vtkICP,
ICP,
WriteSurf,
TransformSurf,
)


def main(args):

if not os.path.exists(os.path.split(args.log_path)[0]):
os.mkdir(os.path.split(args.log_path)[0])

with open(args.log_path, "w") as log_f:
log_f.truncate(0)

dataset = DatasetPatch(args.T1, args.T2, "Universal_ID")
Patched = PredPatch(args.model)

Expand All @@ -138,17 +62,21 @@ def main(args):
lower = False
if dataset.isLower():
lower = True


# pbar = tqdm(total=len(dataset)*3,desc='Segment Palate')
for idx in range(len(dataset)):
print("idx : ",idx)

name = os.path.basename(dataset.getUpperPath(idx, "T1"))

# pbar.set_description(f'Patch {name}')
surf_T1 = dataset.getUpperSurf(idx, "T1")
surf_T1 = Patched(dataset[idx, "T1"], surf_T1)
# pbar.update(1)

name = os.path.basename(dataset.getUpperPath(idx, "T1"))

WriteSurf(surf_T1, args.output, name, args.suffix)

with open(args.log_path, "r+") as log_f:
Expand Down Expand Up @@ -192,9 +120,7 @@ def main(args):

print("Starting")
print(sys.argv)

parser = argparse.ArgumentParser()

parser.add_argument("T1", type=str)
parser.add_argument("T2", type=str)
parser.add_argument("output", type=str)
Expand All @@ -203,5 +129,4 @@ def main(args):
parser.add_argument("log_path", type=str)

args = parser.parse_args()

main(args)
Binary file modified ASO_IOS/ASO_IOS_utils/cache/source.npy
Binary file not shown.
Binary file modified ASO_IOS/ASO_IOS_utils/cache/target.npy
Binary file not shown.
1 change: 1 addition & 0 deletions ASO_IOS/PRE_ASO_IOS/PRE_ASO_IOS.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import argparse
import numpy as np

from tqdm import tqdm
from itertools import chain

Expand Down

0 comments on commit 6b331d1

Please sign in to comment.