From cd9275af952748b8c3d6761bca8842135f6aa89d Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Thu, 23 Nov 2023 13:46:41 -0500 Subject: [PATCH 01/14] add ae - finta --- dwi_ml/data/hdf5/hdf5_creation.py | 23 +-- .../streamlines/data_augmentation.py | 2 +- .../processing/streamlines/post_processing.py | 2 +- dwi_ml/models/main_models.py | 7 +- dwi_ml/models/projects/ae_models.py | 182 ++++++++++++++++++ dwi_ml/training/batch_loaders.py | 7 +- dwi_ml/training/trainers.py | 55 +++++- scripts_python/ae_train_model.py | 148 ++++++++++++++ 8 files changed, 406 insertions(+), 20 deletions(-) create mode 100644 dwi_ml/models/projects/ae_models.py create mode 100755 scripts_python/ae_train_model.py diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index e46e534b..a01194a3 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -593,7 +593,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, def _process_one_streamline_group( self, subj_dir: Path, group: str, subj_id: str, - header: nib.Nifti1Header): + header: nib.Nifti1Header, remove_invalid=False): """ Loads and processes a group of tractograms and merges all streamlines together. @@ -681,12 +681,13 @@ def _process_one_streamline_group( # with Path. save_tractogram(final_sft, str(output_fname)) - # Removing invalid streamlines - logging.debug(' *Total: {:,.0f} streamlines. Now removing ' - 'invalid streamlines.'.format(len(final_sft))) - final_sft.remove_invalid_streamlines() - logging.info(" Final number of streamlines: {:,.0f}." - .format(len(final_sft))) + if remove_invalid: + # Removing invalid streamlines + logging.debug(' *Total: {:,.0f} streamlines. Now removing ' + 'invalid streamlines.'.format(len(final_sft))) + final_sft.remove_invalid_streamlines() + logging.info(" Final number of streamlines: {:,.0f}." + .format(len(final_sft))) conn_matrix = None conn_info = None @@ -741,10 +742,10 @@ def _load_and_process_sft(self, tractogram_file, tractogram_name, header): "We do not support file's type: {}. We only support .trk " "and .tck files.".format(tractogram_file)) if file_extension == '.trk': - if not is_header_compatible(str(tractogram_file), header): - raise ValueError("Streamlines group is not compatible with " - "volume groups\n ({})" - .format(tractogram_file)) + if header: + if not is_header_compatible(str(tractogram_file), header): + raise ValueError("Streamlines group is not compatible with " + "volume groups\n ({})".format(tractogram_file)) # overriding given header. header = 'same' diff --git a/dwi_ml/data/processing/streamlines/data_augmentation.py b/dwi_ml/data/processing/streamlines/data_augmentation.py index c3cedb9b..96d997d6 100644 --- a/dwi_ml/data/processing/streamlines/data_augmentation.py +++ b/dwi_ml/data/processing/streamlines/data_augmentation.py @@ -7,7 +7,7 @@ from dipy.io.stateful_tractogram import StatefulTractogram from nibabel.streamlines.tractogram import (PerArrayDict, PerArraySequenceDict) import numpy as np -from scilpy.tractograms.streamline_operations import resample_streamlines_step_size +from scilpy.tracking.tools import resample_streamlines_step_size from scilpy.utils.streamlines import compress_sft diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index 67bed921..9f319451 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -5,7 +5,7 @@ import numpy as np import torch -from scilpy.tractograms.uncompress import uncompress +from scilpy.tractanalysis.uncompress import uncompress from scilpy.tractanalysis.tools import \ extract_longest_segments_from_profile as segmenting_func diff --git a/dwi_ml/models/main_models.py b/dwi_ml/models/main_models.py index 14576222..f9043a19 100644 --- a/dwi_ml/models/main_models.py +++ b/dwi_ml/models/main_models.py @@ -16,8 +16,7 @@ from dwi_ml.data.processing.space.neighborhood import prepare_neighborhood_vectors from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.io_utils import add_resample_or_compress_arg -from dwi_ml.models.direction_getter_models import keys_to_direction_getters, \ - AbstractDirectionGetterModel +from dwi_ml.models.direction_getter_models import keys_to_direction_getters from dwi_ml.models.embeddings import keys_to_embeddings, NNEmbedding, NoEmbedding from dwi_ml.models.utils.direction_getters import add_direction_getter_args @@ -72,7 +71,7 @@ def __init__(self, experiment_name: str, # To tell our trainer what to send to the forward / loss methods. self.forward_uses_streamlines = False self.loss_uses_streamlines = False - + # To tell our batch loader how to resample streamlines during training # (should also be the step size during tractography). if step_size and compress_lines: @@ -208,7 +207,7 @@ def _load_state(cls, model_dir): def forward(self, *inputs, **kw): raise NotImplementedError - def compute_loss(self, *model_outputs, **kw): + def compute_loss(self, model_outputs, targets): raise NotImplementedError diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py new file mode 100644 index 00000000..fb750d6d --- /dev/null +++ b/dwi_ml/models/projects/ae_models.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +import logging +from typing import List + +import torch +from torch.nn import functional as F + +from dwi_ml.models.main_models import MainModelAbstract + + +class ModelAE(MainModelAbstract): + """ + Recurrent tracking model. + + Composed of an embedding for the imaging data's input + for the previous + direction's input, an RNN model to process the sequences, and a direction + getter model to convert the RNN outputs to the right structure, e.g. + deterministic (3D vectors) or probabilistic (based on probability + distribution parameters). + """ + def __init__(self, kernel_size, latent_space_dims, + experiment_name: str, + # Target preprocessing params for the batch loader + tracker + step_size: float = None, + compress_lines: float = False, + # Other + log_level=logging.root.level): + super().__init__(experiment_name, step_size, compress_lines, log_level) + + self.kernel_size = kernel_size + self.latent_space_dims = latent_space_dims + + self.pad = torch.nn.ReflectionPad1d(1) + + def pre_pad(m): + return torch.nn.Sequential(self.pad, m) + + self.fc1 = torch.nn.Linear(8192, + self.latent_space_dims) # 8192 = 1024*8 + self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + + """ + Encode convolutions + """ + self.encod_conv1 = pre_pad( + torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv2 = pre_pad( + torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv3 = pre_pad( + torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv4 = pre_pad( + torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv5 = pre_pad( + torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv6 = pre_pad( + torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0) + ) + + """ + Decode convolutions + """ + self.decod_conv1 = pre_pad( + torch.nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0) + ) + self.upsampl1 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv2 = pre_pad( + torch.nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0) + ) + self.upsampl2 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv3 = pre_pad( + torch.nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0) + ) + self.upsampl3 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv4 = pre_pad( + torch.nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0) + ) + self.upsampl4 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv5 = pre_pad( + torch.nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0) + ) + self.upsampl5 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv6 = pre_pad( + torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) + ) + + self.forward_uses_streamlines = True + self.loss_uses_streamlines = True + + def forward(self, + input_streamlines: List[torch.tensor], + ): + """Run the model on a batch of sequences. + + Parameters + ---------- + input_streamlines: List[torch.tensor], + Batch of streamlines. Only used if previous directions are added to + the model. Used to compute directions; its last point will not be + used. + + Returns + ------- + model_outputs : List[Tensor] + Output data, ready to be passed to either `compute_loss()` or + `get_tracking_directions()`. + """ + input_streamlines = torch.stack(input_streamlines) + input_streamlines = torch.swapaxes(input_streamlines, 1, 2) + + x = self.decode(self.encode(input_streamlines)) + return x + + def encode(self, x): + h1 = F.relu(self.encod_conv1(x)) + h2 = F.relu(self.encod_conv2(h1)) + h3 = F.relu(self.encod_conv3(h2)) + h4 = F.relu(self.encod_conv4(h3)) + h5 = F.relu(self.encod_conv5(h4)) + h6 = self.encod_conv6(h5) + + self.encoder_out_size = (h6.shape[1], h6.shape[2]) + + # Flatten + h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) + + fc1 = self.fc1(h7) + + return fc1 + + def decode(self, z): + fc = self.fc2(z) + fc_reshape = fc.view( + -1, self.encoder_out_size[0], self.encoder_out_size[1] + ) + h1 = F.relu(self.decod_conv1(fc_reshape)) + h2 = self.upsampl1(h1) + h3 = F.relu(self.decod_conv2(h2)) + h4 = self.upsampl2(h3) + h5 = F.relu(self.decod_conv3(h4)) + h6 = self.upsampl3(h5) + h7 = F.relu(self.decod_conv4(h6)) + h8 = self.upsampl4(h7) + h9 = F.relu(self.decod_conv5(h8)) + h10 = self.upsampl5(h9) + h11 = self.decod_conv6(h10) + + return h11 + + def compute_loss(self, model_outputs, targets, average_results=True): + print("COMPARISON\n") + targets = torch.stack(targets) + targets = torch.swapaxes(targets, 1, 2) + print(targets[0, :, 0:5]) + print(model_outputs[0, :, 0:5]) + reconstruction_loss = torch.nn.MSELoss(reduction="mean") + mse = reconstruction_loss(model_outputs, targets) + + # loss_function_vae + # See Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) + # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) + # kld = torch.sum(kld_element).__mul__(-0.5) + + return mse, 1 diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 0173caaf..5c4acbb9 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -190,7 +190,11 @@ def _data_augmentation_sft(self, sft): self.context_subset.compress == self.model.compress_lines: logger.debug("Compression rate is the same as when creating " "the hdf5 dataset. Not compressing again.") - else: + elif self.model.step_size is not None and \ + self.model.compress_lines is not None: + logger.debug("Resample streamlines using: \n" + + "- step_size: {}\n".format(self.model.step_size) + + "- compress_lines: {}".format(self.model.compress_lines)) sft = resample_or_compress(sft, self.model.step_size, self.model.compress_lines) @@ -306,6 +310,7 @@ def load_batch_streamlines( sft.to_vox() sft.to_corner() batch_streamlines.extend(sft.streamlines) + batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines] return batch_streamlines, final_s_ids_per_subj diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 297cdeeb..f7535c6e 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -4,7 +4,7 @@ import logging import os import shutil -from typing import Union, List, Tuple +from typing import Union, List from comet_ml import (Experiment as CometExperiment, ExistingExperiment) import numpy as np @@ -972,7 +972,58 @@ def run_one_batch(self, data): the list of all streamlines (if we concatenate all sfts' streamlines) """ - raise NotImplementedError + # Data interpolation has not been done yet. GPU computations are done + # here in the main thread. + targets, ids_per_subj = data + + # Dataloader always works on CPU. Sending to right device. + # (model is already moved). + targets = [s.to(self.device, non_blocking=True, dtype=torch.float) + for s in targets] + + # Getting the inputs points from the volumes. + # Uses the model's method, with the batch_loader's data. + # Possibly skipping the last point if not useful. + streamlines_f = targets + if isinstance(self.model, ModelWithDirectionGetter) and \ + not self.model.direction_getter.add_eos: + # No EOS = We don't use the last coord because it does not have an + # associated target direction. + streamlines_f = [s[:-1, :] for s in streamlines_f] + + # Possibly add noise to inputs here. + logger.debug('*** Computing forward propagation') + if self.model.forward_uses_streamlines: + # Now possibly add noise to streamlines (training / valid) + streamlines_f = self.batch_loader.add_noise_streamlines_forward( + streamlines_f, self.device) + + # Possibly computing directions twice (during forward and loss) + # but ok, shouldn't be too heavy. Easier to deal with multiple + # projects' requirements by sending whole streamlines rather + # than only directions. + model_outputs = self.model(streamlines_f) + del streamlines_f + else: + del streamlines_f + model_outputs = self.model() + + logger.debug('*** Computing loss') + if self.model.loss_uses_streamlines: + targets = self.batch_loader.add_noise_streamlines_loss( + targets, self.device) + + results = self.model.compute_loss(model_outputs, targets, + average_results=True) + else: + results = self.model.compute_loss(model_outputs, + average_results=True) + + if self.use_gpu: + log_gpu_memory_usage(logger) + + # The mean tensor is a single value. Converting to float using item(). + return results def fix_parameters(self): """ diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py new file mode 100755 index 00000000..4120ce88 --- /dev/null +++ b/scripts_python/ae_train_model.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Train a model for Autoencoders +""" +import argparse +import logging +import os + +# comet_ml not used, but comet_ml requires to be imported before torch. +# See bug report here https://github.com/Lightning-AI/lightning/issues/5829 +# Importing now to solve issues later. +import comet_ml +import torch + +from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist + +from dwi_ml.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.experiment_utils.prints import format_dict_to_str +from dwi_ml.experiment_utils.timer import Timer +from dwi_ml.io_utils import add_logging_arg, add_memory_args +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.training.trainers import DWIMLAbstractTrainer +from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, + prepare_batch_sampler) +from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) +from dwi_ml.training.batch_loaders import DWIMLAbstractBatchLoader +from dwi_ml.training.utils.experiment import ( + add_mandatory_args_experiment_and_hdf5_path) +from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ + format_lr + + +def prepare_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + add_mandatory_args_experiment_and_hdf5_path(p) + add_args_batch_sampler(p) + add_args_batch_loader(p) + training_group = add_training_args(p) + p.add_argument('streamline_group_name', + help="Name of the group in hdf5") + add_memory_args(p, add_lazy_options=True, add_rng=True) + add_logging_arg(p) + + # Additional arg for projects + training_group.add_argument( + '--clip_grad', type=float, default=None, + help="Value to which the gradient norms to avoid exploding gradients." + "\nDefault = None (not clipping).") + + return p + + +def init_from_args(args, sub_loggers_level): + torch.manual_seed(args.rng) # Set torch seed + + # Prepare the dataset + dataset = prepare_multisubjectdataset(args, load_testing=False, + log_level=sub_loggers_level) + + # Preparing the model + # (Direction getter) + # (Nb features) + # Final model + with Timer("\n\nPreparing model", newline=True, color='yellow'): + # INPUTS: verifying args + model = ModelAE( + experiment_name=args.experiment_name, + step_size=args.step_size, compress_lines=args.compress, + kernel_size=3, latent_space_dims=32, + log_level=sub_loggers_level) + + logging.info("AEmodel final parameters:" + + format_dict_to_str(model.params_for_checkpoint)) + + logging.info("Computed parameters:" + + format_dict_to_str(model.computed_params_for_display)) + + # Preparing the batch samplers + batch_sampler = prepare_batch_sampler(dataset, args, sub_loggers_level) + # Preparing the batch loader. + with Timer("\nPreparing batch loader...", newline=True, color='pink'): + batch_loader = DWIMLAbstractBatchLoader( + dataset=dataset, model=model, + streamline_group_name=args.streamline_group_name, + # OTHER + rng=args.rng, log_level=sub_loggers_level) + + logging.info("Loader user-defined parameters: " + + format_dict_to_str(batch_loader.params_for_checkpoint)) + + # Instantiate trainer + with Timer("\n\nPreparing trainer", newline=True, color='red'): + lr = format_lr(args.learning_rate) + trainer = DWIMLAbstractTrainer( + model=model, experiments_path=args.experiments_path, + experiment_name=args.experiment_name, batch_sampler=batch_sampler, + batch_loader=batch_loader, + # COMET + comet_project=args.comet_project, + comet_workspace=args.comet_workspace, + # TRAINING + learning_rates=lr, weight_decay=args.weight_decay, + optimizer=args.optimizer, max_epochs=args.max_epochs, + max_batches_per_epoch_training=args.max_batches_per_epoch_training, + max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, + patience=args.patience, patience_delta=args.patience_delta, + from_checkpoint=False, clip_grad=args.clip_grad, + # MEMORY + nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, + log_level=args.logging) + logging.info("Trainer params : " + + format_dict_to_str(trainer.params_for_checkpoint)) + + return trainer + + +def main(): + p = prepare_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.logging + if args.logging == 'DEBUG': + sub_loggers_level = 'INFO' + + logging.getLogger().setLevel(level=logging.INFO) + + # Check that all files exist + assert_inputs_exist(p, [args.hdf5_file]) + assert_outputs_exist(p, args, args.experiments_path) + + # Verify if a checkpoint has been saved. Else create an experiment. + if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, + "checkpoint")): + raise FileExistsError("This experiment already exists. Delete or use " + "script l2t_resume_training_from_checkpoint.py.") + + trainer = init_from_args(args, sub_loggers_level) + + run_experiment(trainer) + + +if __name__ == '__main__': + main() From 31e2e85d32cb393610fc1c4dab7c2d4e9098ef49 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 28 Nov 2023 09:17:25 -0500 Subject: [PATCH 02/14] modif with em --- dwi_ml/models/projects/ae_models.py | 30 ++++++++++++++++++++++++++--- dwi_ml/testing/visu_loss.py | 3 +-- scripts_python/ae_train_model.py | 2 +- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index fb750d6d..7acaaa40 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -101,6 +101,27 @@ def pre_pad(m): self.forward_uses_streamlines = True self.loss_uses_streamlines = True + @property + def params_for_checkpoint(self): + """All parameters necessary to create again the same model. Will be + used in the trainer, when saving the checkpoint state. Params here + will be used to re-create the model when starting an experiment from + checkpoint. You should be able to re-create an instance of your + model with those params.""" + #p = super().params_for_checkpoint() + p = {'kernel_size': self.kernel_size, + 'latent_space_dims': self.latent_space_dims, + 'experiment_name': self.experiment_name} + return p + + @classmethod + def _load_params(cls, model_dir): + p = super()._load_params(model_dir) + p['kernel_size'] = 3 + p['latent_space_dims'] = 32 + return p + + def forward(self, input_streamlines: List[torch.tensor], ): @@ -119,13 +140,16 @@ def forward(self, Output data, ready to be passed to either `compute_loss()` or `get_tracking_directions()`. """ - input_streamlines = torch.stack(input_streamlines) - input_streamlines = torch.swapaxes(input_streamlines, 1, 2) + x = self.decode(self.encode(input_streamlines)) return x def encode(self, x): + # x: list of tensors + x = torch.stack(x) + x = torch.swapaxes(x, 1, 2) + h1 = F.relu(self.encod_conv1(x)) h2 = F.relu(self.encod_conv2(h1)) h3 = F.relu(self.encod_conv3(h2)) @@ -167,7 +191,7 @@ def compute_loss(self, model_outputs, targets, average_results=True): targets = torch.swapaxes(targets, 1, 2) print(targets[0, :, 0:5]) print(model_outputs[0, :, 0:5]) - reconstruction_loss = torch.nn.MSELoss(reduction="mean") + reconstruction_loss = torch.nn.MSELoss(reduction="sum") mse = reconstruction_loss(model_outputs, targets) # loss_function_vae diff --git a/dwi_ml/testing/visu_loss.py b/dwi_ml/testing/visu_loss.py index 680d22b1..f867b6fe 100644 --- a/dwi_ml/testing/visu_loss.py +++ b/dwi_ml/testing/visu_loss.py @@ -15,8 +15,7 @@ from matplotlib import pyplot as plt from scilpy.io.utils import add_overwrite_arg -from dwi_ml.io_utils import add_logging_arg, add_arg_existing_experiment_path -from dwi_ml.io_utils import add_memory_args +from dwi_ml.io_utils import add_arg_existing_experiment_path, add_logging_arg, add_memory_args from dwi_ml.models.main_models import ModelWithDirectionGetter from dwi_ml.testing.utils import add_args_testing_subj_hdf5 diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 4120ce88..9b58e77d 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -68,7 +68,7 @@ def init_from_args(args, sub_loggers_level): # INPUTS: verifying args model = ModelAE( experiment_name=args.experiment_name, - step_size=args.step_size, compress_lines=args.compress, + step_size=None, compress_lines=None, kernel_size=3, latent_space_dims=32, log_level=sub_loggers_level) From 346c383e5bc0909dcfa8bb21b0683595562b04a0 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 28 Nov 2023 09:17:38 -0500 Subject: [PATCH 03/14] add visu --- scripts_python/ae_visualize_streamlines.py | 111 +++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 scripts_python/ae_visualize_streamlines.py diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py new file mode 100644 index 00000000..a2cb8881 --- /dev/null +++ b/scripts_python/ae_visualize_streamlines.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +import logging +import os.path + +import torch + +from scilpy.io.utils import add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_reference_arg +from scilpy.io.streamlines import load_tractogram_with_reference +from dipy.io.streamline import save_tractogram + +from dwi_ml.io_utils import add_logging_arg, add_arg_existing_experiment_path, add_memory_args +from dwi_ml.models.projects.learn2track_model import Learn2TrackModel +from dwi_ml.testing.testers import TesterOneInput +from dwi_ml.testing.visu_loss import \ + prepare_args_visu_loss, pick_a_few, run_visu_save_colored_displacement +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.testing.testers import Tester +from dwi_ml.testing.utils import add_args_testing_subj_hdf5 + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + # Mandatory + # Should only be False for debugging tests. + add_arg_existing_experiment_path(p) + #add_args_testing_subj_hdf5(p) + + p.add_argument('in_tractogram', + help="If set, saves the tractogram with the loss per point " + "as a data per point (color)") + + p.add_argument('out_tractogram', + help="If set, saves the tractogram with the loss per point " + "as a data per point (color)") + + # Options + p.add_argument('--batch_size', type=int) + add_memory_args(p) + + p.add_argument('--pick_at_random', action='store_true') + add_reference_arg(p) + add_overwrite_arg(p) + add_logging_arg(p) + return p + + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + # Loggers + sub_logger_level = args.logging.upper() + if sub_logger_level == 'DEBUG': + sub_logger_level = 'INFO' + logging.getLogger().setLevel(level=args.logging) + + # Verify output names + # Check experiment_path exists and best_model folder exists + #assert_inputs_exist(p, args.hdf5_file) + assert_outputs_exist(p, args, args.out_tractogram) + + # Device + device = (torch.device('cuda') if torch.cuda.is_available() and + args.use_gpu else None) + + # 1. Load model + logging.debug("Loading model.") + model = ModelAE.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_logger_level) + #model.set_context('training') + # 2. Compute loss + #tester = TesterOneInput(args.experiment_path, model, args.batch_size, device) + #tester = Tester(args.experiment_path, model, args.batch_size, device) + #sft = tester.load_and_format_data(args.subj_id, args.hdf5_file, args.subset) + + sft = load_tractogram_with_reference(p, args, args.in_tractogram) + sft.to_vox() + sft.to_corner() + bundle = sft.streamlines[0:5000] + + logging.info("Running model to compute loss") + + new_sft = sft.from_sft(bundle, sft) + save_tractogram(new_sft, 'orig_5000.trk') + + with torch.no_grad(): + streamlines = [ + torch.as_tensor(s, dtype=torch.float32, device=device) + for s in bundle] + tmp_outputs = model(streamlines) + #latent = model.encode(streamlines) + + streamlines_output = [tmp_outputs[i, :, :].transpose(0,1).cpu().numpy() for i in range(len(bundle))] + + #print(streamlines_output[0].shape) + new_sft = sft.from_sft(streamlines_output, sft) + save_tractogram(new_sft, args.out_tractogram) + + #latent_output = [s.cpu().numpy() for s in latent] + + + + #outputs, losses = tester.run_model_on_sft( + # sft, uncompress_loss=args.uncompress_loss, + # force_compress_loss=args.force_compress_loss, + # weight_with_angle=args.weight_with_angle) + +if __name__ == '__main__': + main() From a341891052614bf2563f730c7d2a9d7b150101dd Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Fri, 13 Sep 2024 15:14:28 -0400 Subject: [PATCH 04/14] fix pep8 --- dwi_ml/models/projects/ae_models.py | 4 +- scripts_python/ae_visualize_streamlines.py | 50 +++++++++++----------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 7acaaa40..66986da2 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -108,7 +108,7 @@ def params_for_checkpoint(self): will be used to re-create the model when starting an experiment from checkpoint. You should be able to re-create an instance of your model with those params.""" - #p = super().params_for_checkpoint() + # p = super().params_for_checkpoint() p = {'kernel_size': self.kernel_size, 'latent_space_dims': self.latent_space_dims, 'experiment_name': self.experiment_name} @@ -120,7 +120,6 @@ def _load_params(cls, model_dir): p['kernel_size'] = 3 p['latent_space_dims'] = 32 return p - def forward(self, input_streamlines: List[torch.tensor], @@ -141,7 +140,6 @@ def forward(self, `get_tracking_directions()`. """ - x = self.decode(self.encode(input_streamlines)) return x diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index a2cb8881..1d2e0f3f 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -2,22 +2,20 @@ # -*- coding: utf-8 -*- import argparse import logging -import os.path import torch -from scilpy.io.utils import add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_reference_arg +from scilpy.io.utils import (add_overwrite_arg, + assert_outputs_exist, + add_reference_arg) from scilpy.io.streamlines import load_tractogram_with_reference from dipy.io.streamline import save_tractogram -from dwi_ml.io_utils import add_logging_arg, add_arg_existing_experiment_path, add_memory_args -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.testing.testers import TesterOneInput -from dwi_ml.testing.visu_loss import \ - prepare_args_visu_loss, pick_a_few, run_visu_save_colored_displacement +from dwi_ml.io_utils import (add_logging_arg, + add_arg_existing_experiment_path, + add_memory_args) from dwi_ml.models.projects.ae_models import ModelAE -from dwi_ml.testing.testers import Tester -from dwi_ml.testing.utils import add_args_testing_subj_hdf5 + def _build_arg_parser(): p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, @@ -25,7 +23,7 @@ def _build_arg_parser(): # Mandatory # Should only be False for debugging tests. add_arg_existing_experiment_path(p) - #add_args_testing_subj_hdf5(p) + # Add_args_testing_subj_hdf5(p) p.add_argument('in_tractogram', help="If set, saves the tractogram with the loss per point " @@ -58,7 +56,7 @@ def main(): # Verify output names # Check experiment_path exists and best_model folder exists - #assert_inputs_exist(p, args.hdf5_file) + # Assert_inputs_exist(p, args.hdf5_file) assert_outputs_exist(p, args, args.out_tractogram) # Device @@ -69,11 +67,16 @@ def main(): logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( args.experiment_path + '/best_model', log_level=sub_logger_level) - #model.set_context('training') + # model.set_context('training') # 2. Compute loss - #tester = TesterOneInput(args.experiment_path, model, args.batch_size, device) - #tester = Tester(args.experiment_path, model, args.batch_size, device) - #sft = tester.load_and_format_data(args.subj_id, args.hdf5_file, args.subset) + # tester = TesterOneInput(args.experiment_path, + # model, + # args.batch_size, + # device) + # tester = Tester(args.experiment_path, model, args.batch_size, device) + # sft = tester.load_and_format_data(args.subj_id, + # args.hdf5_file, + # args.subset) sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() @@ -90,22 +93,21 @@ def main(): torch.as_tensor(s, dtype=torch.float32, device=device) for s in bundle] tmp_outputs = model(streamlines) - #latent = model.encode(streamlines) - - streamlines_output = [tmp_outputs[i, :, :].transpose(0,1).cpu().numpy() for i in range(len(bundle))] - - #print(streamlines_output[0].shape) - new_sft = sft.from_sft(streamlines_output, sft) - save_tractogram(new_sft, args.out_tractogram) + # latent = model.encode(streamlines) - #latent_output = [s.cpu().numpy() for s in latent] + streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] + # print(streamlines_output[0].shape) + new_sft = sft.from_sft(streamlines_output, sft) + save_tractogram(new_sft, args.out_tractogram) + # latent_output = [s.cpu().numpy() for s in latent] - #outputs, losses = tester.run_model_on_sft( + # outputs, losses = tester.run_model_on_sft( # sft, uncompress_loss=args.uncompress_loss, # force_compress_loss=args.force_compress_loss, # weight_with_angle=args.weight_with_angle) + if __name__ == '__main__': main() From 72789f3fa7a04a18157f2264bad9a36f74164ed3 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Fri, 13 Sep 2024 15:21:07 -0400 Subject: [PATCH 05/14] answer em comments from nov 2023 --- dwi_ml/models/projects/ae_models.py | 3 --- dwi_ml/training/batch_loaders.py | 4 ++-- dwi_ml/training/trainers.py | 37 ++++++++++++----------------- scripts_python/ae_train_model.py | 4 ++-- 4 files changed, 19 insertions(+), 29 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 66986da2..0818b6ad 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -98,9 +98,6 @@ def pre_pad(m): torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) ) - self.forward_uses_streamlines = True - self.loss_uses_streamlines = True - @property def params_for_checkpoint(self): """All parameters necessary to create again the same model. Will be diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index d82bed39..9ac3afaf 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -58,7 +58,7 @@ logger = logging.getLogger('batch_loader_logger') -class DWIMLAbstractBatchLoader: +class DWIMLStreamlinesBatchLoader: def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, streamline_group_name: str, rng: int, split_ratio: float = 0., @@ -355,7 +355,7 @@ def load_batch_connectivity_matrices( connectivity_nb_blocs, connectivity_labels) -class DWIMLBatchLoaderOneInput(DWIMLAbstractBatchLoader): +class DWIMLStreamlinesBatchLoaderOneInput(DWIMLStreamlinesBatchLoader): """ Loads: input = one volume group diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index a86442db..fa65cee1 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -1034,31 +1034,24 @@ def run_one_batch(self, data): # Possibly add noise to inputs here. logger.debug('*** Computing forward propagation') - if self.model.forward_uses_streamlines: - # Now possibly add noise to streamlines (training / valid) - streamlines_f = self.batch_loader.add_noise_streamlines_forward( - streamlines_f, self.device) - - # Possibly computing directions twice (during forward and loss) - # but ok, shouldn't be too heavy. Easier to deal with multiple - # projects' requirements by sending whole streamlines rather - # than only directions. - model_outputs = self.model(streamlines_f) - del streamlines_f - else: - del streamlines_f - model_outputs = self.model() + + # Now possibly add noise to streamlines (training / valid) + streamlines_f = self.batch_loader.add_noise_streamlines_forward( + streamlines_f, self.device) + + # Possibly computing directions twice (during forward and loss) + # but ok, shouldn't be too heavy. Easier to deal with multiple + # projects' requirements by sending whole streamlines rather + # than only directions. + model_outputs = self.model(streamlines_f) + del streamlines_f logger.debug('*** Computing loss') - if self.model.loss_uses_streamlines: - targets = self.batch_loader.add_noise_streamlines_loss( - targets, self.device) + targets = self.batch_loader.add_noise_streamlines_loss( + targets, self.device) - results = self.model.compute_loss(model_outputs, targets, - average_results=True) - else: - results = self.model.compute_loss(model_outputs, - average_results=True) + results = self.model.compute_loss(model_outputs, targets, + average_results=True) if self.use_gpu: log_gpu_memory_usage(logger) diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 9b58e77d..6bcd6046 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -25,7 +25,7 @@ from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, prepare_batch_sampler) from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) -from dwi_ml.training.batch_loaders import DWIMLAbstractBatchLoader +from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader from dwi_ml.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ @@ -82,7 +82,7 @@ def init_from_args(args, sub_loggers_level): batch_sampler = prepare_batch_sampler(dataset, args, sub_loggers_level) # Preparing the batch loader. with Timer("\nPreparing batch loader...", newline=True, color='pink'): - batch_loader = DWIMLAbstractBatchLoader( + batch_loader = DWIMLStreamlinesBatchLoader( dataset=dataset, model=model, streamline_group_name=args.streamline_group_name, # OTHER From 02cab7bb2be3a93c25f0679bcbd83388ff2fe077 Mon Sep 17 00:00:00 2001 From: bora2502 Date: Tue, 17 Sep 2024 10:06:12 -0400 Subject: [PATCH 06/14] fix naming class --- dwi_ml/training/batch_loaders.py | 2 +- dwi_ml/training/trainers.py | 8 ++++---- scripts_python/ae_train_model.py | 16 ++++++++-------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 9ac3afaf..61f13640 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -355,7 +355,7 @@ def load_batch_connectivity_matrices( connectivity_nb_blocs, connectivity_labels) -class DWIMLStreamlinesBatchLoaderOneInput(DWIMLStreamlinesBatchLoader): +class DWIMLBatchLoaderOneInput(DWIMLStreamlinesBatchLoader): """ Loads: input = one volume group diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index fa65cee1..2aae19ac 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -17,7 +17,7 @@ from dwi_ml.models.main_models import (MainModelAbstract, ModelWithDirectionGetter) from dwi_ml.training.batch_loaders import ( - DWIMLAbstractBatchLoader, DWIMLBatchLoaderOneInput) + DWIMLStreamlinesBatchLoader, DWIMLBatchLoaderOneInput) from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.utils.gradient_norm import compute_gradient_norm from dwi_ml.training.utils.monitoring import ( @@ -53,7 +53,7 @@ class DWIMLAbstractTrainer: def __init__(self, model: MainModelAbstract, experiments_path: str, experiment_name: str, batch_sampler: DWIMLBatchIDSampler, - batch_loader: DWIMLAbstractBatchLoader, + batch_loader: DWIMLStreamlinesBatchLoader, learning_rates: Union[List, float] = None, weight_decay: float = 0.01, optimizer: str = 'Adam', max_epochs: int = 10, @@ -78,7 +78,7 @@ def __init__(self, batch_sampler: DWIMLBatchIDSampler Instantiated class used for sampling batches. Data in batch_sampler.dataset must be already loaded. - batch_loader: DWIMLAbstractBatchLoader + batch_loader: DWIMLStreamlinesBatchLoader Instantiated class with a load_batch method able to load data associated to sampled batch ids. Data in batch_sampler.dataset must be already loaded. @@ -461,7 +461,7 @@ def _prepare_checkpoint_info(self) -> dict: def init_from_checkpoint( cls, model: MainModelAbstract, experiments_path, experiment_name, batch_sampler: DWIMLBatchIDSampler, - batch_loader: DWIMLAbstractBatchLoader, + batch_loader: DWIMLStreamlinesBatchLoader, checkpoint_state: dict, new_patience, new_max_epochs, log_level): """ Loads checkpoint information (parameters and states) to instantiate diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 6bcd6046..a291fec4 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -14,12 +14,12 @@ import comet_ml import torch -from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist +from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, add_verbose_arg from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_logging_arg, add_memory_args +from dwi_ml.io_utils import add_memory_args from dwi_ml.models.projects.ae_models import ModelAE from dwi_ml.training.trainers import DWIMLAbstractTrainer from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, @@ -38,17 +38,17 @@ def prepare_arg_parser(): add_mandatory_args_experiment_and_hdf5_path(p) add_args_batch_sampler(p) add_args_batch_loader(p) - training_group = add_training_args(p) + #training_group = add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") add_memory_args(p, add_lazy_options=True, add_rng=True) - add_logging_arg(p) + add_verbose_arg(p) # Additional arg for projects - training_group.add_argument( - '--clip_grad', type=float, default=None, - help="Value to which the gradient norms to avoid exploding gradients." - "\nDefault = None (not clipping).") + #training_group.add_argument( + # '--clip_grad', type=float, default=None, + # help="Value to which the gradient norms to avoid exploding gradients." + # "\nDefault = None (not clipping).") return p From 77918758d8fdbe2a3c29e2be1351798e09599669 Mon Sep 17 00:00:00 2001 From: bora2502 Date: Tue, 17 Sep 2024 13:55:00 -0400 Subject: [PATCH 07/14] fix script --- scripts_python/ae_train_model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index a291fec4..20d416b1 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -25,6 +25,7 @@ from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, prepare_batch_sampler) from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) +from dwi_ml.training.utils.trainer import add_training_args from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader from dwi_ml.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) @@ -39,6 +40,7 @@ def prepare_arg_parser(): add_args_batch_sampler(p) add_args_batch_loader(p) #training_group = add_training_args(p) + add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") add_memory_args(p, add_lazy_options=True, add_rng=True) @@ -110,7 +112,7 @@ def init_from_args(args, sub_loggers_level): from_checkpoint=False, clip_grad=args.clip_grad, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, - log_level=args.logging) + log_level=sub_loggers_level) logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint)) @@ -123,11 +125,10 @@ def main(): # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, # but we will set trainer to user-defined level. - sub_loggers_level = args.logging - if args.logging == 'DEBUG': - sub_loggers_level = 'INFO' + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' - logging.getLogger().setLevel(level=logging.INFO) + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) # Check that all files exist assert_inputs_exist(p, [args.hdf5_file]) From f4701be2c284b95ea991ecd95d433eae43a4993a Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Wed, 18 Sep 2024 10:29:25 -0400 Subject: [PATCH 08/14] fix viz --- scripts_python/ae_visualize_streamlines.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 1d2e0f3f..ef942d96 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -7,13 +7,12 @@ from scilpy.io.utils import (add_overwrite_arg, assert_outputs_exist, - add_reference_arg) + add_reference_arg, + add_verbose_arg) from scilpy.io.streamlines import load_tractogram_with_reference from dipy.io.streamline import save_tractogram - -from dwi_ml.io_utils import (add_logging_arg, - add_arg_existing_experiment_path, - add_memory_args) +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_memory_args) from dwi_ml.models.projects.ae_models import ModelAE @@ -40,7 +39,7 @@ def _build_arg_parser(): p.add_argument('--pick_at_random', action='store_true') add_reference_arg(p) add_overwrite_arg(p) - add_logging_arg(p) + add_verbose_arg(p) return p From c4bd181c16a2b80bcb768b4ac1e2aa4eee63da5b Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Wed, 18 Sep 2024 10:44:24 -0400 Subject: [PATCH 09/14] quick fix ae_vis_streamline --- scripts_python/ae_visualize_streamlines.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index ef942d96..31526453 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -47,11 +47,12 @@ def main(): p = _build_arg_parser() args = p.parse_args() - # Loggers - sub_logger_level = args.logging.upper() - if sub_logger_level == 'DEBUG': - sub_logger_level = 'INFO' - logging.getLogger().setLevel(level=args.logging) + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) # Verify output names # Check experiment_path exists and best_model folder exists @@ -65,7 +66,7 @@ def main(): # 1. Load model logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_logger_level) + args.experiment_path + '/best_model', log_level=sub_loggers_level) # model.set_context('training') # 2. Compute loss # tester = TesterOneInput(args.experiment_path, From 448043b17b33480c5022887ba31b788496846897 Mon Sep 17 00:00:00 2001 From: Antoine Theberge Date: Thu, 19 Sep 2024 07:54:19 -0400 Subject: [PATCH 10/14] WIP: transformer ae --- command_ae.sh | 22 +++++ dwi_ml/models/projects/ae_models.py | 101 ++++++++++++++++----- scripts_python/ae_visualize_streamlines.py | 4 +- 3 files changed, 103 insertions(+), 24 deletions(-) create mode 100755 command_ae.sh diff --git a/command_ae.sh b/command_ae.sh new file mode 100755 index 00000000..17f3e3de --- /dev/null +++ b/command_ae.sh @@ -0,0 +1,22 @@ +experiments=experiments +experiment_name=fibercup_september24 + +rm -rf $experiments/$experiment_name + +ae_train_model.py $experiments \ + $experiment_name \ + fibercup.hdf5 \ + target \ + -v INFO \ + --batch_size_training 80 \ + --batch_size_units nb_streamlines \ + --nb_subjects_per_batch 1 \ + --learning_rate 0.001 \ + --weight_decay 0.05 \ + --optimizer Adam \ + --max_epochs 1000 \ + --max_batches_per_epoch_training 20 \ + --comet_workspace dwi_ml \ + --comet_project ae-fibercup \ + --patience 100 \ + --use_gpu diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 0818b6ad..c4703406 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -1,13 +1,48 @@ # -*- coding: utf-8 -*- import logging +import math + from typing import List import torch +from torch import nn +from torch import Tensor from torch.nn import functional as F from dwi_ml.models.main_models import MainModelAbstract +class PositionalEncoding(nn.Module): + """ Modified from + https://pytorch.org/tutorials/beginner/transformer_tutorial.htm://pytorch.org/tutorials/beginner/transformer_tutorial.html # noqa E504 + """ + + def __init__( + self, d_model: int, dropout: float = 0.1, max_len: int = 5000 + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) + * (-math.log(10000.0) / d_model)) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x: Tensor) -> Tensor: + """ + Arguments: + x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` + """ + x = x.permute(1, 0, 2) + x = x + self.pe[:x.size(0)] + x = self.dropout(x) + x = x.permute(1, 0, 2) + return x + + class ModelAE(MainModelAbstract): """ Recurrent tracking model. @@ -27,23 +62,45 @@ def __init__(self, kernel_size, latent_space_dims, log_level=logging.root.level): super().__init__(experiment_name, step_size, compress_lines, log_level) - self.kernel_size = kernel_size - self.latent_space_dims = latent_space_dims + # Embedding size, could be defined by the user ? + self.embedding_size = 32 + # Embedding layer + self.embedding = nn.Sequential( + *(nn.Linear(3, self.embedding_size), + nn.ReLU())) - self.pad = torch.nn.ReflectionPad1d(1) + # Positional encoding layer + self.pos_encoding = PositionalEncoding( + self.embedding_size, max_len=(256)) + # Transformer encoder layer + layer = nn.TransformerEncoderLayer( + self.embedding_size, 4, batch_first=True) - def pre_pad(m): - return torch.nn.Sequential(self.pad, m) + # Transformer encoder + self.encoder = nn.TransformerEncoder(layer, 2) + self.decoder = nn.TransformerEncoder(layer, 2) + + self.reconstruction_loss = torch.nn.MSELoss() + + self.pad = torch.nn.ReflectionPad1d(1) + self.kernel_size = kernel_size + self.latent_space_dims = latent_space_dims self.fc1 = torch.nn.Linear(8192, self.latent_space_dims) # 8192 = 1024*8 self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + self.fc3 = torch.nn.Linear(self.embedding_size, 3) + + def pre_pad(m): + return torch.nn.Sequential(self.pad, m) + """ Encode convolutions """ self.encod_conv1 = pre_pad( - torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0) + torch.nn.Conv1d(self.embedding_size, 32, + self.kernel_size, stride=2, padding=0) ) self.encod_conv2 = pre_pad( torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) @@ -95,7 +152,8 @@ def pre_pad(m): scale_factor=2, mode="linear", align_corners=False ) self.decod_conv6 = pre_pad( - torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) + torch.nn.Conv1d(32, 32, + self.kernel_size, stride=1, padding=0) ) @property @@ -143,6 +201,11 @@ def forward(self, def encode(self, x): # x: list of tensors x = torch.stack(x) + + x = self.embedding(x) * math.sqrt(self.embedding_size) + x = self.pos_encoding(x) + x = self.encoder(x) + x = torch.swapaxes(x, 1, 2) h1 = F.relu(self.encod_conv1(x)) @@ -162,6 +225,7 @@ def encode(self, x): return fc1 def decode(self, z): + fc = self.fc2(z) fc_reshape = fc.view( -1, self.encoder_out_size[0], self.encoder_out_size[1] @@ -178,24 +242,17 @@ def decode(self, z): h10 = self.upsampl5(h9) h11 = self.decod_conv6(h10) - return h11 + h11 = h11.permute(0, 2, 1) + + h12 = self.decoder(h11) + + x = self.fc3(h12) + + return x.permute(0, 2, 1) def compute_loss(self, model_outputs, targets, average_results=True): - print("COMPARISON\n") targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) - print(targets[0, :, 0:5]) - print(model_outputs[0, :, 0:5]) - reconstruction_loss = torch.nn.MSELoss(reduction="sum") - mse = reconstruction_loss(model_outputs, targets) - - # loss_function_vae - # See Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) - # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) - # kld = torch.sum(kld_element).__mul__(-0.5) + mse = self.reconstruction_loss(model_outputs, targets) return mse, 1 diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 31526453..37e95352 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -66,7 +66,7 @@ def main(): # 1. Load model logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level) + args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) # model.set_context('training') # 2. Compute loss # tester = TesterOneInput(args.experiment_path, @@ -81,7 +81,7 @@ def main(): sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() sft.to_corner() - bundle = sft.streamlines[0:5000] + bundle = sft.streamlines logging.info("Running model to compute loss") From 7bf9fe308ae36cd8fe0f852d4557fd1c38ad1179 Mon Sep 17 00:00:00 2001 From: Antoine Theberge Date: Thu, 19 Sep 2024 07:55:36 -0400 Subject: [PATCH 11/14] Revert "WIP: transformer ae" This reverts commit 448043b17b33480c5022887ba31b788496846897. --- command_ae.sh | 22 ----- dwi_ml/models/projects/ae_models.py | 101 +++++---------------- scripts_python/ae_visualize_streamlines.py | 4 +- 3 files changed, 24 insertions(+), 103 deletions(-) delete mode 100755 command_ae.sh diff --git a/command_ae.sh b/command_ae.sh deleted file mode 100755 index 17f3e3de..00000000 --- a/command_ae.sh +++ /dev/null @@ -1,22 +0,0 @@ -experiments=experiments -experiment_name=fibercup_september24 - -rm -rf $experiments/$experiment_name - -ae_train_model.py $experiments \ - $experiment_name \ - fibercup.hdf5 \ - target \ - -v INFO \ - --batch_size_training 80 \ - --batch_size_units nb_streamlines \ - --nb_subjects_per_batch 1 \ - --learning_rate 0.001 \ - --weight_decay 0.05 \ - --optimizer Adam \ - --max_epochs 1000 \ - --max_batches_per_epoch_training 20 \ - --comet_workspace dwi_ml \ - --comet_project ae-fibercup \ - --patience 100 \ - --use_gpu diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index c4703406..0818b6ad 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -1,48 +1,13 @@ # -*- coding: utf-8 -*- import logging -import math - from typing import List import torch -from torch import nn -from torch import Tensor from torch.nn import functional as F from dwi_ml.models.main_models import MainModelAbstract -class PositionalEncoding(nn.Module): - """ Modified from - https://pytorch.org/tutorials/beginner/transformer_tutorial.htm://pytorch.org/tutorials/beginner/transformer_tutorial.html # noqa E504 - """ - - def __init__( - self, d_model: int, dropout: float = 0.1, max_len: int = 5000 - ): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) - * (-math.log(10000.0) / d_model)) - pe = torch.zeros(max_len, 1, d_model) - pe[:, 0, 0::2] = torch.sin(position * div_term) - pe[:, 0, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) - - def forward(self, x: Tensor) -> Tensor: - """ - Arguments: - x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` - """ - x = x.permute(1, 0, 2) - x = x + self.pe[:x.size(0)] - x = self.dropout(x) - x = x.permute(1, 0, 2) - return x - - class ModelAE(MainModelAbstract): """ Recurrent tracking model. @@ -62,45 +27,23 @@ def __init__(self, kernel_size, latent_space_dims, log_level=logging.root.level): super().__init__(experiment_name, step_size, compress_lines, log_level) - # Embedding size, could be defined by the user ? - self.embedding_size = 32 - # Embedding layer - self.embedding = nn.Sequential( - *(nn.Linear(3, self.embedding_size), - nn.ReLU())) - - # Positional encoding layer - self.pos_encoding = PositionalEncoding( - self.embedding_size, max_len=(256)) - # Transformer encoder layer - layer = nn.TransformerEncoderLayer( - self.embedding_size, 4, batch_first=True) - - # Transformer encoder - self.encoder = nn.TransformerEncoder(layer, 2) - self.decoder = nn.TransformerEncoder(layer, 2) - - self.reconstruction_loss = torch.nn.MSELoss() - - self.pad = torch.nn.ReflectionPad1d(1) self.kernel_size = kernel_size self.latent_space_dims = latent_space_dims - self.fc1 = torch.nn.Linear(8192, - self.latent_space_dims) # 8192 = 1024*8 - self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) - - self.fc3 = torch.nn.Linear(self.embedding_size, 3) + self.pad = torch.nn.ReflectionPad1d(1) def pre_pad(m): return torch.nn.Sequential(self.pad, m) + self.fc1 = torch.nn.Linear(8192, + self.latent_space_dims) # 8192 = 1024*8 + self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + """ Encode convolutions """ self.encod_conv1 = pre_pad( - torch.nn.Conv1d(self.embedding_size, 32, - self.kernel_size, stride=2, padding=0) + torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0) ) self.encod_conv2 = pre_pad( torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) @@ -152,8 +95,7 @@ def pre_pad(m): scale_factor=2, mode="linear", align_corners=False ) self.decod_conv6 = pre_pad( - torch.nn.Conv1d(32, 32, - self.kernel_size, stride=1, padding=0) + torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) ) @property @@ -201,11 +143,6 @@ def forward(self, def encode(self, x): # x: list of tensors x = torch.stack(x) - - x = self.embedding(x) * math.sqrt(self.embedding_size) - x = self.pos_encoding(x) - x = self.encoder(x) - x = torch.swapaxes(x, 1, 2) h1 = F.relu(self.encod_conv1(x)) @@ -225,7 +162,6 @@ def encode(self, x): return fc1 def decode(self, z): - fc = self.fc2(z) fc_reshape = fc.view( -1, self.encoder_out_size[0], self.encoder_out_size[1] @@ -242,17 +178,24 @@ def decode(self, z): h10 = self.upsampl5(h9) h11 = self.decod_conv6(h10) - h11 = h11.permute(0, 2, 1) - - h12 = self.decoder(h11) - - x = self.fc3(h12) - - return x.permute(0, 2, 1) + return h11 def compute_loss(self, model_outputs, targets, average_results=True): + print("COMPARISON\n") targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) - mse = self.reconstruction_loss(model_outputs, targets) + print(targets[0, :, 0:5]) + print(model_outputs[0, :, 0:5]) + reconstruction_loss = torch.nn.MSELoss(reduction="sum") + mse = reconstruction_loss(model_outputs, targets) + + # loss_function_vae + # See Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) + # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) + # kld = torch.sum(kld_element).__mul__(-0.5) return mse, 1 diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 37e95352..31526453 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -66,7 +66,7 @@ def main(): # 1. Load model logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + args.experiment_path + '/best_model', log_level=sub_loggers_level) # model.set_context('training') # 2. Compute loss # tester = TesterOneInput(args.experiment_path, @@ -81,7 +81,7 @@ def main(): sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() sft.to_corner() - bundle = sft.streamlines + bundle = sft.streamlines[0:5000] logging.info("Running model to compute loss") From 5149706fa48aa6382c7d7a09d5a01b496686da04 Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 20 Sep 2024 12:51:58 -0400 Subject: [PATCH 12/14] Latent space visualization integration --- dwi_ml/models/projects/ae_models.py | 34 ++-- dwi_ml/viz/latent_streamlines.py | 183 +++++++++++++++++++++ scripts_python/ae_visualize_bundles.py | 106 ++++++++++++ scripts_python/ae_visualize_streamlines.py | 44 ++--- 4 files changed, 322 insertions(+), 45 deletions(-) create mode 100644 dwi_ml/viz/latent_streamlines.py create mode 100644 scripts_python/ae_visualize_bundles.py diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 0818b6ad..50aace59 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -31,6 +31,7 @@ def __init__(self, kernel_size, latent_space_dims, self.latent_space_dims = latent_space_dims self.pad = torch.nn.ReflectionPad1d(1) + self.post_encoding_hooks = [] def pre_pad(m): return torch.nn.Sequential(self.pad, m) @@ -137,13 +138,20 @@ def forward(self, `get_tracking_directions()`. """ - x = self.decode(self.encode(input_streamlines)) + encoded = self.encode(input_streamlines) + + for hook in self.post_encoding_hooks: + hook(encoded) + + x = self.decode(encoded) return x def encode(self, x): - # x: list of tensors - x = torch.stack(x) - x = torch.swapaxes(x, 1, 2) + # X input shape is (batch_size, nb_points, 3) + if isinstance(x, list): + x = torch.stack(x) + + x = torch.swapaxes(x, 1, 2) # input of the network should be (N, 3, nb_points) h1 = F.relu(self.encod_conv1(x)) h2 = F.relu(self.encod_conv2(h1)) @@ -181,21 +189,13 @@ def decode(self, z): return h11 def compute_loss(self, model_outputs, targets, average_results=True): - print("COMPARISON\n") targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) - print(targets[0, :, 0:5]) - print(model_outputs[0, :, 0:5]) - reconstruction_loss = torch.nn.MSELoss(reduction="sum") + reconstruction_loss = torch.nn.MSELoss() mse = reconstruction_loss(model_outputs, targets) - # loss_function_vae - # See Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) - # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) - # kld = torch.sum(kld_element).__mul__(-0.5) - return mse, 1 + + def register_hook_post_encoding(self, hook): + self.post_encoding_hooks.append(hook) + diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py new file mode 100644 index 00000000..297bced6 --- /dev/null +++ b/dwi_ml/viz/latent_streamlines.py @@ -0,0 +1,183 @@ +import logging + +from typing import Union, List, Tuple +from sklearn.manifold import TSNE +import numpy as np +import torch + +import matplotlib.pyplot as plt + +def plot_latent_streamlines( + encoded_streamlines: Union[np.ndarray, torch.Tensor], + save_path: str = None, + fig_size: Union[List, Tuple] = None, + random_state: int = 42, + max_subset_size: int = None + ): + """ + Projects and plots the latent space representation + of the streamlines using t-SNE dimensionality reduction. + + Parameters + ---------- + encoded_streamlines: Union[np.ndarray, torch.Tensor] + Latent space streamlines to plot of shape (N, latent_space_dim). + save_path: str + Path to save the figure. If not specified, the figure will be shown. + fig_size: List[int] or Tuple[int] + 2-valued figure size (x, y) + random_state: int + Random state for t-SNE. + max_subset_size: int: + In case of performance issues, you can limit the number of streamlines to plot. + """ + + if isinstance(encoded_streamlines, torch.Tensor): + latent_space_streamlines = encoded_streamlines.cpu().numpy() + else: + latent_space_streamlines = encoded_streamlines + + if max_subset_size is not None: + if not (max_subset_size > 0): + raise ValueError("A max_subset_size of an integer value greater than 0 is required.") + + # Only sample if we need to reduce the number of latent streamlines + # to show on the plot. + if (len(latent_space_streamlines) > max_subset_size): + sample_indices = np.random.choice(len(latent_space_streamlines), max_subset_size, replace=False) + latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) + + # Project the data into 2 dimensions. + tsne = TSNE(n_components=2, random_state=random_state) + X_tsne = tsne.fit_transform(latent_space_streamlines) # Output (N, 2) + + + logging.info("New figure for t-SNE visualisation.") + fig, ax = plt.subplots() + if fig_size is not None: + fig.set_figheight(fig_size[0]) + fig.set_figwidth(fig_size[1]) + + ax.scatter(X_tsne[:, 0], X_tsne[:, 1], alpha=0.9, edgecolors='black', linewidths=0.5) + + if save_path is not None: + fig.savefig(save_path) + else: + plt.show() + + +class BundlesLatentSpaceVisualizer(object): + """ + Utility class that wraps a t-SNE projection of the latent space for multiple bundles. + The usage of this class is intented as follows: + 1. Create an instance of this class, + 2. Add the latent space streamlines for each bundle using "add_data_to_plot" + with its corresponding label. + 3. Fit and plot the t-SNE projection using the "plot" method. + + t-SNE projection can only leverage the fit_transform() with all the data that needs to + be projected at the same time since it aims to preserve the local structure of the data. + """ + def __init__(self, + save_path: str = None, + fig_size: Union[List, Tuple] = None, + random_state: int = 42, + max_subset_size: int = None + ): + """ + Parameters + ---------- + save_path: str + Path to save the figure. If not specified, the figure will be shown. + fig_size: List[int] or Tuple[int] + 2-valued figure size (x, y) + random_state: List + Random state for t-SNE. + max_subset_size: + In case of performance issues, you can limit the number of streamlines to plot + for each bundle. + """ + self.save_path = save_path + self.fig_size = fig_size + self.random_state = random_state + self.max_subset_size = max_subset_size + + self.tsne = TSNE(n_components=2, random_state=self.random_state) + self.bundles = {} + + + def add_data_to_plot(self, data: np.ndarray, label: str = '_'): + """ + Add unprojected data (no t-SNE, no PCA, etc.). + This should be directly the output of the model as a numpy array. + + Parameters + ---------- + data: str + Unprojected latent space streamlines (N, latent_space_dim). + label: str + Name of the bundle. Used for the legend. + """ + if isinstance(data, torch.Tensor): + latent_space_streamlines = data.cpu().numpy() + else: + latent_space_streamlines = data + + if self.max_subset_size is not None: + if not (self.max_subset_size > 0): + raise ValueError("A max_subset_size of an integer value greater than 0 is required.") + + # Only sample if we need to reduce the number of latent streamlines + # to show on the plot. + if (len(latent_space_streamlines) > self.max_subset_size): + sample_indices = np.random.choice(len(latent_space_streamlines), self.max_subset_size, replace=False) + latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) + + self.bundles[label] = latent_space_streamlines + + def plot(self): + """ + Fit and plot the t-SNE projection of the latent space streamlines. + This should be called once after adding all the data to plot using "add_data_to_plot". + """ + nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) + logging.info("Plotting a total of {} streamlines".format(nb_streamlines)) + + bundles_indices = {} + current_start = 0 + for (bname, bdata) in self.bundles.items(): + bundles_indices[bname] = np.arange(current_start, current_start + bdata.shape[0]) + current_start += bdata.shape[0] + + assert current_start == nb_streamlines + + all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) + + logging.info("Fitting TSNE projection.") + all_projected_streamlines = self.tsne.fit_transform(all_streamlines) + + logging.info("New figure for t-SNE visualisation.") + fig, ax = plt.subplots() + if self.fig_size is not None: + fig.set_figheight(self.fig_size[0]) + fig.set_figwidth(self.fig_size[1]) + + for (bname, bdata) in self.bundles.items(): + bindices = bundles_indices[bname] + proj_data = all_projected_streamlines[bindices] + ax.scatter( + proj_data[:, 0], + proj_data[:, 1], + label=bname, + alpha=0.9, + edgecolors='black', + linewidths=0.5, + ) + + ax.legend() + + if self.save_path is not None: + fig.savefig(self.save_path) + else: + plt.show() + diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py new file mode 100644 index 00000000..c4b7c0a4 --- /dev/null +++ b/scripts_python/ae_visualize_bundles.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +import logging +import pathlib +import torch +import numpy as np +from glob import glob +from os.path import expanduser +from dipy.tracking.streamline import set_number_of_points + +from scilpy.io.utils import (add_overwrite_arg, + assert_outputs_exist, + add_reference_arg, + add_verbose_arg) +from scilpy.io.streamlines import load_tractogram_with_reference +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_memory_args) +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer + + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + # Mandatory + # Should only be False for debugging tests. + add_arg_existing_experiment_path(p) + # Add_args_testing_subj_hdf5(p) + + p.add_argument('in_bundles', + help="The 'glob' path to several bundles identified by their file name." + "e.g. FiberCupGroundTruth_filtered_bundle_0.tck") + + # Options + p.add_argument('--batch_size', type=int) + add_memory_args(p) + + p.add_argument('--pick_at_random', action='store_true') + add_reference_arg(p) + add_overwrite_arg(p) + add_verbose_arg(p) + return p + +def load_bundles(p, args, files_list: list): + bundles = [] + for bundle_file in files_list: + bundle_sft = load_tractogram_with_reference(p, args, bundle_file) + bundle_sft.to_vox() + bundle_sft.to_corner() + bundles.append(bundle_sft) + return bundles + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + # Verify output names + # Check experiment_path exists and best_model folder exists + # Assert_inputs_exist(p, args.hdf5_file) + assert_outputs_exist(p, args, []) + + # Device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 1. Load model + logging.debug("Loading model.") + model = ModelAE.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + + expanded = expanduser(args.in_bundles) + bundles_files = glob(expanded) + if isinstance(bundles_files, str): + bundles_files = [bundles_files] + + bundles_label = [pathlib.Path(l).stem for l in bundles_files] + bundles_sft = load_bundles(p, args, bundles_files) + + logging.info("Running model to compute loss") + + ls_viz = BundlesLatentSpaceVisualizer( + save_path="/home/local/USHERBROOKE/levj1404/Documents/dwi_ml/data/out.png" + ) + + with torch.no_grad(): + for i, bundle_sft in enumerate(bundles_sft): + + # Resample + streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle_sft.streamlines, 256)), + dtype=torch.float32, device=device) + + latent_streamlines = model.encode(streamlines).cpu().numpy() # output of (N, 32) + ls_viz.add_data_to_plot(latent_streamlines, label=bundles_label[i]) + + ls_viz.plot() + + +if __name__ == '__main__': + main() diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 31526453..79542aba 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -4,7 +4,7 @@ import logging import torch - +import numpy as np from scilpy.io.utils import (add_overwrite_arg, assert_outputs_exist, add_reference_arg, @@ -14,6 +14,8 @@ from dwi_ml.io_utils import (add_arg_existing_experiment_path, add_memory_args) from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer +from dipy.tracking.streamline import set_number_of_points def _build_arg_parser(): @@ -31,6 +33,8 @@ def _build_arg_parser(): p.add_argument('out_tractogram', help="If set, saves the tractogram with the loss per point " "as a data per point (color)") + p.add_argument('--viz_save_path', type=str, default=None, + help="Path to save the figure. If not specified, the figure will be shown.") # Options p.add_argument('--batch_size', type=int) @@ -60,23 +64,16 @@ def main(): assert_outputs_exist(p, args, args.out_tractogram) # Device - device = (torch.device('cuda') if torch.cuda.is_available() and - args.use_gpu else None) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 1. Load model logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level) - # model.set_context('training') - # 2. Compute loss - # tester = TesterOneInput(args.experiment_path, - # model, - # args.batch_size, - # device) - # tester = Tester(args.experiment_path, model, args.batch_size, device) - # sft = tester.load_and_format_data(args.subj_id, - # args.hdf5_file, - # args.subset) + args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + + # Setup vizualisation + ls_viz = BundlesLatentSpaceVisualizer(save_path=args.viz_save_path) + model.register_hook_post_encoding(lambda encoded_data: ls_viz.add_data_to_plot(encoded_data)) sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() @@ -89,24 +86,15 @@ def main(): save_tractogram(new_sft, 'orig_5000.trk') with torch.no_grad(): - streamlines = [ - torch.as_tensor(s, dtype=torch.float32, device=device) - for s in bundle] + streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle, 256)), + dtype=torch.float32, device=device) tmp_outputs = model(streamlines) - # latent = model.encode(streamlines) - streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] + ls_viz.plot() - # print(streamlines_output[0].shape) + streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] new_sft = sft.from_sft(streamlines_output, sft) - save_tractogram(new_sft, args.out_tractogram) - - # latent_output = [s.cpu().numpy() for s in latent] - - # outputs, losses = tester.run_model_on_sft( - # sft, uncompress_loss=args.uncompress_loss, - # force_compress_loss=args.force_compress_loss, - # weight_with_angle=args.weight_with_angle) + save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) if __name__ == '__main__': From c665cf2298030cea0a62d30ab0b6bd1af72ae42e Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 27 Sep 2024 12:16:36 -0400 Subject: [PATCH 13/14] Viz latent space each n epochs --- dwi_ml/viz/latent_streamlines.py | 138 +++++++++++++++---------------- scripts_python/ae_train_model.py | 40 +++++++++ 2 files changed, 105 insertions(+), 73 deletions(-) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 297bced6..7e260c3a 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -1,5 +1,5 @@ +import os import logging - from typing import Union, List, Tuple from sklearn.manifold import TSNE import numpy as np @@ -7,64 +7,7 @@ import matplotlib.pyplot as plt -def plot_latent_streamlines( - encoded_streamlines: Union[np.ndarray, torch.Tensor], - save_path: str = None, - fig_size: Union[List, Tuple] = None, - random_state: int = 42, - max_subset_size: int = None - ): - """ - Projects and plots the latent space representation - of the streamlines using t-SNE dimensionality reduction. - - Parameters - ---------- - encoded_streamlines: Union[np.ndarray, torch.Tensor] - Latent space streamlines to plot of shape (N, latent_space_dim). - save_path: str - Path to save the figure. If not specified, the figure will be shown. - fig_size: List[int] or Tuple[int] - 2-valued figure size (x, y) - random_state: int - Random state for t-SNE. - max_subset_size: int: - In case of performance issues, you can limit the number of streamlines to plot. - """ - - if isinstance(encoded_streamlines, torch.Tensor): - latent_space_streamlines = encoded_streamlines.cpu().numpy() - else: - latent_space_streamlines = encoded_streamlines - - if max_subset_size is not None: - if not (max_subset_size > 0): - raise ValueError("A max_subset_size of an integer value greater than 0 is required.") - - # Only sample if we need to reduce the number of latent streamlines - # to show on the plot. - if (len(latent_space_streamlines) > max_subset_size): - sample_indices = np.random.choice(len(latent_space_streamlines), max_subset_size, replace=False) - latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) - - # Project the data into 2 dimensions. - tsne = TSNE(n_components=2, random_state=random_state) - X_tsne = tsne.fit_transform(latent_space_streamlines) # Output (N, 2) - - - logging.info("New figure for t-SNE visualisation.") - fig, ax = plt.subplots() - if fig_size is not None: - fig.set_figheight(fig_size[0]) - fig.set_figwidth(fig_size[1]) - - ax.scatter(X_tsne[:, 0], X_tsne[:, 1], alpha=0.9, edgecolors='black', linewidths=0.5) - - if save_path is not None: - fig.savefig(save_path) - else: - plt.show() - +LOGGER = logging.getLogger(__name__) class BundlesLatentSpaceVisualizer(object): """ @@ -79,10 +22,12 @@ class BundlesLatentSpaceVisualizer(object): be projected at the same time since it aims to preserve the local structure of the data. """ def __init__(self, - save_path: str = None, + save_dir: str = None, fig_size: Union[List, Tuple] = None, random_state: int = 42, - max_subset_size: int = None + max_subset_size: int = None, + prefix_numbering: bool = False, + reset_warning: bool = True ): """ Parameters @@ -93,18 +38,43 @@ def __init__(self, 2-valued figure size (x, y) random_state: List Random state for t-SNE. - max_subset_size: + max_subset_size: int In case of performance issues, you can limit the number of streamlines to plot for each bundle. + prefix_numbering: bool + If True, the saved figures will be numbered with the current plot number. + The plot number is incremented after each call to the "plot" method. + reset_warning: bool + If True, a warning will be displayed when calling "plot" several times + without calling "reset_data" in between to clear the data. """ - self.save_path = save_path + self.save_dir = save_dir + + # Make sure that self.save_dir is a directory and exists. + if self.save_dir is not None: + if not os.path.isdir(self.save_dir): + raise ValueError("The save_dir should be a directory.") + self.fig_size = fig_size self.random_state = random_state self.max_subset_size = max_subset_size + self.prefix_numbering = prefix_numbering + self.reset_warning = reset_warning + + self.current_plot_number = 0 + self.should_call_reset_before_plot = False self.tsne = TSNE(n_components=2, random_state=self.random_state) self.bundles = {} - + + def reset_data(self): + """ + Reset the data to plot. If you call plot several times without + calling this method, the data will be accumulated. + """ + # Not sure if resetting the TSNE object is necessary. + self.tsne = TSNE(n_components=2, random_state=self.random_state) + self.bundles = {} def add_data_to_plot(self, data: np.ndarray, label: str = '_'): """ @@ -119,7 +89,7 @@ def add_data_to_plot(self, data: np.ndarray, label: str = '_'): Name of the bundle. Used for the legend. """ if isinstance(data, torch.Tensor): - latent_space_streamlines = data.cpu().numpy() + latent_space_streamlines = data.detach().numpy() else: latent_space_streamlines = data @@ -135,13 +105,28 @@ def add_data_to_plot(self, data: np.ndarray, label: str = '_'): self.bundles[label] = latent_space_streamlines - def plot(self): + def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): """ Fit and plot the t-SNE projection of the latent space streamlines. This should be called once after adding all the data to plot using "add_data_to_plot". + + Parameters + ---------- + figure_name_prefix: str + Name of the figure to be saved. This is just the prefix of the full file + name as it will be suffixed with the current plot number if enabled. """ + if self.should_call_reset_before_plot and self.reset_warning: + LOGGER.warning("You plotted another time without resetting the data. " + "The data will be accumulated, which might lead to unexpected results.") + self.should_call_reset_before_plot = False + elif not self.current_plot_number > 0: + # Only enable the flag for the first plot. + # So that the warning above is only displayed once. + self.should_call_reset_before_plot = True + nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) - logging.info("Plotting a total of {} streamlines".format(nb_streamlines)) + LOGGER.info("Plotting a total of {} streamlines".format(nb_streamlines)) bundles_indices = {} current_start = 0 @@ -153,11 +138,12 @@ def plot(self): all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) - logging.info("Fitting TSNE projection.") + LOGGER.info("Fitting TSNE projection.") all_projected_streamlines = self.tsne.fit_transform(all_streamlines) - logging.info("New figure for t-SNE visualisation.") + LOGGER.info("New figure for t-SNE visualisation.") fig, ax = plt.subplots() + ax.set_title(title) if self.fig_size is not None: fig.set_figheight(self.fig_size[0]) fig.set_figwidth(self.fig_size[1]) @@ -174,10 +160,16 @@ def plot(self): linewidths=0.5, ) - ax.legend() - - if self.save_path is not None: - fig.savefig(self.save_path) + if len(self.bundles) > 1: + ax.legend() + + if self.save_dir is not None: + filename = '{}_{}.png'.format(figure_name_prefix, self.current_plot_number) \ + if self.prefix_numbering else '{}.png'.format(figure_name_prefix) + filename = os.path.join(self.save_dir, filename) + fig.savefig(filename) else: plt.show() + self.current_plot_number += 1 + diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 20d416b1..9ea9eadc 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -27,6 +27,7 @@ from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) from dwi_ml.training.utils.trainer import add_training_args from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer from dwi_ml.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ @@ -43,6 +44,9 @@ def prepare_arg_parser(): add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") + p.add_argument('--viz_latent_space_freq', type=int, default=None, + help="Frequency at which to visualize latent space.\n" + "This is expressed in number of epochs.") add_memory_args(p, add_lazy_options=True, add_rng=True) add_verbose_arg(p) @@ -116,6 +120,42 @@ def init_from_args(args, sub_loggers_level): logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint)) + if args.viz_latent_space_freq is not None: + # Setup to visualize latent space + save_dir = os.path.join(args.experiments_path, args.experiment_name, 'latent_space_plots') + os.makedirs(save_dir, exist_ok=True) + + ls_viz = BundlesLatentSpaceVisualizer(save_dir, + prefix_numbering=True, + max_subset_size=1000) + current_epoch = 0 + def visualize_latent_space(encoding): + """ + This is not a clean way to do it. This would require changes in the + trainer to allow for a callback system where we could register a + function to be called at the end of each epoch to plot the latent space + of the data accumulated during the epoch (at each batch). + + Also, using this method, the latent space of the last epoch will not be + plotted. We would need to calculate which batch step would be the last in + the epoch and then plot accordingly. + """ + nonlocal current_epoch, trainer, ls_viz + + # Only execute the following if we are in training + if not trainer.model.context == 'training': + return + + changed_epoch = current_epoch != trainer.current_epoch + if not changed_epoch: + ls_viz.add_data_to_plot(encoding) + elif changed_epoch and trainer.current_epoch % args.viz_latent_space_freq == 0: + current_epoch = trainer.current_epoch + ls_viz.plot(title="Latent space at epoch {}".format(current_epoch)) + ls_viz.reset_data() + ls_viz.add_data_to_plot(encoding) + model.register_hook_post_encoding(visualize_latent_space) + return trainer From 8df0c0f3a60ad6ef0d66e06f447c53b51c366ecf Mon Sep 17 00:00:00 2001 From: Jeremi Levesque Date: Fri, 27 Sep 2024 13:20:57 -0400 Subject: [PATCH 14/14] autopep8 pass --- dwi_ml/viz/latent_streamlines.py | 102 ++++++++++++--------- scripts_python/ae_train_model.py | 28 +++--- scripts_python/ae_visualize_bundles.py | 11 ++- scripts_python/ae_visualize_streamlines.py | 10 +- 4 files changed, 91 insertions(+), 60 deletions(-) diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py index 7e260c3a..ddd95e5b 100644 --- a/dwi_ml/viz/latent_streamlines.py +++ b/dwi_ml/viz/latent_streamlines.py @@ -9,47 +9,52 @@ LOGGER = logging.getLogger(__name__) + class BundlesLatentSpaceVisualizer(object): """ - Utility class that wraps a t-SNE projection of the latent space for multiple bundles. - The usage of this class is intented as follows: + Utility class that wraps a t-SNE projection of the latent + space for multiple bundles. The usage of this class is + intented as follows: 1. Create an instance of this class, - 2. Add the latent space streamlines for each bundle using "add_data_to_plot" - with its corresponding label. + 2. Add the latent space streamlines for each bundle + using "add_data_to_plot" with its corresponding label. 3. Fit and plot the t-SNE projection using the "plot" method. - - t-SNE projection can only leverage the fit_transform() with all the data that needs to - be projected at the same time since it aims to preserve the local structure of the data. + + t-SNE projection can only leverage the fit_transform() with all + the data that needs to be projected at the same time since it aims + to preserve the local structure of the data. """ + def __init__(self, - save_dir: str = None, - fig_size: Union[List, Tuple] = None, - random_state: int = 42, - max_subset_size: int = None, - prefix_numbering: bool = False, - reset_warning: bool = True - ): + save_dir: str = None, + fig_size: Union[List, Tuple] = None, + random_state: int = 42, + max_subset_size: int = None, + prefix_numbering: bool = False, + reset_warning: bool = True + ): """ Parameters ---------- save_path: str - Path to save the figure. If not specified, the figure will be shown. + Path to save the figure. If not specified, the figure will show. fig_size: List[int] or Tuple[int] 2-valued figure size (x, y) random_state: List Random state for t-SNE. max_subset_size: int - In case of performance issues, you can limit the number of streamlines to plot - for each bundle. + In case of performance issues, you can limit the number of + streamlines to plot for each bundle. prefix_numbering: bool - If True, the saved figures will be numbered with the current plot number. - The plot number is incremented after each call to the "plot" method. + If True, the saved figures will be numbered with the current + plot number. The plot number is incremented after each call + to the "plot" method. reset_warning: bool - If True, a warning will be displayed when calling "plot" several times - without calling "reset_data" in between to clear the data. + If True, a warning will be displayed when calling "plot"several + times without calling "reset_data" in between to clear the data. """ self.save_dir = save_dir - + # Make sure that self.save_dir is a directory and exists. if self.save_dir is not None: if not os.path.isdir(self.save_dir): @@ -60,7 +65,7 @@ def __init__(self, self.max_subset_size = max_subset_size self.prefix_numbering = prefix_numbering self.reset_warning = reset_warning - + self.current_plot_number = 0 self.should_call_reset_before_plot = False @@ -92,33 +97,43 @@ def add_data_to_plot(self, data: np.ndarray, label: str = '_'): latent_space_streamlines = data.detach().numpy() else: latent_space_streamlines = data - + if self.max_subset_size is not None: if not (self.max_subset_size > 0): - raise ValueError("A max_subset_size of an integer value greater than 0 is required.") - + raise ValueError( + "A max_subset_size of an integer value greater" + "than 0 is required.") + # Only sample if we need to reduce the number of latent streamlines # to show on the plot. if (len(latent_space_streamlines) > self.max_subset_size): - sample_indices = np.random.choice(len(latent_space_streamlines), self.max_subset_size, replace=False) - latent_space_streamlines = latent_space_streamlines[sample_indices] # (max_subset_size, 2) - + sample_indices = np.random.choice( + len(latent_space_streamlines), + self.max_subset_size, replace=False) + + latent_space_streamlines = \ + latent_space_streamlines[sample_indices] + self.bundles[label] = latent_space_streamlines def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): """ Fit and plot the t-SNE projection of the latent space streamlines. - This should be called once after adding all the data to plot using "add_data_to_plot". - + This should be called once after adding all the data to plot using + "add_data_to_plot". + Parameters ---------- figure_name_prefix: str - Name of the figure to be saved. This is just the prefix of the full file - name as it will be suffixed with the current plot number if enabled. + Name of the figure to be saved. This is just the prefix of the + full file name as it will be suffixed with the current plot + number if enabled. """ if self.should_call_reset_before_plot and self.reset_warning: - LOGGER.warning("You plotted another time without resetting the data. " - "The data will be accumulated, which might lead to unexpected results.") + LOGGER.warning( + "You plotted another time without resetting the data. " + "The data will be accumulated, which might lead to " + "unexpected results.") self.should_call_reset_before_plot = False elif not self.current_plot_number > 0: # Only enable the flag for the first plot. @@ -126,12 +141,14 @@ def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): self.should_call_reset_before_plot = True nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) - LOGGER.info("Plotting a total of {} streamlines".format(nb_streamlines)) + LOGGER.info( + "Plotting a total of {} streamlines".format(nb_streamlines)) bundles_indices = {} current_start = 0 for (bname, bdata) in self.bundles.items(): - bundles_indices[bname] = np.arange(current_start, current_start + bdata.shape[0]) + bundles_indices[bname] = np.arange( + current_start, current_start + bdata.shape[0]) current_start += bdata.shape[0] assert current_start == nb_streamlines @@ -159,17 +176,20 @@ def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): edgecolors='black', linewidths=0.5, ) - + if len(self.bundles) > 1: ax.legend() if self.save_dir is not None: - filename = '{}_{}.png'.format(figure_name_prefix, self.current_plot_number) \ - if self.prefix_numbering else '{}.png'.format(figure_name_prefix) + if self.prefix_numbering: + filename = '{}_{}.png'.format( + figure_name_prefix, self.current_plot_number) + else: + filename = '{}.png'.format(figure_name_prefix) + filename = os.path.join(self.save_dir, filename) fig.savefig(filename) else: plt.show() self.current_plot_number += 1 - diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 9ea9eadc..1252f27c 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -14,7 +14,8 @@ import comet_ml import torch -from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, add_verbose_arg +from scilpy.io.utils import (assert_inputs_exist, assert_outputs_exist, + add_verbose_arg) from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str @@ -40,7 +41,7 @@ def prepare_arg_parser(): add_mandatory_args_experiment_and_hdf5_path(p) add_args_batch_sampler(p) add_args_batch_loader(p) - #training_group = add_training_args(p) + # training_group = add_training_args(p) add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") @@ -51,7 +52,7 @@ def prepare_arg_parser(): add_verbose_arg(p) # Additional arg for projects - #training_group.add_argument( + # training_group.add_argument( # '--clip_grad', type=float, default=None, # help="Value to which the gradient norms to avoid exploding gradients." # "\nDefault = None (not clipping).") @@ -122,13 +123,15 @@ def init_from_args(args, sub_loggers_level): if args.viz_latent_space_freq is not None: # Setup to visualize latent space - save_dir = os.path.join(args.experiments_path, args.experiment_name, 'latent_space_plots') + save_dir = os.path.join(args.experiments_path, + args.experiment_name, 'latent_space_plots') os.makedirs(save_dir, exist_ok=True) ls_viz = BundlesLatentSpaceVisualizer(save_dir, - prefix_numbering=True, - max_subset_size=1000) + prefix_numbering=True, + max_subset_size=1000) current_epoch = 0 + def visualize_latent_space(encoding): """ This is not a clean way to do it. This would require changes in the @@ -151,7 +154,8 @@ def visualize_latent_space(encoding): ls_viz.add_data_to_plot(encoding) elif changed_epoch and trainer.current_epoch % args.viz_latent_space_freq == 0: current_epoch = trainer.current_epoch - ls_viz.plot(title="Latent space at epoch {}".format(current_epoch)) + ls_viz.plot(title="Latent space at epoch {}".format( + current_epoch)) ls_viz.reset_data() ls_viz.add_data_to_plot(encoding) model.register_hook_post_encoding(visualize_latent_space) @@ -175,10 +179,12 @@ def main(): assert_outputs_exist(p, args, args.experiments_path) # Verify if a checkpoint has been saved. Else create an experiment. - if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, - "checkpoint")): - raise FileExistsError("This experiment already exists. Delete or use " - "script l2t_resume_training_from_checkpoint.py.") + if os.path.exists(os.path.join( + args.experiments_path, args.experiment_name, + "checkpoint")): + raise FileExistsError( + "This experiment already exists. Delete or use " + "script l2t_resume_training_from_checkpoint.py.") trainer = init_from_args(args, sub_loggers_level) diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py index c4b7c0a4..2e7aad1e 100644 --- a/scripts_python/ae_visualize_bundles.py +++ b/scripts_python/ae_visualize_bundles.py @@ -15,7 +15,7 @@ add_verbose_arg) from scilpy.io.streamlines import load_tractogram_with_reference from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_memory_args) + add_memory_args) from dwi_ml.models.projects.ae_models import ModelAE from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer @@ -42,6 +42,7 @@ def _build_arg_parser(): add_verbose_arg(p) return p + def load_bundles(p, args, files_list: list): bundles = [] for bundle_file in files_list: @@ -51,6 +52,7 @@ def load_bundles(p, args, files_list: list): bundles.append(bundle_sft) return bundles + def main(): p = _build_arg_parser() args = p.parse_args() @@ -91,12 +93,13 @@ def main(): with torch.no_grad(): for i, bundle_sft in enumerate(bundles_sft): - + # Resample streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle_sft.streamlines, 256)), dtype=torch.float32, device=device) - - latent_streamlines = model.encode(streamlines).cpu().numpy() # output of (N, 32) + + latent_streamlines = model.encode( + streamlines).cpu().numpy() # output of (N, 32) ls_viz.add_data_to_plot(latent_streamlines, label=bundles_label[i]) ls_viz.plot() diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 79542aba..69a91fb3 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -12,7 +12,7 @@ from scilpy.io.streamlines import load_tractogram_with_reference from dipy.io.streamline import save_tractogram from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_memory_args) + add_memory_args) from dwi_ml.models.projects.ae_models import ModelAE from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer from dipy.tracking.streamline import set_number_of_points @@ -73,7 +73,8 @@ def main(): # Setup vizualisation ls_viz = BundlesLatentSpaceVisualizer(save_path=args.viz_save_path) - model.register_hook_post_encoding(lambda encoded_data: ls_viz.add_data_to_plot(encoded_data)) + model.register_hook_post_encoding( + lambda encoded_data: ls_viz.add_data_to_plot(encoded_data)) sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() @@ -87,12 +88,13 @@ def main(): with torch.no_grad(): streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle, 256)), - dtype=torch.float32, device=device) + dtype=torch.float32, device=device) tmp_outputs = model(streamlines) ls_viz.plot() - streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] + streamlines_output = [tmp_outputs[i, :, :].transpose( + 0, 1).cpu().numpy() for i in range(len(bundle))] new_sft = sft.from_sft(streamlines_output, sft) save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False)