Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expected the batch dimensions of a (((384,),)) and the batch dimensions of b (()) to be the same #826

Closed
tfogal opened this issue Jul 22, 2024 · 0 comments · Fixed by #838
Assignees
Labels
dynamo high priority nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage

Comments

@tfogal
Copy link
Collaborator

tfogal commented Jul 22, 2024

🚀 Model / language coverage

Examine fails on one of the NeVA subgraphs (as divided by Dynamo):

Expected the batch dimensions of a (((384,),)) and the batch dimensions of b (()) to be the same

This can be reproduced by defining:

def thunder_backend(gm, args):
  gm.real_recompile()
  from thunder.examine import examine
  try:
      thunder.examine.examine(gm, *args)
  except Exception as e:
      print(f"Hit problem with examine:\n{e}")
      print(f"gm: {gm.print_readable()}")
  # Don't really use Thunder just return the original graph
  return gm

and then compiling the model (in main) with model.model = torch.compile(backend=thunder_backend)(model.model).

Pitch

This is for #343.

Alternatives / Potential work-arounds

Minimal Repro

Hit problem with examine:
Expected the batch dimensions of a (((384,),)) and the batch dimensions of b (()) to be the same
class GraphModule(torch.nn.Module):
    def forward(self, L_self_modules_query_key_value_parameters_weight_: "f32[15360, 5120]", L_hidden_states_: "f16[384, 2, 5120]"):
        l_self_modules_query_key_value_parameters_weight_ = L_self_modules_query_key_value_parameters_weight_
        l_hidden_states_ = L_hidden_states_
        
         # File: /home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/mappings.py:442 in copy_to_tensor_model_parallel_region, code: return _CopyToModelParallelRegion.apply(input_)
        function_ctx = torch.autograd.function.FunctionCtx()
        
         # File: /home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/layers.py:382 in linear_with_frozen_weight, code: return LinearWithFrozenWeight.apply(*args)
        function_ctx_1 = torch.autograd.function.FunctionCtx()
        
         # File: /home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/layers.py:288 in forward, code: output = torch.matmul(input, weight.t())
        t: "f32[5120, 15360]" = l_self_modules_query_key_value_parameters_weight_.t();  l_self_modules_query_key_value_parameters_weight_ = None
        output: "f16[384, 2, 15360]" = torch.matmul(l_hidden_states_, t);  l_hidden_states_ = t = None
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:426 in forward, code: mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
        mixed_x_layer: "f16[384, 2, 40, 384]" = output.view(384, 2, 40, 384);  output = None
        
         # File: /home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/utils.py:34 in split_tensor_along_last_dim, code: tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
        split = torch.functional.split(mixed_x_layer, 128, dim = 3);  mixed_x_layer = None
        chunk: "f16[384, 2, 40, 128]" = split[0]
        chunk_1: "f16[384, 2, 40, 128]" = split[1]
        chunk_2: "f16[384, 2, 40, 128]" = split[2];  split = None
        
         # File: /home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/utils.py:37 in <genexpr>, code: return tuple(chunk.contiguous() for chunk in tensor_list)
        query_layer: "f16[384, 2, 40, 128]" = chunk.contiguous();  chunk = None
        key_layer: "f16[384, 2, 40, 128]" = chunk_1.contiguous();  chunk_1 = None
        value_layer: "f16[384, 2, 40, 128]" = chunk_2.contiguous();  chunk_2 = None
        return (query_layer, key_layer, value_layer)
        
gm: class GraphModule(torch.nn.Module):
    def forward(self, L_self_modules_query_key_value_parameters_weight_: "f32[15360, 5120]", L_hidden_states_: "f16[384, 2, 5120]"):
        l_self_modules_query_key_value_parameters_weight_ = L_self_modules_query_key_value_parameters_weight_
        l_hidden_states_ = L_hidden_states_
        
         # File: /home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/mappings.py:442 in copy_to_tensor_model_parallel_region, code: return _CopyToModelParallelRegion.apply(input_)
        function_ctx = torch.autograd.function.FunctionCtx()
        
         # File: /home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/layers.py:382 in linear_with_frozen_weight, code: return LinearWithFrozenWeight.apply(*args)
        function_ctx_1 = torch.autograd.function.FunctionCtx()
        
         # File: /home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/layers.py:288 in forward, code: output = torch.matmul(input, weight.t())
        t: "f32[5120, 15360]" = l_self_modules_query_key_value_parameters_weight_.t();  l_self_modules_query_key_value_parameters_weight_ = None
        output: "f16[384, 2, 15360]" = torch.matmul(l_hidden_states_, t);  l_hidden_states_ = t = None
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:426 in forward, code: mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
        mixed_x_layer: "f16[384, 2, 40, 384]" = output.view(384, 2, 40, 384);  output = None
        
         # File: /home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/utils.py:34 in split_tensor_along_last_dim, code: tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
        split = torch.functional.split(mixed_x_layer, 128, dim = 3);  mixed_x_layer = None
        chunk: "f16[384, 2, 40, 128]" = split[0]
        chunk_1: "f16[384, 2, 40, 128]" = split[1]
        chunk_2: "f16[384, 2, 40, 128]" = split[2];  split = None
        
         # File: /home/tfogal/env/lib/python3.10/site-packages/megatron/core/tensor_parallel/utils.py:37 in <genexpr>, code: return tuple(chunk.contiguous() for chunk in tensor_list)
        query_layer: "f16[384, 2, 40, 128]" = chunk.contiguous();  chunk = None
        key_layer: "f16[384, 2, 40, 128]" = chunk_1.contiguous();  chunk_1 = None
        value_layer: "f16[384, 2, 40, 128]" = chunk_2.contiguous();  chunk_2 = None
        return (query_layer, key_layer, value_layer)

cc @apaz-cli @tfogal

@tfogal tfogal added high priority nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage labels Jul 22, 2024
@tfogal tfogal self-assigned this Jul 22, 2024
@tfogal tfogal removed their assignment Jul 22, 2024
@t-vi t-vi added the dynamo label Jul 23, 2024
@kshitij12345 kshitij12345 self-assigned this Jul 23, 2024
@t-vi t-vi closed this as completed in #838 Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo high priority nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants