Skip to content

Commit

Permalink
[pallas] run_scoped now supports partial discharge.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687347284
  • Loading branch information
cperivol authored and Google-ML-Automation committed Oct 18, 2024
1 parent ade480f commit f8a3c03
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
31 changes: 20 additions & 11 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,17 +857,22 @@ def _run_scoped_abstract_eval(*args, jaxpr):
return [v.aval for v in jaxpr.outvars], nonlocal_effects


def _run_scoped_discharge_rule(in_avals,
out_avals,
*args_flat,
jaxpr,
**_):
def _run_scoped_discharge_rule(
should_discharge,
in_avals,
out_avals,
*args_flat,
jaxpr,
**_):
del out_avals
num_consts = len(args_flat)
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
num_return_values = len(jaxpr_noconst.outvars)
should_discharge = should_discharge + [
isinstance(var.aval, state.AbstractRef) for var in jaxpr.invars
]
discharged_body, new_consts = state_discharge.discharge_state(
jaxpr_noconst, [])
jaxpr_noconst, [], should_discharge=should_discharge)
if new_consts:
raise NotImplementedError(
"Cannot handle new consts created by state discharge.")
Expand All @@ -886,23 +891,27 @@ def _run_scoped_discharge_rule(in_avals,
updates = [
ref_outputs.pop(0) if isinstance(aval, pallas_core.AbstractMemoryRef)
else None for aval in in_avals]
assert len(ref_outputs) == len(
body_avals), f'{len(body_avals)}, != {len(ref_outputs)}'
assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}'
return updates, return_values


state_discharge.register_discharge_rule(run_scoped_p)(
state_discharge.register_partial_discharge_rule(run_scoped_p)(
_run_scoped_discharge_rule)


@functools.partial(mlir.register_lowering, run_scoped_p)
def _run_scoped_lowering_rule(ctx, *args, jaxpr):
# This lowering rule gets triggered when run_scoped is not discharged.
# In this case there are no stateful effects to handle.
should_discharge = [
isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
]

def _lower_fun(*lower_fun_args):
updates, out = _run_scoped_discharge_rule([], [], *lower_fun_args,
jaxpr=jaxpr)
updates, out = _run_scoped_discharge_rule(
should_discharge,
[], [], *lower_fun_args,
jaxpr=jaxpr)
assert len(updates) == 0, 'Cannot lower run_scoped with effects.'
return out
return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)
22 changes: 22 additions & 0 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from jax._src.lib import xla_extension
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
from jax._src.state import utils as state_utils
from jax._src.state import discharge as state_discharge
from jax.experimental import mesh_utils
from jax.experimental import mosaic
from jax.experimental import pallas as pl
Expand Down Expand Up @@ -860,6 +861,27 @@ def inner_body(z_ref):
)()
np.testing.assert_allclose(o, 4 * np.ones_like(o))

def test_run_scoped_partial_discharge(self):
def f(a_ref, b_ref):
def scope():
a_ref[...] = jnp.ones(4, jnp.float32)
b_ref[...] = jnp.ones(4, jnp.float32)
return []
pl.run_scoped(scope)
return []

aref = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
in_avals = [aref, aref]
stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = state_discharge.discharge_state(
stateful_jaxpr, consts=(), should_discharge=[False, True])
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.AbstractRef)
self.assertIsInstance(discharged_jaxpr.invars[1].aval, jax.core.ShapedArray)
self.assertEqual(discharged_jaxpr.effects, {state.WriteEffect(0)})

def test_can_allocate_semaphore(self):
def kernel(y_ref):
def body(sem1):
Expand Down

0 comments on commit f8a3c03

Please sign in to comment.