@@ -119,6 +119,7 @@ def _is_none(x: Any) -> bool:
119
119
120
120
121
121
def _assert_term_compatible (
122
+ t : FloatScalarLike ,
122
123
y : PyTree [ArrayLike ],
123
124
args : PyTree [Any ],
124
125
terms : PyTree [AbstractTerm ],
@@ -138,7 +139,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
138
139
for term , arg , term_contr_kwarg in zip (
139
140
term .terms , get_args (_tmp ), term_contr_kwargs
140
141
):
141
- _assert_term_compatible (yi , args , term , arg , term_contr_kwarg )
142
+ _assert_term_compatible (t , yi , args , term , arg , term_contr_kwarg )
142
143
else :
143
144
raise ValueError (
144
145
f"Term { term } is not a MultiTerm but is expected to be."
@@ -166,7 +167,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
166
167
elif n_term_args == 2 :
167
168
vf_type_expected , control_type_expected = term_args
168
169
try :
169
- vf_type = eqx .filter_eval_shape (term .vf , 0.0 , yi , args )
170
+ vf_type = eqx .filter_eval_shape (term .vf , t , yi , args )
170
171
except Exception as e :
171
172
raise ValueError (f"Error while tracing { term } .vf: " + str (e ))
172
173
vf_type_compatible = eqx .filter_eval_shape (
@@ -178,7 +179,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
178
179
contr = ft .partial (term .contr , ** term_contr_kwargs )
179
180
# Work around https://github.com/google/jax/issues/21825
180
181
try :
181
- control_type = eqx .filter_eval_shape (contr , 0.0 , 0.0 )
182
+ control_type = eqx .filter_eval_shape (contr , t , t )
182
183
except Exception as e :
183
184
raise ValueError (f"Error while tracing { term } .contr: " + str (e ))
184
185
control_type_compatible = eqx .filter_eval_shape (
@@ -1077,6 +1078,7 @@ def _promote(yi):
1077
1078
if isinstance (solver , (EulerHeun , ItoMilstein , StratonovichMilstein )):
1078
1079
try :
1079
1080
_assert_term_compatible (
1081
+ t0 ,
1080
1082
y0 ,
1081
1083
args ,
1082
1084
terms ,
@@ -1098,6 +1100,7 @@ def _promote(yi):
1098
1100
1099
1101
# Error checking for term compatibility
1100
1102
_assert_term_compatible (
1103
+ t0 ,
1101
1104
y0 ,
1102
1105
args ,
1103
1106
terms ,
0 commit comments