Skip to content

Commit

Permalink
ENH: add clear cache opt to model
Browse files Browse the repository at this point in the history
  • Loading branch information
AntoineTheb committed May 27, 2024
1 parent 05a3f0f commit aaf48f2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
8 changes: 5 additions & 3 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ def forward(self, inputs, target_streamlines: List[torch.tensor]):

class MainModelOneInput(MainModelAbstract):
def prepare_batch_one_input(self, streamlines, subset: MultisubjectSubset,
subj_idx, input_group_idx, prepare_mask=False):
subj_idx, input_group_idx, prepare_mask=False,
clear_cache=True):
"""
These params are passed by either the batch loader or the propagator,
which manage the data.
Expand Down Expand Up @@ -491,10 +492,11 @@ def prepare_batch_one_input(self, streamlines, subset: MultisubjectSubset,
if isinstance(self, ModelWithNeighborhood):
# Adding neighborhood.
subj_x_data, coords_torch = interpolate_volume_in_neighborhood(
data_tensor, flat_subj_x_coords, self.neighborhood_vectors)
data_tensor, flat_subj_x_coords, self.neighborhood_vectors,
clear_cache=clear_cache)
else:
subj_x_data, coords_torch = interpolate_volume_in_neighborhood(
data_tensor, flat_subj_x_coords, None)
data_tensor, flat_subj_x_coords, None, clear_cache=clear_cache)

# Split the flattened signal back to streamlines
lengths = [len(s) for s in streamlines]
Expand Down
3 changes: 2 additions & 1 deletion dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def __init__(self, input_group_name, **kw):
.format(input_group_name,
self.dataset.volume_groups))
self.input_group_idx = idx
self.clear_cache = True

@property
def params_for_checkpoint(self):
Expand Down Expand Up @@ -404,7 +405,7 @@ def load_batch_inputs(self, batch_streamlines: List[torch.tensor],
# before adding streamline to batch.
subbatch_x_data = self.model.prepare_batch_one_input(
streamlines, self.context_subset, subj,
self.input_group_idx)
self.input_group_idx, clear_cache=self.clear_cache)

batch_x_data.extend(subbatch_x_data)

Expand Down

0 comments on commit aaf48f2

Please sign in to comment.