Skip to content

Commit 8d52c58

Browse files
authored
Merge pull request #181 from EmmaRenauld/fix_space
Refactoring while searching for error
2 parents 64706f6 + 796ac7e commit 8d52c58

File tree

13 files changed

+234
-233
lines changed

13 files changed

+234
-233
lines changed

.github/workflows/test_package.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
# max-parallel: 6
1212
matrix:
1313
os: [ubuntu-latest]
14-
python-version: [3.10.11]
14+
python-version: [3.10.12]
1515
requires: ['latest']
1616

1717
steps:

dwi_ml/tracking/projects/__init__.py

Whitespace-only changes.

dwi_ml/tracking/projects/utils.py

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# -*- coding: utf-8 -*-
2+
import logging
3+
import os
4+
5+
from dipy.io.stateful_tractogram import (Space, Origin, set_sft_logger_level,
6+
StatefulTractogram)
7+
from dipy.io.streamline import save_tractogram
8+
import nibabel as nib
9+
import numpy as np
10+
11+
from scilpy.tracking.seed import SeedGenerator
12+
13+
from dwi_ml.experiment_utils.timer import Timer
14+
from dwi_ml.io_utils import add_arg_existing_experiment_path, add_memory_args
15+
from dwi_ml.testing.utils import add_args_testing_subj_hdf5
16+
from dwi_ml.tracking.tracking_mask import TrackingMask
17+
from dwi_ml.tracking.tracker import DWIMLAbstractTracker
18+
19+
20+
ALWAYS_VOX_SPACE = Space.VOX
21+
ALWAYS_CORNER = Origin('corner')
22+
23+
# A decision should be made as if we should keep the last point (out of the
24+
# tracking mask). Currently keeping this as in Dipy, i.e. True. Could be
25+
# an option for the user.
26+
APPEND_LAST_POINT = True # See here: https://github.com/dipy/dipy/discussions/2764
27+
28+
29+
def add_tracking_options(p):
30+
31+
add_arg_existing_experiment_path(p)
32+
add_args_testing_subj_hdf5(p, ask_input_group=True)
33+
34+
p.add_argument('out_tractogram',
35+
help='Tractogram output file (must be .trk or .tck).')
36+
p.add_argument('seeding_mask_group',
37+
help="Seeding mask's volume group in the hdf5.")
38+
39+
track_g = p.add_argument_group(' Tracking options')
40+
track_g.add_argument('--algo', choices=['det', 'prob'], default='det',
41+
help="Tracking algorithm (det or prob). Must be "
42+
"implemented in the chosen model. [det]")
43+
track_g.add_argument('--step_size', type=float,
44+
help='Step size in mm. Default: using the step size '
45+
'saved in the model parameters.')
46+
track_g.add_argument('--track_forward_only', action='store_true',
47+
help="If set, tracks in one direction only (forward) "
48+
"given the initial \nseed. The direction is "
49+
"randomly drawn from the ODF.")
50+
track_g.add_argument('--mask_interp', default='nearest',
51+
choices=['nearest', 'trilinear'],
52+
help="Mask interpolation: nearest-neighbor or "
53+
"trilinear. [%(default)s]")
54+
track_g.add_argument('--data_interp', default='trilinear',
55+
choices=['nearest', 'trilinear'],
56+
help="Input data interpolation: nearest-neighbor or "
57+
"trilinear. [%(default)s]")
58+
59+
stop_g = p.add_argument_group("Stopping criteria")
60+
stop_g.add_argument('--min_length', type=float, default=10.,
61+
metavar='m',
62+
help='Minimum length of a streamline in mm. '
63+
'[%(default)s]')
64+
stop_g.add_argument('--max_length', type=float, default=300.,
65+
metavar='M',
66+
help='Maximum length of a streamline in mm. '
67+
'[%(default)s]')
68+
stop_g.add_argument('--tracking_mask_group',
69+
help="Tracking mask's volume group in the hdf5.")
70+
stop_g.add_argument('--theta', metavar='t', type=float,
71+
default=90,
72+
help="The tracking direction at each step being "
73+
"defined by the model, \ntheta arg can't define "
74+
"allowed directions in the tracking field.\n"
75+
"Rather, this new equivalent angle, is used as "
76+
"\na stopping criterion during propagation: "
77+
"tracking \nis stopped when a direction is more "
78+
"than an angle t from preceding direction")
79+
stop_g.add_argument('--eos_stop', metavar='prob',
80+
help="Stopping criterion if a EOS value was learned "
81+
"during training. \nCan either be a probability "
82+
"(default 0.5) or the string 'max', which will "
83+
"\nstop the propagation if the EOS class's "
84+
"probability is the class with maximal "
85+
"probability, no mather its value.")
86+
87+
r_g = p.add_argument_group(' Random seeding options')
88+
r_g.add_argument('--rng_seed', type=int,
89+
help='Initial value for the random number generator. '
90+
'[%(default)s]')
91+
r_g.add_argument('--skip', type=int, default=0,
92+
help="Skip the first N random numbers. \n"
93+
"Useful if you want to create new streamlines to "
94+
"add to \na previously created tractogram with a "
95+
"fixed --rng_seed.\nEx: If tractogram_1 was created "
96+
"with -nt 1,000,000, \nyou can create tractogram_2 "
97+
"with \n--skip 1,000,000.")
98+
99+
# Memory options:
100+
m_g = add_memory_args(p, add_lazy_options=True,
101+
add_multiprocessing_option=True,
102+
add_rng=True)
103+
m_g.add_argument('--simultaneous_tracking', type=int, default=1,
104+
help='Track n streamlines at the same time. Intended for '
105+
'GPU usage. Default = 1 (no simultaneous tracking).')
106+
107+
return track_g
108+
109+
110+
def prepare_seed_generator(parser, args, hdf_handle):
111+
"""
112+
Prepares a SeedGenerator from scilpy's library. Returns also some header
113+
information to allow verifications.
114+
"""
115+
seeding_group = hdf_handle[args.subj_id][args.seeding_mask_group]
116+
seed_data = np.array(seeding_group['data'], dtype=np.float32)
117+
seed_res = np.array(seeding_group.attrs['voxres'], dtype=np.float32)
118+
affine = np.array(seeding_group.attrs['affine'], dtype=np.float32)
119+
ref = nib.Nifti1Image(seed_data, affine)
120+
121+
seed_generator = SeedGenerator(seed_data, seed_res, space=ALWAYS_VOX_SPACE,
122+
origin=ALWAYS_CORNER)
123+
124+
if len(seed_generator.seeds_vox) == 0:
125+
parser.error('Seed mask "{}" does not have any voxel with value > 0.'
126+
.format(args.in_seed))
127+
128+
if args.npv:
129+
# Note. Not really nb seed per voxel, just in average.
130+
nbr_seeds = len(seed_generator.seeds_vox) * args.npv
131+
elif args.nt:
132+
nbr_seeds = args.nt
133+
else:
134+
# Setting npv = 1.
135+
nbr_seeds = len(seed_generator.seeds_vox)
136+
137+
seed_header = nib.Nifti1Image(seed_data, affine).header
138+
139+
return seed_generator, nbr_seeds, seed_header, ref
140+
141+
142+
def prepare_tracking_mask(hdf_handle, tracking_mask_group, subj_id, mask_interp):
143+
"""
144+
Prepare the tracking mask as a DataVolume from scilpy's library. Returns
145+
also some header information to allow verifications.
146+
"""
147+
if subj_id not in hdf_handle:
148+
raise KeyError("Subject {} not found in {}. Possible subjects are: {}"
149+
.format(subj_id, hdf_handle, list(hdf_handle.keys())))
150+
if tracking_mask_group not in hdf_handle[subj_id]:
151+
raise KeyError("HDF group '{}' not found for subject {} in hdf file {}"
152+
.format(tracking_mask_group, subj_id, hdf_handle))
153+
tm_group = hdf_handle[subj_id][tracking_mask_group]
154+
mask_data = np.array(tm_group['data'], dtype=np.float64).squeeze()
155+
# mask_res = np.array(tm_group.attrs['voxres'], dtype=np.float32)
156+
affine = np.array(tm_group.attrs['affine'], dtype=np.float32)
157+
ref = nib.Nifti1Image(mask_data, affine)
158+
159+
mask = TrackingMask(mask_data.shape, mask_data, mask_interp)
160+
161+
return mask, ref
162+
163+
164+
def track_and_save(tracker: DWIMLAbstractTracker, args, ref):
165+
if args.save_seeds:
166+
name, ext = os.path.splitext(args.out_tractogram)
167+
if ext != '.trk':
168+
raise ValueError("Cannot save seeds! (data per streamline not "
169+
"saved with extension {}). Please change out "
170+
"filename to .trk".format(ext))
171+
172+
with Timer("\nTracking...", newline=True, color='blue'):
173+
streamlines, seeds = tracker.track()
174+
175+
logging.debug("Tracked {} streamlines (out of {} seeds). Now saving..."
176+
.format(len(streamlines), tracker.nbr_seeds))
177+
178+
if len(streamlines) == 0:
179+
logging.warning("No streamlines created! Not saving tractogram!")
180+
return
181+
182+
# save seeds if args.save_seeds is given
183+
# Seeds must be saved in voxel space (ok!), but origin: center, if we want
184+
# to use scripts such as scil_compute_seed_density_map.
185+
if args.save_seeds:
186+
print("Saving seeds in data_per_streamline.")
187+
seeds = [np.asarray(seed) - 0.5 for seed in seeds] # to_center
188+
data_per_streamline = {'seeds': seeds}
189+
else:
190+
data_per_streamline = {}
191+
192+
# Silencing SFT's logger if our logging is in DEBUG mode, because it
193+
# typically produces a lot of outputs!
194+
set_sft_logger_level('WARNING')
195+
196+
logging.info("Saving resulting tractogram to {}"
197+
.format(args.out_tractogram))
198+
sft = StatefulTractogram(streamlines, ref, space=ALWAYS_VOX_SPACE,
199+
origin=ALWAYS_CORNER,
200+
data_per_streamline=data_per_streamline)
201+
save_tractogram(sft, args.out_tractogram, bbox_valid_check=False)

dwi_ml/tracking/propagation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _verify_stopping_criteria(n_last_pos, lines, mask=None, max_nbr_pts=None):
199199
# continue.
200200
still_on = ~stopping
201201

202-
out_of_mask = ~mask.is_in_mask(n_last_pos[still_on]).cpu().numpy()
202+
out_of_mask = ~mask.is_vox_corner_in_mask(n_last_pos[still_on]).cpu().numpy()
203203
if sum(out_of_mask) > 0:
204204
logger.debug("{} streamlines stopping out of mask."
205205
.format(sum(out_of_mask)))

dwi_ml/tracking/tracker.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import traceback
88
from typing import List
99

10-
from dipy.io.stateful_tractogram import Space, Origin
1110
from dipy.tracking.streamlinespeed import compress_streamlines
1211
import numpy as np
1312
import torch
13+
from dwi_ml.tracking.utils import prepare_step_size_vox
1414
from torch import Tensor
1515
from tqdm.contrib.logging import tqdm_logging_redirect
1616

@@ -23,7 +23,6 @@
2323
MainModelOneInput
2424
from dwi_ml.tracking.propagation import propagate_multiple_lines
2525
from dwi_ml.tracking.tracking_mask import TrackingMask
26-
from dwi_ml.tracking.utils import prepare_step_size_vox
2726

2827
logger = logging.getLogger('tracker_logger')
2928

@@ -180,11 +179,6 @@ def __init__(self, dataset: MultisubjectSubset, subj_idx: int,
180179
self.model.eval()
181180
self.grad_context = torch.no_grad()
182181

183-
# Space and origin
184-
# torch trilinear interpolation uses origin='corner', space=vox.
185-
self.origin = Origin('corner')
186-
self.space = Space.VOX
187-
188182
# Nb points
189183
if self.min_nbr_pts <= 0:
190184
logger.warning("Minimum number of points cannot be 0. Changed to "
@@ -374,15 +368,13 @@ def _cpu_tracking(self, chunk_id):
374368
streamline = np.array(line, dtype='float32')
375369

376370
if self.compression_th and self.compression_th > 0:
377-
# Compressing. Threshold is in mm. Verifying space.
378-
if self.space == Space.VOX:
379-
# Equivalent of sft.to_voxmm:
380-
streamline *= self.seed_generator.voxres
381-
compress_streamlines(streamline, self.compression_th)
382-
# Equivalent of sft.to_vox:
383-
streamline /= self.seed_generator.voxres
384-
else:
385-
compress_streamlines(streamline, self.compression_th)
371+
# Compressing. Threshold is in mm. Considering that we work
372+
# in vox space, changing:
373+
# Equivalent of sft.to_voxmm:
374+
streamline *= self.seed_generator.voxres
375+
compress_streamlines(streamline, self.compression_th)
376+
# Equivalent of sft.to_vox:
377+
streamline /= self.seed_generator.voxres
386378

387379
streamlines.append(streamline)
388380

@@ -438,14 +430,14 @@ def _get_multiple_lines_both_directions(self, seeds: List[np.ndarray]):
438430

439431
logger.debug("Starting forward")
440432
self.prepare_forward(seeds)
441-
lines = self.propagate_multiple_lines(lines)
433+
lines = self._propagate_multiple_lines(lines)
442434

443435
if not self.track_forward_only:
444436
logger.debug("Starting backward")
445437
lines, rej_idx = self.prepare_backward(lines)
446438
if rej_idx is not None and len(rej_idx) > 0:
447439
seeds = [s for i, s in enumerate(seeds) if i not in rej_idx]
448-
lines = self.propagate_multiple_lines(lines)
440+
lines = self._propagate_multiple_lines(lines)
449441

450442
# Clean streamlines
451443
# Max is already checked as stopping criteria.
@@ -456,7 +448,7 @@ def _get_multiple_lines_both_directions(self, seeds: List[np.ndarray]):
456448

457449
return clean_lines, clean_seeds
458450

459-
def propagate_multiple_lines(self, lines: List[Tensor]):
451+
def _propagate_multiple_lines(self, lines: List[Tensor]):
460452
return propagate_multiple_lines(
461453
lines, self.update_memory_after_removing_lines,
462454
self.get_next_dirs, self.theta, self.step_size,

dwi_ml/tracking/tracking_mask.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_value_at_vox_corner_coordinate(self, xyz, interpolation):
4848
else:
4949
return torch_trilinear_interpolation(self.data, xyz)
5050

51-
def is_in_mask(self, xyz):
51+
def is_vox_corner_in_mask(self, xyz):
5252
# Clipping to bound.
5353
xyz = torch.maximum(xyz, self.lower_bound)
5454
xyz = torch.minimum(xyz, self.higher_bound - eps)

0 commit comments

Comments
 (0)