-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accept Hopper matmuls and update default heuristic #3579
Changes from 33 commits
cb13e25
cd2d1e1
1692b5d
c8097b9
3751734
076d56a
700df1f
89f4887
15200fe
bba9c88
e5def4c
9a691c6
80c1232
fbde1e2
e311250
7305778
3f7b6a6
d8f80e2
6c17823
486e4d9
5c6f504
c9f1805
0f1ad25
7ce8938
da8f27f
7c3cd60
f11516d
0252ad4
9e0118d
8f4b796
10489a9
c77f6e4
645bc74
a5b79ac
3c2b0fb
69ac0e5
c71d6b7
0290b39
3c4170e
f386536
c56d845
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -49,25 +49,44 @@ using ProblemShape = std::array<int64_t, 4>; | |||||
inline std::optional<MmaMacro> getMmaOp( | ||||||
const int dev_version, | ||||||
const ProblemShape& problem) { | ||||||
using MacroType = MmaMacro; | ||||||
const int64_t n_extent = problem[(size_t)MatmulDimRole::N]; | ||||||
|
||||||
// NOTE: A temp condition | ||||||
const ProblemShape::value_type n_extend = problem[(size_t)MatmulDimRole::N]; | ||||||
const bool use_small_n = ((n_extend % 8) == 0) && ((n_extend % 16) != 0); | ||||||
MmaMacroEncode macro_encode{MmaMacroEncode::Arch::NoMma, 16, 8, 16}; | ||||||
|
||||||
switch (dev_version) { | ||||||
case 75: | ||||||
return (use_small_n) ? MacroType::Turing_16_8_16 | ||||||
: MacroType::Turing_16_16_16; | ||||||
macro_encode.arch = MmaMacroEncode::Arch::Turing; | ||||||
if ((n_extent % 16) == 0) { | ||||||
macro_encode.n = 16; | ||||||
} | ||||||
break; | ||||||
case 80: | ||||||
case 86: | ||||||
case 89: | ||||||
case 90: // NOTE: temp use ampere matmul for hopper | ||||||
return (use_small_n) ? MacroType::Ampere_16_8_16 | ||||||
: MacroType::Ampere_16_16_16; | ||||||
macro_encode.arch = MmaMacroEncode::Arch::Ampere; | ||||||
if ((n_extent % 16) == 0) { | ||||||
macro_encode.n = 16; | ||||||
} | ||||||
break; | ||||||
case 90: | ||||||
macro_encode.arch = MmaMacroEncode::Arch::Hopper; | ||||||
macro_encode.m = 64; | ||||||
// Find the largest instruction tile that divides the problem size and is | ||||||
// a power of two | ||||||
macro_encode.n = 64; | ||||||
// TODO: enable instructions smaller than 64_64_16 | ||||||
while (macro_encode.n > 64) { | ||||||
if (n_extent % macro_encode.n != 0) { | ||||||
macro_encode.n /= 2; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently this only chooses powers of two. For small problems I think we could choose one of the other sizes. For example if |
||||||
} else { | ||||||
break; | ||||||
} | ||||||
} | ||||||
break; | ||||||
default: | ||||||
return std::nullopt; | ||||||
} | ||||||
return macro_encode; | ||||||
} | ||||||
|
||||||
//! Find the number of circular buffer stages for shared memory operands, so | ||||||
|
@@ -93,9 +112,9 @@ void limitCircularBufferingSmemOperands( | |||||
mparams->circular_buffer_options.smem_circular_buffer_stage = (int)stages; | ||||||
} | ||||||
|
||||||
//! A wrapper for core heuristics initialization. | ||||||
//! We should have already set mparams->mma_macro before calling this function. | ||||||
inline bool initCoreHeuristics( | ||||||
namespace { | ||||||
|
||||||
bool fillDefaultAmpereHeuristic( | ||||||
MatmulParams* mparams, | ||||||
const ProblemShape& problem_shape, | ||||||
const mma_utils::TensorRolesMap& tensor_roles, | ||||||
|
@@ -170,6 +189,7 @@ inline bool initCoreHeuristics( | |||||
} | ||||||
return min_size_bytes; | ||||||
}; | ||||||
// Use cp.async on Ampere if possible | ||||||
mparams->async_gmem_load_operands = isCpAsyncOperandLoadSupported( | ||||||
mparams, | ||||||
std::min( | ||||||
|
@@ -186,6 +206,171 @@ inline bool initCoreHeuristics( | |||||
return true; | ||||||
} | ||||||
|
||||||
bool fillDefaultHopperHeuristic( | ||||||
MatmulParams* mparams, | ||||||
const ProblemShape& problem_shape, | ||||||
const mma_utils::TensorRolesMap& tensor_roles, | ||||||
const size_t num_problems) { | ||||||
const auto device_prop = at::cuda::getCurrentDeviceProperties(); | ||||||
|
||||||
const GemmTile instruction_tile = getMmaOpShape(mparams->mma_macro); | ||||||
GemmTile warp_tile = {-1, -1, -1}; | ||||||
GemmTile cta_tile = {-1, -1, -1}; | ||||||
|
||||||
using DimType = decltype(GemmTile::m); | ||||||
|
||||||
// We typically use larger macros on Hopper. By default we will set the | ||||||
// warp tile equal to the macro and increase the CTA tile until we hit | ||||||
// a limit. The limits are given by the maximum number of threads per CTA. | ||||||
|
||||||
// TODO: it might be advantageous in some cases to issue multiple wgmma | ||||||
// instructions per warp group | ||||||
warp_tile = instruction_tile; | ||||||
|
||||||
// The MmaOp output is a 32-bit float which requires one register per value | ||||||
|
||||||
// The Hopper register file is 256KiB. We reduce this by a factor of 1/2 to | ||||||
// account for overhead, since not all of the registers will hold MMA | ||||||
// outputs. | ||||||
const size_t max_registers_per_sm = device_prop->regsPerMultiprocessor / 2L; | ||||||
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
const size_t regs_per_warp_group = warp_tile.m * warp_tile.n * num_problems; | ||||||
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
const auto ratiosValid = [&](const DimType m_ratio, const DimType n_ratio) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: snake_case for lambda functions.
Suggested change
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
DimType cta_m = warp_tile.m * m_ratio; | ||||||
DimType cta_n = warp_tile.n * n_ratio; | ||||||
DimType num_warp_groups = m_ratio * n_ratio; | ||||||
return | ||||||
// We store one float per CTA tile element for each matmul problem we | ||||||
// compute | ||||||
num_warp_groups * regs_per_warp_group < max_registers_per_sm | ||||||
// TMA box dimensions must be less than or equal to 256 | ||||||
&& cta_m <= 256 && | ||||||
cta_n <= 256 | ||||||
// Each warp group is 128 threads. We can only have a maximum of 1024 | ||||||
// threads per SM, or 8 warp groups. | ||||||
&& num_warp_groups <= 8 && | ||||||
// Don't extend the CTA tile beyond the problem size | ||||||
cta_m <= problem_shape[(size_t)MatmulDimRole::M] && | ||||||
cta_n <= problem_shape[(size_t)MatmulDimRole::N]; | ||||||
}; | ||||||
|
||||||
DimType m_ratio = 1; | ||||||
DimType n_ratio = 1; | ||||||
|
||||||
bool increased = true; | ||||||
while (increased) { | ||||||
DimType cta_m = warp_tile.m * m_ratio; | ||||||
DimType cta_n = warp_tile.n * n_ratio; | ||||||
increased = false; | ||||||
|
||||||
const auto tryIncreaseM = [&]() { | ||||||
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
if (ratiosValid(m_ratio * 2, n_ratio)) { | ||||||
m_ratio *= 2; | ||||||
increased = true; | ||||||
} | ||||||
return increased; | ||||||
}; | ||||||
const auto tryIncreaseN = [&]() { | ||||||
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
if (ratiosValid(m_ratio, n_ratio * 2)) { | ||||||
n_ratio *= 2; | ||||||
increased = true; | ||||||
} | ||||||
return increased; | ||||||
}; | ||||||
|
||||||
if (cta_m < cta_n) { | ||||||
// Try to increase smaller tile dimension first since square tiles are | ||||||
// optimal for reducing operand load redundancy | ||||||
if (tryIncreaseM()) { | ||||||
continue; | ||||||
} | ||||||
tryIncreaseN(); | ||||||
} else { | ||||||
if (tryIncreaseN()) { | ||||||
continue; | ||||||
} | ||||||
tryIncreaseM(); | ||||||
} | ||||||
} | ||||||
|
||||||
cta_tile = {warp_tile.m * m_ratio, warp_tile.n * n_ratio, warp_tile.k}; | ||||||
|
||||||
mparams->tile_sizes = {cta_tile, warp_tile}; | ||||||
|
||||||
// stages and async mem copy | ||||||
mparams->circular_buffer_options.smem_circular_buffer_stage = 8; | ||||||
|
||||||
// TODO: We should take the main loop structure into account here to get a | ||||||
// more accurate estimate in case of horizontal fusion | ||||||
int64_t operand_smem_per_stage = | ||||||
num_problems * 2 * (cta_tile.m + cta_tile.n) * cta_tile.k; | ||||||
// We leave a bit of space for semaphores | ||||||
int64_t max_operand_smem = device_prop->sharedMemPerBlock - (1L << 7); | ||||||
|
||||||
while (mparams->circular_buffer_options.smem_circular_buffer_stage * | ||||||
operand_smem_per_stage > | ||||||
max_operand_smem) { | ||||||
mparams->circular_buffer_options.smem_circular_buffer_stage--; | ||||||
} | ||||||
|
||||||
mparams->circular_buffer_options.circular_buffer_smem_write = | ||||||
mparams->circular_buffer_options.smem_circular_buffer_stage > 1; | ||||||
|
||||||
// Always use TMA on Hopper | ||||||
mparams->async_gmem_load_operands = true; | ||||||
|
||||||
// See here for more information: | ||||||
// https://research.colfax-intl.com/cutlass-tutorial-wgmma-hopper/ | ||||||
|
||||||
// We count the number of tiles in each dimension to determine the | ||||||
// rasterization order. The fast rasterization axis is the shortest axis, to | ||||||
// encourage L2 hits by looping over the same rows or cols more frequently. | ||||||
int64_t Mtiles = ceilDiv(problem_shape[(size_t)MatmulDimRole::M], cta_tile.m); | ||||||
int64_t Ntiles = ceilDiv(problem_shape[(size_t)MatmulDimRole::N], cta_tile.n); | ||||||
|
||||||
mparams->cta_order = Ntiles >= Mtiles | ||||||
? MatmulParams::TileRasterizationOrder::ColumnMajor | ||||||
: MatmulParams::TileRasterizationOrder::RowMajor; | ||||||
|
||||||
// We also swizzle the tiles as much as possible up to 4 tiles. Like choosing | ||||||
// the rasterization order, this is used to increase L2 locality | ||||||
mparams->grid_swizzle_factor = 4L; | ||||||
while (Mtiles % mparams->grid_swizzle_factor != 0 || | ||||||
Ntiles % mparams->grid_swizzle_factor != 0) { | ||||||
// Decrease the swizzle factor if it would result in nondivisible splits, | ||||||
// since this would unnecessarily increase the grid size. | ||||||
mparams->grid_swizzle_factor /= 2L; | ||||||
} | ||||||
// TODO: grid swizzling is currently disabled on Hopper since we cannot | ||||||
// properly inline when we swizzle unmapped loop broadcasts | ||||||
mparams->grid_swizzle_factor = 1L; | ||||||
|
||||||
// TODO: Finally, we set the CGA size | ||||||
|
||||||
return true; | ||||||
} | ||||||
|
||||||
} // namespace | ||||||
|
||||||
//! A wrapper for core heuristics initialization. | ||||||
//! We should have already set mparams->mma_macro before calling this function. | ||||||
inline bool initCoreHeuristics( | ||||||
MatmulParams* mparams, | ||||||
const ProblemShape& problem_shape, | ||||||
const mma_utils::TensorRolesMap& tensor_roles, | ||||||
const size_t num_problems) { | ||||||
if (isHopper(mparams->mma_macro)) { | ||||||
return fillDefaultHopperHeuristic( | ||||||
mparams, problem_shape, tensor_roles, num_problems); | ||||||
} else if (isAmpere(mparams->mma_macro) || isTuring(mparams->mma_macro)) { | ||||||
return fillDefaultAmpereHeuristic( | ||||||
mparams, problem_shape, tensor_roles, num_problems); | ||||||
} | ||||||
// Unsupported arch | ||||||
return false; | ||||||
} | ||||||
|
||||||
//! A helper for getting problem shape from fusion and runtime info. | ||||||
//! | ||||||
//! For a given domain, try to find the size by evaluating the extent of an | ||||||
|
@@ -790,7 +975,15 @@ std::unique_ptr<MatmulParams> getMatmulHeuristics( | |||||
mma_utils::generateSharedMemoryEpilogueHeuristics( | ||||||
mparams->tile_sizes, | ||||||
mparams->circular_buffer_options.smem_circular_buffer_stage, | ||||||
tensor_roles); | ||||||
tensor_roles, | ||||||
/*ignore_occupancy_drop=*/true); | ||||||
if (isHopper(mparams->mma_macro)) { | ||||||
// Always promote smem reuse for Hopper. This is needed because we use TMA | ||||||
// which has higher alignment requirements, so it's important that we place | ||||||
// our TMA buffers at an offset that's a multiple of 64 (like 0) if | ||||||
// possible. | ||||||
mparams->promote_prologue_smem_reuse = true; | ||||||
} | ||||||
|
||||||
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { | ||||||
debug() << mparams->toString() << std::endl; | ||||||
|
@@ -842,13 +1035,25 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { | |||||
{ | ||||||
for (const mma_utils::MatmulPattern& pattern : patterns) { | ||||||
Expr* op = pattern.output->definition(); | ||||||
if (device_prop->major >= 9 && op->isA<ReductionOp>()) { | ||||||
bool found_reduction = false; | ||||||
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) { | ||||||
if (found_reduction && | ||||||
!pattern.output->axis((int64_t)dim)->isReduction()) { | ||||||
return "Mul+Sum patterns can only be translated to MmaOp " | ||||||
"on Hopper if the reduction dim is innermost"; | ||||||
if (device_prop->major >= 9) { | ||||||
for (TensorView* operand : {pattern.A, pattern.B}) { | ||||||
if (!operand->isFusionInput() && | ||||||
(operand->definition() == nullptr || | ||||||
!operand->definition()->isA<LoadStoreOp>() || | ||||||
!operand->definition()->input(0)->isFusionInput() || | ||||||
operand->hasRoot())) { | ||||||
return "Operand " + operand->toString() + | ||||||
" must be a fusion input or non-permuting LoadStoreOp of an input on Hopper"; | ||||||
} | ||||||
} | ||||||
if (op->isA<ReductionOp>()) { | ||||||
bool found_reduction = false; | ||||||
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) { | ||||||
if (found_reduction && | ||||||
!pattern.output->axis((int64_t)dim)->isReduction()) { | ||||||
return "Mul+Sum patterns can only be translated to MmaOp " | ||||||
"on Hopper if the reduction dim is innermost"; | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
@@ -922,7 +1127,14 @@ std::string getMatmulRunTimeRejectReason( | |||||
Fusion* fusion, | ||||||
HeuristicDataCache* data_cache, | ||||||
SchedulerRuntimeInfo& runtime_info) { | ||||||
// TODO: add proper set of checks | ||||||
const auto device_prop = at::cuda::getCurrentDeviceProperties(); | ||||||
|
||||||
if (device_prop->major >= 9 && | ||||||
runtime_info.getIndexType() != DataType::Int32) { | ||||||
// See https://github.com/NVIDIA/Fuser/issues/3595 | ||||||
return "Hopper matmul is not yet supported with problem sizes requiring 64-bit indexing"; | ||||||
} | ||||||
|
||||||
return ""; | ||||||
} | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1848,40 +1848,43 @@ class MatmulTranslator : public OptInDispatch { | |
// logical domains in input and weight already). Then we form an MmaOp and | ||
// optionally add the bias tensor followed by a cast back to the input | ||
// dtype. | ||
int64_t a_dims = (int64_t)pattern_.A->getLogicalDomain().size(); | ||
int64_t b_dims = (int64_t)pattern_.B->getLogicalDomain().size(); | ||
NVF_ERROR( | ||
pattern_.A->nDims() > 1 && pattern_.B->nDims() > 1, | ||
"Cannot translate LinearOp with 1D input"); | ||
a_dims > 1 && b_dims > 1, "Cannot translate LinearOp with 1D input"); | ||
NVF_ERROR( | ||
pattern_.B->nDims() == 2, | ||
"Cannot translate LinearOp without 2D weight tensor"); | ||
b_dims == 2, "Cannot translate LinearOp without 2D weight tensor"); | ||
if (avoid_intermediates_) { | ||
MmaOp::AxisMapping axis_mapping; | ||
int64_t out_dim = pattern_.A->nDims() + 1L; | ||
int64_t out_dim = a_dims + 1L; | ||
axis_mapping.a_axes.reserve(out_dim); | ||
for (int64_t d : c10::irange(out_dim - 2L)) { | ||
axis_mapping.a_axes.push_back(d); | ||
} | ||
axis_mapping.a_axes.reserve(out_dim); | ||
for (size_t d : c10::irange(out_dim - 2)) { | ||
Comment on lines
-1862
to
-1865
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this was just due to a busted merge. |
||
axis_mapping.a_axes.push_back((int64_t)d); | ||
} | ||
axis_mapping.a_axes.push_back(-1); // missing N dimension | ||
axis_mapping.a_axes.push_back(pattern_.A->nDims() - 1); // K dimension | ||
axis_mapping.a_axes.push_back(a_dims - 1L); // K dimension | ||
|
||
axis_mapping.b_axes.reserve(out_dim); | ||
axis_mapping.b_axes.resize(out_dim, -1); | ||
axis_mapping.b_axes[out_dim - 2] = 0; // N | ||
axis_mapping.b_axes[out_dim - 1] = 1; // K | ||
|
||
int64_t num_M_dims = 1 + pattern_.A->nDims() - pattern_.B->nDims(); | ||
int64_t num_M_dims = 1 + a_dims - b_dims; | ||
|
||
// Add loop broadcasts to A and B to mimic logical broadcasts for | ||
// simpler scheduling | ||
pattern_.A->broadcast(-2); // There's always a single N dimension | ||
// Note that since operands can be shared among multiple patterns, we | ||
// should avoid modifying the operand twice. This is why we first check | ||
// for loop broadcasts. | ||
if (pattern_.A->domain()->additionalIDs().empty()) { | ||
pattern_.A->broadcast(-2); // There's always a single N dimension | ||
} | ||
|
||
for ([[maybe_unused]] size_t i : c10::irange((size_t)num_M_dims)) { | ||
// Broadcast B for every M dimension in A | ||
pattern_.B->broadcast(0); | ||
if (pattern_.B->domain()->additionalIDs().empty()) { | ||
for ([[maybe_unused]] size_t i : c10::irange((size_t)num_M_dims)) { | ||
// Broadcast B for every M dimension in A | ||
pattern_.B->broadcast(0); | ||
} | ||
} | ||
|
||
fms = fusedMultiplySum( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the grid size is not divisible by the cluster size then we get a launch error, so we should default to not use cluster dims unless explicitly handled by a heuristic.