Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for finding the recomputing symbols in rematerialization #700

Merged
merged 7 commits into from
Jul 5, 2024
Merged

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Jul 3, 2024

To Reproduce

The original bug appears when running with 2 nodes(8 GPUs each), another reproduction is using one node with 7 GPUs.
torchrun --nnodes=1 --nproc-per-node=7 ../thunder/benchmarks/benchmark_litgpt.py --model_name dolly-v2-3b --compile thunder_inductor_cat_cudnn --distributed_mode fsdp --shard_mode zero2
set n_layers=1 will give a shorter trace

The reason why it appears in such specific setting is because when sharding needs padding, it has additional slice operator.

To reproduce in one process fsdp:
Apply patch:

diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index bad6ef74..4199e15a 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -27,7 +27,7 @@ from lightning.fabric.utilities import Throughput
 world_size = int(os.environ.get("WORLD_SIZE", 1))
 local_rank = int(os.environ.get("LOCAL_RANK", 0))
 global_rank = int(os.environ.get("RANK", 0))
-if world_size > 1:
+if world_size >= 1:
     # Avoids the allocator thrashing issue in PyTorch NCCL backend.
     # See https://github.com/Lightning-AI/lightning-thunder/issues/420
     os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
@@ -179,8 +179,9 @@ class Benchmark_litGPT:
         self.profiler_start = profiler_start
         self.profiler_stop = profiler_stop
 
-        if n_layers is not None:
-            self.config.n_layer = n_layers
+        # if n_layers is not None:
+        #     self.config.n_layer = n_layers
+        self.config.n_layer = 1
 
         # Initialize the model
         t0 = time.perf_counter()
@@ -573,5 +574,5 @@ if __name__ == "__main__":
     CLI(benchmark_main)
 
     # ref: https://github.com/pytorch/pytorch/blob/3af12447/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1110-L1116
-    if world_size > 1:
+    if world_size >= 1:
         torch_dist.destroy_process_group()
diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py
index 6ba5d13c..a6acb644 100644
--- a/thunder/distributed/__init__.py
+++ b/thunder/distributed/__init__.py
@@ -694,6 +694,9 @@ def _shard_param(
         chunk_size = (padded_param_shape[0] + world_size - 1) // world_size
         padded_param_shape[0] = chunk_size * world_size
         _thunder_fsdp_padding_size = padded_param_shape[0] - param.size(0)
+        if name=='lm_head.weight':
+            _thunder_fsdp_padding_size=1
+            padded_param_shape[0] = padded_param_shape[0]+1
         if _thunder_fsdp_padding_size > 0:
             padded_param = torch.empty(padded_param_shape, device=param.device, dtype=param.dtype)
             padded_param[:orig_0dim_size].copy_(param)
diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py
index 820359f4..47951a8b 100644
--- a/thunder/executors/torch_autograd.py
+++ b/thunder/executors/torch_autograd.py
@@ -275,5 +275,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
     fw_extrace._include_te_fp8_autocast = True
     # We only want the forward function to be called with `te.fp8_autocast` manager.
     bw_extrace._include_te_fp8_autocast = False
+    import pdb;pdb.set_trace()
+    with open("trace",'w') as f:
+        f.write(str(fw_extrace))
+        f.write(str(bw_extrace))
 
     return fw_extrace, bw_extrace

Run torchrun --nnodes=1 --nproc-per-node=1 ../thunder/benchmarks/benchmark_litgpt.py --model_name dolly-v2-3b --compile thunder_inductor_cat_cudnn --distributed_mode fsdp --shard_mode zero2

Analysis

In rematerialization it finds the symbols that produce the rematerialized_inputs based on a combination of producer/consumer subsymbols(the trace below), but they could have the same subsymbols in producer and consumer:

trace.bound_symbols = (*producer.subsymbols, *consumer.subsymbols)
recomputing_symbols = utils.find_producer_symbols(trace, rematerialized_inputs, cut_inputs)

Here is an example of that in the bug:
(this is when remat applied on joint trace, I think the same subsymbols exist in both consumer/producer is because the previous remats happened in fusion_pass copy the same recomputing_symbols into both new_consumer)

# producer
[t606, t615] = TorchCompile0(t200, t201, t204, t593, t598, t677)
  # t202 = ltorch.mul(t200, t201)  # t202: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
    # t202 = prims.mul(t200, t201)  # t202: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t594 = ltorch.reshape(t593, (1, 2048, 2560))  # t594: "thunder.devices.Device(type='cuda:0') bf16[1, 2048, 2560]"
    # t594 = prims.reshape(t593, (1, 2048, 2560))  # t594: "thunder.devices.Device(type='cuda:0') bf16[1, 2048, 2560]"
  # t599 = prims.convert_element_type(t594, dtypes.thunder.dtypes.float32)  # t599: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t604 = ltorch.mul(t204, t599)  # t604: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
    # t604 = prims.mul(t204, t599)  # t604: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t605 = ltorch.mul(t202, t599)  # t605: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
    # t605 = prims.mul(t202, t599)  # t605: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t606 = prims.convert_element_type(t605, dtypes.thunder.dtypes.bfloat16)  # t606: "thunder.devices.Device(type='cuda:0') bf16[1, 2048, 2560]"
  # t610 = ltorch.mul(t201, t604)  # t610: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
    # t610 = prims.mul(t201, t604)  # t610: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t611 = ltorch.mul(t200, t604)  # t611: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
    # t611 = prims.mul(t200, t604)  # t611: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t614 = ltorch.neg(t610)  # t614: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
    # t614 = prims.neg(t610)  # t614: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t615 = prims.convert_element_type(t610, dtypes.thunder.dtypes.bfloat16)  # t615: "thunder.devices.Device(type='cuda:0') bf16[1, 2048, 2560]"
  # t678 = ltorch.neg(t677)  # t678: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 10240]"
    # t678 = prims.neg(t677)  # t678: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 10240]"
  # t1031 = prims.pad(t598, 0.0, ((0, 1, 0), (0, 0, 0)))  # t1031: "thunder.devices.Device(type='cuda:0') bf16[50281, 2560]"

# consumer
[t1034, t603, t609, t612, t616, t679] = nvFusion1(t200, t201, t204, t593, t598, t606, t677)
  # t594 = prims.reshape(t593, (1, 2048, 2560))  # t594: "thunder.devices.Device(type='cuda:0') bf16[1, 2048, 2560]"
  # t599 = prims.convert_element_type(t594, dtypes.thunder.dtypes.float32)  # t599: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t604 = prims.mul(t204, t599)  # t604: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t610 = prims.mul(t201, t604)  # t610: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t611 = prims.mul(t200, t604)  # t611: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t614 = prims.neg(t610)  # t614: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t678 = prims.neg(t677)  # t678: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 10240]"
  # t1031 = prims.pad(t598, 0.0, ((0, 1, 0), (0, 0, 0)))  # t1031: "thunder.devices.Device(type='cuda:0') bf16[50281, 2560]"
  # t679 = prims.exp(t678)  # t679: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 10240]"
  # t602 = prims.sum(t599, (0, 1))  # t602: "thunder.devices.Device(type='cuda:0') f32[2560]"
  # t603 = prims.convert_element_type(t602, dtypes.thunder.dtypes.bfloat16)  # t603: "thunder.devices.Device(type='cuda:0') bf16[2560]"
  # t607 = prims.convert_element_type(t606, dtypes.thunder.dtypes.float32)  # t607: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t608 = prims.sum(t607, (0, 1))  # t608: "thunder.devices.Device(type='cuda:0') f32[2560]"
  # t609 = prims.convert_element_type(t608, dtypes.thunder.dtypes.bfloat16)  # t609: "thunder.devices.Device(type='cuda:0') bf16[2560]"
  # t612 = prims.sum(t611, (0, 2))  # t612: "thunder.devices.Device(type='cuda:0') f32[2048]"
  # t616 = prims.sum(t614, (0, 2))  # t616: "thunder.devices.Device(type='cuda:0') f32[2048]"
  # t1032 = prims.convert_element_type(t1031, dtypes.thunder.dtypes.float32)  # t1032: "thunder.devices.Device(type='cuda:0') f32[50281, 2560]"
  # t1033 = prims.div(t1032, 7.0)  # t1033: "thunder.devices.Device(type='cuda:0') f32[50281, 2560]"
  # t1034 = prims.convert_element_type(t1033, dtypes.thunder.dtypes.bfloat16)  # t1034: "thunder.devices.Device(type='cuda:0') bf16[50281, 2560]"
# rematerialized_inputs
("t606")
# cut_inputs
("t200","t201", "t204", "t593","t598","t677", )

# recomputing_symbols
t202 = ltorch.mul(t200, t201)  # t202: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t202 = prims.mul(t200, t201)  # t202: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]",
 t605 = ltorch.mul(t202, t599)  # t605: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]"
  # t605 = prims.mul(t202, t599)  # t605: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]",
 t606 = prims.convert_element_type(t605, dtypes.thunder.dtypes.bfloat16)  # t606: "thunder.devices.Device(type='cuda:0') bf16[1, 2048, 2560]",
 t594 = prims.reshape(t593, (1, 2048, 2560))  # t594: "thunder.devices.Device(type='cuda:0') bf16[1, 2048, 2560]",
 t599 = prims.convert_element_type(t594, dtypes.thunder.dtypes.float32)  # t599: "thunder.devices.Device(type='cuda:0') f32[1, 2048, 2560]")

# new_consumer
[t1034, t603, t609, t612, t616, t679] = nvFusion1(t200, t201, t204, t593, t598, t677)
    # t202 = prims.mul(t200, t201)  # t202: "cuda:0 f32[1, 2048, 2560]"
    # t605 = prims.mul(t202, t599)  # t605: "cuda:0 f32[1, 2048, 2560]"
    # t606 = prims.convert_element_type(t605, dtypes.bfloat16)  # t606: "cuda:0 bf16[1, 2048, 2560]"
    # t594 = prims.reshape(t593, (1, 2048, 2560))  # t594: "cuda:0 bf16[1, 2048, 2560]"
    # t599 = prims.convert_element_type(t594, dtypes.float32)  # t599: "cuda:0 f32[1, 2048, 2560]"
    # t594 = prims.reshape(t593, (1, 2048, 2560))  # t594: "cuda:0 bf16[1, 2048, 2560]"
    # t599 = prims.convert_element_type(t594, dtypes.float32)  # t599: "cuda:0 f32[1, 2048, 2560]"
    # t604 = prims.mul(t204, t599)  # t604: "cuda:0 f32[1, 2048, 2560]"
    # t610 = prims.mul(t201, t604)  # t610: "cuda:0 f32[1, 2048, 2560]"
    # t611 = prims.mul(t200, t604)  # t611: "cuda:0 f32[1, 2048, 2560]"
    # t614 = prims.neg(t610)  # t614: "cuda:0 f32[1, 2048, 2560]"
    # t678 = prims.neg(t677)  # t678: "cuda:0 f32[1, 2048, 10240]"
    # t1031 = prims.pad(t598, 0.0, ((0, 1, 0), (0, 0, 0)))  # t1031: "cuda:0 bf16[50281, 2560]"
    # t679 = prims.exp(t678)  # t679: "cuda:0 f32[1, 2048, 10240]"
    # t602 = prims.sum(t599, (0, 1))  # t602: "cuda:0 f32[2560]"
    # t603 = prims.convert_element_type(t602, dtypes.bfloat16)  # t603: "cuda:0 bf16[2560]"
    # t607 = prims.convert_element_type(t606, dtypes.float32)  # t607: "cuda:0 f32[1, 2048, 2560]"
    # t608 = prims.sum(t607, (0, 1))  # t608: "cuda:0 f32[2560]"
    # t609 = prims.convert_element_type(t608, dtypes.bfloat16)  # t609: "cuda:0 bf16[2560]"
    # t612 = prims.sum(t611, (0, 2))  # t612: "cuda:0 f32[2048]"
    # t616 = prims.sum(t614, (0, 2))  # t616: "cuda:0 f32[2048]"
    # t1032 = prims.convert_element_type(t1031, dtypes.float32)  # t1032: "cuda:0 f32[50281, 2560]"
    # t1033 = prims.div(t1032, 7.0)  # t1033: "cuda:0 f32[50281, 2560]"
    # t1034 = prims.convert_element_type(t1033, dtypes.bfloat16)  # t1034: "cuda:0 bf16[50281, 2560]"

Since the rematerialized_inputs comes from the inputs of consumer, so we can find the recomputing_symbols from the subsymbols in producer.

Fixes #665

@IvanYashchuk
Copy link
Collaborator

A minimal example reproducing the problem and a test is required here.

@kiya00 kiya00 changed the title Remove duplicated names in rematerialization (#665) Fix for finding the recomputing symbols in rematerialization Jul 5, 2024
@kiya00
Copy link
Collaborator Author

kiya00 commented Jul 5, 2024

Test results on 2nodes(8 H100):

[viking-prod-283:0]:iter 34: loss 4.6562, iter time: 3934.96ms, t: 2048
[viking-prod-283:0]:iter 35: loss 4.6562, iter time: 3934.41ms, t: 2048
[viking-prod-283:0]:iter 36: loss 4.6562, iter time: 3932.98ms, t: 2048
[viking-prod-283:0]:iter 37: loss 4.6562, iter time: 3930.01ms, t: 2048
[viking-prod-283:0]:iter 38: loss 4.6562, iter time: 3935.24ms, t: 2048
[viking-prod-283:0]:iter 39: loss 4.6875, iter time: 3937.44ms, t: 2048
[viking-prod-283:0]:iter 40: loss 4.6562, iter time: 3930.84ms, t: 2048
[viking-prod-283:0]:iter 41: loss 4.6562, iter time: 3928.77ms, t: 2048
[viking-prod-283:0]:iter 42: loss 4.6562, iter time: 3929.16ms, t: 2048
[viking-prod-283:0]:iter 43: loss 4.6562, iter time: 3930.82ms, t: 2048
[viking-prod-283:0]:iter 44: loss 4.6562, iter time: 3928.82ms, t: 2048
[viking-prod-283:0]:Model name: dolly-v2-3b
[viking-prod-283:0]:Seq Length: 2048
[viking-prod-283:0]:Micro BS: 1
[viking-prod-283:0]:Global BS: 16
[viking-prod-283:0]:Number of Layers: 32
[viking-prod-283:0]:Number of parameters: 0.17B
[viking-prod-283:0]:Distributed Mode: fsdp
[viking-prod-283:0]:Sharding Mode: zero2
[viking-prod-283:0]:Bucketing: none
[viking-prod-283:0]:Compiler: thunder_inductor_cat_cudnn
[viking-prod-283:0]:Low Precision Mode: none
[viking-prod-283:0]:Average iter time: 3934.10 ms
[viking-prod-283:0]:Memory used: 21.66 GB
[viking-prod-283:0]:Tokens/s: 8329.40
[viking-prod-283:0]:Tokens/s/GPU: 520.59
[viking-prod-283:0]:TFLOP/s: 148.97

@kiya00 kiya00 marked this pull request as ready for review July 5, 2024 13:35
@kiya00 kiya00 requested review from mruberry, lantiga and t-vi as code owners July 5, 2024 13:35
@t-vi t-vi enabled auto-merge (squash) July 5, 2024 13:44
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi t-vi merged commit 0d80444 into main Jul 5, 2024
39 checks passed
@t-vi t-vi deleted the fix665 branch July 5, 2024 13:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

dolly-v2-3b with thunder_inductor_cat_cudnn fails with KeyError: \'t5905\'
3 participants