Skip to content

Commit fce996d

Browse files
committed
Expose MatmulParams::cluster_dims in python frontend
This introduces a new type `MatmulParams::ClusterDims` and wraps it in the python frontend. This is essentially just a triple of `int64_t`s but I named the elements x, y, and z to make it easy to bind.
1 parent a470821 commit fce996d

File tree

4 files changed

+48
-11
lines changed

4 files changed

+48
-11
lines changed

csrc/python_frontend/python_bindings.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,13 @@ void defineHeuristicParamBindings(py::module& nvfuser) {
647647
nvfuser, "MatmulTileRasterizationOrder")
648648
.value("column_major", MatmulParams::TileRasterizationOrder::ColumnMajor)
649649
.value("row_major", MatmulParams::TileRasterizationOrder::RowMajor);
650+
651+
DEFINECLASS(MatmulParams::ClusterDims)
652+
.PARAM(MatmulParams::ClusterDims, x)
653+
.PARAM(MatmulParams::ClusterDims, y)
654+
.PARAM(MatmulParams::ClusterDims, z)
655+
.TOSTRINGMETHOD(MatmulParams::ClusterDims);
656+
650657
py::enum_<MmaMacroEncode::Arch>(nvfuser, "MmaMacroArch")
651658
.value("no_mma", MmaMacroEncode::Arch::NoMma)
652659
.value("volta", MmaMacroEncode::Arch::Volta)
@@ -747,6 +754,7 @@ void defineHeuristicParamBindings(py::module& nvfuser) {
747754
.PARAM(MatmulParams, promote_prologue_smem_reuse)
748755
.PARAM(MatmulParams, splitk_factor)
749756
.PARAM(MatmulParams, cta_order)
757+
.PARAM(MatmulParams, cluster_dims)
750758
.PARAM(MatmulParams, mma_macro);
751759

752760
#undef PARAM

csrc/scheduler/hopper_multi_matmul.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,13 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
152152
//! Specifies the CGA dimensions by setting "cluster_dims" as fusion-managed
153153
//! data
154154
void setCGADims() const {
155-
if (params_->cluster_dims != std::tuple<int, int, int>{1, 1, 1}) {
156-
fusion_->manage("cluster_dims", params_->cluster_dims);
155+
if (params_->cluster_dims != MatmulParams::ClusterDims{1, 1, 1}) {
156+
fusion_->manage(
157+
"cluster_dims",
158+
std::tuple<int64_t, int64_t, int64_t>{
159+
params_->cluster_dims.x,
160+
params_->cluster_dims.y,
161+
params_->cluster_dims.z});
157162
}
158163
}
159164

csrc/scheduler/matmul_heuristic.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,32 @@ class MatmulParams : public HeuristicParams {
193193

194194
//! This is the CGA size on Hopper+ devices. This parameter is ignored on
195195
//! Ampere and Turing.
196-
std::tuple<int64_t, int64_t, int64_t> cluster_dims = {1, 1, 1};
196+
struct ClusterDims {
197+
int64_t x = 1;
198+
int64_t y = 1;
199+
int64_t z = 1;
200+
201+
bool operator==(const ClusterDims& other) const {
202+
return x == other.x && y == other.y && z == other.z;
203+
}
204+
205+
bool operator!=(const ClusterDims& other) const {
206+
return !(*this == other);
207+
}
208+
209+
std::string toString() const {
210+
std::stringstream ss;
211+
ss << x << " " << y << " " << z;
212+
return ss.str();
213+
}
214+
215+
size_t hash() const {
216+
return std::hash<size_t>{}(
217+
(static_cast<size_t>(x) << 32) |
218+
(static_cast<size_t>(y)) << 16) |
219+
(static_cast<size_t>(z));
220+
}
221+
} cluster_dims;
197222

198223
std::string toString() const override {
199224
std::stringstream ss;
@@ -216,8 +241,7 @@ class MatmulParams : public HeuristicParams {
216241
: "column-major")
217242
<< "\n"
218243
<< "Grid swizzle factor: " << grid_swizzle_factor << "\n"
219-
<< "Cluster dimensions: " << std::get<0>(cluster_dims) << " "
220-
<< std::get<1>(cluster_dims) << " " << std::get<2>(cluster_dims) << "\n"
244+
<< "Cluster dimensions: " << cluster_dims.toString() << "\n"
221245
<< "Use shared memory epilogue: " << use_smem_epilogue << "\n"
222246
<< "Promote re-use of prologue shared memory: "
223247
<< promote_prologue_smem_reuse << "\n"

csrc/scheduler/matmul_heuristic_plugin.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) {
139139
setConfigTile(config->cta_tile, mparams->tile_sizes.cta_tile);
140140
setConfigTile(config->warp_tile, mparams->tile_sizes.warp_tile);
141141
setConfigTile(config->instruction_tile, getMmaOpShape(mparams->mma_macro));
142-
config->cluster_dims[0] = std::get<0>(mparams->cluster_dims);
143-
config->cluster_dims[1] = std::get<1>(mparams->cluster_dims);
144-
config->cluster_dims[2] = std::get<2>(mparams->cluster_dims);
142+
config->cluster_dims[0] = mparams->cluster_dims.x;
143+
config->cluster_dims[1] = mparams->cluster_dims.y;
144+
config->cluster_dims[2] = mparams->cluster_dims.z;
145145
config->splitk_factor = mparams->splitk_factor;
146146
config->grid_swizzle_factor = mparams->grid_swizzle_factor;
147147
config->cta_order =
@@ -164,9 +164,9 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) {
164164
};
165165
setGemmTile(mparams->tile_sizes.cta_tile, config->cta_tile);
166166
setGemmTile(mparams->tile_sizes.warp_tile, config->warp_tile);
167-
std::get<0>(mparams->cluster_dims) = config->cluster_dims[0];
168-
std::get<1>(mparams->cluster_dims) = config->cluster_dims[1];
169-
std::get<2>(mparams->cluster_dims) = config->cluster_dims[2];
167+
mparams->cluster_dims.x = config->cluster_dims[0];
168+
mparams->cluster_dims.y = config->cluster_dims[1];
169+
mparams->cluster_dims.z = config->cluster_dims[2];
170170
mparams->circular_buffer_options.smem_circular_buffer_stage =
171171
config->load_stages;
172172
mparams->circular_buffer_options.smem_circular_buffer_prefetch_gap =

0 commit comments

Comments
 (0)