Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions binned_cdf/binned_logit_cdf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from typing import Literal

import torch
from torch.distributions import Distribution, constraints
Expand All @@ -24,6 +25,7 @@ def __init__(
bound_low: float = -1e3,
bound_up: float = 1e3,
log_spacing: bool = False,
bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid",
validate_args: bool | None = None,
) -> None:
"""Initializer.
Expand All @@ -33,11 +35,14 @@ def __init__(
bound_low: Lower bound of the distribution support, needs to be finite.
bound_up: Upper bound of the distribution support, needs to be finite.
log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used.
bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid",
each bin is independently activated, while with "softmax", the bins activations influence each other.
validate_args: Whether to validate arguments. Carried over to keep the interface with the base class.
"""
self.logits = logits
self.bound_low = bound_low
self.bound_up = bound_up
self.bin_normalization_method = bin_normalization_method
self.log_spacing = log_spacing

# Create bin structure (same for all batch dimensions).
Expand Down Expand Up @@ -143,8 +148,12 @@ def num_edges(self) -> int:
@property
def bin_probs(self) -> torch.Tensor:
"""Get normalized probabilities for each bin, of shape (*batch_shape, num_bins)."""
raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins)
return raw_probs / raw_probs.sum(dim=-1, keepdim=True)
if self.bin_normalization_method == "sigmoid":
raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins)
bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True)
else:
bin_probs = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, num_bins)
return bin_probs

@property
def mean(self) -> torch.Tensor:
Expand Down
32 changes: 25 additions & 7 deletions tests/test_binned_logit_cdf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -14,6 +15,7 @@
@pytest.mark.parametrize("batch_size", [None, 1, 8])
@pytest.mark.parametrize("num_bins", [1, 2, 7, 1000]) # 2 is an edge case for log-spacing
@pytest.mark.parametrize("log_spacing", [False, True], ids=["linear_spacing", "log_spacing"])
@pytest.mark.parametrize("bin_normalization_method", ["sigmoid", "softmax"], ids=["sigmoid", "softmax"])
@pytest.mark.parametrize("bound_low,bound_up", [(-5, 5), (0, 5), (-5, 0)])
@pytest.mark.parametrize(
"use_cuda",
Expand All @@ -26,6 +28,7 @@ def test_basic_properties(
batch_size: int | None,
num_bins: int,
log_spacing: bool,
bin_normalization_method: Literal["sigmoid", "softmax"],
bound_low: int,
bound_up: int,
use_cuda: bool,
Expand All @@ -37,17 +40,25 @@ def test_basic_properties(

if log_spacing and not math.isclose(-bound_low, bound_up):
with pytest.raises(ValueError, match="log_spacing requires symmetric bounds"):
BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing)
BinnedLogitCDF(
logits, bound_low, bound_up, log_spacing=log_spacing, bin_normalization_method=bin_normalization_method
)
return
if log_spacing and bound_up <= 0:
with pytest.raises(ValueError, match="log_spacing requires positive upper bound"):
BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing)
BinnedLogitCDF(
logits, bound_low, bound_up, log_spacing=log_spacing, bin_normalization_method=bin_normalization_method
)
return
if log_spacing and num_bins % 2 != 0:
with pytest.raises(ValueError, match="log_spacing requires even number of bins"):
BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing)
BinnedLogitCDF(
logits, bound_low, bound_up, log_spacing=log_spacing, bin_normalization_method=bin_normalization_method
)
return
dist = BinnedLogitCDF(logits, bound_low, bound_up, log_spacing=log_spacing)
dist = BinnedLogitCDF(
logits, bound_low, bound_up, log_spacing=log_spacing, bin_normalization_method=bin_normalization_method
)

# Test that tensors are on the correct device.
assert dist.logits.device == device
Expand Down Expand Up @@ -162,6 +173,7 @@ def test_expand(
@pytest.mark.parametrize("batch_size", [None, 1, 8])
@pytest.mark.parametrize("num_bins", [2, 200]) # 2 is an edge case for log-spacing
@pytest.mark.parametrize("log_spacing", [False, True], ids=["linear_spacing", "log_spacing"])
@pytest.mark.parametrize("bin_normalization_method", ["sigmoid", "softmax"], ids=["sigmoid", "softmax"])
@pytest.mark.parametrize(
"use_cuda",
[
Expand All @@ -173,6 +185,7 @@ def test_prob_random_logits(
batch_size: int | None,
num_bins: int,
log_spacing: bool,
bin_normalization_method: Literal["sigmoid", "softmax"],
use_cuda: bool,
):
"""Test probability evaluation with random logits at the bounds."""
Expand All @@ -181,7 +194,7 @@ def test_prob_random_logits(
device = torch.device("cuda:0" if use_cuda else "cpu")
logits = torch.randn((num_bins,)) if batch_size is None else torch.randn(batch_size, num_bins)
logits = logits.to(device)
dist = BinnedLogitCDF(logits, log_spacing=log_spacing)
dist = BinnedLogitCDF(logits, log_spacing=log_spacing, bin_normalization_method=bin_normalization_method)

# Define expected shapes based on batch_size. The bins go into the sample shape.
bin_centers = dist.bin_centers
Expand Down Expand Up @@ -210,6 +223,7 @@ def test_prob_random_logits(


@pytest.mark.parametrize("logit_scale", [1e-3, 1, 1e3, 1e9])
@pytest.mark.parametrize("bin_normalization_method", ["sigmoid", "softmax"], ids=["sigmoid", "softmax"])
@pytest.mark.parametrize("batch_size", [None, 1, 8])
@pytest.mark.parametrize(
"use_cuda,plot",
Expand All @@ -222,6 +236,7 @@ def test_prob_random_logits(
def test_cdf_random_logits(
logit_scale: float,
batch_size: int | None,
bin_normalization_method: Literal["sigmoid", "softmax"],
use_cuda: bool,
plot: bool,
bound_low: float = -10,
Expand All @@ -233,7 +248,7 @@ def test_cdf_random_logits(
logits = logit_scale * torch.randn((num_bins,)) if batch_size is None else torch.randn(batch_size, num_bins)
logits = logits.to(device)

dist = BinnedLogitCDF(logits, bound_low, bound_up)
dist = BinnedLogitCDF(logits, bound_low, bound_up, bin_normalization_method=bin_normalization_method)

# Evaluate the CDF at the bounds.
cdf_low = dist.cdf(torch.tensor(bound_low))
Expand All @@ -253,7 +268,10 @@ def test_cdf_random_logits(
plt.title(f"CDF for random logits scaled by {logit_scale}")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f"tests/results/cdf_random_logits_scale-{logit_scale}.png", bbox_inches="tight")
plt.savefig(
f"tests/results/cdf_random_logits_scale-{logit_scale}_normalization-{bin_normalization_method}.png",
bbox_inches="tight",
)


@pytest.mark.parametrize("batch_size", [None, 1, 8, 16])
Expand Down
Loading