diff --git a/spine/ana/metric/cluster.py b/spine/ana/metric/cluster.py index d15fd06a..fba6ce15 100644 --- a/spine/ana/metric/cluster.py +++ b/spine/ana/metric/cluster.py @@ -35,7 +35,7 @@ class ClusterAna(AnaBase): def __init__(self, obj_type=None, use_objects=False, per_object=True, per_shape=True, metrics=('pur', 'eff', 'ari'), - label_key='clust_label_adapt', label_col=None, **kwargs): + label_key='clust_label_adapt', label_col=None, time_window=None, **kwargs): """Initialize the analysis script. Parameters @@ -58,6 +58,10 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True, using the raw reconstruction output label_col : str, optional Column name in the label tensor specifying the aggregation label_col + time_window : List[float], optional + Time window within which to include objects. If provided, must be a list + of two values [t_min, t_max]. Objects outside this window will be excluded. + For reconstructed objects, filtering is based on their matched truth objects' times. **kwargs : dict, optional Additional arguments to pass to :class:`AnaBase` """ @@ -71,10 +75,14 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True, assert per_object or not use_objects, ( "If evaluating clustering standalone (not per object), cannot " "use objects to evaluate it.") + assert time_window is None or len(time_window) == 2, ( + "Time window must be specified as a list of two values [t_min, t_max].") # Initialize the parent class super().__init__(obj_type, 'both', **kwargs) + # Store the time window + self.time_window = time_window # If the clustering is not done per object, fix target if not per_object: @@ -109,6 +117,8 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True, keys[f'{obj}_clusts'] = True if obj != 'interaction': keys[f'{obj}_shapes'] = True + # Need truth objects for time filtering + keys[f'truth_{obj}s'] = True else: keys['points'] = True @@ -154,15 +164,58 @@ def process(self, data): labels = data[self.label_key][:, label_col] if obj_type != 'interaction': shapes = data[self.label_key][:, SHAPE_COL] - num_truth = len(np.unique(labels[labels > -1])) + + # Get truth objects for time filtering + truth_objects = data[f'truth_{obj_type}s'] + truth_times = np.array([obj.t for obj in truth_objects]) + + # Create a mapping from truth object index to label + truth_to_label = {} + for i, obj in enumerate(truth_objects): + if (obj.index_adapt < len(labels)).all(): + truth_to_label[i] = labels[obj.index_adapt] + + # Filter truth objects by time window if specified + if self.time_window is not None: + valid_truth_mask = (truth_times >= self.time_window[0]) & (truth_times <= self.time_window[1]) + valid_truth_indices = np.where(valid_truth_mask)[0] + + # Create a mask for valid labels + valid_labels = np.zeros_like(labels, dtype=bool) + for idx in valid_truth_indices: + if idx in truth_to_label: + valid_labels[truth_objects[idx].index_adapt] = True + + num_truth = len(np.unique(labels[labels > -1])) + # Set invalid labels to -1 + labels[~valid_labels] = -1 + else: + valid_truth_indices = np.arange(len(truth_objects)) + num_truth = len(np.unique(labels[labels > -1])) + else: # Rebuild the labels num_points = len(data['points']) labels = -np.ones(num_points) - num_truth = len(data[f'truth_{obj_type}s']) - for i, obj in enumerate(data[f'truth_{obj_type}s']): + + # First pass: collect truth objects and their times + truth_objects = data[f'truth_{obj_type}s'] + truth_times = np.array([obj.t for obj in truth_objects]) + + # Filter truth objects by time window if specified + if self.time_window is not None: + valid_truth_mask = (truth_times >= self.time_window[0]) & (truth_times <= self.time_window[1]) + valid_truth_indices = np.where(valid_truth_mask)[0] + else: + valid_truth_indices = np.arange(len(truth_objects)) + + # Build labels only for valid truth objects + for i, idx in enumerate(valid_truth_indices): + obj = truth_objects[idx] labels[obj.index_adapt] = i + + num_truth = len(truth_objects) # Build the cluster predictions for this object type preds = -np.ones(num_points) @@ -178,11 +231,38 @@ def process(self, data): else: # Use clusters from the object indexes - num_reco = len(data[f'reco_{obj_type}s']) - for i, obj in enumerate(data[f'reco_{obj_type}s']): - preds[obj.index] = i + reco_objects = data[f'reco_{obj_type}s'] + truth_objects = data[f'truth_{obj_type}s'] + + # Filter reconstructed objects based on their matched truth objects' times + valid_reco_indices = [] + for i, reco_obj in enumerate(reco_objects): + # Skip if no matches + if not len(reco_obj.match_ids): + continue + + # Get the matched truth object + truth_idx = reco_obj.match_ids[0] + if truth_idx >= len(truth_objects): + continue + + truth_obj = truth_objects[truth_idx] + + # Apply time window filter if specified + if self.time_window is not None: + if truth_obj.t < self.time_window[0] or truth_obj.t > self.time_window[1]: + continue + + valid_reco_indices.append(i) + + # Build predictions only for valid reconstructed objects + for i, reco_idx in enumerate(valid_reco_indices): + reco_obj = reco_objects[reco_idx] + preds[reco_obj.index] = i if obj_type != 'interaction': - shapes[obj.index] = obj.shape + shapes[reco_obj.index] = reco_obj.shape + + num_reco = len(valid_reco_indices) else: num_reco = len(data['clusts']) diff --git a/spine/data/out/interaction.py b/spine/data/out/interaction.py index d498e91a..e3f6c1dc 100644 --- a/spine/data/out/interaction.py +++ b/spine/data/out/interaction.py @@ -77,7 +77,6 @@ class InteractionBase: flash_total_pe: float = -1. flash_hypo_pe: float = -1. topology: str = None - # Fixed-length attributes _fixed_length_attrs = ( ('vertex', 3), ('particle_counts', len(PID_LABELS) - 1), @@ -341,10 +340,12 @@ class TruthInteraction(Neutrino, InteractionBase, TruthBase): """ nu_id: int = -1 reco_vertex: np.ndarray = None + t: float = -np.inf # Fixed-length attributes _fixed_length_attrs = ( ('reco_vertex', 3), + ('t', 1), *Neutrino._fixed_length_attrs, *InteractionBase._fixed_length_attrs ) @@ -384,6 +385,25 @@ def __str__(self): """ return 'Truth' + super().__str__() + @property + def t(self): + """Time of the interaction. + + Returns + ------- + float + Time of the interaction + """ + min_time = np.inf + for part in self.particles: + if part.t < min_time and part.is_valid: + min_time = part.t + return min_time + + @t.setter + def t(self, time): + pass + def attach_neutrino(self, neutrino): """Attach neutrino generator information to this interaction.