Skip to content

Commit

Permalink
Merge pull request #20044 from mattjj:mutable-arrays
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611866507
  • Loading branch information
jax authors committed Mar 1, 2024
2 parents 04f6bfa + 3a403f2 commit 28f84eb
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 75 deletions.
24 changes: 24 additions & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1912,6 +1912,30 @@ def __str__(self) -> str:
AxisSize = Union[int, DArray, Tracer, Var, DBIdx, InDBIdx, OutDBIdx]


class MutableArray:
_aval: ShapedArray
_buf: Array
def __init__(self, aval, buf):
self._aval = aval
self._buf = buf
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
pytype_aval_mappings[MutableArray] = lambda x: x._aval

def mutable_array(init_val):
return mutable_array_p.bind(init_val)
mutable_array_p = Primitive('mutable_array')

@mutable_array_p.def_impl
def _mutable_array_impl(init_val):
from jax._src.state.types import AbstractRef # type: ignore[import]
aval = raise_to_shaped(get_aval(init_val))
return MutableArray(AbstractRef(aval), init_val)


class AbstractToken(AbstractValue):
def join(self, other):
if isinstance(other, AbstractToken):
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2730,6 +2730,13 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
return prim.bind(*subfuns, *args, **bind_params)


def _error_staging_mutable_array_p(trace, x):
raise Exception(
"mutable_array constructor can't be staged out, and in particular can't "
"be used under a jax.jit or jax.lax.scan")
custom_staging_rules[core.mutable_array_p] = _error_staging_mutable_array_p


# TODO(mattjj): the following are deprecated; update callers to _nounits version
# See https://github.com/google/jax/pull/9498
@lu.transformation
Expand Down
23 changes: 13 additions & 10 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def _shard_darray(x, sharding):
return shard_arg(x._data, sharding)
shard_arg_handlers[core.DArray] = _shard_darray

def _shard_mutable_array(x, sharding):
return shard_arg(x._buf, sharding)
shard_arg_handlers[core.MutableArray] = _shard_mutable_array

def batched_device_put(aval: core.ShapedArray,
sharding: jax.sharding.Sharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
Expand Down Expand Up @@ -1778,17 +1782,16 @@ def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
@weakref_lru_cache
def _discharge_refs(
jaxpr: core.ClosedJaxpr
) -> tuple[core.ClosedJaxpr, None | Sequence[int | None], None | Sequence[int | None]]:
) -> tuple[core.ClosedJaxpr, Sequence[int | None], Sequence[int | None]]:
from jax._src.state.discharge import discharge_state
out_mut = [None] * len(jaxpr.out_avals) + [
i for i, a in enumerate(jaxpr.in_avals) if isinstance(a, AbstractRef)]
count = it.count()
inout_aliases = tuple(next(count) if isinstance(a, AbstractRef) else None
for a in jaxpr.in_avals)
jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
assert len(inout_aliases) == len(jaxpr.in_avals)
assert len(out_mut) == len(jaxpr.out_avals)
return jaxpr, inout_aliases, out_mut
new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end
inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals)
if isinstance(a, AbstractRef)}
outin_map = {j: i for i, j in inout_map.items()}
inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals))))
out_mut = tuple(map(outin_map.get, range(len(new_jaxpr.out_avals))))
return new_jaxpr, inout_aliases, out_mut


@dataclasses.dataclass(frozen=True)
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import functools
from functools import partial
import itertools as it
import operator
from typing import Any, Callable, Protocol, Union

import numpy as np
Expand Down Expand Up @@ -166,6 +165,7 @@ def _canonicalize_python_scalar_dtype(typ, x):
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
canonicalize_dtype_handlers[core.Token] = identity
canonicalize_dtype_handlers[core.DArray] = identity
canonicalize_dtype_handlers[core.MutableArray] = identity

def abstractify(x) -> Any:
typ = type(x)
Expand Down Expand Up @@ -196,7 +196,8 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:


pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {}
pytype_aval_mappings[core.DArray] = operator.attrgetter('_aval')
pytype_aval_mappings[core.DArray] = lambda x: x._aval
pytype_aval_mappings[core.MutableArray] = lambda x: x._aval
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
for t in numpy_scalar_types)
Expand Down
19 changes: 4 additions & 15 deletions jax/_src/state/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import pretty_printer as pp
from jax._src import tree_util
from jax._src.interpreters import ad
Expand Down Expand Up @@ -53,11 +54,7 @@
# `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# a:f32[3] <- x[]
get_p = core.Primitive("get")

def _get_impl(ref: AbstractRef, *args: Any, tree):
del ref, args, tree
raise ValueError("Cannot run stateful primitive.")
get_p.def_impl(_get_impl)
get_p.def_impl(partial(dispatch.apply_primitive, get_p))

Indexer = tuple[Union[int, slice, Array], ...]
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
Expand Down Expand Up @@ -113,11 +110,7 @@ def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
# x:Ref{f32[3]}[i, j] <- a
swap_p = core.Primitive("swap")

def _swap_impl(ref: AbstractRef, value: Array, *idx: Any, tree):
del ref, value, idx, tree
raise ValueError("Cannot run stateful primitive.")
swap_p.def_impl(_swap_impl)
swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))

def ref_swap(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array,
_function_name: str = "ref_swap") -> Array:
Expand All @@ -143,11 +136,7 @@ def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Arra
# ```
addupdate_p = core.Primitive('addupdate')
addupdate_p.multiple_results = True

def _addupdate_impl(ref: AbstractRef, value: Array, *args: Any, tree):
del ref, value, args, tree
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
addupdate_p.def_impl(_addupdate_impl)
addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p))

def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
Expand Down
83 changes: 35 additions & 48 deletions tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,6 @@

class StatePrimitivesTest(jtu.JaxTestCase):

def test_cant_eval_get_primitive(self):
with self.assertRaises(ValueError):
get_p.bind(jnp.ones(5), tree=None)

def test_cant_eval_swap_primitive(self):
with self.assertRaises(ValueError):
swap_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)

def test_cant_eval_addupdate_primitive(self):
with self.assertRaises(ValueError):
addupdate_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)

def test_get_abstract_aval_must_take_in_refs(self):
ref_aval = core.ShapedArray((), jnp.float32)
def f(x_ref):
Expand Down Expand Up @@ -1508,55 +1496,54 @@ def _body(ref):
jtu.check_grads(f, (0.5,), order=3)


class MutableArray:
_aval: core.ShapedArray
_buf: jax.Array
def __init__(self, aval, buf):
self._aval = aval
self._buf = buf
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)

def mutable_array(init_val):
return mutable_array_p.bind(init_val)
mutable_array_p = core.Primitive('mutable_array')

@mutable_array_p.def_impl
def _mutable_array_impl(init_val):
aval = core.raise_to_shaped(core.get_aval(init_val))
return MutableArray(AbstractRef(aval), init_val)

def _error_on_staging(trace, x):
raise Exception
pe.custom_staging_rules[mutable_array_p] = _error_on_staging

from jax._src.interpreters import xla
from jax._src.interpreters import pxla
xla.canonicalize_dtype_handlers[MutableArray] = lambda x: x
xla.pytype_aval_mappings[MutableArray] = lambda x: x._aval
pxla.shard_arg_handlers[MutableArray] = lambda x, s: pxla.shard_arg(x._buf, s)
core.pytype_aval_mappings[MutableArray] = lambda x: x._aval

class MutableArrayTest(jtu.JaxTestCase):

def test_basic(self):
read = jax.jit(lambda x_ref: x_ref[...])

@jax.jit
@parameterized.parameters([True, False])
def test_basic(self, jit):
def f(x_mut):
x_mut[...] += 1.
x_mut[0] += 1
x_mut[1] += 5

x_mut = mutable_array(jnp.zeros(3))
if jit:
f = jax.jit(f)

x_mut = core.mutable_array(jnp.zeros(3))
f(x_mut)

self.assertAllClose(read(x_mut), jnp.array([2., 6., 1.]), check_dtypes=False)
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
check_dtypes=False)

jaxpr = jax.make_jaxpr(f)(x_mut)
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))

def test_staging_error(self):
x = jnp.zeros(3)
with self.assertRaises(Exception):
jax.jit(core.mutable_array)(x)

@parameterized.parameters([True, False])
def test_multiple_inputs_and_outputs(self, jit):
def f(x_mut, y, z_mut, w):
x_mut[...] += 1
z_mut[...] += 1
return x_mut[...] + y + z_mut[...] + w, y + w

if jit:
f = jax.jit(f)

x_mut = core.mutable_array(jnp.zeros((1, 3)))
y = jnp.ones((2, 3))
z_mut = core.mutable_array(jnp.zeros((2, 3)))
w = jnp.ones((2, 1))

out1, out2 = f(x_mut, y, z_mut, w)

self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False)
self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False)
self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False)
self.assertAllClose(out2, y + w, check_dtypes=False)


if CAN_USE_HYPOTHESIS:

Expand Down

0 comments on commit 28f84eb

Please sign in to comment.