Skip to content

Commit

Permalink
prototype 'remat' in hlo metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Feb 16, 2024
1 parent cfcae37 commit 08dfb11
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
7 changes: 4 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,9 @@ def transposed(*args_flat):
pe.PartialVal.known(next(ins_iter))
for aval, lin in zip(jaxpr.in_avals, in_lin)]
assert next(ins_iter, None) is None
lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False)
with source_info_util.extend_name_stack('rematted_computation'):
lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False)

# Transpose the linear jaxpr (which only has linear inputs).
out_cts_iter = iter(out_cts_flat)
Expand Down Expand Up @@ -696,7 +697,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
else:
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)

return api.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr)
return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr)

def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
args = _optimization_barrier(args)
Expand Down
13 changes: 13 additions & 0 deletions tests/name_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax import config
from jax._src import test_util as jtu
from jax._src.lib import xla_client
from jax._src import ad_checkpoint

config.parse_flags_with_absl()

Expand Down Expand Up @@ -267,6 +268,18 @@ def g(y):
self.assertIn('jvp(pjit(f))/pjit(g)/cos', hlo_text)
self.assertIn('transpose(jvp(pjit(f)))/pjit(g)/mul', hlo_text)

def test_remat_appears_in_hlo(self):
@ad_checkpoint.remat
def f(x):
return jnp.sin(x)

hlo_text = _get_hlo(f)(2.)
hlo_text_grad = _get_hlo(jax.grad(f))(2.)
self.assertNotIn('rematted_computation', hlo_text)
self.assertNotIn('remat', hlo_text)
self.assertIn('checkpoint', hlo_text)
self.assertIn('rematted_computation', hlo_text_grad)


class NameStackControlFlowTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 08dfb11

Please sign in to comment.