Skip to content

Commit

Permalink
Fix sampler: batch size must be int
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Feb 29, 2024
1 parent 8d343ba commit 29fbae8
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions dwi_ml/training/batch_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,15 @@ def __iter__(self) -> Iterator[List[Tuple[int, list]]]:
# Choose subjects from which to sample streamlines for the next
# few cycles.
if self.nb_subjects_per_batch:
# Sampling first from subjects that were not seed a lot yet
# Sampling first from subjects that were not seen a lot yet
weights = streamlines_per_subj / np.sum(streamlines_per_subj)

# Choosing only non-empty subjects
# NOTE. THIS IS QUESTIONNABLE! It means that the last batch of
# every epoch is ~1 subject: the one with the most streamlines.
# Other choice could be to break as soon as at least one
# subject is done. With batches not too big, we would still
# have seen most of the data of unfinished subjects.
nb_subjects = min(self.nb_subjects_per_batch,
np.count_nonzero(weights))
sampled_subjs = self.np_rng.choice(
Expand All @@ -271,9 +276,9 @@ def __iter__(self) -> Iterator[List[Tuple[int, list]]]:

# Final subject's batch size could be smaller if no streamlines are
# left for this subject.
max_batch_size_per_subj = self.context_batch_size / nb_subjects
max_batch_size_per_subj = int(self.context_batch_size / nb_subjects)
if self.batch_size_units == 'nb_streamlines':
chunk_size = int(max_batch_size_per_subj)
chunk_size = max_batch_size_per_subj
else:
chunk_size = self.nb_streamlines_per_chunk or DEFAULT_CHUNK_SIZE

Expand All @@ -296,9 +301,10 @@ def __iter__(self) -> Iterator[List[Tuple[int, list]]]:
batch_ids_per_subj = []
for subj in sampled_subjs:
self.logger.debug(" Subj {}".format(subj))
sampled_ids = self._sample_streamlines_for_subj(
subj, ids_per_subjs, global_unused_streamlines,
max_batch_size_per_subj, chunk_size)
sampled_ids, global_unused_streamlines = \
self._sample_streamlines_for_subj(
subj, ids_per_subjs, global_unused_streamlines,
max_batch_size_per_subj, chunk_size)

# Append tuple (subj, list_sampled_ids) to the batch
if len(sampled_ids) > 0:
Expand Down Expand Up @@ -331,8 +337,8 @@ def _sample_streamlines_for_subj(self, subj, ids_per_subjs,
------
subj: int
The subject's id.
ids_per_subjs: dict
The list of this subject's streamlines' global ids.
ids_per_subjs: dict[slice]
This subject's streamlines' global ids (slices).
global_unused_streamlines: array
One flag per global streamline id: 0 if already used, else 1.
max_batch_size_per_subj:
Expand All @@ -344,6 +350,8 @@ def _sample_streamlines_for_subj(self, subj, ids_per_subjs,
# subject
subj_slice = ids_per_subjs[subj]

slice_to_list = list(range(subj_slice.start, subj_slice.stop))

# We will continue iterating on this subject until we
# break (i.e. when we reach the maximum batch size for this
# subject)
Expand All @@ -364,12 +372,12 @@ def _sample_streamlines_for_subj(self, subj, ids_per_subjs,
if len(chunk_rel_ids) == 0:
raise ValueError(
"Implementation error? Got no streamline for this subject "
"in this batch, but there are streamlines left. \n"
"Possibly means that the allowed batch size does not even "
"allow one streamline per batch.\n Check your batch size "
"choice!")
"in this batch, but there are streamlines left. To be "
"discussed with the implemetors.")

# Mask the sampled streamlines
# Mask the sampled streamlines.
# Technically done in-place, wouldn't need to return, but
# returning to be sure.
global_unused_streamlines[chunk_global_ids] = 0

# Add sub-sampled ids to subject's batch
Expand All @@ -383,7 +391,7 @@ def _sample_streamlines_for_subj(self, subj, ids_per_subjs,
# Update size and get a new chunk
total_subj_batch_size += subj_batch_size

return sampled_ids
return sampled_ids, global_unused_streamlines

def _get_a_chunk_of_streamlines(self, subj_slice,
global_unused_streamlines,
Expand Down Expand Up @@ -443,15 +451,16 @@ def _get_a_chunk_of_streamlines(self, subj_slice,
chosen_global_ids)
tmp_computed_chunk_size = int(np.sum(size_per_streamline))

# If batch_size has been exceeded, taking a little less streamlines
# for this chunk.
if subj_batch_size + tmp_computed_chunk_size >= max_subj_batch_size:
reached_max_heaviness = True

# If batch_size has been exceeded, taking a little less streamlines
# for this chunk.
if subj_batch_size + tmp_computed_chunk_size > max_subj_batch_size:
self.logger.debug(
" Chunk_size was {}, but max batch size for this "
"subj is {} (we already had acculumated {})."
"subj is {} (we already had acculumated {}). Taking a bit "
"less streamlines."
.format(tmp_computed_chunk_size, max_subj_batch_size,
subj_batch_size))

Expand Down

0 comments on commit 29fbae8

Please sign in to comment.