Skip to content

Commit

Permalink
disable optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 15, 2024
1 parent c515f15 commit 46a00b9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 31 deletions.
66 changes: 35 additions & 31 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,34 +1808,37 @@ def _pjit_partial_eval(trace, *in_tracers,
def keep_where(l, should_keep):
return tuple(x for x, keep in zip(l, should_keep) if keep)

# Input-to-output forwarding: compute which outputs are just forwarded inputs.
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
] + in_fwd_res
del in_fwd_primal, in_fwd_res
# Prune jaxpr outputs and out_shardings by removing the input-forwards.
keep = [f is None for f in in_fwd]
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
known_out_shardings = keep_where(known_out_shardings, keep)
# Update num_out_primals to reflect pruning.
kept_primals, kept_res = split_list(keep, [num_out_primals])
num_out_primals = sum(f is None for f in kept_primals)
del keep, kept_primals, kept_res

# Output-to-output forwarding: compute which residuals are just primal outputs
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
# Prune jaxpr outputs and out_shardings by removing forwarded residuals.
keep = [f is None for f in out_fwd]
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
known_out_shardings = keep_where(known_out_shardings, keep)
del keep

# TODO(mattjj): un-disable this optimization after we have more tests
# # Input-to-output forwarding: compute which outputs are just forwarded inputs.
# num_out_primals = len(known_jaxpr.out_avals) - num_residuals
# in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
# # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
# in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
# in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
# zip(keep_where(out_shardings, known_outs), in_fwd_primal)
# ] + in_fwd_res
# del in_fwd_primal, in_fwd_res
# # Prune jaxpr outputs and out_shardings by removing the input-forwards.
# keep = [f is None for f in in_fwd]
# known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
# known_out_shardings = keep_where(known_out_shardings, keep)
# # Update num_out_primals to reflect pruning.
# kept_primals, kept_res = split_list(keep, [num_out_primals])
# num_out_primals = sum(f is None for f in kept_primals)
# del keep, kept_primals, kept_res

# TODO(mattjj): un-disable this optimization after we have more tests
# # Output-to-output forwarding: compute which residuals are just primal outputs
# out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
# idx_map = {id(v): i for i, v in enumerate(out_vars)}
# out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
# # Prune jaxpr outputs and out_shardings by removing forwarded residuals.
# keep = [f is None for f in out_fwd]
# known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
# known_out_shardings = keep_where(known_out_shardings, keep)
# del keep

known_params = dict(
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
Expand All @@ -1847,10 +1850,11 @@ def keep_where(l, should_keep):
# Bind known things to pjit_p.
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
all_known_outs = pjit_p.bind(*known_inputs, **known_params)
# Add back in the output fwds.
all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
# Add back in the input fwds.
all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
# TODO(mattjj): un-disable this optimization after we have more tests
# # Add back in the output fwds.
# all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
# # Add back in the input fwds.
# all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)

known_out_vals, residual_vals = \
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
Expand Down
9 changes: 9 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4565,6 +4565,15 @@ def foo(self):
gc.collect()
assert a() is None

def test_forwarding_bug(self):
# Test for issue #20267.
def f(x):
@jax.jit
def inner(a, x):
return a, jnp.exp(x)
return inner(0., x)[0]
jax.grad(f)(1.) # don't crash


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 46a00b9

Please sign in to comment.