diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a5631573707b..a3eeb61f1403 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -76,7 +76,7 @@ from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, split_list, weakref_lru_cache, - merge_lists, flatten, unflatten, subs_list2) + merge_lists, flatten, unflatten, subs_list) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -1798,8 +1798,8 @@ def _pjit_partial_eval(trace, *in_tracers, known_ins = tuple(pv.is_known() for pv in in_pvals) unknown_ins = tuple(not k for k in known_ins) - known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = pe.partial_eval_jaxpr_nounits( - jaxpr, unknown_ins, instantiate=False) + known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ + pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) unknown_outs = tuple(unknown_outs) known_outs = tuple(not uk for uk in unknown_outs) num_residuals = len(res_avals) @@ -1808,28 +1808,34 @@ def _pjit_partial_eval(trace, *in_tracers, def keep_where(l, should_keep): return tuple(x for x, keep in zip(l, should_keep) if keep) - # Compute which outputs are just forwarded inputs. + # Input-to-output forwarding: compute which outputs are just forwarded inputs. num_out_primals = len(known_jaxpr.out_avals) - num_residuals - in_fwd = pe._jaxpr_forwarding(known_jaxpr.jaxpr) - + in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr) # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) in_fwd = [fwd if is_unspecified(os) else None for os, fwd in zip(keep_where(out_shardings, known_outs), in_fwd_primal) ] + in_fwd_res del in_fwd_primal, in_fwd_res + # Prune jaxpr outputs and out_shardings by removing the input-forwards. + keep = [f is None for f in in_fwd] + known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) + known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings + known_out_shardings = keep_where(known_out_shardings, keep) + # Update num_out_primals to reflect pruning. + kept_primals, kept_res = split_list(keep, [num_out_primals]) + num_out_primals = sum(f is None for f in kept_primals) + del keep, kept_primals, kept_res - # Compute which residuals are just primal outputs. + # Output-to-output forwarding: compute which residuals are just primal outputs out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals]) idx_map = {id(v): i for i, v in enumerate(out_vars)} out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars] - - # Prune jaxpr outputs and out_shardings by removing forwards. - keep = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] + # Prune jaxpr outputs and out_shardings by removing forwarded residuals. + keep = [f is None for f in out_fwd] known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) - known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings known_out_shardings = keep_where(known_out_shardings, keep) - del keep, num_out_primals + del keep known_params = dict( jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins), @@ -1841,16 +1847,18 @@ def keep_where(l, should_keep): # Bind known things to pjit_p. known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()] all_known_outs = pjit_p.bind(*known_inputs, **known_params) - all_known_outs = subs_list2(in_fwd, out_fwd, known_inputs, all_known_outs, - all_known_outs) + # Add back in the output fwds. + all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs) + # Add back in the input fwds. + all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs) known_out_vals, residual_vals = \ split_list(all_known_outs, [len(all_known_outs) - num_residuals]) residual_tracers = map(trace.new_instantiated_const, residual_vals) - # The convention of partial_eval_jaxpr_nounits is to place residual binders - # at the front of the jaxpr produced, so we move them to the back since both - # the jaxpr equation built below and the pjit transpose rule assume a + # The convention of partial_eval_jaxpr_nounits is to place residual binders at + # the front of the jaxpr produced, so we move them to the back since both the + # jaxpr equation built below and the pjit transpose rule assume a # residual-inputs-last convention. unknown_jaxpr = pe.move_binders_to_back( unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins))