-
-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added spline test for VBT, fixed a bug in levy_tree_transpose, and added references to VBT #342
Added spline test for VBT, fixed a bug in levy_tree_transpose, and added references to VBT #342
Conversation
…ded references in VBT
diffrax/_brownian/tree.py
Outdated
??? cite "Reference" | ||
Based on section 6.1 of | ||
```bibtex | ||
@phdthesis{foster2020a, | ||
publisher = {University of Oxford}, | ||
school = {University of Oxford}, | ||
title = {Numerical approximations for stochastic differential equations}, | ||
author = {Foster, James M.}, | ||
year = {2020} | ||
} | ||
|
||
In particular see Theorem 6.1.6. | ||
``` | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be formatted as:
??? cite "References"
```bibtex
...
```
i.e. with a new line, with an indent, and with only a single blank line between the reference section and the arguments section.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, not sure what happened with the indent there. Will fix
@@ -1,9 +1,9 @@ | |||
import typing | |||
from typing import Any, Optional, TYPE_CHECKING, Union | |||
from typing import Any, Literal, Optional, TYPE_CHECKING, TypeAlias, Union |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TypeAlias
should be imported from typing_extensions
for now, as we don't yet require Python 3.10.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, I see. Didn't know it was new in 3.10.
diffrax/_custom_types.py
Outdated
@@ -52,6 +52,8 @@ class Real(AbstractDtype): | |||
BufferDenseInfos = dict[str, PyTree[eqxi.MaybeBuffer[Shaped[Array, "times ..."]]]] | |||
sentinel: Any = eqxi.doc_repr(object(), "sentinel") | |||
|
|||
_LA: TypeAlias = Literal["", "space-time"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't start with an underscore if it's imported into another module. I'd suggest calling it just LevyArea
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, yes
test/test_brownian.py
Outdated
@pytest.mark.parametrize("use_levy", (False, True)) | ||
def test_conditional_statistics(levy_area, use_levy): | ||
def conditional_statistics( | ||
levy_area: _LA, use_levy: bool, tol=2**-6, spacing=2**-6, spline: _Spline = "sqrt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove the default arguments here as I don't think they're ever used.
This is a general principle -- prefer not to use default arguments where possible, in particular in internal APIs, as they're a common source of surprising behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know, thanks for the reminder. I guess old habits die hard.
test/test_brownian.py
Outdated
# Get >80 randomly selected points; not too close to avoid discretisation error. | ||
t0 = 0.3 | ||
t1 = 8.7 | ||
ts = jr.uniform(sample_key, shape=(100,), minval=t0, maxval=t1) | ||
# ts = jnp.array([1.0, 3.0, 6.0, 7.0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, I tried to catch all the random debugging comments, seems this one escaped.
assert jnp.all(pvals_w1 > 0.1 / pvals_w1.shape[0]) | ||
if levy_area == "space-time" and use_levy: | ||
assert jnp.all(pvals_w2 > 0.1 / pvals_w2.shape[0]) | ||
assert jnp.all(pvals_h > 0.1 / pvals_h.shape[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else:
assert len(pvals_w2) == 0
assert len(pvals_h) == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, might as well.
test/test_brownian.py
Outdated
continue | ||
prev_ti = ti | ||
ts.append(ti) | ||
ts = jnp.stack(ts) | ||
assert len(ts) > 80 | ||
assert len(ts) > min(0.2 * (8.0 / spacing), 75) # for spacing = 2**-5, this is 51 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe just provide a lower bound as an explicit argument to this function, rather than using a heuristic like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea.
test/test_brownian.py
Outdated
|
||
for spline in splines: | ||
pvals_w1, pvals_w2, pvals_h = conditional_statistics( | ||
levy_area, use_levy=True, tol=2**-5, spacing=2**-6, spline=spline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we parameterise this test by use_levy=False/True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be a bit pointless, but sure, it's not that expensive.
test/test_brownian.py
Outdated
else: | ||
# make sure that for incorrect splines at least one p-value is | ||
# below 0.01 (subject to multiple-testing correction) and the | ||
# average p-value is below 0.2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the average p-value should probably be a lot smaller than this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit of a subtle issue. So normally yes, it is a lot smaller, exept in one case. When levy_area="spacetime"
then pvals_w1
actually are quite good even with spline="zero"
, due to the randomness the w_r receives from h_su.
What I'm saying is, that the variance that pval_w1
sees is the variance of the Brownian parabola, but without accounting for the conditioning on H, i.e.
$$
\E[ ( \E[W_{s,r} | W_{s,u}, H_{s,u} ] )^2 - ( \E[W_{s,r} | W_{s,u} ] )^2 ]
$$
Which is actually a very good approximation of the actual variance. In pval_w2
the influence of H is accounted for in the mean, which is subtracted, and hence the variance can only come from x1 (which is zero when spline="zero"
). I hope this makes sense.
So what I will do, is to just keep this more permissive upper bound to pvals_w1
when levy_area="spacetime"
, and add a stricter upper bound to all the other cases.
For the timing code, I'd suggest adding it as a standalone |
Honestly, I wasn't even aware the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments (really teaching points) on the new benchmark/test, but otherwise LGTM. :)
benchmarks/brownian_tree_times.py
Outdated
import jax.random as jr | ||
import jax.tree_util as jtu | ||
from diffrax import VirtualBrownianTree | ||
from diffrax._brownian.base import AbstractBrownianPath |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is accessible as diffrax.AbstractBrownianpath
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that was done automatically by the IDE, and I haven't noticed.
benchmarks/brownian_tree_times.py
Outdated
from diffrax import VirtualBrownianTree | ||
from diffrax._brownian.base import AbstractBrownianPath | ||
from diffrax._custom_types import RealScalarLike | ||
from diffrax._misc import default_floating_dtype, is_tuple_of_ints, split_by_tree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can skip the is_tuple_of_ints
and split_by_tree
imports given what we initialise OldVBT(..., shape=(100,))
.
default_floating_dtype
is now available from lineax.internal
with the latest Lineax release.
(I'm just trying to minimise private imports here.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that makes sense
benchmarks/brownian_tree_times.py
Outdated
tree = tree_cls(t0=t0, t1=t1, tol=tol, shape=(100,), key=key, levy_area=levy_area) | ||
|
||
def f(): | ||
return jax.block_until_ready(vec_eval(tree, ts)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there should be a jax.jit
here as vec_eval
involves a jax.vmap
, which is a JAX operation happening outside of vmap.
Also, for timing benchmarks like this it is better to do many repeats and take a minimum. (Noise can only increase things from the best possible time, so you want a minimum, not a mean.)
Thus I think you want something like:
@jax.jit
def run(ts):
return jax.vmap(lambda _t: tree.evaluate(_t, use_levy=True))(_ts)
return min(timeit.repeat(lambda: jax.block_until_ready(run(ts)), number=1, repeat=100))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, so any JAX operation should be wrapped inside a jit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup! At least, if you want them to fast.
So for typical neural network training, I often won't bother putting a JIT when __init__
ialising a model (which only happens once), but will definitely put one around the entirety of the forward pass.
@pytest.mark.parametrize("spline", ("quad", "sqrt", "zero")) | ||
def test_spline(levy_area: LevyArea, use_levy, spline): | ||
if levy_area == "space-time" and spline == "quad": | ||
pytest.skip("Quad spline is not implemented for space-time Levy area") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, a skipped test is usually used to refer to something that is currently broken, but for which we just don't to worry about right now. In particular it still shows up in the output test log each time.
This means that skips should usually only be used in the short-term, not in the long-term.
Anyway, here you probably want this:
def _levy_area_spline():
for levy_area in ("", "space-time"):
for spline in ("quad", "sqrt", "zero"):
if levy_area == "space-time" and spline == "quad":
continue
yield levy_area, spline
@pytest.mark.parametrize("levy_area,spline", _levy_area_spline())
which won't even generate the test to be skipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that's good to know.
Thanks! These all make sense. I won't be able to make these edits today or tomorrow, however, so is it okay if you fix these in your new_pr_branch and merge if you're already at it? Otherwise I can do it on Monday 😊. |
Haha, no worries! I'll let you make the changes after Christmas. Merry Christmas! |
Merry Christmas to you too!
|
I managed to do it today after all 😊. Should I squash them all together, or do you intend to reorganise the commits yourself anyway? |
f84e731
into
patrick-kidger:new_pr_branch
Alright, LGTM! I've squashed these together and made a few tweaks, and you can see the result back in #337. Let's continue the discussion over there :) |
@patrick-kidger This includes the updates you suggested. For the timeing tests, however, I made a branch identical to this one, but with an additional Jupyter notebook:
Timing tests
I'm not really sure where you wanted me to put the timeing code, so I put it in a separate branch, but I can also add it here.