Skip to content

Commit

Permalink
[HWToSMT] ArrayCreateOp and ArrayGetOp support (#7666)
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart authored Nov 29, 2024
1 parent ad1f68e commit 9c80a00
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 2 deletions.
24 changes: 24 additions & 0 deletions integration_test/circt-lec/hw.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,27 @@ hw.module @onePlusTwoNonSSA(out out: i2) {
// RUN: circt-lec %s -c1=onePlusTwo -c2=onePlusTwoNonSSA --shared-libs=%libz3 | FileCheck %s --check-prefix=HW_MODULE_GRAPH
// HW_MODULE_GRAPH: c1 == c2

// array_create + array_get test
// RUN: circt-lec %s -c1=MultibitMux -c2=MultibitMux2 --shared-libs=%libz3 | FileCheck %s --check-prefix=ARRAY_GET
// ARRAY_GET: c1 == c2

hw.module @MultibitMux(in %a_0 : i1, in %a_1 : i1, in %sel : i1, out b : i1) {
%0 = hw.array_create %a_1, %a_0 : i1
%1 = hw.array_get %0[%sel] : !hw.array<2xi1>, i1
hw.output %1 : i1
}

hw.module @MultibitMux2(in %a_0 : i1, in %a_1 : i1, in %sel : i1, out b : i1) {
%0 = comb.mux bin %sel, %a_1, %a_0 : i1
hw.output %0 : i1
}

// array_get out-of-bounds must not be equivalent
// RUN: circt-lec %s -c1=ArrayOOB -c2=ArrayOOB --shared-libs=%libz3 | FileCheck %s --check-prefix=ARRAY_OOB
// ARRAY_OOB: c1 != c2

hw.module @ArrayOOB(in %a : !hw.array<3xi1>, out b : i1) {
%0 = hw.constant 3 : i2
%1 = hw.array_get %a[%0] : !hw.array<3xi1>, i2
hw.output %1 : i1
}
67 changes: 65 additions & 2 deletions lib/Conversion/HWToSMT/HWToSMT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,61 @@ struct InstanceOpConversion : OpConversionPattern<InstanceOp> {
}
};

/// Lower a hw::ArrayCreateOp operation to smt::DeclareFun and an
/// smt::ArrayStoreOp for each operand.
struct ArrayCreateOpConversion : OpConversionPattern<ArrayCreateOp> {
using OpConversionPattern<ArrayCreateOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ArrayCreateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type arrTy = typeConverter->convertType(op.getType());
if (!arrTy)
return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");

unsigned width = adaptor.getInputs().size();

Value arr = rewriter.create<smt::DeclareFunOp>(loc, arrTy);
for (auto [i, el] : llvm::enumerate(adaptor.getInputs())) {
Value idx = rewriter.create<smt::BVConstantOp>(loc, width - i - 1,
llvm::Log2_64_Ceil(width));
arr = rewriter.create<smt::ArrayStoreOp>(loc, arr, idx, el);
}

rewriter.replaceOp(op, arr);
return success();
}
};

/// Lower a hw::ArrayGetOp operation to smt::ArraySelectOp
struct ArrayGetOpConversion : OpConversionPattern<ArrayGetOp> {
using OpConversionPattern<ArrayGetOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
unsigned numElements =
cast<hw::ArrayType>(op.getInput().getType()).getNumElements();

Type type = typeConverter->convertType(op.getType());
if (!type)
return rewriter.notifyMatchFailure(op.getLoc(),
"unsupported array element type");

Value oobVal = rewriter.create<smt::DeclareFunOp>(loc, type);
Value numElementsVal = rewriter.create<smt::BVConstantOp>(
loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
Value inBounds = rewriter.create<smt::BVCmpOp>(
loc, smt::BVCmpPredicate::ule, adaptor.getIndex(), numElementsVal);
Value indexed = rewriter.create<smt::ArraySelectOp>(loc, adaptor.getInput(),
adaptor.getIndex());
rewriter.replaceOpWithNewOp<smt::IteOp>(op, inBounds, indexed, oobVal);
return success();
}
};

/// Remove redundant (seq::FromClock and seq::ToClock) ops.
template <typename OpTy>
struct ReplaceWithInput : OpConversionPattern<OpTy> {
Expand Down Expand Up @@ -139,6 +194,14 @@ void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
return smt::BitVectorType::get(type.getContext(), 1);
});
converter.addConversion([&](ArrayType type) -> std::optional<Type> {
auto rangeType = converter.convertType(type.getElementType());
if (!rangeType)
return {};
auto domainType = smt::BitVectorType::get(
type.getContext(), llvm::Log2_64_Ceil(type.getNumElements()));
return smt::ArrayType::get(type.getContext(), domainType, rangeType);
});

// Default target materialization to convert from illegal types to legal
// types, e.g., at the boundary of an inlined child block.
Expand Down Expand Up @@ -222,8 +285,8 @@ void circt::populateHWToSMTConversionPatterns(TypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
ReplaceWithInput<seq::FromClockOp>>(converter,
patterns.getContext());
ReplaceWithInput<seq::FromClockOp>, ArrayCreateOpConversion,
ArrayGetOpConversion>(converter, patterns.getContext());
}

void ConvertHWToSMTPass::runOnOperation() {
Expand Down

0 comments on commit 9c80a00

Please sign in to comment.