Skip to content

Commit

Permalink
[EVT] Add support for Row/Col broadcast PtrArray (#2033)
Browse files Browse the repository at this point in the history
* Add group support to EVT row/col broadcast.

* small modifications

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
  • Loading branch information
jwfromm and hwu36 authored Feb 2, 2025
1 parent 6f55278 commit affd1b6
Showing 1 changed file with 39 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -972,14 +972,20 @@ compute_row_broadcast_stages() {
template<
int Stages,
class CtaTileShapeMNK,
class ElementInput,
class ElementCompute = ElementInput,
class ElementInput_,
class ElementCompute = cute::remove_pointer_t<ElementInput_>,
class StrideMNL_ = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<ElementInput>,
int Alignment = 128 / sizeof_bits_v<cute::remove_pointer_t<ElementInput_>>,
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
>
struct Sm90RowBroadcast {
using StrideMNL = StrideMNL_;
// Get base element input type.
using ElementInput = cute::remove_pointer_t<ElementInput_>;
// Check if input is an array of pointers.
static constexpr bool IsArrayOfPointers = is_same_v<ElementInput*, ElementInput_>;
using PtrRowType = cute::conditional_t<IsArrayOfPointers, ElementInput const* const*, ElementInput const*>;

static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining");

static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t<decltype(get<1>(StrideMNL{}))>, bool>; // row vector or scalar broadcast
Expand All @@ -991,7 +997,7 @@ struct Sm90RowBroadcast {
};

struct Arguments {
ElementInput const* ptr_row = nullptr;
PtrRowType ptr_row = nullptr;
ElementInput null_default = ElementInput(0);
StrideMNL dRow = {};
};
Expand Down Expand Up @@ -1036,7 +1042,7 @@ struct Sm90RowBroadcast {
is_zero_ = params.null_default == ElementCompute(0);
}
// Dynamic non-batched scalar broadcast
else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) {
else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0) && !IsArrayOfPointers) {
is_zero_ = params.ptr_row[0] == ElementInput(0);
}
}
Expand Down Expand Up @@ -1183,7 +1189,13 @@ struct Sm90RowBroadcast {

auto layout_M = make_layout(M, repeat_like(M, _0{}));
auto layout_L = make_layout(L, get<2>(params.dRow));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_layout(layout_M,layout_N,layout_L));
ElementInput const* ptr_row;
if constexpr(IsArrayOfPointers) {
ptr_row = params.ptr_row[l];
} else {
ptr_row = params.ptr_row;
}
Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L));
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
Expand Down Expand Up @@ -1220,14 +1232,20 @@ struct Sm90RowBroadcast {
template<
int Stages,
class CtaTileShapeMNK,
class ElementInput,
class ElementCompute = ElementInput,
class ElementInput_,
class ElementCompute = cute::remove_pointer_t<ElementInput_>,
class StrideMNL_ = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<ElementInput>,
int Alignment = 128 / sizeof_bits_v<cute::remove_pointer_t<ElementInput_>>,
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
>
struct Sm90ColBroadcast {
using StrideMNL = StrideMNL_;
// Get base element input type.
using ElementInput = cute::remove_pointer_t<ElementInput_>;
// Check if input is an array of pointers.
static constexpr bool IsArrayOfPointers = is_same_v<ElementInput*, ElementInput_>;
using PtrColType = cute::conditional_t<IsArrayOfPointers, ElementInput const* const*, ElementInput const*>;

static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining");

static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t<decltype(get<0>(StrideMNL{}))>, bool>; // Column vector or scalar broadcast
Expand All @@ -1238,13 +1256,13 @@ struct Sm90ColBroadcast {
struct SharedStorage { };

struct Arguments {
ElementInput const* ptr_col = nullptr;
PtrColType ptr_col = nullptr;
ElementInput null_default = ElementInput(0);
StrideMNL dCol = {};
};

struct Params {
ElementInput const* ptr_col = nullptr;
PtrColType ptr_col = nullptr;
ElementCompute null_default = ElementCompute(0);
StrideMNL dCol = {};
};
Expand Down Expand Up @@ -1301,7 +1319,7 @@ struct Sm90ColBroadcast {
is_zero_ = params.null_default == ElementCompute(0);
}
// Dynamic non-batched scalar broadcast
else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) {
else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0) && !IsArrayOfPointers) {
is_zero_ = params.ptr_col[0] == ElementInput(0);
}
}
Expand Down Expand Up @@ -1398,6 +1416,7 @@ struct Sm90ColBroadcast {
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {

auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
auto layout_M = [&] () CUTLASS_LAMBDA_FUNC_INLINE {
auto shape_M = get<0>(args.problem_shape_mnkl);
if constexpr (IsDynamicBroadcast) {
Expand All @@ -1416,11 +1435,17 @@ struct Sm90ColBroadcast {

auto layout_N = make_layout(N, repeat_like(N, _0{}));
auto layout_L = make_layout(L, get<2>(params.dCol));
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(layout_M,layout_N,layout_L));
ElementInput const* ptr_col;
if constexpr(IsArrayOfPointers) {
ptr_col = params.ptr_col[l];
} else {
ptr_col = params.ptr_col;
}
Tensor mCol = make_tensor(make_gmem_ptr(ptr_col), make_layout(layout_M,layout_N,layout_L));
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);

Tensor mCol_static = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(make_layout(M),layout_N,layout_L));
Tensor mCol_static = make_tensor(make_gmem_ptr(ptr_col), make_layout(make_layout(M),layout_N,layout_L));
Tensor tCgCol_static = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like<ElementCompute>(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Expand Down

0 comments on commit affd1b6

Please sign in to comment.