Skip to content

Commit

Permalink
maybe fix forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 15, 2024
1 parent cdafb8f commit 8e8e297
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, weakref_lru_cache,
merge_lists, flatten, unflatten, subs_list2)
merge_lists, flatten, unflatten, subs_list)

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -1798,8 +1798,8 @@ def _pjit_partial_eval(trace, *in_tracers,

known_ins = tuple(pv.is_known() for pv in in_pvals)
unknown_ins = tuple(not k for k in known_ins)
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = pe.partial_eval_jaxpr_nounits(
jaxpr, unknown_ins, instantiate=False)
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
unknown_outs = tuple(unknown_outs)
known_outs = tuple(not uk for uk in unknown_outs)
num_residuals = len(res_avals)
Expand All @@ -1808,28 +1808,34 @@ def _pjit_partial_eval(trace, *in_tracers,
def keep_where(l, should_keep):
return tuple(x for x, keep in zip(l, should_keep) if keep)

# Compute which outputs are just forwarded inputs.
# Input-to-output forwarding: compute which outputs are just forwarded inputs.
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
in_fwd = pe._jaxpr_forwarding(known_jaxpr.jaxpr)

in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
] + in_fwd_res
del in_fwd_primal, in_fwd_res
# Prune jaxpr outputs and out_shardings by removing the input-forwards.
keep = [f is None for f in in_fwd]
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
known_out_shardings = keep_where(known_out_shardings, keep)
# Update num_out_primals to reflect pruning.
kept_primals, kept_res = split_list(keep, [num_out_primals])
num_out_primals = sum(f is None for f in kept_primals)
del keep, kept_primals, kept_res

# Compute which residuals are just primal outputs.
# Output-to-output forwarding: compute which residuals are just primal outputs
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]

# Prune jaxpr outputs and out_shardings by removing forwards.
keep = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
# Prune jaxpr outputs and out_shardings by removing forwarded residuals.
keep = [f is None for f in out_fwd]
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
known_out_shardings = keep_where(known_out_shardings, keep)
del keep, num_out_primals
del keep

known_params = dict(
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
Expand All @@ -1841,16 +1847,18 @@ def keep_where(l, should_keep):
# Bind known things to pjit_p.
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
all_known_outs = pjit_p.bind(*known_inputs, **known_params)
all_known_outs = subs_list2(in_fwd, out_fwd, known_inputs, all_known_outs,
all_known_outs)
# Add back in the output fwds.
all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
# Add back in the input fwds.
all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)

known_out_vals, residual_vals = \
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
residual_tracers = map(trace.new_instantiated_const, residual_vals)

# The convention of partial_eval_jaxpr_nounits is to place residual binders
# at the front of the jaxpr produced, so we move them to the back since both
# the jaxpr equation built below and the pjit transpose rule assume a
# The convention of partial_eval_jaxpr_nounits is to place residual binders at
# the front of the jaxpr produced, so we move them to the back since both the
# jaxpr equation built below and the pjit transpose rule assume a
# residual-inputs-last convention.
unknown_jaxpr = pe.move_binders_to_back(
unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins))
Expand Down

0 comments on commit 8e8e297

Please sign in to comment.