diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 3221d007..02e88d45 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -652,7 +652,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, def _process_one_streamline_group( self, subj_dir: Path, group: str, subj_id: str, - header: nib.Nifti1Header): + header: nib.Nifti1Header, remove_invalid=False): """ Loads and processes a group of tractograms and merges all streamlines together. @@ -719,6 +719,17 @@ def _process_one_streamline_group( # with Path. save_tractogram(final_sft, str(output_fname)) +<<<<<<< HEAD + if remove_invalid: + # Removing invalid streamlines + logging.debug(' *Total: {:,.0f} streamlines. Now removing ' + 'invalid streamlines.'.format(len(final_sft))) + final_sft.remove_invalid_streamlines() + logging.info(" Final number of streamlines: {:,.0f}." + .format(len(final_sft))) + +======= +>>>>>>> master conn_matrix = None conn_info = None if 'connectivity_matrix' in self.groups_config[group]: @@ -762,11 +773,18 @@ def _load_and_process_sft(self, tractogram_file, header): "We do not support file's type: {}. We only support .trk " "and .tck files.".format(tractogram_file)) if file_extension == '.trk': +<<<<<<< HEAD + if header: + if not is_header_compatible(str(tractogram_file), header): + raise ValueError("Streamlines group is not compatible with " + "volume groups\n ({})".format(tractogram_file)) +======= if header and not is_header_compatible(str(tractogram_file), header): raise ValueError("Streamlines group is not compatible " "with volume groups\n ({})" .format(tractogram_file)) +>>>>>>> master # overriding given header. header = 'same' diff --git a/dwi_ml/models/projects/ae_models.py b/dwi_ml/models/projects/ae_models.py new file mode 100644 index 00000000..d836fd0f --- /dev/null +++ b/dwi_ml/models/projects/ae_models.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +import logging +from typing import List + +import torch +from torch.nn import functional as F + +from dwi_ml.models.main_models import MainModelAbstract + + +class ModelAE(MainModelAbstract): + """ + Recurrent tracking model. + + Composed of an embedding for the imaging data's input + for the previous + direction's input, an RNN model to process the sequences, and a direction + getter model to convert the RNN outputs to the right structure, e.g. + deterministic (3D vectors) or probabilistic (based on probability + distribution parameters). + """ + def __init__(self, kernel_size, latent_space_dims, + experiment_name: str, + # Target preprocessing params for the batch loader + tracker + step_size: float = None, + compress_lines: float = False, + # Other + log_level=logging.root.level): + super().__init__(experiment_name, step_size, compress_lines, log_level) + + self.kernel_size = kernel_size + self.latent_space_dims = latent_space_dims + + self.pad = torch.nn.ReflectionPad1d(1) + + def pre_pad(m): + return torch.nn.Sequential(self.pad, m) + + self.fc1 = torch.nn.Linear(8192, + self.latent_space_dims) # 8192 = 1024*8 + self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192) + + """ + Encode convolutions + """ + self.encod_conv1 = pre_pad( + torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv2 = pre_pad( + torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv3 = pre_pad( + torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv4 = pre_pad( + torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv5 = pre_pad( + torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0) + ) + self.encod_conv6 = pre_pad( + torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0) + ) + + """ + Decode convolutions + """ + self.decod_conv1 = pre_pad( + torch.nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0) + ) + self.upsampl1 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv2 = pre_pad( + torch.nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0) + ) + self.upsampl2 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv3 = pre_pad( + torch.nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0) + ) + self.upsampl3 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv4 = pre_pad( + torch.nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0) + ) + self.upsampl4 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv5 = pre_pad( + torch.nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0) + ) + self.upsampl5 = torch.nn.Upsample( + scale_factor=2, mode="linear", align_corners=False + ) + self.decod_conv6 = pre_pad( + torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0) + ) + + @property + def params_for_checkpoint(self): + """All parameters necessary to create again the same model. Will be + used in the trainer, when saving the checkpoint state. Params here + will be used to re-create the model when starting an experiment from + checkpoint. You should be able to re-create an instance of your + model with those params.""" + # p = super().params_for_checkpoint() + p = {'kernel_size': self.kernel_size, + 'latent_space_dims': self.latent_space_dims, + 'experiment_name': self.experiment_name} + return p + + @classmethod + def _load_params(cls, model_dir): + p = super()._load_params(model_dir) + p['kernel_size'] = 3 + p['latent_space_dims'] = 32 + return p + + def forward(self, + input_streamlines: List[torch.tensor], + ): + """Run the model on a batch of sequences. + + Parameters + ---------- + input_streamlines: List[torch.tensor], + Batch of streamlines. Only used if previous directions are added to + the model. Used to compute directions; its last point will not be + used. + + Returns + ------- + model_outputs : List[Tensor] + Output data, ready to be passed to either `compute_loss()` or + `get_tracking_directions()`. + """ + + x = self.decode(self.encode(input_streamlines)) + return x + + def encode(self, x): + # x: list of tensors + x = torch.stack(x) + x = torch.swapaxes(x, 1, 2) + + h1 = F.relu(self.encod_conv1(x)) + h2 = F.relu(self.encod_conv2(h1)) + h3 = F.relu(self.encod_conv3(h2)) + h4 = F.relu(self.encod_conv4(h3)) + h5 = F.relu(self.encod_conv5(h4)) + h6 = self.encod_conv6(h5) + + self.encoder_out_size = (h6.shape[1], h6.shape[2]) + + # Flatten + h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1]) + + fc1 = self.fc1(h7) + + return fc1 + + def decode(self, z): + fc = self.fc2(z) + fc_reshape = fc.view( + -1, self.encoder_out_size[0], self.encoder_out_size[1] + ) + h1 = F.relu(self.decod_conv1(fc_reshape)) + h2 = self.upsampl1(h1) + h3 = F.relu(self.decod_conv2(h2)) + h4 = self.upsampl2(h3) + h5 = F.relu(self.decod_conv3(h4)) + h6 = self.upsampl3(h5) + h7 = F.relu(self.decod_conv4(h6)) + h8 = self.upsampl4(h7) + h9 = F.relu(self.decod_conv5(h8)) + h10 = self.upsampl5(h9) + h11 = self.decod_conv6(h10) + + return h11 + + def compute_loss(self, model_outputs, targets, average_results=True): + + targets = torch.stack(targets) + targets = torch.swapaxes(targets, 1, 2) + reconstruction_loss = torch.nn.MSELoss(reduction="sum") + mse = reconstruction_loss(model_outputs, targets) + return mse, 1 diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 743830bb..43f16136 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -58,7 +58,7 @@ logger = logging.getLogger('batch_loader_logger') -class DWIMLAbstractBatchLoader: +class DWIMLStreamlinesBatchLoader: def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, streamline_group_name: str, rng: int, split_ratio: float = 0., @@ -197,7 +197,11 @@ def _data_augmentation_sft(self, sft): self.context_subset.compress == self.model.compress_lines: logger.debug("Compression rate is the same as when creating " "the hdf5 dataset. Not compressing again.") - else: + elif self.model.step_size is not None and \ + self.model.compress_lines is not None: + logger.debug("Resample streamlines using: \n" + + "- step_size: {}\n".format(self.model.step_size) + + "- compress_lines: {}".format(self.model.compress_lines)) sft = resample_or_compress(sft, self.model.step_size, self.model.nb_points, self.model.compress_lines) @@ -314,6 +318,7 @@ def load_batch_streamlines( sft.to_vox() sft.to_corner() batch_streamlines.extend(sft.streamlines) + batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines] return batch_streamlines, final_s_ids_per_subj @@ -351,7 +356,7 @@ def load_batch_connectivity_matrices( connectivity_nb_blocs, connectivity_labels) -class DWIMLBatchLoaderOneInput(DWIMLAbstractBatchLoader): +class DWIMLBatchLoaderOneInput(DWIMLStreamlinesBatchLoader): """ Loads: input = one volume group diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 4355ad0f..2aae19ac 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -17,7 +17,7 @@ from dwi_ml.models.main_models import (MainModelAbstract, ModelWithDirectionGetter) from dwi_ml.training.batch_loaders import ( - DWIMLAbstractBatchLoader, DWIMLBatchLoaderOneInput) + DWIMLStreamlinesBatchLoader, DWIMLBatchLoaderOneInput) from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.utils.gradient_norm import compute_gradient_norm from dwi_ml.training.utils.monitoring import ( @@ -53,7 +53,7 @@ class DWIMLAbstractTrainer: def __init__(self, model: MainModelAbstract, experiments_path: str, experiment_name: str, batch_sampler: DWIMLBatchIDSampler, - batch_loader: DWIMLAbstractBatchLoader, + batch_loader: DWIMLStreamlinesBatchLoader, learning_rates: Union[List, float] = None, weight_decay: float = 0.01, optimizer: str = 'Adam', max_epochs: int = 10, @@ -78,7 +78,7 @@ def __init__(self, batch_sampler: DWIMLBatchIDSampler Instantiated class used for sampling batches. Data in batch_sampler.dataset must be already loaded. - batch_loader: DWIMLAbstractBatchLoader + batch_loader: DWIMLStreamlinesBatchLoader Instantiated class with a load_batch method able to load data associated to sampled batch ids. Data in batch_sampler.dataset must be already loaded. @@ -461,7 +461,7 @@ def _prepare_checkpoint_info(self) -> dict: def init_from_checkpoint( cls, model: MainModelAbstract, experiments_path, experiment_name, batch_sampler: DWIMLBatchIDSampler, - batch_loader: DWIMLAbstractBatchLoader, + batch_loader: DWIMLStreamlinesBatchLoader, checkpoint_state: dict, new_patience, new_max_epochs, log_level): """ Loads checkpoint information (parameters and states) to instantiate @@ -1013,7 +1013,51 @@ def run_one_batch(self, data): Any other data returned when computing loss. Not used in the trainer, but could be useful anywhere else. """ - raise NotImplementedError + # Data interpolation has not been done yet. GPU computations are done + # here in the main thread. + targets, ids_per_subj = data + + # Dataloader always works on CPU. Sending to right device. + # (model is already moved). + targets = [s.to(self.device, non_blocking=True, dtype=torch.float) + for s in targets] + + # Getting the inputs points from the volumes. + # Uses the model's method, with the batch_loader's data. + # Possibly skipping the last point if not useful. + streamlines_f = targets + if isinstance(self.model, ModelWithDirectionGetter) and \ + not self.model.direction_getter.add_eos: + # No EOS = We don't use the last coord because it does not have an + # associated target direction. + streamlines_f = [s[:-1, :] for s in streamlines_f] + + # Possibly add noise to inputs here. + logger.debug('*** Computing forward propagation') + + # Now possibly add noise to streamlines (training / valid) + streamlines_f = self.batch_loader.add_noise_streamlines_forward( + streamlines_f, self.device) + + # Possibly computing directions twice (during forward and loss) + # but ok, shouldn't be too heavy. Easier to deal with multiple + # projects' requirements by sending whole streamlines rather + # than only directions. + model_outputs = self.model(streamlines_f) + del streamlines_f + + logger.debug('*** Computing loss') + targets = self.batch_loader.add_noise_streamlines_loss( + targets, self.device) + + results = self.model.compute_loss(model_outputs, targets, + average_results=True) + + if self.use_gpu: + log_gpu_memory_usage(logger) + + # The mean tensor is a single value. Converting to float using item(). + return results def fix_parameters(self): """ diff --git a/scripts_python/ae_train_model.py b/scripts_python/ae_train_model.py new file mode 100755 index 00000000..20d416b1 --- /dev/null +++ b/scripts_python/ae_train_model.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Train a model for Autoencoders +""" +import argparse +import logging +import os + +# comet_ml not used, but comet_ml requires to be imported before torch. +# See bug report here https://github.com/Lightning-AI/lightning/issues/5829 +# Importing now to solve issues later. +import comet_ml +import torch + +from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, add_verbose_arg + +from dwi_ml.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.experiment_utils.prints import format_dict_to_str +from dwi_ml.experiment_utils.timer import Timer +from dwi_ml.io_utils import add_memory_args +from dwi_ml.models.projects.ae_models import ModelAE +from dwi_ml.training.trainers import DWIMLAbstractTrainer +from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, + prepare_batch_sampler) +from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader) +from dwi_ml.training.utils.trainer import add_training_args +from dwi_ml.training.batch_loaders import DWIMLStreamlinesBatchLoader +from dwi_ml.training.utils.experiment import ( + add_mandatory_args_experiment_and_hdf5_path) +from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ + format_lr + + +def prepare_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + add_mandatory_args_experiment_and_hdf5_path(p) + add_args_batch_sampler(p) + add_args_batch_loader(p) + #training_group = add_training_args(p) + add_training_args(p) + p.add_argument('streamline_group_name', + help="Name of the group in hdf5") + add_memory_args(p, add_lazy_options=True, add_rng=True) + add_verbose_arg(p) + + # Additional arg for projects + #training_group.add_argument( + # '--clip_grad', type=float, default=None, + # help="Value to which the gradient norms to avoid exploding gradients." + # "\nDefault = None (not clipping).") + + return p + + +def init_from_args(args, sub_loggers_level): + torch.manual_seed(args.rng) # Set torch seed + + # Prepare the dataset + dataset = prepare_multisubjectdataset(args, load_testing=False, + log_level=sub_loggers_level) + + # Preparing the model + # (Direction getter) + # (Nb features) + # Final model + with Timer("\n\nPreparing model", newline=True, color='yellow'): + # INPUTS: verifying args + model = ModelAE( + experiment_name=args.experiment_name, + step_size=None, compress_lines=None, + kernel_size=3, latent_space_dims=32, + log_level=sub_loggers_level) + + logging.info("AEmodel final parameters:" + + format_dict_to_str(model.params_for_checkpoint)) + + logging.info("Computed parameters:" + + format_dict_to_str(model.computed_params_for_display)) + + # Preparing the batch samplers + batch_sampler = prepare_batch_sampler(dataset, args, sub_loggers_level) + # Preparing the batch loader. + with Timer("\nPreparing batch loader...", newline=True, color='pink'): + batch_loader = DWIMLStreamlinesBatchLoader( + dataset=dataset, model=model, + streamline_group_name=args.streamline_group_name, + # OTHER + rng=args.rng, log_level=sub_loggers_level) + + logging.info("Loader user-defined parameters: " + + format_dict_to_str(batch_loader.params_for_checkpoint)) + + # Instantiate trainer + with Timer("\n\nPreparing trainer", newline=True, color='red'): + lr = format_lr(args.learning_rate) + trainer = DWIMLAbstractTrainer( + model=model, experiments_path=args.experiments_path, + experiment_name=args.experiment_name, batch_sampler=batch_sampler, + batch_loader=batch_loader, + # COMET + comet_project=args.comet_project, + comet_workspace=args.comet_workspace, + # TRAINING + learning_rates=lr, weight_decay=args.weight_decay, + optimizer=args.optimizer, max_epochs=args.max_epochs, + max_batches_per_epoch_training=args.max_batches_per_epoch_training, + max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, + patience=args.patience, patience_delta=args.patience_delta, + from_checkpoint=False, clip_grad=args.clip_grad, + # MEMORY + nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, + log_level=sub_loggers_level) + logging.info("Trainer params : " + + format_dict_to_str(trainer.params_for_checkpoint)) + + return trainer + + +def main(): + p = prepare_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + # Check that all files exist + assert_inputs_exist(p, [args.hdf5_file]) + assert_outputs_exist(p, args, args.experiments_path) + + # Verify if a checkpoint has been saved. Else create an experiment. + if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, + "checkpoint")): + raise FileExistsError("This experiment already exists. Delete or use " + "script l2t_resume_training_from_checkpoint.py.") + + trainer = init_from_args(args, sub_loggers_level) + + run_experiment(trainer) + + +if __name__ == '__main__': + main() diff --git a/scripts_python/ae_visualize_streamlines.py b/scripts_python/ae_visualize_streamlines.py new file mode 100644 index 00000000..93e5735d --- /dev/null +++ b/scripts_python/ae_visualize_streamlines.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +import logging + +import torch + +from scilpy.io.utils import (add_overwrite_arg, + assert_outputs_exist, + add_reference_arg, + add_verbose_arg) +from scilpy.io.streamlines import load_tractogram_with_reference +from dipy.io.streamline import save_tractogram +from dwi_ml.io_utils import (add_arg_existing_experiment_path, + add_memory_args) +from dwi_ml.models.projects.ae_models import ModelAE + + +def _build_arg_parser(): + p = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, + description=__doc__) + # Mandatory + # Should only be False for debugging tests. + add_arg_existing_experiment_path(p) + # Add_args_testing_subj_hdf5(p) + + p.add_argument('in_tractogram', + help="If set, saves the tractogram with the loss per point " + "as a data per point (color)") + + p.add_argument('out_tractogram', + help="If set, saves the tractogram with the loss per point " + "as a data per point (color)") + + # Options + p.add_argument('--batch_size', type=int) + add_memory_args(p) + + p.add_argument('--pick_at_random', action='store_true') + add_reference_arg(p) + add_overwrite_arg(p) + add_verbose_arg(p) + return p + + +def main(): + p = _build_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.verbose if args.verbose != 'DEBUG' else 'INFO' + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) + + # Verify output names + # Check experiment_path exists and best_model folder exists + # Assert_inputs_exist(p, args.hdf5_file) + assert_outputs_exist(p, args, args.out_tractogram) + + # Device + device = (torch.device('cuda') if torch.cuda.is_available() and + args.use_gpu else None) + + # 1. Load model + logging.debug("Loading model.") + model = ModelAE.load_model_from_params_and_state( + args.experiment_path + '/best_model', log_level=sub_loggers_level) + # model.set_context('training') + # 2. Compute loss + # tester = TesterOneInput(args.experiment_path, + # model, + # args.batch_size, + # device) + # tester = Tester(args.experiment_path, model, args.batch_size, device) + # sft = tester.load_and_format_data(args.subj_id, + # args.hdf5_file, + # args.subset) + + sft = load_tractogram_with_reference(p, args, args.in_tractogram) + sft.to_vox() + sft.to_corner() + bundle = sft.streamlines[0:5000] + + logging.info("Running model to compute loss") + + new_sft = sft.from_sft(bundle, sft) + save_tractogram(new_sft, 'orig_5000.trk') + + with torch.no_grad(): + streamlines = [ + torch.as_tensor(s, dtype=torch.float32, device=device) + for s in bundle] + tmp_outputs = model(streamlines) + # latent = model.encode(streamlines) + + streamlines_output = [tmp_outputs[i, :, :].transpose(0, 1).cpu().numpy() for i in range(len(bundle))] + + # print(streamlines_output[0].shape) + new_sft = sft.from_sft(streamlines_output, sft) + save_tractogram(new_sft, args.out_tractogram, bbox_valid_check=False) + + # latent_output = [s.cpu().numpy() for s in latent] + + # outputs, losses = tester.run_model_on_sft( + # sft, uncompress_loss=args.uncompress_loss, + # force_compress_loss=args.force_compress_loss, + # weight_with_angle=args.weight_with_angle) + + +if __name__ == '__main__': + main()