Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use CUTLASS for both trans_a and trans_b on Ampere #14

Merged
merged 9 commits into from
Jul 29, 2024

Conversation

dfyz
Copy link
Contributor

@dfyz dfyz commented Jun 24, 2024

Note that this PR can't be merged as is; see the end of the description. I implemented a workaround, so this PR should be safe to merge.

This PR is based on imoneoi@'s awesome work here. The problem sorting code (which is only needed for variable-size k, i.e., in the backward pass) was copied from their PR verbatim. Everything else is structured pretty much the same conceptually, the only difference is that I tried to make the changes very minimal. Notably:

  • the original PR got rid of hardcoded threadblock/warp/instruction shapes, I left them as is (they can be changed in a separate PR, if necessary)
  • the original PR got rid of the cuBLAS fallback for Hopper, I left it intact (I think that it should be easy to make Hopper use CUTLASS too, but again, this can be done in a separate PR)

The only caveat is that vanilla CUTLASS has a bug I described previously when handling k=0 in a grouped GEMM. I added a test that reliably crashes the backward pass with a CUDA error: an illegal memory access was encountered error. The dimensions of the GEMM are larger than what is used in the other tests, but making them smaller unfortunately makes the crash go away, since the bug leads to an OOB read, which can go undetected without using compute-sanitizer.

I'm currently not sure how to work around this. The best option would be upstream CUTLASS landing my PR fixing the bug (and/or implementing an alternative fix), but I'm really not sure when this will happen. Looking at their opened PRs, I see that it might take weeks, or even months for a PR to be merged. :/

@dfyz dfyz mentioned this pull request Jun 24, 2024
@tgale96
Copy link
Owner

tgale96 commented Jun 28, 2024

Hi! Thanks for the PR! I will have a detailed look at it, but I am curious if you can just "squeeze" k=0 problems from the batch for the backward pass? If k=0, the output just needs to be zero'd, so you can skip the compute and avoid the bug?


// NOTE: We support transposition through the 'trans_b' flag.
TORCH_CHECK(a.is_contiguous());
TORCH_CHECK(b.is_contiguous());
TORCH_CHECK(c.is_contiguous());


// NOTE: Use cuBLAS for SM90 until CUTLASS supports SM90-optimized grouped-gemm.
#if !defined(GROUPED_GEMM_DEVICE_CAPABILITY) || GROUPED_GEMM_DEVICE_CAPABILITY != 80
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like nothing in the CUTLASS path is Ampere (SM80) specific, so we could update these ifdefs?

Maybe we could keep the library using the cuBLAS path for now, and have a define that will enable the CUTLASS path so that some of the larger users can take it for a spin before we merge it to main?

Copy link
Contributor Author

@dfyz dfyz Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't suppose there's anything Ampere specific, but I unfortunately don't have any H100s or similar to test on (hopefully this will change soon). Perhaps we can update the ifdefs it in a separate PR later?

Regarding the define: that sounds like a great idea, but I wasn't sure if you were proposing to introduce a define for non-Ampere (Hopper, etc.) only, or a define to enable the full CUTLASS path in general (so that you need to enable it explicitly even on Ampere). To be safe, I implemented the latter in 2285bb4: the users should set a GROUPED_GEMM_FULL_CUTLASS=1 environment variable when running setup.py if they want to enable CUTLASS for the backward pass (aka transposed cases).

If you have something different in mind, I can easily change it. :)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GROUPED_GEMM_FULL_CUTLASS mode SGTM! Could we replace the other ifdef guards with this as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could try, but I'd be very careful to not break anything accidentally, since testing that the correct code-path is chosen for all configurations is non-trivial. I currently just build the library in different configurations, run a simple test under nsys profile, and check what kernels are actually launched, which is quite error-prone.

Let me try to write down the current behavior of the library very explicitly to make sure I'm not missing anything:

  1. When TORCH_CUDA_ARCH_LIST is set, we never use CUTLASS, at least after 18f7597. I think that this configuration is used for the PyPI builds, meaning you never use CUTLASS if you just pip install grouped_gemm.
  2. When TORCH_CUDA_ARCH_LIST is unset and we build for non-Ampere, CUTLASS is again never used.
  3. When TORCH_CUDA_ARCH_LIST is unset and we build for Ampere, CUTLASS is used for the forward pass only (i.e., !trans_a && !trans_b).

The current iteration of my PR modifies case #3 like this:

  • [...] Additionally, iff GROUPED_GEMM_FULL_CUTLASS is set, CUTLASS is also used for the backward pass (i.e., trans_a || trans_b).

If we want to get rid of the other ifdefs and simplify things, I'd suggest the following:

  • Drop the GROUPED_GEMM_DEVICE_CAPABILITY define/environment variable. I think it's redundant when we also have TORCH_CUDA_ARCH_LIST.
  • Extend my PR to gate both forward and backward CUTLASS paths behind #ifdef GROUPED_GEMM_FULL_CUTLASS (or probably just GROUPED_GEMM_CUTLASS, since ...FULL... is now implied).

This will technically break case #2 because you now have to specify an additional environment variable (GROUPED_GEMM_CUTLASS) when building manually to enable CUTLASS. I don't think this is a huge problems because I believe that most users just use pip install grouped_gemm (case #1), which will continue to use cuBLAS for everything.

So, something like 18b2e8e. I haven't tested this extensively, so I'm not pushing this to the PR branch just yet. Just want to make sure you're okay with this approach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, something like 18b2e8e.

Okay, I just pushed exactly this to the PR branch. As shown here, the current logic is quite error-prone: it's really not obvious that the define that enables CUTLASS doesn't actually always enable CUTLASS. And I think that having just one CUTLASS define for both fwd and bwd also makes it harder to misconfigure things.

To summarize: setup.py is now configured via the GROUPED_GEMM_CUTLASS environment variable, which can be only set to

  • 1 (use CUTLASS everywhere)
  • any other value or not set at all (use cuBLAS everywhere)

Also, I checked that unit tests work just fine on H100 (with both GROUPED_GEMM_CUTLASS=1 and GROUPED_GEMM_CUTLASS=0).

Does this sound good?

template <typename T>
static void ReorderArray(T* data, const std::vector<size_t>& indices) {
// For now, simply create a copy of the data and then copy over to the original.
std::vector<T> copy(indices.size());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could define copy = data, and then do the permutation from copy directly into data and skip the memcpy?

Same amount of data movement, just a little cleaner, I think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, thanks, I copied this function from @imoneoi without thinking twice and didn't notice the extra memcpy(). Removed it in 5b64e00.

@tgale96
Copy link
Owner

tgale96 commented Jun 28, 2024

Sorry for the delay on the review :) This got lost in my email it seems.

@dfyz
Copy link
Contributor Author

dfyz commented Jul 1, 2024

If k=0, the output just needs to be zero'd, so you can skip the compute and avoid the bug?

I would really prefer to have it fixed on the CUTLASS side (given that the fix is essentially a one-liner), but they also suggest to skip the offending problems. In 398709f, I implemented a workaround along these lines, but I really hope that this will get fixed in CUTLASS one day because the workaround just complicates the code for no good reason. :/

Sorry for the delay on the review

No worries at all! Thanks for taking the time to review it.

@tgale96
Copy link
Owner

tgale96 commented Jul 3, 2024

@mvpatel2000 would you be willing to test this on H100 for us?

@dfyz
Copy link
Contributor Author

dfyz commented Jul 3, 2024

would you be willing to test this on H100 for us?

I actually got my hands on an H100, so I will be able at least run unit tests tomorrow. Any other testing would be much appreciated though!

@mvpatel2000
Copy link
Collaborator

Encountered the following running same model code with env flag set on:

[rank7]: │ /usr/lib/python3/dist-packages/torch/_tensor.py:525 in backward              │
[rank7]: │                                                                              │
[rank7]: │    522 │   │   │   │   create_graph=create_graph,                            │
[rank7]: │    523 │   │   │   │   inputs=inputs,                                        │
[rank7]: │    524 │   │   │   )                                                         │
[rank7]: │ ❱  525 │   │   torch.autograd.backward(                                      │
[rank7]: │    526 │   │   │   self, gradient, retain_graph, create_graph, inputs=inputs │
[rank7]: │    527 │   │   )                                                             │
[rank7]: │    528                                                                       │
[rank7]: │                                                                              │
[rank7]: │ /usr/lib/python3/dist-packages/torch/autograd/__init__.py:267 in backward    │
[rank7]: │                                                                              │
[rank7]: │   264 │   # The reason we repeat the same comment below is that              │
[rank7]: │   265 │   # some Python versions print out the first line of a multi-line fu │
[rank7]: │   266 │   # calls in the traceback and some print out the last line          │
[rank7]: │ ❱ 267 │   _engine_run_backward(                                              │
[rank7]: │   268 │   │   tensors,                                                       │
[rank7]: │   269 │   │   grad_tensors_,                                                 │
[rank7]: │   270 │   │   retain_graph,                                                  │
[rank7]: │                                                                              │
[rank7]: │ /usr/lib/python3/dist-packages/torch/autograd/graph.py:744 in                │
[rank7]: │ _engine_run_backward                                                         │
[rank7]: │                                                                              │
[rank7]: │   741 │   if attach_logging_hooks:                                           │
[rank7]: │   742 │   │   unregister_hooks = _register_logging_hooks_on_whole_graph(t_ou │
[rank7]: │   743 │   try:                                                               │
[rank7]: │ ❱ 744 │   │   return Variable._execution_engine.run_backward(  # Calls into  │
[rank7]: │   745 │   │   │   t_outputs, *args, **kwargs                                 │
[rank7]: │   746 │   │   )  # Calls into the C++ engine to run the backward pass        │
[rank7]: │   747 │   finally:                                                           │
[rank7]: │                                                                              │
[rank7]: │ /usr/lib/python3/dist-packages/torch/autograd/function.py:301 in apply       │
[rank7]: │                                                                              │
[rank7]: │   298 │   │   │   │   "of them."                                             │
[rank7]: │   299 │   │   │   )                                                          │
[rank7]: │   300 │   │   user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_f │
[rank7]: │ ❱ 301 │   │   return user_fn(self, *args)                                    │
[rank7]: │   302 │                                                                      │
[rank7]: │   303 │   def apply_jvp(self, *args):                                        │
[rank7]: │   304 │   │   r"""                                                           │
[rank7]: │                                                                              │
[rank7]: │ /usr/lib/python3/dist-packages/torch/cuda/amp/autocast_mode.py:142 in        │
[rank7]: │ decorate_bwd                                                                 │
[rank7]: │                                                                              │
[rank7]: │   139 │   @functools.wraps(bwd)                                              │
[rank7]: │   140 │   def decorate_bwd(*args, **kwargs):                                 │
[rank7]: │   141 │   │   with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0 │
[rank7]: │ ❱ 142 │   │   │   return bwd(*args, **kwargs)                                │
[rank7]: │   143 │                                                                      │
[rank7]: │   144 │   return decorate_bwd                                                │
[rank7]: │   145                                                                        │
[rank7]: │                                                                              │
[rank7]: │ /usr/lib/python3/dist-packages/megablocks/layers/glu.py:146 in backward      │
[rank7]: │                                                                              │
[rank7]: │   143 │   │   │   │   op=activation_fn, out_shape=ctx.w1_out_shape, out_dtyp │
[rank7]: │   144 │   │                                                                  │
[rank7]: │   145 │   │   # Compute dw2 with recomputed activation_fn output.            │
[rank7]: │ ❱ 146 │   │   dw2 = gg.backend.gmm(                                          │
[rank7]: │   147 │   │   │   activation_fn_out, ddsd_out, batch_sizes, trans_a=True)    │
[rank7]: │   148 │   │                                                                  │
[rank7]: │   149 │   │   # Compute dactivation_fn_out.                                  │
[rank7]: │                                                                              │
[rank7]: │ /usr/lib/python3/dist-packages/grouped_gemm/backend.py:27 in gmm             │
[rank7]: │                                                                              │
[rank7]: │   24 def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):       │
[rank7]: │   25 │   if c is None:                                                       │
[rank7]: │   26 │   │   c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)       │
[rank7]: │ ❱ 27 │   backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)                 │
[rank7]: │   28 │   return c                                                            │
[rank7]: │   29                                                                         │
[rank7]: │   30                                                                         │
[rank7]: ╰──────────────────────────────────────────────────────────────────────────────╯
[rank7]: IndexError: Dimension out of range (expected to be in range of [-2, 1], but got
[rank7]: 2)

@dfyz
Copy link
Contributor Author

dfyz commented Jul 4, 2024

Encountered the following running same model code with env flag set on

That was actually my bug in selecting the cuBLAS fallback, I missed that this #ifdef should have been ... || GROUPED_GEMM_DEVICE_CAPABILITY != 80, just as in a similar ifdef below. The current version of setup.py doesn't actually enable CUTLASS when TORCH_CUDA_ARCH_LIST is set, which I suppose happened in your case.

To sidestep this problem entirely, I pushed a streamlined version in 3be87fb that only has two simple modes irrespective of what TORCH_CUDA_ARCH_LIST is set to:

  • GROUPED_GEMM_CUTLASS=1 pip install . uses CUTLASS for both forward and backward passes
  • GROUPED_GEMM_CUTLASS=0 pip install . (or just not setting GROUPED_GEMM_CUTLASS) uses cuBLAS for both forward and backward passes

Notice that I changed GROUPED_GEMM_*FULL_*CUTLASS to just GROUPED_GEMM_CUTLASS to simplify things.

@mvpatel2000
Copy link
Collaborator

@dfyz latest seems slower with cutlass, does this match your benchmarks?

Here is a forward pass with a MoE FFN using GLU with 4 experts per GPU using uniform routing where we see it is ~1.5x slower. Unfortunately, I don't have time to profile any deeper at the moment :(

Baseline:
image

New Kernel:
image

@dfyz
Copy link
Contributor Author

dfyz commented Jul 5, 2024

@mvpatel2000 I only ran benchmarks on an A100, but the H100 situation is not unexpected. Basically, as I said in the description of the PR, I initially wanted to make the minimal amount of changes to use CUTLASS everywhere. The motivation was not to make the kernels run faster, but to avoid GPU<->CPU synchronization points when training MoE models, which is only possible with CUTLASS (I will introduce these changes in a separate PR). In particular, this meant I didn't pay much attention to kernel performance on H100 when @tgale96 asked me to also enable CUTLASS on SM 9.0.

Since this PR is pretty far from minimal already, I guess that ship has sailed, though. :) So I dug a little further.

Here's a small benchmark I used to roughly estimate the performance of a full MoE training run using Mixtral 8x22B matrix sizes:

import torch
import grouped_gemm as gg


if __name__ == '__main__':
    # Mixtral 8x7B sizes.
    M = 16384
    K = 4096
    N = 14336
    E = 8
    x = torch.rand(M, K, dtype=torch.bfloat16, device='cuda')
    w = torch.rand(E, K, N, dtype=torch.bfloat16, device='cuda')

    x.requires_grad_(True)
    w.requires_grad_(True)

    batch_sizes = torch.tensor([M//E]*E)

    with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
        for _ in range(30):
            out = gg.ops.gmm(x, w, batch_sizes)
            grad = out.sum().backward()

    torch.cuda.synchronize()
    prof.export_chrome_trace(f'gmm_trace.json')

On A100, after I just pushed 223cd55 to get rid of weird hardcoded shapes (as @imoneoi's did in his branch), CUTLASS seems to be on par with cuBLAS.

Time per iteration, ms Perfetto trace
cuBLAS 23.8 cleaned_a100_cublas_trace.json
CUTLASS, before 223cd55 31.6 cleaned_a100_cutlass_old_trace.json
CUTLASS, after 223cd55 23.9 cleaned_a100_cutlass_new_trace.json

On H100, the performance is indeed much worse (even worse than ~1.5x), and 223cd55 doesn't help much.

Time per iteration, ms Perfetto trace
cuBLAS 8.5 cleaned_h100_cublas_trace.json
CUTLASS, before 223cd55 15.0 cleaned_h100_cutlass_old_trace.json
CUTLASS, after 223cd55 13.8 cleaned_h100_cutlass_new_trace.json

Apparently, as @imoneoi mentioned, we'd need a CUTLASS GEMM optimized specifically for SM 9.0. The problem is that, according to this table and the official example requires using CUTLASS 3.x API instead of CUTLASS 2.x. This is non-trivial: in addition to changing the code, it requires bumping the CUTLASS version used in third_party.

@tgale96 What do you think of merging this PR as is (i.e., with "naive" H100 support), and then making H100 go brrr in a later PR (I'm willing to do this)? Since using CUTLASS now requires an explicit opt-in, I think this should be safe.

Alternatively, we can drop CUTLASS H100 support for now (keeping the current behavior), and then add H100 properly in a separate PR.

@tgale96
Copy link
Owner

tgale96 commented Jul 8, 2024

I think its fine to merge this with the ifdef guards - it'll be functional on everything and off by default. Do you mind updating the README to explain the build option?

It's pretty annoying that CUTLASS completely changed the API for H100, but not unexpected :)

@tgale96
Copy link
Owner

tgale96 commented Jul 8, 2024

And thank you @mvpatel2000 for testing this on H100!

@dfyz
Copy link
Contributor Author

dfyz commented Jul 9, 2024

Do you mind updating the README to explain the build option?

Just updated! I also added an "Upcoming features" section inspired by Megatron Core. These are the features for which I plan to submit additional PRs in the future. If you think this section doesn't belong to the README, I can remove it.

(In general, feel free to modify the README however you see fit.)

Also, should I bump the library version in setup.py?

setup.py Outdated
@@ -4,24 +4,14 @@
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

if os.environ.get("TORCH_CUDA_ARCH_LIST"):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we needed this for some platforms - is there a reason to delete this?

@tgale96
Copy link
Owner

tgale96 commented Jul 16, 2024

Hi! Sorry for the delay. I think everything looks great except there is some logic being removed from setup.py that I believe we need to support multi-platform builds / builds on platforms where no GPU is present.

@tgale96
Copy link
Owner

tgale96 commented Jul 16, 2024

And yes, I think you can bump the version in setup.py. Thanks!

dfyz added 2 commits July 16, 2024 16:57
When `TORCH_CUDA_ARCH_LIST` is not set explicitly, use the compute
capability of the current GPU to build against to avoid redundant
compilations.
@dfyz
Copy link
Contributor Author

dfyz commented Jul 16, 2024

I think everything looks great except there is some logic being removed from setup.py that I believe we need to support multi-platform builds / builds on platforms where no GPU is present.

Ah right, sorry. I think it's the other way round, though: in the current iteration of this PR, multi-platforms builds / builds without GPU should work fine, but when a GPU is present and TORCH_CUDA_ARCH_LIST is unset, we build the library for more architectures than needed.

AFAIU, device_capability (set by the code I'd removed), was used for two different things:

  1. Host-side choice between cuBLAS and CUTLASS based on -DGROUPED_GEMM_DEVICE_CAPABILITY=... (and hence on device_capability). We got rid of GROUPED_GEMM_DEVICE_CAPABILITY, so I thought the device_capability-related stuff was also safe to remove.
  2. Selecting the compute capabilities to build for. This is what I missed: when TORCH_CUDA_ARCH_LIST was unset, the device_capability was also used to prune the number of compute capabilities based on the current GPU model (by setting nvcc_flags).

I just restored the old logic (without the -DGROUPED_GEMM_DEVICE_CAPABILITY=... part). Thanks for noticing!

@dfyz
Copy link
Contributor Author

dfyz commented Jul 16, 2024

And yes, I think you can bump the version in setup.py.

Bumped to 0.1.5.

@tgale96 tgale96 merged commit 66c7195 into tgale96:main Jul 29, 2024
@tgale96
Copy link
Owner

tgale96 commented Jul 29, 2024

Hi Ivan - sorry again for the delay! I just merged the PR. Thanks so much for the new feature! Very excited to play with it.

@tmm1
Copy link

tmm1 commented Jan 22, 2025

then making H100 go brrr in a later PR (I'm willing to do this)?

hey @dfyz, thanks for all your work here! curious if you did any more investigation into CUTLASS 3.x or H100 perf?

This is non-trivial: in addition to changing the code, it requires bumping the CUTLASS version used in third_party.

I tried bumping third_party/cutlass to v3.7.0 and things still build because of their backwards compat mode. But the performance didn't change at all, so probably need to use the new apis

@dfyz
Copy link
Contributor Author

dfyz commented Jan 22, 2025

Hi @tmm1, proper H100 support is still on my to-do list, but I haven't got around to it yet. I think that using the new API is required to utilize TMA and WGMMA, so yeah, it's not surprising that the performance didn't change. The first steps should be really straightforward:

  1. Keep the current code for anything below Hopper, because why fix what's not broken. :) Thanks for confirming the current version works fine with CUTLASS 3.7, by the way.
  2. For Hopper specifically, write something similar to this example. It uses FP8, but I don't see anything inherently FP8-specific there.

This should definitely improve performance, however one thing that's bothering me is that we currently use a hardcoded set of (default) shapes everywhere. Thankfully, for Ampere the default shapes were enough to get cuBLAS-like performance for the matrix sizes I was interested in. However, I'm not sure this will hold for Hopper.

So, ideally, we should generate multiple template instantiations with different tile/cluster sizes, test them all, then choose the most performant one. I really hope this won't be needed for Mixtral-like matrix sizes, though (I'm mostly interested in those). :)

@tmm1
Copy link

tmm1 commented Jan 23, 2025

2. For Hopper specifically, write something similar to this example. It uses FP8, but I don't see anything inherently FP8-specific there.

Makes sense. I started trying to integrate it in https://github.com/tgale96/grouped_gemm/compare/tmm1:tmm1/cutlass-upgrade but there are a number of changes required for the new API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants