Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix illegal memory accesses in multistage Mma's for k=0 #1593

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dfyz
Copy link
Contributor

@dfyz dfyz commented Jun 18, 2024

The forward/backward passes of MLP's in mixture-of-expert models are a perfect fit for the grouped GEMM implementation in CUTLASS (for example, the grouped_gemm library uses CUTLASS for the forward pass). Unfortunately, when we tried to use CUTLASS for the backward pass, we started randomly getting CUDA illegal memory access when no tokens were assigned to an expert (this means that k=0 when computing the gradients for its weights).

I'm not sure if my analysis of the code is correct, but the problem seems to be a missing edge case in multistage MMA's. When k=0, then gemm_k_iterations=0 when we enter in prologue, so the following happens for A/B iterators:

  • we clear the mask of the iterator once, before processing the residual tile
  • then the iterator moves to the steady state, which means that the access predicates are recomputed, and we have to clear the mask again when we are out of bounds
  • however, we don't do this because gemm_k_iterations is -1 now, not 0, so we don't clear the mask
  • we access the memory we're not supposed to access

This implies that a fix might be just clearing the mask whenever the mask is <=0, not =0. As far as I can see, this only matters for multistage prologues, but just to be safe, in this PR I'm replacing all gemm_k_iterations==0 comparisons I could find. I'm also slightly changing the grouped GEMM test, so that it introduces a problem with k=0 (the test crashes without the fix in this PR).

P.S. I should also note that cuBLAS appears to handle k=0 GEMM's correctly (no crashing, and the output matrix is filled with zeros).

@dfyz
Copy link
Contributor Author

dfyz commented Jun 20, 2024

force-pushed

Changed the author of the commit, no functional changes intended.

@dfyz
Copy link
Contributor Author

dfyz commented Jun 21, 2024

@hwu36 Hi, could you please take a quick look at this PR?

@thakkarV
Copy link
Collaborator

@mnicely @ANIKET-SHIVAM

@hwu36
Copy link
Collaborator

hwu36 commented Jun 26, 2024

@jackkosaian

@jackkosaian
Copy link
Contributor

Thanks for reporting this, @dfyz . For cases in which one of the modes is zero, we recommend simply not including that problem in the grouped GEMM arguments. Is that a possibility for your application?

@dfyz
Copy link
Contributor Author

dfyz commented Jun 26, 2024

@jackkosaian

Is that a possibility for your application?

That would imply a performance hit for two different reasons:

  1. When k=0 (as opposed to m=0 or n=0), you still have to fill the output matrix with zeroes. So I guess you'd need an additional kernel to run before or after the grouped GEMM to find the problems with k=0 and zero out the corresponding outputs.
  2. As far as I can see, the number of problems is a host-side parameter (as opposed to problem sizes themselves, which reside on the device). This means you have to know this number upfront when scheduling the kernel, which requires a CPU-GPU sync if the problem sizes are calculated dynamically by a previously scheduled kernel (as is usual for MoE models).

1 is probably more of a minor inconvenience, but 2 will really kill the performance in my case.

I can try to come up with some workarounds, but isn't it better long-term to properly handle k=0 in CUTLASS, seeing that it matches the cuBLAS behaviour?

@hwu36
Copy link
Collaborator

hwu36 commented Jun 27, 2024

What is the name of the kernel used by cublas?

@dfyz
Copy link
Contributor Author

dfyz commented Jun 27, 2024

I was a little unclear here: in the grouped_gemm library, cuBLAS is only used as a fallback that launches multiple regular GEMMs instead of a single grouped GEMM. In this case, when k=0 for any of the regular GEMMs, cuBLAS runs a CUDA memset on the output and doesn't launch any kernels at all.

I see that cuBLAS also has a grouped GEMM implementation in recent CUDA versions, but that appears to just run CUTLASS under the hood, so it also crashes when k=0.

Here's a quick test program I made to illustrate the observed behaviors of cuBLAS (12.5) and CUTLASS (the most recent code from main) when k=0:

Scenario Outcome Comment
cuBLAS, non-grouped (./main cublas) ✅ computes an all-zero result runs [CUDA memset] and no kernels, according to NSight Systems
CUTLASS, non-grouped (./main cutlass) ❌ crashes runs cutlass::Kernel<cutlass::gemm::kernel::Gemm<cutlass::gemm::threadblock::MmaMultistage<cutlass:…
cuBLAS, grouped (./main cublas_grouped) ❌ crashes runs cutlass::Kernel2<cutlass_80_tensorop_bf16_s16816gemm_bf16_grouped_128x64_64x3_align8>
CUTLASS, grouped (./main cutlass_grouped) ❌ crashes runs cutlass::Kernel<cutlass::gemm::kernel::GemmGrouped<cutlass::gemm::threadblock::MmaMultistage<c…

So cuBLAS and CUTLASS only actually exhibit different behavior for the non-grouped case. Note that the fix I'm proposing should fix all scenarios in the above table.

Copy link

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

@dfyz
Copy link
Contributor Author

dfyz commented Jul 29, 2024

Otherwise, please respond with a comment indicating any updates.

Well, I still think this is something worth fixing in CUTLASS directly. For now, I implemented a workaround in grouped_gemm, but it would be great to get rid of this workaround.

@mnicely
Copy link
Collaborator

mnicely commented Jul 30, 2024

@dfyz, cuBLAS will resolve this issue with Grouped GEMM in an upcoming release. I agree it would be good to fix in CUTLASS, but we'll need to revisit when we have more time.

@dfyz
Copy link
Contributor Author

dfyz commented Jul 30, 2024

@mnicely Thank you, this sounds great! A couple of follow-up questions:

  1. I am little confused as to how cuBLAS can resolve it if they appear to be using a CUTLASS-based kernel for Grouped GEMM. Does this mean cuBLAS will add a host-side preprocessing step before running CUTLASS, similar to what I used in grouped_gemm?
  2. When you say "revisit when we have more time", do you mean that is there something wrong with the fix I'm proposing in this PR (e.g., I only masked the symptoms of the crash, but the root cause is somewhere else), or you need time to review the code/run some internal perf tests/etc.? I honestly expected this to be a relatively uncontroversial fix of an edge case in CUTLASS 2 tile iterators, but of course I might be missing something deeper (e.g., the performance implications?).

@mnicely
Copy link
Collaborator

mnicely commented Jul 31, 2024

Hi @dfyz, it's more us spending time to review the code and any ripple efforts, internal verification and testing, and then productization. Combine this with high priority tasks and bugs; and I'm unable to commit to a date. We really appreciate the PR and your WAR may help other customers while we work on this issue. :)

Copy link

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

@dfyz
Copy link
Contributor Author

dfyz commented Aug 30, 2024

Otherwise, please respond with a comment indicating any updates

An eventual fix for this on the CUTLASS side would be very appreciated, but I understand that this is low-priority issue.

cuBLAS will resolve this issue with Grouped GEMM in an upcoming release

By the way, CUDA 12.6 is out, and this release note looks promising:
image

However, when I run the test program from this comment as ./main cublas_grouped in the nvidia/cuda:12.6.0-devel-ubuntu24.04 image, I still get a crash:

# ./main cublas_grouped
CUDA Error at: main.cu:248
  an illegal memory access was encountered

Is this an unrelated fix, or am I doing something wrong in my test program?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants