Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KDE bandwidth selectors using biased or unbiased cross-validation #2384

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New features
- Add optimized simultaneous ECDF confidence bands ([2368](https://github.com/arviz-devs/arviz/pull/2368))
- Add support for setting groups with `idata[group]` ([2374](https://github.com/arviz-devs/arviz/pull/2374))
- Add cross-validation-based KDE bandwidth selection methods. ([2384](https://github.com/arviz-devs/arviz/pull/2384))

### Maintenance and fixes

Expand Down
4 changes: 2 additions & 2 deletions arviz/plots/kdeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def plot_kde(
bw : float or str, optional
If numeric, indicates the bandwidth and must be positive.
If str, indicates the method to estimate the bandwidth and must be
one of "scott", "silverman", "isj" or "experimental" when ``is_circular`` is False
and "taylor" (for now) when ``is_circular`` is True.
one of "scott", "silverman", "isj", "experimental", "ucv", or "bcv" when ``is_circular`` is
False and "taylor" (for now) when ``is_circular`` is True.
Defaults to "default" which means "experimental" when variable is not circular
and "taylor" when it is.
adaptive : bool, default False
Expand Down
135 changes: 122 additions & 13 deletions arviz/stats/density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import numpy as np
from scipy.fftpack import fft
from scipy.optimize import brentq
from scipy.signal import convolve, convolve2d
from scipy.optimize import brentq, minimize_scalar
from scipy.signal import convolve, convolve2d, correlate
from scipy.signal.windows import gaussian
from scipy.sparse import coo_matrix
from scipy.special import ive # pylint: disable=no-name-in-module
Expand Down Expand Up @@ -34,7 +34,102 @@ def _bw_silverman(x, x_std=None, **kwargs): # pylint: disable=unused-argument
return bw


def _bw_isj(x, grid_counts=None, x_std=None, x_range=None):
def _bw_oversmoothed(x, x_std=None, **kwargs): # pylint: disable=unused-argument
"""Oversmoothed bandwidth estimation."""
if x_std is None:
x_std = np.std(x)
bw = 1.144 * x_std * len(x) ** (-0.2)
return bw


def _bw_cv(x, unbiased=True, bin_width=None, grid_counts=None, x_std=None, **kwargs): # pylint: disable=unused-argument
"""Cross-validation bandwidth estimation."""
if x_std is None:
x_std = np.std(x)

if bin_width is None or grid_counts is None:
x_min = x.min()
x_max = x.max()
grid_len = 256
grid_min = x_min - 0.5 * x_std
grid_max = x_max + 0.5 * x_std
grid_counts, _, grid_edges = histogram(x, grid_len, (grid_min, grid_max))
bin_width = grid_edges[1] - grid_edges[0]

x_len = len(x)
grid_counts_comb, ks = _prepare_cv_score_inputs(grid_counts, x_len)

bw_max = _bw_oversmoothed(x, x_std=x_std)
bw_min = bin_width / (2 * np.pi)

def _compute_score(bw):
return _compute_cv_score(bw, x_len, bin_width, unbiased, grid_counts_comb, ks)

result = minimize_scalar(_compute_score, bounds=(bw_min, bw_max), method="bounded")
if not result.success:
warnings.warn("Optimizing the bandwidth using cross-validation did not converge.")
bw_opt = result.x

return bw_opt


def _prepare_cv_score_inputs(grid_counts, x_len):
grid_len = len(grid_counts)
# entry j is the sum over i of grid_counts[i] * grid_counts[i + j]
grid_counts_comb = correlate(grid_counts[1:], grid_counts[:-1], mode="full")[-grid_len:]
# correct for within-bin counts
grid_counts_comb[0] = 0.5 * (grid_counts_comb[0] - x_len)
ks = np.arange(0, grid_len)
return grid_counts_comb, ks


def _compute_cv_score(bw, x_len, bin_width, unbiased, grid_counts_comb, ks): # pylint: disable=too-many-positional-arguments
deltas = ks * (bin_width / bw)
if unbiased:
summand = np.exp(-0.25 * deltas**2) - np.sqrt(8) * np.exp(-0.5 * deltas**2)
else:
summand = (deltas**4 - 12 * deltas**2 + 12) * np.exp(-0.25 * deltas**2) / 64
score = (0.5 + np.inner(grid_counts_comb, summand) / x_len) / (
x_len * bw * np.sqrt(np.pi)
)
return score


def _bw_ucv(x, **kwargs):
"""Unbiased cross-validation bandwidth estimation.

This method optimizes the bandwidth to minimize the mean integrated squared error of the kernel
density estimate as explained in [1]_. This implementation has been modified to operate on
binned data, which is more efficient.

References
----------
.. [1] Multivariate Density Estimation: Theory, Practice, and Visualization.
D. Scott.
Wiley, 2015.
Section 6.5.1.3
"""
return _bw_cv(x, unbiased=True, **kwargs)


def _bw_bcv(x, **kwargs):
"""Biased cross-validation bandwidth estimation.

This method optimizes the bandwidth to minimize the asymptotic mean integrated squared error of
the kernel density estimate as explained in [1]_. This implementation has been modified to
operate on binned data, which is more efficient.

References
----------
.. [1] Multivariate Density Estimation: Theory, Practice, and Visualization.
D. Scott.
Wiley, 2015.
Section 6.5.1.3
"""
return _bw_cv(x, unbiased=False, **kwargs)


def _bw_isj(x, grid_counts=None, x_std=None, x_range=None, **kwargs): # pylint: disable=unused-argument
"""Improved Sheather-Jones bandwidth estimation.

Improved Sheather and Jones method as explained in [1]_. This method is used internally by the
Expand Down Expand Up @@ -76,7 +171,7 @@ def _bw_isj(x, grid_counts=None, x_std=None, x_range=None):
return h


def _bw_experimental(x, grid_counts=None, x_std=None, x_range=None):
def _bw_experimental(x, grid_counts=None, x_std=None, x_range=None, **kwargs): # pylint: disable=unused-argument
"""Experimental bandwidth estimator."""
bw_silverman = _bw_silverman(x, x_std=x_std)
bw_isj = _bw_isj(x, grid_counts=grid_counts, x_range=x_range)
Expand Down Expand Up @@ -111,10 +206,12 @@ def _bw_taylor(x):
"silverman": _bw_silverman,
"isj": _bw_isj,
"experimental": _bw_experimental,
"ucv": _bw_ucv,
"bcv": _bw_bcv,
}


def _get_bw(x, bw, grid_counts=None, x_std=None, x_range=None):
def _get_bw(x, bw, grid_counts=None, bin_width=None, x_std=None, x_range=None): # pylint: disable=too-many-positional-arguments
Copy link
Contributor

@aloctavodia aloctavodia Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can disable "too-many..." globally. It's popping up in many places.

"""Compute bandwidth for a given data `x` and `bw`.

Also checks `bw` is correctly specified.
Expand Down Expand Up @@ -155,7 +252,7 @@ def _get_bw(x, bw, grid_counts=None, x_std=None, x_range=None):
)

bw_fun = _BW_METHODS_LINEAR[bw_lower]
bw = bw_fun(x, grid_counts=grid_counts, x_std=x_std, x_range=x_range)
bw = bw_fun(x, grid_counts=grid_counts, bin_width=bin_width, x_std=x_std, x_range=x_range)
else:
raise ValueError(
"Unrecognized `bw` argument.\n"
Expand Down Expand Up @@ -321,7 +418,7 @@ def _check_custom_lims(custom_lims, x_min, x_max):

def _get_grid(
x_min, x_max, x_std, extend_fct, grid_len, custom_lims, extend=True, bound_correction=False
):
): # pylint: disable=too-many-positional-arguments
"""Compute the grid that bins the data used to estimate the density function.

Parameters
Expand Down Expand Up @@ -450,6 +547,17 @@ def kde(x, circular=False, **kwargs):
>>> grid, pdf = kde(rvs, bound_correction=False, custom_lims=(0, 11))
>>> plt.plot(grid, pdf)

Density estimation for well-separated modes with bandwidth chosen using unbiased
cross-validation

.. plot::
:context: close-figs

>>> rvs = np.concatenate([np.random.normal(0, 1, 500), np.random.normal(30, 1, 500)])
>>> grid, pdf = kde(rvs, bw='ucv')
>>> plt.plot(grid, pdf)


Default density estimation for circular data

.. plot::
Expand Down Expand Up @@ -499,7 +607,7 @@ def kde(x, circular=False, **kwargs):
return kde_fun(x, **kwargs)


def _kde_linear(
def _kde_linear( # pylint: disable=too-many-positional-arguments
x,
bw="experimental",
adaptive=False,
Expand All @@ -525,7 +633,7 @@ def _kde_linear(
bw: int, float or str, optional
If numeric, indicates the bandwidth and must be positive.
If str, indicates the method to estimate the bandwidth and must be one of "scott",
"silverman", "isj" or "experimental". Defaults to "experimental".
"silverman", "isj", "experimental", "ucv", or "bcv". Defaults to "experimental".
adaptive: boolean, optional
Indicates if the bandwidth is adaptive or not.
It is the recommended approach when there are multiple modes with different spread.
Expand Down Expand Up @@ -580,9 +688,10 @@ def _kde_linear(
x_min, x_max, x_std, extend_fct, grid_len, custom_lims, extend, bound_correction
)
grid_counts, _, grid_edges = histogram(x, grid_len, (grid_min, grid_max))
bin_width = grid_edges[1] - grid_edges[0]

# Bandwidth estimation
bw = bw_fct * _get_bw(x, bw, grid_counts, x_std, x_range)
bw = bw_fct * _get_bw(x, bw, grid_counts, bin_width, x_std, x_range)

# Density estimation
if adaptive:
Expand All @@ -599,7 +708,7 @@ def _kde_linear(
return grid, pdf


def _kde_circular(
def _kde_circular( # pylint: disable=too-many-positional-arguments
x,
bw="taylor",
bw_fct=1,
Expand Down Expand Up @@ -689,7 +798,7 @@ def _kde_circular(


# pylint: disable=unused-argument
def _kde_convolution(x, bw, grid_edges, grid_counts, grid_len, bound_correction, **kwargs):
def _kde_convolution(x, bw, grid_edges, grid_counts, grid_len, bound_correction, **kwargs): # pylint: disable=too-many-positional-arguments
"""Kernel density with convolution.

One dimensional Gaussian kernel density estimation via convolution of the binned relative
Expand Down Expand Up @@ -723,7 +832,7 @@ def _kde_convolution(x, bw, grid_edges, grid_counts, grid_len, bound_correction,
return grid, pdf


def _kde_adaptive(x, bw, grid_edges, grid_counts, grid_len, bound_correction, **kwargs):
def _kde_adaptive(x, bw, grid_edges, grid_counts, grid_len, bound_correction, **kwargs): # pylint: disable=too-many-positional-arguments
"""Compute Adaptive Kernel Density Estimation.

One dimensional adaptive Gaussian kernel density estimation. The implementation uses the binning
Expand Down
78 changes: 78 additions & 0 deletions arviz/tests/base_tests/test_stats_density_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest

import numpy as np
from ...data import load_arviz_data
from ...stats.density_utils import (
_prepare_cv_score_inputs,
_compute_cv_score,
_bw_cv,
_bw_oversmoothed,
_bw_scott,
histogram,
)


def compute_cv_score_explicit(bw, x, unbiased):
"""Explicit computation of the CV score for a 1D dataset."""
n = len(x)
score = 0.0
for i in range(n):
for j in range(i + 1, n):
delta = (x[i] - x[j]) / bw
if unbiased:
score += np.exp(-0.25 * delta**2) - np.sqrt(8) * np.exp(-0.5 * delta**2)
else:
score += (delta**4 - 12 * delta**2 + 12) * np.exp(-0.25 * delta**2)
if not unbiased:
score /= 64
score = 0.5 / n / bw / np.sqrt(np.pi) + score / n**2 / bw / np.sqrt(np.pi)
return score


def test_histogram():
school = load_arviz_data("non_centered_eight").posterior["mu"].values
k_count_az, k_dens_az, _ = histogram(school, bins=np.asarray([-np.inf, 0.5, 0.7, 1, np.inf]))
k_dens_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=True)
k_count_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=False)
assert np.allclose(k_count_az, k_count_np)
assert np.allclose(k_dens_az, k_dens_np)


@pytest.mark.parametrize("unbiased", [True, False])
@pytest.mark.parametrize("bw", [0.1, 0.5, 2.0])
@pytest.mark.parametrize("n", [100, 1_000])
def test_compute_cv_score(bw, unbiased, n, seed=42):
"""Test that the histogram-based CV score matches the explicit CV score."""
rng = np.random.default_rng(seed)
x = rng.normal(size=n)
x_std = x.std()
grid_counts, grid_edges = np.histogram(
x, bins=100, range=(x.min() - 0.5 * x_std, x.max() + 0.5 * x_std)
)
bin_width = grid_edges[1] - grid_edges[0]
grid = grid_edges[:-1] + 0.5 * bin_width

# if data is discretized to regularly-spaced bins, then explicit CV score should match
# the histogram-based CV score
x_discrete = np.repeat(grid, grid_counts)
rng.shuffle(x_discrete)
score_inputs = _prepare_cv_score_inputs(grid_counts, n)
score = _compute_cv_score(bw, n, bin_width, unbiased, *score_inputs)
score_explicit = compute_cv_score_explicit(bw, x_discrete, unbiased)
assert np.isclose(score, score_explicit)


@pytest.mark.parametrize("unbiased", [True, False])
def test_bw_cv_normal(unbiased, seed=42, bins=512, n=100_000):
"""Test that for normal target, selected CV bandwidth converges to known optimum."""
rng = np.random.default_rng(seed)
x = rng.normal(size=n)
x_std = x.std()
grid_counts, grid_edges = np.histogram(
x, bins=bins, range=(x.min() - 0.5 * x_std, x.max() + 0.5 * x_std)
)
bin_width = grid_edges[1] - grid_edges[0]
bw = _bw_cv(x, unbiased=unbiased, bin_width=bin_width, grid_counts=grid_counts)
assert bw > bin_width / (2 * np.pi)
assert bw < _bw_oversmoothed(x)
assert np.isclose(bw, _bw_scott(x), rtol=0.2)
10 changes: 0 additions & 10 deletions arviz/tests/base_tests/test_stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from scipy.stats import circstd

from ...data import from_dict, load_arviz_data
from ...stats.density_utils import histogram
from ...stats.stats_utils import (
ELPDData,
_angle,
Expand Down Expand Up @@ -343,15 +342,6 @@ def test_variance_bad_data():
assert not np.allclose(stats_variance_2d(data), np.var(data, ddof=1))


def test_histogram():
school = load_arviz_data("non_centered_eight").posterior["mu"].values
k_count_az, k_dens_az, _ = histogram(school, bins=np.asarray([-np.inf, 0.5, 0.7, 1, np.inf]))
k_dens_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=True)
k_count_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=False)
assert np.allclose(k_count_az, k_count_np)
assert np.allclose(k_dens_az, k_dens_np)


def test_sqrt():
x = np.random.rand(100)
y = np.random.rand(100)
Expand Down
Loading