Skip to content

Commit

Permalink
add a shape mismatch check and error to custom_vjp
Browse files Browse the repository at this point in the history
no idea how we lasted so long without this...
  • Loading branch information
mattjj committed Mar 14, 2024
1 parent 6046d7d commit 1326c74
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 12 deletions.
42 changes: 33 additions & 9 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@
from jax._src.lax import lax
from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_map,
treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves)
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
register_pytree_node_class, tree_leaves,
tree_flatten_with_path, keystr)
from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable,
unzip2)


traceback_util.register_exclusion(__file__)
Expand Down Expand Up @@ -733,6 +735,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
# object, to be replaced with Nones in the final returned result.
zero = object() # non-pytree sentinel to replace Nones in py_cts_in
dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
keypaths, _ = unzip2(tree_flatten_with_path(dummy)[0])
cts_in_flat = []
def append(x, d):
num_leaves = len(tree_flatten(d)[0])
Expand All @@ -747,17 +750,38 @@ def append(x, d):
tree_map(append, py_cts_in, dummy, is_leaf=lambda x: x is None)
except ValueError:
_, in_tree2 = tree_flatten(py_cts_in)
msg = ("Custom VJP rule must produce an output with the same container "
msg = ("Custom VJP bwd rule must produce an output with the same container "
"(pytree) structure as the args tuple of the primal function, "
"and in particular must produce a tuple of length equal to the "
"number of arguments to the primal function, but got VJP output "
"number of arguments to the primal function, but got bwd output "
"structure {} for primal input structure {}.")
raise TypeError(msg.format(in_tree2, in_tree)) from None
# Ignore any None cotangents, and any corresponding to inputs for which the
# type doesn't equal the tangent type (i.e. float0s)
# TODO(mattjj): change this to check if tangent type represents 0dim vspace
yield [Zero(a.at_least_vspace()) if ct is zero or a != a.at_least_vspace()
else ct for a, ct in zip(in_avals, cts_in_flat)]
results = []
for kp, a, ct in zip(keypaths, in_avals, cts_in_flat):
if ct is zero or a != a.at_least_vspace():
results.append(Zero(a.at_least_vspace()))
elif type(ct) is SymbolicZero:
if not core.typecompat(a.at_least_vspace(), a_ := ct.aval):
msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype "
"that does not match the corresponding input tangent shape/dtype: "
f"the SymbolicZero had shape/dtype {a_.str_short()} while the "
f"corresponding input had shape/dtype {a.str_short()}. "
"Consider just returning a None here instead of a SymbolicZero "
"object.")
raise ValueError(msg)
results.append(Zero(ct.aval))
else:
if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct))
# TODO(mattjj): don't skip check with extended dtype tangent types
and not dtypes.issubdtype(a_.dtype, dtypes.extended)):
msg = ("Custom VJP bwd rule must produce an output with the same "
"shape/dtypes as the args tuple of the primal function, but at "
f"output{keystr(kp)} the bwd rule produced an output of "
f"shape/dtype {raise_to_shaped(a_).str_short()} corresponding "
f"to an input of shape/dtype {a.str_short()}.")
raise ValueError(msg)
results.append(ct)
yield results


class CustomVJPCallPrimitive(core.CallPrimitive):
Expand Down
44 changes: 41 additions & 3 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8244,10 +8244,10 @@ def foo_bwd(_, g):
self.assertRaisesRegex(
TypeError,
re.escape(
"Custom VJP rule must produce an output with the same container "
"Custom VJP bwd rule must produce an output with the same container "
"(pytree) structure as the args tuple of the primal function, "
"and in particular must produce a tuple of length equal to the "
"number of arguments to the primal function, but got VJP output "
"number of arguments to the primal function, but got bwd output "
"structure {} for primal input structure {}.".format(
jax.tree.structure((1, 1)),
jax.tree.structure((1,)))
Expand All @@ -8266,7 +8266,7 @@ def foo_bwd(_, g):
return 2. * g # Should be a tuple

f.defvjp(foo_fwd, foo_bwd)
with self.assertRaisesRegex(TypeError, "Custom VJP rule .* must produce a tuple"):
with self.assertRaisesRegex(TypeError, "Custom VJP bwd rule .* must produce a tuple"):
api.grad(f)(3.)

def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self):
Expand Down Expand Up @@ -8996,6 +8996,26 @@ def bwd_snd(_, g):
gx, = vjp(x)
self.assertArraysAllClose(gx, zero)

def test_symbolic_zero_custom_vjp_bwd_shape_error(self):
@jax.custom_vjp
def f(x, y, z):
return x, y, z

def fwd(x, y, z):
return f(x.value, y.value, z.value), None

def bwd(_, gs):
x_bar, y_bar, z_bar = gs
return y_bar, x_bar, z_bar # swapped!

f.defvjp(fwd, bwd, symbolic_zeros=True)

with self.assertRaisesRegex(
ValueError,
r'Consider just returning a None here'):
jax.grad(lambda x, y, z: f(x, y, z)[2].sum())(
jnp.ones(1), jnp.ones(2), jnp.ones(3))

@parameterized.named_parameters(
('jit_vmap', True, True),
('jit', True, False),
Expand Down Expand Up @@ -9251,6 +9271,24 @@ def f_bwd(_, z_bar):

jax.grad(f)((1.0, (2.0, None))) # don't crash

def test_bwd_rule_shape_mismatch(self):
@jax.custom_vjp
def foo(x, y):
return x

def foo_fwd(x, y):
return x, None

def foo_bwd(_, g):
return jnp.zeros(3), jnp.zeros(3)

foo.defvjp(foo_fwd, foo_bwd)

with self.assertRaisesRegex(
ValueError,
r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'):
jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4))


def transpose_unary(f, x_example):
def transposed(y):
Expand Down

0 comments on commit 1326c74

Please sign in to comment.