Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][NF] Auto-encoders - streamlines - FINTA #220

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,

def _process_one_streamline_group(
self, subj_dir: Path, group: str, subj_id: str,
header: nib.Nifti1Header):
header: nib.Nifti1Header, remove_invalid=False):
"""
Loads and processes a group of tractograms and merges all streamlines
together.
Expand Down Expand Up @@ -681,12 +681,13 @@ def _process_one_streamline_group(
# with Path.
save_tractogram(final_sft, str(output_fname))

# Removing invalid streamlines
logging.debug(' *Total: {:,.0f} streamlines. Now removing '
'invalid streamlines.'.format(len(final_sft)))
final_sft.remove_invalid_streamlines()
logging.info(" Final number of streamlines: {:,.0f}."
.format(len(final_sft)))
if remove_invalid:
# Removing invalid streamlines
logging.debug(' *Total: {:,.0f} streamlines. Now removing '
'invalid streamlines.'.format(len(final_sft)))
final_sft.remove_invalid_streamlines()
logging.info(" Final number of streamlines: {:,.0f}."
.format(len(final_sft)))

conn_matrix = None
conn_info = None
Expand Down Expand Up @@ -741,10 +742,10 @@ def _load_and_process_sft(self, tractogram_file, tractogram_name, header):
"We do not support file's type: {}. We only support .trk "
"and .tck files.".format(tractogram_file))
if file_extension == '.trk':
if not is_header_compatible(str(tractogram_file), header):
raise ValueError("Streamlines group is not compatible with "
"volume groups\n ({})"
.format(tractogram_file))
if header:
if not is_header_compatible(str(tractogram_file), header):
raise ValueError("Streamlines group is not compatible with "
"volume groups\n ({})".format(tractogram_file))
# overriding given header.
header = 'same'

Expand Down
1 change: 0 additions & 1 deletion dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dipy.io.stateful_tractogram import StatefulTractogram
from nibabel.streamlines.tractogram import (PerArrayDict, PerArraySequenceDict)
import numpy as np

from scilpy.tracking.tools import resample_streamlines_step_size
from scilpy.utils.streamlines import compress_sft

Expand Down
3 changes: 1 addition & 2 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from dwi_ml.data.processing.space.neighborhood import prepare_neighborhood_vectors
from dwi_ml.experiment_utils.prints import format_dict_to_str
from dwi_ml.io_utils import add_resample_or_compress_arg
from dwi_ml.models.direction_getter_models import keys_to_direction_getters, \
AbstractDirectionGetterModel
from dwi_ml.models.direction_getter_models import keys_to_direction_getters
from dwi_ml.models.embeddings import keys_to_embeddings, NNEmbedding, NoEmbedding
from dwi_ml.models.utils.direction_getters import add_direction_getter_args

Expand Down
182 changes: 182 additions & 0 deletions dwi_ml/models/projects/ae_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# -*- coding: utf-8 -*-
import logging
from typing import List

import torch
from torch.nn import functional as F

from dwi_ml.models.main_models import MainModelAbstract


class ModelAE(MainModelAbstract):
"""
Recurrent tracking model.

Composed of an embedding for the imaging data's input + for the previous
direction's input, an RNN model to process the sequences, and a direction
getter model to convert the RNN outputs to the right structure, e.g.
deterministic (3D vectors) or probabilistic (based on probability
distribution parameters).
"""
def __init__(self, kernel_size, latent_space_dims,
experiment_name: str,
# Target preprocessing params for the batch loader + tracker
step_size: float = None,
compress_lines: float = False,
# Other
log_level=logging.root.level):
super().__init__(experiment_name, step_size, compress_lines, log_level)

self.kernel_size = kernel_size
self.latent_space_dims = latent_space_dims

self.pad = torch.nn.ReflectionPad1d(1)

def pre_pad(m):
return torch.nn.Sequential(self.pad, m)

self.fc1 = torch.nn.Linear(8192,
self.latent_space_dims) # 8192 = 1024*8
self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192)

"""
Encode convolutions
"""
self.encod_conv1 = pre_pad(
torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0)
)
self.encod_conv2 = pre_pad(
torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0)
)
self.encod_conv3 = pre_pad(
torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0)
)
self.encod_conv4 = pre_pad(
torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0)
)
self.encod_conv5 = pre_pad(
torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0)
)
self.encod_conv6 = pre_pad(
torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0)
)

"""
Decode convolutions
"""
self.decod_conv1 = pre_pad(
torch.nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0)
)
self.upsampl1 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv2 = pre_pad(
torch.nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0)
)
self.upsampl2 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv3 = pre_pad(
torch.nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0)
)
self.upsampl3 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv4 = pre_pad(
torch.nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0)
)
self.upsampl4 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv5 = pre_pad(
torch.nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0)
)
self.upsampl5 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv6 = pre_pad(
torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0)
)

self.forward_uses_streamlines = True
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
self.loss_uses_streamlines = True

def forward(self,
input_streamlines: List[torch.tensor],
):
"""Run the model on a batch of sequences.

Parameters
----------
input_streamlines: List[torch.tensor],
Batch of streamlines. Only used if previous directions are added to
the model. Used to compute directions; its last point will not be
used.

Returns
-------
model_outputs : List[Tensor]
Output data, ready to be passed to either `compute_loss()` or
`get_tracking_directions()`.
"""
input_streamlines = torch.stack(input_streamlines)
input_streamlines = torch.swapaxes(input_streamlines, 1, 2)

x = self.decode(self.encode(input_streamlines))
return x

def encode(self, x):
h1 = F.relu(self.encod_conv1(x))
h2 = F.relu(self.encod_conv2(h1))
h3 = F.relu(self.encod_conv3(h2))
h4 = F.relu(self.encod_conv4(h3))
h5 = F.relu(self.encod_conv5(h4))
h6 = self.encod_conv6(h5)

self.encoder_out_size = (h6.shape[1], h6.shape[2])

# Flatten
h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1])

fc1 = self.fc1(h7)

return fc1

def decode(self, z):
fc = self.fc2(z)
fc_reshape = fc.view(
-1, self.encoder_out_size[0], self.encoder_out_size[1]
)
h1 = F.relu(self.decod_conv1(fc_reshape))
h2 = self.upsampl1(h1)
h3 = F.relu(self.decod_conv2(h2))
h4 = self.upsampl2(h3)
h5 = F.relu(self.decod_conv3(h4))
h6 = self.upsampl3(h5)
h7 = F.relu(self.decod_conv4(h6))
h8 = self.upsampl4(h7)
h9 = F.relu(self.decod_conv5(h8))
h10 = self.upsampl5(h9)
h11 = self.decod_conv6(h10)

return h11

def compute_loss(self, model_outputs, targets, average_results=True):
print("COMPARISON\n")
targets = torch.stack(targets)
targets = torch.swapaxes(targets, 1, 2)
print(targets[0, :, 0:5])
print(model_outputs[0, :, 0:5])
reconstruction_loss = torch.nn.MSELoss(reduction="mean")
mse = reconstruction_loss(model_outputs, targets)

# loss_function_vae
# See Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
# kld = -0.5 * torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp())
# kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
# kld = torch.sum(kld_element).__mul__(-0.5)

return mse, 1
7 changes: 6 additions & 1 deletion dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,11 @@ def _data_augmentation_sft(self, sft):
self.context_subset.compress == self.model.compress_lines:
logger.debug("Compression rate is the same as when creating "
"the hdf5 dataset. Not compressing again.")
else:
elif self.model.step_size is not None and \
self.model.compress_lines is not None:
logger.debug("Resample streamlines using: \n" +
"- step_size: {}\n".format(self.model.step_size) +
"- compress_lines: {}".format(self.model.compress_lines))
sft = resample_or_compress(sft, self.model.step_size,
self.model.compress_lines)

Expand Down Expand Up @@ -306,6 +310,7 @@ def load_batch_streamlines(
sft.to_vox()
sft.to_corner()
batch_streamlines.extend(sft.streamlines)

batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines]

return batch_streamlines, final_s_ids_per_subj
Expand Down
53 changes: 52 additions & 1 deletion dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,58 @@ def run_one_batch(self, data):
the list of all streamlines (if we concatenate all sfts'
streamlines)
"""
raise NotImplementedError
# Data interpolation has not been done yet. GPU computations are done
# here in the main thread.
targets, ids_per_subj = data

# Dataloader always works on CPU. Sending to right device.
# (model is already moved).
targets = [s.to(self.device, non_blocking=True, dtype=torch.float)
for s in targets]

# Getting the inputs points from the volumes.
# Uses the model's method, with the batch_loader's data.
# Possibly skipping the last point if not useful.
streamlines_f = targets
if isinstance(self.model, ModelWithDirectionGetter) and \
not self.model.direction_getter.add_eos:
# No EOS = We don't use the last coord because it does not have an
# associated target direction.
streamlines_f = [s[:-1, :] for s in streamlines_f]

# Possibly add noise to inputs here.
logger.debug('*** Computing forward propagation')
if self.model.forward_uses_streamlines:
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
# Now possibly add noise to streamlines (training / valid)
streamlines_f = self.batch_loader.add_noise_streamlines_forward(
streamlines_f, self.device)

# Possibly computing directions twice (during forward and loss)
# but ok, shouldn't be too heavy. Easier to deal with multiple
# projects' requirements by sending whole streamlines rather
# than only directions.
model_outputs = self.model(streamlines_f)
del streamlines_f
else:
del streamlines_f
model_outputs = self.model()

logger.debug('*** Computing loss')
if self.model.loss_uses_streamlines:
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
targets = self.batch_loader.add_noise_streamlines_loss(
targets, self.device)

results = self.model.compute_loss(model_outputs, targets,
average_results=True)
else:
results = self.model.compute_loss(model_outputs,
average_results=True)

if self.use_gpu:
log_gpu_memory_usage(logger)

# The mean tensor is a single value. Converting to float using item().
return results

def fix_parameters(self):
"""
Expand Down
Loading
Loading