Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PP Tracer doesn't work with fused_rmsnorm #1108

Open
wconstab opened this issue May 3, 2024 · 2 comments
Open

PP Tracer doesn't work with fused_rmsnorm #1108

wconstab opened this issue May 3, 2024 · 2 comments

Comments

@wconstab
Copy link
Contributor

wconstab commented May 3, 2024

Currently have to work around by using regular rmsnorm for PP to be enabled

torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
        # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm

Full trace https://gist.github.com/wconstab/3b68edda6bd30c2414403e91734ccc87
cc @kwen2501 @lessw2020

@wconstab
Copy link
Contributor Author

wconstab commented May 3, 2024

is it safe to skip the if and just call .contiguous() all the time? maybe that is a no-op in the case that x is already contiguous?
image

image

@wconstab
Copy link
Contributor Author

wconstab commented May 3, 2024

some attempts to fix this
(1) gets rid of conditionals on dynamic shapes, which gets me past the first tracing errors
pytorch/torchtitan#300
(2) does a hack for computing sm_count from device(0) which is unsafe. we might be able to make a version of this that is tracer-friendly somehow?
pytorch/torchtitan#301

After these I still hit a stride issue for the non-conditional usages of stride:
(3)

    File "/data/users/whc/pytorch/torch/_dynamo/variables/tensor.py", line 322, in var_getattr                                                                                              
      unimplemented(f"Illegal getattr invocation {name} in strict mode")                      
    File "/data/users/whc/pytorch/torch/_dynamo/exc.py", line 212, in unimplemented                                                                                                         
      raise Unsupported(msg)                                                                                                                                                                
  torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode                                                                                                           
                                                                                                                                                                                            
  from user code:                                                                                                                                                                           
     File "/data/users/whc/torchtitan/torchtitan/models/llama/model.py", line 428, in forward                                                                                               
      h = layer(h, freqs_cis)                                                                                                                                                               
    File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl                                                                                                     
      return forward_call(*args, **kwargs)                                                    
    File "/data/users/whc/torchtitan/torchtitan/models/llama/model.py", line 317, in forward                                                                                                
      h = x + self.attention(self.attention_norm(x), freqs_cis)                                                                                                                             
    File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl                                                                                                     
      return forward_call(*args, **kwargs)                                                                                                                                                  
    File "/data/users/whc/torchtitan/torchtitan/models/norms.py", line 61, in forward                                                                                                       
      return self.fused_rms_norm_fn(                                                                                                                                                        
    File "/data/users/whc/torchtitan/torchtitan/models/norms.py", line 316, in fused_rms_norm_fn                                                                                            
      return TritonFusedRMSNorm.apply(                                                                                                                                                      
    File "/data/users/whc/torchtitan/torchtitan/models/norms.py", line 294, in backward                                                                                                     
      dy.stride(0),                                                                                                                                                                         
                                                                                              
  Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information                                                                                                                   

finally tried setting export strict=False in pippy _IR.py -- this fixes the dy.stride(0) issue, but then I still crash with a data-ptr access during tracing.

(4)

    File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run
      spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
    File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in <genexpr>
      spec_key = tuple(arg.specialization_key() for arg in args if not arg.param.do_not_specialize)
    File "/home/whc/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/triton/runtime/jit.py", line 174, in specialization_key
      return (self.value.data_ptr() % JITFunction.divisibility == 0, )
    File "/data/users/whc/pytorch/torch/export/_safeguard.py", line 43, in __torch_function__
      return func(*args, **kwargs)
  RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ

I will try to wrap in a custom op (?)

can we just 'allow in graph' the whole call to fused_rmsnorm? why aren't we doing that already? 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant