Skip to content

Commit f96b601

Browse files
committed
Update tile and distribute to enable workgroup reordering
1 parent 1d9f5b0 commit f96b601

File tree

7 files changed

+168
-27
lines changed

7 files changed

+168
-27
lines changed

compiler/src/iree/compiler/Codegen/Common/Passes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include <limits>
1616

17+
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
1718
#include "iree/compiler/Codegen/Common/PassUtils.h"
1819
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1920
#include "iree/compiler/Codegen/Utils/Utils.h"
@@ -94,6 +95,11 @@ createTileAndDistributeToWorkgroupsPass(
9495
int32_t maxWorkgroupParallelDims,
9596
linalg::DistributionMethod distributionMethod);
9697

98+
// Pass to tile and distribute using scf.forall with workgroup reordering.
99+
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
100+
createTileAndDistributeToWorkgroupsWithReordering(
101+
ReorderWorkgroupsStrategy strategy);
102+
97103
//----------------------------------------------------------------------------//
98104
// CodeGen Common Patterns
99105
//----------------------------------------------------------------------------//

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,10 @@ def TileAndDistributeToWorkgroupsUsingForallOpPass :
632632
"scf::SCFDialect",
633633
"tensor::TensorDialect",
634634
];
635+
let options = [
636+
Option<"strategy", "strategy", "std::string", /*default=*/"",
637+
"Workgroup reordering strategy, one of: '' (none), 'transpose'">,
638+
];
635639
}
636640

637641
def TileLargeTensorsPass :

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
78
#include "iree/compiler/Codegen/Common/Passes.h"
89
#include "iree/compiler/Codegen/Common/Transforms.h"
910
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
@@ -33,8 +34,34 @@ namespace {
3334
struct TileAndDistributeToWorkgroupsUsingForallOpPass final
3435
: public impl::TileAndDistributeToWorkgroupsUsingForallOpPassBase<
3536
TileAndDistributeToWorkgroupsUsingForallOpPass> {
37+
TileAndDistributeToWorkgroupsUsingForallOpPass(
38+
ReorderWorkgroupsStrategy strategy)
39+
: reorderingStrategy(strategy) {}
40+
3641
using Base::Base;
3742
void runOnOperation() override;
43+
44+
LogicalResult initializeOptions(
45+
StringRef options,
46+
function_ref<LogicalResult(const Twine &)> errorHandler) override {
47+
if (failed(Pass::initializeOptions(options, errorHandler))) {
48+
return failure();
49+
}
50+
auto selectedStrategy =
51+
llvm::StringSwitch<FailureOr<ReorderWorkgroupsStrategy>>(strategy)
52+
.Case("", ReorderWorkgroupsStrategy::None)
53+
.Case("transpose", ReorderWorkgroupsStrategy::Transpose)
54+
.Default(failure());
55+
if (failed(selectedStrategy))
56+
return failure();
57+
58+
reorderingStrategy = *selectedStrategy;
59+
return success();
60+
}
61+
62+
private:
63+
ReorderWorkgroupsStrategy reorderingStrategy =
64+
ReorderWorkgroupsStrategy::None;
3865
};
3966

4067
} // namespace
@@ -190,6 +217,28 @@ pruneDroppedLoops(ArrayRef<Attribute> inputs,
190217
return prunedAttrs;
191218
}
192219

220+
// Checks whether we have static dimension for all the loop bounds and steps.
221+
// This is a requirement if the reordering strategy is set to `transpose`.
222+
static bool checkStaticLoopBounds(scf::ForallOp forallOp) {
223+
224+
SmallVector<OpFoldResult> mixedLbs = forallOp.getMixedLowerBound();
225+
SmallVector<OpFoldResult> mixedUbs = forallOp.getMixedUpperBound();
226+
SmallVector<OpFoldResult> mixedSteps = forallOp.getMixedStep();
227+
228+
for (auto [index, lb, ub, step] :
229+
llvm::enumerate(mixedLbs, mixedUbs, mixedSteps)) {
230+
231+
std::optional<int64_t> lbVal = getConstantIntValue(lb);
232+
std::optional<int64_t> ubVal = getConstantIntValue(ub);
233+
std::optional<int64_t> stepVal = getConstantIntValue(step);
234+
235+
if (!(lbVal && ubVal && stepVal)) {
236+
return false;
237+
}
238+
}
239+
return true;
240+
}
241+
193242
/// Find dimensions of the loop that are unit-trip count and drop them from the
194243
/// distributed dimensions.
195244
static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
@@ -516,6 +565,20 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
516565
// TODO: run producer and consumer fusion in one worklist.
517566
fuseProducersOfSlices(rewriter, newFusionOpportunities,
518567
tileAndFuseOptions, newLoop);
568+
forallOp = newLoop;
569+
}
570+
571+
// Reorder the workgroups if the strategy is set to `transpose`.
572+
// This just transposes the first two dimensions of the workgroup i.e., the
573+
// #iree.codegen.workgroup_id_x and #iree.codegen.workgroup_id_y.
574+
// Only reorders if the loop bounds are static.
575+
if (reorderingStrategy == ReorderWorkgroupsStrategy::Transpose) {
576+
SmallVector<Attribute> mappingAttrs(forallOp.getMappingAttr().getValue());
577+
int64_t mappingSize = mappingAttrs.size();
578+
if (checkStaticLoopBounds(forallOp) && mappingAttrs.size() >= 2) {
579+
std::swap(mappingAttrs[mappingSize - 1], mappingAttrs[mappingSize - 2]);
580+
forallOp.setMappingAttr(ArrayAttr::get(context, mappingAttrs));
581+
}
519582
}
520583
}
521584

@@ -538,4 +601,10 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
538601

539602
return;
540603
}
604+
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
605+
createTileAndDistributeToWorkgroupsWithReordering(
606+
ReorderWorkgroupsStrategy strategy) {
607+
return std::make_unique<TileAndDistributeToWorkgroupsUsingForallOpPass>(
608+
strategy);
609+
}
541610
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-tile-and-distribute-to-workgroups-using-forall-op, cse))" --mlir-print-local-scope --split-input-file %s | FileCheck %s
2+
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-tile-and-distribute-to-workgroups-using-forall-op{strategy=transpose}, cse))" --mlir-print-local-scope --split-input-file %s | FileCheck %s --check-prefix=REORDER
23

34
func.func @matmul_tensors(%0 : tensor<?x?xf32>, %1 : tensor<?x?xf32>, %2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
45
%3 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 0]]>}
@@ -672,3 +673,66 @@ func.func @v_shaped_graph(%0: tensor<12xf32>, %1: tensor<12xf32>) -> tensor<12xf
672673
// CHECK-DAG: %[[RIGHT:.+]] = linalg.generic {{.*}} ins(%[[SLICE1]]
673674
// CHECK: linalg.generic {{.*}} ins(%[[LEFT]], %[[RIGHT]]
674675
// CHECK: return %[[RESULT]]
676+
677+
// -----
678+
679+
func.func @dont_transpose_dynamic(%0 : tensor<?x?xf32>, %1 : tensor<?x?xf32>, %2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
680+
%3 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 0]]>}
681+
ins(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>)
682+
outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
683+
return %3 : tensor<?x?xf32>
684+
}
685+
686+
// TRANSPOSE-LABEL: func @dont_transpose_dynamic(
687+
// TRANSPOSE: scf.forall
688+
// TRANSPOSE: [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]
689+
690+
// -----
691+
692+
func.func @transpose_static(%0 : tensor<128x128xf32>, %1 : tensor<128x128xf32>, %2 : tensor<128x128xf32>) -> tensor<128x128xf32> {
693+
%3 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64, 0]]>}
694+
ins(%0, %1 : tensor<128x128xf32>, tensor<128x128xf32>)
695+
outs(%2 : tensor<128x128xf32>) -> tensor<128x128xf32>
696+
return %3 : tensor<128x128xf32>
697+
}
698+
699+
// TRANSPOSE-LABEL: func @transpose_static(
700+
// TRANSPOSE: scf.forall
701+
// TRANSPOSE: [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]
702+
703+
// -----
704+
705+
func.func @only_transpose_x_y(%7 : tensor<128x128x128x128xf32>, %8 : tensor<128x128x128x128xf32>) -> tensor<128x128x128x128xf32> {
706+
%9 = tensor.empty() : tensor<128x128x128x128xf32>
707+
%10 = linalg.generic {
708+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
709+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
710+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
711+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
712+
ins(%7, %8 : tensor<128x128x128x128xf32>, tensor<128x128x128x128xf32>)
713+
outs(%9 : tensor<128x128x128x128xf32>)
714+
attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[2, 64, 64, 64]]>} {
715+
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
716+
%11 = arith.addf %arg0, %arg1 : f32
717+
linalg.yield %11 : f32
718+
} -> tensor<128x128x128x128xf32>
719+
return %10 : tensor<128x128x128x128xf32>
720+
}
721+
722+
// TRANSPOSE-LABEL: func @only_transpose_x_y(
723+
// TRANSPOSE: scf.forall
724+
// TRANSPOSE: mapping = [#iree_codegen.workgroup_mapping<z:1>, #iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]
725+
726+
// -----
727+
728+
// Incase of less than 2 workgroup_mapping, don't apply transpose.
729+
func.func @dont_transpose_less(%0 : tensor<128x128xf32>, %1 : tensor<128x128xf32>, %2 : tensor<128x128xf32>) -> tensor<128x128xf32> {
730+
%3 = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 0, 0]]>}
731+
ins(%0, %1 : tensor<128x128xf32>, tensor<128x128xf32>)
732+
outs(%2 : tensor<128x128xf32>) -> tensor<128x128xf32>
733+
return %3 : tensor<128x128xf32>
734+
}
735+
736+
// TRANSPOSE-LABEL: func @dont_transpose_less(
737+
// TRANSPOSE: scf.forall
738+
// TRANSPOSE: [#iree_codegen.workgroup_mapping<x>]

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,11 @@ static void addBufferizePasses(OpPassManager &funcPassManager) {
186186
static void tileAndDistributeToWorkgroup(
187187
OpPassManager &funcPassManager, bool useForall,
188188
std::optional<ConvertToDestinationPassingStylePassOptions>
189-
convertToDpsOptions = ConvertToDestinationPassingStylePassOptions{}) {
189+
convertToDpsOptions = ConvertToDestinationPassingStylePassOptions{},
190+
ReorderWorkgroupsStrategy strategy = ReorderWorkgroupsStrategy::None) {
190191
if (useForall) {
191192
funcPassManager.addPass(
192-
createTileAndDistributeToWorkgroupsUsingForallOpPass());
193+
createTileAndDistributeToWorkgroupsWithReordering(strategy));
193194
} else {
194195
funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass(
195196
kNumMaxParallelDims,
@@ -772,10 +773,11 @@ static void addVectorBufferizePasses(OpPassManager &funcPassManager) {
772773
void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
773774
const GPUPipelineOptions &options,
774775
bool usePadToModelSharedMemcpy) {
775-
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true);
776-
777776
ReorderWorkgroupsStrategy reorderStrategy =
778777
getReorderWorkgroupsStrategy(options.reorderStrategy);
778+
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true,
779+
std::nullopt, reorderStrategy);
780+
779781
funcPassManager.addPass(
780782
createReorderWorkgroups(reorderStrategy, canReorderWorkgroups));
781783

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_user_vector_distribute.mlir

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,16 @@ hal.executable public @main_0_dispatch_0 {
3333
// OPT-OUT-LABEL: func.func @main_0_dispatch_0_matmul_transpose_b
3434
// OPT-OUT: memref.alloc() : memref<128x32xf16, #gpu.address_space<workgroup>>
3535
// OPT-OUT: memref.alloc() : memref<128x32xf16, #gpu.address_space<workgroup>>
36-
// OPT-OUT-DAG: %[[WG_Y:.+]] = hal.interface.workgroup.id[1] : index
37-
// OPT-OUT-DAG: %[[WG_X:.+]] = hal.interface.workgroup.id[0] : index
38-
// OPT-OUT-DAG: arith.muli %[[WG_Y]], %{{.+}} : index
39-
// OPT-OUT-DAG: arith.addi %{{.+}}, %[[WG_X]] : index
40-
// OPT-OUT: scf.for
36+
// OPT-OUT: scf.forall
37+
// OPT-OUT: scf.for
38+
// OPT-OUT: } {mapping = [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]}
4139

4240
// OPT-IN-LABEL: func.func @main_0_dispatch_0_matmul_transpose_b
4341
// OPT-IN: memref.alloc() : memref<128x32xf16, #gpu.address_space<workgroup>>
4442
// OPT-IN: memref.alloc() : memref<128x32xf16, #gpu.address_space<workgroup>>
45-
// OPT-IN-DAG: %[[WG_Y:.+]] = hal.interface.workgroup.id[1] : index
46-
// OPT-IN-DAG: %[[WG_X:.+]] = hal.interface.workgroup.id[0] : index
47-
// OPT-IN: scf.for
43+
// OPT-IN: scf.forall
44+
// OPT-IN: scf.for
45+
// OPT-IN: } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
4846

4947
func.func @main_0_dispatch_0_matmul_transpose_b_2048x10240x1280_f16xf16xf32()
5048
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {
@@ -108,20 +106,16 @@ hal.executable public @main_0_dispatch_0 {
108106
// OPT-OUT-LABEL: func.func @main_0_dispatch_0_matmul_transpose_b
109107
// OPT-OUT: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
110108
// OPT-OUT: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
111-
// OPT-OUT-DAG: %[[WG_Y:.+]] = hal.interface.workgroup.id[1] : index
112-
// OPT-OUT-DAG: %[[WG_X:.+]] = hal.interface.workgroup.id[0] : index
113-
// OPT-OUT-DAG: arith.muli %[[WG_Y]], %{{.+}} : index
114-
// OPT-OUT-DAG: arith.addi %{{.+}}, %[[WG_X]] : index
115-
// OPT-OUT: scf.for
109+
// OPT-OUT: scf.forall
110+
// OPT-OUT: scf.for
111+
// OPT-OUT: } {mapping = [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]}
116112

117113
// OPT-IN-LABEL: func.func @main_0_dispatch_0_matmul_transpose_b
118114
// OPT-IN: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
119115
// OPT-IN: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
120-
// OPT-IN-DAG: %[[WG_Y:.+]] = hal.interface.workgroup.id[1] : index
121-
// OPT-IN-DAG: %[[WG_X:.+]] = hal.interface.workgroup.id[0] : index
122-
// OPT-IN-DAG: arith.muli %[[WG_Y]], %{{.+}} : index
123-
// OPT-IN-DAG: arith.addi %{{.+}}, %[[WG_X]] : index
124-
// OPT-IN: scf.for
116+
// OPT-IN: scf.forall
117+
// OPT-IN: scf.for
118+
// OPT-IN: } {mapping = [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]}
125119
func.func @main_0_dispatch_0_matmul_transpose_b_2048x10240x1280_f16xf16xf32()
126120
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {
127121
gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <Transpose>> // enable the 'reorderWorkgroups' pass.
@@ -180,9 +174,9 @@ hal.executable public @main_0_dispatch_0 {
180174
// OPT-OUT-LABEL: func.func @main_0_dispatch_0_matmul_transpose_b
181175
// OPT-OUT: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
182176
// OPT-OUT: memref.alloc() : memref<128x36xf16, #gpu.address_space<workgroup>>
183-
// OPT-OUT-DAG: hal.interface.workgroup.id[1] : index
184-
// OPT-OUT-DAG: hal.interface.workgroup.id[0] : index
185-
// OPT-OUT-NEXT: scf.for
177+
// OPT-OUT: scf.forall
178+
// OPT-OUT: scf.for
179+
// OPT-OUT: } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
186180
func.func @main_0_dispatch_0_matmul_transpose_b_2048x10240x1280_f16xf16xf32()
187181
attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64, {
188182
gpu_pipeline_options = #iree_gpu.pipeline_options<reorder_workgroups_strategy = <None>> // Disable the 'reorderWorkgroups' pass.

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ hal.executable private @attention_20x4096x64x4096x64 {
10231023
// Check that we only use alloc for Q, K, and V. No shared memory for S is
10241024
// needed because the intrinsic layout mathes.
10251025
// MEMORY-LABEL: func.func @attention_20x4096x64x4096x64()
1026-
// MEMORY-COUNT-4: memref.alloc
1026+
// MEMORY-COUNT-3: memref.alloc
10271027
// MEMORY-NOT: memref.alloc
10281028

10291029
// -----
@@ -1090,6 +1090,7 @@ hal.executable private @attention_multiple_m_transpose {
10901090

10911091
// Check that we only use alloc for Q, K, and V. No shared memory for S is
10921092
// needed because the intrinsic layout mathes.
1093+
// TODO: With forall distribution it's allocating memory for S.
10931094
// MEMORY-LABEL: func.func @attention_multiple_m_transpose()
10941095
// MEMORY-COUNT-4: memref.alloc
10951096
// MEMORY-NOT: memref.alloc
@@ -1159,7 +1160,7 @@ hal.executable private @attention_mfma_32x32x8 {
11591160
// Check that we only use alloc for Q, K, and V. No shared memory for S is
11601161
// needed because the intrinsic layout mathes.
11611162
// MEMORY-LABEL: func.func @attention_mfma_32x32x8()
1162-
// MEMORY-COUNT-3: memref.alloc
1163+
// MEMORY-COUNT-4: memref.alloc
11631164
// MEMORY-NOT: memref.alloc
11641165

11651166
// -----
@@ -1311,3 +1312,4 @@ module {
13111312

13121313
// MEMORY-LABEL: func.func @attention_gather_k
13131314
// MEMORY-COUNT-3: memref.alloc
1315+
// MEMORY-NOT: memref.alloc

0 commit comments

Comments
 (0)