Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
Expand Down Expand Up @@ -154,6 +155,10 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ComputeDataType = remove_cvref_t<typename Traits::ComputeDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;

// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;

using Base = BlockGemmBQuantBase<Problem_>;

using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
Expand Down Expand Up @@ -271,12 +276,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;

template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_>(a_warp_tile_, a_block_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_>(b_warp_tile_, b_block_window);
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_block_window);
// If B datatype were pkint4 it would be converted prior to storing in LDS
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_block_window);
}

// C += A * B
Expand Down Expand Up @@ -397,11 +410,16 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
MakeCBlockTile();
}

template <typename ASmemBlockWindow, typename BSmemBlockWindow>
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}

// C += A * B
Expand Down
66 changes: 49 additions & 17 deletions include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,6 @@ struct QuantGemmKernel

if constexpr(kQuantType == QuantType::BQuantGrouped)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
Expand Down Expand Up @@ -782,7 +781,9 @@ struct QuantGemmKernel
{
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amd-khushbu Please not this part as our discussion yesterday. With the preshuffle quant our tensor layout should actually only be RowMajor. We could modify that in your PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll revert the comment here to avoid creating any confusion.

"PreshuffleQuant with BQuantGrouped currently only supports "
"ColumnMajor BQ layout");

return MakePreshuffledQuantTensorView<
GemmPipeline::KPerBlockBQ,
Expand All @@ -792,14 +793,35 @@ struct QuantGemmKernel
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});

if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
// For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN]
// Dimensions: [K/QuantGroupK, N/QuantGroupN]
// Strides: [N/QuantGroupN, 1]
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK),
integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
// For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK]
// Dimensions: [N/QuantGroupN, K/QuantGroupK]
// Strides: [K/QuantGroupK, 1]
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN),
integer_divide_ceil(kargs.K, QuantGroupSize::kK)),
make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
}
}
else
Expand Down Expand Up @@ -1024,10 +1046,10 @@ struct QuantGemmKernel
}
else if constexpr(kQuantType == QuantType::BQuantGrouped)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
Expand All @@ -1043,13 +1065,23 @@ struct QuantGemmKernel
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
}
else
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
}
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,18 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
CK_TILE_DEVICE constexpr auto
GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);

using YPerTile = number<NPerBlockBQ>;
using XPerTile = number<KPerBlockBQ>;
using YPerTile =
std::conditional_t<std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
number<NPerBlockBQ>,
number<KPerBlockBQ>>;
using XPerTile =
std::conditional_t<std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
number<KPerBlockBQ>,
number<NPerBlockBQ>>;

auto bq_copy_dram_window =
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
make_tuple(YPerTile{}, XPerTile{}),
bq_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBQDramTileDistribution<Problem>());
return bq_copy_dram_window;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;

static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
// Support both RowMajor and ColumnMajor layouts for BQ
if constexpr(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetABQGlobalVectorLoadSize<Problem, BQDataType, KPerBlockBQ, NPerBlockBQ>();
}
else
{
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
}
}

template <typename Problem>
Expand All @@ -52,7 +60,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
WarpTile::at(I2),
Problem::TransposeC>;

static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
if constexpr(PreshuffleQuant)
{
using TileEncodingPattern = tile_distribution_encoding_pattern_bq<
Expand All @@ -62,18 +69,21 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
NPerBlock / WarpGemm::kN,
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
VecLoadSize,
BQLayout,
PreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else
{
// KPerTile and NPerTile are LOGICAL dimensions (K quant groups and N quant groups)
using TileEncodingPattern =
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
NPerBlockBQ,
KPerBlockBQ,
Problem::QuantGroupSize::kN>;
KPerBlockBQ, // Logical K dimension
NPerBlockBQ, // Logical N dimension
Problem::QuantGroupSize::kN,
BQLayout>;

return TileEncodingPattern::make_2d_static_tile_distribution();
}
Expand Down
Loading
Loading