From e105d330521957b9ce1ce1f1308415e92af6e8cb Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Mon, 17 Nov 2025 08:49:08 +0100 Subject: [PATCH 1/5] Tests for log-prob --- binned_cdf/binned_logit_cdf.py | 32 +++++++++++------- tests/test_binned_logit_cdf.py | 62 ++++++++++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/binned_cdf/binned_logit_cdf.py b/binned_cdf/binned_logit_cdf.py index 9788f56..376a79a 100644 --- a/binned_cdf/binned_logit_cdf.py +++ b/binned_cdf/binned_logit_cdf.py @@ -98,8 +98,8 @@ def _create_bins( # Create positive side: 0, internal edges, bound_up. if half_bins == 1: - # Special case: only boundary edges. - positive_edges = torch.tensor([0.0, bound_up]) + # Special case where we only use the boundary edges. + positive_edges = torch.tensor([bound_up]) else: # Create half_bins - 1 internal edges between 0 and bound_up. internal_positive = torch.logspace( @@ -141,7 +141,7 @@ def num_edges(self) -> int: return self.bin_edges.shape[0] @property - def probs(self) -> torch.Tensor: + def bin_probs(self) -> torch.Tensor: """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins).""" raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins) return raw_probs / raw_probs.sum(dim=-1, keepdim=True) @@ -149,17 +149,14 @@ def probs(self) -> torch.Tensor: @property def mean(self) -> torch.Tensor: """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_size,).""" - bin_probs = self.probs - weighted_centers = bin_probs * self.bin_centers # shape: (*batch_shape, num_bins) + weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins) return torch.sum(weighted_centers, dim=-1) @property def variance(self) -> torch.Tensor: """Compute variance of the distribution, of shape (*batch_shape,).""" - bin_probs = self.probs - # E[X^2] = weighted squared bin centers. - weighted_centers_sq = bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins) + weighted_centers_sq = self.bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins) second_moment = torch.sum(weighted_centers_sq, dim=-1) # shape: (*batch_shape,) # Var = E[X^2] - E[X]^2 @@ -189,7 +186,16 @@ def expand( ) def log_prob(self, value: torch.Tensor) -> torch.Tensor: - """Compute log probability density at given values.""" + """Compute log probability density at given values. + + Args: + value: Values at which to compute the log PDF. + Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. + + Returns: + Log PDF values corresponding to the input values. + Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). + """ return torch.log(self.prob(value) + 1e-8) # small epsilon for stability def prob(self, value: torch.Tensor) -> torch.Tensor: @@ -218,7 +224,7 @@ def prob(self, value: torch.Tensor) -> torch.Tensor: bin_edges_left = bin_edges_left.view((1,) * num_sample_dims + bin_edges_left.shape) bin_edges_right = bin_edges_right.view((1,) * num_sample_dims + bin_edges_right.shape) bin_widths = self.bin_widths.view((1,) * num_sample_dims + self.bin_widths.shape) - probs = self.probs.view((1,) * num_sample_dims + self.probs.shape) + probs = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) # Add bin dimension to value for broadcasting. value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) @@ -263,7 +269,7 @@ def cdf(self, value: torch.Tensor) -> torch.Tensor: # Prepend singleton dimensions for sample_shape to probs. # probs: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins) - probs_expanded = self.probs.view((1,) * num_sample_dims + self.probs.shape) + probs_expanded = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) # Add the bin dimension to the input which is used for comparing with the bin centers. value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) @@ -294,7 +300,7 @@ def icdf(self, value: torch.Tensor) -> torch.Tensor: cdf_edges = torch.cat( [ torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), - torch.cumsum(self.probs, dim=-1), # shape: (*batch_shape, num_bins) + torch.cumsum(self.bin_probs, dim=-1), # shape: (*batch_shape, num_bins) ], dim=-1, ) # shape: (*batch_shape, num_bins + 1) @@ -362,7 +368,7 @@ def entropy(self) -> torch.Tensor: Note: Here, we are doing an approximation by treating each bin as a uniform distribution over its width. """ - bin_probs = self.probs + bin_probs = self.bin_probs # Get the PDF values at bin centers. pdf_values = bin_probs / self.bin_widths # shape: (*batch_shape, num_bins) diff --git a/tests/test_binned_logit_cdf.py b/tests/test_binned_logit_cdf.py index c1ff1a0..a98f576 100644 --- a/tests/test_binned_logit_cdf.py +++ b/tests/test_binned_logit_cdf.py @@ -47,7 +47,7 @@ def test_basic_properties( with pytest.raises(ValueError, match="log_spacing requires even number of bins"): BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing) return - dist = BinnedLogitCDF(logits, bound_low, bound_up) + dist = BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing) # Test that tensors are on the correct device. assert dist.logits.device == device @@ -76,14 +76,14 @@ def test_basic_properties( assert "BinnedLogitCDF" in repr_str # Test that probabilities are valid. They should be normalized, and sum to 1. - probs = dist.probs + probs = dist.bin_probs assert probs.device == device assert torch.all(probs >= 0) assert torch.all(probs <= 1) assert torch.allclose(probs.sum(dim=-1), torch.ones(dist.batch_shape, device=device)) # The probabilities should also be deterministic. - probs2 = dist.probs + probs2 = dist.bin_probs assert torch.allclose(probs, probs2) # Test that mean and variance have the correct shape and are finite. @@ -98,6 +98,62 @@ def test_basic_properties( assert torch.all(torch.isfinite(var)) +@pytest.mark.parametrize("batch_size", [None, 1, 8]) +@pytest.mark.parametrize("num_bins", [2, 100, 1000]) # 2 is an edge case for log-spacing +@pytest.mark.parametrize("log_spacing", [False, True], ids=["linear_spacing", "log_spacing"]) +@pytest.mark.parametrize( + "use_cuda", + [ + pytest.param(False, id="cpu"), + pytest.param(True, marks=needs_cuda, id="cuda"), + ], +) +def test_prob_random_logits( + batch_size: int | None, + num_bins: int, + log_spacing: bool, + use_cuda: bool, +): + """Test probability evaluation with random logits at the bounds.""" + device = torch.device("cuda:0" if use_cuda else "cpu") + logits = torch.randn((num_bins,)) if batch_size is None else torch.randn(batch_size, num_bins) + logits = logits.to(device) + dist = BinnedLogitCDF(logits, log_spacing=log_spacing) # default bounds + + # Define expected shapes based on batch_size. The bins go into the sample shape. + bin_centers = dist.bin_centers + if batch_size is not None: + # Expand to (num_bins, batch_size) for batched distributions. + bin_centers = bin_centers.unsqueeze(1).expand(num_bins, batch_size) + expected_probs_shape = (num_bins, batch_size) + else: + # Keep as (num_bins,) for non-batched distributions. + expected_probs_shape = (num_bins,) + + # Test probability computation at bin centers. + probs_at_centers = dist.log_prob(bin_centers) + assert probs_at_centers.device == device + assert torch.all(torch.isfinite(probs_at_centers)), "log_prob at bin centers should be finite" + assert probs_at_centers.shape == expected_probs_shape + + # Test probability at bounds - should be finite but may be low + expected_scalar_shape = torch.Size([]) if batch_size is None else torch.Size([batch_size]) + prob_at_low = dist.log_prob(torch.tensor(dist.bound_low, device=device)) + prob_at_up = dist.log_prob(torch.tensor(dist.bound_up, device=device)) + assert torch.all(torch.isfinite(prob_at_low)), f"log_prob at lower bound should be finite: {prob_at_low}" + assert torch.all(torch.isfinite(prob_at_up)), f"log_prob at upper bound should be finite: {prob_at_up}" + assert prob_at_low.shape == expected_scalar_shape + assert prob_at_up.shape == expected_scalar_shape + + # Test probability outside bounds should be very small (log_prob very negative) + # if dist.bound_low > -1e6: # Only test if lower bound is finite enough + # prob_below = dist.log_prob(torch.tensor(dist.bound_low - 10.0, device=device)) + # assert torch.all(prob_below < -5), "log_prob below lower bound should be very negative" + # if dist.bound_up < 1e6: # Only test if upper bound is finite enough + # prob_above = dist.log_prob(torch.tensor(dist.bound_up + 10.0, device=device)) + # assert torch.all(prob_above < -5), "log_prob above upper bound should be very negative" + + @pytest.mark.parametrize("logit_scale", [1e-3, 1, 1e3, 1e9]) @pytest.mark.parametrize("batch_size", [None, 1, 8]) @pytest.mark.parametrize( From 087a0792544dde860b1ceec6d7626b1da76583fa Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Mon, 17 Nov 2025 08:49:58 +0100 Subject: [PATCH 2/5] Removed comment --- tests/test_binned_logit_cdf.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_binned_logit_cdf.py b/tests/test_binned_logit_cdf.py index a98f576..4b45dbd 100644 --- a/tests/test_binned_logit_cdf.py +++ b/tests/test_binned_logit_cdf.py @@ -145,14 +145,6 @@ def test_prob_random_logits( assert prob_at_low.shape == expected_scalar_shape assert prob_at_up.shape == expected_scalar_shape - # Test probability outside bounds should be very small (log_prob very negative) - # if dist.bound_low > -1e6: # Only test if lower bound is finite enough - # prob_below = dist.log_prob(torch.tensor(dist.bound_low - 10.0, device=device)) - # assert torch.all(prob_below < -5), "log_prob below lower bound should be very negative" - # if dist.bound_up < 1e6: # Only test if upper bound is finite enough - # prob_above = dist.log_prob(torch.tensor(dist.bound_up + 10.0, device=device)) - # assert torch.all(prob_above < -5), "log_prob above upper bound should be very negative" - @pytest.mark.parametrize("logit_scale", [1e-3, 1, 1e3, 1e9]) @pytest.mark.parametrize("batch_size", [None, 1, 8]) From c32547cfe4a48672744ec0d415e72d495969abda Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Mon, 17 Nov 2025 08:55:11 +0100 Subject: [PATCH 3/5] Consistent shape comments --- binned_cdf/binned_logit_cdf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/binned_cdf/binned_logit_cdf.py b/binned_cdf/binned_logit_cdf.py index 376a79a..292d3fc 100644 --- a/binned_cdf/binned_logit_cdf.py +++ b/binned_cdf/binned_logit_cdf.py @@ -148,7 +148,7 @@ def bin_probs(self) -> torch.Tensor: @property def mean(self) -> torch.Tensor: - """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_size,).""" + """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,).""" weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins) return torch.sum(weighted_centers, dim=-1) @@ -354,7 +354,7 @@ def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) sample_shape: Shape of the samples to draw. Returns: - Samples of shape [sample_shape + batch_shape, num_bins]. + Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution. """ shape = torch.Size(sample_shape) + self.batch_shape uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) From fb64ea71a4e995fd2961fda7eee13fc8e4b8c4d3 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Mon, 17 Nov 2025 09:01:24 +0100 Subject: [PATCH 4/5] Added test for expand() --- tests/test_binned_logit_cdf.py | 67 +++++++++++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/tests/test_binned_logit_cdf.py b/tests/test_binned_logit_cdf.py index 4b45dbd..1294513 100644 --- a/tests/test_binned_logit_cdf.py +++ b/tests/test_binned_logit_cdf.py @@ -98,8 +98,69 @@ def test_basic_properties( assert torch.all(torch.isfinite(var)) +@pytest.mark.parametrize( + "batch_size,new_batch_shape", + [ + (None, [4, 5]), # () can expand to any shape + (1, [1, 5]), # (1,) can expand by adding dimensions or keeping 1 + (1, [3, 1]), # (1,) can also expand with 1 in last position + (8, [2, 8]), # (8,) can expand by adding leading dimensions + (8, [3, 2, 8]), # (8,) can expand with multiple leading dimensions + ], +) +@pytest.mark.parametrize("num_bins", [2, 200]) # 2 is an edge case for log-spacing +@pytest.mark.parametrize("log_spacing", [False, True], ids=["linear_spacing", "log_spacing"]) +@pytest.mark.parametrize( + "use_cuda", + [ + pytest.param(False, id="cpu"), + pytest.param(True, marks=needs_cuda, id="cuda"), + ], +) +def test_expand( + batch_size: int | None, + new_batch_shape: list[int], + num_bins: int, + log_spacing: bool, + use_cuda: bool, +): + device = torch.device("cuda:0" if use_cuda else "cpu") + logits = torch.randn((num_bins,)) if batch_size is None else torch.randn(batch_size, num_bins) + logits = logits.to(device) + dist = BinnedLogitCDF(logits, log_spacing=log_spacing) + + expanded_dist = dist.expand(new_batch_shape) + + # Assert that expanded_dist is a different object (not the same instance). + assert expanded_dist is not dist, "Expanded distribution should be a new instance" + + # Assert that the expanded distribution is on the same device. + assert expanded_dist.logits.device == device, f"Expected device {device}, got {expanded_dist.logits.device}" + assert expanded_dist.bin_edges.device == device + assert expanded_dist.bin_centers.device == device + assert expanded_dist.bin_widths.device == device + + # Assert that the batch shape is correct. + assert expanded_dist.batch_shape == torch.Size(new_batch_shape), ( + f"Expected batch_shape {torch.Size(new_batch_shape)}, got {expanded_dist.batch_shape}" + ) + + # Assert that the logits have the correct shape: (*new_batch_shape, num_bins). + expected_logits_shape = torch.Size([*new_batch_shape, num_bins]) + assert expanded_dist.logits.shape == expected_logits_shape, ( + f"Expected logits shape {expected_logits_shape}, got {expanded_dist.logits.shape}" + ) + + # Verify properties that should remain unchanged. + assert expanded_dist.event_shape == torch.Size([]), "event_shape should remain empty (scalar)" + assert expanded_dist.num_bins == num_bins, "num_bins should be unchanged" + assert expanded_dist.bin_edges.shape == dist.bin_edges.shape, "bin_edges shape should be unchanged" + assert expanded_dist.bin_centers.shape == dist.bin_centers.shape, "bin_centers shape should be unchanged" + assert expanded_dist.bin_widths.shape == dist.bin_widths.shape, "bin_widths shape should be unchanged" + + @pytest.mark.parametrize("batch_size", [None, 1, 8]) -@pytest.mark.parametrize("num_bins", [2, 100, 1000]) # 2 is an edge case for log-spacing +@pytest.mark.parametrize("num_bins", [2, 200]) # 2 is an edge case for log-spacing @pytest.mark.parametrize("log_spacing", [False, True], ids=["linear_spacing", "log_spacing"]) @pytest.mark.parametrize( "use_cuda", @@ -115,10 +176,12 @@ def test_prob_random_logits( use_cuda: bool, ): """Test probability evaluation with random logits at the bounds.""" + torch.manual_seed(42) + device = torch.device("cuda:0" if use_cuda else "cpu") logits = torch.randn((num_bins,)) if batch_size is None else torch.randn(batch_size, num_bins) logits = logits.to(device) - dist = BinnedLogitCDF(logits, log_spacing=log_spacing) # default bounds + dist = BinnedLogitCDF(logits, log_spacing=log_spacing) # Define expected shapes based on batch_size. The bins go into the sample shape. bin_centers = dist.bin_centers From 1d15e2778a738e24afd55b84bbfa8b7f1788b660 Mon Sep 17 00:00:00 2001 From: Fabio Muratore Date: Mon, 17 Nov 2025 09:23:08 +0100 Subject: [PATCH 5/5] Please mypy --- tests/test_binned_logit_cdf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_binned_logit_cdf.py b/tests/test_binned_logit_cdf.py index 1294513..e1fa1d5 100644 --- a/tests/test_binned_logit_cdf.py +++ b/tests/test_binned_logit_cdf.py @@ -188,10 +188,10 @@ def test_prob_random_logits( if batch_size is not None: # Expand to (num_bins, batch_size) for batched distributions. bin_centers = bin_centers.unsqueeze(1).expand(num_bins, batch_size) - expected_probs_shape = (num_bins, batch_size) + expected_probs_shape: tuple[int, ...] = (num_bins, batch_size) else: # Keep as (num_bins,) for non-batched distributions. - expected_probs_shape = (num_bins,) + expected_probs_shape: tuple[int, ...] = (num_bins,) # type: ignore[no-redef] # Test probability computation at bin centers. probs_at_centers = dist.log_prob(bin_centers)