From e08841825cd69cb0b239abeb452629312e2ebad5 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Mon, 5 Jan 2026 16:16:02 +0100 Subject: [PATCH 1/4] WIP gather --- binned_cdf/binned_logit_cdf.py | 45 +++++++++++++++++----------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/binned_cdf/binned_logit_cdf.py b/binned_cdf/binned_logit_cdf.py index 7da6519..a520a44 100644 --- a/binned_cdf/binned_logit_cdf.py +++ b/binned_cdf/binned_logit_cdf.py @@ -223,35 +223,34 @@ def prob(self, value: torch.Tensor) -> torch.Tensor: value = value.to(dtype=self.logits.dtype, device=self.logits.device) - # Determine number of sample dimensions (dimensions before batch_shape). - num_sample_dims = len(value.shape) - len(self.batch_shape) + # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the + # index where value would be inserted to maintain sorted order. + # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index. + bin_indices = torch.searchsorted(self.bin_edges, value) - 1 # shape: (*sample_shape, *batch_shape) - # Prepend singleton dimensions for sample_shape to bin_edges, bin_widths, and probs. - # For all of them, the resulting shape will be: (*sample_shape, *batch_shape, num_bins) - bin_edges_left = self.bin_edges[..., :-1] # shape: (*batch_shape, num_bins) - bin_edges_right = self.bin_edges[..., 1:] # shape: (*batch_shape, num_bins) - bin_edges_left = bin_edges_left.view((1,) * num_sample_dims + bin_edges_left.shape) - bin_edges_right = bin_edges_right.view((1,) * num_sample_dims + bin_edges_right.shape) - bin_widths = self.bin_widths.view((1,) * num_sample_dims + self.bin_widths.shape) - probs = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) + # Clamp to valid range [0, num_bins - 1] to handle edge cases: + # - values below bound_low would give bin_idx = -1 + # - values at bound_up would give bin_idx = num_bins + bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1) - # Add bin dimension to value for broadcasting. - value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) + # Gather the bin widths and probabilities for the selected bins. - # Check which bin each value falls into. Result shape: (*sample_shape, *batch_shape, num_bins). - in_bin = ((value_expanded >= bin_edges_left) & (value_expanded < bin_edges_right)).to(self.logits.dtype) + # For bin_widths of shape (num_bins,) we can index directly. + bin_widths_selected = self.bin_widths[bin_indices] # shape: (*sample_shape, *batch_shape) - # Handle right edge case (include bound_up in last bin). - at_right_edge = torch.isclose( - value_expanded, torch.tensor(self.bound_up, dtype=self.logits.dtype, device=self.logits.device) - ) - in_bin[..., -1] = torch.max(in_bin[..., -1], at_right_edge[..., -1]) + # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension. + bin_indices_for_gather = bin_indices.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) + + # Expand bin_probs to match sample dimensions. + num_sample_dims = len(value.shape) - len(self.batch_shape) + probs_expanded = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) - # PDF = (probability mass / bin width) for the containing bin. - pdf_per_bin = probs / bin_widths # shape: (*sample_shape, *batch_shape, num_bins) + # Gather and squeeze the extra dimension. + probs_selected = torch.gather(probs_expanded, -1, bin_indices_for_gather) + probs_selected = probs_selected.squeeze(-1) - # Sum over bins is the same as selecting the bin, as there is only one bin active per value. - return torch.sum(in_bin * pdf_per_bin, dim=-1) + # Compute PDF = probability mass / bin width. + return probs_selected / bin_widths_selected def cdf(self, value: torch.Tensor) -> torch.Tensor: """Compute cumulative distribution function at given values. From 653d17b84e05ff5e198871658d9cc09d61000f6b Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Mon, 5 Jan 2026 16:20:09 +0100 Subject: [PATCH 2/4] Tests run with explicit shape construction --- binned_cdf/binned_logit_cdf.py | 40 ++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/binned_cdf/binned_logit_cdf.py b/binned_cdf/binned_logit_cdf.py index a520a44..b39a525 100644 --- a/binned_cdf/binned_logit_cdf.py +++ b/binned_cdf/binned_logit_cdf.py @@ -223,6 +223,12 @@ def prob(self, value: torch.Tensor) -> torch.Tensor: value = value.to(dtype=self.logits.dtype, device=self.logits.device) + # Broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions). + if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape): + # Expand value to have the correct batch dimensions. + # For example: scalar value with batch_shape=(8,) → value.shape=(8,) + value = value.expand(self.batch_shape) + # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the # index where value would be inserted to maintain sorted order. # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index. @@ -238,16 +244,32 @@ def prob(self, value: torch.Tensor) -> torch.Tensor: # For bin_widths of shape (num_bins,) we can index directly. bin_widths_selected = self.bin_widths[bin_indices] # shape: (*sample_shape, *batch_shape) - # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension. - bin_indices_for_gather = bin_indices.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) + # For bin_probs of shape (*batch_shape, num_bins), we need to handle batch indexing. + # We'll use advanced indexing to select the appropriate bins. + # First, we need to create batch indices that match the shape of bin_indices. - # Expand bin_probs to match sample dimensions. - num_sample_dims = len(value.shape) - len(self.batch_shape) - probs_expanded = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) - - # Gather and squeeze the extra dimension. - probs_selected = torch.gather(probs_expanded, -1, bin_indices_for_gather) - probs_selected = probs_selected.squeeze(-1) + if len(self.batch_shape) == 0: + # No batch dimensions - simple case. + probs_selected = self.bin_probs[bin_indices] + else: + # With batch dimensions, we need to construct proper indices. + # bin_probs has shape (*batch_shape, num_bins) + # bin_indices has shape (*sample_shape, *batch_shape) + # We need to index into the last dimension of bin_probs using bin_indices, + # while preserving the batch structure. + + # Create batch indices: for each batch dimension, we need indices [0, 1, 2, ...] + # that broadcast correctly with the sample shape. + batch_indices = [] + for i, batch_dim in enumerate(self.batch_shape): + # Create a shape that's 1 everywhere except at this batch dimension + shape = [1] * len(bin_indices.shape) + shape[len(bin_indices.shape) - len(self.batch_shape) + i] = batch_dim + batch_idx = torch.arange(batch_dim, device=bin_indices.device).view(shape) + batch_indices.append(batch_idx) + + # Use advanced indexing: bin_probs[batch_idx_0, batch_idx_1, ..., bin_indices] + probs_selected = self.bin_probs[*batch_indices, bin_indices] # Compute PDF = probability mass / bin width. return probs_selected / bin_widths_selected From f2923b4dea899d52ce53294b21b709446cbeba65 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Mon, 5 Jan 2026 16:47:34 +0100 Subject: [PATCH 3/4] Gather works --- binned_cdf/binned_logit_cdf.py | 43 ++++++++++------------------------ 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/binned_cdf/binned_logit_cdf.py b/binned_cdf/binned_logit_cdf.py index b39a525..e2457d3 100644 --- a/binned_cdf/binned_logit_cdf.py +++ b/binned_cdf/binned_logit_cdf.py @@ -223,10 +223,8 @@ def prob(self, value: torch.Tensor) -> torch.Tensor: value = value.to(dtype=self.logits.dtype, device=self.logits.device) - # Broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions). + # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions). if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape): - # Expand value to have the correct batch dimensions. - # For example: scalar value with batch_shape=(8,) → value.shape=(8,) value = value.expand(self.batch_shape) # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the @@ -240,39 +238,24 @@ def prob(self, value: torch.Tensor) -> torch.Tensor: bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1) # Gather the bin widths and probabilities for the selected bins. - # For bin_widths of shape (num_bins,) we can index directly. bin_widths_selected = self.bin_widths[bin_indices] # shape: (*sample_shape, *batch_shape) - # For bin_probs of shape (*batch_shape, num_bins), we need to handle batch indexing. - # We'll use advanced indexing to select the appropriate bins. - # First, we need to create batch indices that match the shape of bin_indices. + # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension. + # Add sample dimensions to bin_probs and expand to match bin_indices shape. + num_sample_dims = len(bin_indices.shape) - len(self.batch_shape) + bin_probs_for_gather = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) + bin_probs_for_gather = bin_probs_for_gather.expand( + *bin_indices.shape, -1 + ) # shape: (*sample_shape, *batch_shape, num_bins) - if len(self.batch_shape) == 0: - # No batch dimensions - simple case. - probs_selected = self.bin_probs[bin_indices] - else: - # With batch dimensions, we need to construct proper indices. - # bin_probs has shape (*batch_shape, num_bins) - # bin_indices has shape (*sample_shape, *batch_shape) - # We need to index into the last dimension of bin_probs using bin_indices, - # while preserving the batch structure. - - # Create batch indices: for each batch dimension, we need indices [0, 1, 2, ...] - # that broadcast correctly with the sample shape. - batch_indices = [] - for i, batch_dim in enumerate(self.batch_shape): - # Create a shape that's 1 everywhere except at this batch dimension - shape = [1] * len(bin_indices.shape) - shape[len(bin_indices.shape) - len(self.batch_shape) + i] = batch_dim - batch_idx = torch.arange(batch_dim, device=bin_indices.device).view(shape) - batch_indices.append(batch_idx) - - # Use advanced indexing: bin_probs[batch_idx_0, batch_idx_1, ..., bin_indices] - probs_selected = self.bin_probs[*batch_indices, bin_indices] + # Gather the selected bin probabilities. + bin_indices_for_gather = bin_indices.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) + bin_probs_selected = torch.gather(bin_probs_for_gather, dim=-1, index=bin_indices_for_gather) + bin_probs_selected = bin_probs_selected.squeeze(-1) # Compute PDF = probability mass / bin width. - return probs_selected / bin_widths_selected + return bin_probs_selected / bin_widths_selected def cdf(self, value: torch.Tensor) -> torch.Tensor: """Compute cumulative distribution function at given values. From 41136862d2ae9d803c9f4959592c34d922073f1d Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Mon, 5 Jan 2026 17:15:44 +0100 Subject: [PATCH 4/4] Also optimized cdf method --- binned_cdf/binned_logit_cdf.py | 37 +++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/binned_cdf/binned_logit_cdf.py b/binned_cdf/binned_logit_cdf.py index e2457d3..b1c476e 100644 --- a/binned_cdf/binned_logit_cdf.py +++ b/binned_cdf/binned_logit_cdf.py @@ -273,25 +273,34 @@ def cdf(self, value: torch.Tensor) -> torch.Tensor: value = value.to(dtype=self.logits.dtype, device=self.logits.device) - # Determine number of sample dimensions (dimensions before batch_shape). - num_sample_dims = len(value.shape) - len(self.batch_shape) + # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions). + if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape): + value = value.expand(self.batch_shape) - # Prepend singleton dimensions for sample_shape to bin_centers. - # bin_centers: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins) - bin_centers_expanded = self.bin_centers.view((1,) * num_sample_dims + self.bin_centers.shape) + # Use binary search to find how many bin centers are <= value. + # torch.searchsorted with right=True gives us the number of elements <= value. + num_bins_active = torch.searchsorted(self.bin_centers, value, right=True) - # Prepend singleton dimensions for sample_shape to probs. - # probs: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins) - probs_expanded = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) + # Clamp to valid range [0, num_bins]. + num_bins_active = torch.clamp(num_bins_active, 0, self.num_bins) # shape: (*sample_shape, *batch_shape) - # Add the bin dimension to the input which is used for comparing with the bin centers. - value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) + # Compute cumulative sum of bin probabilities. + # Prepend 0 for the case where no bins are active. + num_sample_dims = len(num_bins_active.shape) - len(self.batch_shape) + cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins) + cumsum_probs = torch.cat( + [torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), cumsum_probs], + dim=-1, + ) # shape: (*batch_shape, num_bins + 1) - # Mask for bins with centers <= value. - mask = bin_centers_expanded <= value_expanded # shape: (*sample_shape, *batch_shape, num_bins) + # Expand cumsum_probs to match sample dimensions and gather. + cumsum_probs_for_gather = cumsum_probs.view((1,) * num_sample_dims + cumsum_probs.shape) + cumsum_probs_for_gather = cumsum_probs_for_gather.expand(*num_bins_active.shape, -1) + num_bins_active_for_gather = num_bins_active.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) + cdf_values = torch.gather(cumsum_probs_for_gather, dim=-1, index=num_bins_active_for_gather) + cdf_values = cdf_values.squeeze(-1) - # Sum the bins for this value by their weighted "activation"=probability. - return torch.sum(mask * probs_expanded, dim=-1) + return cdf_values def icdf(self, value: torch.Tensor) -> torch.Tensor: """Compute the inverse CDF, i.e., the quantile function, at the given values.