Skip to content

Commit

Permalink
In progress commit on branch jump_step_pr.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Feb 9, 2025
1 parent 8d4212c commit effab83
Show file tree
Hide file tree
Showing 11 changed files with 511 additions and 536 deletions.
12 changes: 6 additions & 6 deletions benchmarks/jump_step_timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def get_terms(key):
pid_controller = diffrax.PIDController(
rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7
)
new_controller = diffrax.JumpStepWrapper(
new_controller = diffrax.ClipStepSizeController(
pid_controller,
step_ts=step_ts,
rejected_step_buffer_len=None,
store_rejected_steps=None,
)
old_controller = OldPIDController(
rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7, step_ts=step_ts
Expand Down Expand Up @@ -88,16 +88,16 @@ def time_controller():
print(f"New controller: {time_new:.5} s, Old controller: {time_old:.5} s")

# How expensive is revisiting rejected steps?
revisiting_controller_short = diffrax.JumpStepWrapper(
revisiting_controller_short = diffrax.ClipStepSizeController(
pid_controller,
step_ts=step_ts,
rejected_step_buffer_len=10,
store_rejected_steps=10,
)

revisiting_controller_long = diffrax.JumpStepWrapper(
revisiting_controller_long = diffrax.ClipStepSizeController(
pid_controller,
step_ts=step_ts,
rejected_step_buffer_len=4096,
store_rejected_steps=4096,
)

time_revisiting_short = do_timing(revisiting_controller_short)
Expand Down
2 changes: 1 addition & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@
from ._step_size_controller import (
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
AbstractStepSizeController as AbstractStepSizeController,
ClipStepSizeController as ClipStepSizeController,
ConstantStepSize as ConstantStepSize,
JumpStepWrapper as JumpStepWrapper,
PIDController as PIDController,
StepTo as StepTo,
)
Expand Down
19 changes: 15 additions & 4 deletions diffrax/_autocitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
SRA1,
Tsit5,
)
from ._step_size_controller import PIDController
from ._step_size_controller import ClipStepSizeController, PIDController


def citation(*args, **kwargs):
Expand Down Expand Up @@ -134,7 +134,7 @@ def citation(*args, **kwargs):


_thesis_cite = r"""
phdthesis{kidger2021on,
@phdthesis{kidger2021on,
title={{O}n {N}eural {D}ifferential {E}quations},
author={Patrick Kidger},
year={2021},
Expand Down Expand Up @@ -352,10 +352,10 @@ def _virtual_brownian_tree(terms):
return (
r"""
% You are simulating Brownian motion using a virtual Brownian tree, which was introduced
% in:
% in the following two papers:
"""
+ vbt_ref
+ "\n\n"
+ "\n"
+ single_seed_ref
)

Expand Down Expand Up @@ -570,6 +570,17 @@ def _auto_dt0(dt0):
"""


@citation_rules.append
def _clip_controller(terms, stepsize_controller):
if type(stepsize_controller) is ClipStepSizeController:
if stepsize_controller.store_rejected_steps is not None and is_sde(terms):
return r"""
% You are adaptively solving an SDE whilst revisiting rejected time points. This is a
% subtle point required for the correctness of adaptive noncommutative SDE solves, as
% found in:
""" + _parse_reference(ClipStepSizeController)


@citation_rules.append
def _pid_controller(stepsize_controller, terms=None):
if type(stepsize_controller) is PIDController:
Expand Down
8 changes: 8 additions & 0 deletions diffrax/_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ class RESULTS(optx.RESULTS): # pyright: ignore
event_occurred = (
"Terminating differential equation solve because an event occurred."
)
max_steps_rejected = (
"Maximum number of rejected steps was reached. Consider increasing "
"`diffrax.ClipStepSizeController(store_rejected_steps==...)`."
)
internal_error = (
"An internal error occurred in Diffrax. This is a bug! Please open a GitHub "
"issue with a minimum working example. (<50 lines of code is ideal)"
)


# Backward compatibility
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_step_size_controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
AbstractStepSizeController as AbstractStepSizeController,
)
from .clip import ClipStepSizeController as ClipStepSizeController
from .constant import ConstantStepSize as ConstantStepSize, StepTo as StepTo
from .jump_step_wrapper import JumpStepWrapper as JumpStepWrapper
from .pid import (
PIDController as PIDController,
)
Loading

0 comments on commit effab83

Please sign in to comment.