Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust the indices in _pjit_partial_eval to account for removed primals. #20270

Closed
wants to merge 1 commit into from

Conversation

nshepperd
Copy link

idx_map indexes the residuals that are also primals (ie. duplicated in the output of known_jaxpr), which are then deduplicated. But non-residual primal outputs of known_jaxpr can also be removed on account of being forwarded inputs, which alters the previously calculated primal indices. We need to adjust them to account for this.

Fixes #20267.

Copy link

google-cla bot commented Mar 15, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@nshepperd nshepperd force-pushed the fix-pjit branch 2 times, most recently from 101f738 to 23d60f5 Compare March 15, 2024 15:05
@yashk2810 yashk2810 requested a review from mattjj March 15, 2024 15:06
jax/_src/pjit.py Outdated Show resolved Hide resolved
@nshepperd
Copy link
Author

Some of the api_tests failed. Looks like because it's possible for residual outputs of known_jaxpr to also be primal outputs that are themselves forwarded inputs, which I didn't account for. Fixed by setting in_fwd instead of out_fwd for those.

@mattjj
Copy link
Collaborator

mattjj commented Mar 15, 2024

Thanks so much for looking at this! I'm in meetings for the next 90 mins but would like to review this, since it's scary to me that there was a bug here in the first place, so we should be sure to get it right this time.

I'd love to write more exhaustive tests too, covering all the cases...

@mattjj
Copy link
Collaborator

mattjj commented Mar 15, 2024

@nshepperd thanks so much for diving in to attack this bug!

However, I think we should merge #20273 instead of this PR. That PR does two things:

  1. Rewrites the logic to try to be more obviously correct. That may be a subjective thing, but I think refactoring it to do the two forwarding optimizations one after another, rather than trying to do them at the same time as before, makes it easier to see they won't interfere with each other.
  2. Disables the optimization entirely until we can add more tests. It was a mistake for me to merge this change without exhaustive tests in the first place. So I want to disable it until we can thoroughly test.

What do you think?

@nshepperd
Copy link
Author

Hehe, I guess you merged it. LGTM though. Good test cases!

@nshepperd nshepperd closed this Mar 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

grad(jit(f)) can fail when an output which is a forwarded input precedes an output which is also a residual
3 participants