Skip to content

Commit

Permalink
DOC: Add cubic_spline to docs
Browse files Browse the repository at this point in the history
Add to docs
Fix extrapolation
Add to change log
Complete coverage
  • Loading branch information
bashtage committed Dec 2, 2024
1 parent 33ca821 commit 5ad860c
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 63 deletions.
8 changes: 8 additions & 0 deletions docsite/docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ For changes since the latest tagged release, please refer to the

---

## 1.1.0 (Unreleased)

** Enhancements **

* Added cubic spline support for cyclic (`cc`) and natural (`cr`). See
`formulaic.materializers.transforms.cubic_spline.cubic_spline` for
more details.

## 1.0.2 (12 July 2024)

**Bugfixes and cleanups:**
Expand Down
157 changes: 143 additions & 14 deletions docsite/docs/guides/splines.ipynb

Large diffs are not rendered by default.

78 changes: 45 additions & 33 deletions formulaic/transforms/cubic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from __future__ import annotations

from functools import partial
from typing import Any, Iterable, cast
from typing import Iterable, cast

import numpy
import pandas
Expand All @@ -47,13 +47,6 @@ class ExtrapolationError(ValueError):
pass


def safe_string_eq(obj: Any, value: str) -> bool:
if isinstance(obj, str):
return obj == value
else:
return False


