From 68d625a564a78fdab35e1d919b9e57f0061c0db9 Mon Sep 17 00:00:00 2001 From: Philip Wijesinghe Date: Wed, 5 Nov 2025 11:46:58 +0000 Subject: [PATCH 1/2] fix float error in prev_dt step calculation that led to an infinite loop When: dt is clipped to dtmin, and we wish to continue solver (force_dtmin=True) Calculating if a step should be kept from: prev_dt = t1 - t0 (next_t1 = next_t0 + dt (in previous step)) keep_step = keep_step | (prev_dt <= self.dtmin) can result in float error for high t0 where prev_dt is never <= self.dtmin, and further steps are never accepted -> infinite loop Fix: add a keep_next_step: bool flag to controller_state, and track when we are, and continue to be, at dtmin --- diffrax/_step_size_controller/pid.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/diffrax/_step_size_controller/pid.py b/diffrax/_step_size_controller/pid.py index 1092f184..56c19f61 100644 --- a/diffrax/_step_size_controller/pid.py +++ b/diffrax/_step_size_controller/pid.py @@ -81,8 +81,8 @@ def intermediate(carry): return jnp.minimum(100 * h0, h1) -# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error) -_PidState = tuple[RealScalarLike, RealScalarLike] +# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, keep_next_step) +_PidState = tuple[RealScalarLike, RealScalarLike, BoolScalarLike] # We use a metaclass for backwards compatibility. When a user calls @@ -388,6 +388,7 @@ def init( return t1, ( jnp.array(1.0, dtype=real_dtype), jnp.array(1.0, dtype=real_dtype), + False, ) def adapt_step_size( @@ -469,6 +470,7 @@ def adapt_step_size( ( prev_inv_scaled_error, prev_prev_inv_scaled_error, + keep_next_step, ) = controller_state error_order = self._get_error_order(error_order) prev_dt = t1 - t0 @@ -489,9 +491,9 @@ def _scale(_y0, _y1_candidate, _y_error): scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error)) keep_step = scaled_error < 1 - # Automatically keep the step if we're at dtmin. + # Automatically keep the step if it was at dtmin. if self.dtmin is not None: - keep_step = keep_step | (prev_dt <= self.dtmin) + keep_step = keep_step | keep_next_step # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. inv_scaled_error = 1 / jnp.asarray(scaled_error) inv_scaled_error = lax.stop_gradient( @@ -545,6 +547,9 @@ def _scale(_y0, _y1_candidate, _y_error): if self.dtmin is not None: if not self.force_dtmin: result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result) + # flag next step to be kept if dtmin is reached + # or if it was reached previously and dt is unchanged + keep_next_step = (dt <= self.dtmin) | (keep_next_step & (factor == 1)) dt = jnp.maximum(dt, self.dtmin) next_t0 = jnp.where(keep_step, t1, t0) @@ -554,7 +559,7 @@ def _scale(_y0, _y1_candidate, _y_error): prev_inv_scaled_error = jnp.where( keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error ) - controller_state = inv_scaled_error, prev_inv_scaled_error + controller_state = inv_scaled_error, prev_inv_scaled_error, keep_next_step # made_jump is handled by ClipStepSizeController, so we automatically set it to # False return keep_step, next_t0, next_t1, False, controller_state, result From 81990570438612e0eecb98178f0ff283c2cdb5d0 Mon Sep 17 00:00:00 2001 From: Philip Wijesinghe Date: Thu, 6 Nov 2025 09:08:57 +0000 Subject: [PATCH 2/2] avoids accumulation of float precision errors in dt this solution makes sure that dt is reset to the desired dtmin value if the previous step was at dtmin and dt is unchanged (factor=1) if we do not reset dt then the recalculation of prev_dt = t1 - t0 will keep accumulating float precision errors with potential to drift away from the desired dtmin until a step that warrant a relaxation of step size (factor>1) these errors are likely to be minor, but i believe this is the intended behaviour --- diffrax/_step_size_controller/pid.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/diffrax/_step_size_controller/pid.py b/diffrax/_step_size_controller/pid.py index 56c19f61..c0ccfbec 100644 --- a/diffrax/_step_size_controller/pid.py +++ b/diffrax/_step_size_controller/pid.py @@ -81,7 +81,7 @@ def intermediate(carry): return jnp.minimum(100 * h0, h1) -# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, keep_next_step) +# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, at_dtmin) _PidState = tuple[RealScalarLike, RealScalarLike, BoolScalarLike] @@ -470,7 +470,7 @@ def adapt_step_size( ( prev_inv_scaled_error, prev_prev_inv_scaled_error, - keep_next_step, + at_dtmin, ) = controller_state error_order = self._get_error_order(error_order) prev_dt = t1 - t0 @@ -493,7 +493,7 @@ def _scale(_y0, _y1_candidate, _y_error): keep_step = scaled_error < 1 # Automatically keep the step if it was at dtmin. if self.dtmin is not None: - keep_step = keep_step | keep_next_step + keep_step = keep_step | at_dtmin # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. inv_scaled_error = 1 / jnp.asarray(scaled_error) inv_scaled_error = lax.stop_gradient( @@ -547,9 +547,11 @@ def _scale(_y0, _y1_candidate, _y_error): if self.dtmin is not None: if not self.force_dtmin: result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result) - # flag next step to be kept if dtmin is reached - # or if it was reached previously and dt is unchanged - keep_next_step = (dt <= self.dtmin) | (keep_next_step & (factor == 1)) + # if we are already at dtmin and dt is unchanged (factor == 1), + # reset dt to dtmin to avoid accumulating float precision errors + dt = jnp.where(at_dtmin & (factor == 1), self.dtmin, dt) + # this flags the next loop to accept step + at_dtmin = dt <= self.dtmin dt = jnp.maximum(dt, self.dtmin) next_t0 = jnp.where(keep_step, t1, t0) @@ -559,7 +561,7 @@ def _scale(_y0, _y1_candidate, _y_error): prev_inv_scaled_error = jnp.where( keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error ) - controller_state = inv_scaled_error, prev_inv_scaled_error, keep_next_step + controller_state = inv_scaled_error, prev_inv_scaled_error, at_dtmin # made_jump is handled by ClipStepSizeController, so we automatically set it to # False return keep_step, next_t0, next_t1, False, controller_state, result