Skip to content

Commit 6634b2b

Browse files
authored
[CombToAIG] Add a lowering for Add/Sub (#7968)
This implements a pattern to lower AddOp to a naive ripple-carry adder. Pattern for sub is also added since we can easily compute subtraction from addition. This PR also adds a test-only option to `additional-legal-ops` to test complicated lowering pattern.
1 parent b7a0513 commit 6634b2b

File tree

5 files changed

+176
-2
lines changed

5 files changed

+176
-2
lines changed

include/circt/Conversion/CombToAIG.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
#define CIRCT_CONVERSION_COMBTOAIG_H
1111

1212
#include "circt/Support/LLVM.h"
13+
#include "llvm/ADT/SmallVector.h"
1314
#include <memory>
15+
#include <string>
1416

1517
namespace circt {
1618

include/circt/Conversion/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,11 @@ def ConvertCombToAIG: Pass<"convert-comb-to-aig", "hw::HWModuleOp"> {
813813
"circt::comb::CombDialect",
814814
"circt::aig::AIGDialect",
815815
];
816+
817+
let options = [
818+
ListOption<"additionalLegalOps", "additional-legal-ops", "std::string",
819+
"Specify additional legal ops for testing">,
820+
];
816821
}
817822

818823
//===----------------------------------------------------------------------===//

integration_test/circt-synth/comb-lowering-lec.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,17 @@ hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i3
1313

1414
hw.output %0, %1, %2, %3 : i32, i32, i32, i32
1515
}
16+
17+
// RUN: circt-lec %t.mlir %s -c1=add -c2=add --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ADD
18+
// COMB_ADD: c1 == c2
19+
hw.module @add(in %arg0: i4, in %arg1: i4, in %arg2: i4, out add: i4) {
20+
%0 = comb.add %arg0, %arg1, %arg2 : i4
21+
hw.output %0 : i4
22+
}
23+
24+
// RUN: circt-lec %t.mlir %s -c1=sub -c2=sub --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_SUB
25+
// COMB_SUB: c1 == c2
26+
hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
27+
%0 = comb.sub %lhs, %rhs : i4
28+
hw.output %0 : i4
29+
}

lib/Conversion/CombToAIG/CombToAIG.cpp

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,39 @@ namespace circt {
2525
using namespace circt;
2626
using namespace comb;
2727

28+
//===----------------------------------------------------------------------===//
29+
// Utility Functions
30+
//===----------------------------------------------------------------------===//
31+
32+
// Extract individual bits from a value
33+
static SmallVector<Value> extractBits(ConversionPatternRewriter &rewriter,
34+
Value val) {
35+
assert(val.getType().isInteger() && "expected integer");
36+
auto width = val.getType().getIntOrFloatBitWidth();
37+
SmallVector<Value> bits;
38+
bits.reserve(width);
39+
40+
// Check if we can reuse concat operands
41+
if (auto concat = val.getDefiningOp<comb::ConcatOp>()) {
42+
if (concat.getNumOperands() == width &&
43+
llvm::all_of(concat.getOperandTypes(), [](Type type) {
44+
return type.getIntOrFloatBitWidth() == 1;
45+
})) {
46+
// Reverse the operands to match the bit order
47+
bits.append(std::make_reverse_iterator(concat.getOperands().end()),
48+
std::make_reverse_iterator(concat.getOperands().begin()));
49+
return bits;
50+
}
51+
}
52+
53+
// Extract individual bits
54+
for (int64_t i = 0; i < width; ++i)
55+
bits.push_back(
56+
rewriter.createOrFold<comb::ExtractOp>(val.getLoc(), val, i, 1));
57+
58+
return bits;
59+
}
60+
2861
//===----------------------------------------------------------------------===//
2962
// Conversion patterns
3063
//===----------------------------------------------------------------------===//
@@ -169,6 +202,87 @@ struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
169202
}
170203
};
171204

