Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added spline test for VBT, fixed a bug in levy_tree_transpose, and added references to VBT #342

Merged
merged 4 commits into from
Dec 24, 2023

Conversation

andyElking
Copy link
Contributor

@patrick-kidger This includes the updates you suggested. For the timeing tests, however, I made a branch identical to this one, but with an additional Jupyter notebook:
Timing tests

I'm not really sure where you wanted me to put the timeing code, so I put it in a separate branch, but I can also add it here.

Comment on lines 415 to 429
??? cite "Reference"
Based on section 6.1 of
```bibtex
@phdthesis{foster2020a,
publisher = {University of Oxford},
school = {University of Oxford},
title = {Numerical approximations for stochastic differential equations},
author = {Foster, James M.},
year = {2020}
}

In particular see Theorem 6.1.6.
```


Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be formatted as:

??? cite "References"

    ```bibtex
    ...
    ```

i.e. with a new line, with an indent, and with only a single blank line between the reference section and the arguments section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, not sure what happened with the indent there. Will fix

@@ -1,9 +1,9 @@
import typing
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Literal, Optional, TYPE_CHECKING, TypeAlias, Union
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeAlias should be imported from typing_extensions for now, as we don't yet require Python 3.10.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I see. Didn't know it was new in 3.10.

@@ -52,6 +52,8 @@ class Real(AbstractDtype):
BufferDenseInfos = dict[str, PyTree[eqxi.MaybeBuffer[Shaped[Array, "times ..."]]]]
sentinel: Any = eqxi.doc_repr(object(), "sentinel")

_LA: TypeAlias = Literal["", "space-time"]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't start with an underscore if it's imported into another module. I'd suggest calling it just LevyArea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, yes

@pytest.mark.parametrize("use_levy", (False, True))
def test_conditional_statistics(levy_area, use_levy):
def conditional_statistics(
levy_area: _LA, use_levy: bool, tol=2**-6, spacing=2**-6, spline: _Spline = "sqrt"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the default arguments here as I don't think they're ever used.

This is a general principle -- prefer not to use default arguments where possible, in particular in internal APIs, as they're a common source of surprising behaviour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, thanks for the reminder. I guess old habits die hard.

# Get >80 randomly selected points; not too close to avoid discretisation error.
t0 = 0.3
t1 = 8.7
ts = jr.uniform(sample_key, shape=(100,), minval=t0, maxval=t1)
# ts = jnp.array([1.0, 3.0, 6.0, 7.0])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I tried to catch all the random debugging comments, seems this one escaped.

assert jnp.all(pvals_w1 > 0.1 / pvals_w1.shape[0])
if levy_area == "space-time" and use_levy:
assert jnp.all(pvals_w2 > 0.1 / pvals_w2.shape[0])
assert jnp.all(pvals_h > 0.1 / pvals_h.shape[0])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
    assert len(pvals_w2) == 0
    assert len(pvals_h) == 0

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, might as well.

continue
prev_ti = ti
ts.append(ti)
ts = jnp.stack(ts)
assert len(ts) > 80
assert len(ts) > min(0.2 * (8.0 / spacing), 75) # for spacing = 2**-5, this is 51
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just provide a lower bound as an explicit argument to this function, rather than using a heuristic like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.


for spline in splines:
pvals_w1, pvals_w2, pvals_h = conditional_statistics(
levy_area, use_levy=True, tol=2**-5, spacing=2**-6, spline=spline
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we parameterise this test by use_levy=False/True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be a bit pointless, but sure, it's not that expensive.

else:
# make sure that for incorrect splines at least one p-value is
# below 0.01 (subject to multiple-testing correction) and the
# average p-value is below 0.2.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the average p-value should probably be a lot smaller than this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a subtle issue. So normally yes, it is a lot smaller, exept in one case. When levy_area="spacetime" then pvals_w1 actually are quite good even with spline="zero", due to the randomness the w_r receives from h_su.

What I'm saying is, that the variance that pval_w1 sees is the variance of the Brownian parabola, but without accounting for the conditioning on H, i.e.
$$
\E[ ( \E[W_{s,r} | W_{s,u}, H_{s,u} ] )^2 - ( \E[W_{s,r} | W_{s,u} ] )^2 ]
$$
Which is actually a very good approximation of the actual variance. In pval_w2 the influence of H is accounted for in the mean, which is subtracted, and hence the variance can only come from x1 (which is zero when spline="zero"). I hope this makes sense.

So what I will do, is to just keep this more permissive upper bound to pvals_w1 when levy_area="spacetime", and add a stricter upper bound to all the other cases.

@patrick-kidger
Copy link
Owner

For the timing code, I'd suggest adding it as a standalone .py file inside benchmarks/.

@andyElking
Copy link
Contributor Author

andyElking commented Dec 22, 2023

For the timing code, I'd suggest adding it as a standalone .py file inside benchmarks/.

Honestly, I wasn't even aware the benchmarks/ folder even existed until now. I added it, and these are the results (on my laptop):
New Shallow BM: 2.406
New Shallow STLA: 4.481
New Deep BM: 6.489
New Deep STLA: 12.068
Old Shallow: 3.123
Old Deep: 7.904

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments (really teaching points) on the new benchmark/test, but otherwise LGTM. :)

import jax.random as jr
import jax.tree_util as jtu
from diffrax import VirtualBrownianTree
from diffrax._brownian.base import AbstractBrownianPath
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is accessible as diffrax.AbstractBrownianpath.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that was done automatically by the IDE, and I haven't noticed.

from diffrax import VirtualBrownianTree
from diffrax._brownian.base import AbstractBrownianPath
from diffrax._custom_types import RealScalarLike
from diffrax._misc import default_floating_dtype, is_tuple_of_ints, split_by_tree
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can skip the is_tuple_of_ints and split_by_tree imports given what we initialise OldVBT(..., shape=(100,)).

default_floating_dtype is now available from lineax.internal with the latest Lineax release.

(I'm just trying to minimise private imports here.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense

tree = tree_cls(t0=t0, t1=t1, tol=tol, shape=(100,), key=key, levy_area=levy_area)

def f():
return jax.block_until_ready(vec_eval(tree, ts))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be a jax.jit here as vec_eval involves a jax.vmap, which is a JAX operation happening outside of vmap.

Also, for timing benchmarks like this it is better to do many repeats and take a minimum. (Noise can only increase things from the best possible time, so you want a minimum, not a mean.)

Thus I think you want something like:

@jax.jit
def run(ts):
    return jax.vmap(lambda _t: tree.evaluate(_t, use_levy=True))(_ts)

return min(timeit.repeat(lambda: jax.block_until_ready(run(ts)), number=1, repeat=100))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, so any JAX operation should be wrapped inside a jit?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup! At least, if you want them to fast.

So for typical neural network training, I often won't bother putting a JIT when __init__ialising a model (which only happens once), but will definitely put one around the entirety of the forward pass.

@pytest.mark.parametrize("spline", ("quad", "sqrt", "zero"))
def test_spline(levy_area: LevyArea, use_levy, spline):
if levy_area == "space-time" and spline == "quad":
pytest.skip("Quad spline is not implemented for space-time Levy area")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, a skipped test is usually used to refer to something that is currently broken, but for which we just don't to worry about right now. In particular it still shows up in the output test log each time.
This means that skips should usually only be used in the short-term, not in the long-term.

Anyway, here you probably want this:

def _levy_area_spline():
    for levy_area in ("", "space-time"):
        for spline in ("quad", "sqrt", "zero"):
            if levy_area == "space-time" and spline == "quad":
                continue
            yield levy_area, spline

@pytest.mark.parametrize("levy_area,spline", _levy_area_spline())

which won't even generate the test to be skipped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's good to know.

@andyElking
Copy link
Contributor Author

Minor comments (really teaching points) on the new benchmark/test, but otherwise LGTM. :)

Thanks! These all make sense. I won't be able to make these edits today or tomorrow, however, so is it okay if you fix these in your new_pr_branch and merge if you're already at it? Otherwise I can do it on Monday 😊.

@patrick-kidger
Copy link
Owner

Haha, no worries! I'll let you make the changes after Christmas.

Merry Christmas!

@andyElking
Copy link
Contributor Author

andyElking commented Dec 23, 2023 via email

@andyElking
Copy link
Contributor Author

andyElking commented Dec 24, 2023

I managed to do it today after all 😊.

Should I squash them all together, or do you intend to reorganise the commits yourself anyway?

@patrick-kidger patrick-kidger merged commit f84e731 into patrick-kidger:new_pr_branch Dec 24, 2023
0 of 4 checks passed
@patrick-kidger
Copy link
Owner

Alright, LGTM! I've squashed these together and made a few tweaks, and you can see the result back in #337. Let's continue the discussion over there :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants