-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A converter for FXGraph with Torch calls -> FXGraph with Thunder calls #1261
Conversation
f1d6a6c
to
8800518
Compare
1d1ab80
to
b73c555
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, I just have one main comment related to information being saved in SubgraphInfo
. Curious to know your thoughts on the same. Thanks @kiya00!
The pull request is in a good state already. Before merging, let's ensure it enables the use case we're interested in. Please add the following patch diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index afb9d79f..6a415499 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -532,7 +532,7 @@ class Benchmark_litGPT:
return model
def setup_activation_checkpointing(self):
- if "thunder" in self.compile:
+ if "thunder" in self.compile and not "dynamo" in self.compile:
# checkpointing is an option to thunder.jit
return
@@ -571,11 +571,6 @@ class Benchmark_litGPT:
executors.insert(0, transformer_engine_ex)
- jit_options = {
- "enable_saved_for_backward_recomputation": self.checkpoint_activations,
- "recomputation_policy": None,
- }
-
if "dynamo" in self.compile:
if self.distributed_mode == "fsdp2":
print("Resetting cache size for when fsdp2 and using thunder as backend torch.compile")
@@ -583,13 +578,17 @@ class Benchmark_litGPT:
dynamo_config.cache_size_limit = 64
- backend = ThunderCompiler(executors=executors, **jit_options)
+ backend = ThunderCompiler(executors=executors)
# Because Lightning Fabric is imported in this script it monkey patches the torch.compile function
# https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
# using __wrapped__ to access the original torch.compile function did not work
# so we are using the lower level torch._dynamo.optimize function
model = torch._dynamo.optimize(backend=backend)(model)
else:
+ jit_options = {
+ "enable_saved_for_backward_recomputation": self.checkpoint_activations,
+ }
jit_options["fp8_shard_intermediate_activation"] = self.fp8_shard_intermediate_activation
model = thunder.jit(model, executors=executors, **jit_options) then try for example how
There's a problem with the benchmarking script at the moment and Dynamo doesn't give checkpointed operations at all to Thunder, PyTorch Eager is used instead. Could you please find the correct usage of activation checkpointing so that it works with torch.compile and fix the benchmarking script (#1298)? |
a69a122
to
565ff27
Compare
With the patch on this PR, the benchmark is tested to confirm the native pytorch checkpoint can work. patch
|
Hi @IvanYashchuk @kshitij12345 , do you want to take another look, I think it's ready to merge |
809865f
to
63e426d
Compare
Yes, I want to have another look and I'll do that first thing on Monday or before that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @kiya00 !
for more information, see https://pre-commit.ci
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
63e426d
to
856141c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's resolve #1337 first and then merge this PR so that we are confident it would bring value for activation checkpointing use cases.
Co-authored-by: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
As mentioned in #1370 (comment)
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work on enabling models with activation checkpointing to run with Thunder (#1261 (comment))!
@t-vi, could you please merge this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank yyou @kiya00 @IvanYashchuk @crcrpar @kshitij12345
Dynamo represents the
function
argument oftorch.utils.checkpoint.checkpoint(function, args...)
as FX Graph, this FX graph has PyTorch operators in it. This PR creates a converter to replace the torch operators with the thunder equivalents, so it can be traced in the thundercheckpoint
symbol. With this change native PyTorch activation checkpointing can be supported.