Skip to content

Commit

Permalink
towards landable soln
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Jan 9, 2025
1 parent 31c04a3 commit dc8c92e
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@

#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/IRMapping.h"

#define DEBUG_TYPE "iree-amdaie-linalg-function-outlining"

Expand All @@ -23,7 +21,8 @@ namespace {
/// Utility to outline the linalg compute op.
static FailureOr<func::FuncOp> outline(IRRewriter &rewriter, ModuleOp moduleOp,
linalg::LinalgOp computeOp,
const std::string &funcName) {
const std::string &funcName,
bool noAliasFinalArg) {
// Form outlined FunctionType.
for (const auto &operand : computeOp->getOperands()) {
// Function signatures where the memrefs have layouts (strides / offsets)
Expand Down Expand Up @@ -65,6 +64,18 @@ static FailureOr<func::FuncOp> outline(IRRewriter &rewriter, ModuleOp moduleOp,
// arguments.
Operation *clonedComputeOp = rewriter.clone(*computeOp, operandMap);

if (noAliasFinalArg) {
auto args = func.getArguments();
auto it = std::find_if(args.rbegin(), args.rend(), [](BlockArgument arg) {
return isa<MemRefType>(arg.getType());
});
if (it != args.rend()) {
int index = args.size() - std::distance(args.rbegin(), it) - 1;
auto noAliasAttrName = LLVM::LLVMDialect::getNoAliasAttrName();
func.setArgAttr(index, noAliasAttrName, rewriter.getUnitAttr());
}
}

// Create terminator op returning the cloned compute op's results.
rewriter.setInsertionPointToEnd(funcBody);
rewriter.create<func::ReturnOp>(clonedComputeOp->getLoc(), ValueRange({}));
Expand Down Expand Up @@ -143,7 +154,7 @@ class AMDAIELinalgFunctionOutliningPass
}

FailureOr<func::FuncOp> maybeFunc =
outline(rewriter, moduleOp, computeOp, funcName);
outline(rewriter, moduleOp, computeOp, funcName, noAliasFinalArg);

if (succeeded(maybeFunc)) {
computeOpToOutlinedFuncMap[computeOp] = maybeFunc.value();
Expand Down Expand Up @@ -175,7 +186,6 @@ void AMDAIELinalgFunctionOutliningPass::runOnOperation() {
if (failed(maybeFunc)) return WalkResult::interrupt();
func::FuncOp func = maybeFunc.value();


rewriter.setInsertionPoint(computeOp);
rewriter.create<func::CallOp>(computeOp.getLoc(), func,
computeOp->getOperands());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ LogicalResult AIEDeviceBuilder::coreFuncCallOpToAIE(
StringRef fnName = oldCallOp.getCallee();
auto fnDecl = dyn_cast_if_present<func::FuncOp>(
SymbolTable::lookupSymbolIn(moduleOp, fnName));

assert(fnDecl && "expected function declaration");
// Check the mapper to see if we've already created a new function declaration
// with the new function type. If not, create the same. We need to create a
Expand All @@ -433,26 +434,19 @@ LogicalResult AIEDeviceBuilder::coreFuncCallOpToAIE(
SymbolTable::Visibility::Private);
newFnDecl->setAttr("llvm.bareptr", rewriter.getBoolAttr(true));

// Add the 'noalias' attribute to all argument attributes, if the type is
// memref:
auto noAliasAttrName = LLVM::LLVMDialect::getNoAliasAttrName();
// The read-only attribute:
auto readOnlyAttrName = LLVM::LLVMDialect::getReadonlyAttrName();
for (int i = 0; i < newArgs.size(); ++i) {
if (isa<MemRefType>(newArgs[i].getType())) {
newFnDecl.setArgAttr(i, noAliasAttrName, rewriter.getUnitAttr());
}
}
(void)readOnlyAttrName;
(void)noAliasAttrName;
fnDecl.getBody().cloneInto(&(newFnDecl.getBody()), mapper);
if (ArrayAttr oldAttrs = fnDecl.getAllArgAttrs()) {
newFnDecl.setAllArgAttrs(oldAttrs);
}

mapper.map(fnDecl.getOperation(), newFnDecl.getOperation());
fnDecl = newFnDecl;
}
// Fetch the new function declaration and create the new func.call op.
auto newFnDecl = cast<func::FuncOp>(mapper.lookupOrDefault(fnDecl));
rewriter.create<func::CallOp>(oldCallOp->getLoc(), newFnDecl, newArgs);
toBeErased.push_back(oldCallOp);

return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,22 @@ def AMDAIELinalgFunctionOutlining :
"Replace all outlined functions with a function that does nothing, "
"i.e. it just returns. Useful for measuring the performance of data "
"movement to/from the device -- by doing zero compute, all time is spent "
"moving data to/from the AIE cores.">
"moving data to/from the AIE cores.">,
Option<"noAliasFinalArg", "no-alias-final-arg", "bool", /*default=*/"true",
"A developer only option. When 'true' (the default), "
"the final memref argument of the outlined function "
"will have the 'llvm.noalias' attribute attached to "
"it. The motivation for having this attribute is that "
"sometimes the matmul code generated in llvm's opt is "
"much (2x) faster. The motivation for adding it manually "
"without any analysis is that llvm/peano cannot always "
"infer that this attribute can safely be attached, "
"because (I suppose) the analysis of all call sites, "
"i.e. checking that the final argument is not aliased "
"to any other arguments is too complicated/expensive. "
"The addition of this attribute in this pass is exposed "
"as an option, as there is no guarantee that it is valid "
"-- the final argument could in theory alias another argument. ">
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// CHECK-LABEL: func.func private @generic_matmul_0_outlined
// CHECK-SAME: (%[[LHS:.*]]: memref<4x8xbf16>,
// CHECK-SAME: %[[RHS:.*]]: memref<8x4xbf16>,
// CHECK-SAME: %[[OUT:.*]]: memref<4x4xf32>) {
// CHECK-SAME: %[[OUT:.*]]: memref<4x4xf32> {llvm.noalias}) {
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUT]] :
Expand Down Expand Up @@ -76,8 +76,8 @@ func.func @repeated_identical_matmul(%A: memref<4x8xbf16>, %B: memref<8x4xbf16>,
// Test demonstrating different kind of matmul operations being mapped to a
// unique corresponding outlined function.

// CHECK-DAG: func.func private @[[MATMUL_K6:.*]]({{.*}}memref<4x6xbf16>, {{.*}}memref<6x4xbf16>, {{.*}}memref<4x4xf32>)
// CHECK-DAG: func.func private @[[MATMUL_K4:.*]]({{.*}}memref<4x4xbf16>, {{.*}}memref<4x4xbf16>, {{.*}}memref<4x4xf32>)
// CHECK-DAG: func.func private @[[MATMUL_K6:.*]]({{.*}}memref<4x6xbf16>, {{.*}}memref<6x4xbf16>, {{.*}}memref<4x4xf32> {llvm.noalias})
// CHECK-DAG: func.func private @[[MATMUL_K4:.*]]({{.*}}memref<4x4xbf16>, {{.*}}memref<4x4xbf16>, {{.*}}memref<4x4xf32> {llvm.noalias})
// CHECK-NOT: func.func private
// CHECK: func.func @distinct_matmul_shapes(
// CHECK-SAME: %[[A0:.*]]: memref<4x4xbf16>, %[[B0:.*]]: memref<4x4xbf16>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Currently this is the default:
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-linalg-function-outlining{empty-functions=false no-alias-final-arg=true})" --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK_NOTEMPTY_NOALIAS

// CHECK_NOTEMPTY_NOALIAS: func.func private @
// CHECK_NOTEMPTY_NOALIAS-SAME: memref<4xbf16>,
// CHECK_NOTEMPTY_NOALIAS-SAME: memref<bf16> {llvm.noalias}) {
// CHECK_NOTEMPTY_NOALIAS: linalg.generic
// CHECK_NOTEMPTY_NOALIAS: return
// CHECK_NOTEMPTY_NOALIAS: func.func @reduction

// A run to check the option empty-functions=true:
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-linalg-function-outlining{empty-functions=true no-alias-final-arg=true})" --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK_EMPTY_NOALIAS

// CHECK_EMPTY_NOALIAS: func.func private @
// CHECK_EMPTY_NOALIAS-SAME: memref<4xbf16>,
// CHECK_EMPTY_NOALIAS-SAME: memref<bf16> {llvm.noalias}) {
// CHECK_EMPTY_NOALIAS-NOT: linalg.generic
// CHECK_EMPTY_NOALIAS: return
// CHECK_EMPTY_NOALIAS: func.func @reduction

// A run to check the option no-alias-final-arg=false:
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-linalg-function-outlining{empty-functions=false no-alias-final-arg=false})" --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK_NOTEMPTY_ALIAS

// CHECK_NOTEMPTY_ALIAS: func.func private @
// CHECK_NOTEMPTY_ALIAS-SAME: memref<4xbf16>,
// CHECK_NOTEMPTY_ALIAS-SAME: memref<bf16>) {
// CHECK_NOTEMPTY_ALIAS: linalg.generic
// CHECK_NOTEMPTY_ALIAS: return
// CHECK_NOTEMPTY_ALIAS: func.func @reduction

func.func @reduction(%A: memref<4xbf16>, %B: memref<bf16>) {
%c2 = arith.constant 2 : index
%tile = amdaie.tile(%c2, %c2)
%1 = amdaie.core(%tile, in : [], out : []) {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>],
iterator_types = ["reduction"]
} ins(%A: memref<4xbf16>) outs(%B : memref<bf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}
amdaie.end
}
return
}



