Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf Regression : SDPA is recomputed in backward #1646

Closed
kshitij12345 opened this issue Jan 15, 2025 · 2 comments · Fixed by #1648
Closed

Perf Regression : SDPA is recomputed in backward #1646

kshitij12345 opened this issue Jan 15, 2025 · 2 comments · Fixed by #1648
Assignees
Labels
high priority memory use nemo Issues needed to support NVIDIA NeMo models. performance

Comments

@kshitij12345
Copy link
Collaborator

kshitij12345 commented Jan 15, 2025

Since #1615, recomputing in backward is on by default.

enable_saved_for_backward_recomputation: None | bool = get_compile_option(
"enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation."
)
if enable_saved_for_backward_recomputation is None or remat_policy:
enable_saved_for_backward_recomputation = True

This means that we may end up recomputing compute heavy ops like SDPA leading to perf regression.

Running python thunder/benchmarks/benchmark_litgpt.py --model_name Llama-2-7b-hf --compile dynamo_thunder --micro_batch_size 1 --dump_thunder_traces true

Current Main

Model name: Llama-2-7b-hf
Seq Length: 4096
Micro BS: 1
Global BS: 1
Number of Layers: 32
Number of parameters: 6.74B
Distributed Mode: none
Compiler: dynamo_thunder
Low Precision Mode: none
Average iter time: 350.15 ms
Memory used: 64.17 GB
Tokens/s: 11697.73
Tokens/s/GPU: 11697.73
TFLOP/s: 539.09

With patch to disable recomputing SDPA

Model name: Llama-2-7b-hf
Seq Length: 4096
Micro BS: 1
Global BS: 1
Number of Layers: 32
Number of parameters: 6.74B
Distributed Mode: none
Compiler: dynamo_thunder
Low Precision Mode: none
Average iter time: 341.78 ms
Memory used: 64.18 GB
Tokens/s: 11988.92
Tokens/s/GPU: 11988.92
TFLOP/s: 552.51

I think we should test recomputing for performance and memory usage on multiple models and then turn it on by default based on the numbers.

Also, we should form a list of compute heavy operations for which recomputing might not be best choice and tag them accordingly.

Patch to disable recomputing for SDPA
diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py
index 409c1672..53ce186a 100644
--- a/thunder/executors/cudnnex.py
+++ b/thunder/executors/cudnnex.py
@@ -79,6 +79,7 @@ import thunder.core.dtypes as dtypes
 from thunder.torch import TensorLike
 from thunder.core.compile_data import get_compile_option
 from thunder.core.proxies import Proxy, TensorProxy
+import thunder
 
 
 from thunder.core.transforms import (
@@ -425,6 +426,7 @@ cudnn_sdpa_fwd = cudnn_ex.register_operator(
     "cudnn_sdpa_fwd",
     meta=_cudnn_sdpa_forward_meta,
     fn=_cudnn_sdpa_fwd_impl,
+    tags=(thunder.prims.OpTags.DONT_RECOMPUTE_IN_BACKWARD,)
 )
 
 
diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py
index 0d10f035..743f9e5e 100644
--- a/thunder/torch/__init__.py
+++ b/thunder/torch/__init__.py
@@ -5206,7 +5206,7 @@ def mse_loss(
 
 # TODO Add annotations
 # NOTE The scale parameter is kwarg-only in PyTorch
-@torchsymbol(torch.nn.functional.scaled_dot_product_attention)
+@torchsymbol(torch.nn.functional.scaled_dot_product_attention, tags=(prims.OpTags.DONT_RECOMPUTE_IN_BACKWARD,))
 def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *, scale=None):
     for arg_name, arg in zip(("query", "key", "value"), (query, key, value)):
         utils.check(
Sample Backward Trace for a 1 layer model (note that there is cudnn_sdpa_fwd in the trace)
# Constructed by Delete Last Used (took 1 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t177, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  k_2, l_idx_, l_self_buffers_cos_, l_self_buffers_sin_, \
  l_self_modules_lm_head_parameters_weight_, \
  l_self_modules_transformer_modules_h_modules_0_modules_attn_modules_attn_parameters_weight_, \
  l_self_modules_transformer_modules_h_modules_0_modules_attn_modules_proj_parameters_weight_, \
  l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_fc_1_parameters_weight_, \
  l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_fc_2_parameters_weight_, \
  l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_proj_parameters_weight_, \
  l_self_modules_transformer_modules_h_modules_0_modules_norm_1_parameters_weight_, \
  l_self_modules_transformer_modules_h_modules_0_modules_norm_2_parameters_weight_, \
  l_self_modules_transformer_modules_ln_f_parameters_weight_, q_2, rsqrt, \
  rsqrt_1, rsqrt_2, t304, t306, t307, t308, t309, t312, to_3, v_1, x_2, x_4, x_7, \
  x_normed_1, y, = C0
  clear_mutable_collection(C0)
  del C0
  bw_t773 = torch.reshape(t177, (-1, 32000))  # bw_t773: "cuda:0 bf16[4096, 32000]"
    # bw_t773 = ltorch.reshape(t177, (-1, 32000))  # bw_t773: "cuda:0 bf16[4096, 32000]"
      # bw_t773 = prims.reshape(t177, (4096, 32000))  # bw_t773: "cuda:0 bf16[4096, 32000]"
  del t177
  bw_t760 = torch.matmul(bw_t773, l_self_modules_lm_head_parameters_weight_)  # bw_t760: "cuda:0 bf16[4096, 4096]"
    # bw_t760 = ltorch.matmul(bw_t773, l_self_modules_lm_head_parameters_weight_)  # bw_t760: "cuda:0 bf16[4096, 4096]"
      # bw_t760 = prims.matmul(bw_t773, l_self_modules_lm_head_parameters_weight_)  # bw_t760: "cuda:0 bf16[4096, 4096]"
  del l_self_modules_lm_head_parameters_weight_
  bw_t774 = torch.reshape(bw_t760, (1, 4096, 4096))  # bw_t774: "cuda:0 bf16[1, 4096, 4096]"
    # bw_t774 = ltorch.reshape(bw_t760, (1, 4096, 4096))  # bw_t774: "cuda:0 bf16[1, 4096, 4096]"
      # bw_t774 = prims.reshape(bw_t760, (1, 4096, 4096))  # bw_t774: "cuda:0 bf16[1, 4096, 4096]"
  del bw_t760
  bw_t775 = torch.permute(bw_t773, (1, 0))  # bw_t775: "cuda:0 bf16[32000, 4096]"
    # bw_t775 = ltorch.permute(bw_t773, (1, 0))  # bw_t775: "cuda:0 bf16[32000, 4096]"
      # bw_t775 = prims.transpose(bw_t773, (1, 0))  # bw_t775: "cuda:0 bf16[32000, 4096]"
  del bw_t773
  bw_t776 = torch.reshape(x_7, (-1, 4096))  # bw_t776: "cuda:0 bf16[4096, 4096]"
    # bw_t776 = ltorch.reshape(x_7, (-1, 4096))  # bw_t776: "cuda:0 bf16[4096, 4096]"
      # bw_t776 = prims.reshape(x_7, (4096, 4096))  # bw_t776: "cuda:0 bf16[4096, 4096]"
  del x_7
  bw_t761 = torch.matmul(bw_t775, bw_t776)  # bw_t761: "cuda:0 bf16[32000, 4096]"
    # bw_t761 = ltorch.matmul(bw_t775, bw_t776)  # bw_t761: "cuda:0 bf16[32000, 4096]"
      # bw_t761 = prims.matmul(bw_t775, bw_t776)  # bw_t761: "cuda:0 bf16[32000, 4096]"
  del bw_t775, bw_t776
  [bw_t498, bw_t502, bw_t478] = nvFusion0(l_self_modules_transformer_modules_ln_f_parameters_weight_, bw_t774, x_2, t309, rsqrt_2)

  del l_self_modules_transformer_modules_ln_f_parameters_weight_, bw_t774, x_2, t309, rsqrt_2
  bw_t762 = torch.matmul(bw_t502, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_proj_parameters_weight_)  # bw_t762: "cuda:0 bf16[4096, 11008]"
    # bw_t762 = ltorch.matmul(bw_t502, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_proj_parameters_weight_)  # bw_t762: "cuda:0 bf16[4096, 11008]"
      # bw_t762 = prims.matmul(bw_t502, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_proj_parameters_weight_)  # bw_t762: "cuda:0 bf16[4096, 11008]"
  del bw_t502, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_proj_parameters_weight_
  bw_t777 = torch.reshape(bw_t762, (1, 4096, 11008))  # bw_t777: "cuda:0 bf16[1, 4096, 11008]"
    # bw_t777 = ltorch.reshape(bw_t762, (1, 4096, 11008))  # bw_t777: "cuda:0 bf16[1, 4096, 11008]"
      # bw_t777 = prims.reshape(bw_t762, (1, 4096, 11008))  # bw_t777: "cuda:0 bf16[1, 4096, 11008]"
  del bw_t762
  bw_t778 = torch.reshape(bw_t498, (-1, 4096))  # bw_t778: "cuda:0 bf16[4096, 4096]"
    # bw_t778 = ltorch.reshape(bw_t498, (-1, 4096))  # bw_t778: "cuda:0 bf16[4096, 4096]"
      # bw_t778 = prims.reshape(bw_t498, (4096, 4096))  # bw_t778: "cuda:0 bf16[4096, 4096]"
  bw_t779 = torch.permute(bw_t778, (1, 0))  # bw_t779: "cuda:0 bf16[4096, 4096]"
    # bw_t779 = ltorch.permute(bw_t778, (1, 0))  # bw_t779: "cuda:0 bf16[4096, 4096]"
      # bw_t779 = prims.transpose(bw_t778, (1, 0))  # bw_t779: "cuda:0 bf16[4096, 4096]"
  del bw_t778
  bw_t780 = torch.reshape(x_4, (-1, 11008))  # bw_t780: "cuda:0 bf16[4096, 11008]"
    # bw_t780 = ltorch.reshape(x_4, (-1, 11008))  # bw_t780: "cuda:0 bf16[4096, 11008]"
      # bw_t780 = prims.reshape(x_4, (4096, 11008))  # bw_t780: "cuda:0 bf16[4096, 11008]"
  del x_4
  bw_t763 = torch.matmul(bw_t779, bw_t780)  # bw_t763: "cuda:0 bf16[4096, 11008]"
    # bw_t763 = ltorch.matmul(bw_t779, bw_t780)  # bw_t763: "cuda:0 bf16[4096, 11008]"
      # bw_t763 = prims.matmul(bw_t779, bw_t780)  # bw_t763: "cuda:0 bf16[4096, 11008]"
  del bw_t779, bw_t780
  [bw_t512, bw_t529, bw_t530] = nvFusion1(t307, bw_t777, t308)

  del t307, bw_t777, t308
  bw_t764 = torch.matmul(bw_t530, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_fc_2_parameters_weight_)  # bw_t764: "cuda:0 bf16[4096, 4096]"
    # bw_t764 = ltorch.matmul(bw_t530, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_fc_2_parameters_weight_)  # bw_t764: "cuda:0 bf16[4096, 4096]"
      # bw_t764 = prims.matmul(bw_t530, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_fc_2_parameters_weight_)  # bw_t764: "cuda:0 bf16[4096, 4096]"
  del bw_t530, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_fc_2_parameters_weight_
  bw_t781 = torch.reshape(bw_t764, (1, 4096, 4096))  # bw_t781: "cuda:0 bf16[1, 4096, 4096]"
    # bw_t781 = ltorch.reshape(bw_t764, (1, 4096, 4096))  # bw_t781: "cuda:0 bf16[1, 4096, 4096]"
      # bw_t781 = prims.reshape(bw_t764, (1, 4096, 4096))  # bw_t781: "cuda:0 bf16[1, 4096, 4096]"
  del bw_t764
  bw_t782 = torch.reshape(bw_t512, (-1, 11008))  # bw_t782: "cuda:0 bf16[4096, 11008]"
    # bw_t782 = ltorch.reshape(bw_t512, (-1, 11008))  # bw_t782: "cuda:0 bf16[4096, 11008]"
      # bw_t782 = prims.reshape(bw_t512, (4096, 11008))  # bw_t782: "cuda:0 bf16[4096, 11008]"
  del bw_t512
  bw_t783 = torch.permute(bw_t782, (1, 0))  # bw_t783: "cuda:0 bf16[11008, 4096]"
    # bw_t783 = ltorch.permute(bw_t782, (1, 0))  # bw_t783: "cuda:0 bf16[11008, 4096]"
      # bw_t783 = prims.transpose(bw_t782, (1, 0))  # bw_t783: "cuda:0 bf16[11008, 4096]"
  del bw_t782
  bw_t784 = torch.reshape(to_3, (-1, 4096))  # bw_t784: "cuda:0 bf16[4096, 4096]"
    # bw_t784 = ltorch.reshape(to_3, (-1, 4096))  # bw_t784: "cuda:0 bf16[4096, 4096]"
      # bw_t784 = prims.reshape(to_3, (4096, 4096))  # bw_t784: "cuda:0 bf16[4096, 4096]"
  del to_3
  bw_t765 = torch.matmul(bw_t783, bw_t784)  # bw_t765: "cuda:0 bf16[11008, 4096]"
    # bw_t765 = ltorch.matmul(bw_t783, bw_t784)  # bw_t765: "cuda:0 bf16[11008, 4096]"
      # bw_t765 = prims.matmul(bw_t783, bw_t784)  # bw_t765: "cuda:0 bf16[11008, 4096]"
  del bw_t783
  bw_t785 = torch.reshape(bw_t529, (-1, 11008))  # bw_t785: "cuda:0 bf16[4096, 11008]"
    # bw_t785 = ltorch.reshape(bw_t529, (-1, 11008))  # bw_t785: "cuda:0 bf16[4096, 11008]"
      # bw_t785 = prims.reshape(bw_t529, (4096, 11008))  # bw_t785: "cuda:0 bf16[4096, 11008]"
  del bw_t529
  bw_t766 = torch.matmul(bw_t785, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_fc_1_parameters_weight_)  # bw_t766: "cuda:0 bf16[4096, 4096]"
    # bw_t766 = ltorch.matmul(bw_t785, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_fc_1_parameters_weight_)  # bw_t766: "cuda:0 bf16[4096, 4096]"
      # bw_t766 = prims.matmul(bw_t785, l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_fc_1_parameters_weight_)  # bw_t766: "cuda:0 bf16[4096, 4096]"
  del l_self_modules_transformer_modules_h_modules_0_modules_mlp_modules_fc_1_parameters_weight_
  bw_t786 = torch.reshape(bw_t766, (1, 4096, 4096))  # bw_t786: "cuda:0 bf16[1, 4096, 4096]"
    # bw_t786 = ltorch.reshape(bw_t766, (1, 4096, 4096))  # bw_t786: "cuda:0 bf16[1, 4096, 4096]"
      # bw_t786 = prims.reshape(bw_t766, (1, 4096, 4096))  # bw_t786: "cuda:0 bf16[1, 4096, 4096]"
  del bw_t766
  bw_t787 = torch.permute(bw_t785, (1, 0))  # bw_t787: "cuda:0 bf16[11008, 4096]"
    # bw_t787 = ltorch.permute(bw_t785, (1, 0))  # bw_t787: "cuda:0 bf16[11008, 4096]"
      # bw_t787 = prims.transpose(bw_t785, (1, 0))  # bw_t787: "cuda:0 bf16[11008, 4096]"
  del bw_t785
  bw_t767 = torch.matmul(bw_t787, bw_t784)  # bw_t767: "cuda:0 bf16[11008, 4096]"
    # bw_t767 = ltorch.matmul(bw_t787, bw_t784)  # bw_t767: "cuda:0 bf16[11008, 4096]"
      # bw_t767 = prims.matmul(bw_t787, bw_t784)  # bw_t767: "cuda:0 bf16[11008, 4096]"
  del bw_t787, bw_t784
  [bw_t576, bw_t580, bw_t552] = nvFusion2(bw_t786, bw_t781, l_self_modules_transformer_modules_h_modules_0_modules_norm_2_parameters_weight_, t304, t306, rsqrt_1, bw_t498)

  del bw_t786, bw_t781, l_self_modules_transformer_modules_h_modules_0_modules_norm_2_parameters_weight_, t306, rsqrt_1, bw_t498
  bw_t768 = torch.matmul(bw_t580, l_self_modules_transformer_modules_h_modules_0_modules_attn_modules_proj_parameters_weight_)  # bw_t768: "cuda:0 bf16[4096, 4096]"
    # bw_t768 = ltorch.matmul(bw_t580, l_self_modules_transformer_modules_h_modules_0_modules_attn_modules_proj_parameters_weight_)  # bw_t768: "cuda:0 bf16[4096, 4096]"
      # bw_t768 = prims.matmul(bw_t580, l_self_modules_transformer_modules_h_modules_0_modules_attn_modules_proj_parameters_weight_)  # bw_t768: "cuda:0 bf16[4096, 4096]"
  del bw_t580, l_self_modules_transformer_modules_h_modules_0_modules_attn_modules_proj_parameters_weight_
  bw_t788 = torch.reshape(bw_t768, (1, 4096, 4096))  # bw_t788: "cuda:0 bf16[1, 4096, 4096]"
    # bw_t788 = ltorch.reshape(bw_t768, (1, 4096, 4096))  # bw_t788: "cuda:0 bf16[1, 4096, 4096]"
      # bw_t788 = prims.reshape(bw_t768, (1, 4096, 4096))  # bw_t788: "cuda:0 bf16[1, 4096, 4096]"
  del bw_t768
  bw_t789 = torch.reshape(bw_t576, (-1, 4096))  # bw_t789: "cuda:0 bf16[4096, 4096]"
    # bw_t789 = ltorch.reshape(bw_t576, (-1, 4096))  # bw_t789: "cuda:0 bf16[4096, 4096]"
      # bw_t789 = prims.reshape(bw_t576, (4096, 4096))  # bw_t789: "cuda:0 bf16[4096, 4096]"
  bw_t790 = torch.permute(bw_t789, (1, 0))  # bw_t790: "cuda:0 bf16[4096, 4096]"
    # bw_t790 = ltorch.permute(bw_t789, (1, 0))  # bw_t790: "cuda:0 bf16[4096, 4096]"
      # bw_t790 = prims.transpose(bw_t789, (1, 0))  # bw_t790: "cuda:0 bf16[4096, 4096]"
  del bw_t789
  bw_t791 = torch.reshape(t312, (-1, 4096))  # bw_t791: "cuda:0 bf16[4096, 4096]"
    # bw_t791 = ltorch.reshape(t312, (-1, 4096))  # bw_t791: "cuda:0 bf16[4096, 4096]"
      # bw_t791 = prims.reshape(t312, (4096, 4096))  # bw_t791: "cuda:0 bf16[4096, 4096]"
  del t312
  bw_t769 = torch.matmul(bw_t790, bw_t791)  # bw_t769: "cuda:0 bf16[4096, 4096]"
    # bw_t769 = ltorch.matmul(bw_t790, bw_t791)  # bw_t769: "cuda:0 bf16[4096, 4096]"
      # bw_t769 = prims.matmul(bw_t790, bw_t791)  # bw_t769: "cuda:0 bf16[4096, 4096]"
  del bw_t790, bw_t791
  bw_t792 = torch_prims_reshape_impl(bw_t788, (1, 4096, 32, 128))  # bw_t792: "cuda:0 bf16[1, 4096, 32, 128]"
  del bw_t788
  bw_t793 = torch.permute(bw_t792, (0, 2, 1, 3))  # bw_t793: "cuda:0 bf16[1, 32, 4096, 128]"
    # bw_t793 = ltorch.permute(bw_t792, (0, 2, 1, 3))  # bw_t793: "cuda:0 bf16[1, 32, 4096, 128]"
      # bw_t793 = prims.transpose(bw_t792, (0, 2, 1, 3))  # bw_t793: "cuda:0 bf16[1, 32, 4096, 128]"
  del bw_t792

  # <eval_with_key>.4:49:           y = torch._C._nn.scaled_dot_product_attention(q_2, k_2, v_1, attn_mask = None, dropout_p = 0.0, scale = 0.08838834764831843, is_causal = True);  q_2 = k_2 = v_1 = None
  (_, bw_t238, bw_t239, bw_t240) = cudnn_sdpa_fwd(q_2, k_2, v_1, None, 0.0, True, scale=0.08838834764831843)
  (bw_t594, bw_t595, bw_t596) = cudnn_sdpa_bwd(bw_t793, q_2, k_2, v_1, None, 0.0, True, y, bw_t238, bw_t239, bw_t240, scale=0.08838834764831843, cat_grad_qkv=False)
  del bw_t793, q_2, k_2, v_1, y, bw_t238, bw_t239, bw_t240
  [bw_t722, bw_t723] = TorchCompile0(bw_t595, bw_t594, l_self_buffers_sin_, l_self_buffers_cos_, bw_t596)

  del bw_t595, bw_t594, l_self_buffers_sin_, l_self_buffers_cos_, bw_t596
  bw_t770 = torch.matmul(bw_t723, l_self_modules_transformer_modules_h_modules_0_modules_attn_modules_attn_parameters_weight_)  # bw_t770: "cuda:0 bf16[4096, 4096]"
    # bw_t770 = ltorch.matmul(bw_t723, l_self_modules_transformer_modules_h_modules_0_modules_attn_modules_attn_parameters_weight_)  # bw_t770: "cuda:0 bf16[4096, 4096]"
      # bw_t770 = prims.matmul(bw_t723, l_self_modules_transformer_modules_h_modules_0_modules_attn_modules_attn_parameters_weight_)  # bw_t770: "cuda:0 bf16[4096, 4096]"
  del bw_t723, l_self_modules_transformer_modules_h_modules_0_modules_attn_modules_attn_parameters_weight_
  bw_t794 = torch.reshape(bw_t770, (1, 4096, 4096))  # bw_t794: "cuda:0 bf16[1, 4096, 4096]"
    # bw_t794 = ltorch.reshape(bw_t770, (1, 4096, 4096))  # bw_t794: "cuda:0 bf16[1, 4096, 4096]"
      # bw_t794 = prims.reshape(bw_t770, (1, 4096, 4096))  # bw_t794: "cuda:0 bf16[1, 4096, 4096]"
  del bw_t770
  bw_t795 = torch.reshape(bw_t722, (-1, 12288))  # bw_t795: "cuda:0 bf16[4096, 12288]"
    # bw_t795 = ltorch.reshape(bw_t722, (-1, 12288))  # bw_t795: "cuda:0 bf16[4096, 12288]"
      # bw_t795 = prims.reshape(bw_t722, (4096, 12288))  # bw_t795: "cuda:0 bf16[4096, 12288]"
  del bw_t722
  bw_t796 = torch.permute(bw_t795, (1, 0))  # bw_t796: "cuda:0 bf16[12288, 4096]"
    # bw_t796 = ltorch.permute(bw_t795, (1, 0))  # bw_t796: "cuda:0 bf16[12288, 4096]"
      # bw_t796 = prims.transpose(bw_t795, (1, 0))  # bw_t796: "cuda:0 bf16[12288, 4096]"
  del bw_t795
  bw_t797 = torch.reshape(x_normed_1, (-1, 4096))  # bw_t797: "cuda:0 bf16[4096, 4096]"
    # bw_t797 = ltorch.reshape(x_normed_1, (-1, 4096))  # bw_t797: "cuda:0 bf16[4096, 4096]"
      # bw_t797 = prims.reshape(x_normed_1, (4096, 4096))  # bw_t797: "cuda:0 bf16[4096, 4096]"
  del x_normed_1
  bw_t771 = torch.matmul(bw_t796, bw_t797)  # bw_t771: "cuda:0 bf16[12288, 4096]"
    # bw_t771 = ltorch.matmul(bw_t796, bw_t797)  # bw_t771: "cuda:0 bf16[12288, 4096]"
      # bw_t771 = prims.matmul(bw_t796, bw_t797)  # bw_t771: "cuda:0 bf16[12288, 4096]"
  del bw_t796, bw_t797
  [bw_t734, bw_t758] = nvFusion3(l_self_modules_transformer_modules_h_modules_0_modules_norm_1_parameters_weight_, bw_t794, t304, rsqrt, bw_t576)

  del l_self_modules_transformer_modules_h_modules_0_modules_norm_1_parameters_weight_, bw_t794, t304, rsqrt, bw_t576
  bw_t798 = torch.torch.ops.aten.embedding_backward(bw_t758, l_idx_, 32000, -1, False, False)  # bw_t798: "cuda:0 bf16[32000, 4096]"
    # bw_t798 = ltorch.embedding_backward(bw_t758, l_idx_, 32000, -1, False, False)  # bw_t798: "cuda:0 bf16[32000, 4096]"
      # bw_t798 = prims.embedding_backward(bw_t758, l_idx_, 32000, -1, False, False)  # bw_t798: "cuda:0 bf16[32000, 4096]"
  del bw_t758, l_idx_
  return (None, None, None, bw_t798, bw_t734, bw_t771, bw_t769, bw_t552, bw_t767, bw_t765, bw_t763, bw_t478, bw_t761)

cc: @t-vi @IvanYashchuk @riccardofelluga

cc @t-vi @lantiga @tfogal

@kshitij12345 kshitij12345 added performance memory use nemo Issues needed to support NVIDIA NeMo models. labels Jan 15, 2025
@t-vi
Copy link
Collaborator

t-vi commented Jan 15, 2025

Note that unless you use the (legacy?) options, only tensors tagged as RECOMPUTE_IN_BACKWARD get recomputed.

Is the regression here from recomputing intermediates in decompositions?

# when we decompose to compute the forward/backward, we mark intermediates as to be recomputed in the backward.
# Typically our decompositions are for things that will then be fused together.
# We could refine this heuristic to exclude "expensive" operations.
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)

I would not want to have DONT_RECOMPUTE_IN_BACKWARD for perf because it would override checkpointing.

I basically see two ways:

  • If it is from the above snippet: Given that we have rematerialization, we could remove that heuristic altogether if it does not cause memory regressions elsewhere.
  • We could have a DONT_AUTO_RECOMPUTE_IN_BACKWARD for expensive ops (so probably matmul, sdpa, linear, ???) that specifically disables the above for certain ops.

WDYT?

@t-vi t-vi self-assigned this Jan 15, 2025
@kshitij12345
Copy link
Collaborator Author

kshitij12345 commented Jan 15, 2025

Is the regression here from recomputing intermediates in decompositions?

# TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym?
for o in new_bsym.flat_proxy_outs:
if variableify(o) not in swap_map:
# when we decompose to compute the forward/backward, we mark intermediates as to be recomputed in the backward.
# Typically our decompositions are for things that will then be fused together.
# We could refine this heuristic to exclude "expensive" operations.
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)

I did a quick check and one of the output of cudnn is not present in swapmap, so it is tagged for recompute:

(t9284, t9285, t9286, t9287) = cudnn_sdpa_fwd(q_95, k_95, v_63, None, 0.0, True, scale=0.08838834764831843)
Output not in Swapmap <TensorProxy(name="t9285", dtype=thunder.dtypes.float32, shape=(1, 32, 4096, 1))>

WDYT?

I think it would be safer to have DONT_AUTO_RECOMPUTE_IN_BACKWARD for expensive ops so that we are sure they won't surprise us by being recomputed in backward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority memory use nemo Issues needed to support NVIDIA NeMo models. performance
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants