From db28f5e751b257eb57294d23813b3abda4b1dd92 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 1 Mar 2024 11:07:45 -0800 Subject: [PATCH] [mutable-arrays] move MutableArray, add eager, improve tests, fix bug 1. move MutableArray to core.py, and some handlers to their respective files 2. fix a bug in aliasing setup (it was just broken before, now better test coverage) 3. add eager support by enabling get_p, swap_p, and addupdate_p impls 4. improve tests slightly --- jax/_src/core.py | 24 +++++++++ jax/_src/interpreters/partial_eval.py | 7 +++ jax/_src/interpreters/pxla.py | 23 +++++---- jax/_src/interpreters/xla.py | 5 +- jax/_src/state/primitives.py | 19 ++----- tests/state_test.py | 71 +++++++++++++-------------- 6 files changed, 86 insertions(+), 63 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index b02de88fce8d..10493a91c13d 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -35,6 +35,7 @@ import numpy as np +import jax._src from jax._src import dtypes from jax._src import config from jax._src import effects @@ -1912,6 +1913,29 @@ def __str__(self) -> str: AxisSize = Union[int, DArray, Tracer, Var, DBIdx, InDBIdx, OutDBIdx] +class MutableArray: + _aval: ShapedArray + _buf: jax.Array + def __init__(self, aval, buf): + self._aval = aval + self._buf = buf + aval = property(lambda self: self._aval) + shape = property(lambda self: self._aval.shape) + dtype = property(lambda self: self._aval.dtype) + def __getitem__(self, idx): return get_aval(self)._getitem(self, idx) + def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x) +pytype_aval_mappings[MutableArray] = lambda x: x._aval + +def mutable_array(init_val): + return mutable_array_p.bind(init_val) +mutable_array_p = Primitive('mutable_array') + +@mutable_array_p.def_impl +def _mutable_array_impl(init_val): + aval = raise_to_shaped(get_aval(init_val)) + return MutableArray(jax._src.state.types.AbstractRef(aval), init_val) + + class AbstractToken(AbstractValue): def join(self, other): if isinstance(other, AbstractToken): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 64f768d75bb9..ebdc01300064 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2730,6 +2730,13 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): return prim.bind(*subfuns, *args, **bind_params) +def _error_staging_mutable_array_p(trace, x): + raise Exception( + "mutable_array constructor can't be staged out, and in particular can't " + "be used under a jax.jit or jax.lax.scan") +custom_staging_rules[core.mutable_array_p] = _error_staging_mutable_array_p + + # TODO(mattjj): the following are deprecated; update callers to _nounits version # See https://github.com/google/jax/pull/9498 @lu.transformation diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3a03cbec39af..f9cfbfc9ead3 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -160,6 +160,10 @@ def _shard_darray(x, sharding): return shard_arg(x._data, sharding) shard_arg_handlers[core.DArray] = _shard_darray +def _shard_mutable_array(x, sharding): + return shard_arg(x._buf, sharding) +shard_arg_handlers[core.MutableArray] = _shard_mutable_array + def batched_device_put(aval: core.ShapedArray, sharding: jax.sharding.Sharding, xs: Sequence[Any], devices: Sequence[jax.Device], committed: bool = True): @@ -1778,17 +1782,16 @@ def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name, @weakref_lru_cache def _discharge_refs( jaxpr: core.ClosedJaxpr -) -> tuple[core.ClosedJaxpr, None | Sequence[int | None], None | Sequence[int | None]]: +) -> tuple[core.ClosedJaxpr, Sequence[int | None], Sequence[int | None]]: from jax._src.state.discharge import discharge_state - out_mut = [None] * len(jaxpr.out_avals) + [ - i for i, a in enumerate(jaxpr.in_avals) if isinstance(a, AbstractRef)] - count = it.count() - inout_aliases = tuple(next(count) if isinstance(a, AbstractRef) else None - for a in jaxpr.in_avals) - jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts)) - assert len(inout_aliases) == len(jaxpr.in_avals) - assert len(out_mut) == len(jaxpr.out_avals) - return jaxpr, inout_aliases, out_mut + new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts)) + count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end + inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals) + if isinstance(a, AbstractRef)} + outin_map = {j: i for i, j in inout_map.items()} + inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals)))) + out_mut = tuple(map(outin_map.get, range(len(new_jaxpr.out_avals)))) + return new_jaxpr, inout_aliases, out_mut @dataclasses.dataclass(frozen=True) diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index f30fc5b5d6fb..3b3c7f74515a 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -22,7 +22,6 @@ import functools from functools import partial import itertools as it -import operator from typing import Any, Callable, Protocol, Union import numpy as np @@ -166,6 +165,7 @@ def _canonicalize_python_scalar_dtype(typ, x): (t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types) canonicalize_dtype_handlers[core.Token] = identity canonicalize_dtype_handlers[core.DArray] = identity +canonicalize_dtype_handlers[core.MutableArray] = identity def abstractify(x) -> Any: typ = type(x) @@ -196,7 +196,8 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {} -pytype_aval_mappings[core.DArray] = operator.attrgetter('_aval') +pytype_aval_mappings[core.DArray] = lambda x: x._aval +pytype_aval_mappings[core.MutableArray] = lambda x: x._aval pytype_aval_mappings[core.Token] = lambda _: core.abstract_token pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar) for t in numpy_scalar_types) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 9c35b9a9697a..4b55792f7388 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -22,6 +22,7 @@ from jax._src import ad_util from jax._src import core +from jax._src import dispatch from jax._src import pretty_printer as pp from jax._src import tree_util from jax._src.interpreters import ad @@ -53,11 +54,7 @@ # `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like # a:f32[3] <- x[] get_p = core.Primitive("get") - -def _get_impl(ref: AbstractRef, *args: Any, tree): - del ref, args, tree - raise ValueError("Cannot run stateful primitive.") -get_p.def_impl(_get_impl) +get_p.def_impl(partial(dispatch.apply_primitive, get_p)) Indexer = tuple[Union[int, slice, Array], ...] # or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType) @@ -113,11 +110,7 @@ def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array: # are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like # x:Ref{f32[3]}[i, j] <- a swap_p = core.Primitive("swap") - -def _swap_impl(ref: AbstractRef, value: Array, *idx: Any, tree): - del ref, value, idx, tree - raise ValueError("Cannot run stateful primitive.") -swap_p.def_impl(_swap_impl) +swap_p.def_impl(partial(dispatch.apply_primitive, swap_p)) def ref_swap(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array, _function_name: str = "ref_swap") -> Array: @@ -143,11 +136,7 @@ def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Arra # ``` addupdate_p = core.Primitive('addupdate') addupdate_p.multiple_results = True - -def _addupdate_impl(ref: AbstractRef, value: Array, *args: Any, tree): - del ref, value, args, tree - raise ValueError("Can't evaluate `addupdate` outside a stateful context.") -addupdate_p.def_impl(_addupdate_impl) +addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p)) def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None: """Mutates a ref with an additive update i.e. `ref[idx] += x`.""" diff --git a/tests/state_test.py b/tests/state_test.py index 931d64641b67..cc5bd9451f0a 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -1508,55 +1508,54 @@ def _body(ref): jtu.check_grads(f, (0.5,), order=3) -class MutableArray: - _aval: core.ShapedArray - _buf: jax.Array - def __init__(self, aval, buf): - self._aval = aval - self._buf = buf - aval = property(lambda self: self._aval) - shape = property(lambda self: self._aval.shape) - dtype = property(lambda self: self._aval.dtype) - -def mutable_array(init_val): - return mutable_array_p.bind(init_val) -mutable_array_p = core.Primitive('mutable_array') - -@mutable_array_p.def_impl -def _mutable_array_impl(init_val): - aval = core.raise_to_shaped(core.get_aval(init_val)) - return MutableArray(AbstractRef(aval), init_val) - -def _error_on_staging(trace, x): - raise Exception -pe.custom_staging_rules[mutable_array_p] = _error_on_staging - -from jax._src.interpreters import xla -from jax._src.interpreters import pxla -xla.canonicalize_dtype_handlers[MutableArray] = lambda x: x -xla.pytype_aval_mappings[MutableArray] = lambda x: x._aval -pxla.shard_arg_handlers[MutableArray] = lambda x, s: pxla.shard_arg(x._buf, s) -core.pytype_aval_mappings[MutableArray] = lambda x: x._aval - class MutableArrayTest(jtu.JaxTestCase): - def test_basic(self): - read = jax.jit(lambda x_ref: x_ref[...]) - - @jax.jit + @parameterized.parameters([True, False]) + def test_basic(self, jit): def f(x_mut): x_mut[...] += 1. x_mut[0] += 1 x_mut[1] += 5 - x_mut = mutable_array(jnp.zeros(3)) + if jit: + f = jax.jit(f) + + x_mut = core.mutable_array(jnp.zeros(3)) f(x_mut) - self.assertAllClose(read(x_mut), jnp.array([2., 6., 1.]), check_dtypes=False) + self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), + check_dtypes=False) jaxpr = jax.make_jaxpr(f)(x_mut) self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects)) + def test_staging_error(self): + x = jnp.zeros(3) + with self.assertRaises(Exception): + jax.jit(core.mutable_array)(x) + + @parameterized.parameters([True, False]) + def test_multiple_inputs_and_outputs(self, jit): + def f(x_mut, y, z_mut, w): + x_mut[...] += 1 + z_mut[...] += 1 + return x_mut[...] + y + z_mut[...] + w, y + w + + if jit: + f = jax.jit(f) + + x_mut = core.mutable_array(jnp.zeros((1, 3))) + y = jnp.ones((2, 3)) + z_mut = core.mutable_array(jnp.zeros((2, 3))) + w = jnp.ones((2, 1)) + + out1, out2 = f(x_mut, y, z_mut, w) + + self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False) + self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False) + self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False) + self.assertAllClose(out2, y + w, check_dtypes=False) + if CAN_USE_HYPOTHESIS: