Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions binned_cdf/binned_logit_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def _create_bins(

# Create positive side: 0, internal edges, bound_up.
if half_bins == 1:
# Special case: only boundary edges.
positive_edges = torch.tensor([0.0, bound_up])
# Special case where we only use the boundary edges.
positive_edges = torch.tensor([bound_up])
else:
# Create half_bins - 1 internal edges between 0 and bound_up.
internal_positive = torch.logspace(
Expand Down Expand Up @@ -141,25 +141,22 @@ def num_edges(self) -> int:
return self.bin_edges.shape[0]

@property
def probs(self) -> torch.Tensor:
def bin_probs(self) -> torch.Tensor:
"""Get normalized probabilities for each bin, of shape (*batch_shape, num_bins)."""
raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins)
return raw_probs / raw_probs.sum(dim=-1, keepdim=True)

@property
def mean(self) -> torch.Tensor:
"""Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_size,)."""
bin_probs = self.probs
weighted_centers = bin_probs * self.bin_centers # shape: (*batch_shape, num_bins)
"""Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,)."""
weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins)
return torch.sum(weighted_centers, dim=-1)

@property
def variance(self) -> torch.Tensor:
"""Compute variance of the distribution, of shape (*batch_shape,)."""
bin_probs = self.probs

# E[X^2] = weighted squared bin centers.
weighted_centers_sq = bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins)
weighted_centers_sq = self.bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins)
second_moment = torch.sum(weighted_centers_sq, dim=-1) # shape: (*batch_shape,)

# Var = E[X^2] - E[X]^2
Expand Down Expand Up @@ -189,7 +186,16 @@ def expand(
)

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""Compute log probability density at given values."""
"""Compute log probability density at given values.

Args:
value: Values at which to compute the log PDF.
Expected shape: (*sample_shape, *batch_shape) or broadcastable to it.

Returns:
Log PDF values corresponding to the input values.
Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape).
"""
return torch.log(self.prob(value) + 1e-8) # small epsilon for stability

def prob(self, value: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -218,7 +224,7 @@ def prob(self, value: torch.Tensor) -> torch.Tensor:
bin_edges_left = bin_edges_left.view((1,) * num_sample_dims + bin_edges_left.shape)
bin_edges_right = bin_edges_right.view((1,) * num_sample_dims + bin_edges_right.shape)
bin_widths = self.bin_widths.view((1,) * num_sample_dims + self.bin_widths.shape)
probs = self.probs.view((1,) * num_sample_dims + self.probs.shape)
probs = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)

# Add bin dimension to value for broadcasting.
value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
Expand Down Expand Up @@ -263,7 +269,7 @@ def cdf(self, value: torch.Tensor) -> torch.Tensor:

# Prepend singleton dimensions for sample_shape to probs.
# probs: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins)
probs_expanded = self.probs.view((1,) * num_sample_dims + self.probs.shape)
probs_expanded = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)

# Add the bin dimension to the input which is used for comparing with the bin centers.
value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
Expand Down Expand Up @@ -294,7 +300,7 @@ def icdf(self, value: torch.Tensor) -> torch.Tensor:
cdf_edges = torch.cat(
[
torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device),
torch.cumsum(self.probs, dim=-1), # shape: (*batch_shape, num_bins)
torch.cumsum(self.bin_probs, dim=-1), # shape: (*batch_shape, num_bins)
],
dim=-1,
) # shape: (*batch_shape, num_bins + 1)
Expand Down Expand Up @@ -348,7 +354,7 @@ def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size)
sample_shape: Shape of the samples to draw.

Returns:
Samples of shape [sample_shape + batch_shape, num_bins].
Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution.
"""
shape = torch.Size(sample_shape) + self.batch_shape
uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
Expand All @@ -362,7 +368,7 @@ def entropy(self) -> torch.Tensor:
Note:
Here, we are doing an approximation by treating each bin as a uniform distribution over its width.
"""
bin_probs = self.probs
bin_probs = self.bin_probs

# Get the PDF values at bin centers.
pdf_values = bin_probs / self.bin_widths # shape: (*batch_shape, num_bins)
Expand Down
117 changes: 114 additions & 3 deletions tests/test_binned_logit_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_basic_properties(
with pytest.raises(ValueError, match="log_spacing requires even number of bins"):
BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing)
return
dist = BinnedLogitCDF(logits, bound_low, bound_up)
dist = BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing)

# Test that tensors are on the correct device.
assert dist.logits.device == device
Expand Down Expand Up @@ -76,14 +76,14 @@ def test_basic_properties(
assert "BinnedLogitCDF" in repr_str

# Test that probabilities are valid. They should be normalized, and sum to 1.
probs = dist.probs
probs = dist.bin_probs
assert probs.device == device
assert torch.all(probs >= 0)
assert torch.all(probs <= 1)
assert torch.allclose(probs.sum(dim=-1), torch.ones(dist.batch_shape, device=device))

# The probabilities should also be deterministic.
probs2 = dist.probs
probs2 = dist.bin_probs
assert torch.allclose(probs, probs2)

# Test that mean and variance have the correct shape and are finite.
Expand All @@ -98,6 +98,117 @@ def test_basic_properties(
assert torch.all(torch.isfinite(var))


@pytest.mark.parametrize(
"batch_size,new_batch_shape",
[
(None, [4, 5]), # () can expand to any shape
(1, [1, 5]), # (1,) can expand by adding dimensions or keeping 1
(1, [3, 1]), # (1,) can also expand with 1 in last position
(8, [2, 8]), # (8,) can expand by adding leading dimensions
(8, [3, 2, 8]), # (8,) can expand with multiple leading dimensions
],
)
@pytest.mark.parametrize("num_bins", [2, 200]) # 2 is an edge case for log-spacing
@pytest.mark.parametrize("log_spacing", [False, True], ids=["linear_spacing", "log_spacing"])
@pytest.mark.parametrize(
"use_cuda",
[
pytest.param(False, id="cpu"),
pytest.param(True, marks=needs_cuda, id="cuda"),
],
)
def test_expand(
batch_size: int | None,
new_batch_shape: list[int],
num_bins: int,
log_spacing: bool,
use_cuda: bool,
):
device = torch.device("cuda:0" if use_cuda else "cpu")
logits = torch.randn((num_bins,)) if batch_size is None else torch.randn(batch_size, num_bins)
logits = logits.to(device)
dist = BinnedLogitCDF(logits, log_spacing=log_spacing)

expanded_dist = dist.expand(new_batch_shape)

# Assert that expanded_dist is a different object (not the same instance).
assert expanded_dist is not dist, "Expanded distribution should be a new instance"

# Assert that the expanded distribution is on the same device.
assert expanded_dist.logits.device == device, f"Expected device {device}, got {expanded_dist.logits.device}"
assert expanded_dist.bin_edges.device == device
assert expanded_dist.bin_centers.device == device
assert expanded_dist.bin_widths.device == device

# Assert that the batch shape is correct.
assert expanded_dist.batch_shape == torch.Size(new_batch_shape), (
f"Expected batch_shape {torch.Size(new_batch_shape)}, got {expanded_dist.batch_shape}"
)

# Assert that the logits have the correct shape: (*new_batch_shape, num_bins).
expected_logits_shape = torch.Size([*new_batch_shape, num_bins])
assert expanded_dist.logits.shape == expected_logits_shape, (
f"Expected logits shape {expected_logits_shape}, got {expanded_dist.logits.shape}"
)

# Verify properties that should remain unchanged.
assert expanded_dist.event_shape == torch.Size([]), "event_shape should remain empty (scalar)"
assert expanded_dist.num_bins == num_bins, "num_bins should be unchanged"
assert expanded_dist.bin_edges.shape == dist.bin_edges.shape, "bin_edges shape should be unchanged"
assert expanded_dist.bin_centers.shape == dist.bin_centers.shape, "bin_centers shape should be unchanged"
assert expanded_dist.bin_widths.shape == dist.bin_widths.shape, "bin_widths shape should be unchanged"


@pytest.mark.parametrize("batch_size", [None, 1, 8])
@pytest.mark.parametrize("num_bins", [2, 200]) # 2 is an edge case for log-spacing
@pytest.mark.parametrize("log_spacing", [False, True], ids=["linear_spacing", "log_spacing"])
@pytest.mark.parametrize(
"use_cuda",
[
pytest.param(False, id="cpu"),
pytest.param(True, marks=needs_cuda, id="cuda"),
],
)
def test_prob_random_logits(
batch_size: int | None,
num_bins: int,
log_spacing: bool,
use_cuda: bool,
):
"""Test probability evaluation with random logits at the bounds."""
torch.manual_seed(42)

device = torch.device("cuda:0" if use_cuda else "cpu")
logits = torch.randn((num_bins,)) if batch_size is None else torch.randn(batch_size, num_bins)
logits = logits.to(device)
dist = BinnedLogitCDF(logits, log_spacing=log_spacing)

# Define expected shapes based on batch_size. The bins go into the sample shape.
bin_centers = dist.bin_centers
if batch_size is not None:
# Expand to (num_bins, batch_size) for batched distributions.
bin_centers = bin_centers.unsqueeze(1).expand(num_bins, batch_size)
expected_probs_shape: tuple[int, ...] = (num_bins, batch_size)
else:
# Keep as (num_bins,) for non-batched distributions.
expected_probs_shape: tuple[int, ...] = (num_bins,) # type: ignore[no-redef]

# Test probability computation at bin centers.
probs_at_centers = dist.log_prob(bin_centers)
assert probs_at_centers.device == device
assert torch.all(torch.isfinite(probs_at_centers)), "log_prob at bin centers should be finite"
assert probs_at_centers.shape == expected_probs_shape

# Test probability at bounds - should be finite but may be low
expected_scalar_shape = torch.Size([]) if batch_size is None else torch.Size([batch_size])
prob_at_low = dist.log_prob(torch.tensor(dist.bound_low, device=device))
prob_at_up = dist.log_prob(torch.tensor(dist.bound_up, device=device))
assert torch.all(torch.isfinite(prob_at_low)), f"log_prob at lower bound should be finite: {prob_at_low}"
assert torch.all(torch.isfinite(prob_at_up)), f"log_prob at upper bound should be finite: {prob_at_up}"
assert prob_at_low.shape == expected_scalar_shape
assert prob_at_up.shape == expected_scalar_shape


@pytest.mark.parametrize("logit_scale", [1e-3, 1, 1e3, 1e9])
@pytest.mark.parametrize("batch_size", [None, 1, 8])
@pytest.mark.parametrize(
Expand Down
Loading