Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

L2T_fix_unpacking_hidden_state #182

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,22 +301,30 @@ def forward(self, inputs: List[torch.tensor],
model_outputs = self.direction_getter(rnn_output)
model_outputs = copy_prev_dir + model_outputs

# Return the hidden states. Necessary for the generative
# (tracking) part, done step by step.

# During training / visu: not unpacking now; we will compute the loss
# point by whole tensor.
# But sending as PackedSequence to be sure that targets
# will be concatenated in the same order when computing loss.
# During tracking last point: keeping as one tensor.
# During tracking backward: ignoring output anyway. Only computing
# hidden state.
# Unpacking.
if not self._context == 'tracking':
# (during tracking: keeping as one single tensor.)
model_outputs = PackedSequence(model_outputs, batch_sizes)
model_outputs = unpack_sequence(model_outputs)
model_outputs = [model_outputs[i] for i in unsorted_indices]

if return_hidden:
# Return the hidden states too. Necessary for the generative
# (tracking) part, done step by step.
if not self._context == 'tracking':
# (ex: when preparing backward tracking.
# Must also re-sort hidden states.)
if self.rnn_model.rnn_torch_key == 'lstm':
# LSTM: For each layer, states are tuples; (h_t, C_t)
out_hidden_recurrent_states = [
(layer_states[0][:, unsorted_indices, :],
layer_states[1][:, unsorted_indices, :]) for
layer_states in out_hidden_recurrent_states]
else:
# GRU: For each layer, states are tensors; h_t.
out_hidden_recurrent_states = [
layer_states[:, unsorted_indices, :] for
layer_states in out_hidden_recurrent_states]
return model_outputs, out_hidden_recurrent_states
else:
return model_outputs
Expand Down Expand Up @@ -380,16 +388,16 @@ def copy_prev_dir(self, dirs, n_prev_dirs):

def update_hidden_state(self, hidden_recurrent_states, lines_to_keep):
if self.rnn_model.rnn_torch_key == 'lstm':
# LSTM: States are tuples; (h_t, C_t)
# LSTM: For each layer, states are tuples; (h_t, C_t)
# Size of tensors are each [1, nb_streamlines, nb_neurons]
hidden_recurrent_states = [
(hidden_states[0][:, lines_to_keep, :],
hidden_states[1][:, lines_to_keep, :]) for
hidden_states in hidden_recurrent_states]
(layer_states[0][:, lines_to_keep, :],
layer_states[1][:, lines_to_keep, :]) for
layer_states in hidden_recurrent_states]
else:
# GRU: States are tensors; h_t.
# Size of tensors are [1, nb_streamlines, nb_neurons].
# GRU: For each layer, states are tensors; h_t.
# Size of tensors are [1, nb_streamlines, nb_neurons].
hidden_recurrent_states = [
hidden_states[:, lines_to_keep, :] for
hidden_states in hidden_recurrent_states]
layer_states[:, lines_to_keep, :] for
layer_states in hidden_recurrent_states]
return hidden_recurrent_states