-
-
Notifications
You must be signed in to change notification settings - Fork 152
Added spline test for VBT, fixed a bug in levy_tree_transpose, and added references to VBT #342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
import typing | ||
from typing import Any, Optional, TYPE_CHECKING, Union | ||
from typing import Any, Literal, Optional, TYPE_CHECKING, TypeAlias, Union | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aha, I see. Didn't know it was new in 3.10. |
||
|
||
import equinox as eqx | ||
import equinox.internal as eqxi | ||
import jax.numpy as jnp | ||
import jax | ||
import jax.tree_util as jtu | ||
import numpy as np | ||
from jaxtyping import AbstractDtype, Array, ArrayLike, Bool, Float, Int, PyTree, Shaped | ||
|
@@ -52,6 +52,8 @@ class Real(AbstractDtype): | |
BufferDenseInfos = dict[str, PyTree[eqxi.MaybeBuffer[Shaped[Array, "times ..."]]]] | ||
sentinel: Any = eqxi.doc_repr(object(), "sentinel") | ||
|
||
_LA: TypeAlias = Literal["", "space-time"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't start with an underscore if it's imported into another module. I'd suggest calling it just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, yes |
||
|
||
|
||
class LevyVal(eqx.Module): | ||
dt: PyTree | ||
|
@@ -62,7 +64,7 @@ class LevyVal(eqx.Module): | |
bar_K: Optional[PyTree] | ||
|
||
|
||
def levy_tree_transpose(tree_shape, levy_area, tree): | ||
def levy_tree_transpose(tree_shape, levy_area: _LA, tree: PyTree): | ||
"""Helper that takes a PyTree of LevyVals and transposes | ||
into a LevyVal of PyTrees. | ||
|
||
|
@@ -77,26 +79,11 @@ def levy_tree_transpose(tree_shape, levy_area, tree): | |
|
||
A `LevyVal` of PyTrees. | ||
""" | ||
if levy_area == "space-time": | ||
hh_default_val = jnp.zeros(()) | ||
kk_default_val = None | ||
elif levy_area == "": | ||
hh_default_val = None | ||
kk_default_val = None | ||
else: | ||
assert False | ||
inner_tree = jtu.tree_leaves(tree, is_leaf=lambda x: isinstance(x, LevyVal))[0] | ||
inner_tree_shape = jax.tree_structure(inner_tree) | ||
return jtu.tree_transpose( | ||
outer_treedef=jtu.tree_structure(tree_shape), | ||
inner_treedef=jtu.tree_structure( | ||
LevyVal( | ||
dt=0.0, | ||
W=jnp.zeros(()), | ||
H=hh_default_val, | ||
bar_H=None, | ||
K=kk_default_val, | ||
bar_K=None, | ||
) | ||
), | ||
inner_treedef=inner_tree_shape, | ||
pytree_to_transpose=tree, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import math | ||
from typing import Literal, TypeAlias | ||
|
||
import diffrax | ||
import equinox as eqx | ||
|
@@ -10,6 +11,8 @@ | |
import scipy.stats as stats | ||
|
||
|
||
_LA: TypeAlias = Literal["", "space-time"] | ||
_Spline: TypeAlias = Literal["quad", "sqrt", "zero"] | ||
_vals = { | ||
int: [0, 2], | ||
float: [0.0, 2.0], | ||
|
@@ -29,7 +32,7 @@ def _make_struct(shape, dtype): | |
) | ||
@pytest.mark.parametrize("levy_area", ["", "space-time"]) | ||
@pytest.mark.parametrize("use_levy", (False, True)) | ||
def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): | ||
def test_shape_and_dtype(ctr, levy_area: _LA, use_levy, getkey): | ||
t0 = 0 | ||
t1 = 2 | ||
|
||
|
@@ -111,7 +114,7 @@ def is_tuple_of_ints(obj): | |
) | ||
@pytest.mark.parametrize("levy_area", ["", "space-time"]) | ||
@pytest.mark.parametrize("use_levy", (False, True)) | ||
def test_statistics(ctr, levy_area, use_levy): | ||
def test_statistics(ctr, levy_area: _LA, use_levy): | ||
# Deterministic key for this test; not using getkey() | ||
key = jr.PRNGKey(5678) | ||
keys = jr.split(key, 10000) | ||
|
@@ -146,33 +149,35 @@ def _eval(key): | |
assert pval > 0.1 | ||
|
||
|
||
@pytest.mark.parametrize("levy_area", ["", "space-time"]) | ||
@pytest.mark.parametrize("use_levy", (False, True)) | ||
def test_conditional_statistics(levy_area, use_levy): | ||
def conditional_statistics( | ||
levy_area: _LA, use_levy: bool, tol=2**-6, spacing=2**-6, spline: _Spline = "sqrt" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove the default arguments here as I don't think they're ever used. This is a general principle -- prefer not to use default arguments where possible, in particular in internal APIs, as they're a common source of surprising behaviour. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know, thanks for the reminder. I guess old habits die hard. |
||
): | ||
key = jr.PRNGKey(5678) | ||
bm_key, sample_key, permute_key = jr.split(key, 3) | ||
|
||
# Get >80 randomly selected points; not too close to avoid discretisation error. | ||
t0 = 0.3 | ||
t1 = 8.7 | ||
ts = jr.uniform(sample_key, shape=(100,), minval=t0, maxval=t1) | ||
# ts = jnp.array([1.0, 3.0, 6.0, 7.0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops, I tried to catch all the random debugging comments, seems this one escaped. |
||
sorted_ts = jnp.sort(ts) | ||
ts = [] | ||
prev_ti = sorted_ts[0] | ||
ts = [prev_ti] | ||
ts.append(prev_ti) | ||
for ti in sorted_ts[1:]: | ||
if ti < prev_ti + 2**-10: | ||
if ti < prev_ti + spacing: | ||
continue | ||
prev_ti = ti | ||
ts.append(ti) | ||
ts = jnp.stack(ts) | ||
assert len(ts) > 80 | ||
assert len(ts) > min(0.2 * (8.0 / spacing), 75) # for spacing = 2**-5, this is 51 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just provide a lower bound as an explicit argument to this function, rather than using a heuristic like this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. |
||
ts = jr.permutation(permute_key, ts) | ||
|
||
# Get some random paths | ||
bm_keys = jr.split(bm_key, 10000) | ||
|
||
path = jax.vmap( | ||
lambda k: diffrax.VirtualBrownianTree( | ||
t0=t0, t1=t1, shape=(), tol=2**-12, key=k, levy_area=levy_area | ||
t0=t0, t1=t1, shape=(), tol=tol, key=k, levy_area=levy_area, _spline=spline | ||
) | ||
)(bm_keys) | ||
|
||
|
@@ -183,6 +188,10 @@ def test_conditional_statistics(levy_area, use_levy): | |
out.append((ti, vals)) | ||
out = sorted(out, key=lambda x: x[0]) | ||
|
||
pvals_w1 = [] | ||
pvals_w2 = [] | ||
pvals_h = [] | ||
|
||
# Test their conditional statistics | ||
for i in range(1, len(ts) - 1): | ||
s, bm_s = out[i - 1] | ||
|
@@ -211,7 +220,7 @@ def test_conditional_statistics(levy_area, use_levy): | |
_, pval_w1 = stats.kstest(normalised_w1, stats.norm.cdf) | ||
# Raise if the failure is statistically significant at 10%, subject to | ||
# multiple-testing correction. | ||
assert pval_w1 > 0.1 / (len(ts) - 2) | ||
pvals_w1.append(pval_w1) | ||
|
||
if levy_area == "space-time" and use_levy: | ||
assert h_s is not None | ||
|
@@ -234,7 +243,7 @@ def test_conditional_statistics(levy_area, use_levy): | |
w_std2 = 2 * (a + b) / su | ||
normalised_w2 = (w_r - w_mean2) / w_std2 | ||
_, pval_w2 = stats.kstest(normalised_w2, stats.norm.cdf) | ||
assert pval_w2 > 0.1 / (len(ts) - 2) | ||
pvals_w2.append(pval_w2) | ||
|
||
# Check h_r|(w_s, w_u, h_s, h_u) | ||
h_mean = ( | ||
|
@@ -247,7 +256,54 @@ def test_conditional_statistics(levy_area, use_levy): | |
h_std = math.sqrt(h_var) | ||
normalised_hh = (h_r - h_mean) / h_std | ||
_, pval_h = stats.kstest(normalised_hh, stats.norm.cdf) | ||
assert pval_h > 0.1 / (len(ts) - 2) | ||
pvals_h.append(pval_h) | ||
return jnp.array(pvals_w1), jnp.array(pvals_w2), jnp.array(pvals_h) | ||
|
||
|
||
@pytest.mark.parametrize("levy_area", ["", "space-time"]) | ||
@pytest.mark.parametrize("use_levy", (False, True)) | ||
def test_conditional_statistics(levy_area: _LA, use_levy): | ||
pvals_w1, pvals_w2, pvals_h = conditional_statistics( | ||
levy_area, use_levy, tol=2**-8, spacing=2**-10, spline="sqrt" | ||
) | ||
assert jnp.all(pvals_w1 > 0.1 / pvals_w1.shape[0]) | ||
if levy_area == "space-time" and use_levy: | ||
assert jnp.all(pvals_w2 > 0.1 / pvals_w2.shape[0]) | ||
assert jnp.all(pvals_h > 0.1 / pvals_h.shape[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. else:
assert len(pvals_w2) == 0
assert len(pvals_h) == 0 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, might as well. |
||
|
||
|
||
@pytest.mark.parametrize("levy_area", ["", "space-time"]) | ||
def test_spline(levy_area: _LA): | ||
if levy_area == "": | ||
splines = ("quad", "sqrt", "zero") | ||
elif levy_area == "space-time": | ||
splines = ("sqrt", "zero") | ||
else: | ||
assert False | ||
|
||
for spline in splines: | ||
pvals_w1, pvals_w2, pvals_h = conditional_statistics( | ||
levy_area, use_levy=True, tol=2**-5, spacing=2**-6, spline=spline | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we parameterise this test by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might be a bit pointless, but sure, it's not that expensive. |
||
) | ||
|
||
if spline == "sqrt": | ||
# For the correct spline, make sure that all p-values are above | ||
# 0.1 (subject to multiple-testing correction) and the average | ||
# p-value is above 0.3. | ||
def pred(pvals): | ||
return jnp.min(pvals) > 0.1 / pvals.shape[0] and jnp.mean(pvals) > 0.3 | ||
|
||
else: | ||
# make sure that for incorrect splines at least one p-value is | ||
# below 0.01 (subject to multiple-testing correction) and the | ||
# average p-value is below 0.2. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the average p-value should probably be a lot smaller than this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit of a subtle issue. So normally yes, it is a lot smaller, exept in one case. When What I'm saying is, that the variance that So what I will do, is to just keep this more permissive upper bound to |
||
def pred(pvals): | ||
return jnp.min(pvals) < 0.01 / pvals.shape[0] and jnp.mean(pvals) < 0.2 | ||
|
||
assert pred(pvals_w1) | ||
if levy_area == "space-time": | ||
assert pred(pvals_w2) | ||
assert pred(pvals_h) | ||
|
||
|
||
def test_levy_area_reverse_time(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be formatted as:
i.e. with a new line, with an indent, and with only a single blank line between the reference section and the arguments section.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, not sure what happened with the indent there. Will fix