Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor recomputation to work with tags #1615

Merged
merged 4 commits into from
Jan 8, 2025
Merged

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Jan 7, 2025

Another step of #1560

This refactors the recomputation (activation checkpointing):

  • recomputation works with tags,
  • automatically tag intermediates from autograd on decomposition as to be recomputed, (this matches the behaviour of rematerialize_forward_backward, I think),
  • I needed to disable rematerialize_forward_backward because it ran into "infinite capacity". However, I think this is not needed after this PR (cc @IvanYashchuk),
  • move the uniform -> get_and_update_random_state + uniform_philox transform to before the autograd
  • guard against recomputing random ops,

This is expected to be memory/compute neutral (I'll report numbers in a bit). It does not yet do the checkpointing frontend for the jit (including using memory commparable to eager checkpointing), that will be a separate PR.

@riccardofelluga could you take a look? (should be no surprises relative to #1560)

@t-vi t-vi requested review from mruberry and lantiga as code owners January 7, 2025 20:43
@t-vi
Copy link
Collaborator Author

t-vi commented Jan 7, 2025

So there is quite a regression in memory use:

python thunder/benchmarks/benchmark_litgpt.py --model_name stablecode-completion-alpha-3b --compile thunder --checkpoint_activations True --low_precision_mode none --micro_batch_size 1 --n_layer 4 --max_iters 3 --warmup_iters 2  --dump_thunder_traces True 

gives

  • Main: Average iter time: 1279.73 ms, Memory used: 13.27 GB
  • Average iter time: 1146.84 ms, Memory used: 13.29 GB

So 10% faster and not that much more memory (not sure what the 0.02GB are...)

@t-vi
Copy link
Collaborator Author

t-vi commented Jan 7, 2025

So I'm not entirely happy with disabling the rematerialize_forward_and_backward, but I wonder if the need arises from proxies that are re-computed and thus appear twice in the joint_extrace. (Hello name collision.)
I'll try that either here or in a follow-up (likely by renaming all fw tensors to fw_... and all bw tensors bw_ ... or somesuch.

@riccardofelluga riccardofelluga self-requested a review January 8, 2025 09:52
@t-vi t-vi mentioned this pull request Jan 8, 2025
@t-vi
Copy link
Collaborator Author

t-vi commented Jan 8, 2025

Unfortunately, disabling the forward-backward-rematerialization adversely affects memory for Qwen, but happily we recover that when we do better checkpointing:

python thunder/benchmarks/benchmark_litgpt.py --model_name Qwen2.5-7B --compile thunder --checkpoint_activations True --low_precision_mode none --micro_batch_size 1 --n_layer 4 --max_iters 3 --warmup_iters 2 --block_size 4096 
  • Main: Average iter time: 788.88 ms, Memory used: 18.76 GB
  • This PR: Average iter time: 686.02 ms, Memory used: 20.77 GB
  • PR use tagging checkpointing #1616 : Average iter time: 779.80 ms, Memory used: 18.34 GB
  • Eager instead of thunder: Average iter time: 803.95 ms Memory used: 17.55 GB

Copy link
Collaborator

@riccardofelluga riccardofelluga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it looks great! Just a couple of nits and clarifications

thunder/executors/torch_autograd.py Outdated Show resolved Hide resolved
thunder/core/trace_interpreter.py Outdated Show resolved Hide resolved
thunder/core/transforms.py Show resolved Hide resolved
thunder/core/transforms.py Outdated Show resolved Hide resolved
@t-vi t-vi disabled auto-merge January 8, 2025 15:35
@t-vi t-vi merged commit e536ddc into main Jan 8, 2025
38 of 41 checks passed
@t-vi t-vi deleted the tom/recomputation-refactor branch January 8, 2025 15:35
Comment on lines +1900 to +1901
jfn = thunder.jit(fn, enable_saved_for_backward_recomputation=False)
jfn2 = thunder.jit(fn, enable_saved_for_backward_recomputation=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the default value now?

@IvanYashchuk
Copy link
Collaborator

automatically tag intermediates from autograd on decomposition as to be recomputed, (this matches the behaviour of rematerialize_forward_backward, I think),

No, it doesn't match the behavior of rematerialize_forward_backward. What to recompute in fusion regions is decided by a min-cut-based algorithm. This PR introduced a regression in peak used memory (checked for Llama 2 7B): #1621.

I needed to disable rematerialize_forward_backward because it ran into "infinite capacity".

The rematerialization code has assumptions on input traces to be functioning, these assumptions were violated resulting in the "infinite capacity" error. @riccardofelluga was hitting the same problem when working on the recomputation. The problems are supposed to be fixed with #1367. Riccardo, what's the current status of 1367?

@riccardofelluga
Copy link
Collaborator

#1367 was parked due to change of priorities, we could tho bring that back here adapting to the new saved-for-backward logic of #1615. Tho it can also be that the infinite capacity error is reached by a different cause

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants