Skip to content

Commit

Permalink
[mutable-arrays] move MutableArray, add eager, improve tests, fix bug
Browse files Browse the repository at this point in the history
1. move MutableArray to core.py, and some handlers to their respective files
2. fix a bug in aliasing setup (it was just broken before, now better test coverage)
3. add eager support by enabling get_p, swap_p, and addupdate_p impls
4. improve tests slightly
  • Loading branch information
mattjj committed Mar 1, 2024
1 parent 2761f26 commit db28f5e
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 63 deletions.
24 changes: 24 additions & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import numpy as np

import jax._src
from jax._src import dtypes
from jax._src import config
from jax._src import effects
Expand Down Expand Up @@ -1912,6 +1913,29 @@ def __str__(self) -> str:
AxisSize = Union[int, DArray, Tracer, Var, DBIdx, InDBIdx, OutDBIdx]


class MutableArray:
_aval: ShapedArray
_buf: jax.Array
def __init__(self, aval, buf):
self._aval = aval
self._buf = buf
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
pytype_aval_mappings[MutableArray] = lambda x: x._aval

def mutable_array(init_val):
return mutable_array_p.bind(init_val)
mutable_array_p = Primitive('mutable_array')

@mutable_array_p.def_impl
def _mutable_array_impl(init_val):
aval = raise_to_shaped(get_aval(init_val))
return MutableArray(jax._src.state.types.AbstractRef(aval), init_val)


class AbstractToken(AbstractValue):
def join(self, other):
if isinstance(other, AbstractToken):
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2730,6 +2730,13 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
return prim.bind(*subfuns, *args, **bind_params)


def _error_staging_mutable_array_p(trace, x):
raise Exception(
"mutable_array constructor can't be staged out, and in particular can't "
"be used under a jax.jit or jax.lax.scan")
custom_staging_rules[core.mutable_array_p] = _error_staging_mutable_array_p


# TODO(mattjj): the following are deprecated; update callers to _nounits version
# See https://github.com/google/jax/pull/9498
@lu.transformation
Expand Down
23 changes: 13 additions & 10 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def _shard_darray(x, sharding):
return shard_arg(x._data, sharding)
shard_arg_handlers[core.DArray] = _shard_darray

def _shard_mutable_array(x, sharding):
return shard_arg(x._buf, sharding)
shard_arg_handlers[core.MutableArray] = _shard_mutable_array

def batched_device_put(aval: core.ShapedArray,
sharding: jax.sharding.Sharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
Expand Down Expand Up @@ -1778,17 +1782,16 @@ def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
@weakref_lru_cache
def _discharge_refs(
jaxpr: core.ClosedJaxpr
) -> tuple[core.ClosedJaxpr, None | Sequence[int | None], None | Sequence[int | None]]:
) -> tuple[core.ClosedJaxpr, Sequence[int | None], Sequence[int | None]]:
from jax._src.state.discharge import discharge_state
out_mut = [None] * len(jaxpr.out_avals) + [
i for i, a in enumerate(jaxpr.in_avals) if isinstance(a, AbstractRef)]
count = it.count()
inout_aliases = tuple(next(count) if isinstance(a, AbstractRef) else None
for a in jaxpr.in_avals)
jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
assert len(inout_aliases) == len(jaxpr.in_avals)
assert len(out_mut) == len(jaxpr.out_avals)
return jaxpr, inout_aliases, out_mut
new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end
inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals)
if isinstance(a, AbstractRef)}
outin_map = {j: i for i, j in inout_map.items()}
inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals))))
out_mut = tuple(map(outin_map.get, range(len(new_jaxpr.out_avals))))
return new_jaxpr, inout_aliases, out_mut


@dataclasses.dataclass(frozen=True)
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import functools
from functools import partial
import itertools as it
import operator
from typing import Any, Callable, Protocol, Union

import numpy as np
Expand Down Expand Up @@ -166,6 +165,7 @@ def _canonicalize_python_scalar_dtype(typ, x):
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
canonicalize_dtype_handlers[core.Token] = identity
canonicalize_dtype_handlers[core.DArray] = identity
canonicalize_dtype_handlers[core.MutableArray] = identity

def abstractify(x) -> Any:
typ = type(x)
Expand Down Expand Up @@ -196,7 +196,8 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:


pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {}
pytype_aval_mappings[core.DArray] = operator.attrgetter('_aval')
pytype_aval_mappings[core.DArray] = lambda x: x._aval
pytype_aval_mappings[core.MutableArray] = lambda x: x._aval
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
for t in numpy_scalar_types)
Expand Down
19 changes: 4 additions & 15 deletions jax/_src/state/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import pretty_printer as pp
from jax._src import tree_util
from jax._src.interpreters import ad
Expand Down Expand Up @@ -53,11 +54,7 @@
# `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# a:f32[3] <- x[]
get_p = core.Primitive("get")

def _get_impl(ref: AbstractRef, *args: Any, tree):
del ref, args, tree
raise ValueError("Cannot run stateful primitive.")
get_p.def_impl(_get_impl)
get_p.def_impl(partial(dispatch.apply_primitive, get_p))

Indexer = tuple[Union[int, slice, Array], ...]
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
Expand Down Expand Up @@ -113,11 +110,7 @@ def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
# x:Ref{f32[3]}[i, j] <- a
swap_p = core.Primitive("swap")

def _swap_impl(ref: AbstractRef, value: Array, *idx: Any, tree):
del ref, value, idx, tree
raise ValueError("Cannot run stateful primitive.")
swap_p.def_impl(_swap_impl)
swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))

def ref_swap(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array,
_function_name: str = "ref_swap") -> Array:
Expand All @@ -143,11 +136,7 @@ def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Arra
# ```
addupdate_p = core.Primitive('addupdate')
addupdate_p.multiple_results = True

def _addupdate_impl(ref: AbstractRef, value: Array, *args: Any, tree):
del ref, value, args, tree
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
addupdate_p.def_impl(_addupdate_impl)
addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p))

def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
Expand Down
71 changes: 35 additions & 36 deletions tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,55 +1508,54 @@ def _body(ref):
jtu.check_grads(f, (0.5,), order=3)


class MutableArray:
_aval: core.ShapedArray
_buf: jax.Array
def __init__(self, aval, buf):
self._aval = aval
self._buf = buf
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)

def mutable_array(init_val):
return mutable_array_p.bind(init_val)
mutable_array_p = core.Primitive('mutable_array')

@mutable_array_p.def_impl
def _mutable_array_impl(init_val):
aval = core.raise_to_shaped(core.get_aval(init_val))
return MutableArray(AbstractRef(aval), init_val)

def _error_on_staging(trace, x):
raise Exception
pe.custom_staging_rules[mutable_array_p] = _error_on_staging

from jax._src.interpreters import xla
from jax._src.interpreters import pxla
xla.canonicalize_dtype_handlers[MutableArray] = lambda x: x
xla.pytype_aval_mappings[MutableArray] = lambda x: x._aval
pxla.shard_arg_handlers[MutableArray] = lambda x, s: pxla.shard_arg(x._buf, s)
core.pytype_aval_mappings[MutableArray] = lambda x: x._aval

class MutableArrayTest(jtu.JaxTestCase):

def test_basic(self):
read = jax.jit(lambda x_ref: x_ref[...])

@jax.jit
@parameterized.parameters([True, False])
def test_basic(self, jit):
def f(x_mut):
x_mut[...] += 1.
x_mut[0] += 1
x_mut[1] += 5

x_mut = mutable_array(jnp.zeros(3))
if jit:
f = jax.jit(f)

x_mut = core.mutable_array(jnp.zeros(3))
f(x_mut)

self.assertAllClose(read(x_mut), jnp.array([2., 6., 1.]), check_dtypes=False)
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
check_dtypes=False)

jaxpr = jax.make_jaxpr(f)(x_mut)
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))

def test_staging_error(self):
x = jnp.zeros(3)
with self.assertRaises(Exception):
jax.jit(core.mutable_array)(x)

@parameterized.parameters([True, False])
def test_multiple_inputs_and_outputs(self, jit):
def f(x_mut, y, z_mut, w):
x_mut[...] += 1
z_mut[...] += 1
return x_mut[...] + y + z_mut[...] + w, y + w

if jit:
f = jax.jit(f)

x_mut = core.mutable_array(jnp.zeros((1, 3)))
y = jnp.ones((2, 3))
z_mut = core.mutable_array(jnp.zeros((2, 3)))
w = jnp.ones((2, 1))

out1, out2 = f(x_mut, y, z_mut, w)

self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False)
self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False)
self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False)
self.assertAllClose(out2, y + w, check_dtypes=False)


if CAN_USE_HYPOTHESIS:

Expand Down

0 comments on commit db28f5e

Please sign in to comment.