diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index f397b9ee8ca9..78318a06dfe5 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -630,8 +630,9 @@ def transposed(*args_flat): pe.PartialVal.known(next(ins_iter)) for aval, lin in zip(jaxpr.in_avals, in_lin)] assert next(ins_iter, None) is None - lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits( - lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False) + with source_info_util.extend_name_stack('rematted_computation'): + lin_jaxpr, _, consts = pe.trace_to_jaxpr_nounits( + lu.wrap_init(core.jaxpr_as_fun(jaxpr)), in_pvals, False) # Transpose the linear jaxpr (which only has linear inputs). out_cts_iter = iter(out_cts_flat) @@ -697,7 +698,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool, else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) - return api.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr) + return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr) def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): args = _optimization_barrier(args) diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index 23351659a017..e6ac29e7088b 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -23,6 +23,7 @@ from jax import config from jax._src import test_util as jtu from jax._src.lib import xla_client +from jax._src import ad_checkpoint config.parse_flags_with_absl() @@ -267,6 +268,18 @@ def g(y): self.assertIn('jvp(pjit(f))/pjit(g)/cos', hlo_text) self.assertIn('transpose(jvp(pjit(f)))/pjit(g)/mul', hlo_text) + def test_remat_appears_in_hlo(self): + @ad_checkpoint.remat + def f(x): + return jnp.sin(x) + + hlo_text = _get_hlo(f)(2.) + hlo_text_grad = _get_hlo(jax.grad(f))(2.) + self.assertNotIn('rematted_computation', hlo_text) + self.assertNotIn('remat', hlo_text) + self.assertIn('checkpoint', hlo_text) + self.assertIn('rematted_computation', hlo_text_grad) + class NameStackControlFlowTest(jtu.JaxTestCase):