Skip to content

Commit

Permalink
fix < 128
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Feb 10, 2025
1 parent 8fd2846 commit db87722
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster

constexpr int ScaleMsPerTile = 128; // TODO Fix < 128
constexpr int ScaleMsPerTile = 128;
constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile;

using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ struct CollectiveMma<
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
Tensor mScaleA_mkl = get<2>(load_inputs);
Tensor mScaleB_nkl = get<3>(load_inputs);
auto scales_m = get<0>(mScaleA_mkl.shape());
auto scales_m = size<0>(mScaleA_mkl);

Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());

Expand All @@ -396,7 +396,7 @@ struct CollectiveMma<

// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
Layout<Shape<_32>>{}, Layout<Shape<_1>>{}); // (1,1,1)
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
Expand Down Expand Up @@ -440,7 +440,8 @@ struct CollectiveMma<
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
#pragma unroll
for (int i = 0; i < size(tApA_ScaleA); ++i) {
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) <
std::min(scales_m, (m_coord + 1) * ScaleMsPerTile);
}

// Mainloop
Expand Down

0 comments on commit db87722

Please sign in to comment.