205+
struct CombAddOpConversion : OpConversionPattern<AddOp> {
206+
using OpConversionPattern<AddOp>::OpConversionPattern;
207+
LogicalResult
208+
matchAndRewrite(AddOp op, OpAdaptor adaptor,
209+
ConversionPatternRewriter &rewriter) const override {
210+
auto inputs = adaptor.getInputs();
211+
// Lower only when there are two inputs.
212+
// Variadic operands must be lowered in a different pattern.
213+
if (inputs.size() != 2)
214+
return failure();
215+
216+
auto width = op.getType().getIntOrFloatBitWidth();
217+
// Skip a zero width value.
218+
if (width == 0) {
219+
rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
220+
return success();
221+
}
222+
223+
// Implement a naive Ripple-carry full adder.
224+
Value carry;
225+
226+
auto aBits = extractBits(rewriter, inputs[0]);
227+
auto bBits = extractBits(rewriter, inputs[1]);
228+
SmallVector<Value> results;
229+
results.resize(width);
230+
for (int64_t i = 0; i < width; ++i) {
231+
SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
232+
if (carry)
233+
xorOperands.push_back(carry);
234+
235+
// sum[i] = xor(carry[i-1], a[i], b[i])
236+
// NOTE: The result is stored in reverse order.
237+
results[width - i - 1] =
238+
rewriter.create<comb::XorOp>(op.getLoc(), xorOperands, true);
239+
240+
// If this is the last bit, we are done.
241+
if (i == width - 1) {
242+
break;
243+
}
244+
245+
// carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
246+
Value nextCarry = rewriter.create<comb::AndOp>(
247+
op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
248+
if (!carry) {
249+
// This is the first bit, so the carry is the next carry.
250+
carry = nextCarry;
251+
continue;
252+
}
253+
254+
auto aXnorB = rewriter.create<comb::XorOp>(
255+
op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
256+
auto andOp = rewriter.create<comb::AndOp>(
257+
op.getLoc(), ValueRange{carry, aXnorB}, true);
258+
carry = rewriter.create<comb::OrOp>(op.getLoc(),
259+
ValueRange{andOp, nextCarry}, true);
260+
}
261+
262+
rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
263+
return success();
264+
}
265+
};
266+
267+
struct CombSubOpConversion : OpConversionPattern<SubOp> {
268+
using OpConversionPattern<SubOp>::OpConversionPattern;
269+
LogicalResult
270+
matchAndRewrite(SubOp op, OpAdaptor adaptor,
271+
ConversionPatternRewriter &rewriter) const override {
272+
auto lhs = op.getLhs();
273+
auto rhs = op.getRhs();
274+
// Since `-rhs = ~rhs + 1` holds, rewrite `sub(lhs, rhs)` to:
275+
// sub(lhs, rhs) => add(lhs, -rhs) => add(lhs, add(~rhs, 1))
276+
// => add(lhs, ~rhs, 1)
277+
auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
278+
/*invert=*/true);
279+
auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
280+
rewriter.replaceOpWithNewOp<comb::AddOp>(op, ValueRange{lhs, notRhs, one},
281+
true);
282+
return success();
283+
}
284+
};
285+
172286
} // namespace
173287

174288
//===----------------------------------------------------------------------===//
@@ -179,6 +293,8 @@ namespace {
179293
struct ConvertCombToAIGPass
180294
: public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
181295
void runOnOperation() override;
296+
using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
297+
using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
182298
};
183299
} // namespace
184300

@@ -187,18 +303,26 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
187303
// Bitwise Logical Ops
188304
CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
189305
CombMuxOpConversion,
306+
// Arithmetic Ops
307+
CombAddOpConversion, CombSubOpConversion,
190308
// Variadic ops that must be lowered to binary operations
191-
CombLowerVariadicOp<XorOp>>(patterns.getContext());
309+
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>>(
310+
patterns.getContext());
192311
}
193312

194313
void ConvertCombToAIGPass::runOnOperation() {
195314
ConversionTarget target(getContext());
196315
target.addIllegalDialect<comb::CombDialect>();
197316
// Keep data movement operations like Extract, Concat and Replicate.
198317
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
199-
hw::BitcastOp>();
318+
hw::BitcastOp, hw::ConstantOp>();
200319
target.addLegalDialect<aig::AIGDialect>();
201320

321+
// This is a test only option to add logical ops.
322+
if (!additionalLegalOps.empty())
323+
for (const auto &opName : additionalLegalOps)
324+
target.addLegalOp(OperationName(opName, &getContext()));
325+
202326
RewritePatternSet patterns(&getContext());
203327
populateCombToAIGConversionPatterns(patterns);
204328

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux}))" | FileCheck %s
2+
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.add}))" | FileCheck %s --check-prefix=ALLOW_ADD
3+
4+
5+
// CHECK-LABEL: @add
6+
hw.module @add(in %lhs: i2, in %rhs: i2, out out: i2) {
7+
// CHECK: %[[lhs0:.*]] = comb.extract %lhs from 0 : (i2) -> i1
8+
// CHECK-NEXT: %[[lhs1:.*]] = comb.extract %lhs from 1 : (i2) -> i1
9+
// CHECK-NEXT: %[[rhs0:.*]] = comb.extract %rhs from 0 : (i2) -> i1
10+
// CHECK-NEXT: %[[rhs1:.*]] = comb.extract %rhs from 1 : (i2) -> i1
11+
// CHECK-NEXT: %[[sum0:.*]] = comb.xor bin %[[lhs0]], %[[rhs0]] : i1
12+
// CHECK-NEXT: %[[carry0:.*]] = comb.and bin %[[lhs0]], %[[rhs0]] : i1
13+
// CHECK-NEXT: %[[sum1:.*]] = comb.xor bin %[[lhs1]], %[[rhs1]], %[[carry0]] : i1
14+
// CHECK-NEXT: %[[concat:.*]] = comb.concat %[[sum1]], %[[sum0]] : i1, i1
15+
// CHECK-NEXT: hw.output %[[concat]] : i2
16+
%0 = comb.add %lhs, %rhs : i2
17+
hw.output %0 : i2
18+
}
19+
20+
// CHECK-LABEL: @sub
21+
// ALLOW_ADD-LABEL: @sub
22+
// ALLOW_ADD-NEXT: %[[NOT_RHS:.+]] = aig.and_inv not %rhs
23+
// ALLOW_ADD-NEXT: %[[CONST:.+]] = hw.constant 1 : i4
24+
// ALLOW_ADD-NEXT: %[[ADD:.+]] = comb.add bin %lhs, %[[NOT_RHS]], %[[CONST]]
25+
// ALLOW_ADD-NEXT: hw.output %[[ADD]]
26+
hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) {
27+
%0 = comb.sub %lhs, %rhs : i4
28+
hw.output %0 : i4
29+
}

0 commit comments

Comments
 (0)