-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file .../csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu #19
Comments
Hey David, Thank you so much for your interest in using FlashFFTConv. I am wondering, what GPU card are you running this on (A100, H100 etc) ? We only tested on A100 and H100 and I suspect the issue comes from using a card that does not have enough as much total shared memory as the A100. We can see how to extend this if needed. Also what size of FFTConv are you currenlty using i.e 32K ? 16k ? |
Hi Hermann, I am wondering, what GPU card are you running this on (A100, H100 etc) ? Also what size of FFTConv are you currenlty using i.e 32K ? 16k ?
Funnily, the length that causes the problem is 16384 - the shortest one! The other lengths do not raise that error. David |
Hi Hermann, I tried it out in a A100-40GB, but unfortunately, I keep getting errors related to the package :/
|
Hi David, Do you mind sharing what version of Pytorch and CUDA you are using so that I can try to reproduce your error on my end and see what the issue could be? I suspect this could be from the version of Pytorch or CUDA. We tested on PyTorch 2.0 and |
I am testing on PyTorch '2.2.0a0+81ea7a4'. CUDA and Toolkit versions 12.3. Do you think this might be causing the problem?
By the way, I created a small benchmark to pinpoint the errors:
I am afraid the results are the same. It only works until sequences of length 2048. EDIT: I verified with
Unfortunately, I get the same results. I tried with both
FINAL EDITI found out where the problem lies. After trying on A100-40GB, A6000 ADA and A100-80GB, I noticed that all sequence lengths work only on the A100-80GB, and using torch.bfloat16. For long sequences, A100-80GB does not raise an error, but the loss becomes infinite.
I hope these insights help a bit getting clarity on what's missing :) From my side, I'll continue using A100-80GBs in the meantime. Thank you! David |
Seeing the same issue here on RTX 4090: ERROR: CUDA RT call "cudaFuncSetAttribute(&monarch_conv_cuda_32_32_32_kernel<32, 8, 32768, 2, 16, false, 2, 8, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 135168)" in line 969 of file /tmp/pip-req-build-j90uf05x/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu failed with invalid argument (1). Running the example code in the README (with some obvious fixes):
(sssm) ➜ spectral_ssm git:(main) ✗ nvcc --version |
Running the unit test shared above I see the first error here: Conv layer: FlashFFTConv(), seq_len = 8192, dtype = torch.float16, use_32_butterfly = True |
Hi Dan & Hermann,
I am trying to run some experiments with FlashFFTConv, but I am afraid I am encountering the following error:
For debugging, I am running the following:
where fftconv_fn is a FlashFFTConv element with
use_32_butterfly=True
. Bothtorch.float16
andtorch.bfloat16
lead to the same error.Any help on how to solve this issue would be much appreciated!
David
The text was updated successfully, but these errors were encountered: