diff --git a/compiler/include/byteir/Transforms/Passes.h b/compiler/include/byteir/Transforms/Passes.h index 03dfe839e..d7215705d 100644 --- a/compiler/include/byteir/Transforms/Passes.h +++ b/compiler/include/byteir/Transforms/Passes.h @@ -33,6 +33,7 @@ #include "byteir/Transforms/LoopUnroll.h" #include "byteir/Transforms/MemoryPlanning.h" #include "byteir/Transforms/RemoveFuncBody.h" +#include "byteir/Transforms/ReorderMemrefCopy.h" #include "byteir/Transforms/RewriteOpToStdCall.h" #include "byteir/Transforms/SetArgShape.h" #include "byteir/Transforms/SetSpace.h" diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index b92f1de90..5bd9877b8 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -312,6 +312,18 @@ def RewriteOpToStdCall : Pass<"rewrite-op-to-std-call", "ModuleOp"> { ]; } +//===----------------------------------------------------------------------===// +// ReorderMemrefCopy +//===----------------------------------------------------------------------===// + +def ReorderMemrefCopy : Pass<"reorder-memref-copy", "mlir::func::FuncOp"> { + let summary = "Reorder memref copyOp to overlap IO and compute."; + let constructor = "mlir::createReorderMemrefCopyPass()"; + let dependentDialects = [ + "mlir::memref::MemRefDialect", + ]; +} + //===----------------------------------------------------------------------===// // OneShotBufferize //===----------------------------------------------------------------------===// diff --git a/compiler/include/byteir/Transforms/ReorderMemrefCopy.h b/compiler/include/byteir/Transforms/ReorderMemrefCopy.h new file mode 100644 index 000000000..66489733d --- /dev/null +++ b/compiler/include/byteir/Transforms/ReorderMemrefCopy.h @@ -0,0 +1,33 @@ +//===- ReorderMemrefCopy.h ------------------------------------*--- C++ -*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_TRANSFORMS_REORDERMEMREFCOPY_H +#define BYTEIR_TRANSFORMS_REORDERMEMREFCOPY_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +std::unique_ptr> createReorderMemrefCopyPass(); + +} // namespace mlir + +#endif // BYTEIR_TRANSFORMS_REORDERMEMREFCOPY_H diff --git a/compiler/lib/Pipelines/ByreOpt.cpp b/compiler/lib/Pipelines/ByreOpt.cpp index 541a490f4..efed94eb6 100644 --- a/compiler/lib/Pipelines/ByreOpt.cpp +++ b/compiler/lib/Pipelines/ByreOpt.cpp @@ -60,6 +60,8 @@ void createByreOptPipelineImpl(OpPassManager &pm, const std::string &entryFunc, } anchoredPM.addPass(createConvertMemrefToByrePass()); anchoredPM.addPass(createCanonicalizerPass()); + anchoredPM.addPass(createReorderMemrefCopyPass()); + anchoredPM.addPass(createCanonicalizerPass()); pm.addNestedPass(createAnchoredPipelinePass( ByreDialect::getEntryPointFunctionAttrName(), anchoredPM)); @@ -73,4 +75,4 @@ void mlir::createByreOptPipeline(OpPassManager &pm, invokeOpPassPipelineBuilder(createByreOptPipelineImpl, pm, options.entryFunc, options.appendArgTypes, options.disableMemoryPlanning); -} \ No newline at end of file +} diff --git a/compiler/lib/Transforms/CMakeLists.txt b/compiler/lib/Transforms/CMakeLists.txt index f93e65498..67a859ce5 100644 --- a/compiler/lib/Transforms/CMakeLists.txt +++ b/compiler/lib/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_library(ByteIRTransforms LoopUnroll.cpp MemoryPlanning.cpp RemoveFuncBody.cpp + ReorderMemrefCopy.cpp RewriteOpToStdCall.cpp SetArgShape.cpp SetSpace.cpp @@ -46,4 +47,4 @@ add_mlir_library(ByteIRTransforms MLIRPDLInterpDialect MLIRSCFDialect MLIRTransforms -) \ No newline at end of file +) diff --git a/compiler/lib/Transforms/ReorderMemrefCopy.cpp b/compiler/lib/Transforms/ReorderMemrefCopy.cpp new file mode 100644 index 000000000..5b9fd1938 --- /dev/null +++ b/compiler/lib/Transforms/ReorderMemrefCopy.cpp @@ -0,0 +1,151 @@ +//===- ReorderMemrefCopy.cpp -----------------------------------*--- C++ +//-*-===// +// +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// Some code comes from TestLoopUnrolling.cpp in LLVM project +// Original license: +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "byteir/Transforms/ReorderMemrefCopy.h" +#include "byteir/Dialect/Byre/ByreDialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dominance.h" +#include "llvm/ADT/SmallSet.h" +#include + +#include "./PassDetail.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::memref; + +namespace { + +struct ReorderMemrefCopyPass + : public ReorderMemrefCopyBase { + ReorderMemrefCopyPass() : ReorderMemrefCopyBase() {} + + void runOnOperation() override; +}; // ReorderMemrefCopyPass + +SmallVector getAllAlias(Value val) { + SmallVector alias; + auto rootVal = val; + + while (rootVal.getDefiningOp()) { + auto defOp = rootVal.getDefiningOp(); + if (!isa_and_nonnull(defOp)) + break; + rootVal = defOp->getOperand(0); + } + std::queue workq; + workq.emplace(rootVal); + + while (!workq.empty()) { + auto cur = workq.front(); + workq.pop(); + alias.push_back(cur); + for (auto user : cur.getUsers()) { + if (isa_and_nonnull(user)) { + // NB. Just assume viewlike-op has single result. + for (auto v : user->getResults()) + workq.emplace(v); + } + } + } + return alias; +} + +// Find the last use of val before op. +// NB. Returns `nullptr` while no matched op found. +Operation *lastUseBeforeOp(Value val, Operation *op, DominanceInfo &domInfo) { + auto alias = getAllAlias(val); + // FIXME. early return while size of alias is large, this takes too long to + // analysis alias. + if (alias.size() > 100) + return nullptr; + Operation *targetOp = nullptr; + for (auto aliasVal : alias) { + for (auto &&user : aliasVal.getUsers()) { + + if (user == op || !user->isBeforeInBlock(op)) + continue; + + if (!targetOp) { + targetOp = user; + continue; + } + + if (domInfo.properlyDominates(targetOp, user)) { + targetOp = user; + } + } + } + + return targetOp; +} + +void ReorderMemrefCopyPass::runOnOperation() { + func::FuncOp func = getOperation(); + auto &domInfo = getAnalysis(); + + // collect all `byre.copy`. + SmallVector byreCopyOps; + func.getBody().walk([&](byre::CopyOp op) { + byreCopyOps.emplace_back(op); + return WalkResult::advance(); + }); + + auto reorder = [&](byre::CopyOp &op) { + auto src = op.getSource(); + auto dst = op.getTarget(); + // TODO(chhuang) enable dst which is not arguement. + if (dst.getDefiningOp()) + return; + auto srcLastUse = lastUseBeforeOp(src, op.getOperation(), domInfo); + auto dstLastUse = lastUseBeforeOp(dst, op.getOperation(), domInfo); + Operation *last = nullptr; + if (srcLastUse && dstLastUse) { + if (domInfo.properlyDominates(srcLastUse, dstLastUse)) + last = dstLastUse; + else + last = srcLastUse; + } else if (srcLastUse || dstLastUse) + last = srcLastUse ? srcLastUse : dstLastUse; + if (last && last != op.getOperation()) { + op->moveAfter(last); + } + return; + }; + + // try to reorder candidates. + for (auto op : byreCopyOps) { + reorder(op); + } +} + +} // namespace + +std::unique_ptr> +mlir::createReorderMemrefCopyPass() { + return std::make_unique(); +}