diff --git a/lib/Dialect/Verif/Transforms/LowerContracts.cpp b/lib/Dialect/Verif/Transforms/LowerContracts.cpp index 28246ec34942..8f975717cab0 100644 --- a/lib/Dialect/Verif/Transforms/LowerContracts.cpp +++ b/lib/Dialect/Verif/Transforms/LowerContracts.cpp @@ -54,13 +54,17 @@ void cloneFanIn(OpBuilder &builder, Operation *opToClone, IRMapping &mapping, return; seen.insert(opToClone); + // Ensure all operands have been mapped for (auto operand : opToClone->getOperands()) { if (mapping.contains(operand)) continue; + auto *definingOp = operand.getDefiningOp(); if (definingOp) { + // Recurse and clone defining op cloneFanIn(builder, definingOp, mapping, seen); } else { + // Create symbolic values for arguments auto sym = builder.create(operand.getLoc(), operand.getType()); mapping.map(operand, sym); @@ -68,6 +72,7 @@ void cloneFanIn(OpBuilder &builder, Operation *opToClone, IRMapping &mapping, } Operation *clonedOp; + // Replace ensure/require ops, otherwise clone if (isa(opToClone)) { clonedOp = replaceContractOp( builder, dyn_cast(*opToClone), mapping); @@ -77,49 +82,52 @@ void cloneFanIn(OpBuilder &builder, Operation *opToClone, IRMapping &mapping, } else { clonedOp = builder.clone(*opToClone, mapping); } + + // Add mappings for results for (auto [x, y] : llvm::zip(opToClone->getResults(), clonedOp->getResults())) { mapping.map(x, y); } } -SmallVector collectContracts(HWModuleOp hwModule) { - SmallVector contracts; - hwModule.walk([&](ContractOp op) { contracts.push_back(op); }); - return contracts; -} - LogicalResult runOnHWModule(HWModuleOp hwModule, ModuleOp mlirModule) { - OpBuilder mlirModuleBuilder(mlirModule); mlirModuleBuilder.setInsertionPointAfter(hwModule); - SmallVector contracts = collectContracts(hwModule); + // Collect contract ops + SmallVector contracts; + hwModule.walk([&](ContractOp op) { contracts.push_back(op); }); for (unsigned i = 0; i < contracts.size(); i++) { auto contract = contracts[i]; + // Create verif.formal op auto name = mlirModuleBuilder.getStringAttr(hwModule.getNameAttr().getValue() + "_CheckContract_" + std::to_string(i)); auto formalOp = mlirModuleBuilder.create( contract.getLoc(), name, mlirModuleBuilder.getDictionaryAttr({})); + // Fill in verif.formal body OpBuilder formalBuilder(formalOp); formalBuilder.createBlock(&formalOp.getBody()); IRMapping mapping; DenseSet seen; + + // Clone fan in cone for contract operands for (auto operand : contract.getOperands()) { auto *definingOp = operand.getDefiningOp(); cloneFanIn(formalBuilder, definingOp, mapping, seen); } + // Map results by looking up the input mappings for (auto [result, input] : llvm::zip(contract.getResults(), contract.getInputs())) { mapping.map(result, mapping.lookup(input)); } + // Clone body of the contract for (auto &op : contract.getBody().front().getOperations()) { cloneFanIn(formalBuilder, &op, mapping, seen); }