-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use CUTLASS for both trans_a
and trans_b
on Ampere
#14
Conversation
Hi! Thanks for the PR! I will have a detailed look at it, but I am curious if you can just "squeeze" |
csrc/grouped_gemm.cu
Outdated
|
||
// NOTE: We support transposition through the 'trans_b' flag. | ||
TORCH_CHECK(a.is_contiguous()); | ||
TORCH_CHECK(b.is_contiguous()); | ||
TORCH_CHECK(c.is_contiguous()); | ||
|
||
|
||
// NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm. | ||
#if !defined(GROUPED_GEMM_DEVICE_CAPABILITY) || GROUPED_GEMM_DEVICE_CAPABILITY != 80 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like nothing in the CUTLASS path is Ampere (SM80) specific, so we could update these ifdefs?
Maybe we could keep the library using the cuBLAS path for now, and have a define that will enable the CUTLASS path so that some of the larger users can take it for a spin before we merge it to main?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't suppose there's anything Ampere specific, but I unfortunately don't have any H100s or similar to test on (hopefully this will change soon). Perhaps we can update the ifdefs it in a separate PR later?
Regarding the define: that sounds like a great idea, but I wasn't sure if you were proposing to introduce a define for non-Ampere (Hopper, etc.) only, or a define to enable the full CUTLASS path in general (so that you need to enable it explicitly even on Ampere). To be safe, I implemented the latter in 2285bb4: the users should set a GROUPED_GEMM_FULL_CUTLASS=1
environment variable when running setup.py
if they want to enable CUTLASS for the backward pass (aka transposed cases).
If you have something different in mind, I can easily change it. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The GROUPED_GEMM_FULL_CUTLASS mode SGTM! Could we replace the other ifdef guards with this as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could try, but I'd be very careful to not break anything accidentally, since testing that the correct code-path is chosen for all configurations is non-trivial. I currently just build the library in different configurations, run a simple test under nsys profile
, and check what kernels are actually launched, which is quite error-prone.
Let me try to write down the current behavior of the library very explicitly to make sure I'm not missing anything:
- When
TORCH_CUDA_ARCH_LIST
is set, we never use CUTLASS, at least after 18f7597. I think that this configuration is used for the PyPI builds, meaning you never use CUTLASS if you justpip install grouped_gemm
. - When
TORCH_CUDA_ARCH_LIST
is unset and we build for non-Ampere, CUTLASS is again never used. - When
TORCH_CUDA_ARCH_LIST
is unset and we build for Ampere, CUTLASS is used for the forward pass only (i.e.,!trans_a && !trans_b
).
The current iteration of my PR modifies case #3
like this:
- [...] Additionally, iff
GROUPED_GEMM_FULL_CUTLASS
is set, CUTLASS is also used for the backward pass (i.e.,trans_a || trans_b
).
If we want to get rid of the other ifdefs and simplify things, I'd suggest the following:
- Drop the
GROUPED_GEMM_DEVICE_CAPABILITY
define/environment variable. I think it's redundant when we also haveTORCH_CUDA_ARCH_LIST
. - Extend my PR to gate both forward and backward CUTLASS paths behind
#ifdef GROUPED_GEMM_FULL_CUTLASS
(or probably justGROUPED_GEMM_CUTLASS
, since...FULL...
is now implied).
This will technically break case #2
because you now have to specify an additional environment variable (GROUPED_GEMM_CUTLASS
) when building manually to enable CUTLASS. I don't think this is a huge problems because I believe that most users just use pip install grouped_gemm
(case #1
), which will continue to use cuBLAS for everything.
So, something like 18b2e8e. I haven't tested this extensively, so I'm not pushing this to the PR branch just yet. Just want to make sure you're okay with this approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, something like 18b2e8e.
Okay, I just pushed exactly this to the PR branch. As shown here, the current logic is quite error-prone: it's really not obvious that the define that enables CUTLASS doesn't actually always enable CUTLASS. And I think that having just one CUTLASS define for both fwd and bwd also makes it harder to misconfigure things.
To summarize: setup.py
is now configured via the GROUPED_GEMM_CUTLASS
environment variable, which can be only set to
1
(use CUTLASS everywhere)- any other value or not set at all (use cuBLAS everywhere)
Also, I checked that unit tests work just fine on H100 (with both GROUPED_GEMM_CUTLASS=1
and GROUPED_GEMM_CUTLASS=0
).
Does this sound good?
csrc/grouped_gemm.cu
Outdated
template <typename T> | ||
static void ReorderArray(T* data, const std::vector<size_t>& indices) { | ||
// For now, simply create a copy of the data and then copy over to the original. | ||
std::vector<T> copy(indices.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you could define copy = data, and then do the permutation from copy directly into data and skip the memcpy?
Same amount of data movement, just a little cleaner, I think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay on the review :) This got lost in my email it seems. |
I would really prefer to have it fixed on the CUTLASS side (given that the fix is essentially a one-liner), but they also suggest to skip the offending problems. In 398709f, I implemented a workaround along these lines, but I really hope that this will get fixed in CUTLASS one day because the workaround just complicates the code for no good reason. :/
No worries at all! Thanks for taking the time to review it. |
@mvpatel2000 would you be willing to test this on H100 for us? |
I actually got my hands on an H100, so I will be able at least run unit tests tomorrow. Any other testing would be much appreciated though! |
Encountered the following running same model code with env flag set on:
|
That was actually my bug in selecting the cuBLAS fallback, I missed that To sidestep this problem entirely, I pushed a streamlined version in 3be87fb that only has two simple modes irrespective of what
Notice that I changed |
@dfyz latest seems slower with cutlass, does this match your benchmarks? Here is a forward pass with a MoE FFN using GLU with 4 experts per GPU using uniform routing where we see it is ~1.5x slower. Unfortunately, I don't have time to profile any deeper at the moment :( |
… ours (make CUTLASS match cuBLAS on Ampere)
@mvpatel2000 I only ran benchmarks on an A100, but the H100 situation is not unexpected. Basically, as I said in the description of the PR, I initially wanted to make the minimal amount of changes to use CUTLASS everywhere. The motivation was not to make the kernels run faster, but to avoid GPU<->CPU synchronization points when training MoE models, which is only possible with CUTLASS (I will introduce these changes in a separate PR). In particular, this meant I didn't pay much attention to kernel performance on H100 when @tgale96 asked me to also enable CUTLASS on SM 9.0. Since this PR is pretty far from minimal already, I guess that ship has sailed, though. :) So I dug a little further. Here's a small benchmark I used to roughly estimate the performance of a full MoE training run using Mixtral 8x22B matrix sizes: import torch
import grouped_gemm as gg
if __name__ == '__main__':
# Mixtral 8x7B sizes.
M = 16384
K = 4096
N = 14336
E = 8
x = torch.rand(M, K, dtype=torch.bfloat16, device='cuda')
w = torch.rand(E, K, N, dtype=torch.bfloat16, device='cuda')
x.requires_grad_(True)
w.requires_grad_(True)
batch_sizes = torch.tensor([M//E]*E)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
for _ in range(30):
out = gg.ops.gmm(x, w, batch_sizes)
grad = out.sum().backward()
torch.cuda.synchronize()
prof.export_chrome_trace(f'gmm_trace.json') On A100, after I just pushed 223cd55 to get rid of weird hardcoded shapes (as @imoneoi's did in his branch), CUTLASS seems to be on par with cuBLAS.
On H100, the performance is indeed much worse (even worse than ~1.5x), and 223cd55 doesn't help much.
Apparently, as @imoneoi mentioned, we'd need a CUTLASS GEMM optimized specifically for SM 9.0. The problem is that, according to this table and the official example requires using CUTLASS 3.x API instead of CUTLASS 2.x. This is non-trivial: in addition to changing the code, it requires bumping the CUTLASS version used in @tgale96 What do you think of merging this PR as is (i.e., with "naive" H100 support), and then making H100 go brrr in a later PR (I'm willing to do this)? Since using CUTLASS now requires an explicit opt-in, I think this should be safe. Alternatively, we can drop CUTLASS H100 support for now (keeping the current behavior), and then add H100 properly in a separate PR. |
I think its fine to merge this with the ifdef guards - it'll be functional on everything and off by default. Do you mind updating the README to explain the build option? It's pretty annoying that CUTLASS completely changed the API for H100, but not unexpected :) |
And thank you @mvpatel2000 for testing this on H100! |
Just updated! I also added an "Upcoming features" section inspired by Megatron Core. These are the features for which I plan to submit additional PRs in the future. If you think this section doesn't belong to the README, I can remove it. (In general, feel free to modify the README however you see fit.) Also, should I bump the library version in |
setup.py
Outdated
@@ -4,24 +4,14 @@ | |||
import torch | |||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension | |||
|
|||
if os.environ.get("TORCH_CUDA_ARCH_LIST"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we needed this for some platforms - is there a reason to delete this?
Hi! Sorry for the delay. I think everything looks great except there is some logic being removed from |
And yes, I think you can bump the version in |
When `TORCH_CUDA_ARCH_LIST` is not set explicitly, use the compute capability of the current GPU to build against to avoid redundant compilations.
Ah right, sorry. I think it's the other way round, though: in the current iteration of this PR, multi-platforms builds / builds without GPU should work fine, but when a GPU is present and AFAIU,
I just restored the old logic (without the |
Bumped to |
Hi Ivan - sorry again for the delay! I just merged the PR. Thanks so much for the new feature! Very excited to play with it. |
hey @dfyz, thanks for all your work here! curious if you did any more investigation into CUTLASS 3.x or H100 perf?
I tried bumping third_party/cutlass to v3.7.0 and things still build because of their backwards compat mode. But the performance didn't change at all, so probably need to use the new apis |
Hi @tmm1, proper H100 support is still on my to-do list, but I haven't got around to it yet. I think that using the new API is required to utilize TMA and WGMMA, so yeah, it's not surprising that the performance didn't change. The first steps should be really straightforward:
This should definitely improve performance, however one thing that's bothering me is that we currently use a hardcoded set of (default) shapes everywhere. Thankfully, for Ampere the default shapes were enough to get cuBLAS-like performance for the matrix sizes I was interested in. However, I'm not sure this will hold for Hopper. So, ideally, we should generate multiple template instantiations with different tile/cluster sizes, test them all, then choose the most performant one. I really hope this won't be needed for Mixtral-like matrix sizes, though (I'm mostly interested in those). :) |
Makes sense. I started trying to integrate it in https://github.com/tgale96/grouped_gemm/compare/tmm1:tmm1/cutlass-upgrade but there are a number of changes required for the new API. |
Note that this PR can't be merged as is; see the end of the description.I implemented a workaround, so this PR should be safe to merge.This PR is based on imoneoi@'s awesome work here. The problem sorting code (which is only needed for variable-size
k
, i.e., in the backward pass) was copied from their PR verbatim. Everything else is structured pretty much the same conceptually, the only difference is that I tried to make the changes very minimal. Notably:The only caveat is that vanilla CUTLASS has a bug I described previously when handling
k=0
in a grouped GEMM. I added a test that reliably crashes the backward pass with aCUDA error: an illegal memory access was encountered
error. The dimensions of the GEMM are larger than what is used in the other tests, but making them smaller unfortunately makes the crash go away, since the bug leads to an OOB read, which can go undetected without usingcompute-sanitizer
.I'm currently not sure how to work around this. The best option would be upstream CUTLASS landing my PR fixing the bug (and/or implementing an alternative fix), but I'm really not sure when this will happen. Looking at their opened PRs, I see that it might take weeks, or even months for a PR to be merged. :/