Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ struct SharedStorage {
};

template <int kBlockM_,
int kBlockN_,
int kBlockN1_,
int kBlockN2_,
int kBlockN3_,
int kBlockK_,
int kNWarps_,
int kStages_,
Expand All @@ -73,16 +75,19 @@ struct Kernel_traits {
static_assert(kNWarps_ == 12 || kNWarps_ == 16);

static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kBlockN1 = kBlockN1_;
static constexpr int kBlockN2 = kBlockN2_;
static constexpr int kBlockN3 = kBlockN3_;
static constexpr int kBlockK = kBlockK_;
static constexpr int kTiles = kTiles_;
static constexpr int TokenPackSize = TokenPackSize_;
static constexpr int M = M_;
static constexpr int K = K_;
static constexpr int WeightScaleGroup = WeightScaleGroup_;

using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;

using TileShape_MNK1 = Shape<Int<kBlockM>, Int<kBlockN1>, Int<kBlockK>>;
using TileShape_MNK2 = Shape<Int<kBlockM>, Int<kBlockN2>, Int<kBlockK>>;
using TileShape_MNK3 = Shape<Int<kBlockM>, Int<kBlockN3>, Int<kBlockK>>;
static constexpr int kClusterM = kClusterM_;
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;

Expand All @@ -91,9 +96,17 @@ struct Kernel_traits {

using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;

using TiledMma = decltype(cute::make_tiled_mma(
using TiledMma1 = decltype(cute::make_tiled_mma(
cute::GMMA::
rs_op_selector<Element, Element, ElementAccum, TileShape_MNK1>(),
AtomLayoutMNK{}));
using TiledMma2 = decltype(cute::make_tiled_mma(
cute::GMMA::
rs_op_selector<Element, Element, ElementAccum, TileShape_MNK2>(),
AtomLayoutMNK{}));
using TiledMma3 = decltype(cute::make_tiled_mma(
cute::GMMA::
rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
rs_op_selector<Element, Element, ElementAccum, TileShape_MNK3>(),
AtomLayoutMNK{}));

using SmemLayoutAtomA =
Expand All @@ -107,27 +120,53 @@ struct Kernel_traits {
SmemLayoutAtomA{},
make_shape(Int<kBlockM>{}, Int<kBlockK / 2>{}, Int<kStages>{})));

using SmemLayoutAtomB =
using SmemLayoutAtomB1 =
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<1>(TileShape_MNK1{})),
decltype(cute::get<2>(TileShape_MNK1{}))>());

using SmemLayoutB1 =
decltype(tile_to_shape(SmemLayoutAtomB1{},
make_shape(shape<1>(TileShape_MNK1{}),
shape<2>(TileShape_MNK1{}),
Int<kStages>{})));

using SmemLayoutAtomB2 =
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<1>(TileShape_MNK2{})),
decltype(cute::get<2>(TileShape_MNK2{}))>());

using SmemLayoutB2 =
decltype(tile_to_shape(SmemLayoutAtomB2{},
make_shape(shape<1>(TileShape_MNK2{}),
shape<2>(TileShape_MNK2{}),
Int<kStages>{})));

using SmemLayoutAtomB3 =
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
decltype(cute::get<1>(TileShape_MNK3{})),
decltype(cute::get<2>(TileShape_MNK3{}))>());

using SmemLayoutB =
decltype(tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape_MNK{}),
shape<2>(TileShape_MNK{}),
using SmemLayoutB3 =
decltype(tile_to_shape(SmemLayoutAtomB3{},
make_shape(shape<1>(TileShape_MNK3{}),
shape<2>(TileShape_MNK3{}),
Int<kStages>{})));
using SmemLayoutAtomC =
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K,
ElementOutput,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<1>(TileShape_MNK{}))>());
decltype(cute::get<0>(TileShape_MNK1{})),
decltype(cute::get<1>(TileShape_MNK1{}))>());

using SmemLayoutC =
decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{},
select<0, 1>(TileShape_MNK1{})));

using SmemCopyAtomAB = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomC = Copy_Atom<cute::SM90_U32x4_STSM_N, ElementOutput>;
Expand All @@ -138,15 +177,15 @@ struct Kernel_traits {
Element,
ElementOutput,
SmemLayoutA,
SmemLayoutB,
SmemLayoutB1,
SmemLayoutC,
SmemLayoutScale>;

using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;

static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
static constexpr int kNumThreadsPerRow = kBlockN1 / kNumVecElem;
// static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyCAtom =
Expand Down
78 changes: 52 additions & 26 deletions custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ template <typename Ktraits>
struct CollectiveMainloopFwd {
using Element = typename Ktraits::Element;
using ElementOutput = typename Ktraits::ElementOutput;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using TileShape_MNK1 = typename Ktraits::TileShape_MNK1;
using TileShape_MNK2 = typename Ktraits::TileShape_MNK2;
using TileShape_MNK3 = typename Ktraits::TileShape_MNK3;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using ElementAccum = typename Ktraits::ElementAccum;

static constexpr int kStages = Ktraits::kStages;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kBlockN1 = Ktraits::kBlockN1;
static constexpr int kBlockN2 = Ktraits::kBlockN2;
static constexpr int kBlockN3 = Ktraits::kBlockN3;
static constexpr int kBlockK = Ktraits::kBlockK;
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kTiles = Ktraits::kTiles;
Expand All @@ -50,7 +54,10 @@ struct CollectiveMainloopFwd {
using GmemTiledCopy = cute::SM90_TMA_LOAD;

using SmemLayoutA = typename Ktraits::SmemLayoutA;
using SmemLayoutB = typename Ktraits::SmemLayoutB;
using SmemLayoutB1 = typename Ktraits::SmemLayoutB1;
using SmemLayoutB2 = typename Ktraits::SmemLayoutB2;
using SmemLayoutB3 = typename Ktraits::SmemLayoutB3;

using SmemLayoutC = typename Ktraits::SmemLayoutC;
using SmemLayoutScale = typename Ktraits::SmemLayoutScale;

Expand All @@ -76,8 +83,8 @@ struct CollectiveMainloopFwd {
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}),
take<0, 2>(SmemLayoutB{}),
select<1, 2>(TileShape_MNK{}),
take<0, 2>(SmemLayoutB1{}),
select<1, 2>(TileShape_MNK1{}),
size<0>(ClusterShape{})));

