-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perf Regression : SDPA is recomputed in backward #1646
Comments
Note that unless you use the (legacy?) options, only tensors tagged as RECOMPUTE_IN_BACKWARD get recomputed. Is the regression here from recomputing intermediates in decompositions? lightning-thunder/thunder/core/trace_interpreter.py Lines 188 to 191 in 21d9319
I would not want to have I basically see two ways:
WDYT? |
lightning-thunder/thunder/core/trace_interpreter.py Lines 185 to 191 in 21d9319
I did a quick check and one of the output of cudnn is not present in swapmap, so it is tagged for recompute:
I think it would be safer to have |
Since #1615, recomputing in backward is on by default.
lightning-thunder/thunder/core/transforms.py
Lines 3160 to 3164 in 21d9319
This means that we may end up recomputing compute heavy ops like SDPA leading to perf regression.
Running
python thunder/benchmarks/benchmark_litgpt.py --model_name Llama-2-7b-hf --compile dynamo_thunder --micro_batch_size 1 --dump_thunder_traces true
Current Main
With patch to disable recomputing SDPA
I think we should test recomputing for performance and memory usage on multiple models and then turn it on by default based on the numbers.
Also, we should form a list of compute heavy operations for which recomputing might not be best choice and tag them accordingly.
Patch to disable recomputing for SDPA
Sample Backward Trace for a 1 layer model (note that there is cudnn_sdpa_fwd in the trace)
cc: @t-vi @IvanYashchuk @riccardofelluga
cc @t-vi @lantiga @tfogal
The text was updated successfully, but these errors were encountered: