Skip to content

Commit

Permalink
[COMMON] Add workgroups reordering to distribute using forall
Browse files Browse the repository at this point in the history
Adds an option to reorder workgroups. If set to transpose swaps the
workgroup attribute x and y.
  • Loading branch information
pashu123 committed Jan 11, 2025
1 parent 9f93691 commit aa73d40
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 0 deletions.
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ createTileAndDistributeToWorkgroupsPass(
int32_t maxWorkgroupParallelDims,
linalg::DistributionMethod distributionMethod);

// Pass to tile and distribute using scf.forall with workgroup reordering.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createTileAndDistributeToWorkgroupsWithReordering(
bool reorderWorkgroupsWithTranspose);

//----------------------------------------------------------------------------//
// CodeGen Common Patterns
//----------------------------------------------------------------------------//
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,10 @@ def TileAndDistributeToWorkgroupsUsingForallOpPass :
"scf::SCFDialect",
"tensor::TensorDialect",
];
let options = [
Option<"strategy", "strategy", "std::string", /*default=*/"",
"Workgroup reordering strategy, one of: '' (none), 'transpose'">,
];
}

def TileLargeTensorsPass :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,31 @@ namespace {
struct TileAndDistributeToWorkgroupsUsingForallOpPass final
: public impl::TileAndDistributeToWorkgroupsUsingForallOpPassBase<
TileAndDistributeToWorkgroupsUsingForallOpPass> {
TileAndDistributeToWorkgroupsUsingForallOpPass(bool strategy)
: transposeWorkgroups(strategy) {}

using Base::Base;
void runOnOperation() override;

LogicalResult initializeOptions(
StringRef options,
function_ref<LogicalResult(const Twine &)> errorHandler) override {
if (failed(Pass::initializeOptions(options, errorHandler))) {
return failure();
}
auto selectedStrategy = llvm::StringSwitch<FailureOr<bool>>(strategy)
.Case("", false)
.Case("transpose", true)
.Default(failure());
if (failed(selectedStrategy))
return failure();

transposeWorkgroups = *selectedStrategy;
return success();
}

private:
bool transposeWorkgroups = false;
};

} // namespace
Expand Down Expand Up @@ -190,6 +213,28 @@ pruneDroppedLoops(ArrayRef<Attribute> inputs,
return prunedAttrs;
}

// Checks whether we have static dimension for all the loop bounds and steps.
// This is a requirement if the reordering strategy is set to `transpose`.
static bool checkStaticLoopBounds(scf::ForallOp forallOp) {

SmallVector<OpFoldResult> mixedLbs = forallOp.getMixedLowerBound();
SmallVector<OpFoldResult> mixedUbs = forallOp.getMixedUpperBound();
SmallVector<OpFoldResult> mixedSteps = forallOp.getMixedStep();

for (auto [index, lb, ub, step] :
llvm::enumerate(mixedLbs, mixedUbs, mixedSteps)) {

std::optional<int64_t> lbVal = getConstantIntValue(lb);
std::optional<int64_t> ubVal = getConstantIntValue(ub);
std::optional<int64_t> stepVal = getConstantIntValue(step);

if (!(lbVal && ubVal && stepVal)) {
return false;
}
}
return true;
}

/// Find dimensions of the loop that are unit-trip count and drop them from the
/// distributed dimensions.
static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
Expand Down Expand Up @@ -516,6 +561,20 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
// TODO: run producer and consumer fusion in one worklist.
fuseProducersOfSlices(rewriter, newFusionOpportunities,
tileAndFuseOptions, newLoop);
forallOp = newLoop;
}

// Reorder the workgroups if the strategy is set to `transpose`.
// This just transposes the first two dimensions of the workgroup i.e., the
// #iree.codegen.workgroup_id_x and #iree.codegen.workgroup_id_y.
// Only reorders if the loop bounds are static.
if (transposeWorkgroups) {
SmallVector<Attribute> mappingAttrs(forallOp.getMappingAttr().getValue());
int64_t mappingSize = mappingAttrs.size();
if (checkStaticLoopBounds(forallOp) && mappingAttrs.size() >= 2) {
std::swap(mappingAttrs[mappingSize - 1], mappingAttrs[mappingSize - 2]);
forallOp.setMappingAttr(ArrayAttr::get(context, mappingAttrs));
}
}
}