This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
%lock_5 = amdaie.lock(%tile_0_1(1), 0)
%buffer_3 = amdaie.buffer(%tile_0_2) : memref<2048xi32, 2 : i32>
%lock_6 = amdaie.lock(%tile_0_2(0), 1)
%lock_7 = amdaie.lock(%tile_0_2(1), 0)
%lock_7 = amdaie.lock(%tile_0_2(1), 0)
%0 = amdaie.logicalobjectfifo.from_buffers({%buffer}, {%lock}, {%lock_1}) : memref<4096xi32, 1 : i32> -> !amdaie.logicalobjectfifo<memref<4096xi32, 1 : i32>, 1>
%1 = amdaie.logicalobjectfifo.from_buffers({%buffer_1}, {%lock_2}, {%lock_3}) : memref<4096xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<4096xi32, 2 : i32>, 1>
%channel = amdaie.channel(%tile_0_1, 0, port_type = DMA, direction = MM2S)
Expand Down Expand Up @@ -625,11 +625,11 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// -----

// Tests lowering of a circular DMA operation to a DMA chain.
// Checks that a circular DMA operation with an 'outer' repetition which is not
// part of the objectFifo's repetition count (same repetition on each
// Checks that a circular DMA operation with an 'outer' repetition which is not
// part of the objectFifo's repetition count (same repetition on each
// connection), is lowered to a chain of `dma_bd` operations with a lock
// acquire at the beginning of the chain and a lock release at the end. Note
// that this lowering to multiple `dma_bd` operations is needed because
// that this lowering to multiple `dma_bd` operations is needed because
// `stride == 0` is not supported in hardware and/or because there are more
// dimensions needed than supported in `dma_bd`.
// CHECK: aie.device(npu1_4col)
Expand Down Expand Up @@ -695,10 +695,10 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
// -----

// Tests lowering of a circular DMA operation to a DMA chain.
// Checks that a circular DMA operation with an 'inner' repetition (a dimension
// with `stride == 0` after a dimension with `stride != 0`), is lowered to a
// Checks that a circular DMA operation with an 'inner' repetition (a dimension
// with `stride == 0` after a dimension with `stride != 0`), is lowered to a
// chain of `dma_bd` operations with a lock acquire at the beginning of the chain
// and a lock release at the end. Note that this lowering to multiple `dma_bd`
// and a lock release at the end. Note that this lowering to multiple `dma_bd`
// operations is needed because `stride == 0` is not supported in hardware and/or
// because there are more dimensions needed than supported in `dma_bd`.
// CHECK: aie.device(npu1_4col)
Expand Down Expand Up @@ -899,8 +899,9 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}


// CHECK: aie.device
// CHECK: func.func private @ukernel_B(memref<i32, 2 : i32> {llvm.noalias}, index, memref<f32, 2 : i32> {llvm.noalias}, index) attributes {llvm.bareptr = true}
// CHECK: func.func private @ukernel_A(memref<i32, 2 : i32> {llvm.noalias}, index) attributes {llvm.bareptr = true}
// CHECK-DAG: func.func private @f_with_arg_attr(%arg0: memref<i32, 2 : i32> {llvm.noalias})
// CHECK-DAG: func.func private @ukernel_B(memref<i32, 2 : i32>, index, memref<f32, 2 : i32>, index) attributes {llvm.bareptr = true}
// CHECK-DAG: func.func private @ukernel_A(memref<i32, 2 : i32>, index) attributes {llvm.bareptr = true}
// CHECK: %[[TILE_0_2:.*]] = aie.tile(0, 2)
// CHECK: %[[BUFFER_0_2:.*]] = aie.buffer(%[[TILE_0_2]]) {sym_name = "buff_0"} : memref<4096xi32, 2 : i32>
// CHECK: %[[LOCK_0_2:.*]] = aie.lock(%[[TILE_0_2]], 0) {init = 1 : i8, sym_name = "lock_0"}
Expand All @@ -926,6 +927,8 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func private @ukernel_A(memref<i32, 2 : i32>, index) attributes {link_with = "/path/to/ukernel.o", llvm.bareptr = true}
func.func private @ukernel_B(memref<i32, 2 : i32>, index, memref<f32, 2 : i32>, index) attributes {link_with = "/path/to/ukernel.o", llvm.bareptr = true}
func.func private @f_with_arg_attr(%0 : memref<i32, 2 : i32> {llvm.noalias}) attributes {llvm.bareptr = true} { return }

func.func @core_ukernel() {
amdaie.workgroup {
%c0 = arith.constant 0 : index
Expand All @@ -946,6 +949,7 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
%base_buffer0, %offset0, %sizes0:2, %strides0:2 = memref.extract_strided_metadata %4 : memref<64x64xf32, 2 : i32> -> memref<f32, 2 : i32>, index, index, index, index, index
func.call @ukernel_A(%base_buffer, %c0) : (memref<i32, 2 : i32>, index) -> ()
func.call @ukernel_B(%base_buffer, %c0, %base_buffer0, %c0) : (memref<i32, 2 : i32>, index, memref<f32, 2 : i32>, index) -> ()
func.call @f_with_arg_attr(%base_buffer) : (memref<i32, 2 : i32>) -> ()
amdaie.use_lock(%lock, Release(1))
amdaie.use_lock(%lock_2, Release(1))
amdaie.end
Expand Down

0 comments on commit dc8c92e

Please sign in to comment.