Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 56 additions & 43 deletions binned_cdf/binned_logit_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,35 +223,39 @@ def prob(self, value: torch.Tensor) -> torch.Tensor:

value = value.to(dtype=self.logits.dtype, device=self.logits.device)

# Determine number of sample dimensions (dimensions before batch_shape).
num_sample_dims = len(value.shape) - len(self.batch_shape)

# Prepend singleton dimensions for sample_shape to bin_edges, bin_widths, and probs.
# For all of them, the resulting shape will be: (*sample_shape, *batch_shape, num_bins)
bin_edges_left = self.bin_edges[..., :-1] # shape: (*batch_shape, num_bins)
bin_edges_right = self.bin_edges[..., 1:] # shape: (*batch_shape, num_bins)
bin_edges_left = bin_edges_left.view((1,) * num_sample_dims + bin_edges_left.shape)
bin_edges_right = bin_edges_right.view((1,) * num_sample_dims + bin_edges_right.shape)
bin_widths = self.bin_widths.view((1,) * num_sample_dims + self.bin_widths.shape)
probs = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)

# Add bin dimension to value for broadcasting.
value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)

# Check which bin each value falls into. Result shape: (*sample_shape, *batch_shape, num_bins).
in_bin = ((value_expanded >= bin_edges_left) & (value_expanded < bin_edges_right)).to(self.logits.dtype)

# Handle right edge case (include bound_up in last bin).
at_right_edge = torch.isclose(
value_expanded, torch.tensor(self.bound_up, dtype=self.logits.dtype, device=self.logits.device)
)
in_bin[..., -1] = torch.max(in_bin[..., -1], at_right_edge[..., -1])

# PDF = (probability mass / bin width) for the containing bin.
pdf_per_bin = probs / bin_widths # shape: (*sample_shape, *batch_shape, num_bins)

# Sum over bins is the same as selecting the bin, as there is only one bin active per value.
return torch.sum(in_bin * pdf_per_bin, dim=-1)
# Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
value = value.expand(self.batch_shape)

# Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
# index where value would be inserted to maintain sorted order.
# Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index.
bin_indices = torch.searchsorted(self.bin_edges, value) - 1 # shape: (*sample_shape, *batch_shape)

# Clamp to valid range [0, num_bins - 1] to handle edge cases:
# - values below bound_low would give bin_idx = -1
# - values at bound_up would give bin_idx = num_bins
bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1)

# Gather the bin widths and probabilities for the selected bins.
# For bin_widths of shape (num_bins,) we can index directly.
bin_widths_selected = self.bin_widths[bin_indices] # shape: (*sample_shape, *batch_shape)

# For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension.
# Add sample dimensions to bin_probs and expand to match bin_indices shape.
num_sample_dims = len(bin_indices.shape) - len(self.batch_shape)
bin_probs_for_gather = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)
bin_probs_for_gather = bin_probs_for_gather.expand(
*bin_indices.shape, -1
) # shape: (*sample_shape, *batch_shape, num_bins)

# Gather the selected bin probabilities.
bin_indices_for_gather = bin_indices.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
bin_probs_selected = torch.gather(bin_probs_for_gather, dim=-1, index=bin_indices_for_gather)
bin_probs_selected = bin_probs_selected.squeeze(-1)

# Compute PDF = probability mass / bin width.
return bin_probs_selected / bin_widths_selected

def cdf(self, value: torch.Tensor) -> torch.Tensor:
"""Compute cumulative distribution function at given values.
Expand All @@ -269,25 +273,34 @@ def cdf(self, value: torch.Tensor) -> torch.Tensor:

value = value.to(dtype=self.logits.dtype, device=self.logits.device)

# Determine number of sample dimensions (dimensions before batch_shape).
num_sample_dims = len(value.shape) - len(self.batch_shape)
# Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
value = value.expand(self.batch_shape)

# Prepend singleton dimensions for sample_shape to bin_centers.
# bin_centers: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins)
bin_centers_expanded = self.bin_centers.view((1,) * num_sample_dims + self.bin_centers.shape)
# Use binary search to find how many bin centers are <= value.
# torch.searchsorted with right=True gives us the number of elements <= value.
num_bins_active = torch.searchsorted(self.bin_centers, value, right=True)

# Prepend singleton dimensions for sample_shape to probs.
# probs: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins)
probs_expanded = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)
# Clamp to valid range [0, num_bins].
num_bins_active = torch.clamp(num_bins_active, 0, self.num_bins) # shape: (*sample_shape, *batch_shape)

# Add the bin dimension to the input which is used for comparing with the bin centers.
value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
# Compute cumulative sum of bin probabilities.
# Prepend 0 for the case where no bins are active.
num_sample_dims = len(num_bins_active.shape) - len(self.batch_shape)
cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins)
cumsum_probs = torch.cat(
[torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), cumsum_probs],
dim=-1,
) # shape: (*batch_shape, num_bins + 1)

# Mask for bins with centers <= value.
mask = bin_centers_expanded <= value_expanded # shape: (*sample_shape, *batch_shape, num_bins)
# Expand cumsum_probs to match sample dimensions and gather.
cumsum_probs_for_gather = cumsum_probs.view((1,) * num_sample_dims + cumsum_probs.shape)
cumsum_probs_for_gather = cumsum_probs_for_gather.expand(*num_bins_active.shape, -1)
num_bins_active_for_gather = num_bins_active.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
cdf_values = torch.gather(cumsum_probs_for_gather, dim=-1, index=num_bins_active_for_gather)
cdf_values = cdf_values.squeeze(-1)

# Sum the bins for this value by their weighted "activation"=probability.
return torch.sum(mask * probs_expanded, dim=-1)
return cdf_values

def icdf(self, value: torch.Tensor) -> torch.Tensor:
"""Compute the inverse CDF, i.e., the quantile function, at the given values.
Expand Down