From 46a00b936050505581b51bfe7d6cbf4adadcd49b Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 15 Mar 2024 10:14:57 -0700 Subject: [PATCH] disable optimization --- jax/_src/pjit.py | 66 +++++++++++++++++++++++++---------------------- tests/api_test.py | 9 +++++++ 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a3eeb61f1403..1a3d948dea6b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1808,34 +1808,37 @@ def _pjit_partial_eval(trace, *in_tracers, def keep_where(l, should_keep): return tuple(x for x, keep in zip(l, should_keep) if keep) - # Input-to-output forwarding: compute which outputs are just forwarded inputs. - num_out_primals = len(known_jaxpr.out_avals) - num_residuals - in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr) - # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. - in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) - in_fwd = [fwd if is_unspecified(os) else None for os, fwd in - zip(keep_where(out_shardings, known_outs), in_fwd_primal) - ] + in_fwd_res - del in_fwd_primal, in_fwd_res - # Prune jaxpr outputs and out_shardings by removing the input-forwards. - keep = [f is None for f in in_fwd] - known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings - known_out_shardings = keep_where(known_out_shardings, keep) - # Update num_out_primals to reflect pruning. - kept_primals, kept_res = split_list(keep, [num_out_primals]) - num_out_primals = sum(f is None for f in kept_primals) - del keep, kept_primals, kept_res - - # Output-to-output forwarding: compute which residuals are just primal outputs - out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals]) - idx_map = {id(v): i for i, v in enumerate(out_vars)} - out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars] - # Prune jaxpr outputs and out_shardings by removing forwarded residuals. - keep = [f is None for f in out_fwd] - known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) - known_out_shardings = keep_where(known_out_shardings, keep) - del keep + + # TODO(mattjj): un-disable this optimization after we have more tests + # # Input-to-output forwarding: compute which outputs are just forwarded inputs. + # num_out_primals = len(known_jaxpr.out_avals) - num_residuals + # in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr) + # # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. + # in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) + # in_fwd = [fwd if is_unspecified(os) else None for os, fwd in + # zip(keep_where(out_shardings, known_outs), in_fwd_primal) + # ] + in_fwd_res + # del in_fwd_primal, in_fwd_res + # # Prune jaxpr outputs and out_shardings by removing the input-forwards. + # keep = [f is None for f in in_fwd] + # known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) + # known_out_shardings = keep_where(known_out_shardings, keep) + # # Update num_out_primals to reflect pruning. + # kept_primals, kept_res = split_list(keep, [num_out_primals]) + # num_out_primals = sum(f is None for f in kept_primals) + # del keep, kept_primals, kept_res + + # TODO(mattjj): un-disable this optimization after we have more tests + # # Output-to-output forwarding: compute which residuals are just primal outputs + # out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals]) + # idx_map = {id(v): i for i, v in enumerate(out_vars)} + # out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars] + # # Prune jaxpr outputs and out_shardings by removing forwarded residuals. + # keep = [f is None for f in out_fwd] + # known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) + # known_out_shardings = keep_where(known_out_shardings, keep) + # del keep known_params = dict( jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins), @@ -1847,10 +1850,11 @@ def keep_where(l, should_keep): # Bind known things to pjit_p. known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()] all_known_outs = pjit_p.bind(*known_inputs, **known_params) - # Add back in the output fwds. - all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs) - # Add back in the input fwds. - all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs) + # TODO(mattjj): un-disable this optimization after we have more tests + # # Add back in the output fwds. + # all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs) + # # Add back in the input fwds. + # all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs) known_out_vals, residual_vals = \ split_list(all_known_outs, [len(all_known_outs) - num_residuals]) diff --git a/tests/api_test.py b/tests/api_test.py index 9b177e61aade..6db67205534d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4565,6 +4565,15 @@ def foo(self): gc.collect() assert a() is None + def test_forwarding_bug(self): + # Test for issue #20267. + def f(x): + @jax.jit + def inner(a, x): + return a, jnp.exp(x) + return inner(0., x)[0] + jax.grad(f)(1.) # don't crash + class RematTest(jtu.JaxTestCase):