using TMA_Scale = decltype(make_tma_copy(
Expand All @@ -89,7 +96,7 @@ struct CollectiveMainloopFwd {
select<0>(Shape<Int<kBlockM>>{}),
size<0>(ClusterShape{})));

static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma1{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
Expand All @@ -100,7 +107,7 @@ struct CollectiveMainloopFwd {
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(
size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(
size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
size(take<0, 2>(SmemLayoutB1{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesScale = static_cast<uint32_t>(
size(SmemLayoutScale{}(_, _0{})) * cutlass::sizeof_bits_v<float> / 8);

Expand Down Expand Up @@ -141,8 +148,8 @@ struct CollectiveMainloopFwd {
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
TMA_B tma_load_B = make_tma_copy(GmemTiledCopy{},
mB,
SmemLayoutB{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
SmemLayoutB1{}(_, _, _0{}),
select<1, 2>(TileShape_MNK1{}),
size<0>(ClusterShape{}));
Tensor mScale =
make_tensor(make_gmem_ptr(args.weight_scale), args.layout_Scale);
Expand Down Expand Up @@ -176,7 +183,10 @@ struct CollectiveMainloopFwd {
}
}

template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
template <int CUR_N,
typename SharedStorage,
typename FrgTensorO,
typename TiledMma>
CUTLASS_DEVICE void store(Params const& mainloop_params,
FrgTensorO& tOrO,
SharedStorage& shared_storage,
Expand Down Expand Up @@ -252,7 +262,7 @@ struct CollectiveMainloopFwd {

cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);

constexpr int k_copy_times = kBlockN / 16;
constexpr int k_copy_times = CUR_N / 16;

#pragma unroll
for (int i = 0; i < k_copy_times; i++) {
Expand All @@ -273,15 +283,15 @@ struct CollectiveMainloopFwd {
const int expert_idx =
TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
ElementOutput* store_c = mainloop_params.ptr_C + expert_idx +
bidn * (M * kBlockN) + bidm * kBlockM;
bidn * (M * kBlockN1) + bidm * kBlockM;

const int reamin_tokens = tokens - bidn * kBlockN;
const int reamin_tokens = tokens - bidn * kBlockN1;

const int col = tidx % 2;

constexpr int kPackSize = 16 / sizeof(ElementOutput);
constexpr int kNumVecElem = kBlockM / kPackSize;
constexpr int copy_len = kBlockN * kNumVecElem;
constexpr int copy_len = CUR_N * kNumVecElem;
#pragma unroll
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
const int idx_div2 = idx / 2;
Expand All @@ -307,7 +317,7 @@ struct CollectiveMainloopFwd {
auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0));

Tensor gB = local_tile(
g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
g_tensor, select<1, 2>(TileShape_MNK1{}), make_coord(bidn, _));
return gB;
}

Expand All @@ -324,8 +334,8 @@ struct CollectiveMainloopFwd {
const int tidx) {
Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()),
SmemLayoutB1{});
Tensor sScale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()),
SmemLayoutScale{});

Expand Down Expand Up @@ -387,7 +397,7 @@ struct CollectiveMainloopFwd {
mB(_, _, bidb).data(),
make_layout(cute::make_shape(tokens, size<1>(mB)), mB.stride()));
Tensor gB = local_tile(
mB_this_expert, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
mB_this_expert, select<1, 2>(TileShape_MNK1{}), make_coord(bidn, _));
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B,
_0{},
Layout<ClusterShape>{},
Expand Down Expand Up @@ -421,18 +431,26 @@ struct CollectiveMainloopFwd {
}
}

template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
template <int CUR_N,
typename SharedStorage,
typename FrgTensorO,
typename TiledMma>
CUTLASS_DEVICE void mma(Params const& mainloop_params,
TiledMma tiled_mma,
MainloopPipeline pipeline,
PipelineState& smem_pipe_read,
SharedStorage& shared_storage,
FrgTensorO& tSrS,
const int tidx) {
using sMemBLayout = std::conditional_t<
CUR_N == kBlockN1,
SmemLayoutB1,
std::conditional_t<CUR_N == kBlockN2, SmemLayoutB2, SmemLayoutB3>>;

Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
tiled_mma.accumulate_ = GMMA::ScaleOut::One;

auto threadMma = tiled_mma.get_thread_slice(tidx);
Expand All @@ -447,6 +465,7 @@ struct CollectiveMainloopFwd {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
constexpr int B_STEPS = kBlockN1 / CUR_N;
#pragma unroll
for (int kiter = 0; kiter < kTiles; ++kiter) {
Tensor tSsA =
Expand All @@ -455,7 +474,7 @@ struct CollectiveMainloopFwd {
gemm</*wg_wait=*/0>(tiled_mma,
tSrA,
tSsA,
tSrB(_, _, _, smem_pipe_read.index()),
tSrB(_, _, _, smem_pipe_read.index() * B_STEPS),
tSrS,
smem_tiled_copy_A,
smem_thr_copy_A);
Expand All @@ -464,18 +483,25 @@ struct CollectiveMainloopFwd {
}
}

template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
template <int CUR_N,
typename SharedStorage,
typename FrgTensorO,
typename TiledMma>
CUTLASS_DEVICE void mma_pipeline(Params const& mainloop_params,
TiledMma tiled_mma,
MainloopPipeline pipeline,
PipelineState& smem_pipe_read,
SharedStorage& shared_storage,
FrgTensorO& tSrS,
const int tidx) {
using sMemBLayout = std::conditional_t<
CUR_N == kBlockN1,
SmemLayoutB1,
std::conditional_t<CUR_N == kBlockN2, SmemLayoutB2, SmemLayoutB3>>;
Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
float2* weight_scale =
reinterpret_cast<float2*>(shared_storage.smem_scale.data()) + tidx / 4;

Expand All @@ -501,7 +527,7 @@ struct CollectiveMainloopFwd {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};

constexpr int B_STEPS = kBlockN1 / CUR_N;
__half2 scale1, scale2, scale3, scale4;
float2 scale_cur_k;
#pragma unroll
Expand All @@ -516,7 +542,7 @@ struct CollectiveMainloopFwd {
gemm</*wg_wait=*/0>(tiled_mma,
tSrA,
tSsA1,
tSrB(_, _, _, smem_pipe_read.index()),
tSrB(_, _, _, smem_pipe_read.index() * B_STEPS),
tSrS1,
smem_tiled_copy_A,
smem_thr_copy_A);
Expand Down Expand Up @@ -545,7 +571,7 @@ struct CollectiveMainloopFwd {
gemm</*wg_wait=*/0>(tiled_mma,
tSrA,
tSsA2,
tSrB(_, _, _, smem_pipe_read.index()),
tSrB(_, _, _, smem_pipe_read.index() * B_STEPS),
tSrS2,
smem_tiled_copy_A,
smem_thr_copy_A);
Expand Down
Loading
Loading