From cd9275af952748b8c3d6761bca8842135f6aa89d Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Thu, 23 Nov 2023 13:46:41 -0500 Subject: [PATCH 01/24] add ae - finta --- dwi_ml/data/hdf5/hdf5_creation.py | 23 +-- .../streamlines/data_augmentation.py | 2 +- .../processing/streamlines/post_processing.py | 2 +- dwi_ml/models/main_models.py | 7 +- dwi_ml/models/projects/ae_models.py | 182 ++++++++++++++++++ dwi_ml/training/batch_loaders.py | 7 +- dwi_ml/training/trainers.py | 55 +++++- scripts_python/ae_train_model.py | 148 ++++++++++++++ 8 files changed, 406 insertions(+), 20 deletions(-) create mode 100644 dwi_ml/models/projects/ae_models.py create mode 100755 scripts_python/ae_train_model.py diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index e46e534b..a01194a3 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -593,7 +593,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, def _process_one_streamline_group( self, subj_dir: Path, group: str, subj_id: str, - header: nib.Nifti1Header): + header: nib.Nifti1Header, remove_invalid=False): """ Loads and processes a group of tractograms and merges all streamlines together. @@ -681,12 +681,13 @@ def _process_one_streamline_group( # with Path. save_tractogram(final_sft, str(output_fname)) - # Removing invalid streamlines - logging.debug(' *Total: {:,.0f} streamlines. Now removing ' - 'invalid streamlines.'.format(len(final_sft))) - final_sft.remove_invalid_streamlines() - logging.info(" Final number of streamlines: {:,.0f}." - .format(len(final_sft))) + if remove_invalid: + # Removing invalid streamlines + logging.debug(' *Total: {:,.0f} streamlines. Now removing ' + 'invalid streamlines.'.format(len(final_sft))) + final_sft.remove_invalid_streamlines() + logging.info(" Final number of streamlines: {:,.0f}." + .format(len(final_sft))) conn_matrix = None conn_info = None @@ -741,10 +742,10 @@ def _load_and_process_sft(self, tractogram_file, tractogram_name, header): "We do not support file's type: {}. We only support .trk " "and .tck files.".format(tractogram_file)) if file_extension == '.trk': - if not is_header_compatible(str(tractogram_file), header): - raise ValueError("Streamlines group is not compatible with " - "volume groups\n ({})" - .format(tractogram_file)) + if header: + if not is_header_compatible(str(tractogram_file), header): + raise ValueError("Streamlines group is not compatible with " + "volume groups\n ({})".format(tractogram_file)) # overriding given header. header = 'same' diff --git a/dwi_ml/data/processing/streamlines/data_augmentation.py b/dwi_ml/data/processing/streamlines/data_augmentation.py index c3cedb9b..96d997d6 100644 --- a/dwi_ml/data/processing/streamlines/data_augmentation.py +++ b/dwi_ml/data/processing/streamlines/data_augmentation.py @@ -7,7 +7,7 @@ from dipy.io.stateful_tractogram import StatefulTractogram from nibabel.streamlines.tractogram import (PerArrayDict, PerArraySequenceDict) import numpy as np -from scilpy.tractograms.streamline_operations import resample_streamlines_step_size +from scilpy.tracking.tools import resample_streamlines_step_size from scilpy.utils.streamlines import compress_sft diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index 67bed921..9f319451 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -5,7 +5,7 @@ import numpy as np import torch -from scilpy.tractograms.uncompress import uncompress +from scilpy.tractanalysis.uncompress import uncompress from scilpy.tractanalysis.tools import \ extract_longest_segments_from_profile as segmenting_func diff --git a/dwi_ml/models/main_models.py b/dwi_ml/models/main_models.py index 14576222..f9043a19 100644 --- a/dwi_ml/models/main_models.py +++ b/dwi_ml/models/main_models.py @@ -16,8 +16,7 @@ from dwi_ml.data.processing.space.neighborhood import prepare_neighborhood_vectors from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.io_utils import add_resample_or_compress_arg -from dwi_ml.models.direction_getter_models import keys_to_direction_getters, \ - AbstractDirectionGetterModel +from dwi_ml.models.direction_getter_models import keys_to_direction_getters from dwi_ml.models.embeddings import keys_to_embeddings, NNEmbedding, NoEmbedding from dwi_ml.models.utils.direction_getters import add_direction_getter_args @@ -72,7 +71,7 @@ def __init__(self, experiment_name: str, # To tell our trainer what to send to the forward / loss methods. self.forward_uses_streamlines = False self.loss_uses_streamlines = False - + # To tell our batch loader how to resample streamlines during training # (should also be the step size during tractography). if step_size and compress_lines: @@ -208,7 +207,7 @@ def _load_state(cls, model_dir): def forward(self, *inputs, **kw): raise NotImplementedError - def compute_loss(self, *model_outputs, **kw): + def compute_loss(self, model_outputs, targets): raise NotImplementedError diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py new file mode 100644 index 00000000..fb750d6d --- /dev/null +++ b/dwi_ml/models/projects/ae_models.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +import logging +from typing import List + +import torch +from torch.nn import functional as F + +from dwi_ml.models.main_models import MainModelAbstract + + +class ModelAE(MainModelAbstract): + """ + Recurrent tracking model. + + Composed of an embedding for the imaging data's input + for the previous + direction's input, an RNN model to process the sequences, and a direction + getter model to convert the RNN outputs to the right structure, e.g. + deterministic (3D vectors) or probabilistic (based on probability + distribution parameters). + """ + def __init__(self, kernel_size, latent_space_dims, + experiment_name: str, + # Target preprocessing params for the batch loader + tracker + step_size: float = None, + compress_lines: float = False, + # Other + log_level=logging.root.level): + super().__init__(experiment_name, step_size, compress_lines, log_level) + + self.kernel_size = kernel_size + self.latent_space_dims = latent_space_dims + + self.pad = torch.nn.ReflectionPad1d(1) + + def pre_pad(m): + return torch.nn.Sequential(self.pad, m) + + self.fc1 = torch.nn.Linear(8192, + self.latent_space_dims) # 8192 = 1024*8 + self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + + """ + Encode convolutions + """ + self.encod_conv1 = pre_pad( + torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv2 = pre_pad( + torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv3 = pre_pad( + torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv4 = pre_pad( + torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv5 = pre_pad( + torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv6 = pre_pad( + torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0) + ) + + """ + Decode convolutions + """ + self.decod_conv1 = pre_pad( + torch.nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0) + ) + self.upsampl1 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv2 = pre_pad( + torch.nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0) + ) + self.upsampl2 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv3 = pre_pad( + torch.nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0) + ) + self.upsampl3 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv4 = pre_pad( + torch.nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0) + ) + self.upsampl4 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv5 = pre_pad( + torch.nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0) + ) + self.upsampl5 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv6 = pre_pad( + torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) + ) + + self.forward_uses_streamlines = True + self.loss_uses_streamlines = True + + def forward(self, + input_streamlines: List[torch.tensor], + ): + """Run the model on a batch of sequences. + + Parameters + ---------- + input_streamlines: List[torch.tensor], + Batch of streamlines. Only used if previous directions are added to + the model. Used to compute directions; its last point will not be + used. + + Returns + ------- + model_outputs : List[Tensor] + Output data, ready to be passed to either `compute_loss()` or + `get_tracking_directions()`. + """ + input_streamlines = torch.stack(input_streamlines) + input_streamlines = torch.swapaxes(input_streamlines, 1, 2) + + x = self.decode(self.encode(input_streamlines)) + return x + + def encode(self, x): + h1 = F.relu(self.encod_conv1(x)) + h2 = F.relu(self.encod_conv2(h1)) + h3 = F.relu(self.encod_conv3(h2)) + h4 = F.relu(self.encod_conv4(h3)) + h5 = F.relu(self.encod_conv5(h4)) + h6 = self.encod_conv6(h5) + + self.encoder_out_size = (h6.shape[1], h6.shape[2]) + + # Flatten + h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) + + fc1 = self.fc1(h7) + + return fc1 + + def decode(self, z): + fc = self.fc2(z) + fc_reshape = fc.view( + -1, self.encoder_out_size[0], self.encoder_out_size[1] + ) + h1 = F.relu(self.decod_conv1(fc_reshape)) + h2 = self.upsampl1(h1) + h3 = F.relu(self.decod_conv2(h2)) + h4 = self.upsampl2(h3) + h5 = F.relu(self.decod_conv3(h4)) + h6 = self.upsampl3(h5) + h7 = F.relu(self.decod_conv4(h6)) + h8 = self.upsampl4(h7) + h9 = F.relu(self.decod_conv5(h8)) + h10 = self.upsampl5(h9) + h11 = self.decod_conv6(h10) + + return h11 + + def compute_loss(self, model_outputs, targets, average_results=True): + print("COMPARISON\n") + targets = torch.stack(targets) + targets = torch.swapaxes(targets, 1, 2) + print(targets[0, :, 0:5]) + print(model_outputs[0, :, 0:5]) + reconstruction_loss = torch.nn.MSELoss(reduction="mean") + mse = reconstruction_loss(model_outputs, targets) + + # loss_function_vae + # See Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) + # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) + # kld = torch.sum(kld_element).__mul__(-0.5) + + return mse, 1 diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 0173caaf..5c4acbb9 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -190,7 +190,11 @@ def _data_augmentation_sft(self, sft): self.context_subset.compress == self.model.compress_lines: logger.debug("Compression rate is the same as when creating " "the hdf5 dataset. Not compressing again.") - else: + elif self.model.step_size is not None and \ + self.model.compress_lines is not None: + logger.debug("Resample streamlines using: \n" + + "- step_size: {}\n".format(self.model.step_size) + + "- compress_lines: {}".format(self.model.compress_lines)) sft = resample_or_compress(sft, self.model.step_size, self.model.compress_lines) @@ -306,6 +310,7 @@ def load_batch_streamlines( sft.to_vox() sft.to_corner() batch_streamlines.extend(sft.streamlines) + batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines] return batch_streamlines, final_s_ids_per_subj diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 297cdeeb..f7535c6e 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -4,7 +4,7 @@ import logging import os import shutil -from typing import Union, List, Tuple +from typing import Union, List from comet_ml import (Experiment as CometExperiment, ExistingExperiment) import numpy as np @@ -972,7 +972,58 @@ def run_one_batch(self, data): the list of all streamlines (if we concatenate all sfts' streamlines) """ - raise NotImplementedError + # Data interpolation has not been done yet. GPU computations are done + # here in the main thread. + targets, ids_per_subj = data + + # Dataloader always works on CPU. Sending to right device. + # (model is already moved). + targets = [s.to(self.device, non_blocking=True, dtype=torch.float) + for s in targets] + + # Getting the inputs points from the volumes. + # Uses the model's method, with the batch_loader's data. + # Possibly skipping the last point if not useful. + streamlines_f = targets + if isinstance(self.model, ModelWithDirectionGetter) and \ + not self.model.direction_getter.add_eos: + # No EOS = We don't use the last coord because it does not have an + # associated target direction. + streamlines_f = [s[:-1, :] for s in streamlines_f] + + # Possibly add noise to inputs here. + logger.debug('*** Computing forward propagation') + if self.model.forward_uses_streamlines: + # Now possibly add noise to streamlines (training / valid) + streamlines_f = self.batch_loader.add_noise_streamlines_forward( + streamlines_f, self.device) + + # Possibly computing directions twice (during forward and loss) + # but ok, shouldn't be too heavy. Easier to deal with multiple + # projects' requirements by sending whole streamlines rather + # than only directions. + model_outputs = self.model(streamlines_f) + del streamlines_f + else: + del streamlines_f + model_outputs = self.model() + + logger.debug('*** Computing loss') + if self.model.loss_uses_streamlines: + targets = self.batch_loader.add_noise_streamlines_loss( + targets, self.device) + + results = self.model.compute_loss(model_outputs, targets, + average_results=True) + else: + results = self.model.compute_loss(model_outputs, + average_results=True) + + if self.use_gpu: + log_gpu_memory_usage(logger) + + # The mean tensor is a single value. Converting to float using item(). + return results def fix_parameters(self): """ diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py new file mode 100755 index 00000000..4120ce88 --- /dev/null +++ b/scripts_python/ae_train_model.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Train a model for Autoencoders +""" +import argparse +import logging +import os + +# comet_ml not used, but comet_ml requires to be imported before torch. +# See bug report here https://github.com/Lightning-AI/lightning/issues/5829 +# Importing now to solve issues later. +import comet_ml +import torch + +from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist + +from dwi_ml.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.experiment_utils.prints import format_dict_to_str +from dwi_ml.experiment_utils.timer import Timer +from dwi_ml.io_utils import add_logging_arg, add_memory_args +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.training.trainers import DWIMLAbstractTrainer +from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, + prepare_batch_sampler) +from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) +from dwi_ml.training.batch_loaders import DWIMLAbstractBatchLoader +from dwi_ml.training.utils.experiment import ( + add_mandatory_args_experiment_and_hdf5_path) +from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ + format_lr + + +def prepare_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + add_mandatory_args_experiment_and_hdf5_path(p) + add_args_batch_sampler(p) + add_args_batch_loader(p) + training_group = add_training_args(p) + p.add_argument('streamline_group_name', + help="Name of the group in hdf5") + add_memory_args(p, add_lazy_options=True, add_rng=True) + add_logging_arg(p) + + # Additional arg for projects + training_group.add_argument( + '--clip_grad', type=float, default=None, + help="Value to which the gradient norms to avoid exploding gradients." + "\nDefault = None (not clipping).") + + return p + + +def init_from_args(args, sub_loggers_level): + torch.manual_seed(args.rng) # Set torch seed + + # Prepare the dataset + dataset = prepare_multisubjectdataset(args, load_testing=False, + log_level=sub_loggers_level) + + # Preparing the model + # (Direction getter) + # (Nb features) + # Final model + with Timer("\n\nPreparing model", newline=True, color='yellow'): + # INPUTS: verifying args + model = ModelAE( + experiment_name=args.experiment_name, + step_size=args.step_size, compress_lines=args.compress, + kernel_size=3, latent_space_dims=32, + log_level=sub_loggers_level) + + logging.info("AEmodel final parameters:" + + format_dict_to_str(model.params_for_checkpoint)) + + logging.info("Computed parameters:" + + format_dict_to_str(model.computed_params_for_display)) + + # Preparing the batch samplers + batch_sampler = prepare_batch_sampler(dataset, args, sub_loggers_level) + # Preparing the batch loader. + with Timer("\nPreparing batch loader...", newline=True, color='pink'): + batch_loader = DWIMLAbstractBatchLoader( + dataset=dataset, model=model, + streamline_group_name=args.streamline_group_name, + # OTHER + rng=args.rng, log_level=sub_loggers_level) + + logging.info("Loader user-defined parameters: " + + format_dict_to_str(batch_loader.params_for_checkpoint)) + + # Instantiate trainer + with Timer("\n\nPreparing trainer", newline=True, color='red'): + lr = format_lr(args.learning_rate) + trainer = DWIMLAbstractTrainer( + model=model, experiments_path=args.experiments_path, + experiment_name=args.experiment_name, batch_sampler=batch_sampler, + batch_loader=batch_loader, + # COMET + comet_project=args.comet_project, + comet_workspace=args.comet_workspace, + # TRAINING + learning_rates=lr, weight_decay=args.weight_decay, + optimizer=args.optimizer, max_epochs=args.max_epochs, + max_batches_per_epoch_training=args.max_batches_per_epoch_training, + max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, + patience=args.patience, patience_delta=args.patience_delta, + from_checkpoint=False, clip_grad=args.clip_grad, + # MEMORY + nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, + log_level=args.logging) + logging.info("Trainer params : " + + format_dict_to_str(trainer.params_for_checkpoint)) + + return trainer + + +def main(): + p = prepare_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.logging + if args.logging == 'DEBUG': + sub_loggers_level = 'INFO' + + logging.getLogger().setLevel(level=logging.INFO) + + # Check that all files exist + assert_inputs_exist(p, [args.hdf5_file]) + assert_outputs_exist(p, args, args.experiments_path) + + # Verify if a checkpoint has been saved. Else create an experiment. + if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, + "checkpoint")): + raise FileExistsError("This experiment already exists. Delete or use " + "script l2t_resume_training_from_checkpoint.py.") + + trainer = init_from_args(args, sub_loggers_level) + + run_experiment(trainer) + + +if __name__ == '__main__': + main() From 31e2e85d32cb393610fc1c4dab7c2d4e9098ef49 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 28 Nov 2023 09:17:25 -0500 Subject: [PATCH 02/24] modif with em --- dwi_ml/models/projects/ae_models.py | 30 ++++++++++++++++++++++++++--- dwi_ml/testing/visu_loss.py | 3 +-- scripts_python/ae_train_model.py | 2 +- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index fb750d6d..7acaaa40 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -101,6 +101,27 @@ def pre_pad(m): self.forward_uses_streamlines = True self.loss_uses_streamlines = True + @property + def params_for_checkpoint(self): + """All parameters necessary to create again the same model. Will be + used in the trainer, when saving the checkpoint state. Params here + will be used to re-create the model when starting an experiment from + checkpoint. You should be able to re-create an instance of your + model with those params.""" + #p = super().params_for_checkpoint() + p = {'kernel_size': self.kernel_size, + 'latent_space_dims': self.latent_space_dims, + 'experiment_name': self.experiment_name} + return p + + @classmethod + def _load_params(cls, model_dir): + p = super()._load_params(model_dir) + p['kernel_size'] = 3 + p['latent_space_dims'] = 32 + return p + + def forward(self, input_streamlines: List[torch.tensor], ): @@ -119,13 +140,16 @@ def forward(self, Output data, ready to be passed to either `compute_loss()` or `get_tracking_directions()`. """ - input_streamlines = torch.stack(input_streamlines) - input_streamlines = torch.swapaxes(input_streamlines, 1, 2) + x = self.decode(self.encode(input_streamlines)) return x def encode(self, x): + # x: list of tensors + x = torch.stack(x) + x = torch.swapaxes(x, 1, 2) + h1 = F.relu(self.encod_conv1(x)) h2 = F.relu(self.encod_conv2(h1)) h3 = F.relu(self.encod_conv3(h2)) @@ -167,7 +191,7 @@ def compute_loss(self, model_outputs, targets, average_results=True): targets = torch.swapaxes(targets, 1, 2) print(targets[0, :, 0:5]) print(model_outputs[0, :, 0:5]) - reconstruction_loss = torch.nn.MSELoss(reduction="mean") + reconstruction_loss = torch.nn.MSELoss(reduction="sum") mse = reconstruction_loss(model_outputs, targets) # loss_function_vae diff --git a/dwi_ml/testing/visu_loss.py b/dwi_ml/testing/visu_loss.py index 680d22b1..f867b6fe 100644 --- a/dwi_ml/testing/visu_loss.py +++ b/dwi_ml/testing/visu_loss.py @@ -15,8 +15,7 @@ from matplotlib import pyplot as plt from scilpy.io.utils import add_overwrite_arg -from dwi_ml.io_utils import add_logging_arg, add_arg_existing_experiment_path -from dwi_ml.io_utils import add_memory_args +from dwi_ml.io_utils import add_arg_existing_experiment_path, add_logging_arg, add_memory_args from dwi_ml.models.main_models import ModelWithDirectionGetter from dwi_ml.testing.utils import add_args_testing_subj_hdf5 diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 4120ce88..9b58e77d 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -68,7 +68,7 @@ def init_from_args(args, sub_loggers_level): # INPUTS: verifying args model = ModelAE( experiment_name=args.experiment_name, - step_size=args.step_size, compress_lines=args.compress, + step_size=None, compress_lines=None, kernel_size=3, latent_space_dims=32, log_level=sub_loggers_level) From 346c383e5bc0909dcfa8bb21b0683595562b04a0 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 28 Nov 2023 09:17:38 -0500 Subject: [PATCH 03/24] add visu --- scripts_python/ae_visualize_streamlines.py | 111 +++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 scripts_python/ae_visualize_streamlines.py diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py new file mode 100644 index 00000000..a2cb8881 --- /dev/null +++ b/scripts_python/ae_visualize_streamlines.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +import logging +import os.path + +import torch + +from scilpy.io.utils import add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_reference_arg +from scilpy.io.streamlines import load_tractogram_with_reference +from dipy.io.streamline import save_tractogram + +from dwi_ml.io_utils import add_logging_arg, add_arg_existing_experiment_path, add_memory_args +from dwi_ml.models.projects.learn2track_model import Learn2TrackModel +from dwi_ml.testing.testers import TesterOneInput +from dwi_ml.testing.visu_loss import \ + prepare_args_visu_loss, pick_a_few, run_visu_save_colored_displacement +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.testing.testers import Tester +from dwi_ml.testing.utils import add_args_testing_subj_hdf5 + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + # Mandatory + # Should only be False for debugging tests. + add_arg_existing_experiment_path(p) + #add_args_testing_subj_hdf5(p) + + p.add_argument('in_tractogram', + help="If set, saves the tractogram with the loss per point " + "as a data per point (color)") + + p.add_argument('out_tractogram', + help="If set, saves the tractogram with the loss per point " + "as a data per point (color)") + + # Options + p.add_argument('--batch_size', type=int) + add_memory_args(p) + + p.add_argument('--pick_at_random', action='store_true') + add_reference_arg(p) + add_overwrite_arg(p) + add_logging_arg(p) + return p + + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + # Loggers + sub_logger_level = args.logging.upper() + if sub_logger_level == 'DEBUG': + sub_logger_level = 'INFO' + logging.getLogger().setLevel(level=args.logging) + + # Verify output names + # Check experiment_path exists and best_model folder exists + #assert_inputs_exist(p, args.hdf5_file) + assert_outputs_exist(p, args, args.out_tractogram) + + # Device + device = (torch.device('cuda') if torch.cuda.is_available() and + args.use_gpu else None) + + # 1. Load model + logging.debug("Loading model.") + model = ModelAE.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_logger_level) + #model.set_context('training') + # 2. Compute loss + #tester = TesterOneInput(args.experiment_path, model, args.batch_size, device) + #tester = Tester(args.experiment_path, model, args.batch_size, device) + #sft = tester.load_and_format_data(args.subj_id, args.hdf5_file, args.subset) + + sft = load_tractogram_with_reference(p, args, args.in_tractogram) + sft.to_vox() + sft.to_corner() + bundle = sft.streamlines[0:5000] + + logging.info("Running model to compute loss") + + new_sft = sft.from_sft(bundle, sft) + save_tractogram(new_sft, 'orig_5000.trk') + + with torch.no_grad(): + streamlines = [ + torch.as_tensor(s, dtype=torch.float32, device=device) + for s in bundle] + tmp_outputs = model(streamlines) + #latent = model.encode(streamlines) + + streamlines_output = [tmp_outputs[i, :, :].transpose(0,1).cpu().numpy() for i in range(len(bundle))] + + #print(streamlines_output[0].shape) + new_sft = sft.from_sft(streamlines_output, sft) + save_tractogram(new_sft, args.out_tractogram) + + #latent_output = [s.cpu().numpy() for s in latent] + + + + #outputs, losses = tester.run_model_on_sft( + # sft, uncompress_loss=args.uncompress_loss, + # force_compress_loss=args.force_compress_loss, + # weight_with_angle=args.weight_with_angle) + +if __name__ == '__main__': + main() From a341891052614bf2563f730c7d2a9d7b150101dd Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Fri, 13 Sep 2024 15:14:28 -0400 Subject: [PATCH 04/24] fix pep8 --- dwi_ml/models/projects/ae_models.py | 4 +- scripts_python/ae_visualize_streamlines.py | 50 +++++++++++----------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 7acaaa40..66986da2 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -108,7 +108,7 @@ def params_for_checkpoint(self): will be used to re-create the model when starting an experiment from checkpoint. You should be able to re-create an instance of your model with those params.""" - #p = super().params_for_checkpoint() + # p = super().params_for_checkpoint() p = {'kernel_size': self.kernel_size, 'latent_space_dims': self.latent_space_dims, 'experiment_name': self.experiment_name} @@ -120,7 +120,6 @@ def _load_params(cls, model_dir): p['kernel_size'] = 3 p['latent_space_dims'] = 32 return p - def forward(self, input_streamlines: List[torch.tensor], @@ -141,7 +140,6 @@ def forward(self, `get_tracking_directions()`. """ - x = self.decode(self.encode(input_streamlines)) return x diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index a2cb8881..1d2e0f3f 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -2,22 +2,20 @@ # -*- coding: utf-8 -*- import argparse import logging -import os.path import torch -from scilpy.io.utils import add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_reference_arg +from scilpy.io.utils import (add_overwrite_arg, + assert_outputs_exist, + add_reference_arg) from scilpy.io.streamlines import load_tractogram_with_reference from dipy.io.streamline import save_tractogram -from dwi_ml.io_utils import add_logging_arg, add_arg_existing_experiment_path, add_memory_args -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.testing.testers import TesterOneInput -from dwi_ml.testing.visu_loss import \ - prepare_args_visu_loss, pick_a_few, run_visu_save_colored_displacement +from dwi_ml.io_utils import (add_logging_arg, + add_arg_existing_experiment_path, + add_memory_args) from dwi_ml.models.projects.ae_models import ModelAE -from dwi_ml.testing.testers import Tester -from dwi_ml.testing.utils import add_args_testing_subj_hdf5 + def _build_arg_parser(): p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, @@ -25,7 +23,7 @@ def _build_arg_parser(): # Mandatory # Should only be False for debugging tests. add_arg_existing_experiment_path(p) - #add_args_testing_subj_hdf5(p) + # Add_args_testing_subj_hdf5(p) p.add_argument('in_tractogram', help="If set, saves the tractogram with the loss per point " @@ -58,7 +56,7 @@ def main(): # Verify output names # Check experiment_path exists and best_model folder exists - #assert_inputs_exist(p, args.hdf5_file) + # Assert_inputs_exist(p, args.hdf5_file) assert_outputs_exist(p, args, args.out_tractogram) # Device @@ -69,11 +67,16 @@ def main(): logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( args.experiment_path + '/best_model', log_level=sub_logger_level) - #model.set_context('training') + # model.set_context('training') # 2. Compute loss - #tester = TesterOneInput(args.experiment_path, model, args.batch_size, device) - #tester = Tester(args.experiment_path, model, args.batch_size, device) - #sft = tester.load_and_format_data(args.subj_id, args.hdf5_file, args.subset) + # tester = TesterOneInput(args.experiment_path, + # model, + # args.batch_size, + # device) + # tester = Tester(args.experiment_path, model, args.batch_size, device) + # sft = tester.load_and_format_data(args.subj_id, + # args.hdf5_file, + # args.subset) sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() @@ -90,22 +93,21 @@ def main(): torch.as_tensor(s, dtype=torch.float32, device=device) for s in bundle] tmp_outputs = model(streamlines) - #latent = model.encode(streamlines) - - streamlines_output = [tmp_outputs[i, :, :].transpose(0,1).cpu().numpy() for i in range(len(bundle))] - - #print(streamlines_output[0].shape) - new_sft = sft.from_sft(streamlines_output, sft) - save_tractogram(new_sft, args.out_tractogram) + # latent = model.encode(streamlines) - #latent_output = [s.cpu().numpy() for s in latent] + streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] + # print(streamlines_output[0].shape) + new_sft = sft.from_sft(streamlines_output, sft) + save_tractogram(new_sft, args.out_tractogram) + # latent_output = [s.cpu().numpy() for s in latent] - #outputs, losses = tester.run_model_on_sft( + # outputs, losses = tester.run_model_on_sft( # sft, uncompress_loss=args.uncompress_loss, # force_compress_loss=args.force_compress_loss, # weight_with_angle=args.weight_with_angle) + if __name__ == '__main__': main() From 72789f3fa7a04a18157f2264bad9a36f74164ed3 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Fri, 13 Sep 2024 15:21:07 -0400 Subject: [PATCH 05/24] answer em comments from nov 2023 --- dwi_ml/models/projects/ae_models.py | 3 --- dwi_ml/training/batch_loaders.py | 4 ++-- dwi_ml/training/trainers.py | 37 ++++++++++++----------------- scripts_python/ae_train_model.py | 4 ++-- 4 files changed, 19 insertions(+), 29 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 66986da2..0818b6ad 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -98,9 +98,6 @@ def pre_pad(m): torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) ) - self.forward_uses_streamlines = True - self.loss_uses_streamlines = True - @property def params_for_checkpoint(self): """All parameters necessary to create again the same model. Will be diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index d82bed39..9ac3afaf 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -58,7 +58,7 @@ logger = logging.getLogger('batch_loader_logger') -class DWIMLAbstractBatchLoader: +class DWIMLStreamlinesBatchLoader: def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, streamline_group_name: str, rng: int, split_ratio: float = 0., @@ -355,7 +355,7 @@ def load_batch_connectivity_matrices( connectivity_nb_blocs, connectivity_labels) -class DWIMLBatchLoaderOneInput(DWIMLAbstractBatchLoader): +class DWIMLStreamlinesBatchLoaderOneInput(DWIMLStreamlinesBatchLoader): """ Loads: input = one volume group diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index a86442db..fa65cee1 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -1034,31 +1034,24 @@ def run_one_batch(self, data): # Possibly add noise to inputs here. logger.debug('*** Computing forward propagation') - if self.model.forward_uses_streamlines: - # Now possibly add noise to streamlines (training / valid) - streamlines_f = self.batch_loader.add_noise_streamlines_forward( - streamlines_f, self.device) - - # Possibly computing directions twice (during forward and loss) - # but ok, shouldn't be too heavy. Easier to deal with multiple - # projects' requirements by sending whole streamlines rather - # than only directions. - model_outputs = self.model(streamlines_f) - del streamlines_f - else: - del streamlines_f - model_outputs = self.model() + + # Now possibly add noise to streamlines (training / valid) + streamlines_f = self.batch_loader.add_noise_streamlines_forward( + streamlines_f, self.device) + + # Possibly computing directions twice (during forward and loss) + # but ok, shouldn't be too heavy. Easier to deal with multiple + # projects' requirements by sending whole streamlines rather + # than only directions. + model_outputs = self.model(streamlines_f) + del streamlines_f logger.debug('*** Computing loss') - if self.model.loss_uses_streamlines: - targets = self.batch_loader.add_noise_streamlines_loss( - targets, self.device) + targets = self.batch_loader.add_noise_streamlines_loss( + targets, self.device) - results = self.model.compute_loss(model_outputs, targets, - average_results=True) - else: - results = self.model.compute_loss(model_outputs, - average_results=True) + results = self.model.compute_loss(model_outputs, targets, + average_results=True) if self.use_gpu: log_gpu_memory_usage(logger) diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 9b58e77d..6bcd6046 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -25,7 +25,7 @@ from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, prepare_batch_sampler) from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) -from dwi_ml.training.batch_loaders import DWIMLAbstractBatchLoader +from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader from dwi_ml.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ @@ -82,7 +82,7 @@ def init_from_args(args, sub_loggers_level): batch_sampler = prepare_batch_sampler(dataset, args, sub_loggers_level) # Preparing the batch loader. with Timer("\nPreparing batch loader...", newline=True, color='pink'): - batch_loader = DWIMLAbstractBatchLoader( + batch_loader = DWIMLStreamlinesBatchLoader( dataset=dataset, model=model, streamline_group_name=args.streamline_group_name, # OTHER From 02cab7bb2be3a93c25f0679bcbd83388ff2fe077 Mon Sep 17 00:00:00 2001 From: bora2502 Date: Tue, 17 Sep 2024 10:06:12 -0400 Subject: [PATCH 06/24] fix naming class --- dwi_ml/training/batch_loaders.py | 2 +- dwi_ml/training/trainers.py | 8 ++++---- scripts_python/ae_train_model.py | 16 ++++++++-------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 9ac3afaf..61f13640 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -355,7 +355,7 @@ def load_batch_connectivity_matrices( connectivity_nb_blocs, connectivity_labels) -class DWIMLStreamlinesBatchLoaderOneInput(DWIMLStreamlinesBatchLoader): +class DWIMLBatchLoaderOneInput(DWIMLStreamlinesBatchLoader): """ Loads: input = one volume group diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index fa65cee1..2aae19ac 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -17,7 +17,7 @@ from dwi_ml.models.main_models import (MainModelAbstract, ModelWithDirectionGetter) from dwi_ml.training.batch_loaders import ( - DWIMLAbstractBatchLoader, DWIMLBatchLoaderOneInput) + DWIMLStreamlinesBatchLoader, DWIMLBatchLoaderOneInput) from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.utils.gradient_norm import compute_gradient_norm from dwi_ml.training.utils.monitoring import ( @@ -53,7 +53,7 @@ class DWIMLAbstractTrainer: def __init__(self, model: MainModelAbstract, experiments_path: str, experiment_name: str, batch_sampler: DWIMLBatchIDSampler, - batch_loader: DWIMLAbstractBatchLoader, + batch_loader: DWIMLStreamlinesBatchLoader, learning_rates: Union[List, float] = None, weight_decay: float = 0.01, optimizer: str = 'Adam', max_epochs: int = 10, @@ -78,7 +78,7 @@ def __init__(self, batch_sampler: DWIMLBatchIDSampler Instantiated class used for sampling batches. Data in batch_sampler.dataset must be already loaded. - batch_loader: DWIMLAbstractBatchLoader + batch_loader: DWIMLStreamlinesBatchLoader Instantiated class with a load_batch method able to load data associated to sampled batch ids. Data in batch_sampler.dataset must be already loaded. @@ -461,7 +461,7 @@ def _prepare_checkpoint_info(self) -> dict: def init_from_checkpoint( cls, model: MainModelAbstract, experiments_path, experiment_name, batch_sampler: DWIMLBatchIDSampler, - batch_loader: DWIMLAbstractBatchLoader, + batch_loader: DWIMLStreamlinesBatchLoader, checkpoint_state: dict, new_patience, new_max_epochs, log_level): """ Loads checkpoint information (parameters and states) to instantiate diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 6bcd6046..a291fec4 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -14,12 +14,12 @@ import comet_ml import torch -from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist +from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, add_verbose_arg from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_logging_arg, add_memory_args +from dwi_ml.io_utils import add_memory_args from dwi_ml.models.projects.ae_models import ModelAE from dwi_ml.training.trainers import DWIMLAbstractTrainer from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, @@ -38,17 +38,17 @@ def prepare_arg_parser(): add_mandatory_args_experiment_and_hdf5_path(p) add_args_batch_sampler(p) add_args_batch_loader(p) - training_group = add_training_args(p) + #training_group = add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") add_memory_args(p, add_lazy_options=True, add_rng=True) - add_logging_arg(p) + add_verbose_arg(p) # Additional arg for projects - training_group.add_argument( - '--clip_grad', type=float, default=None, - help="Value to which the gradient norms to avoid exploding gradients." - "\nDefault = None (not clipping).") + #training_group.add_argument( + # '--clip_grad', type=float, default=None, + # help="Value to which the gradient norms to avoid exploding gradients." + # "\nDefault = None (not clipping).") return p From 77918758d8fdbe2a3c29e2be1351798e09599669 Mon Sep 17 00:00:00 2001 From: bora2502 Date: Tue, 17 Sep 2024 13:55:00 -0400 Subject: [PATCH 07/24] fix script --- scripts_python/ae_train_model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index a291fec4..20d416b1 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -25,6 +25,7 @@ from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, prepare_batch_sampler) from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) +from dwi_ml.training.utils.trainer import add_training_args from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader from dwi_ml.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) @@ -39,6 +40,7 @@ def prepare_arg_parser(): add_args_batch_sampler(p) add_args_batch_loader(p) #training_group = add_training_args(p) + add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") add_memory_args(p, add_lazy_options=True, add_rng=True) @@ -110,7 +112,7 @@ def init_from_args(args, sub_loggers_level): from_checkpoint=False, clip_grad=args.clip_grad, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, - log_level=args.logging) + log_level=sub_loggers_level) logging.info("Trainer params : " + format_dict_to_str(trainer.params_for_checkpoint)) @@ -123,11 +125,10 @@ def main(): # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, # but we will set trainer to user-defined level. - sub_loggers_level = args.logging - if args.logging == 'DEBUG': - sub_loggers_level = 'INFO' + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' - logging.getLogger().setLevel(level=logging.INFO) + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) # Check that all files exist assert_inputs_exist(p, [args.hdf5_file]) From f4701be2c284b95ea991ecd95d433eae43a4993a Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Wed, 18 Sep 2024 10:29:25 -0400 Subject: [PATCH 08/24] fix viz --- scripts_python/ae_visualize_streamlines.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 1d2e0f3f..ef942d96 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -7,13 +7,12 @@ from scilpy.io.utils import (add_overwrite_arg, assert_outputs_exist, - add_reference_arg) + add_reference_arg, + add_verbose_arg) from scilpy.io.streamlines import load_tractogram_with_reference from dipy.io.streamline import save_tractogram - -from dwi_ml.io_utils import (add_logging_arg, - add_arg_existing_experiment_path, - add_memory_args) +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_memory_args) from dwi_ml.models.projects.ae_models import ModelAE @@ -40,7 +39,7 @@ def _build_arg_parser(): p.add_argument('--pick_at_random', action='store_true') add_reference_arg(p) add_overwrite_arg(p) - add_logging_arg(p) + add_verbose_arg(p) return p From c4bd181c16a2b80bcb768b4ac1e2aa4eee63da5b Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Wed, 18 Sep 2024 10:44:24 -0400 Subject: [PATCH 09/24] quick fix ae_vis_streamline --- scripts_python/ae_visualize_streamlines.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index ef942d96..31526453 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -47,11 +47,12 @@ def main(): p = _build_arg_parser() args = p.parse_args() - # Loggers - sub_logger_level = args.logging.upper() - if sub_logger_level == 'DEBUG': - sub_logger_level = 'INFO' - logging.getLogger().setLevel(level=args.logging) + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) # Verify output names # Check experiment_path exists and best_model folder exists @@ -65,7 +66,7 @@ def main(): # 1. Load model logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_logger_level) + args.experiment_path + '/best_model', log_level=sub_loggers_level) # model.set_context('training') # 2. Compute loss # tester = TesterOneInput(args.experiment_path, From 448043b17b33480c5022887ba31b788496846897 Mon Sep 17 00:00:00 2001 From: Antoine Theberge Date: Thu, 19 Sep 2024 07:54:19 -0400 Subject: [PATCH 10/24] WIP: transformer ae --- command_ae.sh | 22 +++++ dwi_ml/models/projects/ae_models.py | 101 ++++++++++++++++----- scripts_python/ae_visualize_streamlines.py | 4 +- 3 files changed, 103 insertions(+), 24 deletions(-) create mode 100755 command_ae.sh diff --git a/command_ae.sh b/command_ae.sh new file mode 100755 index 00000000..17f3e3de --- /dev/null +++ b/command_ae.sh @@ -0,0 +1,22 @@ +experiments=experiments +experiment_name=fibercup_september24 + +rm -rf $experiments/$experiment_name + +ae_train_model.py $experiments \ + $experiment_name \ + fibercup.hdf5 \ + target \ + -v INFO \ + --batch_size_training 80 \ + --batch_size_units nb_streamlines \ + --nb_subjects_per_batch 1 \ + --learning_rate 0.001 \ + --weight_decay 0.05 \ + --optimizer Adam \ + --max_epochs 1000 \ + --max_batches_per_epoch_training 20 \ + --comet_workspace dwi_ml \ + --comet_project ae-fibercup \ + --patience 100 \ + --use_gpu diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 0818b6ad..c4703406 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -1,13 +1,48 @@ # -*- coding: utf-8 -*- import logging +import math + from typing import List import torch +from torch import nn +from torch import Tensor from torch.nn import functional as F from dwi_ml.models.main_models import MainModelAbstract +class PositionalEncoding(nn.Module): + """ Modified from + https://pytorch.org/tutorials/beginner/transformer_tutorial.htm://pytorch.org/tutorials/beginner/transformer_tutorial.html # noqa E504 + """ + + def __init__( + self, d_model: int, dropout: float = 0.1, max_len: int = 5000 + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) + * (-math.log(10000.0) / d_model)) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x: Tensor) -> Tensor: + """ + Arguments: + x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` + """ + x = x.permute(1, 0, 2) + x = x + self.pe[:x.size(0)] + x = self.dropout(x) + x = x.permute(1, 0, 2) + return x + + class ModelAE(MainModelAbstract): """ Recurrent tracking model. @@ -27,23 +62,45 @@ def __init__(self, kernel_size, latent_space_dims, log_level=logging.root.level): super().__init__(experiment_name, step_size, compress_lines, log_level) - self.kernel_size = kernel_size - self.latent_space_dims = latent_space_dims + # Embedding size, could be defined by the user ? + self.embedding_size = 32 + # Embedding layer + self.embedding = nn.Sequential( + *(nn.Linear(3, self.embedding_size), + nn.ReLU())) - self.pad = torch.nn.ReflectionPad1d(1) + # Positional encoding layer + self.pos_encoding = PositionalEncoding( + self.embedding_size, max_len=(256)) + # Transformer encoder layer + layer = nn.TransformerEncoderLayer( + self.embedding_size, 4, batch_first=True) - def pre_pad(m): - return torch.nn.Sequential(self.pad, m) + # Transformer encoder + self.encoder = nn.TransformerEncoder(layer, 2) + self.decoder = nn.TransformerEncoder(layer, 2) + + self.reconstruction_loss = torch.nn.MSELoss() + + self.pad = torch.nn.ReflectionPad1d(1) + self.kernel_size = kernel_size + self.latent_space_dims = latent_space_dims self.fc1 = torch.nn.Linear(8192, self.latent_space_dims) # 8192 = 1024*8 self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + self.fc3 = torch.nn.Linear(self.embedding_size, 3) + + def pre_pad(m): + return torch.nn.Sequential(self.pad, m) + """ Encode convolutions """ self.encod_conv1 = pre_pad( - torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0) + torch.nn.Conv1d(self.embedding_size, 32, + self.kernel_size, stride=2, padding=0) ) self.encod_conv2 = pre_pad( torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) @@ -95,7 +152,8 @@ def pre_pad(m): scale_factor=2, mode="linear", align_corners=False ) self.decod_conv6 = pre_pad( - torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) + torch.nn.Conv1d(32, 32, + self.kernel_size, stride=1, padding=0) ) @property @@ -143,6 +201,11 @@ def forward(self, def encode(self, x): # x: list of tensors x = torch.stack(x) + + x = self.embedding(x) * math.sqrt(self.embedding_size) + x = self.pos_encoding(x) + x = self.encoder(x) + x = torch.swapaxes(x, 1, 2) h1 = F.relu(self.encod_conv1(x)) @@ -162,6 +225,7 @@ def encode(self, x): return fc1 def decode(self, z): + fc = self.fc2(z) fc_reshape = fc.view( -1, self.encoder_out_size[0], self.encoder_out_size[1] @@ -178,24 +242,17 @@ def decode(self, z): h10 = self.upsampl5(h9) h11 = self.decod_conv6(h10) - return h11 + h11 = h11.permute(0, 2, 1) + + h12 = self.decoder(h11) + + x = self.fc3(h12) + + return x.permute(0, 2, 1) def compute_loss(self, model_outputs, targets, average_results=True): - print("COMPARISON\n") targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) - print(targets[0, :, 0:5]) - print(model_outputs[0, :, 0:5]) - reconstruction_loss = torch.nn.MSELoss(reduction="sum") - mse = reconstruction_loss(model_outputs, targets) - - # loss_function_vae - # See Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) - # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) - # kld = torch.sum(kld_element).__mul__(-0.5) + mse = self.reconstruction_loss(model_outputs, targets) return mse, 1 diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 31526453..37e95352 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -66,7 +66,7 @@ def main(): # 1. Load model logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level) + args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) # model.set_context('training') # 2. Compute loss # tester = TesterOneInput(args.experiment_path, @@ -81,7 +81,7 @@ def main(): sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() sft.to_corner() - bundle = sft.streamlines[0:5000] + bundle = sft.streamlines logging.info("Running model to compute loss") From 7bf9fe308ae36cd8fe0f852d4557fd1c38ad1179 Mon Sep 17 00:00:00 2001 From: Antoine Theberge Date: Thu, 19 Sep 2024 07:55:36 -0400 Subject: [PATCH 11/24] Revert "WIP: transformer ae" This reverts commit 448043b17b33480c5022887ba31b788496846897. --- command_ae.sh | 22 ----- dwi_ml/models/projects/ae_models.py | 101 +++++---------------- scripts_python/ae_visualize_streamlines.py | 4 +- 3 files changed, 24 insertions(+), 103 deletions(-) delete mode 100755 command_ae.sh diff --git a/command_ae.sh b/command_ae.sh deleted file mode 100755 index 17f3e3de..00000000 --- a/command_ae.sh +++ /dev/null @@ -1,22 +0,0 @@ -experiments=experiments -experiment_name=fibercup_september24 - -rm -rf $experiments/$experiment_name - -ae_train_model.py $experiments \ - $experiment_name \ - fibercup.hdf5 \ - target \ - -v INFO \ - --batch_size_training 80 \ - --batch_size_units nb_streamlines \ - --nb_subjects_per_batch 1 \ - --learning_rate 0.001 \ - --weight_decay 0.05 \ - --optimizer Adam \ - --max_epochs 1000 \ - --max_batches_per_epoch_training 20 \ - --comet_workspace dwi_ml \ - --comet_project ae-fibercup \ - --patience 100 \ - --use_gpu diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index c4703406..0818b6ad 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -1,48 +1,13 @@ # -*- coding: utf-8 -*- import logging -import math - from typing import List import torch -from torch import nn -from torch import Tensor from torch.nn import functional as F from dwi_ml.models.main_models import MainModelAbstract -class PositionalEncoding(nn.Module): - """ Modified from - https://pytorch.org/tutorials/beginner/transformer_tutorial.htm://pytorch.org/tutorials/beginner/transformer_tutorial.html # noqa E504 - """ - - def __init__( - self, d_model: int, dropout: float = 0.1, max_len: int = 5000 - ): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) - * (-math.log(10000.0) / d_model)) - pe = torch.zeros(max_len, 1, d_model) - pe[:, 0, 0::2] = torch.sin(position * div_term) - pe[:, 0, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) - - def forward(self, x: Tensor) -> Tensor: - """ - Arguments: - x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` - """ - x = x.permute(1, 0, 2) - x = x + self.pe[:x.size(0)] - x = self.dropout(x) - x = x.permute(1, 0, 2) - return x - - class ModelAE(MainModelAbstract): """ Recurrent tracking model. @@ -62,45 +27,23 @@ def __init__(self, kernel_size, latent_space_dims, log_level=logging.root.level): super().__init__(experiment_name, step_size, compress_lines, log_level) - # Embedding size, could be defined by the user ? - self.embedding_size = 32 - # Embedding layer - self.embedding = nn.Sequential( - *(nn.Linear(3, self.embedding_size), - nn.ReLU())) - - # Positional encoding layer - self.pos_encoding = PositionalEncoding( - self.embedding_size, max_len=(256)) - # Transformer encoder layer - layer = nn.TransformerEncoderLayer( - self.embedding_size, 4, batch_first=True) - - # Transformer encoder - self.encoder = nn.TransformerEncoder(layer, 2) - self.decoder = nn.TransformerEncoder(layer, 2) - - self.reconstruction_loss = torch.nn.MSELoss() - - self.pad = torch.nn.ReflectionPad1d(1) self.kernel_size = kernel_size self.latent_space_dims = latent_space_dims - self.fc1 = torch.nn.Linear(8192, - self.latent_space_dims) # 8192 = 1024*8 - self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) - - self.fc3 = torch.nn.Linear(self.embedding_size, 3) + self.pad = torch.nn.ReflectionPad1d(1) def pre_pad(m): return torch.nn.Sequential(self.pad, m) + self.fc1 = torch.nn.Linear(8192, + self.latent_space_dims) # 8192 = 1024*8 + self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + """ Encode convolutions """ self.encod_conv1 = pre_pad( - torch.nn.Conv1d(self.embedding_size, 32, - self.kernel_size, stride=2, padding=0) + torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0) ) self.encod_conv2 = pre_pad( torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) @@ -152,8 +95,7 @@ def pre_pad(m): scale_factor=2, mode="linear", align_corners=False ) self.decod_conv6 = pre_pad( - torch.nn.Conv1d(32, 32, - self.kernel_size, stride=1, padding=0) + torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) ) @property @@ -201,11 +143,6 @@ def forward(self, def encode(self, x): # x: list of tensors x = torch.stack(x) - - x = self.embedding(x) * math.sqrt(self.embedding_size) - x = self.pos_encoding(x) - x = self.encoder(x) - x = torch.swapaxes(x, 1, 2) h1 = F.relu(self.encod_conv1(x)) @@ -225,7 +162,6 @@ def encode(self, x): return fc1 def decode(self, z): - fc = self.fc2(z) fc_reshape = fc.view( -1, self.encoder_out_size[0], self.encoder_out_size[1] @@ -242,17 +178,24 @@ def decode(self, z): h10 = self.upsampl5(h9) h11 = self.decod_conv6(h10) - h11 = h11.permute(0, 2, 1) - - h12 = self.decoder(h11) - - x = self.fc3(h12) - - return x.permute(0, 2, 1) + return h11 def compute_loss(self, model_outputs, targets, average_results=True): + print("COMPARISON\n") targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) - mse = self.reconstruction_loss(model_outputs, targets) + print(targets[0, :, 0:5]) + print(model_outputs[0, :, 0:5]) + reconstruction_loss = torch.nn.MSELoss(reduction="sum") + mse = reconstruction_loss(model_outputs, targets) + + # loss_function_vae + # See Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) + # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) + # kld = torch.sum(kld_element).__mul__(-0.5) return mse, 1 diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 37e95352..31526453 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -66,7 +66,7 @@ def main(): # 1. Load model logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level).to(device) + args.experiment_path + '/best_model', log_level=sub_loggers_level) # model.set_context('training') # 2. Compute loss # tester = TesterOneInput(args.experiment_path, @@ -81,7 +81,7 @@ def main(): sft = load_tractogram_with_reference(p, args, args.in_tractogram) sft.to_vox() sft.to_corner() - bundle = sft.streamlines + bundle = sft.streamlines[0:5000] logging.info("Running model to compute loss") From 72be6fe012d17ce562e70b31d5413d7351b0e4b7 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Thu, 19 Sep 2024 13:15:35 -0400 Subject: [PATCH 12/24] set bbox to false when saving trk --- scripts_python/ae_visualize_streamlines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py index 31526453..93e5735d 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_visualize_streamlines.py @@ -99,7 +99,7 @@ def main(): # print(streamlines_output[0].shape) new_sft = sft.from_sft(streamlines_output, sft) - save_tractogram(new_sft, args.out_tractogram) + save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) # latent_output = [s.cpu().numpy() for s in latent] From 4f04fca75b56090d11dc90dbf31fb1333f9cd57b Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Thu, 26 Sep 2024 10:35:38 -0400 Subject: [PATCH 13/24] add jeremi comments --- dwi_ml/models/projects/ae_models.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 0818b6ad..fd5338eb 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -32,6 +32,8 @@ def __init__(self, kernel_size, latent_space_dims, self.pad = torch.nn.ReflectionPad1d(1) + self.post_encoding_hooks = [] + def pre_pad(m): return torch.nn.Sequential(self.pad, m) @@ -136,14 +138,20 @@ def forward(self, Output data, ready to be passed to either `compute_loss()` or `get_tracking_directions()`. """ + encoded = self.encode(input_streamlines) + + for hook in self.post_encoding_hooks: + hook(encoded) - x = self.decode(self.encode(input_streamlines)) - return x + model_outputs = self.decode(encoded) + return model_outputs def encode(self, x): - # x: list of tensors - x = torch.stack(x) - x = torch.swapaxes(x, 1, 2) + # X input shape is (batch_size, nb_points, 3) + if isinstance(x, list): + x = torch.stack(x) + + x = torch.swapaxes(x, 1, 2) # input of the network should be (N, 3, nb_points) h1 = F.relu(self.encod_conv1(x)) h2 = F.relu(self.encod_conv2(h1)) @@ -199,3 +207,6 @@ def compute_loss(self, model_outputs, targets, average_results=True): # kld = torch.sum(kld_element).__mul__(-0.5) return mse, 1 + + def register_hook_post_encoding(self, hook): + self.post_encoding_hooks.append(hook) \ No newline at end of file From 29d690f82b7778e9eaf02aa779070008b1c3e7b5 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Thu, 26 Sep 2024 13:06:25 -0400 Subject: [PATCH 14/24] Revert "add jeremi comments" This reverts commit 4f04fca75b56090d11dc90dbf31fb1333f9cd57b. --- dwi_ml/models/projects/ae_models.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index fd5338eb..0818b6ad 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -32,8 +32,6 @@ def __init__(self, kernel_size, latent_space_dims, self.pad = torch.nn.ReflectionPad1d(1) - self.post_encoding_hooks = [] - def pre_pad(m): return torch.nn.Sequential(self.pad, m) @@ -138,20 +136,14 @@ def forward(self, Output data, ready to be passed to either `compute_loss()` or `get_tracking_directions()`. """ - encoded = self.encode(input_streamlines) - - for hook in self.post_encoding_hooks: - hook(encoded) - model_outputs = self.decode(encoded) - return model_outputs + x = self.decode(self.encode(input_streamlines)) + return x def encode(self, x): - # X input shape is (batch_size, nb_points, 3) - if isinstance(x, list): - x = torch.stack(x) - - x = torch.swapaxes(x, 1, 2) # input of the network should be (N, 3, nb_points) + # x: list of tensors + x = torch.stack(x) + x = torch.swapaxes(x, 1, 2) h1 = F.relu(self.encod_conv1(x)) h2 = F.relu(self.encod_conv2(h1)) @@ -207,6 +199,3 @@ def compute_loss(self, model_outputs, targets, average_results=True): # kld = torch.sum(kld_element).__mul__(-0.5) return mse, 1 - - def register_hook_post_encoding(self, hook): - self.post_encoding_hooks.append(hook) \ No newline at end of file From 8511db60e054e2d2b6e3abbe875e6718365d6fc3 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Thu, 26 Sep 2024 13:10:45 -0400 Subject: [PATCH 15/24] make it a little bit prettier waiting for PR244 to be merged --- dwi_ml/models/projects/ae_models.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 0818b6ad..d836fd0f 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -181,21 +181,9 @@ def decode(self, z): return h11 def compute_loss(self, model_outputs, targets, average_results=True): - print("COMPARISON\n") + targets = torch.stack(targets) targets = torch.swapaxes(targets, 1, 2) - print(targets[0, :, 0:5]) - print(model_outputs[0, :, 0:5]) reconstruction_loss = torch.nn.MSELoss(reduction="sum") mse = reconstruction_loss(model_outputs, targets) - - # loss_function_vae - # See Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - # kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) - # kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) - # kld = torch.sum(kld_element).__mul__(-0.5) - return mse, 1 From 517a530945d9af12aeea6881b324e8ac0954f167 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Mon, 30 Sep 2024 11:36:21 -0400 Subject: [PATCH 16/24] rename vis streamline to autoencode tractogram, add save tractogram on the fly --- dwi_ml/data/hdf5/hdf5_creation.py | 20 +---- dwi_ml/data/processing/streamlines/utils.py | 52 +++++++++++++ ...amlines.py => ae_autoencode_tractogram.py} | 74 ++++++++----------- scripts_python/ae_train_model.py | 17 ++--- 4 files changed, 89 insertions(+), 74 deletions(-) create mode 100644 dwi_ml/data/processing/streamlines/utils.py rename scripts_python/{ae_visualize_streamlines.py => ae_autoencode_tractogram.py} (56%) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 02e88d45..fbb9d419 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -634,7 +634,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, raise ValueError( "The data_per_streamline key '{}' was not found in " "the sft. Check your tractogram file.".format(dps_key)) - + logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key)) streamlines_group.create_dataset('dps_' + dps_key, data=sft.data_per_streamline[dps_key]) @@ -719,17 +719,6 @@ def _process_one_streamline_group( # with Path. save_tractogram(final_sft, str(output_fname)) -<<<<<<< HEAD - if remove_invalid: - # Removing invalid streamlines - logging.debug(' *Total: {:,.0f} streamlines. Now removing ' - 'invalid streamlines.'.format(len(final_sft))) - final_sft.remove_invalid_streamlines() - logging.info(" Final number of streamlines: {:,.0f}." - .format(len(final_sft))) - -======= ->>>>>>> master conn_matrix = None conn_info = None if 'connectivity_matrix' in self.groups_config[group]: @@ -773,18 +762,11 @@ def _load_and_process_sft(self, tractogram_file, header): "We do not support file's type: {}. We only support .trk " "and .tck files.".format(tractogram_file)) if file_extension == '.trk': -<<<<<<< HEAD - if header: - if not is_header_compatible(str(tractogram_file), header): - raise ValueError("Streamlines group is not compatible with " - "volume groups\n ({})".format(tractogram_file)) -======= if header and not is_header_compatible(str(tractogram_file), header): raise ValueError("Streamlines group is not compatible " "with volume groups\n ({})" .format(tractogram_file)) ->>>>>>> master # overriding given header. header = 'same' diff --git a/dwi_ml/data/processing/streamlines/utils.py b/dwi_ml/data/processing/streamlines/utils.py new file mode 100644 index 00000000..d191ad68 --- /dev/null +++ b/dwi_ml/data/processing/streamlines/utils.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- + +import numpy as np + +from dipy.tracking.streamline import set_number_of_points + +import torch + + +def _autoencode_streamlines(model, sft, batch_size, normalize, device): + """ + Autoencode streamlines using the given model. + + Parameters + ---------- + model : torch.nn.Module + The model to use for autoencoding. + batch_size : int + The batch size to use for encoding. + bundle : list of np.ndarray + The streamlines to encode. + normalize : bool + Whether to normalize the streamlines before encoding. + sft : StatefulTractogram + The stateful tractogram containing the streamlines. + device : torch.device + The device to use for encoding. + + Returns + ------- + generator + A generator that yields the encoded streamlines. + """ + + batches = range(0, len(sft.streamlines), batch_size) + for i, b in enumerate(batches): + with torch.no_grad(): + s = np.asarray( + set_number_of_points( + sft.streamlines[i * batch_size:(i+1) * batch_size], + 256)) + if normalize: + s /= sft.dimensions + + streamlines = torch.as_tensor( + s, dtype=torch.float32, device=device) + tmp_outputs = model(streamlines).cpu().numpy() + + scaling = sft.dimensions if normalize else 1.0 + streamlines_output = tmp_outputs.transpose((0, 2, 1)) * scaling + for strml in streamlines_output: + yield strml, strml[0] diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_autoencode_tractogram.py similarity index 56% rename from scripts_python/ae_visualize_streamlines.py rename to scripts_python/ae_autoencode_tractogram.py index 93e5735d..97932bfc 100644 --- a/scripts_python/ae_visualize_streamlines.py +++ b/scripts_python/ae_autoencode_tractogram.py @@ -3,6 +3,10 @@ import argparse import logging +import nibabel as nb +from nibabel.streamlines import detect_format +import numpy as np + import torch from scilpy.io.utils import (add_overwrite_arg, @@ -12,18 +16,16 @@ from scilpy.io.streamlines import load_tractogram_with_reference from dipy.io.streamline import save_tractogram from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_memory_args) + add_memory_args) +from dwi_ml.data.processing.streamlines.utils import _autoencode_streamlines from dwi_ml.models.projects.ae_models import ModelAE def _build_arg_parser(): p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, description=__doc__) - # Mandatory - # Should only be False for debugging tests. - add_arg_existing_experiment_path(p) - # Add_args_testing_subj_hdf5(p) + add_arg_existing_experiment_path(p) p.add_argument('in_tractogram', help="If set, saves the tractogram with the loss per point " "as a data per point (color)") @@ -33,10 +35,11 @@ def _build_arg_parser(): "as a data per point (color)") # Options - p.add_argument('--batch_size', type=int) + p.add_argument('--batch_size', type=int, default=5000) + p.add_argument('--normalize', action='store_true', + help="If set, normalize the input data " + "before running the model.") add_memory_args(p) - - p.add_argument('--pick_at_random', action='store_true') add_reference_arg(p) add_overwrite_arg(p) add_verbose_arg(p) @@ -67,46 +70,31 @@ def main(): logging.debug("Loading model.") model = ModelAE.load_model_from_params_and_state( args.experiment_path + '/best_model', log_level=sub_loggers_level) - # model.set_context('training') - # 2. Compute loss - # tester = TesterOneInput(args.experiment_path, - # model, - # args.batch_size, - # device) - # tester = Tester(args.experiment_path, model, args.batch_size, device) - # sft = tester.load_and_format_data(args.subj_id, - # args.hdf5_file, - # args.subset) + # 2. Load tractogram sft = load_tractogram_with_reference(p, args, args.in_tractogram) + tracts_format = detect_format(args.out_tractogram) sft.to_vox() sft.to_corner() - bundle = sft.streamlines[0:5000] - - logging.info("Running model to compute loss") - - new_sft = sft.from_sft(bundle, sft) - save_tractogram(new_sft, 'orig_5000.trk') - - with torch.no_grad(): - streamlines = [ - torch.as_tensor(s, dtype=torch.float32, device=device) - for s in bundle] - tmp_outputs = model(streamlines) - # latent = model.encode(streamlines) - - streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] - - # print(streamlines_output[0].shape) - new_sft = sft.from_sft(streamlines_output, sft) - save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) - - # latent_output = [s.cpu().numpy() for s in latent] - # outputs, losses = tester.run_model_on_sft( - # sft, uncompress_loss=args.uncompress_loss, - # force_compress_loss=args.force_compress_loss, - # weight_with_angle=args.weight_with_angle) + logging.info("Running Model") + # Need a nifti image to lazy-save a tractogram + fake_ref = nb.Nifti1Image(np.zeros(sft.dimensions), sft.affine) + + save_tractogram(_autoencode_streamlines(model, + sft, + args.batch_size, + sft, + device), + tracts_format=tracts_format, + ref_img=fake_ref, + total_nb_seeds=len(sft.streamlines), + out_tractogram=args.out_tractogram, + min_length=0, + max_length=999, + compress=False, + save_seeds=False, + verbose=args.verbose) if __name__ == '__main__': diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 20d416b1..05d1e3f1 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -11,10 +11,11 @@ # comet_ml not used, but comet_ml requires to be imported before torch. # See bug report here https://github.com/Lightning-AI/lightning/issues/5829 # Importing now to solve issues later. -import comet_ml +import comet_ml # noqa F401 import torch -from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, add_verbose_arg +from scilpy.io.utils import (assert_inputs_exist, assert_outputs_exist, + add_verbose_arg) from dwi_ml.data.dataset.utils import prepare_multisubjectdataset from dwi_ml.experiment_utils.prints import format_dict_to_str @@ -25,12 +26,11 @@ from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, prepare_batch_sampler) from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) -from dwi_ml.training.utils.trainer import add_training_args +from dwi_ml.training.utils.trainer import (add_training_args, run_experiment, + format_lr) from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader from dwi_ml.training.utils.experiment import ( add_mandatory_args_experiment_and_hdf5_path) -from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ - format_lr def prepare_arg_parser(): @@ -39,19 +39,12 @@ def prepare_arg_parser(): add_mandatory_args_experiment_and_hdf5_path(p) add_args_batch_sampler(p) add_args_batch_loader(p) - #training_group = add_training_args(p) add_training_args(p) p.add_argument('streamline_group_name', help="Name of the group in hdf5") add_memory_args(p, add_lazy_options=True, add_rng=True) add_verbose_arg(p) - # Additional arg for projects - #training_group.add_argument( - # '--clip_grad', type=float, default=None, - # help="Value to which the gradient norms to avoid exploding gradients." - # "\nDefault = None (not clipping).") - return p From 327b6598bc84e1ff13f64bd007fbf0a50b0b2009 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Mon, 30 Sep 2024 11:37:13 -0400 Subject: [PATCH 17/24] change permission script --- scripts_python/ae_autoencode_tractogram.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 scripts_python/ae_autoencode_tractogram.py diff --git a/scripts_python/ae_autoencode_tractogram.py b/scripts_python/ae_autoencode_tractogram.py old mode 100644 new mode 100755 From 3711e4da6760ff3f26042ba34333144e1b18e3e4 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 1 Oct 2024 11:30:46 -0400 Subject: [PATCH 18/24] remove scripts wait for Antoine script --- dwi_ml/data/processing/streamlines/utils.py | 52 ---------- scripts_python/ae_autoencode_tractogram.py | 101 -------------------- 2 files changed, 153 deletions(-) delete mode 100644 dwi_ml/data/processing/streamlines/utils.py delete mode 100755 scripts_python/ae_autoencode_tractogram.py diff --git a/dwi_ml/data/processing/streamlines/utils.py b/dwi_ml/data/processing/streamlines/utils.py deleted file mode 100644 index d191ad68..00000000 --- a/dwi_ml/data/processing/streamlines/utils.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- - -import numpy as np - -from dipy.tracking.streamline import set_number_of_points - -import torch - - -def _autoencode_streamlines(model, sft, batch_size, normalize, device): - """ - Autoencode streamlines using the given model. - - Parameters - ---------- - model : torch.nn.Module - The model to use for autoencoding. - batch_size : int - The batch size to use for encoding. - bundle : list of np.ndarray - The streamlines to encode. - normalize : bool - Whether to normalize the streamlines before encoding. - sft : StatefulTractogram - The stateful tractogram containing the streamlines. - device : torch.device - The device to use for encoding. - - Returns - ------- - generator - A generator that yields the encoded streamlines. - """ - - batches = range(0, len(sft.streamlines), batch_size) - for i, b in enumerate(batches): - with torch.no_grad(): - s = np.asarray( - set_number_of_points( - sft.streamlines[i * batch_size:(i+1) * batch_size], - 256)) - if normalize: - s /= sft.dimensions - - streamlines = torch.as_tensor( - s, dtype=torch.float32, device=device) - tmp_outputs = model(streamlines).cpu().numpy() - - scaling = sft.dimensions if normalize else 1.0 - streamlines_output = tmp_outputs.transpose((0, 2, 1)) * scaling - for strml in streamlines_output: - yield strml, strml[0] diff --git a/scripts_python/ae_autoencode_tractogram.py b/scripts_python/ae_autoencode_tractogram.py deleted file mode 100755 index 97932bfc..00000000 --- a/scripts_python/ae_autoencode_tractogram.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import argparse -import logging - -import nibabel as nb -from nibabel.streamlines import detect_format -import numpy as np - -import torch - -from scilpy.io.utils import (add_overwrite_arg, - assert_outputs_exist, - add_reference_arg, - add_verbose_arg) -from scilpy.io.streamlines import load_tractogram_with_reference -from dipy.io.streamline import save_tractogram -from dwi_ml.io_utils import (add_arg_existing_experiment_path, - add_memory_args) -from dwi_ml.data.processing.streamlines.utils import _autoencode_streamlines -from dwi_ml.models.projects.ae_models import ModelAE - - -def _build_arg_parser(): - p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, - description=__doc__) - - add_arg_existing_experiment_path(p) - p.add_argument('in_tractogram', - help="If set, saves the tractogram with the loss per point " - "as a data per point (color)") - - p.add_argument('out_tractogram', - help="If set, saves the tractogram with the loss per point " - "as a data per point (color)") - - # Options - p.add_argument('--batch_size', type=int, default=5000) - p.add_argument('--normalize', action='store_true', - help="If set, normalize the input data " - "before running the model.") - add_memory_args(p) - add_reference_arg(p) - add_overwrite_arg(p) - add_verbose_arg(p) - return p - - -def main(): - p = _build_arg_parser() - args = p.parse_args() - - # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, - # but we will set trainer to user-defined level. - sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' - - # General logging (ex, scilpy: Warning) - logging.getLogger().setLevel(level=logging.WARNING) - - # Verify output names - # Check experiment_path exists and best_model folder exists - # Assert_inputs_exist(p, args.hdf5_file) - assert_outputs_exist(p, args, args.out_tractogram) - - # Device - device = (torch.device('cuda') if torch.cuda.is_available() and - args.use_gpu else None) - - # 1. Load model - logging.debug("Loading model.") - model = ModelAE.load_model_from_params_and_state( - args.experiment_path + '/best_model', log_level=sub_loggers_level) - - # 2. Load tractogram - sft = load_tractogram_with_reference(p, args, args.in_tractogram) - tracts_format = detect_format(args.out_tractogram) - sft.to_vox() - sft.to_corner() - - logging.info("Running Model") - # Need a nifti image to lazy-save a tractogram - fake_ref = nb.Nifti1Image(np.zeros(sft.dimensions), sft.affine) - - save_tractogram(_autoencode_streamlines(model, - sft, - args.batch_size, - sft, - device), - tracts_format=tracts_format, - ref_img=fake_ref, - total_nb_seeds=len(sft.streamlines), - out_tractogram=args.out_tractogram, - min_length=0, - max_length=999, - compress=False, - save_seeds=False, - verbose=args.verbose) - - -if __name__ == '__main__': - main() From 7521801ddeea02d36d489ccc5daff95881532bc1 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 1 Oct 2024 11:33:49 -0400 Subject: [PATCH 19/24] fix hdf5 --- dwi_ml/data/hdf5/hdf5_creation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index fbb9d419..cd789e42 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -652,7 +652,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, def _process_one_streamline_group( self, subj_dir: Path, group: str, subj_id: str, - header: nib.Nifti1Header, remove_invalid=False): + header: nib.Nifti1Header): """ Loads and processes a group of tractograms and merges all streamlines together. @@ -669,8 +669,6 @@ def _process_one_streamline_group( Reference used to load and send the streamlines in voxel space and to create final merged SFT. If the file is a .trk, 'same' is used instead. - remove_invalid : bool - If True, invalid streamlines will be removed Returns ------- From 97d1169dae1c4f68b1790ed92f2ede5dd96897e7 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 1 Oct 2024 14:44:53 -0400 Subject: [PATCH 20/24] answer em comments --- dwi_ml/models/main_models.py | 1 + dwi_ml/models/projects/ae_models.py | 35 +++++++---------------------- dwi_ml/training/batch_loaders.py | 9 +++++--- dwi_ml/training/trainers.py | 6 ----- scripts_python/ae_train_model.py | 7 +++--- 5 files changed, 19 insertions(+), 39 deletions(-) diff --git a/dwi_ml/models/main_models.py b/dwi_ml/models/main_models.py index cf3af707..cdb4e79d 100644 --- a/dwi_ml/models/main_models.py +++ b/dwi_ml/models/main_models.py @@ -128,6 +128,7 @@ def params_for_checkpoint(self): 'experiment_name': self.experiment_name, 'step_size': self.step_size, 'compress_lines': self.compress_lines, + 'nb_points': self.nb_points, } @property diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index d836fd0f..5a82e84a 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -18,17 +18,18 @@ class ModelAE(MainModelAbstract): deterministic (3D vectors) or probabilistic (based on probability distribution parameters). """ - def __init__(self, kernel_size, latent_space_dims, + def __init__(self, experiment_name: str, - # Target preprocessing params for the batch loader + tracker - step_size: float = None, - compress_lines: float = False, # Other log_level=logging.root.level): - super().__init__(experiment_name, step_size, compress_lines, log_level) + super().__init__(experiment_name, + step_size=None, + nb_points=None, + compress_lines=None, + log_level=log_level) - self.kernel_size = kernel_size - self.latent_space_dims = latent_space_dims + self.kernel_size = 3 + self.latent_space_dims = 32 self.pad = torch.nn.ReflectionPad1d(1) @@ -98,26 +99,6 @@ def pre_pad(m): torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) ) - @property - def params_for_checkpoint(self): - """All parameters necessary to create again the same model. Will be - used in the trainer, when saving the checkpoint state. Params here - will be used to re-create the model when starting an experiment from - checkpoint. You should be able to re-create an instance of your - model with those params.""" - # p = super().params_for_checkpoint() - p = {'kernel_size': self.kernel_size, - 'latent_space_dims': self.latent_space_dims, - 'experiment_name': self.experiment_name} - return p - - @classmethod - def _load_params(cls, model_dir): - p = super()._load_params(model_dir) - p['kernel_size'] = 3 - p['latent_space_dims'] = 32 - return p - def forward(self, input_streamlines: List[torch.tensor], ): diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 43f16136..623371a9 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -197,11 +197,14 @@ def _data_augmentation_sft(self, sft): self.context_subset.compress == self.model.compress_lines: logger.debug("Compression rate is the same as when creating " "the hdf5 dataset. Not compressing again.") - elif self.model.step_size is not None and \ - self.model.compress_lines is not None: + elif self.model.nb_points is not None and self.model.nb_points == self.context_subset.nb_points: + logging.debug("Number of points per streamline is the same" + " as when creating the hdf5. Not resampling again.") + else: logger.debug("Resample streamlines using: \n" + "- step_size: {}\n".format(self.model.step_size) + - "- compress_lines: {}".format(self.model.compress_lines)) + "- compress_lines: {}".format(self.model.compress_lines) + + "- nb_points: {}".format(self.model.nb_points)) sft = resample_or_compress(sft, self.model.step_size, self.model.nb_points, self.model.compress_lines) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 2aae19ac..15621360 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -1022,15 +1022,9 @@ def run_one_batch(self, data): targets = [s.to(self.device, non_blocking=True, dtype=torch.float) for s in targets] - # Getting the inputs points from the volumes. # Uses the model's method, with the batch_loader's data. # Possibly skipping the last point if not useful. streamlines_f = targets - if isinstance(self.model, ModelWithDirectionGetter) and \ - not self.model.direction_getter.add_eos: - # No EOS = We don't use the last coord because it does not have an - # associated target direction. - streamlines_f = [s[:-1, :] for s in streamlines_f] # Possibly add noise to inputs here. logger.debug('*** Computing forward propagation') diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 05d1e3f1..e32f9448 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -63,8 +63,9 @@ def init_from_args(args, sub_loggers_level): # INPUTS: verifying args model = ModelAE( experiment_name=args.experiment_name, - step_size=None, compress_lines=None, - kernel_size=3, latent_space_dims=32, + step_size=None, + nb_points=None, + compress_lines=None, log_level=sub_loggers_level) logging.info("AEmodel final parameters:" + @@ -131,7 +132,7 @@ def main(): if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, "checkpoint")): raise FileExistsError("This experiment already exists. Delete or use " - "script l2t_resume_training_from_checkpoint.py.") + "script ae_resume_training_from_checkpoint.py.") trainer = init_from_args(args, sub_loggers_level) From 765ad56062304ce16c3135ad92d434cf33428bfe Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 1 Oct 2024 14:47:57 -0400 Subject: [PATCH 21/24] fix init modelAE --- scripts_python/ae_train_model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index e32f9448..6bf9aec9 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -63,9 +63,6 @@ def init_from_args(args, sub_loggers_level): # INPUTS: verifying args model = ModelAE( experiment_name=args.experiment_name, - step_size=None, - nb_points=None, - compress_lines=None, log_level=sub_loggers_level) logging.info("AEmodel final parameters:" + From de40049a5b155e102963c59532a0b99572e69a15 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 1 Oct 2024 15:07:44 -0400 Subject: [PATCH 22/24] add unused nb_points for others models --- dwi_ml/models/projects/ae_models.py | 9 ++++++--- dwi_ml/models/projects/learn2track_model.py | 4 +++- dwi_ml/models/projects/transformer_models.py | 8 +++++--- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py index 5a82e84a..eff22243 100644 --- a/dwi_ml/models/projects/ae_models.py +++ b/dwi_ml/models/projects/ae_models.py @@ -20,12 +20,15 @@ class ModelAE(MainModelAbstract): """ def __init__(self, experiment_name: str, + step_size: float = None, + nb_points: int = None, + compress_lines: float = False, # Other log_level=logging.root.level): super().__init__(experiment_name, - step_size=None, - nb_points=None, - compress_lines=None, + step_size=step_size, + nb_points=nb_points, + compress_lines=compress_lines, log_level=log_level) self.kernel_size = 3 diff --git a/dwi_ml/models/projects/learn2track_model.py b/dwi_ml/models/projects/learn2track_model.py index 3961eacd..9ba8074c 100644 --- a/dwi_ml/models/projects/learn2track_model.py +++ b/dwi_ml/models/projects/learn2track_model.py @@ -102,7 +102,8 @@ def __init__(self, experiment_name, neighborhood_type: Optional[str] = None, neighborhood_radius: Optional[int] = None, neighborhood_resolution: Optional[float] = None, - log_level=logging.root.level): + log_level=logging.root.level, + nb_points: Optional[int] = None): """ Params ------ @@ -133,6 +134,7 @@ def __init__(self, experiment_name, """ super().__init__( experiment_name=experiment_name, step_size=step_size, + nb_points=nb_points, compress_lines=compress_lines, log_level=log_level, # For modelWithNeighborhood neighborhood_type=neighborhood_type, diff --git a/dwi_ml/models/projects/transformer_models.py b/dwi_ml/models/projects/transformer_models.py index e10609d4..a93a5c8e 100644 --- a/dwi_ml/models/projects/transformer_models.py +++ b/dwi_ml/models/projects/transformer_models.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import logging -from typing import Union, List, Tuple, Optional +from typing import Union, List, Optional from dipy.data import get_sphere import numpy as np @@ -137,7 +137,8 @@ def __init__(self, neighborhood_type: Optional[str] = None, neighborhood_radius: Optional[int] = None, neighborhood_resolution: Optional[float] = None, - log_level=logging.root.level): + log_level=logging.root.level, + nb_points: Optional[int] = None): """ Note about embedding size: In the original model + SrcOnly model: defines d_model. @@ -185,6 +186,7 @@ def __init__(self, super().__init__( # MainAbstract experiment_name=experiment_name, step_size=step_size, + nb_points=nb_points, compress_lines=compress_lines, log_level=log_level, # Neighborhood neighborhood_type=neighborhood_type, @@ -610,7 +612,7 @@ def _prepare_data(self, inputs, _): def _run_embeddings(self, inputs, use_padding, batch_max_len): return self._run_input_embedding(inputs, use_padding, batch_max_len) - + def _run_position_encoding(self, inputs): inputs = self.position_encoding_layer(inputs) inputs = self.dropout(inputs) From e57153b2afc4e5813b3e388fef4a315a5b988818 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 1 Oct 2024 15:37:47 -0400 Subject: [PATCH 23/24] fix condition resampling and compressing --- dwi_ml/data/processing/streamlines/data_augmentation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dwi_ml/data/processing/streamlines/data_augmentation.py b/dwi_ml/data/processing/streamlines/data_augmentation.py index 18428cb9..cd46a185 100644 --- a/dwi_ml/data/processing/streamlines/data_augmentation.py +++ b/dwi_ml/data/processing/streamlines/data_augmentation.py @@ -17,15 +17,15 @@ def resample_or_compress(sft, step_size_mm: float = None, nb_points: int = None, compress: float = None, remove_invalid: bool = False): - if step_size_mm is not None: + if step_size_mm: # Note. No matter the chosen space, resampling is done in mm. logging.debug(" Resampling (step size): {}mm".format(step_size_mm)) sft = resample_streamlines_step_size(sft, step_size=step_size_mm) - elif nb_points is not None: + elif nb_points: logging.debug(" Resampling: " + "{} points per streamline".format(nb_points)) sft = resample_streamlines_num_points(sft, nb_points) - elif compress is not None: + elif compress: logging.debug(" Compressing: {}".format(compress)) sft = compress_sft(sft, compress) From fc55338f5e988185d057a7d8a78c2e430b14fc16 Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Tue, 1 Oct 2024 16:32:34 -0400 Subject: [PATCH 24/24] fix pep8 --- scripts_python/ae_train_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py index 6bf9aec9..e5846235 100755 --- a/scripts_python/ae_train_model.py +++ b/scripts_python/ae_train_model.py @@ -11,7 +11,7 @@ # comet_ml not used, but comet_ml requires to be imported before torch. # See bug report here https://github.com/Lightning-AI/lightning/issues/5829 # Importing now to solve issues later. -import comet_ml # noqa F401 +import comet_ml # noqa F401 import torch from scilpy.io.utils import (assert_inputs_exist, assert_outputs_exist,