Skip to content

Commit

Permalink
Add an extension mechanism to run_state that allows:
Browse files Browse the repository at this point in the history
* Uninitialized values
* Custom ref aval construction

This will allow us to replace `run_scoped` with `run_state`, and allow us to change the memory space of initialized values.

PiperOrigin-RevId: 687287590
  • Loading branch information
sharadmv authored and Google-ML-Automation committed Oct 19, 2024
1 parent 884f1dc commit 89ca5d4
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 48 deletions.
9 changes: 9 additions & 0 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from jax._src import sharding_impls
from jax._src import sharding_specs
from jax._src import source_info_util
from jax._src import state
from jax._src import traceback_util
from jax._src import pjit
from jax._src import xla_bridge as xb
Expand All @@ -69,6 +70,7 @@
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
from jax._src.state import types as state_types
from jax._src.layout import Layout, AutoLayout
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
Expand Down Expand Up @@ -2555,6 +2557,13 @@ def _sds_aval_mapping(x):
weak_type=x.weak_type)
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping

def _sdstruct_ref_type(x: ShapeDtypeStruct) -> tuple[state.AbstractRef, basearray.Array]:
# Just initialize it with zeros as a reasonable starting point
uninitialized = lax_internal.full(x.shape, 0, x.dtype)
return state.AbstractRef(core.ShapedArray(x.shape, x.dtype)), uninitialized
state_types._ref_type_aval_mappings[ShapeDtypeStruct] = _sdstruct_ref_type



@api_boundary
def eval_shape(fun: Callable, *args, **kwargs):
Expand Down
Loading

0 comments on commit 89ca5d4

Please sign in to comment.