Skip to content

Commit

Permalink
Fix a few lines too long
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Feb 28, 2024
1 parent de0d20d commit 2e350cd
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions dwi_ml/training/with_generation/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
- Connectivity fit:
Percentage of streamlines ending in a block of the volume indeed connected
in the validation subject. Real connectivity matrices must be saved in the
hdf5. Right now, volumes are simply split into blocs (the same way as in the
hdf5, ex, to 10x10x10 volumes for a total of 1000 blocks), not based on
hdf5. Right now, volumes are simply split into blocs (the same way as in
the hdf5, ex, to 10x10x10 volumes for a total of 1000 blocks), not based on
anatomical ROIs. It has the advantage that it does not rely on the quality
of segmentation. It had the drawback that a generated streamline ending
very close to the "true" streamline, but in another block, if the
Expand Down Expand Up @@ -175,9 +175,10 @@ def validate_one_batch(self, data, epoch):
logger.debug("Additional tracking-like generation validation "
"from batch.")
(gen_n, mean_final_dist, mean_clipped_final_dist,
percent_IS_very_good, percent_IS_acceptable, percent_IS_very_far,
diverging_pnt, connectivity) = self.validation_generation_one_batch(
data, compute_all_scores=True)
percent_IS_very_good, percent_IS_acceptable,
percent_IS_very_far, diverging_pnt, connectivity) = \
self.validation_generation_one_batch(
data, compute_all_scores=True)

self.tracking_very_good_IS_monitor.update(
percent_IS_very_good, weight=gen_n)
Expand All @@ -196,8 +197,9 @@ def validate_one_batch(self, data, epoch):
self.tracking_connectivity_score_monitor.update(
connectivity, weight=gen_n)
elif len(self.tracking_mean_final_distance_monitor.average_per_epoch) == 0:
logger.info("Skipping tracking-like generation validation from "
"batch. No values yet: adding fake initial values.")
logger.info("Skipping tracking-like generation validation "
"from batch. No values yet: adding fake initial "
"values.")
# Fake values at the beginning
# Bad IS = 100%
self.tracking_very_good_IS_monitor.update(100.0)
Expand All @@ -216,8 +218,8 @@ def validate_one_batch(self, data, epoch):

self.tracking_connectivity_score_monitor.update(1)
else:
logger.info("Skipping tracking-like generation validation from "
"batch. Copying previous epoch's values.")
logger.info("Skipping tracking-like generation validation "
"from batch. Copying previous epoch's values.")
# Copy previous value
for monitor in [self.tracking_very_good_IS_monitor,
self.tracking_acceptable_IS_monitor,
Expand All @@ -238,7 +240,8 @@ def validation_generation_one_batch(self, data, compute_all_scores=False):

# Possibly sending again to GPU even if done in the local loss
# computation, but easier with current implementation.
real_lines = [line.to(self.device, non_blocking=True, dtype=torch.float)
real_lines = [line.to(self.device, non_blocking=True,
dtype=torch.float)
for line in real_lines]
last_pos = torch.vstack([line[-1, :] for line in real_lines])

Expand Down Expand Up @@ -267,7 +270,8 @@ def validation_generation_one_batch(self, data, compute_all_scores=False):
final_dist_clipped = torch.mean(final_dist_clipped)

# 2. Connectivity scores, if available (else None)
connectivity_score = self._compare_connectivity(lines, ids_per_subj)
connectivity_score = self._compare_connectivity(lines,
ids_per_subj)

# 3. "IS ratio", i.e. percentage of streamlines ending inside a
# predefined radius.
Expand All @@ -280,9 +284,9 @@ def validation_generation_one_batch(self, data, compute_all_scores=False):
final_dist = torch.mean(final_dist)

# 4. Verify point where streamline starts diverging.
# abs(100 - score): 0 = good. 100 = bad (either abs(100) -> diverged
# at first point or abs(-100) = diverged after twice the expected
# length.
# abs(100 - score): 0 = good. 100 = bad (either
# abs(100) -> diverged at first point or
# abs(-100) = diverged after twice the expected length.
total_point = 0
for line, real_line in zip(lines, real_lines):
expected_nb = len(real_line)
Expand Down Expand Up @@ -321,7 +325,8 @@ def _compare_connectivity(self, lines, ids_per_subj):
# done in the loop for each subject.
(connectivity_matrices, volume_sizes,
connectivity_nb_blocs, connectivity_labels) = \
self.batch_loader.load_batch_connectivity_matrices(ids_per_subj)
self.batch_loader.load_batch_connectivity_matrices(
ids_per_subj)

score = 0.0
for i, subj in enumerate(ids_per_subj.keys()):
Expand Down Expand Up @@ -365,7 +370,8 @@ def _compare_connectivity(self, lines, ids_per_subj):
score = None
return score

def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj):
def propagate_multiple_lines(self, lines: List[torch.Tensor],
ids_per_subj):
"""
Tractography propagation of 'lines'.
"""
Expand Down Expand Up @@ -396,7 +402,8 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos):
# accept multiple masks or manage it differently.
final_lines = []
for subj_idx, line_idx in ids_per_subj.items():
with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle:
with h5py.File(self.batch_loader.dataset.hdf5_file, 'r'
) as hdf_handle:
subj_id = self.batch_loader.context_subset.subjects[subj_idx]
logging.debug("Loading subj {} ({})'s tracking mask."
.format(subj_idx, subj_id))
Expand All @@ -408,7 +415,8 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos):
final_lines.extend(propagate_multiple_lines(
lines[line_idx], update_memory_after_removing_lines,
get_dirs_at_last_pos, theta=theta,
step_size=self.model.step_size, verify_opposite_direction=False,
step_size=self.model.step_size,
verify_opposite_direction=False,
mask=tracking_mask, max_nbr_pts=max_nbr_pts,
append_last_point=False, normalize_directions=True))

Expand Down

0 comments on commit 2e350cd

Please sign in to comment.