diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 3221d007..f87ff2b8 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -133,6 +133,7 @@ class HDF5Creator: See the doc for an example of config file. https://dwi-ml.readthedocs.io/en/latest/config_file.html """ + def __init__(self, root_folder: Path, out_hdf_filename: Path, training_subjs: List[str], validation_subjs: List[str], testing_subjs: List[str], groups_config: dict, @@ -634,8 +635,9 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, raise ValueError( "The data_per_streamline key '{}' was not found in " "the sft. Check your tractogram file.".format(dps_key)) - - logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key)) + + logging.debug( + " Include dps \"{}\" in the HDF5.".format(dps_key)) streamlines_group.create_dataset('dps_' + dps_key, data=sft.data_per_streamline[dps_key]) @@ -652,7 +654,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, def _process_one_streamline_group( self, subj_dir: Path, group: str, subj_id: str, - header: nib.Nifti1Header): + header: nib.Nifti1Header, remove_invalid=False): """ Loads and processes a group of tractograms and merges all streamlines together. diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py new file mode 100644 index 00000000..50aace59 --- /dev/null +++ b/dwi_ml/models/projects/ae_models.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +import logging +from typing import List + +import torch +from torch.nn import functional as F + +from dwi_ml.models.main_models import MainModelAbstract + + +class ModelAE(MainModelAbstract): + """ + Recurrent tracking model. + + Composed of an embedding for the imaging data's input + for the previous + direction's input, an RNN model to process the sequences, and a direction + getter model to convert the RNN outputs to the right structure, e.g. + deterministic (3D vectors) or probabilistic (based on probability + distribution parameters). + """ + def __init__(self, kernel_size, latent_space_dims, + experiment_name: str, + # Target preprocessing params for the batch loader + tracker + step_size: float = None, + compress_lines: float = False, + # Other + log_level=logging.root.level): + super().__init__(experiment_name, step_size, compress_lines, log_level) + + self.kernel_size = kernel_size + self.latent_space_dims = latent_space_dims + + self.pad = torch.nn.ReflectionPad1d(1) + self.post_encoding_hooks = [] + + def pre_pad(m): + return torch.nn.Sequential(self.pad, m) + + self.fc1 = torch.nn.Linear(8192, + self.latent_space_dims) # 8192 = 1024*8 + self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + + """ + Encode convolutions + """ + self.encod_conv1 = pre_pad( + torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv2 = pre_pad( + torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv3 = pre_pad( + torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv4 = pre_pad( + torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv5 = pre_pad( + torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv6 = pre_pad( + torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0) + ) + + """ + Decode convolutions + """ + self.decod_conv1 = pre_pad( + torch.nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0) + ) + self.upsampl1 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv2 = pre_pad( + torch.nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0) + ) + self.upsampl2 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv3 = pre_pad( + torch.nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0) + ) + self.upsampl3 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv4 = pre_pad( + torch.nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0) + ) + self.upsampl4 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv5 = pre_pad( + torch.nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0) + ) + self.upsampl5 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv6 = pre_pad( + torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) + ) + + @property + def params_for_checkpoint(self): + """All parameters necessary to create again the same model. Will be + used in the trainer, when saving the checkpoint state. Params here + will be used to re-create the model when starting an experiment from + checkpoint. You should be able to re-create an instance of your + model with those params.""" + # p = super().params_for_checkpoint() + p = {'kernel_size': self.kernel_size, + 'latent_space_dims': self.latent_space_dims, + 'experiment_name': self.experiment_name} + return p + + @classmethod + def _load_params(cls, model_dir): + p = super()._load_params(model_dir) + p['kernel_size'] = 3 + p['latent_space_dims'] = 32 + return p + + def forward(self, + input_streamlines: List[torch.tensor], + ): + """Run the model on a batch of sequences. + + Parameters + ---------- + input_streamlines: List[torch.tensor], + Batch of streamlines. Only used if previous directions are added to + the model. Used to compute directions; its last point will not be + used. + + Returns + ------- + model_outputs : List[Tensor] + Output data, ready to be passed to either `compute_loss()` or + `get_tracking_directions()`. + """ + + encoded = self.encode(input_streamlines) + + for hook in self.post_encoding_hooks: + hook(encoded) + + x = self.decode(encoded) + return x + + def encode(self, x): + # X input shape is (batch_size, nb_points, 3) + if isinstance(x, list): + x = torch.stack(x) + + x = torch.swapaxes(x, 1, 2) # input of the network should be (N, 3, nb_points) + + h1 = F.relu(self.encod_conv1(x)) + h2 = F.relu(self.encod_conv2(h1)) + h3 = F.relu(self.encod_conv3(h2)) + h4 = F.relu(self.encod_conv4(h3)) + h5 = F.relu(self.encod_conv5(h4)) + h6 = self.encod_conv6(h5) + + self.encoder_out_size = (h6.shape[1], h6.shape[2]) + + # Flatten + h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) + + fc1 = self.fc1(h7) + + return fc1 + + def decode(self, z): + fc = self.fc2(z) + fc_reshape = fc.view( + -1, self.encoder_out_size[0], self.encoder_out_size[1] + ) + h1 = F.relu(self.decod_conv1(fc_reshape)) + h2 = self.upsampl1(h1) + h3 = F.relu(self.decod_conv2(h2)) + h4 = self.upsampl2(h3) + h5 = F.relu(self.decod_conv3(h4)) + h6 = self.upsampl3(h5) + h7 = F.relu(self.decod_conv4(h6)) + h8 = self.upsampl4(h7) + h9 = F.relu(self.decod_conv5(h8)) + h10 = self.upsampl5(h9) + h11 = self.decod_conv6(h10) + + return h11 + + def compute_loss(self, model_outputs, targets, average_results=True): + targets = torch.stack(targets) + targets = torch.swapaxes(targets, 1, 2) + reconstruction_loss = torch.nn.MSELoss() + mse = reconstruction_loss(model_outputs, targets) + + return mse, 1 + + def register_hook_post_encoding(self, hook): + self.post_encoding_hooks.append(hook) + diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 743830bb..43f16136 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -58,7 +58,7 @@ logger = logging.getLogger('batch_loader_logger') -class DWIMLAbstractBatchLoader: +class DWIMLStreamlinesBatchLoader: def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, streamline_group_name: str, rng: int, split_ratio: float = 0., @@ -197,7 +197,11 @@ def _data_augmentation_sft(self, sft): self.context_subset.compress == self.model.compress_lines: logger.debug("Compression rate is the same as when creating " "the hdf5 dataset. Not compressing again.") - else: + elif self.model.step_size is not None and \ + self.model.compress_lines is not None: + logger.debug("Resample streamlines using: \n" + + "- step_size: {}\n".format(self.model.step_size) + + "- compress_lines: {}".format(self.model.compress_lines)) sft = resample_or_compress(sft, self.model.step_size, self.model.nb_points, self.model.compress_lines) @@ -314,6 +318,7 @@ def load_batch_streamlines( sft.to_vox() sft.to_corner() batch_streamlines.extend(sft.streamlines) + batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines] return batch_streamlines, final_s_ids_per_subj @@ -351,7 +356,7 @@ def load_batch_connectivity_matrices( connectivity_nb_blocs, connectivity_labels) -class DWIMLBatchLoaderOneInput(DWIMLAbstractBatchLoader): +class DWIMLBatchLoaderOneInput(DWIMLStreamlinesBatchLoader): """ Loads: input = one volume group diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 4355ad0f..2aae19ac 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -17,7 +17,7 @@ from dwi_ml.models.main_models import (MainModelAbstract, ModelWithDirectionGetter) from dwi_ml.training.batch_loaders import ( - DWIMLAbstractBatchLoader, DWIMLBatchLoaderOneInput) + DWIMLStreamlinesBatchLoader, DWIMLBatchLoaderOneInput) from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.utils.gradient_norm import compute_gradient_norm from dwi_ml.training.utils.monitoring import ( @@ -53,7 +53,7 @@ class DWIMLAbstractTrainer: def __init__(self, model: MainModelAbstract, experiments_path: str, experiment_name: str, batch_sampler: DWIMLBatchIDSampler, - batch_loader: DWIMLAbstractBatchLoader, + batch_loader: DWIMLStreamlinesBatchLoader, learning_rates: Union[List, float] = None, weight_decay: float = 0.01, optimizer: str = 'Adam', max_epochs: int = 10, @@ -78,7 +78,7 @@ def __init__(self, batch_sampler: DWIMLBatchIDSampler Instantiated class used for sampling batches. Data in batch_sampler.dataset must be already loaded. - batch_loader: DWIMLAbstractBatchLoader + batch_loader: DWIMLStreamlinesBatchLoader Instantiated class with a load_batch method able to load data associated to sampled batch ids. Data in batch_sampler.dataset must be already loaded. @@ -461,7 +461,7 @@ def _prepare_checkpoint_info(self) -> dict: def init_from_checkpoint( cls, model: MainModelAbstract, experiments_path, experiment_name, batch_sampler: DWIMLBatchIDSampler, - batch_loader: DWIMLAbstractBatchLoader, + batch_loader: DWIMLStreamlinesBatchLoader, checkpoint_state: dict, new_patience, new_max_epochs, log_level): """ Loads checkpoint information (parameters and states) to instantiate @@ -1013,7 +1013,51 @@ def run_one_batch(self, data): Any other data returned when computing loss. Not used in the trainer, but could be useful anywhere else. """ - raise NotImplementedError + # Data interpolation has not been done yet. GPU computations are done + # here in the main thread. + targets, ids_per_subj = data + + # Dataloader always works on CPU. Sending to right device. + # (model is already moved). + targets = [s.to(self.device, non_blocking=True, dtype=torch.float) + for s in targets] + + # Getting the inputs points from the volumes. + # Uses the model's method, with the batch_loader's data. + # Possibly skipping the last point if not useful. + streamlines_f = targets + if isinstance(self.model, ModelWithDirectionGetter) and \ + not self.model.direction_getter.add_eos: + # No EOS = We don't use the last coord because it does not have an + # associated target direction. + streamlines_f = [s[:-1, :] for s in streamlines_f] + + # Possibly add noise to inputs here. + logger.debug('*** Computing forward propagation') + + # Now possibly add noise to streamlines (training / valid) + streamlines_f = self.batch_loader.add_noise_streamlines_forward( + streamlines_f, self.device) + + # Possibly computing directions twice (during forward and loss) + # but ok, shouldn't be too heavy. Easier to deal with multiple + # projects' requirements by sending whole streamlines rather + # than only directions. + model_outputs = self.model(streamlines_f) + del streamlines_f + + logger.debug('*** Computing loss') + targets = self.batch_loader.add_noise_streamlines_loss( + targets, self.device) + + results = self.model.compute_loss(model_outputs, targets, + average_results=True) + + if self.use_gpu: + log_gpu_memory_usage(logger) + + # The mean tensor is a single value. Converting to float using item(). + return results def fix_parameters(self): """ diff --git a/dwi_ml/viz/latent_streamlines.py b/dwi_ml/viz/latent_streamlines.py new file mode 100644 index 00000000..ddd95e5b --- /dev/null +++ b/dwi_ml/viz/latent_streamlines.py @@ -0,0 +1,195 @@ +import os +import logging +from typing import Union, List, Tuple +from sklearn.manifold import TSNE +import numpy as np +import torch + +import matplotlib.pyplot as plt + +LOGGER = logging.getLogger(__name__) + + +class BundlesLatentSpaceVisualizer(object): + """ + Utility class that wraps a t-SNE projection of the latent + space for multiple bundles. The usage of this class is + intented as follows: + 1. Create an instance of this class, + 2. Add the latent space streamlines for each bundle + using "add_data_to_plot" with its corresponding label. + 3. Fit and plot the t-SNE projection using the "plot" method. + + t-SNE projection can only leverage the fit_transform() with all + the data that needs to be projected at the same time since it aims + to preserve the local structure of the data. + """ + + def __init__(self, + save_dir: str = None, + fig_size: Union[List, Tuple] = None, + random_state: int = 42, + max_subset_size: int = None, + prefix_numbering: bool = False, + reset_warning: bool = True + ): + """ + Parameters + ---------- + save_path: str + Path to save the figure. If not specified, the figure will show. + fig_size: List[int] or Tuple[int] + 2-valued figure size (x, y) + random_state: List + Random state for t-SNE. + max_subset_size: int + In case of performance issues, you can limit the number of + streamlines to plot for each bundle. + prefix_numbering: bool + If True, the saved figures will be numbered with the current + plot number. The plot number is incremented after each call + to the "plot" method. + reset_warning: bool + If True, a warning will be displayed when calling "plot"several + times without calling "reset_data" in between to clear the data. + """ + self.save_dir = save_dir + + # Make sure that self.save_dir is a directory and exists. + if self.save_dir is not None: + if not os.path.isdir(self.save_dir): + raise ValueError("The save_dir should be a directory.") + + self.fig_size = fig_size + self.random_state = random_state + self.max_subset_size = max_subset_size + self.prefix_numbering = prefix_numbering + self.reset_warning = reset_warning + + self.current_plot_number = 0 + self.should_call_reset_before_plot = False + + self.tsne = TSNE(n_components=2, random_state=self.random_state) + self.bundles = {} + + def reset_data(self): + """ + Reset the data to plot. If you call plot several times without + calling this method, the data will be accumulated. + """ + # Not sure if resetting the TSNE object is necessary. + self.tsne = TSNE(n_components=2, random_state=self.random_state) + self.bundles = {} + + def add_data_to_plot(self, data: np.ndarray, label: str = '_'): + """ + Add unprojected data (no t-SNE, no PCA, etc.). + This should be directly the output of the model as a numpy array. + + Parameters + ---------- + data: str + Unprojected latent space streamlines (N, latent_space_dim). + label: str + Name of the bundle. Used for the legend. + """ + if isinstance(data, torch.Tensor): + latent_space_streamlines = data.detach().numpy() + else: + latent_space_streamlines = data + + if self.max_subset_size is not None: + if not (self.max_subset_size > 0): + raise ValueError( + "A max_subset_size of an integer value greater" + "than 0 is required.") + + # Only sample if we need to reduce the number of latent streamlines + # to show on the plot. + if (len(latent_space_streamlines) > self.max_subset_size): + sample_indices = np.random.choice( + len(latent_space_streamlines), + self.max_subset_size, replace=False) + + latent_space_streamlines = \ + latent_space_streamlines[sample_indices] + + self.bundles[label] = latent_space_streamlines + + def plot(self, title: str = "", figure_name_prefix: str = 'lt_space'): + """ + Fit and plot the t-SNE projection of the latent space streamlines. + This should be called once after adding all the data to plot using + "add_data_to_plot". + + Parameters + ---------- + figure_name_prefix: str + Name of the figure to be saved. This is just the prefix of the + full file name as it will be suffixed with the current plot + number if enabled. + """ + if self.should_call_reset_before_plot and self.reset_warning: + LOGGER.warning( + "You plotted another time without resetting the data. " + "The data will be accumulated, which might lead to " + "unexpected results.") + self.should_call_reset_before_plot = False + elif not self.current_plot_number > 0: + # Only enable the flag for the first plot. + # So that the warning above is only displayed once. + self.should_call_reset_before_plot = True + + nb_streamlines = sum(b.shape[0] for b in self.bundles.values()) + LOGGER.info( + "Plotting a total of {} streamlines".format(nb_streamlines)) + + bundles_indices = {} + current_start = 0 + for (bname, bdata) in self.bundles.items(): + bundles_indices[bname] = np.arange( + current_start, current_start + bdata.shape[0]) + current_start += bdata.shape[0] + + assert current_start == nb_streamlines + + all_streamlines = np.concatenate(list(self.bundles.values()), axis=0) + + LOGGER.info("Fitting TSNE projection.") + all_projected_streamlines = self.tsne.fit_transform(all_streamlines) + + LOGGER.info("New figure for t-SNE visualisation.") + fig, ax = plt.subplots() + ax.set_title(title) + if self.fig_size is not None: + fig.set_figheight(self.fig_size[0]) + fig.set_figwidth(self.fig_size[1]) + + for (bname, bdata) in self.bundles.items(): + bindices = bundles_indices[bname] + proj_data = all_projected_streamlines[bindices] + ax.scatter( + proj_data[:, 0], + proj_data[:, 1], + label=bname, + alpha=0.9, + edgecolors='black', + linewidths=0.5, + ) + + if len(self.bundles) > 1: + ax.legend() + + if self.save_dir is not None: + if self.prefix_numbering: + filename = '{}_{}.png'.format( + figure_name_prefix, self.current_plot_number) + else: + filename = '{}.png'.format(figure_name_prefix) + + filename = os.path.join(self.save_dir, filename) + fig.savefig(filename) + else: + plt.show() + + self.current_plot_number += 1 diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py new file mode 100755 index 00000000..1252f27c --- /dev/null +++ b/scripts_python/ae_train_model.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Train a model for Autoencoders +""" +import argparse +import logging +import os + +# comet_ml not used, but comet_ml requires to be imported before torch. +# See bug report here https://github.com/Lightning-AI/lightning/issues/5829 +# Importing now to solve issues later. +import comet_ml +import torch + +from scilpy.io.utils import (assert_inputs_exist, assert_outputs_exist, + add_verbose_arg) + +from dwi_ml.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.experiment_utils.prints import format_dict_to_str +from dwi_ml.experiment_utils.timer import Timer +from dwi_ml.io_utils import add_memory_args +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.training.trainers import DWIMLAbstractTrainer +from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, + prepare_batch_sampler) +from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) +from dwi_ml.training.utils.trainer import add_training_args +from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer +from dwi_ml.training.utils.experiment import ( + add_mandatory_args_experiment_and_hdf5_path) +from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ + format_lr + + +def prepare_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + add_mandatory_args_experiment_and_hdf5_path(p) + add_args_batch_sampler(p) + add_args_batch_loader(p) + # training_group = add_training_args(p) + add_training_args(p) + p.add_argument('streamline_group_name', + help="Name of the group in hdf5") + p.add_argument('--viz_latent_space_freq', type=int, default=None, + help="Frequency at which to visualize latent space.\n" + "This is expressed in number of epochs.") + add_memory_args(p, add_lazy_options=True, add_rng=True) + add_verbose_arg(p) + + # Additional arg for projects + # training_group.add_argument( + # '--clip_grad', type=float, default=None, + # help="Value to which the gradient norms to avoid exploding gradients." + # "\nDefault = None (not clipping).") + + return p + + +def init_from_args(args, sub_loggers_level): + torch.manual_seed(args.rng) # Set torch seed + + # Prepare the dataset + dataset = prepare_multisubjectdataset(args, load_testing=False, + log_level=sub_loggers_level) + + # Preparing the model + # (Direction getter) + # (Nb features) + # Final model + with Timer("\n\nPreparing model", newline=True, color='yellow'): + # INPUTS: verifying args + model = ModelAE( + experiment_name=args.experiment_name, + step_size=None, compress_lines=None, + kernel_size=3, latent_space_dims=32, + log_level=sub_loggers_level) + + logging.info("AEmodel final parameters:" + + format_dict_to_str(model.params_for_checkpoint)) + + logging.info("Computed parameters:" + + format_dict_to_str(model.computed_params_for_display)) + + # Preparing the batch samplers + batch_sampler = prepare_batch_sampler(dataset, args, sub_loggers_level) + # Preparing the batch loader. + with Timer("\nPreparing batch loader...", newline=True, color='pink'): + batch_loader = DWIMLStreamlinesBatchLoader( + dataset=dataset, model=model, + streamline_group_name=args.streamline_group_name, + # OTHER + rng=args.rng, log_level=sub_loggers_level) + + logging.info("Loader user-defined parameters: " + + format_dict_to_str(batch_loader.params_for_checkpoint)) + + # Instantiate trainer + with Timer("\n\nPreparing trainer", newline=True, color='red'): + lr = format_lr(args.learning_rate) + trainer = DWIMLAbstractTrainer( + model=model, experiments_path=args.experiments_path, + experiment_name=args.experiment_name, batch_sampler=batch_sampler, + batch_loader=batch_loader, + # COMET + comet_project=args.comet_project, + comet_workspace=args.comet_workspace, + # TRAINING + learning_rates=lr, weight_decay=args.weight_decay, + optimizer=args.optimizer, max_epochs=args.max_epochs, + max_batches_per_epoch_training=args.max_batches_per_epoch_training, + max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, + patience=args.patience, patience_delta=args.patience_delta, + from_checkpoint=False, clip_grad=args.clip_grad, + # MEMORY + nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, + log_level=sub_loggers_level) + logging.info("Trainer params : " + + format_dict_to_str(trainer.params_for_checkpoint)) + + if args.viz_latent_space_freq is not None: + # Setup to visualize latent space + save_dir = os.path.join(args.experiments_path, + args.experiment_name, 'latent_space_plots') + os.makedirs(save_dir, exist_ok=True) + + ls_viz = BundlesLatentSpaceVisualizer(save_dir, + prefix_numbering=True, + max_subset_size=1000) + current_epoch = 0 + + def visualize_latent_space(encoding): + """ + This is not a clean way to do it. This would require changes in the + trainer to allow for a callback system where we could register a + function to be called at the end of each epoch to plot the latent space + of the data accumulated during the epoch (at each batch). + + Also, using this method, the latent space of the last epoch will not be + plotted. We would need to calculate which batch step would be the last in + the epoch and then plot accordingly. + """ + nonlocal current_epoch, trainer, ls_viz + + # Only execute the following if we are in training + if not trainer.model.context == 'training': + return + + changed_epoch = current_epoch != trainer.current_epoch + if not changed_epoch: + ls_viz.add_data_to_plot(encoding) + elif changed_epoch and trainer.current_epoch % args.viz_latent_space_freq == 0: + current_epoch = trainer.current_epoch + ls_viz.plot(title="Latent space at epoch {}".format( + current_epoch)) + ls_viz.reset_data() + ls_viz.add_data_to_plot(encoding) + model.register_hook_post_encoding(visualize_latent_space) + + return trainer + + +def main(): + p = prepare_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + # Check that all files exist + assert_inputs_exist(p, [args.hdf5_file]) + assert_outputs_exist(p, args, args.experiments_path) + + # Verify if a checkpoint has been saved. Else create an experiment. + if os.path.exists(os.path.join( + args.experiments_path, args.experiment_name, + "checkpoint")): + raise FileExistsError( + "This experiment already exists. Delete or use " + "script l2t_resume_training_from_checkpoint.py.") + + trainer = init_from_args(args, sub_loggers_level) + + run_experiment(trainer) + + +if __name__ == '__main__': + main() diff --git a/scripts_python/ae_visualize_bundles.py b/scripts_python/ae_visualize_bundles.py new file mode 100644 index 00000000..2e7aad1e --- /dev/null +++ b/scripts_python/ae_visualize_bundles.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +import logging +import pathlib +import torch +import numpy as np +from glob import glob +from os.path import expanduser +from dipy.tracking.streamline import set_number_of_points + +from scilpy.io.utils import (add_overwrite_arg, + assert_outputs_exist, + add_reference_arg, + add_verbose_arg) +from scilpy.io.streamlines import load_tractogram_with_reference +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_memory_args) +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer + + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + # Mandatory + # Should only be False for debugging tests. + add_arg_existing_experiment_path(p) + # Add_args_testing_subj_hdf5(p) + + p.add_argument('in_bundles', + help="The 'glob' path to several bundles identified by their file name." + "e.g. FiberCupGroundTruth_filtered_bundle_0.tck") + + # Options + p.add_argument('--batch_size', type=int) + add_memory_args(p) + + p.add_argument('--pick_at_random', action='store_true') + add_reference_arg(p) + add_overwrite_arg(p) + add_verbose_arg(p) + return p + + +def load_bundles(p, args, files_list: list): + bundles = [] + for bundle_file in files_list: + bundle_sft = load_tractogram_with_reference(p, args, bundle_file) + bundle_sft.to_vox() + bundle_sft.to_corner() + bundles.append(bundle_sft) + return bundles + + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + # Verify output names + # Check experiment_path exists and best_model folder exists + # Assert_inputs_exist(p, args.hdf5_file) + assert_outputs_exist(p, args, []) + + # Device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 1. Load model + logging.debug("Loading model.") + model = ModelAE.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + + expanded = expanduser(args.in_bundles) + bundles_files = glob(expanded) + if isinstance(bundles_files, str): + bundles_files = [bundles_files] + + bundles_label = [pathlib.Path(l).stem for l in bundles_files] + bundles_sft = load_bundles(p, args, bundles_files) + + logging.info("Running model to compute loss") + + ls_viz = BundlesLatentSpaceVisualizer( + save_path="/home/local/USHERBROOKE/levj1404/Documents/dwi_ml/data/out.png" + ) + + with torch.no_grad(): + for i, bundle_sft in enumerate(bundles_sft): + + # Resample + streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle_sft.streamlines, 256)), + dtype=torch.float32, device=device) + + latent_streamlines = model.encode( + streamlines).cpu().numpy() # output of (N, 32) + ls_viz.add_data_to_plot(latent_streamlines, label=bundles_label[i]) + + ls_viz.plot() + + +if __name__ == '__main__': + main() diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py new file mode 100644 index 00000000..69a91fb3 --- /dev/null +++ b/scripts_python/ae_visualize_streamlines.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +import logging + +import torch +import numpy as np +from scilpy.io.utils import (add_overwrite_arg, + assert_outputs_exist, + add_reference_arg, + add_verbose_arg) +from scilpy.io.streamlines import load_tractogram_with_reference +from dipy.io.streamline import save_tractogram +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_memory_args) +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.viz.latent_streamlines import BundlesLatentSpaceVisualizer +from dipy.tracking.streamline import set_number_of_points + + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + # Mandatory + # Should only be False for debugging tests. + add_arg_existing_experiment_path(p) + # Add_args_testing_subj_hdf5(p) + + p.add_argument('in_tractogram', + help="If set, saves the tractogram with the loss per point " + "as a data per point (color)") + + p.add_argument('out_tractogram', + help="If set, saves the tractogram with the loss per point " + "as a data per point (color)") + p.add_argument('--viz_save_path', type=str, default=None, + help="Path to save the figure. If not specified, the figure will be shown.") + + # Options + p.add_argument('--batch_size', type=int) + add_memory_args(p) + + p.add_argument('--pick_at_random', action='store_true') + add_reference_arg(p) + add_overwrite_arg(p) + add_verbose_arg(p) + return p + + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + # Verify output names + # Check experiment_path exists and best_model folder exists + # Assert_inputs_exist(p, args.hdf5_file) + assert_outputs_exist(p, args, args.out_tractogram) + + # Device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # 1. Load model + logging.debug("Loading model.") + model = ModelAE.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + + # Setup vizualisation + ls_viz = BundlesLatentSpaceVisualizer(save_path=args.viz_save_path) + model.register_hook_post_encoding( + lambda encoded_data: ls_viz.add_data_to_plot(encoded_data)) + + sft = load_tractogram_with_reference(p, args, args.in_tractogram) + sft.to_vox() + sft.to_corner() + bundle = sft.streamlines[0:5000] + + logging.info("Running model to compute loss") + + new_sft = sft.from_sft(bundle, sft) + save_tractogram(new_sft, 'orig_5000.trk') + + with torch.no_grad(): + streamlines = torch.as_tensor(np.asarray(set_number_of_points(bundle, 256)), + dtype=torch.float32, device=device) + tmp_outputs = model(streamlines) + + ls_viz.plot() + + streamlines_output = [tmp_outputs[i, :, :].transpose( + 0, 1).cpu().numpy() for i in range(len(bundle))] + new_sft = sft.from_sft(streamlines_output, sft) + save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) + + +if __name__ == '__main__': + main()