diff --git a/diffrax/_misc.py b/diffrax/_misc.py index 5a5a3b67..c10b37eb 100644 --- a/diffrax/_misc.py +++ b/diffrax/_misc.py @@ -73,17 +73,22 @@ def fill_forward( def linear_rescale(t0, t, t1) -> Array: - """Calculates (t - t0) / (t1 - t0), assuming t0 <= t <= t1. + """Calculates (t - t0) / (t1 - t0). Specially handles the edge case t0 == t1: - zero is returned; - gradients through all three arguments are zero. + - output conditionally clipped to be in [0,1] to protect + from floating point errors. """ cond = t0 == t1 numerator = cast(Array, jnp.where(cond, 0, t - t0)) denominator = cast(Array, jnp.where(cond, 1, t1 - t0)) - return numerator / denominator + out = numerator / denominator + positive_between = (t0 < t1) & (t0 <= t) & (t <= t1) + negative_between = (t1 < t0) & (t <= t0) & (t1 <= t) + return jnp.where(positive_between | negative_between, jnp.clip(out, 0, 1), out) def adjoint_rms_seminorm(x: tuple[PyTree, PyTree, PyTree, PyTree]) -> RealScalarLike: