diff --git a/binned_cdf/binned_logit_cdf.py b/binned_cdf/binned_logit_cdf.py index 7da6519..b1c476e 100644 --- a/binned_cdf/binned_logit_cdf.py +++ b/binned_cdf/binned_logit_cdf.py @@ -223,35 +223,39 @@ def prob(self, value: torch.Tensor) -> torch.Tensor: value = value.to(dtype=self.logits.dtype, device=self.logits.device) - # Determine number of sample dimensions (dimensions before batch_shape). - num_sample_dims = len(value.shape) - len(self.batch_shape) - - # Prepend singleton dimensions for sample_shape to bin_edges, bin_widths, and probs. - # For all of them, the resulting shape will be: (*sample_shape, *batch_shape, num_bins) - bin_edges_left = self.bin_edges[..., :-1] # shape: (*batch_shape, num_bins) - bin_edges_right = self.bin_edges[..., 1:] # shape: (*batch_shape, num_bins) - bin_edges_left = bin_edges_left.view((1,) * num_sample_dims + bin_edges_left.shape) - bin_edges_right = bin_edges_right.view((1,) * num_sample_dims + bin_edges_right.shape) - bin_widths = self.bin_widths.view((1,) * num_sample_dims + self.bin_widths.shape) - probs = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) - - # Add bin dimension to value for broadcasting. - value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) - - # Check which bin each value falls into. Result shape: (*sample_shape, *batch_shape, num_bins). - in_bin = ((value_expanded >= bin_edges_left) & (value_expanded < bin_edges_right)).to(self.logits.dtype) - - # Handle right edge case (include bound_up in last bin). - at_right_edge = torch.isclose( - value_expanded, torch.tensor(self.bound_up, dtype=self.logits.dtype, device=self.logits.device) - ) - in_bin[..., -1] = torch.max(in_bin[..., -1], at_right_edge[..., -1]) - - # PDF = (probability mass / bin width) for the containing bin. - pdf_per_bin = probs / bin_widths # shape: (*sample_shape, *batch_shape, num_bins) - - # Sum over bins is the same as selecting the bin, as there is only one bin active per value. - return torch.sum(in_bin * pdf_per_bin, dim=-1) + # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions). + if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape): + value = value.expand(self.batch_shape) + + # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the + # index where value would be inserted to maintain sorted order. + # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index. + bin_indices = torch.searchsorted(self.bin_edges, value) - 1 # shape: (*sample_shape, *batch_shape) + + # Clamp to valid range [0, num_bins - 1] to handle edge cases: + # - values below bound_low would give bin_idx = -1 + # - values at bound_up would give bin_idx = num_bins + bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1) + + # Gather the bin widths and probabilities for the selected bins. + # For bin_widths of shape (num_bins,) we can index directly. + bin_widths_selected = self.bin_widths[bin_indices] # shape: (*sample_shape, *batch_shape) + + # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension. + # Add sample dimensions to bin_probs and expand to match bin_indices shape. + num_sample_dims = len(bin_indices.shape) - len(self.batch_shape) + bin_probs_for_gather = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) + bin_probs_for_gather = bin_probs_for_gather.expand( + *bin_indices.shape, -1 + ) # shape: (*sample_shape, *batch_shape, num_bins) + + # Gather the selected bin probabilities. + bin_indices_for_gather = bin_indices.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) + bin_probs_selected = torch.gather(bin_probs_for_gather, dim=-1, index=bin_indices_for_gather) + bin_probs_selected = bin_probs_selected.squeeze(-1) + + # Compute PDF = probability mass / bin width. + return bin_probs_selected / bin_widths_selected def cdf(self, value: torch.Tensor) -> torch.Tensor: """Compute cumulative distribution function at given values. @@ -269,25 +273,34 @@ def cdf(self, value: torch.Tensor) -> torch.Tensor: value = value.to(dtype=self.logits.dtype, device=self.logits.device) - # Determine number of sample dimensions (dimensions before batch_shape). - num_sample_dims = len(value.shape) - len(self.batch_shape) + # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions). + if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape): + value = value.expand(self.batch_shape) - # Prepend singleton dimensions for sample_shape to bin_centers. - # bin_centers: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins) - bin_centers_expanded = self.bin_centers.view((1,) * num_sample_dims + self.bin_centers.shape) + # Use binary search to find how many bin centers are <= value. + # torch.searchsorted with right=True gives us the number of elements <= value. + num_bins_active = torch.searchsorted(self.bin_centers, value, right=True) - # Prepend singleton dimensions for sample_shape to probs. - # probs: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins) - probs_expanded = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) + # Clamp to valid range [0, num_bins]. + num_bins_active = torch.clamp(num_bins_active, 0, self.num_bins) # shape: (*sample_shape, *batch_shape) - # Add the bin dimension to the input which is used for comparing with the bin centers. - value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) + # Compute cumulative sum of bin probabilities. + # Prepend 0 for the case where no bins are active. + num_sample_dims = len(num_bins_active.shape) - len(self.batch_shape) + cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins) + cumsum_probs = torch.cat( + [torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), cumsum_probs], + dim=-1, + ) # shape: (*batch_shape, num_bins + 1) - # Mask for bins with centers <= value. - mask = bin_centers_expanded <= value_expanded # shape: (*sample_shape, *batch_shape, num_bins) + # Expand cumsum_probs to match sample dimensions and gather. + cumsum_probs_for_gather = cumsum_probs.view((1,) * num_sample_dims + cumsum_probs.shape) + cumsum_probs_for_gather = cumsum_probs_for_gather.expand(*num_bins_active.shape, -1) + num_bins_active_for_gather = num_bins_active.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) + cdf_values = torch.gather(cumsum_probs_for_gather, dim=-1, index=num_bins_active_for_gather) + cdf_values = cdf_values.squeeze(-1) - # Sum the bins for this value by their weighted "activation"=probability. - return torch.sum(mask * probs_expanded, dim=-1) + return cdf_values def icdf(self, value: torch.Tensor) -> torch.Tensor: """Compute the inverse CDF, i.e., the quantile function, at the given values.