diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 522a3b779bb3..9c112970aea7 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -772,8 +772,7 @@ def append(x, d): results.append(Zero(ct.aval)) else: if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct)) - # TODO(mattjj): don't skip check with extended dtype tangent types - and not dtypes.issubdtype(a_.dtype, dtypes.extended)): + and not _temporary_dtype_exception(a, a_)): msg = ("Custom VJP bwd rule must produce an output with the same " "shape/dtypes as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " @@ -783,6 +782,14 @@ def append(x, d): results.append(ct) yield results +# TODO(mattjj): remove both these exceptions to cotangent compatibility check +def _temporary_dtype_exception(a, a_) -> bool: + if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray): + return (a.shape == a_.shape and + (dtypes.issubdtype(a_.dtype, dtypes.extended) or + dtypes.issubdtype(a.dtype, dtypes.np.inexact))) + return False + class CustomVJPCallPrimitive(core.CallPrimitive): initial_style: core.Primitive