diff --git a/README.md b/README.md index a7ad3dc..7fea7e6 100755 --- a/README.md +++ b/README.md @@ -8,15 +8,14 @@ * **Refinement MaskGAN version**: An extended model, which refined images using a simple, yet effective multi-stage, multi-plane approach is develop to improve the volumetric definition of synthetic images. * **Model enhancements**: We include selection strategies to choose similar MRI/CT matches based on the position of slices. - ## MaskGAN Framework - A novel unsupervised MR-to-CT synthesis method that preserves the anatomy under the explicit supervision of coarse masks without using costly manual annotations. MaskGAN bypasses the need for precise annotations, replacing them with standard (unsupervised) image processing techniques, which can produce coarse anatomical masks. +A novel unsupervised MR-to-CT synthesis method that preserves the anatomy under the explicit supervision of coarse masks without using costly manual annotations. MaskGAN bypasses the need for precise annotations, replacing them with standard (unsupervised) image processing techniques, which can produce coarse anatomical masks. Such masks, although imperfect, provide sufficient cues for MaskGAN to capture anatomical outlines and produce structurally consistent images. ![Framework](./imgs/maskgan_v2.svg) -## Comparsion with State-of-the-Art Methods on Paediatric MR-CT Synthesis +## Comparison with State-of-the-Art Methods on Paediatric MR-CT Synthesis ![Result](./imgs/results.jpg) diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index 1144493..9b0f2f3 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -49,6 +49,14 @@ def __init__(self, opt): # self.transform_maskA = get_transform(self.opt, grayscale=(self.input_nc == 1), mask=True) # self.transform_maskB = get_transform(self.opt, grayscale=(self.output_nc == 1), mask=True) + # Save relative position of each img Input images should be given in the format xxxx_RELATIVEPOSITION.jpg + self.relative_pos_A = [int(img.split(".")[-2].split("_")[-1]) for img in self.A_paths] + self.relative_pos_B = [int(img.split(".")[-2].split("_")[-1]) for img in self.B_paths] + # Define range of adjacent slices to consider + if opt.phase == 'train': + self.position_based_range = opt.position_based_range*10 + + def __getitem__(self, index): """Return a data point and its metadata information. @@ -66,7 +74,17 @@ def __getitem__(self, index): if self.opt.serial_batches: # make sure index is within then range index_B = index % self.B_size else: # randomize the index for domain B to avoid fixed pairs. - index_B = random.randint(0, self.B_size - 1) + # Check the relative position of the image (Position based selection PBS) + A_path_spplited = A_path.split(".") + A_relative_position = A_path_spplited[-2].split("_")[-1] + # Convert to a number + A_relative_position = float(A_relative_position) + # Obtain the images in a similar range (Position based selection) + potential_indexes = [index for index, value in enumerate(self.relative_pos_B) if (A_relative_position-self.position_based_range) <= value <= (A_relative_position + self.position_based_range)] + # Define position of B image + potential_indexes = list(set(potential_indexes) & set(potential_indexes)) + index_position = random.randint(0, len(potential_indexes) - 1) + index_B = potential_indexes[index_position] B_path = self.B_paths[index_B] maskA_path = self.maskA_paths[index_A] maskB_path = self.maskB_paths[index_B] diff --git a/options/train_options.py b/options/train_options.py index 1cdb59f..66db0c5 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -37,6 +37,7 @@ def initialize(self, parser): parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + parser.add_argument('--position_based_range', type = int, default=3, help='Define the range for the position-based selection strategy (PBS). In percentage') self.isTrain = True return parser diff --git a/preprocess/README.md b/preprocess/README.md index 087cb90..17843b5 100644 --- a/preprocess/README.md +++ b/preprocess/README.md @@ -1,7 +1,7 @@ # Preprocess MR-CT data and generate masks -- For simplicity, we assume the dataset have all pairs MRI-CT. -- The simplified code only has a single for-loop to partition 80/20/20 train/val/test. -- If you have an unpaired training set, i.e., source and target modalities do not match. You can simply separate the training data preprocessing and val/test data preprocessing by copy-paste the for-loop. +- For simplicity, we assume the dataset have MRI and CT scans. +- This code assumes that you have the raw data in three folders: **train**, **val** and **test**. +- train set contains only **unpaired images**, whereas val and set contain **paired images**. ## Environment installation Setup using `pip install -r requirements.txt` @@ -10,19 +10,30 @@ Setup using `pip install -r requirements.txt` - Refer to your root folder as `root`. Assume your data structure is as follows ```bash ├── root/ -│ ├── MRI/ -│ │ ├── filename001.nii -│ │ ├── filename002.nii -│ │ └── ... -│ └── CT/ -│ ├── filename001.nii -│ ├── filename002.nii -│ └── ... -``` -- If your data structure is different, please modify the pattern matching expression at lines 138-139 in `preprocess/main.py`: -```python -root_a = f'{data_dir}/MRI/*.nii' -root_b = f'{data_dir}/CT/*.nii' + ├── train/ + | ├── MRI/ + │ │ ├── filename001.nii + │ │ ├── filename002.nii + │ │ └── ... + | └── CT/ + | ├── filename001.nii + | ├── filename002.nii + | └── ... + ├── val/ + | ├── MRI/ + | | ├── filename003.nii + | | └── filename004.nii + | └── CT/ + | ├── filename003.nii + | └── filename004.nii + └── test/ + ├── MRI/ + | ├── filename005.nii + | └── filename006.nii + └── CT/ + ├── filename005.nii + └── filename006.nii + ``` ## Preprocess @@ -32,3 +43,7 @@ root_b = f'{data_dir}/CT/*.nii' - `--resample`: resample the resolution of the medical scans, default is [1.0, 1.0, 1.0] mm^3. Since our paediatric scans have irregular sizes, we need to crop the depth and height dimensions in function `crop_scan()` at Ln 47. When running, the preprocessed 2D slice visualizations are saved under `vis` for your inspection. Use them to modify data augmentation `crop_scan()` as needed. + +## Update! + +Now, to use the based position selection strategy, our preprocessing stage generate files as: filename_XXX.jpg, where XXX corresponds to the relative position of the slice respect to the entire volumetric image. In that way, our model can choose slices of similar position. diff --git a/preprocess/main.py b/preprocess/main.py index b8dd72d..b3acc23 100644 --- a/preprocess/main.py +++ b/preprocess/main.py @@ -15,8 +15,8 @@ import glob import imageio import argparse -from tqdm import tqdm - +from scipy import ndimage +from skimage.morphology import binary_erosion, binary_dilation def visualize(img, filename, step=10): shapes = img.shape @@ -43,28 +43,21 @@ def normalize(img, min_=None, max_=None): max_ = img.max() return (img - min_)/(max_ - min_) - def crop_scan(img, mask, crop=0, crop_h=0, ignore_zero=True): - # Swap dimensions for visualizability - modify as needed img = np.transpose(img, (0,2,1))[:,::-1,::-1] if mask is not None: mask = np.transpose(mask, (0,2,1))[:,::-1,::-1] - - # Exclude all zero (air only) slices during training - modify as needed if ignore_zero: mask_ = img.sum(axis=(1,2)) > 0 img = img[mask_] if mask is not None: mask = mask[mask_] - - # Crop depth dimension - modify as needed if crop > 0: length = img.shape[0] img = img[int(crop*length): int((1-crop)*length)] if mask is not None: mask = mask[int(crop*length): int((1-crop)*length)] - # Crop height dimension - modify as needed if crop_h > 0: if img.shape[1] > 200: crop_h = 0.8 @@ -75,26 +68,55 @@ def crop_scan(img, mask, crop=0, crop_h=0, ignore_zero=True): return img, mask + +def crop_scan_paired(img1, img2, mask, crop=0, crop_h=0, ignore_zero=True): + img1 = np.transpose(img1, (0,2,1))[:,::-1,::-1] + img2 = np.transpose(img2, (0,2,1))[:,::-1,::-1] + if mask is not None: + mask = np.transpose(mask, (0,2,1))[:,::-1,::-1] + if ignore_zero: + mask1_ = img1.sum(axis=(1,2)) > 0 + mask2_ = img2.sum(axis=(1,2)) > 0 + mask_ = mask1_ * mask2_ + img1 = img1[mask_] + img2 = img2[mask_] + if mask is not None: + mask = mask[mask_] + if crop > 0: + length = img1.shape[0] + img1 = img1[int(crop*length): int((1-crop)*length)] + img2 = img2[int(crop*length): int((1-crop)*length)] + if mask is not None: + mask = mask[int(crop*length): int((1-crop)*length)] + + if crop_h > 0: + if img1.shape[1] > 200: + crop_h = 0.8 + new_h = int(crop_h*img1.shape[1]) + img1 = img1[:, :new_h] + img2 = img2[:, :new_h] + if mask is not None: + mask = mask[:, :new_h] + + return img1, img2, mask + def getLargestCC(segmentation): labels = label(segmentation) assert( labels.max() != 0 ) # assume at least 1 CC largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 return largestCC - - - def get_3d_mask(img, min_, max_=None, th=50, width=2): if max_ is None: max_ = img.max() img = np.clip(img, min_, max_) img = np.uint8(255*normalize(img, min_, max_)) - ## Remove artifacts - mask = np.zeros(img.shape).astype(np.int32) + ## Remove holes + mask = np.zeros(img.shape).astype(int) mask[img > th] = 1 - ## Remove artifacts and small holes with binary opening + ## Opening np.ones((3,3,3)) mask = morphology.binary_opening(mask, ) remove_holes = morphology.remove_small_holes( @@ -102,26 +124,53 @@ def get_3d_mask(img, min_, max_=None, th=50, width=2): area_threshold=width ** 3 ) - largest_cc = getLargestCC(remove_holes) - return img, largest_cc.astype(np.int32) + return img, largest_cc.astype(int) + +def float_to_padded_string(number, total_digits=3): + formatted_number = format(number, f".{total_digits}f") + return formatted_number.lstrip('0.') or '0' + +def resize_volume(img,desired_depth,desired_width, desired_height): + + current_depth = img.shape[0] + current_width = img.shape[1] + current_height = img.shape[2] + + depth = current_depth / desired_depth + width = current_width / desired_width + height = current_height / desired_height + + depth_factor = 1 / depth + width_factor = 1 / width + height_factor = 1 / height + + img = ndimage.zoom(img, (depth_factor, width_factor, height_factor), order=1) + return img def save_slice(img, mask, data_dir, data_mask_dir, filename): assert img.shape == mask.shape, f"Shape not match - img {img.shape} vs mask {mask.shape}" + pad_width = ((5, 5), (5, 5), (5, 5)) + # Resize to 204,204,204 + img = resize_volume(img, 214,214,214) + img = np.pad(img, pad_width, mode='constant', constant_values=0) + mask = resize_volume(mask, 214,214,214) + mask = np.pad(mask, pad_width, mode='constant', constant_values=0) + for i in range(len(img)): - im = np.uint8(255*normalize(img[i])) - m = np.uint8(255*normalize(mask[i])) - imageio.imwrite(f'{data_dir}/{filename}_{i}.png', im) - imageio.imwrite(f'{data_mask_dir}/{filename}_{i}.png', m) + #im = np.uint8(255*normalize(img[i])) + im = img[i] + m = 255*mask[i].astype(np.uint8) + + #m = np.uint8(255*normalize(mask[i])) + imageio.imwrite(f'{data_dir}/{filename}_{str(i).zfill(3)}_{float_to_padded_string(round(i/len(img),2), 3)}.png', im) + imageio.imwrite(f'{data_mask_dir}/{filename}_{str(i).zfill(3)}_{float_to_padded_string(round(i/len(img),2), 3)}.png', m) -def parse_float_array(s): - try: - values = [float(x.strip()) for x in s.split(',')] - return values - except ValueError: - raise argparse.ArgumentTypeError('Invalid float array format. Example: 1.0, 1.0, 1.0') +######################################################## +### DEFINITION OF DIRECTORIES ### +######################################################## if __name__ == '__main__': parser = argparse.ArgumentParser(description='Preprocess data') @@ -130,115 +179,299 @@ def parse_float_array(s): parser.add_argument('--out', type=str, default='processed-mr-ct', help='Output directory') parser.add_argument('--resample', nargs='+', type=float, default=[1.0, 1.0, 1.0], help='Resample scan resolutions') args = parser.parse_args() - + # Declaring variables data_dir = args.root out_dir = args.out resample = args.resample + # Defining root folders (nii files) + train_root = f'{data_dir}/train/' + val_root = f'{data_dir}/val/' + test_root = f'{data_dir}/test/' + + # Defining output folders + train_ct = train_root + 'ct/*.nii.gz' + train_mri = train_root + 'mri/*.nii.gz' + val_ct = val_root + 'ct/*.nii.gz' + val_mri = val_root + 'mri/*.nii.gz' + test_ct = test_root + 'ct/*.nii.gz' + test_mri = test_root + 'mri/*.nii.gz' + + output_ct_dir = f'{out_dir}/train_B' + output_mri_dir = f'{out_dir}/train_A' + output_ct_mask_dir = f'{out_dir}/train_maskB' + output_mri_mask_dir = f'{out_dir}/train_maskA' + + # Create main directories + os.makedirs(output_ct_dir, exist_ok=True) + os.makedirs(output_mri_dir, exist_ok=True) + os.makedirs(output_ct_mask_dir, exist_ok=True) + os.makedirs(output_mri_mask_dir, exist_ok=True) + + ct_files_train = glob.glob(train_ct) + mri_files_train = glob.glob(train_mri) + #min_ct, max_ct = -800, 1900 + th_ct = 50 + th_mri = 10 # Set clip CT intensity - min_ct, max_ct = -1000, 2000 - th = 10 # Consider pixel less than certain threshold as background (remove noise, artifacts) - - # Modify pattern matching expression to list MR and CT nifti files - root_a = f'{data_dir}/MRI/*.nii' - root_b = f'{data_dir}/CT/*.nii' - output_a_dir = f'{out_dir}/train_A' - output_b_dir = f'{out_dir}/train_B' - output_a_mask_dir = f'{out_dir}/train_maskA' - output_b_mask_dir = f'{out_dir}/train_maskB' - - output_a_val_dir = f'{out_dir}/val_A' - output_b_val_dir = f'{out_dir}/val_B' - output_a_mask_val_dir = f'{out_dir}/val_maskA' - output_b_mask_val_dir = f'{out_dir}/val_maskB' - - output_a_test_dir = f'{out_dir}/test_A' - output_b_test_dir = f'{out_dir}/test_B' - output_a_mask_test_dir = f'{out_dir}/test_maskA' - output_b_mask_test_dir = f'{out_dir}/test_maskB' - - os.makedirs(output_a_dir, exist_ok=True) - os.makedirs(output_b_dir, exist_ok=True) - os.makedirs(output_a_mask_dir, exist_ok=True) - os.makedirs(output_b_mask_dir, exist_ok=True) - - os.makedirs(output_a_val_dir, exist_ok=True) - os.makedirs(output_b_val_dir, exist_ok=True) - os.makedirs(output_a_mask_val_dir, exist_ok=True) - os.makedirs(output_b_mask_val_dir, exist_ok=True) - - os.makedirs(output_a_test_dir, exist_ok=True) - os.makedirs(output_b_test_dir, exist_ok=True) - os.makedirs(output_a_mask_test_dir, exist_ok=True) - os.makedirs(output_b_mask_test_dir, exist_ok=True) - - ## Partition scan into 80/10/10 for train/val/test - a_files = sorted(glob.glob(root_a)) - b_files = sorted(glob.glob(root_b)) - train_len = int(len(a_files)*0.8) - val_len = int(len(a_files)*0.1) - train_idx = np.arange(0, train_len) - val_idx = np.arange(train_len, train_len + val_len) - - - ## Resample resolutions as needed - - results = 'vis' # Visualize 2D slices for debugging purpose + min_ct, max_ct = -800, 2000 + #th = 10 # Consider pixel less than certain threshold as background (remove noise, artifacts) + + results = 'vis' os.makedirs(results, exist_ok=True) - crop = 0.0 # Crop first dimension, ignore some air only scans - crop_h = 0.9 # Crop height dimension (dim 2) + crop = 0.0 + crop_h = 0.9 + resample = [1.0, 1.0, 1.0] - ## CT preprocess - for idx, filepath in enumerate(tqdm(b_files)): - img = ants.image_read(filepath) - img = ants.resample_image(img, resample, False, 1) - img = img.numpy() - filename = os.path.splitext(os.path.basename(filepath))[0] - img, mask = get_3d_mask(img, min_=min_ct, max_=max_ct, th=th) + print("Creating MR images for training") + for idx, filepath in enumerate(mri_files_train): + + mri = ants.image_read(filepath) + mri = ants.resample_image(mri, resample, False, 1).numpy() + #filename = os.path.splitext(os.path.basename(filepath))[0] + mri, mask = get_3d_mask(mri, min_=0, th=th_mri, width=10) # Our scans have irregular size, crop to adjust, comment out as needed - img, mask = crop_scan(img, mask, crop, ignore_zero=(idx in train_idx)) - if idx in train_idx: - output_ct_dir = output_b_dir - output_ct_mask_dir = output_b_mask_dir - elif idx in val_idx: - output_ct_dir = output_b_val_dir - output_ct_mask_dir = output_b_mask_val_dir - else: - output_ct_dir = output_b_test_dir - output_ct_mask_dir = output_b_mask_test_dir - save_slice(img, mask, output_ct_dir, output_ct_mask_dir, filename) - visualize(img, f'{results}/ct') - visualize(mask, f'{results}/ct_mask') - - ## MRI preprocess - for idx, filepath in enumerate(tqdm(a_files)): - img = ants.image_read(filepath) - img = ants.resample_image(img, resample, False, 1) - img = img.numpy() - filename = os.path.splitext(os.path.basename(filepath))[0] - img, mask = get_3d_mask(img, min_=0, th=10, width=10) + mri, mask = crop_scan(mri, mask, crop,crop_h) + + mri = mri.astype('uint8') + mask = mask.astype('uint8') + + # Remove noise + # Define the structure element for erosion + selem_ero = np.ones((1, 1, 1), dtype=bool) + selem_dil = np.ones((10, 10, 10), dtype=bool) + # Perform erosion on the entire 3D array + eroded_mask = binary_erosion(mask, selem_ero) + mask[eroded_mask == 0] = 0 + dilated_mask = binary_dilation(mask, selem_dil) + mri[dilated_mask == 0] = 0 + mask[dilated_mask== 0] = 0 + + # Remove images with zero values in the mask + non_zero_slices_mask_axis1_2 = np.any(mask, axis=(1, 2)) + mri = mri[non_zero_slices_mask_axis1_2] + mask = mask[non_zero_slices_mask_axis1_2] + non_zero_slices_mask_axis0_1 = np.any(mask, axis=(0, 1)) + mri = mri[:,:,non_zero_slices_mask_axis0_1] + mask = mask[:,:,non_zero_slices_mask_axis0_1] + non_zero_slices_mask_axis0_2 = np.any(mask, axis=(0, 2)) + mri = mri[:,non_zero_slices_mask_axis0_2,:] + mask = mask[:,non_zero_slices_mask_axis0_2,:] + + # Enter the name of the file + filename = filepath.split('/')[-1].replace('.nii.gz','') + # Create a generic format + filename = filename.zfill(3) + save_slice(mri, mask, output_mri_dir, output_mri_mask_dir, filename) + + print("Creating CT images for training") + for idx, filepath in enumerate(ct_files_train): + + ct = ants.image_read(filepath) + ct = ants.resample_image(ct, resample, False, 1).numpy() + ct, mask = get_3d_mask(ct, min_=min_ct, max_=max_ct, th=th_ct) # Our scans have irregular size, crop to adjust, comment out as needed - img, mask = crop_scan(img, mask, crop, ignore_zero=(idx in train_idx)) - if idx in train_idx: - output_mri_dir = output_a_dir - output_mri_mask_dir = output_a_mask_dir - elif idx in val_idx: - output_mri_dir = output_a_val_dir - output_mri_mask_dir = output_a_mask_val_dir - else: - output_mri_dir = output_a_test_dir - output_mri_mask_dir = output_a_mask_test_dir - save_slice(img, mask, output_mri_dir, output_mri_mask_dir, filename) + ct, mask = crop_scan(ct, mask, crop,crop_h) + + ct = ct.astype('uint8') + mask = mask.astype('uint8') + + # Remove noise + # Define the structure element for erosion + selem_ero = np.ones((1, 1, 1), dtype=bool) + selem_dil = np.ones((10, 10, 10), dtype=bool) + # Perform erosion on the entire 3D array + eroded_mask = binary_erosion(mask, selem_ero) + mask[eroded_mask == 0] = 0 + dilated_mask = binary_dilation(mask, selem_dil) + ct[dilated_mask == 0] = 0 + mask[dilated_mask== 0] = 0 + + # Remove images with zero values in the mask + non_zero_slices_mask_axis1_2 = np.any(mask, axis=(1, 2)) + ct = ct[non_zero_slices_mask_axis1_2] + mask = mask[non_zero_slices_mask_axis1_2] + + non_zero_slices_mask_axis0_1 = np.any(mask, axis=(0, 1)) + ct = ct[:,:,non_zero_slices_mask_axis0_1] + mask = mask[:,:,non_zero_slices_mask_axis0_1] + + non_zero_slices_mask_axis0_2 = np.any(mask, axis=(0, 2)) + ct = ct[:,non_zero_slices_mask_axis0_2,:] + mask = mask[:,non_zero_slices_mask_axis0_2,:] + + # Enter the name of the file + filename = filepath.split('/')[-1].replace('.nii.gz','') + # Create a generic format + filename = filename.zfill(3) + + save_slice(ct, mask, output_ct_dir, output_ct_mask_dir, filename) + + ### VALIDATION + output_ct_dir = f'{out_dir}/val_B' + output_mri_dir = f'{out_dir}/val_A' + output_ct_mask_dir = f'{out_dir}/val_maskB' + output_mri_mask_dir = f'{out_dir}/val_maskA' + + os.makedirs(output_ct_dir, exist_ok=True) + os.makedirs(output_mri_dir, exist_ok=True) + os.makedirs(output_ct_mask_dir, exist_ok=True) + os.makedirs(output_mri_mask_dir, exist_ok=True) + + ct_files_val = glob.glob(val_ct) + mri_files_val = glob.glob(val_mri) + + print("Creating MRI images and CT scans for validation") + for mri_path, ct_path in zip(mri_files_val,ct_files_val): + + mri = ants.image_read(mri_path) + mri = ants.resample_image(mri, resample, False, 1).numpy() + #filename = os.path.splitext(os.path.basename(filepath))[0] + mri, mri_mask = get_3d_mask(mri, min_=0, th=th_mri, width=10) + + ct = ants.image_read(ct_path) + ct = ants.resample_image(ct, resample, False, 1).numpy() + ct, ct_mask = get_3d_mask(ct, min_=min_ct, max_=max_ct, th=th_ct) + + # Getting a uniform mask template for paired images + uniform_mask = mri_mask * ct_mask + + ct, mri, uniform_mask = crop_scan_paired(ct, mri, uniform_mask,crop,crop_h) + mri_mask = uniform_mask + ct_mask = uniform_mask + + ct = ct.astype('uint8') + ct_mask = ct_mask.astype('uint8') + mri = mri.astype('uint8') + mri_mask = mri_mask.astype('uint8') + + # Remove noise + selem_ero = np.ones((1, 1, 1), dtype=bool) + selem_dil = np.ones((10, 10, 10), dtype=bool) + # Perform erosion on the entire 3D array + eroded_mask = binary_erosion(ct_mask, selem_ero) + ct_mask[eroded_mask == 0] = 0 + dilated_mask = binary_dilation(ct_mask, selem_dil) + ct[dilated_mask == 0] = 0 + ct_mask[dilated_mask== 0] = 0 + # Perform erosion on the entire 3D array + eroded_mask = binary_erosion(mri_mask, selem_ero) + mri_mask[eroded_mask == 0] = 0 + dilated_mask = binary_dilation(mri_mask, selem_dil) + mri[dilated_mask == 0] = 0 + mri_mask[dilated_mask== 0] = 0 + + # Remove images with zero values in the mask + non_zero_slices_mask_axis1_2 = np.any(ct_mask, axis=(1, 2)) + ct = ct[non_zero_slices_mask_axis1_2] + ct_mask = ct_mask[non_zero_slices_mask_axis1_2] + non_zero_slices_mask_axis0_1 = np.any(ct_mask, axis=(0, 1)) + ct = ct[:,:,non_zero_slices_mask_axis0_1] + ct_mask = ct_mask[:,:,non_zero_slices_mask_axis0_1] + non_zero_slices_mask_axis0_2 = np.any(ct_mask, axis=(0, 2)) + ct = ct[:,non_zero_slices_mask_axis0_2,:] + ct_mask = ct_mask[:,non_zero_slices_mask_axis0_2,:] + + non_zero_slices_mask_axis1_2 = np.any(mri_mask, axis=(1, 2)) + mri = mri[non_zero_slices_mask_axis1_2] + mri_mask = mri_mask[non_zero_slices_mask_axis1_2] + non_zero_slices_mask_axis0_1 = np.any(mri_mask, axis=(0, 1)) + mri = mri[:,:,non_zero_slices_mask_axis0_1] + mri_mask = mri_mask[:,:,non_zero_slices_mask_axis0_1] + non_zero_slices_mask_axis0_2 = np.any(mri_mask, axis=(0, 2)) + mri = mri[:,non_zero_slices_mask_axis0_2,:] + mri_mask = mri_mask[:,non_zero_slices_mask_axis0_2,:] + + # Enter the name of the file + filename = mri_path.split('/')[-1].replace('.nii.gz','') + # Create a generic format + filename = filename.zfill(3) - visualize(img, f'{results}/mri') - visualize(mask, f'{results}/mri_mask') + save_slice(mri, mri_mask, output_mri_dir, output_mri_mask_dir, filename) + save_slice(ct, ct_mask, output_ct_dir, output_ct_mask_dir, filename) + ### TESTING + output_ct_dir = f'{out_dir}/test_B' + output_mri_dir = f'{out_dir}/test_A' + output_ct_mask_dir = f'{out_dir}/test_maskB' + output_mri_mask_dir = f'{out_dir}/test_maskA' + os.makedirs(output_ct_dir, exist_ok=True) + os.makedirs(output_mri_dir, exist_ok=True) + os.makedirs(output_ct_mask_dir, exist_ok=True) + os.makedirs(output_mri_mask_dir, exist_ok=True) + mri_files_test = glob.glob(test_mri) + ct_files_test = glob.glob(test_ct) + print("Creating MRI images and CT scans for testing") + for mri_path, ct_path in zip(mri_files_test,ct_files_test): + mri = ants.image_read(mri_path) + mri = ants.resample_image(mri, resample, False, 1).numpy() + #filename = os.path.splitext(os.path.basename(filepath))[0] + mri, mri_mask = get_3d_mask(mri, min_=0, th=th_mri, width=10) + ct = ants.image_read(ct_path) + ct = ants.resample_image(ct, resample, False, 1).numpy() + ct, ct_mask = get_3d_mask(ct, min_=min_ct, max_=max_ct, th=th_ct) + # Getting a uniform mask template for paired images + uniform_mask = mri_mask * ct_mask + # Our scans have irregular size, crop to adjust, comment out as needed + ct, mri, uniform_mask = crop_scan_paired(ct, mri, uniform_mask,crop,crop_h) + mri_mask = uniform_mask + ct_mask = uniform_mask + + ct = ct.astype('uint8') + ct_mask = ct_mask.astype('uint8') + mri = mri.astype('uint8') + mri_mask = mri_mask.astype('uint8') + + # Remove noise + selem_ero = np.ones((1, 1, 1), dtype=bool) + selem_dil = np.ones((10, 10, 10), dtype=bool) + # Perform erosion on the entire 3D array + eroded_mask = binary_erosion(ct_mask, selem_ero) + ct_mask[eroded_mask == 0] = 0 + dilated_mask = binary_dilation(ct_mask, selem_dil) + ct[dilated_mask == 0] = 0 + ct_mask[dilated_mask== 0] = 0 + # Perform erosion on the entire 3D array + eroded_mask = binary_erosion(mri_mask, selem_ero) + mri_mask[eroded_mask == 0] = 0 + dilated_mask = binary_dilation(mri_mask, selem_dil) + mri[dilated_mask == 0] = 0 + mri_mask[dilated_mask== 0] = 0 + + # Remove images with zero values in the mask + non_zero_slices_mask_axis1_2 = np.any(ct_mask, axis=(1, 2)) + ct = ct[non_zero_slices_mask_axis1_2] + ct_mask = ct_mask[non_zero_slices_mask_axis1_2] + non_zero_slices_mask_axis0_1 = np.any(ct_mask, axis=(0, 1)) + ct = ct[:,:,non_zero_slices_mask_axis0_1] + ct_mask = ct_mask[:,:,non_zero_slices_mask_axis0_1] + non_zero_slices_mask_axis0_2 = np.any(ct_mask, axis=(0, 2)) + ct = ct[:,non_zero_slices_mask_axis0_2,:] + ct_mask = ct_mask[:,non_zero_slices_mask_axis0_2,:] + + non_zero_slices_mask_axis1_2 = np.any(mri_mask, axis=(1, 2)) + mri = mri[non_zero_slices_mask_axis1_2] + mri_mask = mri_mask[non_zero_slices_mask_axis1_2] + non_zero_slices_mask_axis0_1 = np.any(mri_mask, axis=(0, 1)) + mri = mri[:,:,non_zero_slices_mask_axis0_1] + mri_mask = mri_mask[:,:,non_zero_slices_mask_axis0_1] + non_zero_slices_mask_axis0_2 = np.any(mri_mask, axis=(0, 2)) + mri = mri[:,non_zero_slices_mask_axis0_2,:] + mri_mask = mri_mask[:,non_zero_slices_mask_axis0_2,:] + + # Enter the name of the file + filename = mri_path.split('/')[-1].replace('.nii.gz','') + # Create a generic format + filename = filename.zfill(3) + + save_slice(mri, mri_mask, output_mri_dir, output_mri_mask_dir, filename) + save_slice(ct, ct_mask, output_ct_dir, output_ct_mask_dir, filename)