Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update checkpointing support for jit #1560

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Dec 17, 2024

I'll add switches and testing of checkpointing, but here is the material code changes.

@t-vi
Copy link
Collaborator Author

t-vi commented Dec 17, 2024

There are a number of things still to be fixed.
There is a failure with the memory calculation because we don't update the initial collection proxy.
I'll look into these.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@t-vi
Copy link
Collaborator Author

t-vi commented Dec 18, 2024

To my mind, the remaining missing bit is the handling of uniform -> uniform_philox conversion (currently handled by rematerialization) and the interaction with recomputation. I'll look into it. In the meantime, I'd be keen to hear complaints and/or success stories about memory / performance impact.

@t-vi t-vi force-pushed the tom/checkpointing-memory branch from e103484 to 891bf84 Compare December 18, 2024 14:38
@riccardofelluga
Copy link
Collaborator

I was trying to check memory savings but it looks like the following hangs:

python thunder/benchmarks/benchmark_litgpt.py --model_name stablecode-completion-alpha-3b --compile thunder --checkpoint_activations True --low_precision_mode none --micro_batch_size 1

:(

@t-vi t-vi mentioned this pull request Dec 19, 2024
@t-vi
Copy link
Collaborator Author

t-vi commented Dec 19, 2024

So one needs to enable checkpointing layers with compiler==thunder for this. Then the memory profile of the backward is still terrible:
Running @riccardofelluga 's benchmark:
Eager with checkpointing and 8 layers needs 12.59GB, thunder with checkpointing and 8 layers needs 38GB.
This is not the saved for backwards, which is (for thunder): Saved for backward size: 1761.89 MiB Saved for backward number of tensors: 103

I think we need to look more closely at the memory over time.

@t-vi
Copy link
Collaborator Author

t-vi commented Dec 19, 2024

Doing the following:

  • apply del last used to any trace of the last_backward_traces
  • run the memory examine tool (so awesome @kiya00 !)

We see that we are doing much better at first than after transform for execution (this is for 4 layers), so we still have reordering that hurts us.

#the following trace uses ~9.84GB memory
# Constructed by Saved for backward remat trace (took 20.46 milliseconds)
#the following trace uses ~12.39GB memory
# Constructed by Transform for execution (took 886 milliseconds)

Note that the difference of 2.55GB is smaller than the difference between thunder and eager (5.82GB).

@t-vi t-vi force-pushed the tom/checkpointing-memory branch from ebf420c to b1604ee Compare December 20, 2024 21:50
@t-vi
Copy link
Collaborator Author

t-vi commented Dec 21, 2024

With the two latest bits

I have that

python thunder/benchmarks/benchmark_litgpt.py --model_name stablecode-completion-alpha-3b --compile thunder --checkpoint_activations True --low_precision_mode none --micro_batch_size 1 --n_layer 4

is on par (even a little below) the same with --compile eager.
I will be looking at the failing tests and splitting out some bits that can be handled independently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants