diff --git a/dwi_ml/training/with_generation/trainer.py b/dwi_ml/training/with_generation/trainer.py index 263cbf7e..bac638db 100644 --- a/dwi_ml/training/with_generation/trainer.py +++ b/dwi_ml/training/with_generation/trainer.py @@ -25,8 +25,8 @@ - Connectivity fit: Percentage of streamlines ending in a block of the volume indeed connected in the validation subject. Real connectivity matrices must be saved in the - hdf5. Right now, volumes are simply split into blocs (the same way as in the - hdf5, ex, to 10x10x10 volumes for a total of 1000 blocks), not based on + hdf5. Right now, volumes are simply split into blocs (the same way as in + the hdf5, ex, to 10x10x10 volumes for a total of 1000 blocks), not based on anatomical ROIs. It has the advantage that it does not rely on the quality of segmentation. It had the drawback that a generated streamline ending very close to the "true" streamline, but in another block, if the @@ -175,9 +175,10 @@ def validate_one_batch(self, data, epoch): logger.debug("Additional tracking-like generation validation " "from batch.") (gen_n, mean_final_dist, mean_clipped_final_dist, - percent_IS_very_good, percent_IS_acceptable, percent_IS_very_far, - diverging_pnt, connectivity) = self.validation_generation_one_batch( - data, compute_all_scores=True) + percent_IS_very_good, percent_IS_acceptable, + percent_IS_very_far, diverging_pnt, connectivity) = \ + self.validation_generation_one_batch( + data, compute_all_scores=True) self.tracking_very_good_IS_monitor.update( percent_IS_very_good, weight=gen_n) @@ -196,8 +197,9 @@ def validate_one_batch(self, data, epoch): self.tracking_connectivity_score_monitor.update( connectivity, weight=gen_n) elif len(self.tracking_mean_final_distance_monitor.average_per_epoch) == 0: - logger.info("Skipping tracking-like generation validation from " - "batch. No values yet: adding fake initial values.") + logger.info("Skipping tracking-like generation validation " + "from batch. No values yet: adding fake initial " + "values.") # Fake values at the beginning # Bad IS = 100% self.tracking_very_good_IS_monitor.update(100.0) @@ -216,8 +218,8 @@ def validate_one_batch(self, data, epoch): self.tracking_connectivity_score_monitor.update(1) else: - logger.info("Skipping tracking-like generation validation from " - "batch. Copying previous epoch's values.") + logger.info("Skipping tracking-like generation validation " + "from batch. Copying previous epoch's values.") # Copy previous value for monitor in [self.tracking_very_good_IS_monitor, self.tracking_acceptable_IS_monitor, @@ -238,7 +240,8 @@ def validation_generation_one_batch(self, data, compute_all_scores=False): # Possibly sending again to GPU even if done in the local loss # computation, but easier with current implementation. - real_lines = [line.to(self.device, non_blocking=True, dtype=torch.float) + real_lines = [line.to(self.device, non_blocking=True, + dtype=torch.float) for line in real_lines] last_pos = torch.vstack([line[-1, :] for line in real_lines]) @@ -267,7 +270,8 @@ def validation_generation_one_batch(self, data, compute_all_scores=False): final_dist_clipped = torch.mean(final_dist_clipped) # 2. Connectivity scores, if available (else None) - connectivity_score = self._compare_connectivity(lines, ids_per_subj) + connectivity_score = self._compare_connectivity(lines, + ids_per_subj) # 3. "IS ratio", i.e. percentage of streamlines ending inside a # predefined radius. @@ -280,9 +284,9 @@ def validation_generation_one_batch(self, data, compute_all_scores=False): final_dist = torch.mean(final_dist) # 4. Verify point where streamline starts diverging. - # abs(100 - score): 0 = good. 100 = bad (either abs(100) -> diverged - # at first point or abs(-100) = diverged after twice the expected - # length. + # abs(100 - score): 0 = good. 100 = bad (either + # abs(100) -> diverged at first point or + # abs(-100) = diverged after twice the expected length. total_point = 0 for line, real_line in zip(lines, real_lines): expected_nb = len(real_line) @@ -321,7 +325,8 @@ def _compare_connectivity(self, lines, ids_per_subj): # done in the loop for each subject. (connectivity_matrices, volume_sizes, connectivity_nb_blocs, connectivity_labels) = \ - self.batch_loader.load_batch_connectivity_matrices(ids_per_subj) + self.batch_loader.load_batch_connectivity_matrices( + ids_per_subj) score = 0.0 for i, subj in enumerate(ids_per_subj.keys()): @@ -365,7 +370,8 @@ def _compare_connectivity(self, lines, ids_per_subj): score = None return score - def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj): + def propagate_multiple_lines(self, lines: List[torch.Tensor], + ids_per_subj): """ Tractography propagation of 'lines'. """ @@ -396,7 +402,8 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): # accept multiple masks or manage it differently. final_lines = [] for subj_idx, line_idx in ids_per_subj.items(): - with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle: + with h5py.File(self.batch_loader.dataset.hdf5_file, 'r' + ) as hdf_handle: subj_id = self.batch_loader.context_subset.subjects[subj_idx] logging.debug("Loading subj {} ({})'s tracking mask." .format(subj_idx, subj_id)) @@ -408,7 +415,8 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): final_lines.extend(propagate_multiple_lines( lines[line_idx], update_memory_after_removing_lines, get_dirs_at_last_pos, theta=theta, - step_size=self.model.step_size, verify_opposite_direction=False, + step_size=self.model.step_size, + verify_opposite_direction=False, mask=tracking_mask, max_nbr_pts=max_nbr_pts, append_last_point=False, normalize_directions=True))