From 038050eb35941af1c872193fec0ea246ee273901 Mon Sep 17 00:00:00 2001 From: mmasoud1 Date: Tue, 19 Nov 2024 18:15:04 -0500 Subject: [PATCH] fixes --- app/code/aggregator/gradient_aggregator.py | 7 + app/code/executor/loader.py | 8 +- app/code/executor/meshnet_executor.py | 585 ++++++++++-------- .../workflow/gradient_aggregation_workflow.py | 2 +- 4 files changed, 331 insertions(+), 271 deletions(-) diff --git a/app/code/aggregator/gradient_aggregator.py b/app/code/aggregator/gradient_aggregator.py index cd162ac..ae50a11 100644 --- a/app/code/aggregator/gradient_aggregator.py +++ b/app/code/aggregator/gradient_aggregator.py @@ -50,6 +50,13 @@ def average_gradients(self, gradients_list): # Convert gradients to numpy arrays and perform averaging n = len(gradients_list) + # Ensure each gradient in gradients_list is in NumPy format + gradients_list = [ + [np.array(grad) if isinstance(grad, list) else grad for grad in gradients] + for gradients in gradients_list + ] + + # Initialize Empty Arrays: sum_arrays is created as a list of arrays with the same shape as the gradients. # These arrays will accumulate the gradients from all clients. sum_arrays = [np.zeros_like(arr) for arr in gradients_list[0]] diff --git a/app/code/executor/loader.py b/app/code/executor/loader.py index b85bc46..103bf2a 100644 --- a/app/code/executor/loader.py +++ b/app/code/executor/loader.py @@ -75,9 +75,9 @@ def split_dataset(self): train_data, valid_data, infer_data = torch.utils.data.random_split(self, [train_size, valid_size, self.len - train_size - valid_size]) return train_data, valid_data, infer_data - def get_loaders(self, batch_size=1, shuffle=True): + def get_loaders(self, batch_size=1, shuffle=True, num_workers=0): train_data, valid_data, infer_data = self.split_dataset() - train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle) - valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False) - infer_loader = torch.utils.data.DataLoader(infer_data, batch_size=batch_size, shuffle=False) + train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) + valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers) + infer_loader = torch.utils.data.DataLoader(infer_data, batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_loader, valid_loader, infer_loader \ No newline at end of file diff --git a/app/code/executor/meshnet_executor.py b/app/code/executor/meshnet_executor.py index 39e05ed..b7c26e0 100644 --- a/app/code/executor/meshnet_executor.py +++ b/app/code/executor/meshnet_executor.py @@ -41,6 +41,12 @@ def __init__(self): self.model = enMesh_checkpoint(in_channels=1, n_classes=3, channels=5, config_file=config_file_path) # Check if GPU available + # GPU assignment + os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Set GPU ID explicitly + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + # Set the environment variable PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to help PyTorch handle fragmented memory + # This env var helps in avoid cuda out of memory message. MeshNetExecutor: OutOfMemoryError: CUDA out of memory. + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) @@ -63,7 +69,8 @@ def __init__(self): # Check if the initial weights file exists if os.path.exists(initial_weights_file_path): # If weights have already been saved, load them from the file - self.model.load_state_dict(torch.load(initial_weights_file_path)) + # self.model.load_state_dict(torch.load(initial_weights_file_path)) + self.model.load_state_dict(torch.load(initial_weights_file_path, weights_only=True)) self.logger.log_message(f"Loaded initial weights from file for this site.") else: # Save initial weights to file (first site) @@ -76,26 +83,37 @@ def __init__(self): param.requires_grad = True # Initialize the variable to store the previous learning rate (initially None) - self.learning_rate = self.previous_lr = 0.0001 #<<<<<<<<<<<<<<<<<<<<<<<< + self.learning_rate = 0.001 #<<<<<<<<<<<<<<<<<<<<<<<< + + self.previous_lr = 0 # Optimizer and criterion setup self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) - self.criterion = torch.nn.CrossEntropyLoss() + + + + + # Criterion: CrossEntropy for loss calculation with class weights and label smoothing + class_weight = torch.FloatTensor([0.2, 1.0, 0.8]).to(self.device) # Adjust weights based on class balance + self.criterion = torch.nn.CrossEntropyLoss(weight=class_weight, label_smoothing=0.01) + # Add learning rate scheduler ( Overlook for now) <<<<<<<<<< # This will reduce the learning rate by a factor of 0.1 if the validation loss does not improve for 5 epochs. - self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.1, patience=5, verbose=True) + # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.1, patience=5, verbose=True) - - self.current_epoch = 0 # Epochs and aggregation interval - self.total_epochs = 50 # Set the total number of epochs #<<<<<<<<<<<<<<<<<<<<<<<< + self.total_epochs = 350 # Set the total number of epochs #<<<<<<<<<<<<<<<<<<<<<<<< + + self.target_dice = 0.95 # learning stop once reach this dice + self.stop_training = False # Flag to control breaking out of the outer loop + # self.aggregation_interval = 1 # Aggregation occurs every 5 epochs (you can modify this) - self.dice_threshold = 0.9 # Set the Dice score threshold + # self.dice_threshold = 0.9 # Set the Dice score threshold # Gradient accumulation ( Overlook for now) <<<<<<<<<< # self.gradient_accumulation_steps = 4 @@ -168,12 +186,26 @@ def execute( self.data_loader = Scanloader(db_file=db_file_path, label_type='GWlabels', num_cubes=1) self.trainloader, self.validloader, self.testloader = self.data_loader.get_loaders(batch_size=1) + + # Log data loader details + self.logger.log_message(f"Data loader has {len(self.trainloader)} training samples") + self.shape = 256 self.current_iteration = 0 self.train_size = len(self.trainloader) self.iterations = self.total_epochs * self.train_size + # Initialize the scheduler now that trainloader is available + self.scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, + max_lr=0.01, + total_steps=self.iterations if self.iterations > 0 else 1 # Exact number of steps + # epochs=self.total_epochs, + # steps_per_epoch=len(self.trainloader) if len(self.trainloader) > 0 else 1 + ) + + # Set flag to True so the loader is only initialized once self.data_loader_initialized = True self.logger.log_message(f"Data loader initialized for site {self.site_name}") @@ -194,128 +226,172 @@ def execute( self.apply_gradients(aggregated_gradients, fl_ctx) return Shareable() - # def train_and_get_gradients_old(self): - # for epoch in range(self.total_epochs): - # self.logger.log_message(f"Starting Epoch {epoch}/{self.total_epochs}, Aggregation Interval: {self.aggregation_interval}") - # self.model.train() - - # # Initialize accumulators for the loss and gradients - # total_loss = 0.0 - # gradient_accumulator = [torch.zeros_like(param).to(self.device) for param in self.model.parameters()] - - # # Training loop for one epoch (full pass through the dataset) - # for batch_id, (image, label) in enumerate(self.trainloader): - # image, label = image.to(self.device), label.to(self.device) - # self.optimizer.zero_grad() - # # Mixed precision and checkpointing - # with torch.amp.autocast(device_type='cuda'): - # output = torch.utils.checkpoint.checkpoint(self.model, image, use_reentrant=False) - # label = label.squeeze(1) - # loss = self.criterion(output, label.long()) + # def train_and_get_gradients_old(self, fl_ctx): + # # Set the model to training mode, so dropout is activated + # # and BatchNorm normalizes the input based on the statistics of the current mini-batch. + # self.model.train() + + # # Initialize accumulators for the loss and gradients + # total_loss = 0.0 + + + + # # Training loop for one epoch (full pass through the dataset) + + # for batch_id, (image, label) in enumerate(self.trainloader): + + # # Moving data to the correct device (CPU or GPU) + # image, label = image.to(self.device), label.to(self.device) + + # self.optimizer.zero_grad() # It resets/clears the gradients of all model parameters (i.e., weights). + # # PyTorch by defualt adds new gradients to any existing gradients, to prevent this accumulation of gradients + # # from previous iterations, manually set them to zero at the beginning of each new training iteration. + + + # # Mixed precision and checkpointing + # with torch.amp.autocast(device_type='cuda'): + # # Forward passing input data through the model to get predictions, start training + + # # reshape the 3D image tensor into a 5D tensor with the shape [batch_size, channels, depth, height, width]. + # # -1: The batch size dimension is inferred. + # # 1: This indicates a single channel (grayscale MRI image). + # output = torch.utils.checkpoint.checkpoint(self.model, image.reshape(-1, 1, self.shape, self.shape, self.shape), use_reentrant=False) + + # labels = torch.squeeze(label) # Squeeze the label + # # Training label shape: torch.Size([1, 256, 256, 256]) + # # Squeeze labels shape : [ 256, 256, 256] + + # labels = (labels * 2).round().long() # Multiply by 2, round the values, and cast to long + + + # # Log the shapes and unique values of the image and label once + # if self.log_image_label_shapes_once: + # # Log image and label shapes + # self.logger.log_message(f"Training image shape: {image.shape}") + # # Training image shape: torch.Size([1, 256, 256, 256]) + + # self.logger.log_message(f"Training label shape: {label.shape}") + # # Training label shape: torch.Size([1, 256, 256, 256]) + + # self.logger.log_message(f"Training output shape: {output.shape}") + # # Training output shape: torch.Size([1, 3, 256, 256, 256]) + + # # Log the unique values in both image and label + # unique_label = torch.unique(label) + # unique_labels = torch.unique(labels) + # self.logger.log_message(f"Unique values in training GT label: {unique_label.tolist()}") + # # Unique values in training label: [0.0, 0.5, 1.0] <<<<<<<<<<<<<<<<< Normalized by Pratyush + + # self.logger.log_message(f"Unique values in training sequeezed long scaled labels: {unique_labels.tolist()}") + # # Unique values in training label: [0, 1, 2] + + # self.log_image_label_shapes_once = False # Set to False so this is only logged once - # total_loss += loss.item() + # # compute the loss between the predicted output and the ground truth labels + # # The label tensor is reshaped to [batch_size, depth, height, width] to match the output shape of the model. + # # .long() * 2: Double the label values. Looks like Pratyush made it 0, 0.5, 1 range. - # # Scale loss and backward pass - # self.scaler.scale(loss).backward() + # # For CrossEntropyLoss, the input (predictions) should have + # # shape [batch_size, num_classes, height, width, depth] (as the output does), + # # while the target (label) should have shape [batch_size, height, width, depth] containing class indices as integers. - # # Accumulate gradients - # for i, param in enumerate(self.model.parameters()): - # if param.grad is not None: - # gradient_accumulator[i] += param.grad.clone() - # self.scaler.step(self.optimizer) - # self.scaler.update() - # torch.cuda.empty_cache() + # loss = self.criterion(output, labels.reshape(-1, self.shape, self.shape, self.shape)) - # # Log the average loss per epoch - # average_loss = total_loss / len(self.trainloader) - # dice_score = self.calculate_dice(self.trainloader) - # self.logger.log_message(f"Site {self.site_name} - Epoch {epoch}: Loss = {average_loss}, Dice = {dice_score}") - # # Call aggregation based on your set aggregation_interval - # if (epoch + 1) % self.aggregation_interval == 0: - # # Perform model aggregation here - # return [grad.clone().cpu().numpy() for grad in gradient_accumulator if grad is not None] + # # Accumulate loss + # total_loss += loss.item() - # return [] + # # Scale loss and backward pass + # # self.scaler.scale(loss).backward() - # def train_and_get_gradients_new(self): - # for epoch in range(self.total_epochs): - # # self.logger.log_message(f"Starting Epoch {epoch+1}/{self.total_epochs}, Aggregation Interval: {self.aggregation_interval}") - # self.model.train() + # # Scale loss and backward pass (Backpropagation) + # # Backward pass (gradient calculation): Calculate the gradients for all the model's parameters with respect to the loss: + # loss.backward() # calculate gradients - # # Initialize accumulators for the loss and gradients - # total_loss = 0.0 - # gradient_accumulator = [torch.zeros_like(param).to(self.device) for param in self.model.parameters()] + # self.optimizer.step() # gradients are applied to the model parameters + # # Training done for that round. - # # Training loop for one epoch (full pass through the dataset) - # for batch_id, (image, label) in enumerate(self.trainloader): - # image, label = image.to(self.device), label.to(self.device) - # self.optimizer.zero_grad() + # # Update the learning rate + # self.scheduler.step() - # # Mixed precision and checkpointing - # with torch.amp.autocast(device_type='cuda'): - # output = torch.utils.checkpoint.checkpoint(self.model, image, use_reentrant=False) - # label = label.squeeze(1) - # loss = self.criterion(output, label.long()) + # # Get the current learning rate + # current_lr = self.optimizer.param_groups[0]['lr'] - # total_loss += loss.item() + # # Check if the learning rate has changed, and log it if so + # if current_lr != self.previous_lr: + # self.logger.log_message(f"Learning rate changed from {self.previous_lr} to: {current_lr}") + # self.previous_lr = current_lr # Update the previous learning rate - # # Scale loss and backward pass - # self.scaler.scale(loss).backward() - # # Accumulate gradients - # for i, param in enumerate(self.model.parameters()): - # if param.grad is not None: - # gradient_accumulator[i] += param.grad.clone() + # # # Accumulate gradients without updating yet + # # if (batch_id + 1) % self.gradient_accumulation_steps == 0: + # # # Update optimizer + # # self.scaler.step(self.optimizer) + # # self.scaler.update() + # # self.optimizer.zero_grad() - # self.scaler.step(self.optimizer) - # self.scaler.update() - # torch.cuda.empty_cache() + # # Clear GPU cache (No need) + # # torch.cuda.empty_cache() - # # Log the average loss per epoch - # average_loss = total_loss / len(self.trainloader) - # dice_score = self.calculate_dice(self.trainloader) - # self.logger.log_message(f"Site {self.site_name} - Epoch {epoch+1}: Loss = {average_loss}, Dice = {dice_score}") + # # Log the average loss and Dice score per epoch + # average_loss = total_loss / len(self.trainloader) - # # Check if it's time to perform aggregation - # if (epoch + 1) % self.aggregation_interval == 0: - # # Return the gradients after completing the specified aggregation interval - # self.logger.log_message(f"Performing aggregation after epoch {epoch+1}") - # return [grad.clone().cpu().numpy() for grad in gradient_accumulator if grad is not None] + # # dice_score = self.calculate_dice(self.trainloader) + # # Calculate Dice score on the validation set + # # self.model.eval() # Set the model to evaluation mode for validation + # dice_score = self.calculate_dice(self.validloader, fl_ctx) # Use validation set + # self.logger.log_message(f"{self.site_name} - Epoch {self.current_epoch}: Loss = {average_loss}, Val Dice = {dice_score}") - # return [] - # def setup_site_logger(self): - # site_id = os.getenv('FL_SITE_ID', 'site_unknown') # Use environment variable or other means to set site ID - # log_dir = f'logs/{site_id}' - # os.makedirs(log_dir, exist_ok=True) - # log_filename = os.path.join(log_dir, f'training_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') + # # Check for early stopping <<<<<<<<<<<<<<<<<< (Overlooked for now) + # # if average_loss < self.best_loss: + # # self.best_loss = average_loss + # # self.epochs_without_improvement = 0 + # # else: + # # self.epochs_without_improvement += 1 - # logging.basicConfig( - # filename=log_filename, - # level=logging.INFO, - # format='%(asctime)s - %(levelname)s - %(message)s', - # datefmt='%Y-%m-%d %H:%M:%S', - # ) + # # if self.epochs_without_improvement >= self.early_stopping_patience: + # # self.logger.log_message(f"Early stopping triggered at epoch {self.current_epoch}") + # # return [] + + # self.logger.log_message(f"{self.site_name} Preparing payload after an iteration in epoch {self.current_epoch}") + # # return [grad.clone().cpu().numpy() for grad in gradient_accumulator if grad is not None] + + # # Accumulate gradients + # gradients = [] + # for i, param in enumerate(self.model.parameters()): + # if param.grad is not None: + # gradients.append(param.grad.clone().cpu().numpy()) + - # self.logger = logging.getLogger() - # self.logger.info("Logging started") + # # for example : + # # gradients = [ + # # array([[ 0.01, -0.02], [ 0.03, 0.04]]), # gradient for some weight matrix (e.g shape [2, 2]) + # # array([0.005, -0.015]), # gradient for some bias vector (e.g shape [2]) + # # array([[ 0.02, 0.01], [-0.03, 0.05]]) # Another gradient for a different weight matrix and so on. + # # ] - # def load_model_weights(self, model_weights): - # # Load the received model weights into the local model - # # (((( FOR FUTUR USE)))) - # self.model.load_state_dict(model_weights) - # self.logger.log_message(f"Loaded initial model weights for site {self.site_name}") + # return gradients + # Define functions to save and load checkpoints + def save_checkpoint(model, optimizer, scheduler, epoch, filename="checkpoint.pth"): + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + }, filename) + print(f"Checkpoint saved at epoch {epoch}") + def train_and_get_gradients(self, fl_ctx): # Set the model to training mode, so dropout is activated @@ -324,9 +400,9 @@ def train_and_get_gradients(self, fl_ctx): # Initialize accumulators for the loss and gradients total_loss = 0.0 + total_train_dice = 0.0 - # Training loop for one epoch (full pass through the dataset) for batch_id, (image, label) in enumerate(self.trainloader): @@ -339,117 +415,152 @@ def train_and_get_gradients(self, fl_ctx): # from previous iterations, manually set them to zero at the beginning of each new training iteration. - # Mixed precision and checkpointing - with torch.amp.autocast(device_type='cuda'): - # Forward passing input data through the model to get predictions, start training - - # reshape the 3D image tensor into a 5D tensor with the shape [batch_size, channels, depth, height, width]. - # -1: The batch size dimension is inferred. - # 1: This indicates a single channel (grayscale MRI image). - output = torch.utils.checkpoint.checkpoint(self.model, image.reshape(-1, 1, self.shape, self.shape, self.shape), use_reentrant=False) - labels = torch.squeeze(label) # Squeeze the label - # Training label shape: torch.Size([1, 256, 256, 256]) - # Squeeze labels shape : [ 256, 256, 256] + # Forward passing input data through the model to get predictions, start training + + # reshape the 3D image tensor into a 5D tensor with the shape [batch_size, channels, depth, height, width]. + # -1: The batch size dimension is inferred. + # 1: This indicates a single channel (grayscale MRI image). + # output = torch.utils.checkpoint.checkpoint(self.model, image.reshape(-1, 1, self.shape, self.shape, self.shape), use_reentrant=False) - labels = (labels * 2).round().long() # Multiply by 2, round the values, and cast to long + output = self.model(image.reshape(-1, 1, self.shape, self.shape, self.shape)) + labels = torch.squeeze(label) # Squeeze the label + # Training label shape: torch.Size([1, 256, 256, 256]) + # Squeeze labels shape : [ 256, 256, 256] - # Log the shapes and unique values of the image and label once - if self.log_image_label_shapes_once: - # Log image and label shapes - self.logger.log_message(f"Training image shape: {image.shape}") - # Training image shape: torch.Size([1, 256, 256, 256]) + labels = (labels * 2).round().long() # Multiply by 2, round the values, and cast to long - self.logger.log_message(f"Training label shape: {label.shape}") - # Training label shape: torch.Size([1, 256, 256, 256]) - self.logger.log_message(f"Training output shape: {output.shape}") - # Training output shape: torch.Size([1, 3, 256, 256, 256]) - - # Log the unique values in both image and label - unique_label = torch.unique(label) - unique_labels = torch.unique(labels) - self.logger.log_message(f"Unique values in training GT label: {unique_label.tolist()}") - # Unique values in training label: [0.0, 0.5, 1.0] <<<<<<<<<<<<<<<<< Normalized by Pratyush - - self.logger.log_message(f"Unique values in training sequeezed long scaled labels: {unique_labels.tolist()}") - # Unique values in training label: [0, 1, 2] + # Log the shapes and unique values of the image and label once + if self.log_image_label_shapes_once: + # Log image and label shapes + self.logger.log_message(f"Training image shape: {image.shape}") + # Training image shape: torch.Size([1, 256, 256, 256]) - self.log_image_label_shapes_once = False # Set to False so this is only logged once + self.logger.log_message(f"Training label shape: {label.shape}") + # Training label shape: torch.Size([1, 256, 256, 256]) - # compute the loss between the predicted output and the ground truth labels - # The label tensor is reshaped to [batch_size, depth, height, width] to match the output shape of the model. - # .long() * 2: Double the label values. Looks like Pratyush made it 0, 0.5, 1 range. + self.logger.log_message(f"Training output shape: {output.shape}") + # Training output shape: torch.Size([1, 3, 256, 256, 256]) + + # Log the unique values in both image and label + unique_label = torch.unique(label) + unique_labels = torch.unique(labels) + self.logger.log_message(f"Unique values in training GT label: {unique_label.tolist()}") + # Unique values in training label: [0.0, 0.5, 1.0] <<<<<<<<<<<<<<<<< Normalized by Pratyush + + self.logger.log_message(f"Unique values in training sequeezed long scaled labels: {unique_labels.tolist()}") + # Unique values in training label: [0, 1, 2] - # For CrossEntropyLoss, the input (predictions) should have - # shape [batch_size, num_classes, height, width, depth] (as the output does), - # while the target (label) should have shape [batch_size, height, width, depth] containing class indices as integers. + self.log_image_label_shapes_once = False # Set to False so this is only logged once + # compute the loss between the predicted output and the ground truth labels + # The label tensor is reshaped to [batch_size, depth, height, width] to match the output shape of the model. + # .long() * 2: Double the label values. Looks like Pratyush made it 0, 0.5, 1 range. + # For CrossEntropyLoss, the input (predictions) should have + # shape [batch_size, num_classes, height, width, depth] (as the output does), + # while the target (label) should have shape [batch_size, height, width, depth] containing class indices as integers. - loss = self.criterion(output, labels.reshape(-1, self.shape, self.shape, self.shape)) + loss = self.criterion(output, labels.reshape(-1, self.shape, self.shape, self.shape)) + train_dice = torch.mean(faster_dice(torch.squeeze(torch.argmax(output, dim=1)).long(), labels, labels=[0, 1, 2])) + # Scale loss and backward pass (Backpropagation) + # Backward pass (gradient calculation): Calculate the gradients for all the model's parameters with respect to the loss: + loss.backward() # calculate gradients + self.optimizer.step() # gradients are applied to the model parameters + # Training done for that round. + # Accumulate loss total_loss += loss.item() - # Scale loss and backward pass - # self.scaler.scale(loss).backward() + # Accumulate training dice + total_train_dice += train_dice.item() - # Scale loss and backward pass (Backpropagation) - # Backward pass (gradient calculation): Calculate the gradients for all the model's parameters with respect to the loss: - loss.backward() # calculate gradients + if self.scheduler._step_count < self.scheduler.total_steps: + # Update the learning rate + self.scheduler.step() + + # End of one Epoch training + + + # Validation phase + self.model.eval() + val_loss = 0.0 + dice_scores = [] + with torch.no_grad(): + for input, label in self.validloader: + input, label = input.to(self.device), label.to(self.device) + output = self.model(input.reshape(-1, 1, self.shape, self.shape, self.shape)) + + label = torch.squeeze(label) + label = (label * 2).round().long() + + loss = self.criterion(output, label.reshape(-1, self.shape, self.shape, self.shape)) + val_loss += loss.item() + + # Calculate Dice score + pred = torch.squeeze(torch.argmax(output, dim=1)) + dice = torch.mean(faster_dice(pred, label, labels=[0, 1, 2])) + dice_scores.append(dice) + + # Average Dice score + avg_train_dice = total_train_dice / len(self.trainloader) + avg_dice_score = sum(dice_scores) / len(dice_scores) if dice_scores else 0 + + self.logger.log_message(f"{self.site_name} - Epoch [{self.current_epoch+1}/{self.total_epochs}], Train Loss: {total_loss/len(self.trainloader):.4f}, " + f"Val Loss: {val_loss/len(self.validloader):.4f}, Train Dice: {avg_train_dice:.4f}, Val Dice: {avg_dice_score:.4f}, lr: {self.optimizer.param_groups[0]['lr']:.6f}") + + # Save checkpoint every few epochs + # if (epoch + 1) % save_every == 0 or (epoch + 1) == num_epochs: + # if (self.current_epoch + 1) == self.total_epochs: + # save_checkpoint(self.model, self.optimizer, self.scheduler, self.current_epoch + 1) + # logger.log_message(f"{self.site_name} - All training epochs finished and model saved. Stopping training.") + + # # Stop training if target Dice score is reached + # if avg_dice_score >= self.target_dice: + # self.stop_training = True + # save_checkpoint(self.model, self.optimizer, self.scheduler, self.current_epoch + 1) # Save the final checkpoint + # logger.log_message(f"{self.site_name} -Target Dice Score {self.target_dice} reached and model saved. ") + - self.optimizer.step() # gradients are applied to the model parameters - # Training done for that round. + # # # Accumulate gradients without updating yet + # # if (batch_id + 1) % self.gradient_accumulation_steps == 0: + # # # Update optimizer + # # self.scaler.step(self.optimizer) + # # self.scaler.update() + # # self.optimizer.zero_grad() - # Get the current learning rate - current_lr = self.optimizer.param_groups[0]['lr'] + # # Get the current learning rate + # current_lr = self.optimizer.param_groups[0]['lr'] - # Check if the learning rate has changed, and log it if so - if current_lr != self.previous_lr: - self.logger.log_message(f"Learning rate changed from {self.previous_lr} to: {current_lr}") - previous_lr = current_lr # Update the previous learning rate + # # Check if the learning rate has changed, and log it if so + # if current_lr != self.previous_lr: + # self.logger.log_message(f"Per epoch, learning rate changed from {self.previous_lr} to: {current_lr}") + # self.previous_lr = current_lr # Update the previous learning rate - # # Accumulate gradients without updating yet - # if (batch_id + 1) % self.gradient_accumulation_steps == 0: - # # Update optimizer - # self.scaler.step(self.optimizer) - # self.scaler.update() - # self.optimizer.zero_grad() + # # Log the average loss and Dice score per epoch + # average_loss = total_loss / len(self.trainloader) - # Clear GPU cache (No need) - # torch.cuda.empty_cache() + # avg_train_dice = total_train_dice / len(self.trainloader) - # Log the average loss and Dice score per epoch - average_loss = total_loss / len(self.trainloader) + # # dice_score = self.calculate_dice(self.trainloader) + # # Calculate Dice score on the validation set + # # self.model.eval() # Set the model to evaluation mode for validation + # val_dice_score = self.calculate_dice(self.validloader, fl_ctx) # Use validation set + # self.logger.log_message(f"{self.site_name} - Epoch {self.current_epoch}: Loss = {average_loss}, Avg Train Dice ={avg_train_dice}, Val Dice = {val_dice_score}") - # dice_score = self.calculate_dice(self.trainloader) - # Calculate Dice score on the validation set - self.model.eval() # Set the model to evaluation mode for validation - dice_score = self.calculate_dice(self.validloader, fl_ctx) # Use validation set - self.logger.log_message(f"{self.site_name} - Epoch {self.current_epoch}: Loss = {average_loss}, Val Dice = {dice_score}") - # Update the learning rate based on the loss <<<<<<<<<<<<<< (Overlooked for now) - self.scheduler.step(average_loss) - # Check for early stopping <<<<<<<<<<<<<<<<<< (Overlooked for now) - # if average_loss < self.best_loss: - # self.best_loss = average_loss - # self.epochs_without_improvement = 0 - # else: - # self.epochs_without_improvement += 1 - # if self.epochs_without_improvement >= self.early_stopping_patience: - # self.logger.log_message(f"Early stopping triggered at epoch {self.current_epoch}") - # return [] - self.logger.log_message(f"{self.site_name} Preparing payload after an iteration in epoch {self.current_epoch}") + self.logger.log_message(f"{self.site_name} Preparing payload after epoch {self.current_epoch}") # return [grad.clone().cpu().numpy() for grad in gradient_accumulator if grad is not None] # Accumulate gradients @@ -457,6 +568,13 @@ def train_and_get_gradients(self, fl_ctx): for i, param in enumerate(self.model.parameters()): if param.grad is not None: gradients.append(param.grad.clone().cpu().numpy()) + + + # Converts CUDA tensors to NumPy arrays directly isn’t allowed. We need to move tensors to the CPU before calling .numpy(). + # local_gradients = [param.grad.clone().cpu() for param in self.model.parameters()] # Move gradients to CPU + # numpy_arrays = [tensor.numpy() for tensor in local_gradients] + # gradients = [array.tolist() for array in numpy_arrays] + # for example : @@ -468,14 +586,14 @@ def train_and_get_gradients(self, fl_ctx): - return gradients + return gradients # New one with fast dice def calculate_dice(self, loader, fl_ctx): dice_total = 0.0 - + for image, label in loader: image, label = image.to(self.device), label.to(self.device) with torch.inference_mode(): @@ -573,89 +691,6 @@ def calculate_dice(self, loader, fl_ctx): - # def calculate_dice_old(self, loader): - # dice_total = 0.0 - - # for image, label in loader: - # image, label = image.to(self.device), label.to(self.device) - # with torch.no_grad(): - # # Ensure consistency by reshaping image and label similarly as in the training loop - # output = self.model(image.reshape(-1, 1, self.shape, self.shape, self.shape)) # Model expects this reshaped - # # output shape : [1, 3, 256, 256, 256] - - # output_label = torch.argmax(output, dim=1) - # # output_label shape: [1, 256, 256, 256] - # # Max voxel value in output_label : 2 - - - # if self.log_shapes_once: - - # self.logger.log_message(f"loaded image tensor shape: {image.shape}") - # self.logger.log_message(f"Max voxel value in loaded image: {image.max().item()}") - - # self.logger.log_message(f"Model output tensor shape: {output.shape}") - # # output shape : [1, 3, 256, 256, 256] - - # # Log the shape after applying argmax only once - # self.logger.log_message(f"Output label shape after argmax: {output_label.shape}") - # # output_label shape: [1, 256, 256, 256] - - # # Log max voxel value - # self.logger.log_message(f"Max voxel value in output_label: {output_label.max().item()}") - # # Max voxel value in output_label : 2 - - # # Log the GT label shape - # self.logger.log_message(f" GT Label shape : {label.shape}") - # # GT label shape: [1, 256, 256, 256] - - - # # Log max voxel value in GT label - # self.logger.log_message(f"Max voxel value in GT label: {label.max().item()}") - # # Max voxel value in GT label: 1.0 - - # # Log the unique values in both output_label and label - # unique_output = torch.unique(output_label) - # unique_label = torch.unique(label) - # self.logger.log_message(f"Unique values in output_label: {unique_output.tolist()}") - # self.logger.log_message(f"Unique values in GT label: {unique_label.tolist()}") - - - # self.log_shapes_once = False - - - # # Applying the same transformation to the labels as in training - # # dice_score = self.calculate_dice_score(output_label, label.reshape(-1, self.shape, self.shape, self.shape) * 2) - - # # Multiply the label by 2 to restore original class values (0, 1, 2) before calculating Dice score - # reshaped_label = label.reshape(self.shape, self.shape, self.shape).long() * 2 - - # dice_score = self.calculate_dice_score(torch.squeeze(output_label), reshaped_label) - # dice_total += dice_score - - # return dice_total / len(loader) - - - - - - - - # def calculate_dice_score(self, pred, target, num_classes=3): - # dice_scores = [] - - # for class_idx in range(num_classes): - # pred_class = (pred == class_idx).float() - # target_class = (target == class_idx).float() - - # intersection = (pred_class * target_class).sum() - # union = pred_class.sum() + target_class.sum() - - # # To avoid division by zero, add a small epsilon to the denominator - # dice_score = (2.0 * intersection) / (union + 1e-6) - # dice_scores.append(dice_score.item()) - - # # Return the mean Dice score across all classes - # return sum(dice_scores) / len(dice_scores) @@ -667,12 +702,28 @@ def apply_gradients(self, aggregated_gradients, fl_ctx): # self.optimizer.step() # Apply aggregated gradients to the model parameters - self.optimizer.zero_grad() + + aggregated_gradients = [np.array(array) for array in aggregated_gradients] + + # self.optimizer.zero_grad() # The loop for param, grad in zip(self.model.parameters(), aggregated_gradients) is iterating # through the model's parameters and the aggregated gradients. - for param, grad in zip(self.model.parameters(), aggregated_gradients): - param.grad = torch.tensor(grad).to(self.device) # manually setting the .grad attribute of each model parameter with the aggregated gradient. + # for param, grad in zip(self.model.parameters(), aggregated_gradients): + # if grad is not None: + # param.grad = torch.tensor(grad).to(self.device) # manually setting the .grad attribute of each model parameter with the aggregated gradient. + + + # self.optimizer.zero_grad() # ensure that any previously accumulated gradients are cleared before applying new ones + + # Apply each aggregated gradient to the corresponding model parameter + for param, avg_grad in zip(self.model.parameters(), aggregated_gradients): + if param.requires_grad: + avg_grad = torch.tensor(avg_grad).to(param.device) + avg_grad = avg_grad.to(param.grad.dtype) + param.grad = avg_grad + + # Update model parameters based on the applied gradients self.optimizer.step() # Clear GPU memory cache after applying gradients @@ -704,3 +755,5 @@ def apply_gradients(self, aggregated_gradients, fl_ctx): # Increment the epoch counter after processing self.current_epoch += 1 + # self.current_iteration += 1 + diff --git a/app/code/workflow/gradient_aggregation_workflow.py b/app/code/workflow/gradient_aggregation_workflow.py index b220703..4672f57 100644 --- a/app/code/workflow/gradient_aggregation_workflow.py +++ b/app/code/workflow/gradient_aggregation_workflow.py @@ -8,7 +8,7 @@ def __init__( self, aggregator_id="gradient_aggregator", min_clients: int = 2, - num_rounds: int = 50, # <<<<<<<<<<< + num_rounds: int = 350, # <<<<<<<<<<< start_round: int = 0, wait_time_after_min_received: int = 10, train_timeout: int = 0,