From f8a3c0366b364fbbbd0fa52bd65f364b5d307148 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Fri, 18 Oct 2024 10:21:53 -0700 Subject: [PATCH] [pallas] run_scoped now supports partial discharge. PiperOrigin-RevId: 687347284 --- jax/_src/pallas/primitives.py | 31 ++++++++++++++++++++----------- tests/pallas/tpu_pallas_test.py | 22 ++++++++++++++++++++++ 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 9e446917b896..97655b8dff4b 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -857,17 +857,22 @@ def _run_scoped_abstract_eval(*args, jaxpr): return [v.aval for v in jaxpr.outvars], nonlocal_effects -def _run_scoped_discharge_rule(in_avals, - out_avals, - *args_flat, - jaxpr, - **_): +def _run_scoped_discharge_rule( + should_discharge, + in_avals, + out_avals, + *args_flat, + jaxpr, + **_): del out_avals num_consts = len(args_flat) jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr) num_return_values = len(jaxpr_noconst.outvars) + should_discharge = should_discharge + [ + isinstance(var.aval, state.AbstractRef) for var in jaxpr.invars + ] discharged_body, new_consts = state_discharge.discharge_state( - jaxpr_noconst, []) + jaxpr_noconst, [], should_discharge=should_discharge) if new_consts: raise NotImplementedError( "Cannot handle new consts created by state discharge.") @@ -886,13 +891,11 @@ def _run_scoped_discharge_rule(in_avals, updates = [ ref_outputs.pop(0) if isinstance(aval, pallas_core.AbstractMemoryRef) else None for aval in in_avals] - assert len(ref_outputs) == len( - body_avals), f'{len(body_avals)}, != {len(ref_outputs)}' assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}' return updates, return_values -state_discharge.register_discharge_rule(run_scoped_p)( +state_discharge.register_partial_discharge_rule(run_scoped_p)( _run_scoped_discharge_rule) @@ -900,9 +903,15 @@ def _run_scoped_discharge_rule(in_avals, def _run_scoped_lowering_rule(ctx, *args, jaxpr): # This lowering rule gets triggered when run_scoped is not discharged. # In this case there are no stateful effects to handle. + should_discharge = [ + isinstance(aval, state.AbstractRef) for aval in ctx.avals_in + ] + def _lower_fun(*lower_fun_args): - updates, out = _run_scoped_discharge_rule([], [], *lower_fun_args, - jaxpr=jaxpr) + updates, out = _run_scoped_discharge_rule( + should_discharge, + [], [], *lower_fun_args, + jaxpr=jaxpr) assert len(updates) == 0, 'Cannot lower run_scoped with effects.' return out return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 3bab5bc88373..49dd127b76fe 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -32,6 +32,7 @@ from jax._src.lib import xla_extension from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax._src.state import utils as state_utils +from jax._src.state import discharge as state_discharge from jax.experimental import mesh_utils from jax.experimental import mosaic from jax.experimental import pallas as pl @@ -860,6 +861,27 @@ def inner_body(z_ref): )() np.testing.assert_allclose(o, 4 * np.ones_like(o)) + def test_run_scoped_partial_discharge(self): + def f(a_ref, b_ref): + def scope(): + a_ref[...] = jnp.ones(4, jnp.float32) + b_ref[...] = jnp.ones(4, jnp.float32) + return [] + pl.run_scoped(scope) + return [] + + aref = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) + in_avals = [aref, aref] + stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + in_avals) + discharged_jaxpr, _ = state_discharge.discharge_state( + stateful_jaxpr, consts=(), should_discharge=[False, True]) + self.assertLen(discharged_jaxpr.invars, 2) + self.assertLen(discharged_jaxpr.outvars, 1) + self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.AbstractRef) + self.assertIsInstance(discharged_jaxpr.invars[1].aval, jax.core.ShapedArray) + self.assertEqual(discharged_jaxpr.effects, {state.WriteEffect(0)}) + def test_can_allocate_semaphore(self): def kernel(y_ref): def body(sem1):