Skip to content

Commit af1b80d

Browse files
Merge pull request #67 from francois-drielsma/develop
Sanitized indexing, add template-based chi2 attributes
2 parents 623c1da + 06aa2d2 commit af1b80d

17 files changed

+267
-110
lines changed

spine/build/fragment.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,23 @@ def _build_truth(self, label_tensor, points_label, depositions_label,
220220
ref_tensor[index_ref, PART_COL], return_counts=True)
221221
part_id = int(part_ids[np.argmax(counts)])
222222
if part_id > -1:
223+
# Load the MC particle information
223224
assert part_id < len(particles), (
224225
"Invalid particle ID found in fragment labels.")
225-
fragment = TruthFragment(**particles[part_id].as_dict())
226+
particle = particles[part_id]
227+
fragment = TruthFragment(**particle.as_dict())
228+
229+
# Override the indexes of the fragment but preserve them
230+
fragment.orig_id = part_id
231+
fragment.orig_group_id = particle.group_id
232+
fragment.orig_parent_id = particle.parent_id
233+
fragment.orig_children_id = particle.children_id
234+
226235
fragment.id = i
236+
fragment.group_id = i
237+
fragment.parent_id = i
238+
fragment.children_id = np.empty(
239+
0, dtype=fragment.orig_children_id.dtype)
227240

228241
# Fill long-form attributes
229242
if truth_only:
@@ -243,7 +256,7 @@ def _build_truth(self, label_tensor, points_label, depositions_label,
243256
index_g4 = np.where(
244257
label_g4_tensor[:, CLUST_COL] == frag_id)[0]
245258
fragment.index_g4 = index_g4
246-
fragment.points_g4 = poins_g4[index_g4]
259+
fragment.points_g4 = points_g4[index_g4]
247260
fragment.depositions_g4 = depositions_g4[index_g4]
248261

249262
else:

spine/build/manager.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class BuildManager:
3434
('label_adapt_tensor', ('clust_label_adapt',)),
3535
('label_g4_tensor', ('clust_label_g4',)),
3636
('depositions_q_label', ('charge_label',)),
37+
('graph_label', ('graph_label',)),
3738
('sources', ('sources_adapt', 'sources')),
3839
('sources_label', ('sources_label',)),
3940
('particles', ('particles',)),

spine/build/particle.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from scipy.special import softmax
66

77
from spine.data.out import RecoParticle, TruthParticle
8+
89
from spine.utils.globals import COORD_COLS, VALUE_COL, GROUP_COL, TRACK_SHP
10+
from spine.utils.gnn.network import filter_invalid_nodes
911

1012
from .base import BuilderBase
1113

@@ -39,7 +41,8 @@ class ParticleBuilder(BuilderBase):
3941

4042
# Necessary/optional data products to build a truth object
4143
_build_truth_keys = (
42-
('particles', False), ('truth_fragments', False),
44+
('particles', False), ('graph_label', False),
45+
('truth_fragments', False),
4346
*BuilderBase._build_truth_keys
4447
)
4548

@@ -165,7 +168,8 @@ def _build_truth(self, particles, label_tensor, points_label,
165168
depositions_label, depositions_q_label=None,
166169
label_adapt_tensor=None, points=None, depositions=None,
167170
label_g4_tensor=None, points_g4=None, depositions_g4=None,
168-
sources_label=None, sources=None, truth_fragments=None):
171+
sources_label=None, sources=None, graph_label=None,
172+
truth_fragments=None):
169173
"""Builds :class:`TruthParticle` objects from the full chain output.
170174
171175
Parameters
@@ -199,6 +203,8 @@ def _build_truth(self, particles, label_tensor, points_label,
199203
(N', 2) Tensor which contains the label module/tpc information
200204
sources : np.ndarray, optional
201205
(N, 2) Tensor which contains the module/tpc information
206+
graph_label : np.ndarray, optional
207+
(E, 2) Parentage relations in the set of particles
202208
truth_fragments : List[TruthFragment], optional
203209
(F) List of true fragments
204210
@@ -222,9 +228,17 @@ def _build_truth(self, particles, label_tensor, points_label,
222228
assert particle.id == group_id, (
223229
"The ordering of the true particles is wrong.")
224230

225-
# Override the index of the particle but preserve it
231+
# Override the index of the particle and its group, but preserve it
226232
particle.orig_id = group_id
233+
particle.orig_group_id = group_id
234+
particle.orig_parent_id = particle.parent_id
235+
particle.orig_children_id = particle.children_id
236+
227237
particle.id = i
238+
particle.group_id = i
239+
particle.parent_id = i
240+
particle.children_id = np.empty(
241+
0, dtype=particle.orig_children_id.dtype)
228242

229243
# Update the deposited energy attribute by summing that of all
230244
# particles in the group (LArCV definition != SPINE definition)
@@ -268,6 +282,23 @@ def _build_truth(self, particles, label_tensor, points_label,
268282
# Append
269283
truth_particles.append(particle)
270284

285+
# If the parentage relations of non-empty particles are available,
286+
# use them to assign parent/children IDs in the new particle set
287+
if graph_label is not None:
288+
# Narrow down the list of edges to those connecting visible particles
289+
inval = set(np.unique(graph_label)).difference(set(valid_group_ids))
290+
if len(inval) > 0:
291+
graph_label = filter_invalid_nodes(graph_label, tuple(inval))
292+
293+
# Use the remaining edges to build parantage relations
294+
mapping = {group_id: i for i, group_id in enumerate(valid_group_ids)}
295+
for (source, target) in graph_label:
296+
parent = truth_particles[mapping[source]]
297+
child = truth_particles[mapping[target]]
298+
299+
child.parent_id = parent.id
300+
parent.children_id = np.append(parent.children_id, child.id)
301+
271302
return truth_particles
272303

273304
def load_reco(self, data):

spine/data/out/fragment.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,13 @@ class TruthFragment(Particle, FragmentBase, TruthBase):
131131
Attributes
132132
----------
133133
orig_interaction_id : int
134-
Unaltered index of the interaction in the original MC paricle list
134+
Unaltered index of the interaction in the original MC particle list
135+
orig_parent_id : int
136+
Unaltered index of the particle parent in the original MC particle list
137+
orig_group_id : int
138+
Unaltered index of the particle group in the original MC particle list
139+
orig_children_id : np.ndarray
140+
Unaltered list of the particle children in the original MC particle list
135141
children_counts : np.ndarray
136142
(P) Number of truth child fragment of each shape
137143
reco_length : float
@@ -143,6 +149,9 @@ class TruthFragment(Particle, FragmentBase, TruthBase):
143149
to track objects)
144150
"""
145151
orig_interaction_id: int = -1
152+
orig_parent_id: int = -1
153+
orig_group_id: int = -1
154+
orig_children_id: np.ndarray = -1
146155
children_counts: np.ndarray = None
147156
reco_length: float = -1.
148157
reco_start_dir: np.ndarray = None
@@ -157,7 +166,7 @@ class TruthFragment(Particle, FragmentBase, TruthBase):
157166

158167
# Variable-length attributes
159168
_var_length_attrs = (
160-
('children_counts', np.int32),
169+
('orig_children_id', np.int64), ('children_counts', np.int32),
161170
*TruthBase._var_length_attrs,
162171
*Particle._var_length_attrs
163172
)

spine/data/out/particle.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,13 @@ class ParticleBase:
3535
Semantic type (shower (0), track (1), Michel (2), delta (3),
3636
low energy scatter (4)) of this particle
3737
pid : int
38-
Particle spcies (Photon (0), Electron (1), Muon (2), Charged Pion (3),
39-
Proton (4)) of this particle
38+
Particle species (Photon (0), Electron (1), Muon (2), Charged Pion (3),
39+
Proton (4), Kaon (5)) of this particle
40+
chi2_pid : int
41+
Particle species as predicted by the chi2 template method (Muon (2),
42+
Charged Pion (3), Proton (4), Kaon (5)) of this particle
43+
chi2_per_pid : np.ndarray
44+
(P) Array of chi2 values associated with each particle class
4045
pdg_code : int
4146
PDG code corresponding to the PID number
4247
is_primary : bool
@@ -61,11 +66,11 @@ class ParticleBase:
6166
csda_ke : float
6267
Kinetic energy reconstructed from the particle range in MeV
6368
csda_ke_per_pid : np.ndarray
64-
Same as `csda_ke` but for every available track PID hypothesis
69+
(P) Same as `csda_ke` but for every available track PID hypothesis
6570
mcs_ke : float
6671
Kinetic energy reconstructed using the MCS method in MeV
6772
mcs_ke_per_pid : np.ndarray
68-
Same as `mcs_ke` but for every available track PID hypothesis
73+
(P) Same as `mcs_ke` but for every available track PID hypothesis
6974
momentum : np.ndarray
7075
3-momentum of the particle at the production point in MeV/c
7176
p : float
@@ -80,6 +85,8 @@ class ParticleBase:
8085
interaction_id: int = -1
8186
shape: int = -1
8287
pid: int = -1
88+
chi2_pid: int = -1
89+
chi2_per_pid: np.ndarray = None
8390
pdg_code: int = -1
8491
is_primary: bool = False
8592
length: float = -1.
@@ -102,6 +109,7 @@ class ParticleBase:
102109
_fixed_length_attrs = (
103110
('start_point', 3), ('end_point', 3), ('start_dir', 3),
104111
('end_dir', 3), ('momentum', 3),
112+
('chi2_per_pid', len(PID_LABELS) - 1),
105113
('csda_ke_per_pid', len(PID_LABELS) - 1),
106114
('mcs_ke_per_pid', len(PID_LABELS) - 1)
107115
)
@@ -197,7 +205,7 @@ class RecoParticle(ParticleBase, RecoBase):
197205
Attributes
198206
----------
199207
pid_scores : np.ndarray
200-
(P) Array of softmax scores associated with each of particle class
208+
(P) Array of softmax scores associated with each particle class
201209
primary_scores : np.ndarray
202210
(2) Array of softmax scores associated with secondary and primary
203211
ppn_ids : np.ndarray
@@ -411,7 +419,13 @@ class TruthParticle(Particle, ParticleBase, TruthBase):
411419
Attributes
412420
----------
413421
orig_interaction_id : int
414-
Unaltered index of the interaction in the original MC paricle list
422+
Unaltered index of the interaction in the original MC particle list
423+
orig_parent_id : int
424+
Unaltered index of the particle parent in the original MC particle list
425+
orig_group_id : int
426+
Unaltered index of the particle group in the original MC particle list
427+
orig_children_id : np.ndarray
428+
Unaltered list of the particle children in the original MC particle list
415429
children_counts : np.ndarray
416430
(P) Number of truth child particle of each shape
417431
reco_length : float
@@ -427,6 +441,9 @@ class TruthParticle(Particle, ParticleBase, TruthBase):
427441
Best-guess reconstructed momentum of the particle
428442
"""
429443
orig_interaction_id: int = -1
444+
orig_parent_id: int = -1
445+
orig_group_id: int = -1
446+
orig_children_id: np.ndarray = -1
430447
children_counts: np.ndarray = None
431448
reco_length: float = -1.
432449
reco_start_dir: np.ndarray = None
@@ -443,7 +460,7 @@ class TruthParticle(Particle, ParticleBase, TruthBase):
443460

444461
# Variable-length attributes
445462
_var_length_attrs = (
446-
('children_counts', np.int32),
463+
('orig_children_id', np.int64), ('children_counts', np.int32),
447464
*TruthBase._var_length_attrs,
448465
*ParticleBase._var_length_attrs,
449466
*Particle._var_length_attrs

spine/data/particle.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ class Particle(PosDataBase):
2727
Index in the original MCTruth array from whence it came
2828
mcst_index : int
2929
Index in the original MCTrack/MCShower array from whence it came
30-
gen_id : int
31-
Index of the particle at the generator level
3230
group_id : int
3331
Index of the group the particle belongs to
3432
interaction_id : int
@@ -110,14 +108,13 @@ class Particle(PosDataBase):
110108
id: int = -1
111109
mct_index: int = -1
112110
mcst_index: int = -1
113-
gen_id: int = -1
114111
group_id: int = -1
115112
interaction_id: int = -1
116113
nu_id: int = -1
117114
interaction_primary: int = -1
118115
group_primary: int = -1
119116
parent_id: int = -1
120-
children_id: int = None
117+
children_id: np.ndarray = None
121118
track_id: int = -1
122119
parent_track_id: int = -1
123120
ancestor_track_id: int = -1
@@ -248,7 +245,7 @@ def from_larcv(cls, particle):
248245
for prefix in ('', 'parent_', 'ancestor_'):
249246
for key in ('track_id', 'pdg_code', 'creation_process', 't'):
250247
obj_dict[prefix+key] = getattr(particle, prefix+key)()
251-
for key in ('id', 'gen_id', 'group_id', 'interaction_id', 'parent_id',
248+
for key in ('id', 'group_id', 'interaction_id', 'parent_id',
252249
'mct_index', 'mcst_index', 'num_voxels', 'shape',
253250
'energy_init', 'energy_deposit', 'distance_travel'):
254251
if not hasattr(particle, key):

spine/io/collate.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,18 @@ def __call__(self, batch):
169169
has_batch_col=True, coord_cols=coord_cols)
170170

171171
elif isinstance(ref_obj, tuple) and len(ref_obj) == 2:
172-
# Case where an index and an offset is provided per entry.
172+
# Case where an index and a count is provided per entry.
173+
# Start by computing the necessary node ID offsets to apply
174+
total_counts = [sample[key][1] for sample in batch]
175+
offsets = np.zeros(len(total_counts), dtype=int)
176+
offsets[1:] = np.cumsum(total_counts)[:-1]
177+
173178
# Stack the indexes, do not add a batch column
174-
tensor = np.concatenate(
175-
[sample[key][0] for sample in batch], axis=1)
179+
tensor_list = []
180+
for i, sample in enumerate(batch):
181+
tensor_list.append(sample[key][0] + offsets[i])
182+
tensor = np.concatenate(tensor_list, axis=1)
176183
counts = [sample[key][0].shape[-1] for sample in batch]
177-
offsets = [sample[key][1] for sample in batch]
178184

179185
if len(tensor.shape) == 1:
180186
data[key] = IndexBatch(tensor, counts, offsets)

spine/io/parse/cluster.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
Contains the following parsers:
44
- :class:`Cluster2DParser`
55
- :class:`Cluster3DParser`
6+
- :class:`Cluster3DAggregateParser`
7+
- :class:`Cluster3DChargeRescaledParser`
68
"""
79

810
from warnings import warn

0 commit comments

Comments
 (0)