Skip to content

Commit

Permalink
ENH: 2 columns training with penalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeanneclre committed Jun 25, 2024
1 parent 3b9ab85 commit febae6a
Show file tree
Hide file tree
Showing 14 changed files with 1,198 additions and 127 deletions.
3 changes: 2 additions & 1 deletion Preprocess/create_CBCTmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def applyMask(image, mask, label,dilation_radius):

# Define the structuring element for dilation
kernel = sitk.sitkBall
kernel=sitk.sitkBox
radius = dilation_radius
dilate_filter = sitk.BinaryDilateImageFilter()
dilate_filter.SetKernelType(kernel)
Expand Down Expand Up @@ -88,7 +89,7 @@ def applyMask(image, mask, label,dilation_radius):
parser = argparse.ArgumentParser(description="Apply a mask to an image")
parser.add_argument("--img", help="Input image")
parser.add_argument("--mask", help="Input mask")
parser.add_argument("--label", nargs='+', help="Label to apply the mask",default=1)
parser.add_argument("--label", nargs='+', help="Label to apply the mask. Ex: 1 or 1 2",required=True)
parser.add_argument("--output", help="Output image")
parser.add_argument("--dilatation_radius", type=int, help="Radius of the dilatation to apply to the mask",default=None)
args = parser.parse_args()
Expand Down
79 changes: 79 additions & 0 deletions Preprocess/create_csv_right_left.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import csv
import argparse
import os
import pandas as pd

def main(args):
input_file = args.input
output_file = args.output
# Create a dictionary to store the combined rows
combined_rows = {}

# Read the input CSV file
with open(input_file, 'r') as file:
reader = csv.reader(file)
next(reader) # Skip the header row

# Iterate over each row in the input file
for row in reader:
path, name, label = row

# Extract the patient number from the name
patient_number = name.split('_')[0]
side = name.split('_')[1]
# Check if a row with the same patient number already exists
if patient_number in combined_rows:
# Add the label to the corresponding column
if side == 'R':
combined_rows[patient_number]['Label R'] = label

elif side == 'L':
combined_rows[patient_number]['Label L'] = label

else:
# Create a new entry for the patient number
if side == 'R':
combined_rows[patient_number] = {
'Path': path,
'Name': name,
'Label R': label,
'Label L': 'nan'
}
elif side == 'L':
combined_rows[patient_number] = {
'Path': path,
'Name': name,
'Label R': 'nan',
'Label L': label
}

# Write the combined rows to the output CSV file
output_dir = os.path.dirname(input_file)
output_path = os.path.join(output_dir, output_file)
with open(output_path, 'w', newline='') as file:
writer = csv.writer(file)

# Write the header row
writer.writerow(['Path', 'Name', 'Label R', 'Label L','Label comb'])

# Write the combined rows
for row in combined_rows.values():
if row['Label R'] != 'nan' and row['Label L'] != 'nan':
if int(row['Label R']) > int(row['Label L']):
#concat the labels in the order of the highest label
comb_label = row['Label L']+ row['Label R']
else :
comb_label = row['Label R']+ row['Label L']
elif row['Label R'] != 'nan' and row['Label L'] == 'nan':
comb_label = row['Label R']
elif row['Label R'] == 'nan' and row['Label L'] != 'nan':
comb_label = row['Label L']

writer.writerow([row['Path'], row['Name'], row['Label R'], row['Label L'],comb_label])

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Combine labels for left and right images')
parser.add_argument('--input', required=True, type=str, help='Input CSV file')
parser.add_argument('--output', required=True, type=str, help='Output CSV filename only')
args = parser.parse_args()
main(args)
15 changes: 15 additions & 0 deletions Preprocess/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,20 @@ def count_classes(csv_file,word_class='Classification',dict_classes={}):
"nan": '' ,
}

dict_classes ={
None: 0,
0: 1,
1: 2,
2: 3,

}

# dict_classes = {
# None: 4,
# 3: 5,
# 4: 6,
# 5: 7,
# }

classes = count_classes(args.input,args.class_column,dict_classes)
print(classes)
126 changes: 126 additions & 0 deletions Preprocess/dataset_info_8classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
'''
Date: November 2023
Author: Jeanne Claret
Description: This file contains the information about the dataset
such as the number of data per type, the number of classes.
Used to rename the classes in the dataset.
'''

import csv
import argparse
import pandas as pd

# Count the number of different classes in the column "Position" of te csv file

def count_classes(csv_file,word_class='Classification',dict_classes={}):
reader = pd.read_csv(csv_file)
classes = {}
output_file = csv_file.split('.')[0] + '_classes.csv'


for index, row in reader.iterrows():
#if key_name is already an integer, count the number of similar int
#remove all empty value for row[word_class]
if pd.isnull(row[word_class]):
#delete the row
# reader = reader.drop(index)
# reader.to_csv(output_file, index=False)
# print(f'[INFO] Deleted row {index} because of empty value')
# put the empty value to None
reader.loc[index,word_class] = 10
key_name = reader.loc[index,word_class]


else:
if isinstance(row[word_class],int):
key_name = row[word_class]


elif isinstance(row[word_class],float):
#change type of float to int
key_name = int(row[word_class])
#rewrite the csv file
reader.loc[index,word_class] = key_name
reader.to_csv(output_file, index=False)

elif isinstance(row[word_class],str):
key_name = str(row[word_class]).split(' ')[0]
## For Sara's dataset
if key_name == 'Lingual':
key_name = 'Palatal'
if key_name == 'Bucccal':
key_name = 'Buccal'
if key_name == 'BuccaL':
key_name = 'Buccal'


if classes.get(key_name) is None:
classes[key_name] = 1
else:
classes[key_name] += 1


if not args.change_classes:
if not isinstance(row[word_class],int):
if isinstance(row[word_class],float):
continue
# Change the name of the classes
reader.loc[index,word_class] = dict_classes[key_name]
reader.to_csv(output_file, index=False)
else:
reader.loc[index,word_class] = dict_classes[key_name]
reader.to_csv(output_file, index=False)


return classes

if __name__ == "__main__":

parser = argparse.ArgumentParser(description='Dataset information')
parser.add_argument('--input', required=True, type=str, help='CSV to count and rename classes')
parser.add_argument('--class_column', type=str, default='Label', help='Name of class column')
parser.add_argument('--change_classes', type=bool, default=False, help='Change the name of the classes (uses the dict)')

args = parser.parse_args()
# Classification of the position
# dict_classes = {
# "Buccal": 0,
# "Bicortical":1,
# "Palatal": 2,
# "nan": '' ,
# }

# Classification No Damage (0) /Damage (1)
# dict_classes = {
# 0:0,
# 1:1,
# 2:1,
# 3:1,
# }

# No damage:0, Mild damage:1, Severe + Extreme damage:2
# dict_classes={
# 0:0,
# 1:1,
# 2:2,
# 3:2,
# }

# dict_classes ={
# 10: 0,
# 0: 1,
# 1: 2,
# 2: 3,

# }

# dict_classes = {
# 10: 4,
# 3: 5,
# 4: 6,
# 5: 7,
# }
classes = count_classes(args.input,args.class_column,dict_classes)
print(classes)
58 changes: 41 additions & 17 deletions Preprocess/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,35 @@ def resample_image_with_custom_size(img,segmentation, args):
axes_to_pad_Up = [0]*img.GetDimension()
axes_to_pad_Down = [0]*img.GetDimension()
for dim in range(img.GetDimension()):
if roi_img.GetSize()[dim] < target_size[dim]:
pad_size = target_size[dim] - roi_img.GetSize()[dim]
pad_size_Up = pad_size // 2
pad_size_Down = pad_size - pad_size_Up
axes_to_pad_Up[dim] = pad_size_Up
axes_to_pad_Down[dim] = pad_size_Down
pad_filter = sitk.ConstantPadImageFilter()
# Get the minimum value of the image
min_val = float(np.min(sitk.GetArrayFromImage(img)))
pad_filter.SetConstant(min_val)
pad_filter.SetPadLowerBound(axes_to_pad_Down)
pad_filter.SetPadUpperBound(axes_to_pad_Up)
img_padded = pad_filter.Execute(roi_img)
if args.crop == False:
if roi_img.GetSize()[dim] < target_size[dim]:
pad_size = target_size[dim] - roi_img.GetSize()[dim]
pad_size_Up = pad_size // 2
pad_size_Down = pad_size - pad_size_Up
axes_to_pad_Up[dim] = pad_size_Up
axes_to_pad_Down[dim] = pad_size_Down
pad_filter = sitk.ConstantPadImageFilter()
# Get the minimum value of the image
min_val = float(np.min(sitk.GetArrayFromImage(img)))
pad_filter.SetConstant(min_val)
pad_filter.SetPadLowerBound(axes_to_pad_Down)
pad_filter.SetPadUpperBound(axes_to_pad_Up)
img_padded = pad_filter.Execute(roi_img)

# if roi_img.GetSize()[dim] > target_size[dim]:
# # Crop the image
# crop_size_up = (roi_img.GetSize()[dim] - target_size[dim]) // 2
# crop_size_down = roi_img.GetSize()[dim] - target_size[dim] - crop_size_up
# axes_to_pad_Up[dim] = crop_size_up
# axes_to_pad_Down[dim] = crop_size_down
# crop_filter = sitk.CropImageFilter()
# crop_filter.SetLowerBoundaryCropSize(axes_to_pad_Up)
# crop_filter.SetUpperBoundaryCropSize(axes_to_pad_Down)
# img_padded = crop_filter.Execute(roi_img)

# #check if the image is the same size as the target size
# if img_padded.GetSize()[dim] == target_size[dim]:
# img_padded = img_padded

else:
img_padded = roi_img
Expand All @@ -66,8 +82,6 @@ def resample_fn(img, args):
pixel_dimension = args.pixel_dimension
center = args.center

print('FIT SPACING:', fit_spacing)
print('ISO SPACING:', iso_spacing)
# if(pixel_dimension == 1):
# zeroPixel = 0
# else:
Expand Down Expand Up @@ -114,6 +128,7 @@ def resample_fn(img, args):
print("Input size:", size)
print("Input spacing:", spacing)
print("Output size:", output_size)

print("Output spacing:", output_spacing)
print("Output origin:", output_origin)

Expand Down Expand Up @@ -148,6 +163,7 @@ def Resample(img_filename, segm, args):
seg = sitk.ReadImage(segm)
return resample_image_with_custom_size(img, seg, args)
else:

return resample_fn(img, args)


Expand Down Expand Up @@ -181,7 +197,8 @@ def Resample(img_filename, segm, args):
transform_group.add_argument('--linear', type=bool, help='Use linear interpolation.', default=True)
transform_group.add_argument('--center', type=bool, help='Center the image in the space', default=True)
transform_group.add_argument('--fit_spacing', type=bool, help='Fit spacing to output', default=False)
transform_group.add_argument('--iso_spacing', type=bool, help='Same spacing for resampled output', default=False)
transform_group.add_argument('--iso_spacing', type=bool, help='Same spacing for resampled output', default=True)
transform_group.add_argument('--crop', type=bool, help='Only Crop the image to the segmentation (used only if args.segmentation is)', default=False)

img_group = parser.add_argument_group('Image parameters')
img_group.add_argument('--image_dimension', type=int, help='Image dimension', default=3)
Expand Down Expand Up @@ -333,7 +350,15 @@ def Resample(img_filename, segm, args):
if args.size is not None:
img = Resample(fobj["img"], args.segmentation, args)
else:

img = sitk.ReadImage(fobj["img"])
size = img.GetSize()
physical_size = np.array(size)*np.array(img.GetSpacing())
new_size = [int(physical_size[i]//args.spacing[i]) for i in range(img.GetDimension())]
args.size = new_size
img = Resample(fobj["img"],args.segmentation, args)



print("Writing:", fobj["out"])
writer = sitk.ImageFileWriter()
Expand All @@ -360,4 +385,3 @@ def Resample(img_filename, segm, args):
writer.UseCompressionOn()
writer.Execute(img)


Loading

0 comments on commit febae6a

Please sign in to comment.