Skip to content

Commit

Permalink
corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Jan 27, 2025
1 parent 656989e commit 0d446ad
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 25 deletions.
2 changes: 1 addition & 1 deletion diffrax/_step_size_controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


_ControllerState = TypeVar("_ControllerState")
_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike])
_Dt0 = TypeVar("_Dt0", bound=Optional[RealScalarLike])


class AbstractStepSizeController(eqx.Module, Generic[_ControllerState, _Dt0]):
Expand Down
21 changes: 11 additions & 10 deletions diffrax/_step_size_controller/jump_step_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .._misc import static_select, upcast_or_raise
from .._solution import RESULTS
from .._term import AbstractTerm
from .base import AbstractAdaptiveStepSizeController, AbstractStepSizeController
from .base import AbstractStepSizeController


_ControllerState = TypeVar("_ControllerState")
Expand Down Expand Up @@ -77,7 +77,7 @@ def _clip_ts(
return _t1, jump_at_t1


def find_idx_with_hint(t: RealScalarLike, ts: Optional[Array], hint: IntScalarLike):
def _find_idx_with_hint(t: RealScalarLike, ts: Optional[Array], hint: IntScalarLike):
# Find index of first element of ts greater than t
# using linear search starting from hint.
if ts is None:
Expand Down Expand Up @@ -186,7 +186,7 @@ class JumpStepWrapper(
# }
# ```

controller: AbstractAdaptiveStepSizeController[_ControllerState, _Dt0]
controller: AbstractStepSizeController[_ControllerState, _Dt0]
step_ts: Optional[Real[Array, " steps"]]
jump_ts: Optional[Real[Array, " jumps"]]
rejected_step_buffer_len: Optional[int] = eqx.field(static=True)
Expand Down Expand Up @@ -358,12 +358,13 @@ def adapt_step_size(

# This is just a logging utility for testing purposes
if self.callback_on_reject is not None:
jax.debug.callback(self.callback_on_reject, keep_step, t1)
# jax.debug.callback(self.callback_on_reject, keep_step, t1)
jax.experimental.io_callback(self.callback_on_reject, None, keep_step, t1) # pyright: ignore

# For step ts and jump ts find the index of the first element in jump_ts/step_ts
# greater than next_t0. We use the hint i_step/i_jump to speed up the search.
i_step = find_idx_with_hint(next_t0, st.step_ts, i_step)
i_jump = find_idx_with_hint(next_t0, st.jump_ts, i_jump)
i_step = _find_idx_with_hint(next_t0, st.step_ts, i_step)
i_jump = _find_idx_with_hint(next_t0, st.jump_ts, i_jump)

if self.rejected_step_buffer_len is not None:
rejected_buffer = st.rejected_buffer
Expand Down Expand Up @@ -414,11 +415,11 @@ def adapt_step_size(
)

if TYPE_CHECKING: # if i don't seperate this out pyright complains
assert isinstance(next_t0, RealScalarLike)
else:
assert isinstance(
next_t0, RealScalarLike
next_t0, get_args(RealScalarLike)
), f"type(next_t0) = {type(next_t0)}"
else:
isinstance(next_t0, get_args(RealScalarLike))

# Clip the step to the next element of jump_ts or step_ts or
# rejected_buffer. Important to do jump_ts last because otherwise
Expand All @@ -439,8 +440,8 @@ def adapt_step_size(
# OR of the two.
jump_at_next_t1 = jnp.where(
next_t1 == original_next_t1,
jump_at_original_next_t1,
jump_at_next_t1 | jump_at_original_next_t1,
jump_at_next_t1,
)

# Here made_jump signifies whether there is a jump at t1. What the solver
Expand Down
7 changes: 2 additions & 5 deletions diffrax/_step_size_controller/pid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@
# We use a metaclass for backwards compatibility. When a user calls
# PIDController(... step_ts=s, jump_ts=j) this should return a
# JumpStepWrapper(PIDController(...), s, j).
module_meta = type(eqx.Module)


class PIDMeta(module_meta):
class _PIDMeta(type(eqx.Module)):
def __call__(cls, *args, **kwargs):
step_ts = kwargs.pop("step_ts", None)
jump_ts = kwargs.pop("jump_ts", None)
Expand Down Expand Up @@ -124,7 +121,7 @@ def __repr__(self):
# in Soderlind and Wang 2006.
class PIDController(
AbstractAdaptiveStepSizeController[_PidState, Optional[RealScalarLike]],
metaclass=PIDMeta,
metaclass=_PIDMeta,
):
r"""Adapts the step size to produce a solution accurate to a given tolerance.
The tolerance is calculated as `atol + rtol * y` for the evolving solution `y`.
Expand Down
18 changes: 9 additions & 9 deletions docs/api/stepsize_controller.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ The list of step size controllers is as follows. The most common cases are fixed
For more details about the convergence of adaptive solutions to SDEs, please refer to

```bibtex
@misc{foster2024convergenceadaptiveapproximationsstochastic,
title={On the convergence of adaptive approximations for stochastic differential equations},
author={James Foster and Andraž Jelinčič},
year={2024},
eprint={2311.14201},
archivePrefix={arXiv},
primaryClass={math.NA},
url={https://arxiv.org/abs/2311.14201},
}
@misc{foster2024convergenceadaptiveapproximationsstochastic,
title={On the convergence of adaptive approximations for stochastic differential equations},
author={James Foster and Andraž Jelinčič},
year={2024},
eprint={2311.14201},
archivePrefix={arXiv},
primaryClass={math.NA},
url={https://arxiv.org/abs/2311.14201},
}
```


Expand Down
43 changes: 43 additions & 0 deletions test/test_adaptive_stepsize_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def diffusion_vf(t, y, args):
def callback_fun(keep_step, t1):
if not keep_step:
rejected_ts_list.append(t1)
return None

stepsize_controller = diffrax.JumpStepWrapper(
pid_controller,
Expand Down Expand Up @@ -246,3 +247,45 @@ def test_pid_meta():
assert all(pid2.step_ts == ts)
assert all(pid3.step_ts == ts)
assert all(pid3.jump_ts == ts)


def test_nested_jump_step_wrappers():
pid = diffrax.PIDController(rtol=0, atol=1.0)
wrap1 = diffrax.JumpStepWrapper(pid, jump_ts=[3.0, 13.0], step_ts=[23.0])
wrap2 = diffrax.JumpStepWrapper(wrap1, step_ts=[2.0, 13.0], jump_ts=[23.0])
func = lambda terms, t, y, args: -y
terms = diffrax.ODETerm(lambda t, y, args: -y)
_, state = wrap2.init(terms, -1.0, 0.0, 0.0, 4.0, None, func, 5)

# test 1
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
0.0, 1.0, 0.0, 0.0, None, 0.0, 5, state
)
assert next_t1 == 2
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
next_t0, next_t1, 0.0, 0.0, None, 0.0, 5, state
)
assert jnp.isclose(next_t0, 2)
assert not made_jump

# test 2
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
10.0, 11.0, 0.0, 0.0, None, 0.0, 5, state
)
assert next_t1 == 13
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
next_t0, next_t1, 0.0, 0.0, None, 0.0, 5, state
)
assert jnp.isclose(next_t0, 13)
assert made_jump

# test 3
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
20.0, 21.0, 0.0, 0.0, None, 0.0, 5, state
)
assert next_t1 == 23
_, next_t0, next_t1, made_jump, state, _ = wrap2.adapt_step_size(
next_t0, next_t1, 0.0, 0.0, None, 0.0, 5, state
)
assert jnp.isclose(next_t0, 23)
assert made_jump

0 comments on commit 0d446ad

Please sign in to comment.