Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust the indices in _pjit_partial_eval to account for removed primals. #20270

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,12 +1824,27 @@ def keep_where(l, should_keep):
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]

# Rarely, residuals can be primal outputs that are themselves forwarded inputs.
# In this case, set in_fwd for them.
for i in range(len(out_fwd)):
j = out_fwd[i]
if j is not None and in_fwd[j] is not None:
in_fwd[i] = in_fwd[j]
out_fwd[i] = None

# Prune jaxpr outputs and out_shardings by removing forwards.
keep = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
known_out_shardings = keep_where(known_out_shardings, keep)
del keep, num_out_primals

# Adjust the indices in out_fwd to account for the removed outputs.
kept = {}
for i in range(len(keep)):
if keep[i]:
kept[i] = len(kept)
out_fwd = [None if i is None else kept[i] for i in out_fwd]
del keep, kept, num_out_primals

known_params = dict(
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
Expand Down
9 changes: 9 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,15 @@ def testAutodiff(self, mesh, resources):
)
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)

def testAutodiffPrimals(self):
# Test for issue #20267.
def f(x):
@jax.jit
def inner(a, x):
return a, jnp.exp(x)
return inner(0., x)[0]
jax.grad(f)(1.)

@jtu.with_mesh([('x', 2), ('y', 1)])
def testAutodiffCache(self):
f = pjit(
Expand Down