Skip to content

Commit 134a40a

Browse files
FloListpatrick-kidger
authored andcommitted
Fix issue #563: time t0 instead of 0 is passed to _check in _assert_term_compatible, so the term is not required to be well-defined at time 0, but rather at time t0.
1 parent 4a308b8 commit 134a40a

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

diffrax/_integrate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def _is_none(x: Any) -> bool:
119119

120120

121121
def _assert_term_compatible(
122+
t: FloatScalarLike,
122123
y: PyTree[ArrayLike],
123124
args: PyTree[Any],
124125
terms: PyTree[AbstractTerm],
@@ -138,7 +139,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
138139
for term, arg, term_contr_kwarg in zip(
139140
term.terms, get_args(_tmp), term_contr_kwargs
140141
):
141-
_assert_term_compatible(yi, args, term, arg, term_contr_kwarg)
142+
_assert_term_compatible(t, yi, args, term, arg, term_contr_kwarg)
142143
else:
143144
raise ValueError(
144145
f"Term {term} is not a MultiTerm but is expected to be."
@@ -166,7 +167,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
166167
elif n_term_args == 2:
167168
vf_type_expected, control_type_expected = term_args
168169
try:
169-
vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args)
170+
vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
170171
except Exception as e:
171172
raise ValueError(f"Error while tracing {term}.vf: " + str(e))
172173
vf_type_compatible = eqx.filter_eval_shape(
@@ -178,7 +179,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
178179
contr = ft.partial(term.contr, **term_contr_kwargs)
179180
# Work around https://github.com/google/jax/issues/21825
180181
try:
181-
control_type = eqx.filter_eval_shape(contr, 0.0, 0.0)
182+
control_type = eqx.filter_eval_shape(contr, t, t)
182183
except Exception as e:
183184
raise ValueError(f"Error while tracing {term}.contr: " + str(e))
184185
control_type_compatible = eqx.filter_eval_shape(
@@ -1077,6 +1078,7 @@ def _promote(yi):
10771078
if isinstance(solver, (EulerHeun, ItoMilstein, StratonovichMilstein)):
10781079
try:
10791080
_assert_term_compatible(
1081+
t0,
10801082
y0,
10811083
args,
10821084
terms,
@@ -1098,6 +1100,7 @@ def _promote(yi):
10981100

10991101
# Error checking for term compatibility
11001102
_assert_term_compatible(
1103+
t0,
11011104
y0,
11021105
args,
11031106
terms,

0 commit comments

Comments
 (0)