From 97c770c095ca7ac46b1cd1caa070f494961e50c1 Mon Sep 17 00:00:00 2001 From: Jake Taylor Date: Fri, 10 Nov 2023 19:04:19 +0100 Subject: [PATCH] [HWLegalizeModules] Lower types-like packed array handling (#5355) (#6402) This PR refactors HWLegalizeModules to something more like existing lower types passes in order to support more complex patterns such as array concatenations. Several new tests are also added. Fix https://github.com/llvm/circt/issues/5355. --- .../SV/Transforms/HWLegalizeModules.cpp | 329 ++++++++++++++---- .../SV/hw-legalize-modules-packed-arrays.mlir | 121 ++++++- 2 files changed, 383 insertions(+), 67 deletions(-) diff --git a/lib/Dialect/SV/Transforms/HWLegalizeModules.cpp b/lib/Dialect/SV/Transforms/HWLegalizeModules.cpp index 025d866250ba..1ac6db5e3d51 100644 --- a/lib/Dialect/SV/Transforms/HWLegalizeModules.cpp +++ b/lib/Dialect/SV/Transforms/HWLegalizeModules.cpp @@ -34,7 +34,13 @@ struct HWLegalizeModulesPass private: void processPostOrder(Block &block); - Operation *tryLoweringArrayGet(hw::ArrayGetOp getOp); + bool tryLoweringPackedArrayOp(Operation &op); + Value lowerLookupToCasez(Operation &op, Value input, Value index, + mlir::Type elementType, + SmallVector caseValues); + bool processUsers(Operation &op, Value value, ArrayRef mapping); + std::optional> + tryExtractIndexAndBitWidth(Value value); /// This is the current hw.module being processed. hw::HWModuleOp thisHWModule; @@ -52,52 +58,218 @@ struct HWLegalizeModulesPass }; } // end anonymous namespace -/// Try to lower a hw.array_get in module that doesn't support packed arrays. -/// This returns a replacement operation if lowering was successful, null -/// otherwise. -Operation *HWLegalizeModulesPass::tryLoweringArrayGet(hw::ArrayGetOp getOp) { - SmallVector caseValues; - OpBuilder builder(&thisHWModule.getBodyBlock()->front()); - // If the operand is an array_create or aggregate constant, then we can lower - // this into a casez. - if (auto createOp = getOp.getInput().getDefiningOp()) - caseValues = SmallVector(llvm::reverse(createOp.getOperands())); - else if (auto aggregateConstant = - getOp.getInput().getDefiningOp()) { - for (auto elem : llvm::reverse(aggregateConstant.getFields())) { - if (auto intAttr = dyn_cast(elem)) - caseValues.push_back(builder.create( - aggregateConstant.getLoc(), intAttr)); - else - caseValues.push_back(builder.create( - aggregateConstant.getLoc(), getOp.getType(), - elem.cast())); - } - } else { - return nullptr; - } +bool HWLegalizeModulesPass::tryLoweringPackedArrayOp(Operation &op) { + return TypeSwitch(&op) + .Case([&](hw::AggregateConstantOp constOp) { + // Replace individual element uses (if any) with input fields. + SmallVector inputs; + OpBuilder builder(constOp); + for (auto field : llvm::reverse(constOp.getFields())) { + if (auto intAttr = dyn_cast(field)) + inputs.push_back( + builder.create(constOp.getLoc(), intAttr)); + else + inputs.push_back(builder.create( + constOp.getLoc(), constOp.getType(), field.cast())); + } + if (!processUsers(op, constOp.getResult(), inputs)) + return false; + + // Remove original op. + return true; + }) + .Case([&](hw::ArrayConcatOp concatOp) { + // Redirect individual element uses (if any) to the input arguments. + SmallVector> arrays; + for (auto array : llvm::reverse(concatOp.getInputs())) { + auto ty = hw::type_cast(array.getType()); + arrays.emplace_back(array, ty.getNumElements()); + } + for (auto *user : + llvm::make_early_inc_range(concatOp.getResult().getUsers())) { + if (TypeSwitch(user) + .Case([&](hw::ArrayGetOp getOp) { + if (auto indexAndBitWidth = + tryExtractIndexAndBitWidth(getOp.getIndex())) { + auto [indexValue, bitWidth] = *indexAndBitWidth; + // FIXME: More efficient search + for (const auto &[array, size] : arrays) { + if (indexValue >= size) { + indexValue -= size; + continue; + } + OpBuilder builder(getOp); + getOp.getInputMutable().set(array); + getOp.getIndexMutable().set( + builder.createOrFold( + getOp.getLoc(), APInt(bitWidth, indexValue))); + return true; + } + } + + return false; + }) + .Default([](auto op) { return false; })) + continue; + + op.emitError("unsupported packed array expression"); + signalPassFailure(); + } + + // Remove the original op. + return true; + }) + .Case([&](hw::ArrayCreateOp createOp) { + // Replace individual element uses (if any) with input arguments. + SmallVector inputs(llvm::reverse(createOp.getInputs())); + if (!processUsers(op, createOp.getResult(), inputs)) + return false; - // array_get(idx, array_create(a,b,c,d)) ==> casez(idx). - Value index = getOp.getIndex(); + // Remove original op. + return true; + }) + .Case([&](hw::ArrayGetOp getOp) { + // Skip index ops with constant index. + auto index = getOp.getIndex(); + if (auto *definingOp = index.getDefiningOp()) + if (isa(definingOp)) + return false; - // Create the wire for the result of the casez in the hw.module. - auto theWire = builder.create(getOp.getLoc(), getOp.getType(), + // Generate case value element lookups. + auto ty = hw::type_cast(getOp.getInput().getType()); + OpBuilder builder(getOp); + SmallVector caseValues; + for (size_t i = 0, e = ty.getNumElements(); i < e; i++) { + auto loc = op.getLoc(); + auto index = builder.createOrFold( + loc, APInt(llvm::Log2_64_Ceil(e), i)); + auto element = + builder.create(loc, getOp.getInput(), index); + caseValues.push_back(element); + } + + // Transform array index op into casez statement. + auto theWire = lowerLookupToCasez(op, getOp.getInput(), index, + ty.getElementType(), caseValues); + + // Emit the read from the wire, replace uses and clean up. + builder.setInsertionPoint(getOp); + auto readWire = + builder.create(getOp.getLoc(), theWire); + getOp.getResult().replaceAllUsesWith(readWire); + return true; + }) + .Case([&](sv::ArrayIndexInOutOp indexOp) { + // Skip index ops with constant index. + auto index = indexOp.getIndex(); + if (auto *definingOp = index.getDefiningOp()) + if (isa(definingOp)) + return false; + + // Skip index ops with unpacked arrays. + auto inout = indexOp.getInput().getType(); + if (hw::type_isa(inout.getElementType())) + return false; + + // Generate case value element lookups. + auto ty = hw::type_cast(inout.getElementType()); + OpBuilder builder(&op); + SmallVector caseValues; + for (size_t i = 0, e = ty.getNumElements(); i < e; i++) { + auto loc = op.getLoc(); + auto index = builder.createOrFold( + loc, APInt(llvm::Log2_64_Ceil(e), i)); + auto element = builder.create( + loc, indexOp.getInput(), index); + auto readElement = builder.create(loc, element); + caseValues.push_back(readElement); + } + + // Transform array index op into casez statement. + auto theWire = lowerLookupToCasez(op, indexOp.getInput(), index, + ty.getElementType(), caseValues); + + // Replace uses and clean up. + indexOp.getResult().replaceAllUsesWith(theWire); + return true; + }) + .Case([&](sv::PAssignOp assignOp) { + // Transform array assignment into individual assignments for each array + // element. + auto inout = assignOp.getDest().getType(); + auto ty = hw::type_dyn_cast(inout.getElementType()); + if (!ty) + return false; + + OpBuilder builder(assignOp); + for (size_t i = 0, e = ty.getNumElements(); i < e; i++) { + auto loc = op.getLoc(); + auto index = builder.createOrFold( + loc, APInt(llvm::Log2_64_Ceil(e), i)); + auto dstElement = builder.create( + loc, assignOp.getDest(), index); + auto srcElement = + builder.create(loc, assignOp.getSrc(), index); + builder.create(loc, dstElement, srcElement); + } + + // Remove original assignment. + return true; + }) + .Case([&](sv::RegOp regOp) { + // Transform array reg into individual regs for each array element. + auto ty = hw::type_dyn_cast(regOp.getElementType()); + if (!ty) + return false; + + OpBuilder builder(regOp); + auto name = StringAttr::get(regOp.getContext(), "name"); + SmallVector elements; + for (size_t i = 0, e = ty.getNumElements(); i < e; i++) { + auto loc = op.getLoc(); + auto element = builder.create(loc, ty.getElementType()); + if (auto nameAttr = regOp->getAttrOfType(name)) { + element.setNameAttr( + StringAttr::get(regOp.getContext(), nameAttr.getValue())); + } + elements.push_back(element); + } + + // Fix users to refer to individual element regs. + if (!processUsers(op, regOp.getResult(), elements)) + return false; + + // Remove original reg. + return true; + }) + .Default([&](auto op) { return false; }); +} + +Value HWLegalizeModulesPass::lowerLookupToCasez(Operation &op, Value input, + Value index, + mlir::Type elementType, + SmallVector caseValues) { + // Create the wire for the result of the casez in the + // hw.module. + OpBuilder builder(&op); + auto theWire = builder.create(op.getLoc(), elementType, builder.getStringAttr("casez_tmp")); - builder.setInsertionPoint(getOp); + builder.setInsertionPoint(&op); - auto loc = getOp.getInput().getDefiningOp()->getLoc(); - // A casez is a procedural operation, so if we're in a non-procedural region - // we need to inject an always_comb block. - if (!getOp->getParentOp()->hasTrait()) { + auto loc = input.getDefiningOp()->getLoc(); + // A casez is a procedural operation, so if we're in a + // non-procedural region we need to inject an always_comb + // block. + if (!op.getParentOp()->hasTrait()) { auto alwaysComb = builder.create(loc); builder.setInsertionPointToEnd(alwaysComb.getBodyBlock()); } - // If we are missing elements in the array (it is non-power of two), then - // add a default 'X' value. + // If we are missing elements in the array (it is non-power of + // two), then add a default 'X' value. if (1ULL << index.getType().getIntOrFloatBitWidth() != caseValues.size()) { - caseValues.push_back( - builder.create(getOp.getLoc(), getOp.getType())); + caseValues.push_back(builder.create( + op.getLoc(), op.getResult(0).getType())); } APInt caseValue(index.getType().getIntOrFloatBitWidth(), 0); @@ -107,9 +279,10 @@ Operation *HWLegalizeModulesPass::tryLoweringArrayGet(hw::ArrayGetOp getOp) { builder.create( loc, CaseStmtType::CaseZStmt, index, caseValues.size(), [&](size_t caseIdx) -> std::unique_ptr { - // Use a default pattern for the last value, even if we are complete. - // This avoids tools thinking they need to insert a latch due to - // potentially incomplete case coverage. + // Use a default pattern for the last value, even if we + // are complete. This avoids tools thinking they need to + // insert a latch due to potentially incomplete case + // coverage. bool isDefault = caseIdx == caseValues.size() - 1; Value theValue = caseValues[caseIdx]; std::unique_ptr thePattern; @@ -123,12 +296,52 @@ Operation *HWLegalizeModulesPass::tryLoweringArrayGet(hw::ArrayGetOp getOp) { return thePattern; }); - // Ok, emit the read from the wire to get the value out. - builder.setInsertionPoint(getOp); - auto readWire = builder.create(getOp.getLoc(), theWire); - getOp.getResult().replaceAllUsesWith(readWire); - getOp->erase(); - return readWire; + return theWire; +} + +bool HWLegalizeModulesPass::processUsers(Operation &op, Value value, + ArrayRef mapping) { + for (auto *user : llvm::make_early_inc_range(value.getUsers())) { + if (TypeSwitch(user) + .Case([&](hw::ArrayGetOp getOp) { + if (auto indexAndBitWidth = + tryExtractIndexAndBitWidth(getOp.getIndex())) { + getOp.replaceAllUsesWith(mapping[indexAndBitWidth->first]); + return true; + } + + return false; + }) + .Case([&](sv::ArrayIndexInOutOp indexOp) { + if (auto indexAndBitWidth = + tryExtractIndexAndBitWidth(indexOp.getIndex())) { + indexOp.replaceAllUsesWith(mapping[indexAndBitWidth->first]); + return true; + } + + return false; + }) + .Default([](auto op) { return false; })) { + user->erase(); + continue; + } + + user->emitError("unsupported packed array expression"); + signalPassFailure(); + return false; + } + + return true; +} + +std::optional> +HWLegalizeModulesPass::tryExtractIndexAndBitWidth(Value value) { + if (auto constantOp = dyn_cast(value.getDefiningOp())) { + auto index = constantOp.getValue(); + return std::make_optional( + std::make_pair(index.getZExtValue(), index.getBitWidth())); + } + return std::nullopt; } void HWLegalizeModulesPass::processPostOrder(Block &body) { @@ -154,26 +367,16 @@ void HWLegalizeModulesPass::processPostOrder(Block &body) { } if (options.disallowPackedArrays) { - // Try idioms for lowering array_get operations. - if (auto getOp = dyn_cast(op)) - if (auto *replacement = tryLoweringArrayGet(getOp)) { - it = Block::iterator(replacement); - anythingChanged = true; - continue; - } - - // If this is a dead array, then we can just delete it. This is - // probably left over from get/create lowering. - if (isa(op) && - op.use_empty()) { + // Try supported packed array op lowering. + if (tryLoweringPackedArrayOp(op)) { + it = --Block::iterator(op); op.erase(); + anythingChanged = true; continue; } - // Otherwise, if we aren't allowing multi-dimensional arrays, reject the - // IR as invalid. - // TODO: We should eventually implement a "lower types" like feature in - // this pass. + // Otherwise, if the IR produces a packed array and we aren't allowing + // multi-dimensional arrays, reject the IR as invalid. for (auto value : op.getResults()) { if (value.getType().isa()) { op.emitError("unsupported packed array expression"); diff --git a/test/Dialect/SV/hw-legalize-modules-packed-arrays.mlir b/test/Dialect/SV/hw-legalize-modules-packed-arrays.mlir index a97145803b51..8bd685a3f5cc 100644 --- a/test/Dialect/SV/hw-legalize-modules-packed-arrays.mlir +++ b/test/Dialect/SV/hw-legalize-modules-packed-arrays.mlir @@ -4,15 +4,13 @@ module attributes {circt.loweringOptions = "disallowPackedArrays"} { hw.module @reject_arrays(in %arg0: i8, in %arg1: i8, in %arg2: i8, in %arg3: i8, in %sel: i2, in %clock: i1, out a: !hw.array<4xi8>) { - // This needs full-on "legalize types" for the HW dialect. - %reg = sv.reg : !hw.inout> sv.alwaysff(posedge %clock) { - // expected-error @+1 {{unsupported packed array expression}} %0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i8 sv.passign %reg, %0 : !hw.array<4xi8> } + // This needs full-on "legalize types" for the HW dialect. // expected-error @+1 {{unsupported packed array expression}} %1 = sv.read_inout %reg : !hw.inout> hw.output %1 : !hw.array<4xi8> @@ -52,9 +50,9 @@ hw.module @array_create_get_comb(in %arg0: i8, in %arg1: i8, in %arg2: i8, in %a // CHECK-LABEL: hw.module @array_create_get_default hw.module @array_create_get_default(in %arg0: i8, in %arg1: i8, in %arg2: i8, in %arg3: i8, in %sel: i2) { - // CHECK: %casez_tmp = sv.reg : !hw.inout // CHECK: sv.initial { sv.initial { + // CHECK: %casez_tmp = sv.reg : !hw.inout // CHECK: %x_i8 = sv.constantX : i8 // CHECK: sv.case casez %sel : i2 // CHECK: case b00: { @@ -83,6 +81,42 @@ hw.module @array_create_get_default(in %arg0: i8, in %arg1: i8, in %arg2: i8, in } } +// CHECK-LABEL: hw.module @array_create_concat_get_default +hw.module @array_create_concat_get_default(in %arg0: i8, in %arg1: i8, in %arg2: i8, in %arg3: i8, + in %sel: i2) { + // CHECK: sv.initial { + sv.initial { + // CHECK: %casez_tmp = sv.reg : !hw.inout + // CHECK: %x_i8 = sv.constantX : i8 + // CHECK: sv.case casez %sel : i2 + // CHECK: case b00: { + // CHECK: sv.bpassign %casez_tmp, %arg0 : i8 + // CHECK: } + // CHECK: case b01: { + // CHECK: sv.bpassign %casez_tmp, %arg1 : i8 + // CHECK: } + // CHECK: case b10: { + // CHECK: sv.bpassign %casez_tmp, %arg2 : i8 + // CHECK: } + // CHECK: default: { + // CHECK: sv.bpassign %casez_tmp, %x_i8 : i8 + // CHECK: } + %one_array = hw.array_create %arg2 : i8 + %two_array = hw.array_create %arg1, %arg0 : i8 + %three_array = hw.array_concat %one_array, %two_array : !hw.array<1xi8>, !hw.array<2xi8> + + // CHECK: %0 = sv.read_inout %casez_tmp : !hw.inout + %2 = hw.array_get %three_array[%sel] : !hw.array<3xi8>, i2 + + // CHECK: %1 = comb.icmp eq %0, %arg2 : i8 + // CHECK: sv.if %1 { + %cond = comb.icmp eq %2, %arg2 : i8 + sv.if %cond { + sv.fatal 1 + } + } +} + // CHECK-LABEL: hw.module @array_constant_get_comb hw.module @array_constant_get_comb(in %sel: i2, out a: i8) { // CHECK: %casez_tmp = sv.reg : !hw.inout @@ -109,4 +143,83 @@ hw.module @array_constant_get_comb(in %sel: i2, out a: i8) { hw.output %1 : i8 } +// CHECK-LABEL: hw.module @array_reg_mux_2 +hw.module @array_reg_mux_2(in %clock: i1, in %arg0: i8, in %arg1: i8, in %sel: i1, out a: i8) { + // CHECK: %reg = sv.reg : !hw.inout + // CHECK: %reg_0 = sv.reg name "reg" : !hw.inout + %reg = sv.reg : !hw.inout> + // CHECK: sv.alwaysff(posedge %clock) { + sv.alwaysff(posedge %clock) { + // CHECK: sv.passign %reg, %arg1 : i8 + // CHECK: sv.passign %reg_0, %arg0 : i8 + %0 = hw.array_create %arg0, %arg1 : i8 + sv.passign %reg, %0 : !hw.array<2xi8> + // CHECK: } + } + + // CHECK: %0 = sv.read_inout %reg : !hw.inout + // CHECK: %1 = sv.read_inout %reg_0 : !hw.inout + // CHECK: %casez_tmp = sv.reg : !hw.inout + // CHECK: sv.alwayscomb { + // CHECK: sv.case casez %sel : i1 + // CHECK: case b0: { + // CHECK: sv.bpassign %casez_tmp, %0 : i8 + // CHECK: } + // CHECK: default: { + // CHECK: sv.bpassign %casez_tmp, %1 : i8 + // CHECK: } + // CHECK: } + %1 = sv.array_index_inout %reg[%sel] : !hw.inout>, i1 + // CHECK: %2 = sv.read_inout %casez_tmp : !hw.inout + %2 = sv.read_inout %1 : !hw.inout + // CHECK: hw.output %2 : i8 + hw.output %2 : i8 +} + +// CHECK-LABEL: hw.module @array_reg_mux_4 +hw.module @array_reg_mux_4(in %arg0: i8, in %arg1: i8, in %arg2: i8, + in %arg3: i8, in %sel: i2, in %clock: i1, + out a: i8) { + // CHECK: %reg = sv.reg : !hw.inout + // CHECK: %reg_0 = sv.reg name "reg" : !hw.inout + // CHECK: %reg_1 = sv.reg name "reg" : !hw.inout + // CHECK: %reg_2 = sv.reg name "reg" : !hw.inout + %reg = sv.reg : !hw.inout> + // CHECK: sv.alwaysff(posedge %clock) { + sv.alwaysff(posedge %clock) { + // CHECK: sv.passign %reg, %arg3 : i8 + // CHECK: sv.passign %reg_0, %arg2 : i8 + // CHECK: sv.passign %reg_1, %arg1 : i8 + // CHECK: sv.passign %reg_2, %arg0 : i8 + %0 = hw.array_create %arg0, %arg1, %arg2, %arg3 : i8 + sv.passign %reg, %0 : !hw.array<4xi8> + // CHECK: } + } + // CHECK: %0 = sv.read_inout %reg : !hw.inout + // CHECK: %1 = sv.read_inout %reg_0 : !hw.inout + // CHECK: %2 = sv.read_inout %reg_1 : !hw.inout + // CHECK: %3 = sv.read_inout %reg_2 : !hw.inout + // CHECK: %casez_tmp = sv.reg : !hw.inout + // CHECK: sv.alwayscomb { + // CHECK: sv.case casez %sel : i2 + // CHECK: case b00: { + // CHECK: sv.bpassign %casez_tmp, %0 : i8 + // CHECK: } + // CHECK: case b01: { + // CHECK: sv.bpassign %casez_tmp, %1 : i8 + // CHECK: } + // CHECK: case b10: { + // CHECK: sv.bpassign %casez_tmp, %2 : i8 + // CHECK: } + // CHECK: default: { + // CHECK: sv.bpassign %casez_tmp, %3 : i8 + // CHECK: } + // CHECK: } + %1 = sv.array_index_inout %reg[%sel] : !hw.inout>, i2 + // CHECK: %4 = sv.read_inout %casez_tmp : !hw.inout + %2 = sv.read_inout %1 : !hw.inout + // CHECK: hw.output %4 : i8 + hw.output %2 : i8 +} + } // end builtin.module