From b0b88d87d33162a42fefaeccd5dbb528101f62d9 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 22 Feb 2024 15:43:20 -0800 Subject: [PATCH] [attrs] add linearize and vjp support --- jax/_src/api.py | 5 +- jax/experimental/attrs.py | 92 +++++++++++++++++++-- tests/attrs_test.py | 168 +++++++++++++++++++++++++++++++++++++- 3 files changed, 256 insertions(+), 9 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index c3c3b94381b0..3c4480bfb158 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2084,9 +2084,8 @@ def linearize(fun: Callable, *primals, has_aux: bool = False jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree) else: jaxtree_fun, out_tree = flatten_fun_nokwargs(f, in_tree) - out_primals, out_pvals, jaxpr, consts, *maybe_aux = ad.linearize(jaxtree_fun, - *primals_flat, - has_aux=has_aux) + out_primals, out_pvals, jaxpr, consts, *maybe_aux = ad.linearize( + jaxtree_fun, *primals_flat, has_aux=has_aux) if has_aux: out_tree, aux_tree = out_tree() else: diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 9838aa935351..1506b33d21bb 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -22,8 +22,12 @@ from jax._src.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad from jax._src.interpreters import partial_eval as pe -from jax._src.tree_util import tree_flatten, tree_unflatten -from jax._src.util import unzip2 +from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure, + treedef_tuple) +from jax._src.util import unzip2, safe_map, safe_zip, split_list + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip JaxVal = Any @@ -84,8 +88,8 @@ def _setattr_staging(trace, tracer, *, obj, attr): def jvp(f, primals, tangents, attr_tangents): attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents) attr_primals = tuple(jax_getattr(o, a) for o, a in attrs) - primals_flat, in_tree = tree_flatten((attr_primals, primals)) - tangents_flat, in_tree_ = tree_flatten((attr_tangents, tangents)) + primals_flat, in_tree = tree_flatten((attr_primals, *primals)) + tangents_flat, in_tree_ = tree_flatten((attr_tangents, *tangents)) if in_tree != in_tree_: raise Exception f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), in_tree) out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped( @@ -95,7 +99,7 @@ def jvp(f, primals, tangents, attr_tangents): return out_primals, out_tangents, tangent_attrs_out @lu.transformation -def _set_attrs(attrs, attr_vals, args): +def _set_attrs(attrs, attr_vals, *args): for (o, a), x in zip(attrs, attr_vals): jax_setattr(o, a, x) yield (yield args, {}) @@ -134,3 +138,81 @@ def _setattr_jvp(trace, tracer, *, obj, attr): trace.main.attrs_tracked.append((obj, attr)) setattr(obj, attr, tracer) ad.JVPTrace.process_setattr = _setattr_jvp + + +def linearize(f, *primals, attrs: list[tuple[Any, str]] = []): + attr_primals = [jax_getattr(o, a) for o, a in attrs] + attr_avals = [core.raise_to_shaped(core.get_aval(p)) for p in attr_primals] + primals_flat, in_tree = tree_flatten(primals) + tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) + f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree) + primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( + f_, *attr_primals, *primals_flat) + f_lin = _lin_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), + attrs, attrs_out) + return tree_unflatten(out_tree(), primal_out), f_lin + +def _linearize(traceable: lu.WrappedFun, *primals): + jvpfun, attrs = _split_attrs(_jvp(traceable)) + in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) + + tuple(pe.PartialVal.unknown(core.get_aval(p).at_least_vspace()) + for p in primals)) + _, in_tree = tree_flatten((primals, primals)) + jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree) + jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) + out_primals_pvals, out_tangents_pvals, out_tangent_attr_pvals = \ + tree_unflatten(out_tree(), out_pvals) + out_primals_consts = [pval.get_known() for pval in out_primals_pvals] + return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals], + jaxpr, consts, attrs()) + +@lu.transformation_with_aux +def _split_attrs(*args, **kwargs): + primals, tangents, tangent_attrs = yield args, kwargs + attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs) + yield (primals, tangents, tangent_attr_vals), attrs + +def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): + in_tree, out_tree = io_tree + def f_lin(*tangents, attr_tangents): + if set(attr_tangents) - set(in_attrs): raise Exception + tangents_, in_tree_ = tree_flatten(tangents) + assert in_tree == in_tree_ + attr_tangents_ = [attr_tangents.get(a, ad.Zero(aval)) + for a, aval in zip(in_attrs, attr_avals)] + out = core.eval_jaxpr(jaxpr, consts, *attr_tangents_, *tangents_) + out_ = iter(out) + out = [p.get_known() if p.is_known() else next(out_) for p in out_pvals] + assert next(out_, None) is None + tangents_out, attr_tangents_out = split_list(out, [len(out)-len(out_attrs)]) + out_ct = tree_unflatten(out_tree, tangents_out) + return out_ct, dict(zip(out_attrs, attr_tangents_out)) + return f_lin + + +def vjp(f, *primals, attrs: list[tuple[Any, str]] = []): + attr_primals = [jax_getattr(o, a) for o, a in attrs] + primals_flat, in_tree = tree_flatten(primals) + tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) + f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree) + primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( + f_, *attr_primals, *primals_flat) + attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).at_least_vspace() + for o, a in attrs_out] + f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), + attrs, attrs_out) + return tree_unflatten(out_tree(), primal_out), f_vjp + +def _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): + in_tree, out_tree = io_tree + dummies = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] + def f_vjp(out_ct, *, attr_cotangents: dict[tuple[Any, str], JaxVal] = {}): + out_cts, out_tree_ = tree_flatten(out_ct) + assert out_tree == out_tree_ + attr_cts = [attr_cotangents.get(a, ad.Zero(aval)) + for a, aval in zip(out_attrs, attr_avals)] + out = ad.backward_pass(jaxpr, (), (), consts, dummies, (*out_cts, *attr_cts)) + in_attr_bars, arg_cts = split_list(out, [len(in_attrs)]) + args_ct = tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) + return args_ct, dict(zip(in_attrs, in_attr_bars)) + return f_vjp diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 339af904c832..d78ba77988f0 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -38,8 +38,10 @@ @dataclass class Thing: x: float + __hash__ = object.__hash__ + __eq__ = object.__eq__ -attrs.register(Thing) +attrs.register(Thing) # enables passing as arg into jitted function class AttrsTest(jtu.JaxTestCase): @@ -366,6 +368,170 @@ def g_ref(x, x_dot, y, y_dot): self.assertAllClose(w_ddot, w_ddot_, check_dtypes=False) self.assertAllClose(z_ddot, z_ddot_, check_dtypes=False) +class AttrsLinTest(jtu.JaxTestCase): + + @parameterized.parameters([True, False]) + def test_attr_output(self, jit): + thing = Thing(1.0) + + def f(x, _): + y = jnp.sin(x) + jax_setattr(thing, 'x', y) + + if jit: + f = jax.jit(f) + + out, f_lin = attrs.linearize(f, 3.0, 4.0) + self.assertIsNone(out) + self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False) + + out_dot, attr_tangents = f_lin(1.0, 2.0, attr_tangents={}) + self.assertIsNone(out_dot) + self.assertAllClose(thing.x, jnp.sin(3.0)) # didn't change + self.assertLen(attr_tangents, 1) + self.assertAllClose(attr_tangents[(thing, 'x')], jnp.cos(3.0), + check_dtypes=False) + + @parameterized.parameters([True, False]) + def test_attr_input(self, jit): + thing = Thing(1.0) + + def f(): + x = jax_getattr(thing, 'x') + return jnp.sin(x) + + if jit: + f = jax.jit(f) + + out, f_lin = attrs.linearize(f, attrs=[(thing, 'x')]) + self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False) + + out_dot, attr_tangents = f_lin(attr_tangents={(thing, 'x'): 2.0}) + self.assertAllClose(out_dot, 2. * jnp.cos(1.0), check_dtypes=False) + self.assertLen(attr_tangents, 1) + self.assertAllClose(attr_tangents[(thing, 'x')], 2.0, check_dtypes=False) + + @parameterized.parameters([True, False]) + def test_attr_inout(self, jit): + thing1 = Thing(1.0) + thing2 = Thing(2.0) + + def f(x, y): + z = jax_getattr(thing1, 'x') + w = jax_getattr(thing2, 'x') + out = jnp.sin(x * y * z * w) + jax_setattr(thing1, 'x', out) + jax_setattr(thing2, 'x', 2 * out) + return 3 * out, 4 * out + + if jit: + f = jax.jit(f) + + def f_ref(x, y, z, w): + out = jnp.sin(x * y * z * w) + return (3 * out, 4 * out), (out, 2 * out) + + out, f_lin = attrs.linearize(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')]) + expected = (3 * jnp.sin(1. * 2. * 3. * 4.), + 4 * jnp.sin(1. * 2. * 3. * 4.)) + self.assertAllClose(out, expected, check_dtypes=False) + self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.)) + self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.)) + + (out_ref, state_out_ref), f_lin_ref = jax.linearize(f_ref, 3., 4., 1., 2.) + self.assertAllClose(out, out_ref, check_dtypes=False) + self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False) + + out_dot, attr_tangents = f_lin(1., 2., + attr_tangents={(thing1, 'x'): 5., + (thing2, 'x'): 6.}) + self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.)) + self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.)) + (out_dot_ref, state_dot_ref) = f_lin_ref(1., 2., 5., 6.) + self.assertAllClose(out_dot, out_dot_ref, check_dtypes=False) + self.assertLen(attr_tangents, 2) + self.assertAllClose(attr_tangents[(thing1, 'x')], state_dot_ref[0], + check_dtypes=False) + self.assertAllClose(attr_tangents[(thing2, 'x')], state_dot_ref[1], + check_dtypes=False) + +class AttrsVJPTest(jtu.JaxTestCase): + + @parameterized.parameters([True, False]) + def test_attr_input(self, jit): + thing = Thing(1.0) + + def f(): + x = jax_getattr(thing, 'x') + return jnp.sin(x) + + if jit: + f = jax.jit(f) + + out, f_vjp = attrs.vjp(f, attrs=[(thing, 'x')]) + self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False) + + arg_cts, attr_cotangents = f_vjp(1.0) + self.assertEqual(arg_cts, ()) + self.assertLen(attr_cotangents, 1) + self.assertAllClose(attr_cotangents[(thing, 'x')], jnp.cos(1.0), + check_dtypes=False) + + @parameterized.parameters([True, False]) + def test_attr_output(self, jit): + thing = Thing(1.0) + + def f(x, _): + y = jnp.sin(x) + jax_setattr(thing, 'x', y) + + if jit: + f = jax.jit(f) + + out, f_vjp = attrs.vjp(f, 3.0, 4.0) + self.assertIsNone(out) + self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False) + + arg_cts, attr_cotangents = f_vjp(None, attr_cotangents={(thing, 'x'): 2.0}) + self.assertAllClose(arg_cts, (2 * jnp.cos(3.0), 0.), check_dtypes=False) + self.assertLen(attr_cotangents, 0) + + @parameterized.parameters([True, False]) + def test_attr_inout(self, jit): + thing1 = Thing(1.0) + thing2 = Thing(2.0) + + def f(x, y): + z = jax_getattr(thing1, 'x') + w = jax_getattr(thing2, 'x') + out = jnp.sin(x * y * z * w) + jax_setattr(thing1, 'x', out) + jax_setattr(thing2, 'x', 2 * out) + return 3 * out, 4 * out + + if jit: + f = jax.jit(f) + + def f_ref(x, y, z, w): + out = jnp.sin(x * y * z * w) + return (3 * out, 4 * out), (out, 2 * out) + + out, f_vjp = attrs.vjp(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')]) + (out_ref, state_out_ref), f_vjp_ref = jax.vjp(f_ref, 3., 4., 1., 2.) + self.assertAllClose(out, out_ref, check_dtypes=False) + self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False) + + in_bar, attr_cotangents = f_vjp((1., 2.), + attr_cotangents={(thing1, 'x'): 5., + (thing2, 'x'): 6.}) + in_bar_ref_ = f_vjp_ref(((1., 2.), (5., 6.))) + in_bar_ref, attr_cotangents_ref = in_bar_ref_[:2], in_bar_ref_[2:] + self.assertAllClose(in_bar, in_bar_ref, check_dtypes=False) + self.assertLen(attr_cotangents, 2) + self.assertAllClose(attr_cotangents[(thing1, 'x')], attr_cotangents_ref[0], + check_dtypes=False) + self.assertAllClose(attr_cotangents[(thing2, 'x')], attr_cotangents_ref[1], + check_dtypes=False) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())