Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for: Groupwise scaling along M for FP8 gemm #2095

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -557,13 +557,13 @@ bool verify(const Options<RasterOrderOptions> &options) {
auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
cute::make_layout(
cute::make_shape(blockscale_m, blockscale_k, options.l),
cute::make_stride(blockscale_k, 1, blockscale_m * blockscale_k)
cute::make_stride(1, blockscale_m, blockscale_m * blockscale_k)
)
);
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
cute::make_layout(
cute::make_shape(blockscale_n, blockscale_k, options.l),
cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k)
cute::make_stride(1, blockscale_n, blockscale_n * blockscale_k)
)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,14 +396,17 @@ template <typename GroupScaleConfig>
void initialize(const Options<RasterOrderOptions> &options) {

using TileShape = typename GroupScaleConfig::TileShape;
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM;
const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN;

assert(options.m % ScaleGranularityM == 0);
assert(options.n % ScaleGranularityN == 0);

// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access.
auto groupscale_m = cute::get<0>(gemm_problem_shape) / ScaleGranularityM;
auto groupscale_n = cute::get<1>(gemm_problem_shape) / ScaleGranularityN;
auto blockscale_k = cute::get<2>(blockscale_shape);

stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
Expand Down Expand Up @@ -575,13 +578,17 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile
//
// Compute reference output
//
const int ScaleGranularityM = get<0>(TileShape_{}) / ScaleMsPerTile;
const int ScaleGranularityN = get<1>(TileShape_{}) / ScaleNsPerTile;

// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{})));
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
auto groupscale_m = get<0>(gemm_problem_shape) / ScaleGranularityM;
auto groupscale_n = get<1>(gemm_problem_shape) / ScaleGranularityN;

// Create instantiation for device reference gemm kernel
auto A = cute::make_tensor(tensor_A.host_data(),
Expand Down Expand Up @@ -617,14 +624,14 @@ bool verify(const Options<RasterOrderOptions> &options, const int ScaleMsPerTile

auto blockscale_A = cute::make_tensor(blockscale_tensor_A.host_data(),
cute::make_layout(
cute::make_shape(blockscale_m, ScaleMsPerTile, blockscale_k, options.l),
cute::make_stride(blockscale_k * ScaleMsPerTile, 1, ScaleMsPerTile, blockscale_m * blockscale_k * ScaleMsPerTile)
cute::make_shape(groupscale_m, blockscale_k, options.l),
cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k)
)
);
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
cute::make_layout(
cute::make_shape(blockscale_n, ScaleNsPerTile, blockscale_k, options.l),
cute::make_stride(blockscale_k * ScaleNsPerTile, 1, ScaleNsPerTile, blockscale_n * blockscale_k * ScaleNsPerTile)
cute::make_shape(groupscale_n, blockscale_k, options.l),
cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k)
)
);

Expand Down Expand Up @@ -708,6 +715,31 @@ int run(Options<RasterOrderOptions> &options)
const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;

bool skip = false;

if (options.m % ScaleGranularityM != 0) {
std::cout << "Skippig (m size: " << options.m << " less then ScaleGranularityM: " << ScaleGranularityM << "):" << std::endl;
skip = true;
}

if (options.n % ScaleGranularityN != 0) {
std::cout << "Skippig (n size: " << options.m << " less then ScaleGranularityN: " << ScaleGranularityM << "):" << std::endl;
skip = true;
}

if (options.k % size<2>(TileShape{}) != 0) {
std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;
skip = true;
}

if (!skip) std::cout << "Running: " << std::endl;
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;

if (skip) return -1;

initialize<GroupScaleConfig>(options);

// Instantiate CUTLASS kernel depending on templates
Expand Down Expand Up @@ -768,10 +800,6 @@ int run(Options<RasterOrderOptions> &options)
raster = "Along M";
}

std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,19 @@ void gett_mainloop(
}
}

int64_t block_m = m / kBlockM;
int64_t block_n = n / kBlockN;
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l);
cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, _, l);

const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape());
const int ScaleGranularityN = cute::size<1>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleB.shape());
assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape()));
assert(cute::size<1>(typename MainloopParams::TileShape{}) == ScaleGranularityN * cute::size<1>(mainloop_params.ScaleB.shape()));
const int M = cute::size<0>(mainloop_params.A.layout());
const int N = cute::size<0>(mainloop_params.B.layout());
const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA);
const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB);
assert(ScaleGranularityM && M % ScaleGranularityM == 0
&& "ScaleGranularityM must divide M");
assert(ScaleGranularityN && N % ScaleGranularityN == 0
&& "ScaleGranularityN must divide N");

cute::Tensor blockscale_A = domain_offset(
make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l));
cute::Tensor blockscale_B = domain_offset(
make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l));

// Compute on this k-block
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
Expand Down Expand Up @@ -257,9 +261,12 @@ void gett_mainloop(
}
}

int m_size = std::min(static_cast<int64_t>(kBlockM), cute::size<0>(mainloop_params.A.layout()) - m);
int n_size = std::min(static_cast<int64_t>(kBlockN), cute::size<0>(mainloop_params.B.layout()) - n);

// do compute
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int n_b = 0; n_b < kBlockN; ++n_b) {
for (int m_b = 0; m_b < m_size; ++m_b) {
for (int n_b = 0; n_b < n_size; ++n_b) {
acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]);
}
}
Expand All @@ -269,9 +276,9 @@ void gett_mainloop(
// (b) Zero-out partial temporary (acc_temp),
// (c) Update permanent (accu)
if ((k+1) % kBlockK == 0) {
for (int m_b = 0; m_b < kBlockM; ++m_b) {
for (int m_b = 0; m_b < m_size; ++m_b) {
auto scale_a_m_b = scale_a[m_b / ScaleGranularityM];
for (int n_b = 0; n_b < kBlockN; ++n_b) {
for (int n_b = 0; n_b < n_size; ++n_b) {
auto scale_b_n_b = scale_b[n_b / ScaleGranularityN];
ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b;
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
Expand Down
Loading