Skip to content

Commit

Permalink
Handle MNK Sm90{Row, Col}Reduction problem shapes (#1803)
Browse files Browse the repository at this point in the history
  • Loading branch information
saagarjha authored Oct 14, 2024
1 parent cc3c29a commit 5366879
Showing 1 changed file with 16 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,8 @@ struct Sm90ScalarReduction {
CudaHostAdapter* cuda_adapter = nullptr) {
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
if constexpr (IsAtomic) {
auto [M, N, K, L] = problem_shape;
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar);
if (args.ptr_scalar != nullptr) {
return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter);
Expand Down Expand Up @@ -700,7 +701,9 @@ struct Sm90RowReduction {
reduction_buffer = nullptr;
}
else if constexpr (FinalReduction) {
auto [M, N, K, L] = problem_shape;
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;

auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
Expand Down Expand Up @@ -735,7 +738,8 @@ struct Sm90RowReduction {
}

size_t workspace_size = 0;
auto [M, N, K, L] = problem_shape;
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
// Increment by size of reduction buffer
workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
Expand All @@ -750,8 +754,9 @@ struct Sm90RowReduction {
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
if constexpr (IsAtomic) {
auto [M, N, K, L] = problem_shape;
Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow);
if (args.ptr_row != nullptr) {
return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter);
Expand All @@ -761,7 +766,6 @@ struct Sm90RowReduction {
else
#endif
if constexpr (FinalReduction) {
auto [M, N, K, L] = problem_shape;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute);
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
Expand Down Expand Up @@ -1290,7 +1294,9 @@ struct Sm90ColReduction {
reduction_buffer = nullptr;
}
else if constexpr (FinalReduction) {
auto [M, N, K, L] = problem_shape;
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;

auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute);
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
Expand Down Expand Up @@ -1325,7 +1331,8 @@ struct Sm90ColReduction {
}

size_t workspace_size = 0;
auto [M, N, K, L] = problem_shape;
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};

// Increment by size of reduction buffer
Expand All @@ -1342,8 +1349,9 @@ struct Sm90ColReduction {
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
#if !defined(CUTLASS_SKIP_REDUCTION_INIT)
auto problem_shape_mnkl = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_mnkl;
if constexpr (IsAtomic) {
auto [M, N, K, L] = problem_shape;
Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol);
if (args.ptr_col != nullptr) {
return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter);
Expand All @@ -1353,7 +1361,6 @@ struct Sm90ColReduction {
else
#endif
if constexpr (FinalReduction) {
auto [M, N, K, L] = problem_shape;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute);
tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment);
Expand Down

0 comments on commit 5366879

Please sign in to comment.