Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Visualize the latent space from the auto encoder #245

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
8 changes: 5 additions & 3 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class HDF5Creator:
See the doc for an example of config file.
https://dwi-ml.readthedocs.io/en/latest/config_file.html
"""

def __init__(self, root_folder: Path, out_hdf_filename: Path,
training_subjs: List[str], validation_subjs: List[str],
testing_subjs: List[str], groups_config: dict,
Expand Down Expand Up @@ -634,8 +635,9 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
raise ValueError(
"The data_per_streamline key '{}' was not found in "
"the sft. Check your tractogram file.".format(dps_key))

logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key))

logging.debug(
" Include dps \"{}\" in the HDF5.".format(dps_key))
streamlines_group.create_dataset('dps_' + dps_key,
data=sft.data_per_streamline[dps_key])

Expand All @@ -652,7 +654,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,

def _process_one_streamline_group(
self, subj_dir: Path, group: str, subj_id: str,
header: nib.Nifti1Header):
header: nib.Nifti1Header, remove_invalid=False):
"""
Loads and processes a group of tractograms and merges all streamlines
together.
Expand Down
201 changes: 201 additions & 0 deletions dwi_ml/models/projects/ae_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-
import logging
from typing import List

import torch
from torch.nn import functional as F

from dwi_ml.models.main_models import MainModelAbstract


class ModelAE(MainModelAbstract):
"""
Recurrent tracking model.

Composed of an embedding for the imaging data's input + for the previous
direction's input, an RNN model to process the sequences, and a direction
getter model to convert the RNN outputs to the right structure, e.g.
deterministic (3D vectors) or probabilistic (based on probability
distribution parameters).
"""
def __init__(self, kernel_size, latent_space_dims,
experiment_name: str,
# Target preprocessing params for the batch loader + tracker
step_size: float = None,
compress_lines: float = False,
# Other
log_level=logging.root.level):
super().__init__(experiment_name, step_size, compress_lines, log_level)

self.kernel_size = kernel_size
self.latent_space_dims = latent_space_dims

self.pad = torch.nn.ReflectionPad1d(1)
self.post_encoding_hooks = []

def pre_pad(m):
return torch.nn.Sequential(self.pad, m)

self.fc1 = torch.nn.Linear(8192,
self.latent_space_dims) # 8192 = 1024*8
self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192)

"""
Encode convolutions
"""
self.encod_conv1 = pre_pad(
torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0)
)
self.encod_conv2 = pre_pad(
torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0)
)
self.encod_conv3 = pre_pad(
torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0)
)
self.encod_conv4 = pre_pad(
torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0)
)
self.encod_conv5 = pre_pad(
torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0)
)
self.encod_conv6 = pre_pad(
torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0)
)

"""
Decode convolutions
"""
self.decod_conv1 = pre_pad(
torch.nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0)
)
self.upsampl1 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv2 = pre_pad(
torch.nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0)
)
self.upsampl2 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv3 = pre_pad(
torch.nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0)
)
self.upsampl3 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv4 = pre_pad(
torch.nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0)
)
self.upsampl4 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv5 = pre_pad(
torch.nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0)
)
self.upsampl5 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv6 = pre_pad(
torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0)
)

@property
def params_for_checkpoint(self):
"""All parameters necessary to create again the same model. Will be
used in the trainer, when saving the checkpoint state. Params here
will be used to re-create the model when starting an experiment from
checkpoint. You should be able to re-create an instance of your
model with those params."""
# p = super().params_for_checkpoint()
p = {'kernel_size': self.kernel_size,
'latent_space_dims': self.latent_space_dims,
'experiment_name': self.experiment_name}
return p

@classmethod
def _load_params(cls, model_dir):
p = super()._load_params(model_dir)
p['kernel_size'] = 3
p['latent_space_dims'] = 32
return p

def forward(self,
input_streamlines: List[torch.tensor],
):
"""Run the model on a batch of sequences.

Parameters
----------
input_streamlines: List[torch.tensor],
Batch of streamlines. Only used if previous directions are added to
the model. Used to compute directions; its last point will not be
used.

Returns
-------
model_outputs : List[Tensor]
Output data, ready to be passed to either `compute_loss()` or
`get_tracking_directions()`.
"""

encoded = self.encode(input_streamlines)

for hook in self.post_encoding_hooks:
hook(encoded)

x = self.decode(encoded)
return x

def encode(self, x):
# X input shape is (batch_size, nb_points, 3)
if isinstance(x, list):
x = torch.stack(x)

x = torch.swapaxes(x, 1, 2) # input of the network should be (N, 3, nb_points)

h1 = F.relu(self.encod_conv1(x))
h2 = F.relu(self.encod_conv2(h1))
h3 = F.relu(self.encod_conv3(h2))
h4 = F.relu(self.encod_conv4(h3))
h5 = F.relu(self.encod_conv5(h4))
h6 = self.encod_conv6(h5)

self.encoder_out_size = (h6.shape[1], h6.shape[2])

# Flatten
h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1])

fc1 = self.fc1(h7)

return fc1

def decode(self, z):
fc = self.fc2(z)
fc_reshape = fc.view(
-1, self.encoder_out_size[0], self.encoder_out_size[1]
)
h1 = F.relu(self.decod_conv1(fc_reshape))
h2 = self.upsampl1(h1)
h3 = F.relu(self.decod_conv2(h2))
h4 = self.upsampl2(h3)
h5 = F.relu(self.decod_conv3(h4))
h6 = self.upsampl3(h5)
h7 = F.relu(self.decod_conv4(h6))
h8 = self.upsampl4(h7)
h9 = F.relu(self.decod_conv5(h8))
h10 = self.upsampl5(h9)
h11 = self.decod_conv6(h10)

return h11

def compute_loss(self, model_outputs, targets, average_results=True):
targets = torch.stack(targets)
targets = torch.swapaxes(targets, 1, 2)
reconstruction_loss = torch.nn.MSELoss()
mse = reconstruction_loss(model_outputs, targets)

return mse, 1

def register_hook_post_encoding(self, hook):
self.post_encoding_hooks.append(hook)

11 changes: 8 additions & 3 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
logger = logging.getLogger('batch_loader_logger')


class DWIMLAbstractBatchLoader:
class DWIMLStreamlinesBatchLoader:
def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
streamline_group_name: str, rng: int,
split_ratio: float = 0.,
Expand Down Expand Up @@ -197,7 +197,11 @@ def _data_augmentation_sft(self, sft):
self.context_subset.compress == self.model.compress_lines:
logger.debug("Compression rate is the same as when creating "
"the hdf5 dataset. Not compressing again.")
else:
elif self.model.step_size is not None and \
self.model.compress_lines is not None:
logger.debug("Resample streamlines using: \n" +
"- step_size: {}\n".format(self.model.step_size) +
"- compress_lines: {}".format(self.model.compress_lines))
sft = resample_or_compress(sft, self.model.step_size,
self.model.nb_points,
self.model.compress_lines)
Expand Down Expand Up @@ -314,6 +318,7 @@ def load_batch_streamlines(
sft.to_vox()
sft.to_corner()
batch_streamlines.extend(sft.streamlines)

batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines]

return batch_streamlines, final_s_ids_per_subj
Expand Down Expand Up @@ -351,7 +356,7 @@ def load_batch_connectivity_matrices(
connectivity_nb_blocs, connectivity_labels)


class DWIMLBatchLoaderOneInput(DWIMLAbstractBatchLoader):
class DWIMLBatchLoaderOneInput(DWIMLStreamlinesBatchLoader):
"""
Loads:
input = one volume group
Expand Down
54 changes: 49 additions & 5 deletions dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dwi_ml.models.main_models import (MainModelAbstract,
ModelWithDirectionGetter)
from dwi_ml.training.batch_loaders import (
DWIMLAbstractBatchLoader, DWIMLBatchLoaderOneInput)
DWIMLStreamlinesBatchLoader, DWIMLBatchLoaderOneInput)
from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
from dwi_ml.training.utils.gradient_norm import compute_gradient_norm
from dwi_ml.training.utils.monitoring import (
Expand Down Expand Up @@ -53,7 +53,7 @@ class DWIMLAbstractTrainer:
def __init__(self,
model: MainModelAbstract, experiments_path: str,
experiment_name: str, batch_sampler: DWIMLBatchIDSampler,
batch_loader: DWIMLAbstractBatchLoader,
batch_loader: DWIMLStreamlinesBatchLoader,
learning_rates: Union[List, float] = None,
weight_decay: float = 0.01,
optimizer: str = 'Adam', max_epochs: int = 10,
Expand All @@ -78,7 +78,7 @@ def __init__(self,
batch_sampler: DWIMLBatchIDSampler
Instantiated class used for sampling batches.
Data in batch_sampler.dataset must be already loaded.
batch_loader: DWIMLAbstractBatchLoader
batch_loader: DWIMLStreamlinesBatchLoader
Instantiated class with a load_batch method able to load data
associated to sampled batch ids. Data in batch_sampler.dataset must
be already loaded.
Expand Down Expand Up @@ -461,7 +461,7 @@ def _prepare_checkpoint_info(self) -> dict:
def init_from_checkpoint(
cls, model: MainModelAbstract, experiments_path, experiment_name,
batch_sampler: DWIMLBatchIDSampler,
batch_loader: DWIMLAbstractBatchLoader,
batch_loader: DWIMLStreamlinesBatchLoader,
checkpoint_state: dict, new_patience, new_max_epochs, log_level):
"""
Loads checkpoint information (parameters and states) to instantiate
Expand Down Expand Up @@ -1013,7 +1013,51 @@ def run_one_batch(self, data):
Any other data returned when computing loss. Not used in the
trainer, but could be useful anywhere else.
"""
raise NotImplementedError
# Data interpolation has not been done yet. GPU computations are done
# here in the main thread.
targets, ids_per_subj = data

# Dataloader always works on CPU. Sending to right device.
# (model is already moved).
targets = [s.to(self.device, non_blocking=True, dtype=torch.float)
for s in targets]

# Getting the inputs points from the volumes.
# Uses the model's method, with the batch_loader's data.
# Possibly skipping the last point if not useful.
streamlines_f = targets
if isinstance(self.model, ModelWithDirectionGetter) and \
not self.model.direction_getter.add_eos:
# No EOS = We don't use the last coord because it does not have an
# associated target direction.
streamlines_f = [s[:-1, :] for s in streamlines_f]

# Possibly add noise to inputs here.
logger.debug('*** Computing forward propagation')

# Now possibly add noise to streamlines (training / valid)
streamlines_f = self.batch_loader.add_noise_streamlines_forward(
streamlines_f, self.device)

# Possibly computing directions twice (during forward and loss)
# but ok, shouldn't be too heavy. Easier to deal with multiple
# projects' requirements by sending whole streamlines rather
# than only directions.
model_outputs = self.model(streamlines_f)
del streamlines_f

logger.debug('*** Computing loss')
targets = self.batch_loader.add_noise_streamlines_loss(
targets, self.device)

results = self.model.compute_loss(model_outputs, targets,
average_results=True)

if self.use_gpu:
log_gpu_memory_usage(logger)

# The mean tensor is a single value. Converting to float using item().
return results

def fix_parameters(self):
"""
Expand Down
Loading
Loading