Skip to content

Added spline test for VBT, fixed a bug in levy_tree_transpose, and added references to VBT #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import cast, Literal, Optional, Union
from typing import cast, Optional, Union

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -9,7 +9,7 @@
import jax.tree_util as jtu
from jaxtyping import Array, PRNGKeyArray, PyTree

from .._custom_types import levy_tree_transpose, LevyVal, RealScalarLike
from .._custom_types import _LA, levy_tree_transpose, LevyVal, RealScalarLike
from .._misc import (
default_floating_dtype,
force_bitcast_convert_type,
Expand Down Expand Up @@ -42,14 +42,14 @@ class UnsafeBrownianPath(AbstractBrownianPath):
"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: Literal["", "space-time"] = eqx.field(static=True)
levy_area: _LA = eqx.field(static=True)
key: PRNGKeyArray

def __init__(
self,
shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: PRNGKeyArray,
levy_area: Literal["", "space-time"] = "",
levy_area: _LA = "",
):
self.shape = (
jax.ShapeDtypeStruct(shape, default_floating_dtype())
Expand Down
77 changes: 63 additions & 14 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from jaxtyping import Array, Float, PRNGKeyArray, PyTree

from .._custom_types import (
_LA,
BoolScalarLike,
IntScalarLike,
levy_tree_transpose,
Expand All @@ -36,21 +37,27 @@
# year={2021},
# school={University of Oxford},
# }
#

# We define
# H_{s,t} = 1/(t-s) ( \int_s^t ( W_u - (u-s)/(t-s) W_{s,t} ) du ).
# bhh_t = t * H_{0,t}
# For more details see Definition 4.2.1 and Theorem 6.1.4 of
#
# Foster, J. M. (2020). Numerical approximations for stochastic
# differential equations [PhD thesis]. University of Oxford.

# For more details see section 6.1 of
# @phdthesis{foster2020a,
# publisher = {University of Oxford},
# school = {University of Oxford},
# title = {Numerical approximations for stochastic differential equations},
# author = {Foster, James M.},
# year = {2020}
# }
# For more about space-time Levy area see Definition 4.2.1.
# For the midpoint rule for generating space-time Levy area see Theorem 6.1.6.
# For the general interpolation rule for space-time Levy area see Theorem 6.1.4.

FloatDouble: TypeAlias = tuple[Float[Array, " *shape"], Float[Array, " *shape"]]
FloatTriple: TypeAlias = tuple[
Float[Array, " *shape"], Float[Array, " *shape"], Float[Array, " *shape"]
]
_Spline: TypeAlias = Literal["sqrt", "quad", "zero"]


class _State(eqx.Module):
Expand Down Expand Up @@ -130,17 +137,20 @@ class VirtualBrownianTree(AbstractBrownianPath):
t1: RealScalarLike
tol: RealScalarLike
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: Literal["", "space-time"] = eqx.field(static=True)
levy_area: _LA = eqx.field(static=True)
key: PyTree[PRNGKeyArray]
_spline: _Spline = eqx.field(static=True)

@eqxi.doc_remove_args("_spline")
def __init__(
self,
t0: RealScalarLike,
t1: RealScalarLike,
tol: RealScalarLike,
shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: PRNGKeyArray,
levy_area: Literal["", "space-time"] = "",
levy_area: _LA = "",
_spline: _Spline = "sqrt",
):
(t0, t1) = eqx.error_if((t0, t1), t0 >= t1, "t0 must be strictly less than t1")
self.t0 = t0
Expand All @@ -154,6 +164,7 @@ def __init__(
f"levy_area must be one of '', 'space-time', but got {levy_area}."
)
self.levy_area = levy_area
self._spline = _spline
self.shape = (
jax.ShapeDtypeStruct(shape, default_floating_dtype())
if is_tuple_of_ints(shape)
Expand Down Expand Up @@ -321,20 +332,43 @@ def _body_fun(_state: _State):

# BM only case
if self.levy_area == "":
z = jr.normal(final_state.key, shape, dtype)
w_sr = sr / su * w_su + jnp.sqrt(sr * ru / su) * z
w_r = w_s + w_sr
w_mean = w_s + sr / su * w_su
if self._spline == "sqrt":
z = jr.normal(final_state.key, shape, dtype)
bb = jnp.sqrt(sr * ru / su) * z
elif self._spline == "quad":
z = jr.normal(final_state.key, shape, dtype)
bb = (sr * ru / su) * z
elif self._spline == "zero":
bb = jnp.zeros(shape, dtype)
else:
assert False
w_r = w_mean + bb
return LevyVal(dt=r, W=w_r, H=None, bar_H=None, K=None, bar_K=None)

elif self.levy_area == "space-time":
# This is based on Theorem 6.1.4 of Foster's thesis (see above).

assert final_state.bhh_s_u_su is not None
bhh_s, bhh_u, bhh_su = final_state.bhh_s_u_su
sr3 = jnp.power(sr, 3)
ru3 = jnp.power(ru, 3)
su3 = jnp.power(su, 3)
key1, key2 = jr.split(final_state.key, 2)
x1 = jr.normal(key1, shape, dtype)
x2 = jr.normal(key2, shape, dtype)

# Here "quad" spline doesn't really exist, but we can still
# compare "sqrt" and "zero" splines.
if self._spline == "sqrt":
key1, key2 = jr.split(final_state.key, 2)
x1 = jr.normal(key1, shape, dtype)
x2 = jr.normal(key2, shape, dtype)
elif self._spline == "zero":
x1 = jnp.zeros(shape, dtype)
x2 = jnp.zeros(shape, dtype)
else:
raise ValueError(
f"When levy_area='space-time', only 'sqrt' and"
f" 'zero' splines are permitted, got {self._spline}."
)

sr_ru_half = jnp.sqrt(sr * ru)
d = jnp.sqrt(sr3 + ru3)
Expand Down Expand Up @@ -378,6 +412,21 @@ def _brownian_arch(
there for the sake of a future extension with "space-time-time" Levy area
and should be None for now.

??? cite "Reference"
Based on section 6.1 of
```bibtex
@phdthesis{foster2020a,
publisher = {University of Oxford},
school = {University of Oxford},
title = {Numerical approximations for stochastic differential equations},
author = {Foster, James M.},
year = {2020}
}

In particular see Theorem 6.1.6.
```


Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be formatted as:

??? cite "References"

    ```bibtex
    ...
    ```

i.e. with a new line, with an indent, and with only a single blank line between the reference section and the arguments section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, not sure what happened with the indent there. Will fix

**Arguments:**

- `_state`: The state of the Brownian tree
Expand Down
29 changes: 8 additions & 21 deletions diffrax/_custom_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import typing
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Literal, Optional, TYPE_CHECKING, TypeAlias, Union
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeAlias should be imported from typing_extensions for now, as we don't yet require Python 3.10.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I see. Didn't know it was new in 3.10.


import equinox as eqx
import equinox.internal as eqxi
import jax.numpy as jnp
import jax
import jax.tree_util as jtu
import numpy as np
from jaxtyping import AbstractDtype, Array, ArrayLike, Bool, Float, Int, PyTree, Shaped
Expand Down Expand Up @@ -52,6 +52,8 @@ class Real(AbstractDtype):
BufferDenseInfos = dict[str, PyTree[eqxi.MaybeBuffer[Shaped[Array, "times ..."]]]]
sentinel: Any = eqxi.doc_repr(object(), "sentinel")

_LA: TypeAlias = Literal["", "space-time"]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't start with an underscore if it's imported into another module. I'd suggest calling it just LevyArea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, yes



class LevyVal(eqx.Module):
dt: PyTree
Expand All @@ -62,7 +64,7 @@ class LevyVal(eqx.Module):
bar_K: Optional[PyTree]


def levy_tree_transpose(tree_shape, levy_area, tree):
def levy_tree_transpose(tree_shape, levy_area: _LA, tree: PyTree):
"""Helper that takes a PyTree of LevyVals and transposes
into a LevyVal of PyTrees.

Expand All @@ -77,26 +79,11 @@ def levy_tree_transpose(tree_shape, levy_area, tree):

A `LevyVal` of PyTrees.
"""
if levy_area == "space-time":
hh_default_val = jnp.zeros(())
kk_default_val = None
elif levy_area == "":
hh_default_val = None
kk_default_val = None
else:
assert False
inner_tree = jtu.tree_leaves(tree, is_leaf=lambda x: isinstance(x, LevyVal))[0]
inner_tree_shape = jax.tree_structure(inner_tree)
return jtu.tree_transpose(
outer_treedef=jtu.tree_structure(tree_shape),
inner_treedef=jtu.tree_structure(
LevyVal(
dt=0.0,
W=jnp.zeros(()),
H=hh_default_val,
bar_H=None,
K=kk_default_val,
bar_K=None,
)
),
inner_treedef=inner_tree_shape,
pytree_to_transpose=tree,
)

Expand Down
82 changes: 69 additions & 13 deletions test/test_brownian.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from typing import Literal, TypeAlias

import diffrax
import equinox as eqx
Expand All @@ -10,6 +11,8 @@
import scipy.stats as stats


_LA: TypeAlias = Literal["", "space-time"]
_Spline: TypeAlias = Literal["quad", "sqrt", "zero"]
_vals = {
int: [0, 2],
float: [0.0, 2.0],
Expand All @@ -29,7 +32,7 @@ def _make_struct(shape, dtype):
)
@pytest.mark.parametrize("levy_area", ["", "space-time"])
@pytest.mark.parametrize("use_levy", (False, True))
def test_shape_and_dtype(ctr, levy_area, use_levy, getkey):
def test_shape_and_dtype(ctr, levy_area: _LA, use_levy, getkey):
t0 = 0
t1 = 2

Expand Down Expand Up @@ -111,7 +114,7 @@ def is_tuple_of_ints(obj):
)
@pytest.mark.parametrize("levy_area", ["", "space-time"])
@pytest.mark.parametrize("use_levy", (False, True))
def test_statistics(ctr, levy_area, use_levy):
def test_statistics(ctr, levy_area: _LA, use_levy):
# Deterministic key for this test; not using getkey()
key = jr.PRNGKey(5678)
keys = jr.split(key, 10000)
Expand Down Expand Up @@ -146,33 +149,35 @@ def _eval(key):
assert pval > 0.1


@pytest.mark.parametrize("levy_area", ["", "space-time"])
@pytest.mark.parametrize("use_levy", (False, True))
def test_conditional_statistics(levy_area, use_levy):
def conditional_statistics(
levy_area: _LA, use_levy: bool, tol=2**-6, spacing=2**-6, spline: _Spline = "sqrt"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the default arguments here as I don't think they're ever used.

This is a general principle -- prefer not to use default arguments where possible, in particular in internal APIs, as they're a common source of surprising behaviour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, thanks for the reminder. I guess old habits die hard.

):
key = jr.PRNGKey(5678)
bm_key, sample_key, permute_key = jr.split(key, 3)

# Get >80 randomly selected points; not too close to avoid discretisation error.
t0 = 0.3
t1 = 8.7
ts = jr.uniform(sample_key, shape=(100,), minval=t0, maxval=t1)
# ts = jnp.array([1.0, 3.0, 6.0, 7.0])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I tried to catch all the random debugging comments, seems this one escaped.

sorted_ts = jnp.sort(ts)
ts = []
prev_ti = sorted_ts[0]
ts = [prev_ti]
ts.append(prev_ti)
for ti in sorted_ts[1:]:
if ti < prev_ti + 2**-10:
if ti < prev_ti + spacing:
continue
prev_ti = ti
ts.append(ti)
ts = jnp.stack(ts)
assert len(ts) > 80
assert len(ts) > min(0.2 * (8.0 / spacing), 75) # for spacing = 2**-5, this is 51
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just provide a lower bound as an explicit argument to this function, rather than using a heuristic like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

ts = jr.permutation(permute_key, ts)

# Get some random paths
bm_keys = jr.split(bm_key, 10000)

path = jax.vmap(
lambda k: diffrax.VirtualBrownianTree(
t0=t0, t1=t1, shape=(), tol=2**-12, key=k, levy_area=levy_area
t0=t0, t1=t1, shape=(), tol=tol, key=k, levy_area=levy_area, _spline=spline
)
)(bm_keys)

Expand All @@ -183,6 +188,10 @@ def test_conditional_statistics(levy_area, use_levy):
out.append((ti, vals))
out = sorted(out, key=lambda x: x[0])

pvals_w1 = []
pvals_w2 = []
pvals_h = []

# Test their conditional statistics
for i in range(1, len(ts) - 1):
s, bm_s = out[i - 1]
Expand Down Expand Up @@ -211,7 +220,7 @@ def test_conditional_statistics(levy_area, use_levy):
_, pval_w1 = stats.kstest(normalised_w1, stats.norm.cdf)
# Raise if the failure is statistically significant at 10%, subject to
# multiple-testing correction.
assert pval_w1 > 0.1 / (len(ts) - 2)
pvals_w1.append(pval_w1)

if levy_area == "space-time" and use_levy:
assert h_s is not None
Expand All @@ -234,7 +243,7 @@ def test_conditional_statistics(levy_area, use_levy):
w_std2 = 2 * (a + b) / su
normalised_w2 = (w_r - w_mean2) / w_std2
_, pval_w2 = stats.kstest(normalised_w2, stats.norm.cdf)
assert pval_w2 > 0.1 / (len(ts) - 2)
pvals_w2.append(pval_w2)

# Check h_r|(w_s, w_u, h_s, h_u)
h_mean = (
Expand All @@ -247,7 +256,54 @@ def test_conditional_statistics(levy_area, use_levy):
h_std = math.sqrt(h_var)
normalised_hh = (h_r - h_mean) / h_std
_, pval_h = stats.kstest(normalised_hh, stats.norm.cdf)
assert pval_h > 0.1 / (len(ts) - 2)
pvals_h.append(pval_h)
return jnp.array(pvals_w1), jnp.array(pvals_w2), jnp.array(pvals_h)


@pytest.mark.parametrize("levy_area", ["", "space-time"])
@pytest.mark.parametrize("use_levy", (False, True))
def test_conditional_statistics(levy_area: _LA, use_levy):
pvals_w1, pvals_w2, pvals_h = conditional_statistics(
levy_area, use_levy, tol=2**-8, spacing=2**-10, spline="sqrt"
)
assert jnp.all(pvals_w1 > 0.1 / pvals_w1.shape[0])
if levy_area == "space-time" and use_levy:
assert jnp.all(pvals_w2 > 0.1 / pvals_w2.shape[0])
assert jnp.all(pvals_h > 0.1 / pvals_h.shape[0])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
    assert len(pvals_w2) == 0
    assert len(pvals_h) == 0

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, might as well.



@pytest.mark.parametrize("levy_area", ["", "space-time"])
def test_spline(levy_area: _LA):
if levy_area == "":
splines = ("quad", "sqrt", "zero")
elif levy_area == "space-time":
splines = ("sqrt", "zero")
else:
assert False

for spline in splines:
pvals_w1, pvals_w2, pvals_h = conditional_statistics(
levy_area, use_levy=True, tol=2**-5, spacing=2**-6, spline=spline
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we parameterise this test by use_levy=False/True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be a bit pointless, but sure, it's not that expensive.

)

if spline == "sqrt":
# For the correct spline, make sure that all p-values are above
# 0.1 (subject to multiple-testing correction) and the average
# p-value is above 0.3.
def pred(pvals):
return jnp.min(pvals) > 0.1 / pvals.shape[0] and jnp.mean(pvals) > 0.3

else:
# make sure that for incorrect splines at least one p-value is
# below 0.01 (subject to multiple-testing correction) and the
# average p-value is below 0.2.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the average p-value should probably be a lot smaller than this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a subtle issue. So normally yes, it is a lot smaller, exept in one case. When levy_area="spacetime" then pvals_w1 actually are quite good even with spline="zero", due to the randomness the w_r receives from h_su.

What I'm saying is, that the variance that pval_w1 sees is the variance of the Brownian parabola, but without accounting for the conditioning on H, i.e.
$$
\E[ ( \E[W_{s,r} | W_{s,u}, H_{s,u} ] )^2 - ( \E[W_{s,r} | W_{s,u} ] )^2 ]
$$
Which is actually a very good approximation of the actual variance. In pval_w2 the influence of H is accounted for in the mean, which is subtracted, and hence the variance can only come from x1 (which is zero when spline="zero"). I hope this makes sense.

So what I will do, is to just keep this more permissive upper bound to pvals_w1 when levy_area="spacetime", and add a stricter upper bound to all the other cases.

def pred(pvals):
return jnp.min(pvals) < 0.01 / pvals.shape[0] and jnp.mean(pvals) < 0.2

assert pred(pvals_w1)
if levy_area == "space-time":
assert pred(pvals_w2)
assert pred(pvals_h)


def test_levy_area_reverse_time():
Expand Down