Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_empty_transform() got multiple values for argument 'device' #825

Closed
tfogal opened this issue Jul 22, 2024 · 1 comment · Fixed by #836
Closed

_empty_transform() got multiple values for argument 'device' #825

tfogal opened this issue Jul 22, 2024 · 1 comment · Fixed by #836
Assignees
Labels
dynamo high priority nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage

Comments

@tfogal
Copy link
Collaborator

tfogal commented Jul 22, 2024

🚀 Model / language coverage

I'm trying to get a fuller picture of what we need to support NeVA. As such I'm using:

def thunder_backend(gm, args):
  gm.real_recompile()
  # Examine may raise an error
  from thunder.examine import examine
  try:
      thunder.examine.examine(gm, *args)
  except Exception as e:
      print(f"Hit problem with examine:\n{e}")
      print(f"gm: {gm.print_readable()}")
  # Don't really use Thunder just return the original graph
  return gm

...
#model.model = thunder.jit(model.model)
model.model = torch.compile(backend=thunder_backend)(model.model)

(thanks Ivan for the great idea!)

And in one of the subgraphs I am seeing the issue:

_empty_transform() got multiple values for argument 'device'

Pitch

This is related to #343.

Alternatives / Potential work-arounds

Printing the graph module gives more information:

[NeMo W 2024-07-22 23:25:54 nemo_logging:349] /usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py:1943: FutureWarning: `torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.
      return node.target(*args, **kwargs)
    
[NeMo W 2024-07-22 23:25:54 nemo_logging:349] /usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py:685: UserWarning: Graph break due to unsupported builtin scaled_upper_triang_masked_softmax_cuda.PyCapsule.forward. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
      torch._dynamo.utils.warn_once(msg)
    
[NeMo W 2024-07-22 23:25:54 nemo_logging:349] /usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py:1943: FutureWarning: `torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.
      return node.target(*args, **kwargs)
    
[NeMo W 2024-07-22 23:25:54 nemo_logging:349] /usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py:1943: FutureWarning: `torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.
      return node.target(*args, **kwargs)
    
[NeMo W 2024-07-22 23:25:54 nemo_logging:349] /usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py:1943: FutureWarning: `torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.
      return node.target(*args, **kwargs)
    
[NeMo W 2024-07-22 23:25:54 nemo_logging:349] <eval_with_key>.31:7: FutureWarning: `torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.
      _cast = torch.cuda.amp.autocast_mode._cast((inputs, 1), torch.float16);  inputs = None
    
class GraphModule(torch.nn.Module):
    def forward(self, L_query_layer_: "f16[384, 2, 40, 128]", L_key_layer_: "f16[384, 2, 40, 128]", L_value_layer_: "f16[384, 2, 40, 128]"):
        l_query_layer_ = L_query_layer_
        l_key_layer_ = L_key_layer_
        l_value_layer_ = L_value_layer_
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:935 in torch_attention, code: query_layer = rearrange(query_layer, 'sq b np hn -> (b np) sq hn')
        query_layer: "f16[80, 384, 128]" = einops_einops_rearrange(l_query_layer_, 'sq b np hn -> (b np) sq hn');  l_query_layer_ = None
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:936 in torch_attention, code: key_layer = rearrange(key_layer, 'sk b np hn -> (b np) hn sk')
        key_layer: "f16[80, 128, 384]" = einops_einops_rearrange(l_key_layer_, 'sk b np hn -> (b np) hn sk');  l_key_layer_ = None
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:937 in torch_attention, code: value_layer = rearrange(value_layer, 'sv b np hn -> (b np) sv hn')
        value_layer: "f16[80, 384, 128]" = einops_einops_rearrange(l_value_layer_, 'sv b np hn -> (b np) sv hn');  l_value_layer_ = None
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:939 in torch_attention, code: matmul_input_buffer = torch.empty(
        matmul_input_buffer: "f16[80, 384, 384]" = torch.empty(80, 384, 384, dtype = torch.float16, device = device(type='cuda', index=0))
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:947 in torch_attention, code: matmul_result = torch.baddbmm(
        matmul_result: "f16[80, 384, 384]" = torch.baddbmm(matmul_input_buffer, query_layer, key_layer, beta = 0.0, alpha = 0.08838834764831843);  matmul_input_buffer = query_layer = key_layer = None
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:956 in torch_attention, code: attention_scores = matmul_result.view(b, np, sq, sk)
        attention_scores: "f16[2, 40, 384, 384]" = matmul_result.view(2, 40, 384, 384);  matmul_result = None
        return (attention_scores, value_layer)
        
gm: class GraphModule(torch.nn.Module):
    def forward(self, L_query_layer_: "f16[384, 2, 40, 128]", L_key_layer_: "f16[384, 2, 40, 128]", L_value_layer_: "f16[384, 2, 40, 128]"):
        l_query_layer_ = L_query_layer_
        l_key_layer_ = L_key_layer_
        l_value_layer_ = L_value_layer_
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:935 in torch_attention, code: query_layer = rearrange(query_layer, 'sq b np hn -> (b np) sq hn')
        query_layer: "f16[80, 384, 128]" = einops_einops_rearrange(l_query_layer_, 'sq b np hn -> (b np) sq hn');  l_query_layer_ = None
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:936 in torch_attention, code: key_layer = rearrange(key_layer, 'sk b np hn -> (b np) hn sk')
        key_layer: "f16[80, 128, 384]" = einops_einops_rearrange(l_key_layer_, 'sk b np hn -> (b np) hn sk');  l_key_layer_ = None
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:937 in torch_attention, code: value_layer = rearrange(value_layer, 'sv b np hn -> (b np) sv hn')
        value_layer: "f16[80, 384, 128]" = einops_einops_rearrange(l_value_layer_, 'sv b np hn -> (b np) sv hn');  l_value_layer_ = None
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:939 in torch_attention, code: matmul_input_buffer = torch.empty(
        matmul_input_buffer: "f16[80, 384, 384]" = torch.empty(80, 384, 384, dtype = torch.float16, device = device(type='cuda', index=0))
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:947 in torch_attention, code: matmul_result = torch.baddbmm(
        matmul_result: "f16[80, 384, 384]" = torch.baddbmm(matmul_input_buffer, query_layer, key_layer, beta = 0.0, alpha = 0.08838834764831843);  matmul_input_buffer = query_layer = key_layer = None
        
         # File: /home/tfogal/dev/nemo/nemo/collections/nlp/modules/common/megatron/attention.py:956 in torch_attention, code: attention_scores = matmul_result.view(b, np, sq, sk)
        attention_scores: "f16[2, 40, 384, 384]" = matmul_result.view(2, 40, 384, 384);  matmul_result = None
        return (attention_scores, value_layer)

Minimal Repro

cc @apaz-cli @tfogal

@tfogal tfogal added the program-coverage Requests for model and program coverage label Jul 22, 2024
@tfogal tfogal self-assigned this Jul 22, 2024
@tfogal
Copy link
Collaborator Author

tfogal commented Jul 22, 2024

Assigning to me until I can flesh out with a real reproducer.

@tfogal tfogal added the nemo Issues needed to support NVIDIA NeMo models. label Jul 22, 2024
@tfogal tfogal removed their assignment Jul 22, 2024
@t-vi t-vi added the dynamo label Jul 23, 2024
@kshitij12345 kshitij12345 self-assigned this Jul 23, 2024
@t-vi t-vi closed this as completed in #836 Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo high priority nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage
Projects
None yet
3 participants