Skip to content

Commit

Permalink
Update citation-handling to work with nested step size controllers
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Feb 10, 2025
1 parent ff301da commit 50c9ff1
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 16 deletions.
50 changes: 34 additions & 16 deletions diffrax/_autocitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,22 +571,21 @@ def _auto_dt0(dt0):


@citation_rules.append
def _clip_controller(terms, stepsize_controller):
if type(stepsize_controller) is ClipStepSizeController:
if stepsize_controller.store_rejected_steps is not None and is_sde(terms):
return r"""
% You are adaptively solving an SDE whilst revisiting rejected time points. This is a
% subtle point required for the correctness of adaptive noncommutative SDE solves, as
% found in:
""" + _parse_reference(ClipStepSizeController)
def _stepsize_controller(stepsize_controller, terms=None):
out = _stepsize_controller_impl(terms, stepsize_controller)
if len(out) == 0:
return None
else:
return "\n\n".join(x.strip() for x in out)


@citation_rules.append
def _pid_controller(stepsize_controller, terms=None):
def _stepsize_controller_impl(terms, stepsize_controller) -> set[str]:
out = set()
if type(stepsize_controller) is PIDController:
if is_sde(terms):
return r"""
% The use of PI and PI controllers to adapt step sizes for SDEs are from:
out.add(
r"""
% The use of adaptive step size controllers for SDEs are from:
@article{burrage2004adaptive,
title={Adaptive stepsize based on control theory for stochastic
differential equations},
Expand All @@ -612,14 +611,16 @@ def _pid_controller(stepsize_controller, terms=None):
pages={791–-812},
}
"""
)
else:
no_p = stepsize_controller.pcoeff == 0
no_d = stepsize_controller.dcoeff == 0
_no_tracer(no_p, "stepsize_controller.pcoeff")
_no_tracer(no_d, "stepsize_controller.dcoeff")
if no_d:
if no_p:
return r"""
out.add(
r"""
% The use of an I-controller to adapt step sizes is from Section II.4 of:
@book{hairer2008solving-i,
address={Berlin},
Expand All @@ -631,8 +632,10 @@ def _pid_controller(stepsize_controller, terms=None):
year={2008}
}
"""
)
else:
return r"""
out.add(
r"""
% The use of a PI-controller to adapt step sizes is from Section IV.2 of:
@book{hairer2002solving-ii,
address={Berlin},
Expand All @@ -653,9 +656,11 @@ def _pid_controller(stepsize_controller, terms=None):
pages={281--310}
}
"""
)
else:
return r"""
% The use of a PID controller to adapt step sizes is from:
out.add(
r"""
% The use of a PID-controller to adapt step sizes is from:
@article{soderlind2003digital,
title={{D}igital {F}ilters in {A}daptive {T}ime-{S}tepping,
author={Gustaf S{\"o}derlind},
Expand All @@ -666,3 +671,16 @@ def _pid_controller(stepsize_controller, terms=None):
pages={1--26}
}
"""
)
elif type(stepsize_controller) is ClipStepSizeController:
out.update(_stepsize_controller_impl(terms, stepsize_controller.controller))
if stepsize_controller.store_rejected_steps is not None and is_sde(terms):
out.add(
r"""
% You are adaptively solving an SDE whilst revisiting rejected time points. This is a
% subtle point required for the correctness of adaptive noncommutative SDE solves, as
% found in:
"""
+ _parse_reference(ClipStepSizeController)
)
return out
34 changes: 34 additions & 0 deletions test/test_citation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import diffrax as dfx


def test_adaptive_sde(capfd, getkey):
topline = "AUTOGENERATED REFERENCES PRODUCED USING `diffrax.citation(...)`"
sde = "You are solving an SDE, and may wish to cite the textbook"
adaptive = "The use of adaptive step size controllers for SDEs are from"
reject_buffer = (
"You are adaptively solving an SDE whilst revisiting rejected time points"
)

bm = dfx.VirtualBrownianTree(0, 1, 1e-3, (), getkey())
terms = dfx.ControlTerm(lambda t, y, args: -y, bm)
stepsize_controller = dfx.PIDController(rtol=1e-3, atol=1e-3)
capfd.readouterr()
dfx.citation(terms=terms, stepsize_controller=stepsize_controller)
out = capfd.readouterr().out
assert topline in out
assert sde in out
assert adaptive in out
assert reject_buffer not in out

capfd.readouterr()
dfx.citation(
terms=terms,
stepsize_controller=dfx.ClipStepSizeController(
stepsize_controller, store_rejected_steps=1
),
)
out = capfd.readouterr().out
assert topline in out
assert sde in out
assert adaptive in out
assert reject_buffer in out

0 comments on commit 50c9ff1

Please sign in to comment.