Skip to content

Commit 1a5e044

Browse files
Merge pull request #33 from francois-drielsma/develop
Various bug fixes and functionality extensions
2 parents c8e4c09 + fcc7deb commit 1a5e044

29 files changed

+707
-176
lines changed

bin/run.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from spine.main import run
2626

2727

28-
def main(config, source, source_list, output, n, nskip, detect_anomaly, log_dir, weight_prefix, weight_path):
28+
def main(config, source, source_list, output, n, nskip, detect_anomaly,
29+
log_dir, weight_prefix, weight_path):
2930
"""Main driver for training/validation/inference/analysis.
3031
3132
Performs these basic functions:
@@ -53,7 +54,8 @@ def main(config, source, source_list, output, n, nskip, detect_anomaly, log_dir,
5354
weight_prefix : str
5455
Path to the directory for storing the training weights
5556
weight_path : str
56-
Path string a weight file or pattern for multiple weight files to load the model weights
57+
Path string a weight file or pattern for multiple weight files to load
58+
the model weights
5759
"""
5860
# Try to find configuration file using the absolute path or under
5961
# the 'config' directory of the parent SPINE repository
@@ -112,7 +114,8 @@ def main(config, source, source_list, output, n, nskip, detect_anomaly, log_dir,
112114

113115
if weight_prefix is not None:
114116
if not 'train' in cfg['base']:
115-
raise KeyError('--weight_prefix flag provided: must specify `train` in the `base` block.')
117+
raise KeyError("--weight_prefix flag provided: must specify "
118+
"`train` in the `base` block.")
116119
cfg['base']['train']['weight_prefix']=weight_prefix
117120

118121
if weight_path is not None:
@@ -182,4 +185,5 @@ def main(config, source, source_list, output, n, nskip, detect_anomaly, log_dir,
182185

183186
# Execute the main function
184187
main(args.config, args.source, args.source_list, args.output, args.n,
185-
args.nskip, args.detect_anomaly, args.log_dir, args.weight_prefix, args.weight_path)
188+
args.nskip, args.detect_anomaly, args.log_dir, args.weight_prefix,
189+
args.weight_path)

spine/ana/base.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,15 @@ class AnaBase(ABC):
4444
# Valid run modes
4545
_run_modes = ('reco', 'truth', 'both', 'all')
4646

47-
def __init__(self, obj_type=None, run_mode=None, append=False,
48-
overwrite=False, log_dir=None, prefix=None):
47+
# List of known point modes for true particles and their corresponding keys
48+
_point_modes = (
49+
('points', 'points_label'),
50+
('points_adapt', 'points'),
51+
('points_g4', 'points_g4')
52+
)
53+
54+
def __init__(self, obj_type=None, run_mode=None, truth_point_mode=None,
55+
append=False, overwrite=False, log_dir=None, prefix=None):
4956
"""Initialize default anlysis script object properties.
5057
5158
Parameters
@@ -56,6 +63,10 @@ def __init__(self, obj_type=None, run_mode=None, append=False,
5663
If specified, tells whether the analysis script must run on
5764
reconstructed ('reco'), true ('true') or both objects
5865
('both' or 'all')
66+
truth_point_mode : str, optional
67+
If specified, tells which attribute of the :class:`TruthFragment`,
68+
:class:`TruthParticle` or :class:`TruthInteraction` object to use
69+
to fetch its point coordinates
5970
append : bool, default False
6071
If True, appends existing CSV files instead of creating new ones
6172
overwrite : bool, default False
@@ -114,6 +125,14 @@ def __init__(self, obj_type=None, run_mode=None, append=False,
114125
# Update underlying keys, if needed
115126
self.update_keys({k:True for k in self.obj_keys})
116127

128+
# If a truth point mode is specified, store it
129+
if truth_point_mode is not None:
130+
assert truth_point_mode in self.point_modes, (
131+
"The `truth_point_mode` argument must be one of "
132+
f"{self.point_modes.keys()}. Got `{truth_point_mode}` instead.")
133+
self.truth_point_mode = truth_point_mode
134+
self.truth_index_mode = truth_point_mode.replace('points', 'index')
135+
117136
# Store the append flag
118137
self.append_file = append
119138
self.overwrite_file = overwrite
@@ -167,6 +186,18 @@ def keys(self, keys):
167186
"""
168187
self._keys = tuple(keys.items())
169188

189+
@property
190+
def point_modes(self):
191+
"""Dictionary which makes the correspondance between the name of a true
192+
object point attribute with the underlying point tensor it points to.
193+
194+
Returns
195+
-------
196+
Dict[str, str]
197+
Dictionary of (attribute, key) mapping for point coordinates
198+
"""
199+
return dict(self._point_modes)
200+
170201
def update_keys(self, update_dict):
171202
"""Update the underlying set of keys and their necessity in place.
172203
@@ -249,6 +280,50 @@ def __call__(self, data, entry=None):
249280
# Run the analysis script
250281
return self.process(data_filter)
251282

283+
def get_index(self, obj):
284+
"""Get a certain pre-defined index attribute of an object.
285+
286+
The :class:`TruthFragment`, :class:`TruthParticle` and
287+
:class:`TruthInteraction` objects index are obtained using the
288+
`truth_index_mode` attribute of the class.
289+
290+
Parameters
291+
----------
292+
obj : Union[FragmentBase, ParticleBase, InteractionBase]
293+
Fragment, Particle or Interaction object
294+
295+
Results
296+
-------
297+
np.ndarray
298+
(N) Object index
299+
"""
300+
if not obj.is_truth:
301+
return obj.index
302+
else:
303+
return getattr(obj, self.truth_index_mode)
304+
305+
def get_points(self, obj):
306+
"""Get a certain pre-defined point attribute of an object.
307+
308+
The :class:`TruthFragment`, :class:`TruthParticle` and
309+
:class:`TruthInteraction` objects points are obtained using the
310+
`truth_point_mode` attribute of the class.
311+
312+
Parameters
313+
----------
314+
obj : Union[FragmentBase, ParticleBase, InteractionBase]
315+
Fragment, Particle or Interaction object
316+
317+
Results
318+
-------
319+
np.ndarray
320+
(N, 3) Point coordinates
321+
"""
322+
if not obj.is_truth:
323+
return obj.points
324+
else:
325+
return getattr(obj, self.truth_point_mode)
326+
252327
@abstractmethod
253328
def process(self, data):
254329
"""Place-holder method to be defined in each analysis script.

spine/ana/diag/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
This submodule is use to run basic diagnostics analyses such as:
44
- Track dE/dx profile
55
- Track energy reconstruction
6+
- Track completeness
67
- Shower start dE/dx
78
- ...
89
'''
910

1011
from .shower import *
12+
from .track import *

spine/ana/diag/track.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""Module to evaluate diagnostic metrics on tracks."""
2+
3+
import numpy as np
4+
from scipy.spatial.distance import cdist
5+
6+
from spine.ana.base import AnaBase
7+
8+
from spine.utils.globals import TRACK_SHP
9+
from spine.utils.numba_local import principal_components
10+
11+
12+
__all__ = ['TrackCompletenessAna']
13+
14+
15+
class TrackCompletenessAna(AnaBase):
16+
"""This analysis script identifies gaps in tracks and measures the
17+
cumulative length of these gaps relative to the track length.
18+
19+
This is a useful diagnostic tool to evaluate the space-point efficiency
20+
on tracks (good standard candal as track should have exactly no gap in
21+
a perfectly efficient detector).
22+
"""
23+
24+
# Name of the analysis script (as specified in the configuration)
25+
name = 'track_completeness'
26+
27+
def __init__(self, time_window=None, run_mode='both',
28+
truth_point_mode='points', **kwargs):
29+
"""Initialize the analysis script.
30+
31+
Parameters
32+
----------
33+
time_window : List[float]
34+
Time window within which to include particle (only works for `truth`)
35+
**kwargs : dict, optional
36+
Additional arguments to pass to :class:`AnaBase`
37+
"""
38+
# Initialize the parent class
39+
super().__init__('particle', run_mode, truth_point_mode, **kwargs)
40+
41+
# Store the time window
42+
self.time_window = time_window
43+
assert time_window is None or len(time_window) == 2, (
44+
"Time window must be specified as an array of two values.")
45+
assert time_window is None or run_mode == 'truth', (
46+
"Time of reconstructed particle is unknown.")
47+
48+
# Make sure the metadata is provided (rasterization needed)
49+
self.update_keys({'meta': True})
50+
51+
# Initialize the CSV writer(s) you want
52+
for prefix in self.prefixes:
53+
self.initialize_writer(prefix)
54+
55+
def process(self, data):
56+
"""Evaluate track completeness for tracks in one entry.
57+
58+
Parameters
59+
----------
60+
data : dict
61+
Dictionary of data products
62+
"""
63+
# Fetch the pixel size in this image (assume cubic cells)
64+
pixel_size = data['meta'].size[0]
65+
66+
# Loop over the types of particle data products
67+
for key in self.obj_keys:
68+
# Fetch the prefix ('reco' or 'truth')
69+
prefix = key.split('_')[0]
70+
71+
# Loop over particle objects
72+
for part in data[key]:
73+
# Check that the particle is a track
74+
if part.shape != TRACK_SHP:
75+
continue
76+
77+
# If needed, check on the particle time
78+
if self.time_window is not None:
79+
if part.t < self.time_window[0] or part.t > self.time_window[1]:
80+
continue
81+
82+
# Initialize the particle dictionary
83+
comp_dict = {'particle_id': part.id}
84+
85+
# Fetch the particle point coordinates
86+
points = self.get_points(part)
87+
88+
# Find start/end points, collapse onto track cluster
89+
start = points[np.argmin(cdist([part.start_point], points))]
90+
end = points[np.argmin(cdist([part.end_point], points))]
91+
92+
# Add the direction of the track
93+
vec = end - start
94+
length = np.linalg.norm(vec)
95+
if length:
96+
vec /= length
97+
98+
comp_dict['size'] = len(points)
99+
comp_dict['length'] = length
100+
comp_dict.update(
101+
{'dir_x': vec[0], 'dir_y': vec[1], 'dir_z': vec[2]})
102+
103+
# Chunk out the track along gaps, estimate gap length
104+
chunk_labels = self.cluster_track_chunks(
105+
points, start, end, pixel_size)
106+
gaps = self.sequential_cluster_distances(
107+
points, chunk_labels, start)
108+
109+
# Substract minimum gap distance due to rasterization
110+
min_gap = pixel_size/np.max(np.abs(vec))
111+
gaps -= min_gap
112+
113+
# Store gap information
114+
comp_dict['num_gaps'] = len(gaps)
115+
comp_dict['gap_length'] = np.sum(gaps)
116+
comp_dict['gap_frac'] = np.sum(gaps)/length
117+
118+
# Append the dictionary to the CSV log
119+
self.append(prefix, **comp_dict)
120+
121+
@staticmethod
122+
def cluster_track_chunks(points, start_point, end_point, pixel_size):
123+
"""Find point where the track is broken, divide out the track
124+
into self-contained chunks which are Linf connect (Moore neighbors).
125+
126+
Parameters
127+
----------
128+
points : np.ndarray
129+
(N, 3) List of track cluster point coordinates
130+
start_point : np.ndarray
131+
(3) Start point of the track cluster
132+
end_point : np.ndarray
133+
(3) End point of the track cluster
134+
pixel_size : float
135+
Dimension of one pixel, used to identify what is big enough to
136+
constitute a break
137+
138+
Returns
139+
-------
140+
np.ndarray
141+
(N) Track chunk labels
142+
"""
143+
# Project and cluster on the projected axis
144+
direction = (end_point-start_point)/np.linalg.norm(end_point-start_point)
145+
projs = np.dot(points - start_point, direction)
146+
perm = np.argsort(projs)
147+
seps = projs[perm][1:] - projs[perm][:-1]
148+
breaks = np.where(seps > pixel_size*1.1)[0] + 1
149+
cluster_labels = np.empty(len(projs), dtype=int)
150+
for i, index in enumerate(np.split(np.arange(len(projs)), breaks)):
151+
cluster_labels[perm[index]] = i
152+
153+
return cluster_labels
154+
155+
@staticmethod
156+
def sequential_cluster_distances(points, labels, start_point):
157+
"""Order clusters in order of distance from a starting point, compute
158+
the distances between successive clusters.
159+
160+
Parameters
161+
----------
162+
points : np.ndarray
163+
(N, 3) List of track cluster point coordinates
164+
labels : np.ndarray
165+
(N) Track chunk labels
166+
start_point : np.ndarray
167+
(3) Start point of the track cluster
168+
"""
169+
# If there's only one cluster, nothing to do here
170+
unique_labels = np.unique(labels)
171+
if len(unique_labels) < 2:
172+
return np.empty(0, dtype=float), np.empty(0, dtype=float)
173+
174+
# Order clusters
175+
start_dist = cdist([start_point], points).flatten()
176+
start_clust_dist = np.empty(len(unique_labels))
177+
for i, c in enumerate(unique_labels):
178+
start_clust_dist[i] = np.min(start_dist[labels == c])
179+
ordered_labels = unique_labels[np.argsort(start_clust_dist)]
180+
181+
# Compute the intercluster distance and relative angle
182+
n_gaps = len(ordered_labels) - 1
183+
dists = np.empty(n_gaps, dtype=float)
184+
for i in range(n_gaps):
185+
points_i = points[labels == ordered_labels[i]]
186+
points_j = points[labels == ordered_labels[i + 1]]
187+
dists[i] = np.min(cdist(points_i, points_j))
188+
189+
return dists

spine/ana/factories.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from spine.utils.factory import module_dict, instantiate
44

5-
from . import metric, script
5+
from . import diag, metric, script
66

77
# Build a dictionary of available calibration modules
88
ANA_DICT = {}
9-
for module in [metric, script]:
9+
for module in [diag, metric, script]:
1010
ANA_DICT.update(**module_dict(module))
1111

1212

spine/ana/script/save.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,7 @@ def process(self, data):
139139
prefix, obj_type = key.split('_')
140140
other = other_prefix[prefix]
141141
attrs = self.attrs[key]
142-
attrs_other = self.attrs[f'{other}_{obj_type}']
143142
lengths = self.lengths
144-
lengths_other = self.lengths
145143
if (self.match_mode is None or
146144
self.match_mode == f'{other}_to_{prefix}'):
147145
# If there is no matches, save objects by themselves
@@ -153,6 +151,8 @@ def process(self, data):
153151
# match on a single row
154152
match_suffix = f'{prefix[0]}2{other[0]}'
155153
match_key = f'{obj_type[:-1]}_matches_{match_suffix}'
154+
attrs_other = self.attrs[f'{other}_{obj_type}']
155+
lengths_other = self.lengths # TODO
156156
for idx, (obj_i, obj_j) in enumerate(data[match_key]):
157157
src_dict = obj_i.scalar_dict(attrs, lengths)
158158
if obj_j is not None:

0 commit comments

Comments
 (0)