Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A converter for FXGraph with Torch calls -> FXGraph with Thunder calls #1261

Merged
merged 25 commits into from
Nov 4, 2024

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Oct 4, 2024

Dynamo represents the function argument of torch.utils.checkpoint.checkpoint(function, args...) as FX Graph, this FX graph has PyTorch operators in it. This PR creates a converter to replace the torch operators with the thunder equivalents, so it can be traced in the thunder checkpoint symbol. With this change native PyTorch activation checkpointing can be supported.

@kiya00 kiya00 requested a review from IvanYashchuk October 4, 2024 14:16
@kiya00 kiya00 force-pushed the basedon-functional-autograd-checkpoint branch from f1d6a6c to 8800518 Compare October 4, 2024 16:47
@kiya00 kiya00 changed the base branch from main to functional-autograd-checkpoint October 7, 2024 10:46
@kiya00 kiya00 changed the base branch from functional-autograd-checkpoint to main October 7, 2024 10:46
@kiya00 kiya00 force-pushed the basedon-functional-autograd-checkpoint branch 2 times, most recently from 1d1ab80 to b73c555 Compare October 11, 2024 14:08
thunder/dynamo/splitter.py Outdated Show resolved Hide resolved
@kiya00 kiya00 requested a review from kshitij12345 October 11, 2024 14:31
@kiya00 kiya00 marked this pull request as ready for review October 11, 2024 14:31
@kiya00 kiya00 changed the title [WIP] A converter for FXGraph with Torch calls -> FXGraph with Thunder calls A converter for FXGraph with Torch calls -> FXGraph with Thunder calls Oct 11, 2024
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, I just have one main comment related to information being saved in SubgraphInfo. Curious to know your thoughts on the same. Thanks @kiya00!

thunder/dynamo/splitter.py Outdated Show resolved Hide resolved
thunder/dynamo/splitter.py Outdated Show resolved Hide resolved
thunder/dynamo/utils.py Outdated Show resolved Hide resolved
thunder/dynamo/splitter.py Outdated Show resolved Hide resolved
thunder/dynamo/splitter.py Outdated Show resolved Hide resolved
@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Oct 14, 2024

The pull request is in a good state already. Before merging, let's ensure it enables the use case we're interested in. Please add the following patch

diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index afb9d79f..6a415499 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -532,7 +532,7 @@ class Benchmark_litGPT:
         return model
 
     def setup_activation_checkpointing(self):
-        if "thunder" in self.compile:
+        if "thunder" in self.compile and not "dynamo" in self.compile:
             # checkpointing is an option to thunder.jit
             return
 
@@ -571,11 +571,6 @@ class Benchmark_litGPT:
 
                 executors.insert(0, transformer_engine_ex)
 
-            jit_options = {
-                "enable_saved_for_backward_recomputation": self.checkpoint_activations,
-                "recomputation_policy": None,
-            }
-
             if "dynamo" in self.compile:
                 if self.distributed_mode == "fsdp2":
                     print("Resetting cache size for when fsdp2 and using thunder as backend torch.compile")
@@ -583,13 +578,17 @@ class Benchmark_litGPT:
 
                     dynamo_config.cache_size_limit = 64
 
-                backend = ThunderCompiler(executors=executors, **jit_options)
+                backend = ThunderCompiler(executors=executors)
                 # Because Lightning Fabric is imported in this script it monkey patches the torch.compile function
                 # https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
                 # using __wrapped__ to access the original torch.compile function did not work
                 # so we are using the lower level torch._dynamo.optimize function
                 model = torch._dynamo.optimize(backend=backend)(model)
             else:
+                jit_options = {
+                    "enable_saved_for_backward_recomputation": self.checkpoint_activations,
+                }
                 jit_options["fp8_shard_intermediate_activation"] = self.fp8_shard_intermediate_activation
                 model = thunder.jit(model, executors=executors, **jit_options)

then try for example how longchat-7b-16k works with activation checkpointing:

TORCH_LOGS="+dynamo" python thunder/benchmarks/benchmark_litgpt.py --model_name longchat-7b-16k --compile dynamo+thunder --n_layers=2 --max_iters=2 --warmup_iters=1 --checkpoint_activations=True

There's a problem with the benchmarking script at the moment and Dynamo doesn't give checkpointed operations at all to Thunder, PyTorch Eager is used instead. Could you please find the correct usage of activation checkpointing so that it works with torch.compile and fix the benchmarking script (#1298)?

@kiya00 kiya00 force-pushed the basedon-functional-autograd-checkpoint branch 2 times, most recently from a69a122 to 565ff27 Compare October 15, 2024 14:09
@kiya00
Copy link
Collaborator Author

kiya00 commented Oct 15, 2024

With the patch on this PR, the benchmark is tested to confirm the native pytorch checkpoint can work.
python thunder/benchmarks/benchmark_litgpt.py --model_name longchat-7b-16k --compile thunder+dynamo --n_layers=2 --max_iters=5 --warmup_iters=3 --checkpoint_activations=True --dump_thunder_traces=True will print out the first and last trace

patch
diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index afb9d79f..beef7623 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -532,7 +532,7 @@ class Benchmark_litGPT:
         return model
 
     def setup_activation_checkpointing(self):
-        if "thunder" in self.compile:
+        if "thunder" in self.compile and "dynamo" not in self.compile:
             # checkpointing is an option to thunder.jit
             return
 
@@ -571,11 +571,6 @@ class Benchmark_litGPT:
 
                 executors.insert(0, transformer_engine_ex)
 
-            jit_options = {
-                "enable_saved_for_backward_recomputation": self.checkpoint_activations,
-                "recomputation_policy": None,
-            }
-
             if "dynamo" in self.compile:
                 if self.distributed_mode == "fsdp2":
                     print("Resetting cache size for when fsdp2 and using thunder as backend torch.compile")
@@ -583,13 +578,16 @@ class Benchmark_litGPT:
 
                     dynamo_config.cache_size_limit = 64
 
-                backend = ThunderCompiler(executors=executors, **jit_options)
+                self.backend = ThunderCompiler(executors=executors)
                 # Because Lightning Fabric is imported in this script it monkey patches the torch.compile function
                 # https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
                 # using __wrapped__ to access the original torch.compile function did not work
                 # so we are using the lower level torch._dynamo.optimize function
-                model = torch._dynamo.optimize(backend=backend)(model)
+                model = torch._dynamo.optimize(backend=self.backend)(model)
             else:
+                jit_options = {
+                    "enable_saved_for_backward_recomputation": self.checkpoint_activations,
+                }
                 jit_options["fp8_shard_intermediate_activation"] = self.fp8_shard_intermediate_activation
                 model = thunder.jit(model, executors=executors, **jit_options)
 
@@ -819,9 +817,9 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
 
         print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms")
         print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB")
-        print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}")
-        print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}")
-        print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}")
+        # print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}")
+        # print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}")
+        # print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}")
 
         if benchmark.dump_memory_snapshot:
             file_name = f"{benchmark.model_name}_{benchmark.compile}_{benchmark.distributed_mode}"
@@ -843,16 +841,28 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
                 for jitted in benchmark.thunder_as_torch_compile_backend.gm_to_thunder.values():
                     fwd_traces.append(thunder.last_traces(jitted))
                     bwd_traces.append(thunder.last_backward_traces(jitted))
-            else:
+            elif "dynamo" not in benchmark.compile:
                 fwd_traces = [thunder.last_traces(benchmark.model)]
                 bwd_traces = [thunder.last_backward_traces(benchmark.model)]
 
-            for i, f_traces in enumerate(fwd_traces, start=1):
-                print(f"##########\n#{i}-th ThunderModule\n##########")
-                print(f_traces[-1])
-            for i, b_traces in enumerate(bwd_traces, start=1):
-                print(f"##########\n#{i}-th ThunderModule\n##########")
-                print(b_traces[-1])
+            if "dynamo" in benchmark.compile:
+                for gid, infos in enumerate(benchmark.backend.subgraph_infos):
+                    for subgid, thunder_fn in enumerate(infos.thunder_compiled_fns):
+                        print(f"##########\n#Graph{gid}-ThunderFn{subgid} first forward trace\n##########")
+                        print(thunder.last_traces(thunder_fn)[0])
+                        print(f"##########\n#Graph{gid}-ThunderFn{subgid} last forward trace\n##########")
+                        print(thunder.last_traces(thunder_fn)[-1])
+                        print(f"##########\n#Graph{gid}-ThunderFn{subgid} last backward trace\n##########")
+                        print(thunder.last_backward_traces(thunder_fn)[0])
+                        print(f"##########\n#Graph{gid}-ThunderFn{subgid} last backward trace\n##########")
+                        print(thunder.last_backward_traces(thunder_fn)[-1])
+            else:
+                for i, f_traces in enumerate(fwd_traces, start=1):
+                    print(f"##########\n#{i}-th ThunderModule\n##########")
+                    print(f_traces[-1])
+                for i, b_traces in enumerate(bwd_traces, start=1):
+                    print(f"##########\n#{i}-th ThunderModule\n##########")
+                    print(b_traces[-1])
 
     if global_rank in [0, None]:
         if return_metrics_as_json:

thunder/tests/test_dynamo.py Outdated Show resolved Hide resolved
thunder/dynamo/splitter.py Outdated Show resolved Hide resolved
thunder/tests/test_dynamo.py Outdated Show resolved Hide resolved
thunder/tests/test_dynamo.py Outdated Show resolved Hide resolved
thunder/tests/test_dynamo.py Show resolved Hide resolved
thunder/tests/test_dynamo.py Show resolved Hide resolved
thunder/torch/__init__.py Outdated Show resolved Hide resolved
@kiya00
Copy link
Collaborator Author

kiya00 commented Oct 18, 2024

Hi @IvanYashchuk @kshitij12345 , do you want to take another look, I think it's ready to merge

@kiya00 kiya00 force-pushed the basedon-functional-autograd-checkpoint branch from 809865f to 63e426d Compare October 18, 2024 11:37
@IvanYashchuk
Copy link
Collaborator

Hi @IvanYashchuk @kshitij12345 , do you want to take another look, I think it's ready to merge

Yes, I want to have another look and I'll do that first thing on Monday or before that.

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @kiya00 !

thunder/dynamo/utils.py Outdated Show resolved Hide resolved
thunder/tests/test_dynamo.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's resolve #1337 first and then merge this PR so that we are confident it would bring value for activation checkpointing use cases.

thunder/dynamo/splitter.py Outdated Show resolved Hide resolved
thunder/dynamo/utils.py Outdated Show resolved Hide resolved
thunder/dynamo/utils.py Outdated Show resolved Hide resolved
thunder/dynamo/utils.py Outdated Show resolved Hide resolved
thunder/dynamo/utils.py Show resolved Hide resolved
thunder/dynamo/utils.py Show resolved Hide resolved
thunder/torch/__init__.py Outdated Show resolved Hide resolved
thunder/tests/test_dynamo.py Outdated Show resolved Hide resolved
thunder/tests/test_dynamo.py Outdated Show resolved Hide resolved
thunder/tests/test_dynamo.py Outdated Show resolved Hide resolved
@kiya00
Copy link
Collaborator Author

kiya00 commented Oct 30, 2024

As mentioned in #1370 (comment)
Here are some results comparing this PR and the PR1370 in terms of peak memory usage on single H100 GPU

H100 1GPU 2layers this PR (GB) ThunderFX with checkpoint fallback to inductor(GB)
longchat-13b-16k 12.41 15.99
CodeLlama-34b-hf 25.03 30.09
Gemma-2-27b OOM OOM
Llama-3-70B 35.56 39.41
Mistral-7B-v0.2 21.36 28.78
vicuna-7b-v1.5-16k 12.41 15.99
cc: @IvanYashchuk

@kiya00
Copy link
Collaborator Author

kiya00 commented Oct 31, 2024

torchrun --nproc_per_node=8 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --model_name Mistral-7B-v0.2 --micro_batch_size 3 --compile thunder+dynamo --checkpoint_activations=True --distributed_mode=fsdp --shard_mode zero3 --max_iters=4 --warmup_iters=1 --bucketing_mode=block on 8*H100 (container: 20241028)

  micro batch size thunder+dynamo inductor
longchat-13b-16k 3 49.09 48.59
CodeLlama-34b-hf 1 50.47 50.8
Gemma-2-27b 1 OOM OOM
Llama-3-70B 1 OOM OOM
Mistral-7B-v0.2 3 63.43 64.63
vicuna-7b-v1.5-16k 5 51.67 51

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work on enabling models with activation checkpointing to run with Thunder (#1261 (comment))!

@IvanYashchuk
Copy link
Collaborator

@t-vi, could you please merge this PR?

@IvanYashchuk IvanYashchuk enabled auto-merge (squash) November 1, 2024 14:45
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IvanYashchuk IvanYashchuk merged commit 6341444 into main Nov 4, 2024
41 checks passed
@IvanYashchuk IvanYashchuk deleted the basedon-functional-autograd-checkpoint branch November 4, 2024 10:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants