Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CK TILE] Block universal gemm lds<->vgpr optimizations #1906

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
524 changes: 220 additions & 304 deletions include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp

Large diffs are not rendered by default.

28 changes: 18 additions & 10 deletions include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ struct GemmPipelineAgBgCrImplBase
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
}

template <typename ADramBlockWindowTmp, typename ALdsTensorView>
CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view) const
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&) const
{
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;

Expand All @@ -88,17 +89,21 @@ struct GemmPipelineAgBgCrImplBase
auto a_copy_lds_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});

auto a_lds_gemm_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto a_lds_gemm_window =
make_tile_window(a_lds_block_view,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsLoadTileDistr{});

return make_tuple(std::move(a_copy_dram_window),
std::move(a_copy_lds_window),
std::move(a_lds_gemm_window));
}

template <typename BDramBlockWindowTmp, typename BLdsTensorView>
CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view) const
template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr&) const
{
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;

Expand All @@ -117,8 +122,11 @@ struct GemmPipelineAgBgCrImplBase
auto b_copy_lds_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});

auto b_lds_gemm_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_lds_gemm_window =
make_tile_window(b_lds_block_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsLoadTileDistr{});

return make_tuple(std::move(b_copy_dram_window),
std::move(b_copy_lds_window),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,17 +341,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// A/B tiles in LDS
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);

// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());

// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);

// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);

// Block GEMM
auto block_gemm = BlockGemm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,26 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto& a_lds_block = ab_lds_blocks.at(I0{});
auto& b_lds_block = ab_lds_blocks.at(I1{});

// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};

// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
auto a_windows =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
auto& a_copy_dram_window = a_windows.at(I0{});
auto& a_copy_lds_window = a_windows.at(I1{});
auto& a_lds_gemm_window = a_windows.at(I2{});

// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
auto b_windows =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
auto& b_copy_dram_window = b_windows.at(I0{});
auto& b_copy_lds_window = b_windows.at(I1{});
auto& b_lds_gemm_window = b_windows.at(I2{});
Expand Down Expand Up @@ -484,18 +492,26 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto& a_lds_block = ab_lds_blocks.at(I0{});
auto& b_lds_block = ab_lds_blocks.at(I1{});

// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};

// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
auto a_windows =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
auto& a_copy_dram_window = a_windows.at(I0{});
auto& a_copy_lds_window = a_windows.at(I1{});
auto& a_lds_gemm_window = a_windows.at(I2{});

// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
auto b_windows =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
auto& b_copy_dram_window = b_windows.at(I0{});
auto& b_copy_lds_window = b_windows.at(I1{});
auto& b_lds_gemm_window = b_windows.at(I2{});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,25 @@ struct GemmPipelineAGmemBGmemCRegV1
auto b_copy_lds_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});

// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());

// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_lds_gemm_window =
make_tile_window(a_lds_block,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_lds_load_tile_distr);

// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_lds_gemm_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_lds_load_tile_distr);

// Block GEMM
auto block_gemm = BlockGemm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,28 @@ struct GemmPipelineAGmemBGmemCRegV2
{0, 0},
b_copy_dram_window.get_tile_distribution());

// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();

// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(decltype(block_gemm)::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(decltype(block_gemm)::MakeBBlockDistributionEncode());

// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_lds_gemm_window =
make_tile_window(a_lds_block,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_lds_load_tile_distr);

// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});

// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
auto b_lds_gemm_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_lds_load_tile_distr);

// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
Expand Down