diff --git a/Preprocess/create_CBCTmask.py b/Preprocess/create_CBCTmask.py index 5a28d30..34692d4 100644 --- a/Preprocess/create_CBCTmask.py +++ b/Preprocess/create_CBCTmask.py @@ -42,6 +42,7 @@ def applyMask(image, mask, label,dilation_radius): # Define the structuring element for dilation kernel = sitk.sitkBall + kernel=sitk.sitkBox radius = dilation_radius dilate_filter = sitk.BinaryDilateImageFilter() dilate_filter.SetKernelType(kernel) @@ -88,7 +89,7 @@ def applyMask(image, mask, label,dilation_radius): parser = argparse.ArgumentParser(description="Apply a mask to an image") parser.add_argument("--img", help="Input image") parser.add_argument("--mask", help="Input mask") - parser.add_argument("--label", nargs='+', help="Label to apply the mask",default=1) + parser.add_argument("--label", nargs='+', help="Label to apply the mask. Ex: 1 or 1 2",required=True) parser.add_argument("--output", help="Output image") parser.add_argument("--dilatation_radius", type=int, help="Radius of the dilatation to apply to the mask",default=None) args = parser.parse_args() diff --git a/Preprocess/create_csv_right_left.py b/Preprocess/create_csv_right_left.py new file mode 100644 index 0000000..b351516 --- /dev/null +++ b/Preprocess/create_csv_right_left.py @@ -0,0 +1,79 @@ +import csv +import argparse +import os +import pandas as pd + +def main(args): + input_file = args.input + output_file = args.output + # Create a dictionary to store the combined rows + combined_rows = {} + + # Read the input CSV file + with open(input_file, 'r') as file: + reader = csv.reader(file) + next(reader) # Skip the header row + + # Iterate over each row in the input file + for row in reader: + path, name, label = row + + # Extract the patient number from the name + patient_number = name.split('_')[0] + side = name.split('_')[1] + # Check if a row with the same patient number already exists + if patient_number in combined_rows: + # Add the label to the corresponding column + if side == 'R': + combined_rows[patient_number]['Label R'] = label + + elif side == 'L': + combined_rows[patient_number]['Label L'] = label + + else: + # Create a new entry for the patient number + if side == 'R': + combined_rows[patient_number] = { + 'Path': path, + 'Name': name, + 'Label R': label, + 'Label L': 'nan' + } + elif side == 'L': + combined_rows[patient_number] = { + 'Path': path, + 'Name': name, + 'Label R': 'nan', + 'Label L': label + } + + # Write the combined rows to the output CSV file + output_dir = os.path.dirname(input_file) + output_path = os.path.join(output_dir, output_file) + with open(output_path, 'w', newline='') as file: + writer = csv.writer(file) + + # Write the header row + writer.writerow(['Path', 'Name', 'Label R', 'Label L','Label comb']) + + # Write the combined rows + for row in combined_rows.values(): + if row['Label R'] != 'nan' and row['Label L'] != 'nan': + if int(row['Label R']) > int(row['Label L']): + #concat the labels in the order of the highest label + comb_label = row['Label L']+ row['Label R'] + else : + comb_label = row['Label R']+ row['Label L'] + elif row['Label R'] != 'nan' and row['Label L'] == 'nan': + comb_label = row['Label R'] + elif row['Label R'] == 'nan' and row['Label L'] != 'nan': + comb_label = row['Label L'] + + writer.writerow([row['Path'], row['Name'], row['Label R'], row['Label L'],comb_label]) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Combine labels for left and right images') + parser.add_argument('--input', required=True, type=str, help='Input CSV file') + parser.add_argument('--output', required=True, type=str, help='Output CSV filename only') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/Preprocess/dataset_info.py b/Preprocess/dataset_info.py index ec28651..fdf11ca 100644 --- a/Preprocess/dataset_info.py +++ b/Preprocess/dataset_info.py @@ -70,5 +70,20 @@ def count_classes(csv_file,word_class='Classification',dict_classes={}): "nan": '' , } + dict_classes ={ + None: 0, + 0: 1, + 1: 2, + 2: 3, + + } + + # dict_classes = { + # None: 4, + # 3: 5, + # 4: 6, + # 5: 7, + # } + classes = count_classes(args.input,args.class_column,dict_classes) print(classes) \ No newline at end of file diff --git a/Preprocess/dataset_info_8classes.py b/Preprocess/dataset_info_8classes.py new file mode 100644 index 0000000..bc7940f --- /dev/null +++ b/Preprocess/dataset_info_8classes.py @@ -0,0 +1,126 @@ +''' +Date: November 2023 +Author: Jeanne Claret + +Description: This file contains the information about the dataset +such as the number of data per type, the number of classes. + +Used to rename the classes in the dataset. +''' + +import csv +import argparse +import pandas as pd + +# Count the number of different classes in the column "Position" of te csv file + +def count_classes(csv_file,word_class='Classification',dict_classes={}): + reader = pd.read_csv(csv_file) + classes = {} + output_file = csv_file.split('.')[0] + '_classes.csv' + + + for index, row in reader.iterrows(): + #if key_name is already an integer, count the number of similar int + #remove all empty value for row[word_class] + if pd.isnull(row[word_class]): + #delete the row + # reader = reader.drop(index) + # reader.to_csv(output_file, index=False) + # print(f'[INFO] Deleted row {index} because of empty value') + # put the empty value to None + reader.loc[index,word_class] = 10 + key_name = reader.loc[index,word_class] + + + else: + if isinstance(row[word_class],int): + key_name = row[word_class] + + + elif isinstance(row[word_class],float): + #change type of float to int + key_name = int(row[word_class]) + #rewrite the csv file + reader.loc[index,word_class] = key_name + reader.to_csv(output_file, index=False) + + elif isinstance(row[word_class],str): + key_name = str(row[word_class]).split(' ')[0] + ## For Sara's dataset + if key_name == 'Lingual': + key_name = 'Palatal' + if key_name == 'Bucccal': + key_name = 'Buccal' + if key_name == 'BuccaL': + key_name = 'Buccal' + + + if classes.get(key_name) is None: + classes[key_name] = 1 + else: + classes[key_name] += 1 + + + if not args.change_classes: + if not isinstance(row[word_class],int): + if isinstance(row[word_class],float): + continue + # Change the name of the classes + reader.loc[index,word_class] = dict_classes[key_name] + reader.to_csv(output_file, index=False) + else: + reader.loc[index,word_class] = dict_classes[key_name] + reader.to_csv(output_file, index=False) + + + return classes + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Dataset information') + parser.add_argument('--input', required=True, type=str, help='CSV to count and rename classes') + parser.add_argument('--class_column', type=str, default='Label', help='Name of class column') + parser.add_argument('--change_classes', type=bool, default=False, help='Change the name of the classes (uses the dict)') + + args = parser.parse_args() + # Classification of the position + # dict_classes = { + # "Buccal": 0, + # "Bicortical":1, + # "Palatal": 2, + # "nan": '' , + # } + + # Classification No Damage (0) /Damage (1) + # dict_classes = { + # 0:0, + # 1:1, + # 2:1, + # 3:1, + # } + + # No damage:0, Mild damage:1, Severe + Extreme damage:2 + # dict_classes={ + # 0:0, + # 1:1, + # 2:2, + # 3:2, + # } + + # dict_classes ={ + # 10: 0, + # 0: 1, + # 1: 2, + # 2: 3, + + # } + + # dict_classes = { + # 10: 4, + # 3: 5, + # 4: 6, + # 5: 7, + # } + classes = count_classes(args.input,args.class_column,dict_classes) + print(classes) \ No newline at end of file diff --git a/Preprocess/resample.py b/Preprocess/resample.py index bc1f542..e849b1b 100644 --- a/Preprocess/resample.py +++ b/Preprocess/resample.py @@ -35,19 +35,35 @@ def resample_image_with_custom_size(img,segmentation, args): axes_to_pad_Up = [0]*img.GetDimension() axes_to_pad_Down = [0]*img.GetDimension() for dim in range(img.GetDimension()): - if roi_img.GetSize()[dim] < target_size[dim]: - pad_size = target_size[dim] - roi_img.GetSize()[dim] - pad_size_Up = pad_size // 2 - pad_size_Down = pad_size - pad_size_Up - axes_to_pad_Up[dim] = pad_size_Up - axes_to_pad_Down[dim] = pad_size_Down - pad_filter = sitk.ConstantPadImageFilter() - # Get the minimum value of the image - min_val = float(np.min(sitk.GetArrayFromImage(img))) - pad_filter.SetConstant(min_val) - pad_filter.SetPadLowerBound(axes_to_pad_Down) - pad_filter.SetPadUpperBound(axes_to_pad_Up) - img_padded = pad_filter.Execute(roi_img) + if args.crop == False: + if roi_img.GetSize()[dim] < target_size[dim]: + pad_size = target_size[dim] - roi_img.GetSize()[dim] + pad_size_Up = pad_size // 2 + pad_size_Down = pad_size - pad_size_Up + axes_to_pad_Up[dim] = pad_size_Up + axes_to_pad_Down[dim] = pad_size_Down + pad_filter = sitk.ConstantPadImageFilter() + # Get the minimum value of the image + min_val = float(np.min(sitk.GetArrayFromImage(img))) + pad_filter.SetConstant(min_val) + pad_filter.SetPadLowerBound(axes_to_pad_Down) + pad_filter.SetPadUpperBound(axes_to_pad_Up) + img_padded = pad_filter.Execute(roi_img) + + # if roi_img.GetSize()[dim] > target_size[dim]: + # # Crop the image + # crop_size_up = (roi_img.GetSize()[dim] - target_size[dim]) // 2 + # crop_size_down = roi_img.GetSize()[dim] - target_size[dim] - crop_size_up + # axes_to_pad_Up[dim] = crop_size_up + # axes_to_pad_Down[dim] = crop_size_down + # crop_filter = sitk.CropImageFilter() + # crop_filter.SetLowerBoundaryCropSize(axes_to_pad_Up) + # crop_filter.SetUpperBoundaryCropSize(axes_to_pad_Down) + # img_padded = crop_filter.Execute(roi_img) + + # #check if the image is the same size as the target size + # if img_padded.GetSize()[dim] == target_size[dim]: + # img_padded = img_padded else: img_padded = roi_img @@ -66,8 +82,6 @@ def resample_fn(img, args): pixel_dimension = args.pixel_dimension center = args.center - print('FIT SPACING:', fit_spacing) - print('ISO SPACING:', iso_spacing) # if(pixel_dimension == 1): # zeroPixel = 0 # else: @@ -114,6 +128,7 @@ def resample_fn(img, args): print("Input size:", size) print("Input spacing:", spacing) print("Output size:", output_size) + print("Output spacing:", output_spacing) print("Output origin:", output_origin) @@ -148,6 +163,7 @@ def Resample(img_filename, segm, args): seg = sitk.ReadImage(segm) return resample_image_with_custom_size(img, seg, args) else: + return resample_fn(img, args) @@ -181,7 +197,8 @@ def Resample(img_filename, segm, args): transform_group.add_argument('--linear', type=bool, help='Use linear interpolation.', default=True) transform_group.add_argument('--center', type=bool, help='Center the image in the space', default=True) transform_group.add_argument('--fit_spacing', type=bool, help='Fit spacing to output', default=False) - transform_group.add_argument('--iso_spacing', type=bool, help='Same spacing for resampled output', default=False) + transform_group.add_argument('--iso_spacing', type=bool, help='Same spacing for resampled output', default=True) + transform_group.add_argument('--crop', type=bool, help='Only Crop the image to the segmentation (used only if args.segmentation is)', default=False) img_group = parser.add_argument_group('Image parameters') img_group.add_argument('--image_dimension', type=int, help='Image dimension', default=3) @@ -333,7 +350,15 @@ def Resample(img_filename, segm, args): if args.size is not None: img = Resample(fobj["img"], args.segmentation, args) else: + img = sitk.ReadImage(fobj["img"]) + size = img.GetSize() + physical_size = np.array(size)*np.array(img.GetSpacing()) + new_size = [int(physical_size[i]//args.spacing[i]) for i in range(img.GetDimension())] + args.size = new_size + img = Resample(fobj["img"],args.segmentation, args) + + print("Writing:", fobj["out"]) writer = sitk.ImageFileWriter() @@ -360,4 +385,3 @@ def Resample(img_filename, segm, args): writer.UseCompressionOn() writer.Execute(img) - diff --git a/classification_eval_VAXI.py b/classification_eval_VAXI.py index 38a86d4..ff0a268 100644 --- a/classification_eval_VAXI.py +++ b/classification_eval_VAXI.py @@ -5,7 +5,7 @@ import json import sys from sklearn.metrics import confusion_matrix -from sklearn.metrics import roc_curve, auc, roc_auc_score +from sklearn.metrics import roc_curve, auc, roc_auc_score,f1_score,precision_score,recall_score,accuracy_score from sklearn.metrics import classification_report import pandas as pd @@ -15,6 +15,7 @@ import itertools import pickle +from useful_readibility import printRed, printBlue,printGreen COLORS={ @@ -118,7 +119,6 @@ def classification_eval(df, args, y_true_arr, y_pred_arr): confusion_filename = os.path.join(output_dir,fn_cf) fig.savefig(confusion_filename) - # Plot normalized confusion matrix fig2 = plt.figure(figsize=args.figsize) cm = plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, title=args.title + ' - normalized') @@ -201,6 +201,120 @@ def classification_eval(df, args, y_true_arr, y_pred_arr): return score +def ClassificationMultiLabel_eval(df, args, y_true_arr, y_pred_arr): + ''' + function to evaluate a multi-label column classification model. + Test file example: + Path, Name, Label1, Label2, Pred1, Pred2 + /path/to/image1, image1, 1, 3, 1, 3 + /path/to/image2, image2, None, 4, None, 4 + /path/to/image3, image3, 2, 5, 2, 4 + ''' + input_dir = os.path.dirname(args.csv) + output_dir = os.path.join(args.mount_point, input_dir) + output_dir= output_dir + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if '_' in args.diff[0]: + column1_nm = args.csv_true_column + args.diff[0] + column2_nm = args.csv_true_column + args.diff[1] + + pred1_nm = args.csv_prediction_column + args.diff[0] + pred2_nm = args.csv_prediction_column + args.diff[1] + + else: + column1_nm = args.csv_true_column + ' ' + args.diff[0] + column2_nm = args.csv_true_column + ' ' + args.diff[1] + + pred1_nm = args.csv_prediction_column + ' ' + args.diff[0] + pred2_nm = args.csv_prediction_column + ' ' + args.diff[1] + + + #concatenate the 2 columns to get the class names + df_combined = pd.concat([df[column1_nm], df[column2_nm]]) + class_names = pd.unique(df_combined) + + #remove nan + class_names = [x for x in class_names if str(x) != 'nan'] + class_names.sort() + + + print("Class names:", class_names) + # Count false predictions (case where the true label is None and the prediction is not None) + fail_fp_R=0 + fail_fp_L=0 + # Count wrong predictions (case where the true label is not None and the prediction is None) + fail_wp_R=0 + fail_wp_L=0 + for idx,row in df.iterrows(): + if str(row[column1_nm]) != 'nan' and str(row[pred1_nm]) != 'nan': + + y_true_arr.append(str(row[column1_nm])) + y_pred_arr.append(str(row[pred1_nm])) + elif str(row[column1_nm]) !='nan' and str(row[pred1_nm]) == 'nan': + fail_wp_R+=1 + elif str(row[column1_nm])=='nan' and str(row[pred1_nm]) != 'nan': + fail_fp_R+=1 + else: + pass + + + for idx,row in df.iterrows(): + if str(row[column2_nm]) != 'nan' and str(row[pred2_nm]) != 'nan': + y_true_arr.append(str(row[column2_nm])) + y_pred_arr.append(str(row[pred2_nm])) + elif str(row[column2_nm]) !='nan' and str(row[pred2_nm]) == 'nan': + fail_wp_L+=1 + elif str(row[column2_nm])=='nan' and str(row[pred2_nm]) != 'nan': + fail_fp_L+=1 + else: + pass + + report = classification_report(y_true_arr, y_pred_arr, output_dict=True, zero_division=1) + + cnf_matrix = confusion_matrix(y_true_arr, y_pred_arr) + np.set_printoptions(precision=3) + + # Plot non-normalized confusion matrix + fig = plt.figure(figsize=args.figsize) + plot_confusion_matrix(cnf_matrix, classes=class_names, title=args.title) + + fn_cf = os.path.splitext(args.out)[0] + "_confusion.png" + confusion_filename = os.path.join(output_dir,fn_cf) + #add legend with the number of failed predictions + fig.text(0.25, 0.01, f'Predicted ghost: {fail_fp_R} ({args.diff[0]}), {fail_fp_L} ({args.diff[1]}), Missed Prediction: {fail_wp_R} ({args.diff[0]}), {fail_wp_L} ({args.diff[1]})', ha='center', va='center', color='red') + fig.savefig(confusion_filename) + + # Plot normalized confusion matrix + fig2 = plt.figure(figsize=args.figsize) + cm = plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, title=args.title + ' - normalized') + + fn =os.path.splitext(args.out)[0] + "_norm_confusion.png" + norm_confusion_filename = os.path.join(output_dir, fn) + print('norm_confusion_filename',norm_confusion_filename ) + fig2.text(0.25, 0.01, f'Predicted ghost: {fail_fp_R} ({args.diff[0]}), {fail_fp_L} ({args.diff[1]}), Missed Prediction: {fail_wp_R} ({args.diff[0]}), {fail_wp_L} ({args.diff[1]})', ha='center', va='center',color='red') + fig2.savefig(norm_confusion_filename) + + # save report to csv + + df_report = pd.DataFrame(report).transpose() + # if 'accuracy' + if 'accuracy' in df_report.columns: + df_report.loc['accuracy'] = '' + + df_report.loc['accuracy','accuracy']=report['accuracy'] + df_report.loc['accuracy','support']= df_report.loc['weighted avg','support'] + + fn = os.path.splitext(args.out)[0] + "_classification_report.csv" + report_filename = os.path.join(output_dir, fn) + df_report.to_csv(report_filename) + + args.eval_metric = 'F1' + score = choose_score(args,report) + return score + def main(args): y_true_arr = [] @@ -212,8 +326,11 @@ def main(args): else: df = pd.read_parquet(path_to_csv) - score = classification_eval(df, args, y_true_arr, y_pred_arr) - + if args.mode == 'CV': + score = classification_eval(df, args, y_true_arr, y_pred_arr) + pass + elif args.mode == 'CV_2pred': + score = ClassificationMultiLabel_eval(df, args, y_true_arr, y_pred_arr) return score @@ -224,15 +341,22 @@ def get_argparse(): parser = argparse.ArgumentParser(description='Evaluate classification result', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--csv', type=str, help='CSV file', required=True) - parser.add_argument('--csv_true_column', type=str, help='Which column to do the stats on', default="class") - parser.add_argument('--csv_prediction_column', type=str, help='csv prediction class', default='pred') + parser.add_argument('--csv_true_column', type=str, help='Which column to do the stats on, if Multi like Label L and Label R, write Label', default="class") + parser.add_argument('--csv_prediction_column', type=str, help='csv prediction class, if Multi write common word', default='pred') parser.add_argument('--title', type=str, help='Title for the image', default='Confusion matrix') parser.add_argument('--figsize', type=str, nargs='+', help='Figure size', default=(8, 8)) parser.add_argument('--eval_metric', type=str, help='Score you want to choose for picking the best model : F1 or AUC', default='F1', choices=['F1', 'AUC']) parser.add_argument('--mount_point', type=str, help='Mount point for the data', default='./') - parser.add_argument('--out', type=str, help='Output filename for the plot', default="out.png") + parser.add_argument('--out', type=str, help='Output filename for the plot', default="Final_evaluation.png") + + parser.add_argument('--mode', type=str, help='Mode of the evaluation', default='CV', choices=['CV', 'CV_2pred']) + # For MultiLabel evaluation + parser.add_argument('--diff',nargs='+', help='Differentiator between the 2 Label/predict columns. Ex: Label 1, Label 2 --> 1 2', default=['_R','_L']) + + + return parser @@ -240,9 +364,8 @@ def get_argparse(): if __name__ == "__main__": parser = get_argparse() args = parser.parse_args() - main(args) - + main(args) diff --git a/classification_predict.py b/classification_predict.py index 9e7bf90..d093fde 100644 --- a/classification_predict.py +++ b/classification_predict.py @@ -8,27 +8,264 @@ import torch from torch.utils.data import DataLoader -from nets.classification import Net, SegNet -from loaders.cleft_dataset import BasicDataset, SegDataset +from nets.classification import Net +from loaders.cleft_dataset import BasicDataset, Datasetarget from transforms.volumetric import EvalTransforms, SegEvalTransforms, NoEvalTransform from callbacks.logger import ImageLogger from sklearn.utils import class_weight -from sklearn.metrics import classification_report +from sklearn.metrics import classification_report, roc_curve, auc +from monai.metrics import compute_roc_auc + +from useful_readibility import printRed, printBlue, printGreen from tqdm import tqdm import pickle - +import matplotlib.pyplot as plt import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') -def main(args): +def plot_roc_curve_final(probs, truths, class_idx,out_path): + ''' + function used as final step to plot the last interested class roc curve to the plot + ''' + fpr, tpr, _ = roc_curve(truths, probs) + roc_auc = auc(fpr, tpr) + plt.plot(fpr, tpr, lw=2, label=f'Class {class_idx} (AUC = {roc_auc:.2f})') + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title(f'ROC AUC ') + plt.legend(loc="lower right") + plt.savefig(out_path) + plt.close() + +def add_roc_curve(probs, truths, class_idx): + ''' + function used to add a roc curve to the plot + ''' + fpr, tpr, _ = roc_curve(truths, probs) + roc_auc = auc(fpr, tpr) + plt.plot(fpr, tpr, lw=2, label=f'Class {class_idx} (AUC = {roc_auc:.2f})') + + +def MultiPred(args): + ''' + Function to handle prediction from a dataset with 2 different class columns. + for example, csv file: + Path, Patient, Label R, Label L + img1.nii.gz, patient1, 1, 4 + img2.nii.gz, patient2, 0, 5 + img3.nii.gz, patient3, 2, 3 + + The function will predict the target vector for each image and save the prediction in a new column in the csv file. + Each probability of class is between 0 and 1. The sum is not equal to 1. + In the example, the sum would give maximum 6. + + The target vector must be splitable in 2 parts (in the middle), one for each class column. + idx 0 to 2 are for Label R and idx 3 to 5 are for Label L in the example. + target vector = [0, 1, 0, 0, 1, 0] for the first row. + ''' + model = Net(seed=args.seed).load_from_checkpoint(args.model) + model.eval() + model.cuda() + + + if(os.path.splitext(args.csv)[1] == ".csv"): + df_train = pd.read_csv(args.csv_train) + df_test = pd.read_csv(args.csv) + else: + df_train = pd.read_parquet(args.csv_train) + df_test = pd.read_parquet(args.csv) + + test_ds = Datasetarget(df_test, mount_point=args.mount_point, img_column=args.img_column, class_column1=args.class_column1, class_column2=args.class_column2,nb_classes=args.nb_classes, transform=EvalTransforms(args.img_size)) + + test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, prefetch_factor=4) + + with torch.no_grad(): + predictionsR = [] + predictionsL = [] + probs_sig = [] + probs_softmax = [] + features = [] + + probR,probL = [],[] + trueR,trueL = [],[] + predictions_model = [] + idx_batch = [] + for idx, batch in tqdm(enumerate(test_loader), total=len(test_loader)): + X, Y = batch + X = X.cuda().contiguous() + if args.extract_features: + pred, x_f = model(X) + features.append(x_f.cpu().numpy()) + else: + pred = model(X) + pred_softmax = torch.nn.functional.softmax(pred, dim=1) #dim 0 + pred_sigmoid = torch.nn.functional.sigmoid(pred) + + demi_len = pred_sigmoid.shape[1]//2 + + for j in range(pred_sigmoid.shape[0]): + + noneR=0 + noneL=0 + + best_probR,idxR = torch.max(pred_sigmoid[j,:demi_len],dim=0) + best_probL,idxL = torch.max(pred_sigmoid[j,demi_len:],dim=0) + + if best_probR.item() > 0.6: + predictionsR.append(idxR.item()) + elif best_probR.item() <= 0.8: + predictionsR.append(None) + noneR=1 + + if best_probL.item() > 0.6: + idxL = idxL.item() + demi_len + predictionsL.append(idxL) + elif best_probL.item() <= 0.8 : + predictionsL.append(None) + noneL=1 + + if noneR == 1 and noneL == 1: + #get the highest probability and replace the None + best_prob,idx = torch.max(pred_sigmoid[j,:],dim=0) + if idx.item() < demi_len: + predictionsR[-1] = idx.item() + else: + predictionsL[-1] = idx.item() + + + #find label R and L from Y[j] + #look for the index of the non zero value + target_lst = Y[j].tolist() + idx_non_zero_target = [i for i, val in enumerate(target_lst) if val > 0.0] + + if len(idx_non_zero_target) > 1: + #when 2 classes are given, target_vector has 2 values of 0.5 so we need to change the value to 1 + trueR_value = [value*2 for value in target_lst[:demi_len]] + trueR.append(trueR_value) + probR.append(pred_softmax[j,:demi_len].cpu().numpy()) + + trueL_value = [value*2 for value in target_lst[demi_len:]] + trueL.append(trueL_value) + probL.append(pred_softmax[j,demi_len:].cpu().numpy()) + else: + if idx_non_zero_target[0] < demi_len: + trueR.append(target_lst[:demi_len]) + probR.append(pred_softmax[j,:demi_len].cpu().numpy()) + else: + trueL.append(target_lst[demi_len:]) + probL.append(pred_softmax[j,demi_len:].cpu().numpy()) + + probs_sig.append(pred_sigmoid[j,:].cpu().numpy()) + probs_softmax.append(pred_softmax[j,:].cpu().numpy()) + predictions_model.append(pred[j,:].cpu().numpy()) + + + + print('args.class_column1',args.class_column1) + if '_' in args.class_column1: + predR_column = args.pred_column+'_R' + predL_column = args.pred_column+'_L' + else: + predR_column = args.pred_column + ' R' + predL_column = args.pred_column + ' L' + df_test[predR_column] = predictionsR + df_test[predL_column] = predictionsL + output_dir =args.out + + df_test['Prob sigmoid'] = probs_sig + df_test['Prob softmax'] = probs_softmax + + filename = os.path.basename(args.csv).replace('.csv', '_prediction.csv') + df_test.to_csv(os.path.join(output_dir, filename), index=False) + # create csv with the predictions target before sigmoid + prob_filenameCsv = os.path.basename(args.csv).replace('.csv', '_prob.csv') + prob_filenamePickle = os.path.basename(args.csv).replace('.csv', '_prob-pred.pickle') + df_pred = pd.DataFrame() + df_pred['Before Sigmoid'] = predictions_model + df_pred.to_pickle(os.path.join(output_dir, prob_filenamePickle)) + df_pred['Sigmoid'] = probs_sig + df_pred.to_csv(os.path.join(output_dir, prob_filenameCsv)) + + + ## Compute AUC and ROC AUC curve and save it + auc_fn = "auc_evaluation.csv" + auc_path = os.path.join(output_dir, "AUC/") + if not os.path.exists(auc_path): + os.makedirs(auc_path) + auc_path = os.path.join(auc_path, auc_fn) + + roc_curve_fn = "roc_curve.png" + roc_curve_path = os.path.join(output_dir, "AUC/") + if not os.path.exists(roc_curve_path): + os.makedirs(roc_curve_path) + roc_curve_path = os.path.join(roc_curve_path, roc_curve_fn) + + auc_data_lst = [] + probR,trueR = np.array(probR), np.array(trueR) + probL, trueL = np.array(probL), np.array(trueL) + aucR_tot =compute_roc_auc(torch.tensor(probR), torch.tensor(trueR)) + aucL_tot =compute_roc_auc(torch.tensor(probL), torch.tensor(trueL)) + if aucR_tot>0.6: + printGreen(f'AUC R: {aucR_tot}') + else: + printRed(f'AUC R: {aucR_tot}') + if aucL_tot>0.6: + printGreen(f'AUC L: {aucL_tot}') + else: + printRed(f'AUC L: {aucL_tot}') + + #Compute AUC for each class + first_column = [] + for i in range(args.nb_classes): + if i==demi_len: + first_column.append('auc 1') + auc_data_lst.append(aucR_tot) + first_column.append(i) + else: + first_column.append(i) + if i< demi_len: + aucR =compute_roc_auc(torch.tensor(probR)[:,i], torch.tensor(trueR)[:,i]) + auc_data_lst.append(aucR) + + if i< demi_len-1: + plt.figure(1) + add_roc_curve(torch.tensor(probR)[:,i], torch.tensor(trueR)[:,i], i) + else: + right_curve_path = roc_curve_path.replace('.png','_R.png') + plot_roc_curve_final(torch.tensor(probR)[:,i], torch.tensor(trueR)[:,i], i,right_curve_path) + else: + aucL =compute_roc_auc(torch.tensor(probL)[:,i-demi_len], torch.tensor(trueL)[:,i-demi_len]) + auc_data_lst.append(aucL) + if i< args.nb_classes-1: + plt.figure(2) + add_roc_curve(torch.tensor(probL)[:,i-demi_len], torch.tensor(trueL)[:,i-demi_len], i) + else: + left_curve_path = roc_curve_path.replace('.png','_L.png') + plot_roc_curve_final(torch.tensor(probL)[:,i-demi_len], torch.tensor(trueL)[:,i-demi_len], i,left_curve_path) + + first_column.append('auc 2') + auc_data_lst.append(aucL_tot) + + df_auc = pd.DataFrame({'Class':first_column,"AUC": auc_data_lst}) + df_auc.to_csv(auc_path, index=False) + + + +def NormalPred(args): + ''' + Prediction function for a single class column. + Sum of probabilities of each class is equal to 1 (Use Softmax activation function). + ''' if args.seg_column is None: model = Net(seed=args.seed).load_from_checkpoint(args.model) - else: - model = SegNet().load_from_checkpoint(args.model) + model.eval() model.cuda() @@ -63,8 +300,6 @@ def main(args): if args.seg_column is None: test_ds = BasicDataset(df_test, img_column=args.img_column, mount_point=args.mount_point, class_column=args.class_column, transform=EvalTransforms(args.img_size)) - else: - test_ds = SegDataset(df_test, img_column=args.img_column, mount_point=args.mount_point, class_column=args.class_column, seg_column=args.seg_column, transform=SegEvalTransforms(args.img_size)) else: test_ds = BasicDataset(df_test, img_column=args.img_column, mount_point=args.mount_point, transform=EvalTransforms(args.img_size)) @@ -88,7 +323,9 @@ def main(args): features.append(x_f.cpu().numpy()) else: pred = model(X) - probs.append(pred.cpu().numpy()) + pred_prob = torch.nn.functional.softmax(pred, dim=1) + + probs.append(pred_prob.cpu().numpy()) predictions.append(torch.argmax(pred, dim=1).cpu().numpy()) df_test[args.pred_column] = np.concatenate(predictions, axis=0) @@ -114,16 +351,24 @@ def main(args): pickle.dump(features, open(os.path.join(args.mount_point, args.out, os.path.basename(args.csv).replace(ext, "_prediction.pickle")), 'wb')) +def main(args): + if not os.path.exists(os.path.dirname(args.out)): + os.makedirs(os.path.dirname(args.out)) + if args.mode == 'CV_2pred': + MultiPred(args) + elif args.mode == 'CV': + NormalPred(args) + def get_argparse(): parser = argparse.ArgumentParser(description='Classification predict') parser.add_argument('--csv', type=str, help='CSV file for testing', required=True) parser.add_argument('--csv_train', type=str, help='CSV file to compute class replace', required=True) parser.add_argument('--extract_features', type=bool, help='Extract the features', default=False) - parser.add_argument('--img_column', type=str, help='Column name in the csv file with image path', default="img") - parser.add_argument('--class_column', type=str, help='Column name in the csv file with classes', default="Classification") + parser.add_argument('--img_column', type=str, help='Column name in the csv file with image path', default="Path") + parser.add_argument('--class_column', type=str, help='Column name in the csv file with classes', default="Label") parser.add_argument('--seg_column', type=str, help='Column name in the csv file with image segmentation path', default=None) parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, help='Learning rate') - parser.add_argument('--model', type=str, help='Model path to use for the predictions', default='./') + parser.add_argument('--model', type=str, help='Model path to use for the predictions',required=True, default='./') parser.add_argument('--epochs', help='Max number of epochs', type=int, default=200) parser.add_argument('--out', help='Output directory', type=str, default="./") parser.add_argument('--pred_column', help='Output column name', type=str, default="pred") @@ -134,14 +379,25 @@ def get_argparse(): parser.add_argument('--img_size', help='Image size of the dataset', type=int, default=224) parser.add_argument('--seed', help='Seed for reproducibility', type=int, default=42) + + parser.add_argument('--mode', type=str, help='Mode used for the model', default='CV', choices=['CV', 'CV_2pred']) + # target vector prediction + parser.add_argument('--nb_classes', help='Number of classes', type=int, default=6) + parser.add_argument('--class_column1', type=str, help='Column name in the csv file with classes', default="Label_R") + parser.add_argument('--class_column2', type=str, help='Column name in the csv file with classes', default="Label_L") + parser.add_argument('--diff', nargs='+',help='Differentiator between the 2 Label/predict columns. Ex: Label 1, Label 2 --> 1 2', default=['_R','_L']) + + + return parser + if __name__ == '__main__': parser = get_argparse() args = parser.parse_args() if args.model =='./': - print('Please provide a model path') + printRed('Please provide a path to a model') exit() - main(args) + main(args) \ No newline at end of file diff --git a/classification_train_v2.py b/classification_train_v2.py index 1410246..0a93d58 100644 --- a/classification_train_v2.py +++ b/classification_train_v2.py @@ -9,9 +9,9 @@ import torch -from nets.classification import Net, SegNet -from loaders.cleft_dataset import DataModule, SegDataModule -from transforms.volumetric import TrainTransforms, EvalTransforms,SpecialTransforms, SegTrainTransforms, SegEvalTransforms, NoTransform, NoEvalTransform +from nets.classification import Net, NetTarget +from loaders.cleft_dataset import DataModule, DataModuleT +from transforms.volumetric import TrainTransforms, EvalTransforms,SpecialTransforms, NoTransform, NoEvalTransform import classification_predict import classification_eval_VAXI from useful_readibility import printRed, printGreen,printOrange, printBlue @@ -93,6 +93,7 @@ def main(args): for cn, cl in enumerate(unique_classes): class_replace[int(cl)] = cn print(unique_classes, unique_class_weights, class_replace) + unique_class_weights = None #save the parameters of the model outpath_modelInfo = args.out + "/modelParams.csv" @@ -184,9 +185,13 @@ def main(args): if args.seg_column is None: - data = DataModule(df_train_inner, df_val,df_test, df_filtered_special, mount_point=args.mount_point, batch_size=args.batch_size, num_workers=args.num_workers, img_column=args.img_column, class_column=args.class_column, - train_transform= TrainTransforms(img_size,pad_size), valid_transform=EvalTransforms(img_size),test_transform=EvalTransforms(img_size), special_transform = special_tf,seed=args.seed) - + if args.mode == 'CV_2pred': + data = DataModuleT(df_train_inner, df_val,df_test, df_filtered_special, mount_point=args.mount_point, batch_size=args.batch_size, num_workers=args.num_workers, + img_column=args.img_column, nb_classes=args.nb_classes, class_column1="Label_R", class_column2="Label_L", + train_transform= TrainTransforms(img_size,pad_size), valid_transform=EvalTransforms(img_size),test_transform=EvalTransforms(img_size), special_transform = special_tf,seed=args.seed) + else: + data = DataModule(df_train_inner, df_val,df_test, df_filtered_special, mount_point=args.mount_point, batch_size=args.batch_size, num_workers=args.num_workers,img_column=args.img_column,class_column=args.class_column, + train_transform= TrainTransforms(img_size,pad_size), valid_transform=EvalTransforms(img_size),test_transform=EvalTransforms(img_size), special_transform = special_tf,drop_last= False, seed=args.seed) #restart the training to the fold of the model then continue if args.checkpoint is not None: @@ -200,24 +205,26 @@ def main(args): prediction_folder = os.path.dirname(args.checkpoint).replace(folder_last_model,'Predictions')+f'/{folder_last_model}' if os.path.exists(prediction_folder): continue - model = Net.load_from_checkpoint(args.checkpoint, num_classes=unique_classes.shape[0], class_weights=unique_class_weights, base_encoder=base_encoder,seed=args.seed) + if args.mode == 'CV_2pred': + model = NetTarget.load_from_checkpoint(args.checkpoint, num_classes=args.nb_classes, class_weights=unique_class_weights, base_encoder=base_encoder,seed=args.seed) + else: + model = Net.load_from_checkpoint(args.checkpoint, num_classes=args.nb_classes, class_weights=unique_class_weights, base_encoder=base_encoder,seed=args.seed) ckpt_path = args.checkpoint else: - model = Net(args, num_classes=unique_classes.shape[0], class_weights=unique_class_weights, base_encoder=base_encoder,seed=args.seed) + if args.mode == 'CV_2pred': + model= NetTarget(args, num_classes=args.nb_classes, class_weights=unique_class_weights, base_encoder=base_encoder,seed=args.seed) + else: + model = Net(args, num_classes=args.nb_classes, class_weights=unique_class_weights, base_encoder=base_encoder,seed=args.seed) ckpt_path = None - # if args.model is not None and i==0: - # model = Net.load_from_checkpoint(args.model, num_classes=unique_classes.shape[0], class_weights=unique_class_weights, base_encoder=base_encoder,seed=args.seed) + else: - model = Net(args, num_classes=unique_classes.shape[0], class_weights=unique_class_weights, base_encoder=base_encoder,seed=args.seed) + if args.mode == 'CV_2pred': + model= NetTarget(args, num_classes=args.nb_classes, class_weights=unique_class_weights, base_encoder=base_encoder,seed=args.seed) + else: + model = Net(args, num_classes=args.nb_classes, class_weights=unique_class_weights, base_encoder=base_encoder,seed=args.seed) ckpt_path = None - # torch.backends.cudnn.benchmark = True - # else: - # data = SegDataModule(df_train_inner, df_val,df_test, mount_point=args.mount_point, batch_size=args.batch_size, num_workers=args.num_workers, img_column=args.img_column, class_column=args.class_column, - # train_transform=SegTrainTransforms(img_size), valid_transform=SegEvalTransforms(img_size),test_transform=SegEvalTransforms(img_size)) - - # model = SegNet(args, num_classes=unique_classes.shape[0], class_weights=unique_class_weights, base_encoder=base_encoder) # Create a folder for each fold checkpoint_dir =args.out + f"/fold_{i}" @@ -294,6 +301,13 @@ def main(args): prediction_args['seed']= args.seed + ## Prediction for 2 classes columns + prediction_args['mode']=args.mode + if args.mode == 'CV_2pred': + prediction_args['class_column1']= "Label_R" + prediction_args['class_column2']= "Label_L" + prediction_args['nb_classes']= args.nb_classes + prediction_args= Namespace(**prediction_args) ext = os.path.splitext(outpath_test)[1] out_prediction = os.path.join(prediction_args.out, os.path.basename(best_model), os.path.basename(outpath_test).replace(ext, "_prediction" + ext)) @@ -310,11 +324,18 @@ def main(args): predict_csv_path = outdir_prediction + os.path.basename(outpath_test).replace(ext, "_prediction" + ext) evaluation_args['csv']= predict_csv_path - evaluation_args['csv_true_column']= args.class_column evaluation_args['csv_prediction_column']= "Prediction" evaluation_args['title']= f"Confusion matrix fold {i}" evaluation_args['out']= f"fold_{i}_eval.png" + ## Evaluation for 2 classes columns + if args.mode == 'CV_2pred': + evaluation_args['csv_true_column']= "Label" + evaluation_args['mode']=args.mode + evaluation_args['diff']= ['_R','_L'] + else: + evaluation_args['csv_true_column']= args.class_column + evaluation_args= Namespace(**evaluation_args) metric = classification_eval_VAXI.main(evaluation_args) # AUC or F1 @@ -347,6 +368,7 @@ def main(args): parser.add_argument('--img_column', type=str, default='img', help='Name of image column') parser.add_argument('--class_column', type=str, default='Label', help='Name of class column') parser.add_argument('--seg_column', type=str, default=None, help='Name of segmentation column') + parser.add_argument('--nb_classes', type=int, default=6, help='Number of classes') parser.add_argument('--base_encoder', nargs="+", default='efficientnet-b0', help='Type of base encoder') parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, help='Learning rate') parser.add_argument('--epochs', help='Max number of epochs', type=int, default=400) @@ -374,6 +396,9 @@ def main(args): # seed parser.add_argument('--seed', help='Seed', type=int, default=42) + #mode + parser.add_argument('--mode', help='Mode for training', type=str, default='CV', choices=['CV', 'CV_2pred']) + #Cross validation cv_group = parser.add_argument_group('Cross validation') cv_group.add_argument('--split', type=int, default=5, help='Number of splits for cross validation') diff --git a/count_cases.py b/count_cases.py new file mode 100644 index 0000000..e933d4b --- /dev/null +++ b/count_cases.py @@ -0,0 +1,57 @@ +import csv +import pandas as pd +import argparse + +# Path to the CSV file +parser = argparse.ArgumentParser(description="Count the number of cases with labels") +parser.add_argument("--csv_file", help="Path to the CSV file") +args = parser.parse_args() +csv_file = args.csv_file +# Initialize counters +l_cases = 0 +r_cases = 0 +both_cases = 0 + +# Read the CSV file +with open(csv_file, 'r') as file: + reader = csv.DictReader(file) + ct=0 + for row in reader: + ct+=1 + label_r = row['Label_R'] + label_l = row['Label_L'] + #change type of row + label_r = int(label_r) if label_r else None + label_l = int(label_l) if label_l else None + print(f'======row {ct} {row}======') + print('Name', row['Name'] ) + print('label r', label_r) + print('label l', label_l) + + # Check if Label_R is None and Label_L has a class + #use pd.isna + if (pd.isna(label_r) and label_l is not None) or ( label_r is None and label_l is not None): + l_cases += 1 + + + # Check if Label_R has a class and Label_L is None + elif (label_r is not None and pd.isna(label_l)) or (label_r is not None and label_l is None): + r_cases += 1 + + + # Check if both Label_R and Label_L have classes + elif label_r is not None and label_l is not None: + both_cases += 1 + + + else: + print("This case has no labels") + +# Print the counts +print('number lines csv', reader.line_num) +print(f"Number of L cases: {l_cases}") +print(f"Number of R cases: {r_cases}") +print(f"Number of cases with both labels: {both_cases}") +sum_cases = l_cases + r_cases + both_cases +if sum_cases == reader.line_num-1: + print('[SUCCESS] All cases have been counted') \ No newline at end of file diff --git a/gradcam3D_monai.py b/gradcam3D_monai.py index 396e4e1..c167c4b 100644 --- a/gradcam3D_monai.py +++ b/gradcam3D_monai.py @@ -158,7 +158,7 @@ def get_cam_sum(data,model,class_index, layers): def main(args): # Parameters - nb_of_classes = args.nb_classes + nb_of_classes = args.nb_class class_index = args.class_index layer_name = args.layer_name out_dir = args.out @@ -169,7 +169,10 @@ def main(args): unique_classes = np.unique(df[args.class_column]) class_replace = {} for cn, cl in enumerate(unique_classes): - class_replace[int(cl)] = cn + if pd.isna(cl) : + continue + else: + class_replace[int(cl)] = cn df[args.class_column] = df[args.class_column].replace(class_replace) # img_path = ['Preprocess/Preprocessed_data/Resampled/Left/MN080_scan_MB_Masked.nii.gz'] # dataset = CustomDataset(img_path, transforms=EvalTransforms(256)) @@ -211,7 +214,7 @@ def main(args): given_path = False # Save directories - cam_save_dir = f'{out_dir}/cam_images' + cam_save_dir = f'{out_dir}/class{class_index}' if os.path.exists(cam_save_dir) is False: os.makedirs(cam_save_dir) @@ -250,6 +253,8 @@ def main(args): true_class = df.loc[batch][args.class_column] predicted_class = df.loc[batch][args.pred_column] img_fn = df.loc[batch][args.img_column] + #extract the value from the tensor + true_class = true_class.item() else: true_class = 'X' @@ -284,22 +289,29 @@ def main(args): group.add_argument('--csv_test', type=str, help='Testing set csv to load') parser.add_argument('--img_column', type=str, default='Path', help='Name of image column') parser.add_argument('--class_column', type=str, default='Label', help='Name of class column with the labels') - parser.add_argument('--pred_column', type=str, default='Predictions', help='Name of class column with the predicted labels') + parser.add_argument('--pred_column', type=str, default='pred', help='Name of class column with the predicted labels') group.add_argument('--img_path', type=str, help='Path to the image to load') - parser.add_argument('--mount_point', help='Dataset mount directory', type=str, default="./") parser.add_argument('--model_path', help='Model path to use', type=str, default='') - parser.add_argument('--out', help='Output folder with vizualisation files', type=str, default="Training_Left/SEResNet50/Predictions/GRADCAM/onebyone") + parser.add_argument('--mount_point', help='Dataset mount directory', type=str, default="./") + parser.add_argument('--out', help='Output folder with vizualisation files', type=str, default="./GRADCAM") + - parser.add_argument('--img_size', help='Image size of the dataset', type=int, default=256) - parser.add_argument('--nb_classes', help='Number of classes', type=int, default=3) + parser.add_argument('--img_size', help='Image size of the dataset', type=int, default=224) + parser.add_argument('--nb_class', help='Number of classes', type=int, default=3) parser.add_argument('--class_index', help='Class index for GradCAM', type=int, default=1) + + parser.add_argument('--base_encoder', type=str, default="DenseNet201", help='Type of base encoder') parser.add_argument('--layer_name', help='Layer name for GradCAM', nargs="+", default=['model.layer4']) - parser.add_argument('--base_encoder', type=str, default="SEResNet50", help='Type of base encoder') parser.add_argument('--show_plots', help='Show plots', type=bool, default=False) parser.add_argument('--slice_idx', help='Slice index to plot', type=int, default=120) args = parser.parse_args() + #if args.model == '': print error message + if args.model_path == '': + printRed('Please provide a model path') + exit(1) + main(args) \ No newline at end of file diff --git a/loaders/cleft_dataset.py b/loaders/cleft_dataset.py index 96aacfa..27e3513 100644 --- a/loaders/cleft_dataset.py +++ b/loaders/cleft_dataset.py @@ -15,11 +15,10 @@ import pandas as pd import numpy as np + class BasicDataset(Dataset): def __init__(self, df, mount_point = "./", img_column='img', class_column='Classification', transform=None): self.df = df - - self.mount_point = mount_point self.transform = transform self.img_column = img_column @@ -30,21 +29,53 @@ def __len__(self): def __getitem__(self, idx): - # df_filtered = self.df.dropna(subset=[self.class_column]) - # df_filtered.reset_index(drop=True) row = self.df.loc[idx] - # print('*********idx***********',idx) - img = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(self.mount_point, row[self.img_column]))) if self.transform: img = self.transform(img) - cl = int(row[self.class_column]) + if pd.isna(row[self.class_column]): + cl =11 #11 is the class for missing data + else: + cl = int(row[self.class_column]) return img, torch.tensor(cl, dtype=torch.long) +class Datasetarget(Dataset): + def __init__(self, df, mount_point = "./", img_column='img', nb_classes=2,class_column1 ="Right", class_column2="Left", transform=None): + self.df = df + self.mount_point = mount_point + self.transform = transform + self.img_column = img_column + self.class_column_R = class_column1 + self.class_column_L = class_column2 + self.nb_classes = nb_classes + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + + row = self.df.loc[idx] + img = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(self.mount_point, row[self.img_column]))) + + if self.transform: + img = self.transform(img) + + target_vector = np.zeros(self.nb_classes) + + #convert into str both columns + if not pd.isna(row[self.class_column_R]): + target_vector[int(row[self.class_column_R])] = 1 + if not pd.isna(row[self.class_column_L]): + target_vector[int(row[self.class_column_L])] = 1 + if not pd.isna(row[self.class_column_R]) and not pd.isna(row[self.class_column_L]): + target_vector[int(row[self.class_column_R])] = 0.5 + target_vector[int(row[self.class_column_L])] = 0.5 + + return img, torch.tensor(target_vector, dtype=torch.float) class DataModule(pl.LightningDataModule): def __init__(self, df_train,df_val,df_test,df_special, mount_point="./", batch_size=32, num_workers=4, img_column='img_path', class_column='Classification', train_transform=None, valid_transform=None, test_transform=None,special_transform=None,drop_last=False,seed=42): @@ -71,15 +102,14 @@ def _set_seed(self, seed): def setup(self, stage=None): # Assign train/val datasets for use in dataloaders self.train_ds = BasicDataset(self.df_train, mount_point=self.mount_point, img_column=self.img_column, class_column=self.class_column, transform=self.train_transform) - # self.train_ds = SmartCacheDataset(base_train_ds, num_replace_workers=self.num_workers,replace_rate=0.3, cache_rate=1.0, cache_num=1) + if self.df_special is not None: self.special_ds = BasicDataset(self.df_special, mount_point=self.mount_point, img_column=self.img_column, class_column=self.class_column, transform=self.special_transform) self.val_ds = BasicDataset(self.df_val, mount_point=self.mount_point, img_column=self.img_column, class_column=self.class_column, transform=self.valid_transform) - # self.val_ds = SmartCacheDataset(base_val_ds,num_replace_workers=self.num_workers,replace_rate=0.3, cache_rate=1.0, cache_num=1) self.test_ds = BasicDataset(self.df_test, mount_point=self.mount_point, img_column=self.img_column, class_column=self.class_column, transform=self.test_transform) - # self.test_ds = SmartCacheDataset(base_test_ds,num_replace_workers=self.num_workers,replace_rate=0.3, cache_rate=1.0, cache_num=1) + def train_dataloader(self): if self.df_special is not None: @@ -97,69 +127,53 @@ def test_dataloader(self): return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=self.drop_last) -class SegDataset(Dataset): - def __init__(self, df, mount_point = "./", img_column='img', seg_column='seg', class_column='Classification', transform=None): - self.df = df - self.mount_point = mount_point - self.transform = transform - self.img_column = img_column - self.class_column = class_column - self.seg_column = seg_column - - def __len__(self): - return len(self.df) - - def __getitem__(self, idx): - - row = self.df.loc[idx] - img = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(self.mount_point, row[self.img_column]))) - - seg = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(self.mount_point, row[self.seg_column]))) - - if self.transform: - obj = self.transform({"img": img, "seg": seg}) - img = obj["img"] - seg = obj["seg"] - - cl = row[self.class_column] - - return img, seg, torch.tensor(cl, dtype=torch.long) - - -class SegDataModule(pl.LightningDataModule): - def __init__(self, df_train, df_val,df_test, mount_point="./", batch_size=32, num_workers=4, img_column='img_path', class_column='Classification', seg_column='seg', train_transform=None, valid_transform=None,test_transform=None, drop_last=False): +class DataModuleT(pl.LightningDataModule): + def __init__(self, df_train,df_val,df_test,df_special, mount_point="./", batch_size=32, num_workers=4, img_column='img_path', class_column1= 'Right',class_column2='Left', nb_classes=2,train_transform=None, valid_transform=None, test_transform=None,special_transform=None,drop_last=False,seed=42): super().__init__() self.df_train = df_train self.df_val = df_val self.df_test = df_test + self.df_special = df_special self.mount_point = mount_point self.batch_size = batch_size self.num_workers = num_workers self.img_column = img_column - self.seg_column = seg_column - self.class_column = class_column + self.class_column1 = class_column1 + self.class_column2 = class_column2 + self.nb_classes = nb_classes self.train_transform = train_transform self.valid_transform = valid_transform self.test_transform = test_transform + self.special_transform = special_transform self.drop_last=drop_last + def _set_seed(self, seed): + torch.manual_seed(seed) + def setup(self, stage=None): # Assign train/val datasets for use in dataloaders - base_train_ds = SegDataset(self.df_train, mount_point=self.mount_point, img_column=self.img_column, class_column=self.class_column, seg_column=self.seg_column, transform=self.train_transform) - self.train_ds = SmartCacheDataset(base_train_ds, num_replace_workers=self.num_workers,replace_rate=0.3, cache_rate=1.0, cache_num=1) + self.train_ds = Datasetarget(self.df_train, mount_point=self.mount_point, img_column=self.img_column, class_column1= self.class_column1, class_column2 = self.class_column2,nb_classes=self.nb_classes,transform=self.train_transform) - base_val_ds = SegDataset(self.df_val, mount_point=self.mount_point, img_column=self.img_column, class_column=self.class_column, seg_column=self.seg_column, transform=self.valid_transform) - self.val_ds = SmartCacheDataset(base_val_ds, num_replace_workers=self.num_workers,replace_rate=0.3, cache_rate=1.0, cache_num=1) + if self.df_special is not None: + self.special_ds = Datasetarget(self.df_special, mount_point=self.mount_point, img_column=self.img_column, class_column1=self.class_column1,class_column2=self.class_column2,nb_classes=self.nb_classes, transform=self.special_transform) + + self.val_ds = Datasetarget(self.df_val, mount_point=self.mount_point, img_column=self.img_column, class_column1=self.class_column1,class_column2=self.class_column2, nb_classes=self.nb_classes, transform=self.valid_transform) + + self.test_ds = Datasetarget(self.df_test, mount_point=self.mount_point, img_column=self.img_column, class_column1 =self.class_column1, class_column2=self.class_column2,nb_classes=self.nb_classes, transform=self.test_transform) - base_test_ds =SegDataset(self.df_test, mount_point=self.mount_point, img_column=self.img_column, class_column=self.class_column, seg_column=self.seg_column, transform=self.test_transform) - self.test_ds = SmartCacheDataset(base_test_ds, num_replace_workers=self.num_workers,replace_rate=0.3, cache_rate=1.0, cache_num=1) def train_dataloader(self): - return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, pin_memory=True, drop_last=self.drop_last, shuffle=True) + if self.df_special is not None: + #concatenate the special dataset with the training dataset + train_special_ds = torch.utils.data.ConcatDataset([self.train_ds, self.special_ds]) + return DataLoader(train_special_ds, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, pin_memory=True, drop_last=self.drop_last, shuffle=True) + else: + return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, pin_memory=True, drop_last=self.drop_last, shuffle=True) def val_dataloader(self): return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, pin_memory=True, drop_last=self.drop_last) def test_dataloader(self): - return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=self.drop_last) \ No newline at end of file + return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=self.drop_last) + diff --git a/nets/classification.py b/nets/classification.py index 4540519..5024965 100644 --- a/nets/classification.py +++ b/nets/classification.py @@ -15,6 +15,7 @@ from monai.networks.nets import AutoEncoder from monai.networks.blocks import Convolution +from useful_readibility import printRed, printBlue,printGreen import pytorch_lightning as pl from pl_bolts.transforms.dataset_normalizations import ( @@ -83,6 +84,188 @@ def forward(self, input_seq): return output +class CustomAccuracy(nn.Module): + def __init__(self): + super(CustomAccuracy, self).__init__() + self.sigmoid = nn.Sigmoid() + + def forward(self,preds,targets): + preds_sig = self.sigmoid(preds) + + + demi_len = int(preds_sig.shape[1]/2) + for j in range(preds_sig.shape[0]): + + classR = None + classL = None + + best_probR,idxR = torch.max(preds_sig[j,:demi_len],dim=0) + best_probL,idxL = torch.max(preds_sig[j,demi_len:],dim=0) + + if best_probR.item() > 0.59: + classR = idxR + if best_probL.item() > 0.59: + classL = idxL + if classR is None and classL is None: + if best_probR > best_probL: + classR = idxR + else: + classL = idxL + #find class from target vector,where '1' or '0.5' is present + #if 1 is present + if 1 in targets[j,:demi_len]: + targetR = torch.where(targets[j,:demi_len] == 1)[0] + if 1 in targets[j,demi_len:]: + targetL = torch.where(targets[j,demi_len:] == 1)[0] + if 0.5 in targets[j,:]: + targetR = torch.where(targets[j,:demi_len] == 0.5)[0] + targetL = torch.where(targets[j,demi_len:] == 0.5)[0] + + + +class CustomLossTarget(torch.nn.Module): + def __init__(self, penalty_weight=0.1, class_weights=None): + super(CustomLossTarget, self).__init__() + self.base_loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights) + self.penalty_weight = penalty_weight + self.sigmoid = nn.Sigmoid() + + + + def forward(self, preds, targets): + # Compute the base loss + base_loss = self.base_loss_fn(preds,targets) + #create prediction_vector + preds_sig = self.sigmoid(preds) + preds_sigF = F.sigmoid(preds) + demi_len = int(preds_sig.shape[1]/2) + prediction_vector = torch.zeros(preds_sig.shape[0],preds_sig.shape[1]) + + penalty_fp = 0 + penalty_fn = 0 + penalty_class = 0 + penalty_bonus = 0 + penalty_weight = self.penalty_weight + for j in range(preds_sig.shape[0]): + + noneR=0 + noneL=0 + + best_probR,idxR = torch.max(preds_sig[j,:demi_len],dim=0) + best_probL,idxL = torch.max(preds_sig[j,demi_len:],dim=0) + + if best_probR.item() > 0.65: + prediction_vector[j,idxR] = 1 + + if best_probL.item() > 0.65: + prediction_vector[j,idxL+demi_len] = 1 + + #check the sum of the vector, if it's 2, we need to divide by 2 each value, if it's 0, we need to put the best score to 1 + if prediction_vector[j].sum() > 1: + prediction_vector[j] = prediction_vector[j]/prediction_vector[j].sum() + elif prediction_vector[j].sum() == 0: + if best_probR > best_probL: + prediction_vector[j,idxR] = 1 + else: + prediction_vector[j,idxL+demi_len] = 1 + + # # # # If there is a prediction for a non-existing canine, penalize + # left_false_positive = True if (prediction_vector[j, :3].sum() > 0) & (targets[j, :3].sum() == 0) else False + # right_false_positive = True if (prediction_vector[j, 3:].sum() > 0) & (targets[j, 3:].sum() == 0) else False + + # If there is no prediction for an existing canine, penalize + left_false_negative = True if (prediction_vector[j, :3].sum() == 0) & (targets[:, :3].sum() > 0) else False + right_false_negative = True if (prediction_vector[:, 3:].sum() == 0) & (targets[:, 3:].sum() > 0) else False + + # if (left_false_positive or right_false_positive) or (left_false_positive and right_false_positive): + # penalty_fp +=1 + if left_false_negative or right_false_negative: + penalty_fn +=1 + + # # penalize if it predicts a class 2 (resp 5) when it's a class 0 (resp 3) and vice versa + # if (prediction_vector[j, 0] == 1 and targets[j, 2] == 1) or (prediction_vector[j, 2] == 1 and targets[j, 0] == 1): + # penalty_class+=0.5 + # if (prediction_vector[j, 3] == 1 and targets[j, 5] == 1) or (prediction_vector[j, 5] == 1 and targets[j, 3] == 1): + # penalty_class+=0.5 + + + + print('preds_sig',preds_sig) + printBlue(f'targets {targets}') + + penalty = penalty_weight * (penalty_fp + penalty_fn + penalty_class-penalty_bonus) + # penalty = penalty_weight * penalty_fn + printRed(f'penalty: {round(penalty,3)}') + + total_loss = base_loss + penalty + return total_loss + +class NetTarget(pl.LightningModule): + def __init__(self, args = None, class_weights=None, base_encoder="DenseNet",seed=42,num_classes=6): + super(NetTarget, self).__init__() + + self.save_hyperparameters() + self.args = args + self._set_seed(seed) + self.class_weights = class_weights + + if(class_weights is not None): + class_weights = torch.tensor(class_weights).to(torch.float32) + + self.loss = CustomLossTarget(penalty_weight=0.1,class_weights=class_weights) + self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=self.hparams.num_classes) + + if self.hparams.base_encoder == 'DenseNet': + self.model = monai.networks.nets.DenseNet(spatial_dims=3, in_channels=1,out_channels=self.hparams.num_classes) + if self.hparams.base_encoder == 'DenseNet169': + self.model = monai.networks.nets.DenseNet169(spatial_dims=3, in_channels=1,out_channels=self.hparams.num_classes) + if self.hparams.base_encoder == 'DenseNet201': + self.model = monai.networks.nets.DenseNet201(spatial_dims=3, in_channels=1,out_channels=self.hparams.num_classes) + if self.hparams.base_encoder == 'DenseNet264': + self.model = monai.networks.nets.DenseNet264(spatial_dims=3, in_channels=1,out_channels=self.hparams.num_classes) + + if self.hparams.base_encoder == 'SEResNet50': + self.model = monai.networks.nets.SEResNet50(spatial_dims=3, in_channels=1, num_classes=self.hparams.num_classes) + # elif self.hparams.base_encoder == 'ResNet': + # self.model = monai.networks.nets.ResNet(spatial_dims=3, n_input_channels=1, num_classes=self.hparams.num_classes) + elif self.hparams.base_encoder == 'resnet18' or base_encoder=='ResNet': + self.model = monai.networks.nets.resnet18(spatial_dims=3, n_input_channels=1, num_classes=self.hparams.num_classes) + if base_encoder == 'efficientnet-b0' or base_encoder == 'efficientnet-b1' or base_encoder == 'efficientnet-b2' or base_encoder == 'efficientnet-b3' or base_encoder == 'efficientnet-b4' or base_encoder == 'efficientnet-b5' or base_encoder == 'efficientnet-b6' or base_encoder == 'efficientnet-b7' or base_encoder == 'efficientnet-b8': + self.model = monai.networks.nets.EfficientNetBN(base_encoder, spatial_dims=3, in_channels=1, num_classes=self.hparams.num_classes) + + def _set_seed(self, seed): + torch.manual_seed(seed) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr) + return optimizer + + def forward(self, x): + ret =self.model(x) + return ret + + def training_step(self, train_batch, batch_idx): + x, y = train_batch + x = self(x) + + loss = self.loss(x, y) + self.log('train_loss', loss) + + self.accuracy(x, y) + self.log("train_acc", self.accuracy) + + return loss + + def validation_step(self, val_batch, batch_idx): + x, y = val_batch + x = self(x) + + loss = self.loss(x, y) + self.log('val_loss', loss, sync_dist=True) + + self.accuracy(x, y) + self.log("val_acc", self.accuracy) + class Net(pl.LightningModule): def __init__(self, args = None, class_weights=None, base_encoder="efficientnet-b0", seed = 42,num_classes=3): super(Net, self).__init__() @@ -96,11 +279,21 @@ def __init__(self, args = None, class_weights=None, base_encoder="efficientnet-b class_weights = torch.tensor(class_weights).to(torch.float32) self.loss = nn.CrossEntropyLoss(weight=class_weights) + self.bce_loss = nn.BCEWithLogitsLoss(weight=class_weights) + self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=self.hparams.num_classes) self.softmax = nn.Softmax(dim=1) + self.sigmoid = nn.Sigmoid() if self.hparams.base_encoder == 'DenseNet': self.model = monai.networks.nets.DenseNet(spatial_dims=3, in_channels=1,out_channels=self.hparams.num_classes) + if self.hparams.base_encoder == 'DenseNet169': + self.model = monai.networks.nets.DenseNet169(spatial_dims=3, in_channels=1,out_channels=self.hparams.num_classes) + if self.hparams.base_encoder == 'DenseNet201': + self.model = monai.networks.nets.DenseNet201(spatial_dims=3, in_channels=1,out_channels=self.hparams.num_classes) + if self.hparams.base_encoder == 'DenseNet264': + self.model = monai.networks.nets.DenseNet264(spatial_dims=3, in_channels=1,out_channels=self.hparams.num_classes) + if self.hparams.base_encoder == 'SEResNet50': self.model = monai.networks.nets.SEResNet50(spatial_dims=3, in_channels=1, num_classes=self.hparams.num_classes) # elif self.hparams.base_encoder == 'ResNet': @@ -127,10 +320,12 @@ def training_step(self, train_batch, batch_idx): x, y = train_batch x = self(x) - x= self.softmax(x) loss = self.loss(x, y) self.log('train_loss', loss) + bce_loss = self.bce_loss(x, y) + self.log('train_bce_loss', bce_loss) + self.accuracy(x, y) self.log("train_acc", self.accuracy) @@ -140,9 +335,10 @@ def validation_step(self, val_batch, batch_idx): x, y = val_batch x = self(x) - x= self.softmax(x) loss = self.loss(x, y) self.log('val_loss', loss, sync_dist=True) + bce_loss = self.bce_loss(x, y) + self.log('val_bce_loss', bce_loss, sync_dist=True) self.accuracy(x, y) self.log("val_acc", self.accuracy) diff --git a/test.py b/test.py index d170b80..952ecc2 100644 --- a/test.py +++ b/test.py @@ -3,6 +3,7 @@ import sklearn from sklearn.model_selection import StratifiedKFold +import torch.nn as nn import os import argparse @@ -49,10 +50,150 @@ def main(args): df_train.to_csv(os.path.join(args.csv.split('.')[0] + f'_train_{i}.csv'), index=False) df_test.to_csv(os.path.join(args.csv.split('.')[0] + f'_test_{i}.csv'), index=False) +def pred_vector(): + import torch + penalty_weight =0.1 + prediction_vector = [[0, 1, 0, 0, 1, 0],[1,0,0,0,0,0],[0,0,0,0,0,1]] + targets = [[0, 0, 0, 0, 1, 0],[1,0,0,0,0,0],[0,0.5,0,0,0,0.5]] + targets2= [[0,1,0,0,1,0],[1,0,0,0,0,0],[0,0,0,0,0,1]] + prediction_vector = torch.tensor(prediction_vector, dtype=torch.float32) + targets = torch.tensor(targets, dtype=torch.float32) + targets2 = torch.tensor(targets2, dtype=torch.float32) + + logits = torch.log(prediction_vector / (1 - prediction_vector + 1e-9)) + print('LOGITS',logits) + + base_loss = nn.functional.binary_cross_entropy(prediction_vector, targets, reduction='none') + print('base_loss',base_loss) + base_loss_S = nn.functional.binary_cross_entropy(prediction_vector, targets, reduction='sum') + print('base_loss_S',base_loss_S) + base_loss_M = nn.functional.binary_cross_entropy(prediction_vector, targets, reduction='mean') + print('base_loss_M',base_loss_M) + + CE_loss = torch.nn.CrossEntropyLoss() + CE_loss = CE_loss(prediction_vector, targets.argmax(dim=1)) + print('CE_loss',CE_loss) + + True_CE_loss = torch.nn.CrossEntropyLoss() + True_CE_loss = True_CE_loss(prediction_vector, targets2.argmax(dim=1)) + print('True_CE_loss',True_CE_loss) + + loss_function = torch.nn.BCEWithLogitsLoss() + # Compute the loss + loss = loss_function(prediction_vector, targets) + print('loss',loss) + + loss_true = loss_function(logits, targets2) + print('loss_true',loss_true) + + print('----- Loss with probability multi-class -----') + prediction_vector = torch.tensor([[0.1, 0.9, 0.1, 0.1, 0.9, 0.1],[0.9,0.1,0.1,0.1,0.1,0.1],[0.1,0.1,0.1,0.1,0.1,0.9]], dtype=torch.float32) + bce_loss = torch.nn.BCELoss() + bce_loss_false = bce_loss(prediction_vector, targets) + bce_loss_true = bce_loss(prediction_vector, targets2) + print('bce_loss_false',bce_loss_false) + print('bce_loss_true',bce_loss_true) + + bce_logit_loss = torch.nn.BCEWithLogitsLoss() + bce_logit_loss_false = bce_logit_loss(prediction_vector, targets) + bce_logit_loss_true = bce_logit_loss(prediction_vector, targets2) + print('bce_logit_loss_false',bce_logit_loss_false) + print('bce_logit_loss_true',bce_logit_loss_true) + + # Calculate penalty for false positives on left or right + print('prediction_vector',prediction_vector.shape) + penalty_fp = 0 + penalty_fn = 0 + for j in range(prediction_vector.shape[0]): + print('prediction_vector[j, :3]',prediction_vector[j, :3]) + left_false_positive = [ True if (prediction_vector[j, :3].sum() > 0) & (targets[j, :3].sum() == 0) else False] + print('fp',left_false_positive) + right_false_positive = [True if (prediction_vector[j, 3:].sum() > 0) & (targets[j, 3:].sum() == 0) else False] + + # Calculate penalty for false negatives on left or right + left_false_negative = [True if (prediction_vector[j, :3].sum() == 0) & (targets[:, :3].sum() > 0) else False] + print('fn',left_false_negative) + right_false_negative = [True if (prediction_vector[:, 3:].sum() == 0) & (targets[:, 3:].sum() > 0) else False] + + if left_false_positive or right_false_positive: + penalty_fp +=1 + if left_false_negative or right_false_negative: + penalty_fn +=1 + + + penalty = penalty_weight * (penalty_fp + penalty_fn) + + print('penalty',penalty) + + total_loss = True_CE_loss + penalty + print('total',total_loss) + + +def test_vector(): + import torch + targets= [[0, 0, 0, 0, 1, 0],[1,0,0,0,0,0],[0,0.5,0,0,0,0.5]] + prediction_vector = [[0, 1, 0, 0, 1, 0],[1,0,0,0,0,0],[0,0.5,0,0.5,0,0]] + + #to tensor + targets = torch.tensor(targets, dtype=torch.float32) + prediction_vector = torch.tensor(prediction_vector, dtype=torch.float32) + penalty_bonus=0 + for j in range(prediction_vector.shape[0]): + best_probL = prediction_vector[j, :3].max() + best_probR = prediction_vector[j, 3:].max() + idx_non_zero_target = torch.where(targets[j] != 0) + idx_non_zero_pred = torch.where(prediction_vector[j] != 0) + # get the value from the tensor + # idx_non_zero_target = idx_non_zero_target[0].tolist() + # idx_non_zero_pred = dict(zip(idx_non_zero_pred[0].tolist(), idx_non_zero_pred[1].tolist())) + print('idx_non_zero_target shape',[tensor.size()[0] for tensor in idx_non_zero_target]) + print('idx_non_zero_target',[tensor[0] for tensor in idx_non_zero_target]) + print('idx_non_zero_pred',idx_non_zero_pred) + # if idx_non_zero_pred == idx_non_zero_target: + # if best_probR > 0.7: + # penalty_bonus += 1 + # if best_probL > 0.7: + # penalty_bonus += 1 + print('penalty_bonus',penalty_bonus) + + +def rocAUC(): + import torch + import sklearn.metrics as mt + from sklearn.metrics import roc_auc_score + from monai.metrics import compute_roc_auc + + preds_sig = [[3.9941e-01, 1.5503e-01, 2.3230e-01, 7.5146e-01, 9.2676e-01, 6.6345e-02], + [5.0812e-02, 1.1871e-02, 9.9756e-01, 4.8267e-01, 1.5479e-01, 2.6050e-01], + [1.1467e-02, 4.6015e-04, 8.5156e-01, 5.1807e-01, 9.7119e-01, 9.3945e-01], + [3.3283e-04, 1.0321e-01, 9.2334e-01, 6.2132e-04, 9.8975e-01, 9.8682e-01]] + targets =[[0.0000, 1, 0.0000, 0.0000,1, 0.0000], + [0.0000, 0.0000, 1, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1], + [0.0000, 0.0000, 0.0000, 0.0000, 0, 1]] + + target_label = [[2],[5],[4]] + + preds_soft =nn.functional.softmax(torch.tensor(preds_sig), dim=1) + print('preds_soft',preds_soft) + + preds_sig = torch.tensor(preds_sig, dtype=torch.float32) + targets = torch.tensor(targets, dtype=torch.float32) + target_label = torch.tensor(target_label, dtype=torch.float32) + + print('targets:',targets[:,2]) + print('preds_soft[0,0]:',preds_soft[:,2]) + + auc_monai = compute_roc_auc(preds_soft[:,2], targets[:,2]) + print('AUC monai',auc_monai) + auc = roc_auc_score(targets[:,2], preds_soft[:,2], average='macro',multi_class='ovr') + print('AUC',auc) + if __name__ == "__main__": parser = argparse.ArgumentParser(description='Dataset information') - parser.add_argument('--csv', required=True, type=str, help='CSV to count and rename classes') + parser.add_argument('--csv', required=False, type=str, help='CSV to count and rename classes') parser.add_argument('--class_column', type=str, default='Classification', help='Name of class column') args = parser.parse_args() - main(args) + + test_vector() diff --git a/transforms/volumetric.py b/transforms/volumetric.py index 35b9293..201d454 100644 --- a/transforms/volumetric.py +++ b/transforms/volumetric.py @@ -89,19 +89,21 @@ def __call__(self, inp): class TrainTransforms: def __init__(self, size=256, pad=10): # image augmentation functions + calcul_rotate = math.pi/4 + print("Calcul Rotate", calcul_rotate) self.train_transform = Compose( [ EnsureChannelFirst(channel_dim='no_channel'), # RandFlip(prob=0.5), - RandRotate(prob=0.5, range_x=math.pi, range_y=math.pi, range_z=math.pi, mode="nearest", padding_mode='zeros'), + RandRotate(prob=0.5, range_x=calcul_rotate, range_y=calcul_rotate, range_z=calcul_rotate, mode="nearest", padding_mode='zeros'), SpatialPad(spatial_size=size + pad), RandSpatialCrop(roi_size=size, random_size=False), RandGaussianNoise(prob=0.5), - RandGaussianSmooth(prob=0.5), + # RandGaussianSmooth(prob=0.5), # ScaleIntensityRangePercentiles(2,99,0,1), ScaleIntensity(0,1), RandAdjustContrast(prob=0.5), - # NormalizeIntensity(), + # NormalizeIntensity(subtrahend=0,divisor=10), ToTensor(dtype=torch.float32, track_meta=False) ] )