Skip to content

Commit

Permalink
temporarily relax the cotangent dtype check introduced in #19009
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615883208
  • Loading branch information
mattjj authored and jax authors committed Mar 14, 2024
1 parent 993abb1 commit 6f38f27
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,7 @@ def append(x, d):
results.append(Zero(ct.aval))
else:
if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct))
# TODO(mattjj): don't skip check with extended dtype tangent types
and not dtypes.issubdtype(a_.dtype, dtypes.extended)):
and not _temporary_dtype_exception(a, a_)):
msg = ("Custom VJP bwd rule must produce an output with the same "
"shape/dtypes as the args tuple of the primal function, but at "
f"output{keystr(kp)} the bwd rule produced an output of "
Expand All @@ -783,6 +782,14 @@ def append(x, d):
results.append(ct)
yield results

# TODO(mattjj): remove both these exceptions to cotangent compatibility check
def _temporary_dtype_exception(a, a_) -> bool:
if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray):
return (a.shape == a_.shape and
(dtypes.issubdtype(a_.dtype, dtypes.extended) or
dtypes.issubdtype(a.dtype, dtypes.np.inexact)))
return False


class CustomVJPCallPrimitive(core.CallPrimitive):
initial_style: core.Primitive
Expand Down

0 comments on commit 6f38f27

Please sign in to comment.