Expand All @@ -538,4 +597,10 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {

return;
}
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createTileAndDistributeToWorkgroupsWithReordering(
bool reorderWorkgroupsWithTranspose) {
return std::make_unique<TileAndDistributeToWorkgroupsUsingForallOpPass>(
reorderWorkgroupsWithTranspose);
}
} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-tile-and-distribute-to-workgroups-using-forall-op, cse))" --mlir-print-local-scope --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-tile-and-distribute-to-workgroups-using-forall-op{strategy=transpose}, cse))" --mlir-print-local-scope --split-input-file %s | FileCheck %s --check-prefix=TRANSPOSE

func.func @matmul_tensors(%0 : tensor<?x?xf32>, %1 : tensor<?x?xf32>, %2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%3 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 0]]>}
Expand Down Expand Up @@ -701,3 +702,66 @@ func.func @consumer_fuse_scatter(%arg0: tensor<3x2048x2048xf32>,
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
// CHECK-SAME: ins(%[[SRC]], %[[IND_SLICE]]{{.*}} outs(%[[DEST_SLICE]]
// CHECK: tensor.parallel_insert_slice %[[SCATTER]] into %[[DEST]][0, %[[ID1]], %[[ID2]]]

// -----

func.func @dont_transpose_dynamic(%0 : tensor<?x?xf32>, %1 : tensor<?x?xf32>, %2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%3 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 0]]>}
ins(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %3 : tensor<?x?xf32>
}

// TRANSPOSE-LABEL: func @dont_transpose_dynamic(
// TRANSPOSE: scf.forall
// TRANSPOSE: [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]

// -----

func.func @transpose_static(%0 : tensor<128x128xf32>, %1 : tensor<128x128xf32>, %2 : tensor<128x128xf32>) -> tensor<128x128xf32> {
%3 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 0]]>}
ins(%0, %1 : tensor<128x128xf32>, tensor<128x128xf32>)
outs(%2 : tensor<128x128xf32>) -> tensor<128x128xf32>
return %3 : tensor<128x128xf32>
}

// TRANSPOSE-LABEL: func @transpose_static(
// TRANSPOSE: scf.forall
// TRANSPOSE: [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]

// -----

func.func @only_transpose_x_y(%7 : tensor<128x128x128x128xf32>, %8 : tensor<128x128x128x128xf32>) -> tensor<128x128x128x128xf32> {
%9 = tensor.empty() : tensor<128x128x128x128xf32>
%10 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%7, %8 : tensor<128x128x128x128xf32>, tensor<128x128x128x128xf32>)
outs(%9 : tensor<128x128x128x128xf32>)
attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[2, 64, 64, 64]]>} {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
%11 = arith.addf %arg0, %arg1 : f32
linalg.yield %11 : f32
} -> tensor<128x128x128x128xf32>
return %10 : tensor<128x128x128x128xf32>
}

// TRANSPOSE-LABEL: func @only_transpose_x_y(
// TRANSPOSE: scf.forall
// TRANSPOSE: mapping = [#iree_codegen.workgroup_mapping<z:1>, #iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]

// -----

// Incase of less than 2 workgroup_mapping, don't apply transpose.
func.func @dont_transpose_less(%0 : tensor<128x128xf32>, %1 : tensor<128x128xf32>, %2 : tensor<128x128xf32>) -> tensor<128x128xf32> {
%3 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 0, 0]]>}
ins(%0, %1 : tensor<128x128xf32>, tensor<128x128xf32>)
outs(%2 : tensor<128x128xf32>) -> tensor<128x128xf32>
return %3 : tensor<128x128xf32>
}

// TRANSPOSE-LABEL: func @dont_transpose_less(
// TRANSPOSE: scf.forall
// TRANSPOSE: [#iree_codegen.workgroup_mapping<x>]

0 comments on commit aa73d40

Please sign in to comment.