From 31bedfac277ba5817fc8117598078c18f77ba2de Mon Sep 17 00:00:00 2001 From: Eric Schweitz Date: Thu, 12 Dec 2024 16:47:28 -0800 Subject: [PATCH] [core] Move the QIR peephole patterns out of tablegen. Moves all the peephole patterns to explicit patterns and out of tablegen. Fix #476. Signed-off-by: Eric Schweitz --- .../cudaq/Optimizer/CodeGen/CMakeLists.txt | 4 - include/cudaq/Optimizer/CodeGen/Peephole.h | 9 +- include/cudaq/Optimizer/CodeGen/Peephole.td | 177 ------------- lib/Optimizer/CodeGen/CMakeLists.txt | 1 - lib/Optimizer/CodeGen/ConvertToQIR.cpp | 2 + lib/Optimizer/CodeGen/ConvertToQIRProfile.cpp | 2 + lib/Optimizer/CodeGen/PeepholePatterns.inc | 238 ++++++++++++++++++ lib/Optimizer/CodeGen/VerifyQIRProfile.cpp | 2 +- 8 files changed, 245 insertions(+), 190 deletions(-) delete mode 100644 include/cudaq/Optimizer/CodeGen/Peephole.td create mode 100644 lib/Optimizer/CodeGen/PeepholePatterns.inc diff --git a/include/cudaq/Optimizer/CodeGen/CMakeLists.txt b/include/cudaq/Optimizer/CodeGen/CMakeLists.txt index 5c2c15f8d7..c0140aefa4 100644 --- a/include/cudaq/Optimizer/CodeGen/CMakeLists.txt +++ b/include/cudaq/Optimizer/CodeGen/CMakeLists.txt @@ -12,7 +12,3 @@ add_cudaq_dialect_doc(CodeGenDialect codegen) set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name OptCodeGen) add_public_tablegen_target(OptCodeGenPassIncGen) - -set(LLVM_TARGET_DEFINITIONS Peephole.td) -mlir_tablegen(Peephole.inc -gen-rewriters) -add_public_tablegen_target(OptPeepholeIncGen) diff --git a/include/cudaq/Optimizer/CodeGen/Peephole.h b/include/cudaq/Optimizer/CodeGen/Peephole.h index 4fdca9bd02..f5eae54c4b 100644 --- a/include/cudaq/Optimizer/CodeGen/Peephole.h +++ b/include/cudaq/Optimizer/CodeGen/Peephole.h @@ -38,9 +38,8 @@ inline bool isIntToPtrOp(mlir::Value operand) { static constexpr char resultIndexName[] = "result.index"; inline mlir::Value createMeasureCall(mlir::PatternRewriter &builder, - mlir::Location loc, mlir::OpResult result, + mlir::Location loc, mlir::LLVM::CallOp op, mlir::ValueRange args) { - auto op = cast(result.getDefiningOp()); auto ptrTy = cudaq::opt::getResultType(builder.getContext()); if (auto intAttr = dyn_cast_or_null(op->getAttr(resultIndexName))) { @@ -57,7 +56,7 @@ inline mlir::Value createMeasureCall(mlir::PatternRewriter &builder, inline mlir::Value createReadResultCall(mlir::PatternRewriter &builder, mlir::Location loc, - mlir::OpResult result) { + mlir::Value result) { auto i1Ty = mlir::IntegerType::get(builder.getContext(), 1); return builder .create(loc, mlir::TypeRange{i1Ty}, @@ -65,7 +64,3 @@ inline mlir::Value createReadResultCall(mlir::PatternRewriter &builder, mlir::ArrayRef{result}) .getResult(); } - -namespace { -#include "cudaq/Optimizer/CodeGen/Peephole.inc" -} diff --git a/include/cudaq/Optimizer/CodeGen/Peephole.td b/include/cudaq/Optimizer/CodeGen/Peephole.td deleted file mode 100644 index 32f32b5d21..0000000000 --- a/include/cudaq/Optimizer/CodeGen/Peephole.td +++ /dev/null @@ -1,177 +0,0 @@ -/********************************************************** -*- tablegen -*- *** - * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * - * All rights reserved. * - * * - * This source code and the accompanying materials are made available under * - * the terms of the Apache License 2.0 which accompanies this distribution. * - ******************************************************************************/ - -#ifndef NVQPP_OPTIMIZER_CODEGEN_PEEPHOLE -#define NVQPP_OPTIMIZER_CODEGEN_PEEPHOLE - -include "cudaq/Optimizer/Dialect/Quake/QuakeOps.td" -include "mlir/Dialect/LLVMIR/LLVMOps.td" -include "mlir/IR/OpBase.td" -include "mlir/IR/PatternBase.td" - -//===----------------------------------------------------------------------===// - -def InvokeOnXWithOneControl : Constraint>; - -def CreateCallCnot : NativeCodeCall< - "[&]() -> std::size_t {" - " $_builder.create($_loc," - " mlir::TypeRange{}, cudaq::opt::QIRCnot, $0.drop_front(2));" - " return 0; }()">; - -// %1 = address_of @__quantum__qis__x__ctl -// %2 = call @invokewithControlBits %1, %ctrl, %targ -// ───────────────────────────────────────────────── -// %2 = call __quantum__qis__cnot %ctrl, %targ -def XCtrlOneTargetToCNot : Pat< - (LLVM_CallOp $callee, $args, $_, $_), (CreateCallCnot $args), - [(InvokeOnXWithOneControl $callee, $args)]>; - -//===----------------------------------------------------------------------===// - -def NeedsRenaming : Constraint>; - -def CreateAddressOf : NativeCodeCall< - "$_builder.create($_loc, $0.getType()," - " $1.getValue().str() + \"__body\")">; - -// %4 = address_of @__quantum__cis__* -// ──────────────────────────────────────── -// %4 = address_of @__quantum__cis__*__body -def AddrOfCisToBase : Pat< - (LLVM_AddressOfOp:$addr $global), (CreateAddressOf $addr, $global), - [(NeedsRenaming $global)]>; - -//===----------------------------------------------------------------------===// - -// Apply special rule for `mz`. See below. -def FuncNotMeasure : Constraint>; - -def CreateCallOp : NativeCodeCall< - "[&]() -> std::size_t {" - " $_builder.create($_loc, mlir::TypeRange{}," - " mlir::FlatSymbolRefAttr::get($_builder.getContext()," - " $0.getValue().str() + \"__body\"), $1, $2, $3);" - " return 0; }()">; - -// %4 = call @__quantum__cis__* -// ────────────────────────────────── -// %4 = call @__quantum__cis__*__body -def CalleeConv : Pat< - (LLVM_CallOp $callee, $args, $fm, $bw), - (CreateCallOp $callee, $args, $fm, $bw), - [(NeedsRenaming $callee), (FuncNotMeasure:$callee)]>; - -//===----------------------------------------------------------------------===// - -def IsArrayGetElementPtrId : Constraint>; - -def EraseArrayGEPOp : NativeCodeCall< - "$_builder.create($_loc," - " cudaq::opt::getQubitType($_builder.getContext()))">; - -def EraseDeadArrayGEP : Pat< - (LLVM_CallOp:$call $callee, $_, $_, $_), (EraseArrayGEPOp), - [(IsArrayGetElementPtrId $callee), (HasNoUseOf:$call)]>; - -//===----------------------------------------------------------------------===// - -def IsaAllocateCall : Constraint>; - -def EraseArrayAllocateOp : NativeCodeCall< - "$_builder.create($_loc," - " cudaq::opt::getArrayType($_builder.getContext()))">; - -// Replace the call with a dead op to DCE. -// -// %0 = call @allocate ... : ... -> T* -// ─────────────────────────────────── -// %0 = undef : T* -def EraseArrayAlloc : Pat< - (LLVM_CallOp $callee, $_, $_, $_), (EraseArrayAllocateOp), - [(IsaAllocateCall $callee)]>; - -//===----------------------------------------------------------------------===// - -def IsaReleaseCall : Constraint>; - -def EraseArrayReleaseOp : NativeCodeCall<"static_cast(0)">; - -// Remove the release calls. This removes both array allocations as well as -// qubit singletons. -// -// call @release %5 : (!Qubit) -> () -// ───────────────────────────────── -def EraseArrayRelease : Pat< - (LLVM_CallOp $callee, $_, $_, $_), (EraseArrayReleaseOp), - [(IsaReleaseCall $callee)]>; - -//===----------------------------------------------------------------------===// - -def IsaMeasureCall : Constraint>; - -def IsaIntToPtrOperand : Constraint>; - -def CreateMeasureCall : NativeCodeCall< - "createMeasureCall($_builder, $_loc, $0, $1)">; - -// %result = call @__quantum__qis__mz(%qbit) : (!Qubit) -> i1 -// ────────────────────────────────────────────────────────────── -// call @__quantum__qis__mz_body(%qbit, %result) : (Q*, R*) -> () -def MeasureCallConv : Pat< - (LLVM_CallOp:$call $callee, $args, $_, $_), - (CreateMeasureCall $call, $args), - [(IsaMeasureCall:$callee), (IsaIntToPtrOperand $args)]>; - -//===----------------------------------------------------------------------===// - -def IsaMeasureToRegisterCall : Constraint>; - -// %result = call @__quantum__qis__mz__to__register(%qbit, i8) : (!Qubit) -> i1 -// ──────────────────────────────────────────────────────────────────────────── -// call @__quantum__qis__mz_body(%qbit, %result) : (Q*, R*) -> () -def MeasureToRegisterCallConv : Pat< - (LLVM_CallOp:$call $callee, $args, $_, $_), - (CreateMeasureCall $call, $args), - [(IsaMeasureToRegisterCall:$callee), (IsaIntToPtrOperand $args)]>; - -//===----------------------------------------------------------------------===// - -def HasI1PtrType : Constraint>; - -def HasResultType : Constraint>; - -def IsaIntAttr : Constraint()">>; - -def CreateReadResultCall : NativeCodeCall< - "createReadResultCall($_builder, $_loc, $0)">; - -// %1 = llvm.constant 1 -// %2 = llvm.inttoptr %1 : i64 -> Result* -// %3 = llvm.bitcast %2 : Result* -> i1* -// %4 = llvm.load %3 -// ───────────────────────────────────── -// %4 = call @read_result %2 -def LoadMeasureResult : Pat< - (LLVM_LoadOp:$load (LLVM_BitcastOp:$bitcast (LLVM_IntToPtrOp:$cast - (LLVM_ConstantOp $attr))), $_, $_, $_, $_, $_, $_), - (CreateReadResultCall $cast), - [(HasI1PtrType:$bitcast), (HasResultType:$cast), (IsaIntAttr:$attr)]>; - -#endif diff --git a/lib/Optimizer/CodeGen/CMakeLists.txt b/lib/Optimizer/CodeGen/CMakeLists.txt index 5c056e0e11..3739855b31 100644 --- a/lib/Optimizer/CodeGen/CMakeLists.txt +++ b/lib/Optimizer/CodeGen/CMakeLists.txt @@ -37,7 +37,6 @@ add_cudaq_library(OptCodeGen CodeGenOpsIncGen CodeGenTypesIncGen OptCodeGenPassIncGen - OptPeepholeIncGen OptTransformsPassIncGen QuakeDialect diff --git a/lib/Optimizer/CodeGen/ConvertToQIR.cpp b/lib/Optimizer/CodeGen/ConvertToQIR.cpp index 738ee66ea1..1eaba931b3 100644 --- a/lib/Optimizer/CodeGen/ConvertToQIR.cpp +++ b/lib/Optimizer/CodeGen/ConvertToQIR.cpp @@ -45,6 +45,8 @@ namespace cudaq::opt { using namespace mlir; +#include "PeepholePatterns.inc" + /// Greedy pass to match subgraphs in the IR and replace them with codegen ops. /// This step makes converting a DAG of nodes in the conversion step simpler. static LogicalResult fuseSubgraphPatterns(MLIRContext *ctx, ModuleOp module) { diff --git a/lib/Optimizer/CodeGen/ConvertToQIRProfile.cpp b/lib/Optimizer/CodeGen/ConvertToQIRProfile.cpp index 58526dc692..f3aad7c60c 100644 --- a/lib/Optimizer/CodeGen/ConvertToQIRProfile.cpp +++ b/lib/Optimizer/CodeGen/ConvertToQIRProfile.cpp @@ -32,6 +32,8 @@ using namespace mlir; +#include "PeepholePatterns.inc" + /// For a call to `__quantum__rt__qubit_allocate_array`, get the number of /// qubits allocated. static std::size_t getNumQubits(LLVM::CallOp callOp) { diff --git a/lib/Optimizer/CodeGen/PeepholePatterns.inc b/lib/Optimizer/CodeGen/PeepholePatterns.inc new file mode 100644 index 0000000000..ad6cc64fe8 --- /dev/null +++ b/lib/Optimizer/CodeGen/PeepholePatterns.inc @@ -0,0 +1,238 @@ +/****************************************************************-*- C++ -*-**** + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +namespace { + +//===----------------------------------------------------------------------===// + +// %1 = address_of @__quantum__qis__x__ctl +// %2 = call @invokewithControlBits %1, %ctrl, %targ +// ───────────────────────────────────────────────── +// %2 = call __quantum__qis__cnot %ctrl, %targ +struct XCtrlOneTargetToCNot : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp call, + PatternRewriter &rewriter) const override { + auto callee = call.getCallee(); + if (!callee) + return failure(); + auto args = call.getOperands(); + if (!callToInvokeWithXCtrlOneTarget(*callee, args)) + return failure(); + auto *ctx = rewriter.getContext(); + auto funcSymbol = FlatSymbolRefAttr::get(ctx, cudaq::opt::QIRCnot); + rewriter.replaceOpWithNewOp( + call, TypeRange{}, funcSymbol, args.drop_front(2), + call.getFastmathFlagsAttr(), call.getBranchWeightsAttr()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// + +// %4 = address_of @__quantum__cis__* +// ──────────────────────────────────────── +// %4 = address_of @__quantum__cis__*__body +struct AddrOfCisToBase : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::AddressOfOp addr, + PatternRewriter &rewriter) const override { + auto global = addr.getGlobalName(); + if (!needsToBeRenamed(global)) + return failure(); + rewriter.replaceOpWithNewOp(addr, addr.getType(), + global.str() + "__body"); + return success(); + } +}; + +//===----------------------------------------------------------------------===// + +// This rule does not apply to measurements. +// +// %4 = call @__quantum__cis__* +// ────────────────────────────────── +// %4 = call @__quantum__cis__*__body +struct CalleeConv : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp call, + PatternRewriter &rewriter) const override { + auto callee = call.getCallee(); + if (!callee) + return failure(); + if (!needsToBeRenamed(*callee) || + callee->startswith(cudaq::opt::QIRMeasure)) + return failure(); + auto *ctx = rewriter.getContext(); + auto symbol = FlatSymbolRefAttr::get(ctx, callee->str() + "__body"); + rewriter.replaceOpWithNewOp( + call, TypeRange{}, symbol, call.getOperands(), + call.getFastmathFlagsAttr(), call.getBranchWeightsAttr()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// + +// Manually erase dead calls to QIRArrayGetElementPtr1d. +struct EraseDeadArrayGEP : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp call, + PatternRewriter &rewriter) const override { + auto callee = call.getCallee(); + if (!callee) + return failure(); + if (*callee != cudaq::opt::QIRArrayGetElementPtr1d) + return failure(); + if (!call->use_empty()) + return failure(); + rewriter.eraseOp(call); + return success(); + } +}; + +//===----------------------------------------------------------------------===// + +// Replace the call with a dead op to DCE. +// +// %0 = call @allocate ... : ... -> T* +// ─────────────────────────────────── +// %0 = undef : T* +struct EraseArrayAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp call, + PatternRewriter &rewriter) const override { + auto callee = call.getCallee(); + if (!callee) + return failure(); + if (*callee != cudaq::opt::QIRArrayQubitAllocateArray) + return failure(); + auto *ctx = rewriter.getContext(); + rewriter.replaceOpWithNewOp(call, + cudaq::opt::getArrayType(ctx)); + return success(); + } +}; + +//===----------------------------------------------------------------------===// + +// Remove the release calls. This removes both array allocations as well as +// qubit singletons. +// +// call @release %5 : (!Qubit) -> () +// ───────────────────────────────── +// +struct EraseArrayRelease : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp call, + PatternRewriter &rewriter) const override { + auto callee = call.getCallee(); + if (!callee) + return failure(); + if (*callee != cudaq::opt::QIRArrayQubitReleaseArray && + *callee != cudaq::opt::QIRArrayQubitReleaseQubit) + return failure(); + rewriter.eraseOp(call); + return success(); + } +}; + +//===----------------------------------------------------------------------===// + +// %result = call @__quantum__qis__mz(%qbit) : (!Qubit) -> i1 +// ────────────────────────────────────────────────────────────── +// call @__quantum__qis__mz_body(%qbit, %result) : (Q*, R*) -> () +struct MeasureCallConv : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp call, + PatternRewriter &rewriter) const override { + auto callee = call.getCallee(); + if (!callee) + return failure(); + auto args = call.getOperands(); + if (*callee != cudaq::opt::QIRMeasure) + return failure(); + auto inttoptr = args[0].getDefiningOp(); + if (!inttoptr) + return failure(); + rewriter.replaceOp(call, + createMeasureCall(rewriter, call.getLoc(), call, args)); + return success(); + } +}; + +//===----------------------------------------------------------------------===// + +// %result = call @__quantum__qis__mz__to__register(%qbit, i8) : (!Qubit) -> i1 +// ──────────────────────────────────────────────────────────────────────────── +// call @__quantum__qis__mz_body(%qbit, %result) : (Q*, R*) -> () +struct MeasureToRegisterCallConv : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::CallOp call, + PatternRewriter &rewriter) const override { + auto callee = call.getCallee(); + if (!callee) + return failure(); + auto args = call.getOperands(); + if (*callee != cudaq::opt::QIRMeasureToRegister) + return failure(); + auto inttoptr = args[0].getDefiningOp(); + if (!inttoptr) + return failure(); + rewriter.replaceOp(call, + createMeasureCall(rewriter, call.getLoc(), call, args)); + return success(); + } +}; + +//===----------------------------------------------------------------------===// + +// %1 = llvm.constant 1 +// %2 = llvm.inttoptr %1 : i64 -> Result* +// %3 = llvm.bitcast %2 : Result* -> i1* +// %4 = llvm.load %3 +// ───────────────────────────────────── +// %4 = call @read_result %2 +struct LoadMeasureResult : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LLVM::LoadOp load, + PatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + auto bitcast = load.getAddr().getDefiningOp(); + if (!bitcast) + return failure(); + auto inttoptr = bitcast.getArg().getDefiningOp(); + if (!inttoptr) + return failure(); + auto conint = inttoptr.getArg().getDefiningOp(); + if (!conint) + return failure(); + if (bitcast.getType() != + cudaq::opt::factory::getPointerType(IntegerType::get(ctx, 1))) + return failure(); + if (inttoptr.getType() != cudaq::opt::getResultType(ctx)) + return failure(); + if (!isa(conint.getValue())) + return failure(); + + rewriter.replaceOp(load, createReadResultCall(rewriter, load.getLoc(), + inttoptr.getResult())); + return success(); + } +}; + +} // namespace diff --git a/lib/Optimizer/CodeGen/VerifyQIRProfile.cpp b/lib/Optimizer/CodeGen/VerifyQIRProfile.cpp index 02ccc932fa..6adbe833d2 100644 --- a/lib/Optimizer/CodeGen/VerifyQIRProfile.cpp +++ b/lib/Optimizer/CodeGen/VerifyQIRProfile.cpp @@ -9,7 +9,7 @@ #include "PassDetails.h" #include "cudaq/Optimizer/Builder/Intrinsics.h" #include "cudaq/Optimizer/CodeGen/Passes.h" -#include "cudaq/Optimizer/CodeGen/Peephole.h" +#include "cudaq/Optimizer/CodeGen/QIRFunctionNames.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" #include "cudaq/Todo.h" #include "nlohmann/json.hpp"