def _find_knots_lower_bounds(x: numpy.ndarray, knots: numpy.ndarray) -> numpy.ndarray:
"""Finds knots lower bounds for given values.
Expand Down Expand Up @@ -456,6 +449,15 @@ def cubic_spline( # pylint: disable=dangerous-default-value # always replaced
not specified this is determined from `x`.
upper_bound: The upper bound for the domain for the B-Spline basis. If
not specified this is determined from `x`.
constraints: Either a 2-d array defining general linear constraints
(that is ``np.dot(constraints, betas)`` is zero, where ``betas`` denotes
the array of *initial* parameters, corresponding to the *initial*
unconstrained design matrix), or the string ``'center'`` indicating
that we should apply a centering constraint (this constraint
will be computed from the input data, remembered and re-used for
prediction from the fitted model). The constraints are absorbed
in the resulting design matrix which means that the model is
actually rewritten in terms of *unconstrained* parameters.
extrapolation: Selects how extrapolation should be performed when values
in `x` extend beyond the lower and upper bounds. Valid values are:
- 'raise': Raises a `ValueError` if there are any values in `x`
Expand Down Expand Up @@ -513,44 +515,54 @@ def cubic_spline( # pylint: disable=dangerous-default-value # always replaced
else:
_state["cyclic"] = cyclic

if "extrapolation" in _state:
extrapolation = SplineExtrapolation(_state["extrapolation"])
else:
extrapolation = SplineExtrapolation(extrapolation)
_state["extrapolation"] = extrapolation.value
extrapolation = SplineExtrapolation(extrapolation)

# Check extrapolations and adjust x if necessary
# SplineExtrapolation.EXTEND is the natural default, so no need to do anything
knots_x = x
below_lower = x < lower_bound
above_upper = x > upper_bound
if extrapolation is SplineExtrapolation.RAISE and numpy.any(
(x < lower_bound) | (x > upper_bound)
below_lower | above_upper
):
raise ExtrapolationError(
"Some field values extend beyond upper and/or lower bounds, which can "
"result in ill-conditioned bases. Pass a value for `extrapolation` to "
"control how extrapolation should be performed."
)
elif extrapolation is SplineExtrapolation.CLIP:
x = numpy.clip(x, lower_bound, upper_bound)
elif extrapolation in (SplineExtrapolation.NA, SplineExtrapolation.ZERO):
fill_value = numpy.nan if extrapolation is SplineExtrapolation.NA else 0.0
x = numpy.where((x >= lower_bound) & (x <= upper_bound), x, fill_value)
elif extrapolation in (
SplineExtrapolation.CLIP,
SplineExtrapolation.ZERO,
SplineExtrapolation.NA,
):
out_of_bounds = below_lower | above_upper
if "knots" not in _state and numpy.any(out_of_bounds):
knots_x = x[~out_of_bounds]
if extrapolation is SplineExtrapolation.CLIP:
x = numpy.clip(x, lower_bound, upper_bound)
elif extrapolation is SplineExtrapolation.NA:
x = numpy.where(~out_of_bounds, x, numpy.nan)

# Prepare knots
if "knots" not in _state:
if df is None and knots is None:
raise ValueError("Must specify either 'df' or 'knots'.")

n_constraints = 0
if constraints is not None:
if safe_string_eq(constraints, "center"):
# Here we collect only number of constraints,
# actual centering constraint will be computed after all_knots
n_constraints = 1
else:
constraints_arr = numpy.atleast_2d(constraints)
if constraints_arr.ndim != 2:
raise ValueError("Constraints must be 2-d array or 1-d vector.")
n_constraints = constraints_arr.shape[0]
centering_constraint = isinstance(constraints, str) and constraints == "center"
if centering_constraint:
# Here we collect only number of constraints,
# actual centering constraint will be computed after all_knots
n_constraints = 1
elif isinstance(constraints, str):
raise ValueError(
"Constraints must be 'center' when not passed as an array."
)
elif constraints is not None:
constraints_arr = numpy.atleast_2d(constraints)
if constraints_arr.ndim != 2:
raise ValueError("Constraints must be 2-d array or 1-d vector.")
n_constraints = constraints_arr.shape[0]

n_inner_knots = None
if df is not None:
Expand All @@ -566,19 +578,18 @@ def cubic_spline( # pylint: disable=dangerous-default-value # always replaced
n_inner_knots += 1
_knots = numpy.array(knots) if knots is not None else None
all_knots = _get_all_sorted_knots(
x,
knots_x,
lower_bound,
upper_bound,
n_inner_knots=n_inner_knots,
inner_knots=_knots,
)
if constraints is not None:
if safe_string_eq(constraints, "center"):
if centering_constraint:
# Now we can compute centering constraints
constraints_arr = _get_centering_constraint_from_matrix(
_get_free_cubic_spline_matrix(x, all_knots, cyclic=cyclic)
)

df_before_constraints = all_knots.size
if cyclic:
df_before_constraints -= 1
Expand All @@ -594,7 +605,8 @@ def cubic_spline( # pylint: disable=dangerous-default-value # always replaced

# Compute cubic splines
cs_mat = _get_cubic_spline_matrix(x, knots, constraints, cyclic=cyclic)

if extrapolation is SplineExtrapolation.ZERO:
cs_mat[below_lower | above_upper] = 0.0
return FactorValues(
{i + 1: cs_mat[:, i] for i in range(cs_mat.shape[1])},
kind="numerical",
Expand Down
44 changes: 28 additions & 16 deletions tests/transforms/test_cubic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def test_crs_errors():
# Too small 'df' for cyclic cubic spline
with pytest.raises(ValueError):
cubic_spline(numpy.arange(50), df=0, _state={})
with pytest.raises(ValueError, match="Constraints must be"):
cubic_spline(numpy.linspace(0, 1, 200), df=3, constraints="unknown", _state={})


def test_crs_with_specific_constraint():
Expand Down Expand Up @@ -252,10 +254,10 @@ def test_crs_compat_with_r(test_data):
out_stateful = func(
cubic_spline_test_x,
df=df,
knots=knots,
lower_bound=lower_bound,
upper_bound=upper_bound,
constraints=constraints,
_state=state,
)
out_stateful_arr = numpy.column_stack(list(out_stateful.values()))
numpy.testing.assert_allclose(out_stateful_arr, out_arr, atol=1e-10)
Expand All @@ -273,12 +275,9 @@ def test_statefulness():
"upper_bound": 0.9,
"constraints": None,
"cyclic": True,
"extrapolation": "raise",
}
# Test separately to avoid exact float comparison
numpy.testing.assert_allclose(knots, [0.1, 0.3, 0.5, 0.7, 0.9])
with pytest.raises(ExtrapolationError):
cubic_spline([-0.1, 1.1], _state=state)


def test_cubic_spline_edges():
Expand Down Expand Up @@ -313,18 +312,6 @@ def test_alternative_extrapolation():
assert not numpy.allclose(extrap[1], res[1])
assert not numpy.allclose(extrap[2], res[2])

res = cubic_spline(
data, df=2, extrapolation="zero", lower_bound=-5.5, upper_bound=5.5, _state={}
)
data_zeroed = numpy.where((data > -5.5) & (data < 5.5), data, 0.0)
direct_res = cubic_spline(
data_zeroed, df=2, lower_bound=-5.5, upper_bound=5.5, _state={}
)
numpy.testing.assert_allclose(res[1], direct_res[1])
numpy.testing.assert_allclose(res[2], direct_res[2])
assert not numpy.allclose(extrap[1], res[1])
assert not numpy.allclose(extrap[2], res[2])

res = cubic_spline(
data, df=2, extrapolation="na", lower_bound=-5.0, upper_bound=5.0, _state={}
)
Expand All @@ -344,3 +331,28 @@ def test_alternative_extrapolation():
upper_bound=5.5,
_state={},
)

lower_bound = -5.5
upper_bound = 5.5
in_bounds = (data >= lower_bound) & (data <= upper_bound)
valid_data = data[in_bounds]
state = {}
res = cubic_spline(
valid_data,
df=2,
extrapolation="zero",
lower_bound=-5.5,
upper_bound=5.5,
_state=state,
)
re_res = cubic_spline(data, extrapolation="zero", _state=state)
for i in res:
numpy.testing.assert_allclose(res[i], re_res[i][in_bounds])
numpy.testing.assert_allclose(
re_res[i][~in_bounds], numpy.zeros((~in_bounds).sum())
)
res2 = cubic_spline(
data, df=2, extrapolation="zero", lower_bound=-5.5, upper_bound=5.5, _state={}
)
for i in res:
numpy.testing.assert_allclose(res2[i], re_res[i])

0 comments on commit 5ad860c

Please sign in to comment.