diff --git a/binned_cdf/binned_logit_cdf.py b/binned_cdf/binned_logit_cdf.py index 292d3fc..7da6519 100644 --- a/binned_cdf/binned_logit_cdf.py +++ b/binned_cdf/binned_logit_cdf.py @@ -1,4 +1,5 @@ import math +from typing import Literal import torch from torch.distributions import Distribution, constraints @@ -24,6 +25,7 @@ def __init__( bound_low: float = -1e3, bound_up: float = 1e3, log_spacing: bool = False, + bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid", validate_args: bool | None = None, ) -> None: """Initializer. @@ -33,11 +35,14 @@ def __init__( bound_low: Lower bound of the distribution support, needs to be finite. bound_up: Upper bound of the distribution support, needs to be finite. log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used. + bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid", + each bin is independently activated, while with "softmax", the bins activations influence each other. validate_args: Whether to validate arguments. Carried over to keep the interface with the base class. """ self.logits = logits self.bound_low = bound_low self.bound_up = bound_up + self.bin_normalization_method = bin_normalization_method self.log_spacing = log_spacing # Create bin structure (same for all batch dimensions). @@ -143,8 +148,12 @@ def num_edges(self) -> int: @property def bin_probs(self) -> torch.Tensor: """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins).""" - raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins) - return raw_probs / raw_probs.sum(dim=-1, keepdim=True) + if self.bin_normalization_method == "sigmoid": + raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins) + bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True) + else: + bin_probs = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, num_bins) + return bin_probs @property def mean(self) -> torch.Tensor: diff --git a/tests/test_binned_logit_cdf.py b/tests/test_binned_logit_cdf.py index e1fa1d5..a42e0bd 100644 --- a/tests/test_binned_logit_cdf.py +++ b/tests/test_binned_logit_cdf.py @@ -1,4 +1,5 @@ import math +from typing import Literal import matplotlib.pyplot as plt import numpy as np @@ -14,6 +15,7 @@ @pytest.mark.parametrize("batch_size", [None, 1, 8]) @pytest.mark.parametrize("num_bins", [1, 2, 7, 1000]) # 2 is an edge case for log-spacing @pytest.mark.parametrize("log_spacing", [False, True], ids=["linear_spacing", "log_spacing"]) +@pytest.mark.parametrize("bin_normalization_method", ["sigmoid", "softmax"], ids=["sigmoid", "softmax"]) @pytest.mark.parametrize("bound_low,bound_up", [(-5, 5), (0, 5), (-5, 0)]) @pytest.mark.parametrize( "use_cuda", @@ -26,6 +28,7 @@ def test_basic_properties( batch_size: int | None, num_bins: int, log_spacing: bool, + bin_normalization_method: Literal["sigmoid", "softmax"], bound_low: int, bound_up: int, use_cuda: bool, @@ -37,17 +40,25 @@ def test_basic_properties( if log_spacing and not math.isclose(-bound_low, bound_up): with pytest.raises(ValueError, match="log_spacing requires symmetric bounds"): - BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing) + BinnedLogitCDF( + logits, bound_low, bound_up, log_spacing=log_spacing, bin_normalization_method=bin_normalization_method + ) return if log_spacing and bound_up <= 0: with pytest.raises(ValueError, match="log_spacing requires positive upper bound"): - BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing) + BinnedLogitCDF( + logits, bound_low, bound_up, log_spacing=log_spacing, bin_normalization_method=bin_normalization_method + ) return if log_spacing and num_bins % 2 != 0: with pytest.raises(ValueError, match="log_spacing requires even number of bins"): - BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing) + BinnedLogitCDF( + logits, bound_low, bound_up, log_spacing=log_spacing, bin_normalization_method=bin_normalization_method + ) return - dist = BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing) + dist = BinnedLogitCDF( + logits, bound_low, bound_up, log_spacing=log_spacing, bin_normalization_method=bin_normalization_method + ) # Test that tensors are on the correct device. assert dist.logits.device == device @@ -162,6 +173,7 @@ def test_expand( @pytest.mark.parametrize("batch_size", [None, 1, 8]) @pytest.mark.parametrize("num_bins", [2, 200]) # 2 is an edge case for log-spacing @pytest.mark.parametrize("log_spacing", [False, True], ids=["linear_spacing", "log_spacing"]) +@pytest.mark.parametrize("bin_normalization_method", ["sigmoid", "softmax"], ids=["sigmoid", "softmax"]) @pytest.mark.parametrize( "use_cuda", [ @@ -173,6 +185,7 @@ def test_prob_random_logits( batch_size: int | None, num_bins: int, log_spacing: bool, + bin_normalization_method: Literal["sigmoid", "softmax"], use_cuda: bool, ): """Test probability evaluation with random logits at the bounds.""" @@ -181,7 +194,7 @@ def test_prob_random_logits( device = torch.device("cuda:0" if use_cuda else "cpu") logits = torch.randn((num_bins,)) if batch_size is None else torch.randn(batch_size, num_bins) logits = logits.to(device) - dist = BinnedLogitCDF(logits, log_spacing=log_spacing) + dist = BinnedLogitCDF(logits, log_spacing=log_spacing, bin_normalization_method=bin_normalization_method) # Define expected shapes based on batch_size. The bins go into the sample shape. bin_centers = dist.bin_centers @@ -210,6 +223,7 @@ def test_prob_random_logits( @pytest.mark.parametrize("logit_scale", [1e-3, 1, 1e3, 1e9]) +@pytest.mark.parametrize("bin_normalization_method", ["sigmoid", "softmax"], ids=["sigmoid", "softmax"]) @pytest.mark.parametrize("batch_size", [None, 1, 8]) @pytest.mark.parametrize( "use_cuda,plot", @@ -222,6 +236,7 @@ def test_prob_random_logits( def test_cdf_random_logits( logit_scale: float, batch_size: int | None, + bin_normalization_method: Literal["sigmoid", "softmax"], use_cuda: bool, plot: bool, bound_low: float = -10, @@ -233,7 +248,7 @@ def test_cdf_random_logits( logits = logit_scale * torch.randn((num_bins,)) if batch_size is None else torch.randn(batch_size, num_bins) logits = logits.to(device) - dist = BinnedLogitCDF(logits, bound_low, bound_up) + dist = BinnedLogitCDF(logits, bound_low, bound_up, bin_normalization_method=bin_normalization_method) # Evaluate the CDF at the bounds. cdf_low = dist.cdf(torch.tensor(bound_low)) @@ -253,7 +268,10 @@ def test_cdf_random_logits( plt.title(f"CDF for random logits scaled by {logit_scale}") plt.legend() plt.grid(True, alpha=0.3) - plt.savefig(f"tests/results/cdf_random_logits_scale-{logit_scale}.png", bbox_inches="tight") + plt.savefig( + f"tests/results/cdf_random_logits_scale-{logit_scale}_normalization-{bin_normalization_method}.png", + bbox_inches="tight", + ) @pytest.mark.parametrize("batch_size", [None, 1, 8, 16])