Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 88 additions & 8 deletions spine/ana/metric/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ClusterAna(AnaBase):

def __init__(self, obj_type=None, use_objects=False, per_object=True,
per_shape=True, metrics=('pur', 'eff', 'ari'),
label_key='clust_label_adapt', label_col=None, **kwargs):
label_key='clust_label_adapt', label_col=None, time_window=None, **kwargs):
"""Initialize the analysis script.

Parameters
Expand All @@ -58,6 +58,10 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True,
using the raw reconstruction output
label_col : str, optional
Column name in the label tensor specifying the aggregation label_col
time_window : List[float], optional
Time window within which to include objects. If provided, must be a list
of two values [t_min, t_max]. Objects outside this window will be excluded.
For reconstructed objects, filtering is based on their matched truth objects' times.
**kwargs : dict, optional
Additional arguments to pass to :class:`AnaBase`
"""
Expand All @@ -71,10 +75,14 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True,
assert per_object or not use_objects, (
"If evaluating clustering standalone (not per object), cannot "
"use objects to evaluate it.")
assert time_window is None or len(time_window) == 2, (
"Time window must be specified as a list of two values [t_min, t_max].")

# Initialize the parent class
super().__init__(obj_type, 'both', **kwargs)

# Store the time window
self.time_window = time_window

# If the clustering is not done per object, fix target
if not per_object:
Expand Down Expand Up @@ -109,6 +117,8 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True,
keys[f'{obj}_clusts'] = True
if obj != 'interaction':
keys[f'{obj}_shapes'] = True
# Need truth objects for time filtering
keys[f'truth_{obj}s'] = True

else:
keys['points'] = True
Expand Down Expand Up @@ -154,15 +164,58 @@ def process(self, data):
labels = data[self.label_key][:, label_col]
if obj_type != 'interaction':
shapes = data[self.label_key][:, SHAPE_COL]
num_truth = len(np.unique(labels[labels > -1]))

# Get truth objects for time filtering
truth_objects = data[f'truth_{obj_type}s']
truth_times = np.array([obj.t for obj in truth_objects])

# Create a mapping from truth object index to label
truth_to_label = {}
for i, obj in enumerate(truth_objects):
if (obj.index_adapt < len(labels)).all():
truth_to_label[i] = labels[obj.index_adapt]

# Filter truth objects by time window if specified
if self.time_window is not None:
valid_truth_mask = (truth_times >= self.time_window[0]) & (truth_times <= self.time_window[1])
valid_truth_indices = np.where(valid_truth_mask)[0]

# Create a mask for valid labels
valid_labels = np.zeros_like(labels, dtype=bool)
for idx in valid_truth_indices:
if idx in truth_to_label:
valid_labels[truth_objects[idx].index_adapt] = True

num_truth = len(np.unique(labels[labels > -1]))
# Set invalid labels to -1
labels[~valid_labels] = -1
else:
valid_truth_indices = np.arange(len(truth_objects))
num_truth = len(np.unique(labels[labels > -1]))


else:
# Rebuild the labels
num_points = len(data['points'])
labels = -np.ones(num_points)
num_truth = len(data[f'truth_{obj_type}s'])
for i, obj in enumerate(data[f'truth_{obj_type}s']):

# First pass: collect truth objects and their times
truth_objects = data[f'truth_{obj_type}s']
truth_times = np.array([obj.t for obj in truth_objects])

# Filter truth objects by time window if specified
if self.time_window is not None:
valid_truth_mask = (truth_times >= self.time_window[0]) & (truth_times <= self.time_window[1])
valid_truth_indices = np.where(valid_truth_mask)[0]
else:
valid_truth_indices = np.arange(len(truth_objects))

# Build labels only for valid truth objects
for i, idx in enumerate(valid_truth_indices):
obj = truth_objects[idx]
labels[obj.index_adapt] = i

num_truth = len(truth_objects)

# Build the cluster predictions for this object type
preds = -np.ones(num_points)
Expand All @@ -178,11 +231,38 @@ def process(self, data):

else:
# Use clusters from the object indexes
num_reco = len(data[f'reco_{obj_type}s'])
for i, obj in enumerate(data[f'reco_{obj_type}s']):
preds[obj.index] = i
reco_objects = data[f'reco_{obj_type}s']
truth_objects = data[f'truth_{obj_type}s']

# Filter reconstructed objects based on their matched truth objects' times
valid_reco_indices = []
for i, reco_obj in enumerate(reco_objects):
# Skip if no matches
if not len(reco_obj.match_ids):
continue

# Get the matched truth object
truth_idx = reco_obj.match_ids[0]
if truth_idx >= len(truth_objects):
continue

truth_obj = truth_objects[truth_idx]

# Apply time window filter if specified
if self.time_window is not None:
if truth_obj.t < self.time_window[0] or truth_obj.t > self.time_window[1]:
continue

valid_reco_indices.append(i)

# Build predictions only for valid reconstructed objects
for i, reco_idx in enumerate(valid_reco_indices):
reco_obj = reco_objects[reco_idx]
preds[reco_obj.index] = i
if obj_type != 'interaction':
shapes[obj.index] = obj.shape
shapes[reco_obj.index] = reco_obj.shape

num_reco = len(valid_reco_indices)

else:
num_reco = len(data['clusts'])
Expand Down
22 changes: 21 additions & 1 deletion spine/data/out/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class InteractionBase:
flash_total_pe: float = -1.
flash_hypo_pe: float = -1.
topology: str = None

# Fixed-length attributes
_fixed_length_attrs = (
('vertex', 3), ('particle_counts', len(PID_LABELS) - 1),
Expand Down Expand Up @@ -341,10 +340,12 @@ class TruthInteraction(Neutrino, InteractionBase, TruthBase):
"""
nu_id: int = -1
reco_vertex: np.ndarray = None
t: float = -np.inf

# Fixed-length attributes
_fixed_length_attrs = (
('reco_vertex', 3),
('t', 1),
*Neutrino._fixed_length_attrs,
*InteractionBase._fixed_length_attrs
)
Expand Down Expand Up @@ -384,6 +385,25 @@ def __str__(self):
"""
return 'Truth' + super().__str__()

@property
def t(self):
"""Time of the interaction.

Returns
-------
float
Time of the interaction
"""
min_time = np.inf
for part in self.particles:
if part.t < min_time and part.is_valid:
min_time = part.t
return min_time

@t.setter
def t(self, time):
pass

def attach_neutrino(self, neutrino):
"""Attach neutrino generator information to this interaction.

Expand Down
Loading