Skip to content

Commit

Permalink
Implement ModArithType for mod_arith dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Nov 12, 2024
1 parent 840f0f7 commit 07b4ac1
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 7 deletions.
27 changes: 27 additions & 0 deletions lib/Dialect/ModArith/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ cc_library(
hdrs = [
"ModArithDialect.h",
"ModArithOps.h",
"ModArithTypes.h",
],
deps = [
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
Expand All @@ -32,6 +34,7 @@ td_library(
srcs = [
"ModArithDialect.td",
"ModArithOps.td",
"ModArithTypes.td",
],
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
Expand Down Expand Up @@ -67,6 +70,30 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "types_inc_gen",
tbl_outs = [
(
["-gen-typedef-decls"],
"ModArithTypes.h.inc",
),
(
["-gen-typedef-defs"],
"ModArithTypes.cpp.inc",
),
(
["-gen-typedef-doc"],
"ModArithTypes.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ModArithTypes.td",
deps = [
":dialect_inc_gen",
":td_files",
],
)

gentbl_cc_library(
name = "ops_inc_gen",
tbl_outs = [
Expand Down
48 changes: 41 additions & 7 deletions lib/Dialect/ModArith/IR/ModArithDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@

#include <cassert>

#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

// NOLINTBEGIN(misc-include-cleaner): Required to define ModArithDialect and
// ModArithOps
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

// NOLINTBEGIN(misc-include-cleaner): Required to define ModArithDialect,
// ModArithTypes, ModArithOps
#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
// NOLINTEND(misc-include-cleaner)

// Generated definitions
#include "lib/Dialect/ModArith/IR/ModArithDialect.cpp.inc"

#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/ModArith/IR/ModArithTypes.cpp.inc"

#define GET_OP_CLASSES
#include "lib/Dialect/ModArith/IR/ModArithOps.cpp.inc"

Expand All @@ -24,12 +30,32 @@ namespace heir {
namespace mod_arith {

void ModArithDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "lib/Dialect/ModArith/IR/ModArithTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "lib/Dialect/ModArith/IR/ModArithOps.cpp.inc"
>();
}

/// Ensures that the underlying integer type is wide enough for the coefficient
template <typename OpType>
LogicalResult verifyModArithOpType(OpType op) {
auto type =
llvm::cast<ModArithType>(getElementTypeOrSelf(op.getResult().getType()));
APInt modulus = type.getModulus().getValue();
unsigned bitWidth = modulus.getBitWidth();
unsigned modWidth = modulus.getActiveBits();
if (modWidth > bitWidth - 1)
return op.emitOpError()
<< "underlying type bitwidth must be 1 bit larger than "
<< "the modulus bitwidth, but got " << bitWidth
<< " while modulus requires width " << modWidth << ".";
return success();
}

/// Ensures that the underlying integer type is wide enough for the coefficient
template <typename OpType>
LogicalResult verifyModArithOpMod(OpType op, bool reduce = false) {
Expand Down Expand Up @@ -61,6 +87,14 @@ LogicalResult verifyModArithOpMod(OpType op, bool reduce = false) {
return success();
}

LogicalResult ModReduceOp::verify() {
return verifyModArithOpType<ModReduceOp>(*this);
}

LogicalResult ModAddOp::verify() {
return verifyModArithOpType<ModAddOp>(*this);
}

LogicalResult AddOp::verify() { return verifyModArithOpMod<AddOp>(*this); }

LogicalResult SubOp::verify() { return verifyModArithOpMod<SubOp>(*this); }
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/ModArith/IR/ModArithDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def ModArith_Dialect : Dialect {
}];

let cppNamespace = "::mlir::heir::mod_arith";
let useDefaultTypePrinterParser = 1;

let dependentDialects = [
"arith::ArithDialect",
];
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/ModArith/IR/ModArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define LIB_DIALECT_MODARITH_IR_MODARITHOPS_TD_

include "lib/Dialect/ModArith/IR/ModArithDialect.td"
include "lib/Dialect/ModArith/IR/ModArithTypes.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
Expand All @@ -15,6 +16,35 @@ class ModArith_Op<string mnemonic, list<Trait> traits = [Pure]> :
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}

def ModArith_ModReduceOp : ModArith_Op<"mreduce", [Pure, ElementwiseMappable]> {
let summary = "reduce a signed integer to its congruence modulo equivalent";

let description = [{
}];

let arguments = (ins
SignlessIntegerLike:$input
);
let results = (outs ModArithLike:$output);
let hasVerifier = 1;
let assemblyFormat = "operands attr-dict `:` type($input) `->` type($output)";
}

class ModArith_ModBinaryOp<string mnemonic, list<Trait> traits = []> :
ModArith_Op<mnemonic, traits # [SameOperandsAndResultType, Pure, ElementwiseMappable]>,
Arguments<(ins ModArithLike:$lhs, ModArithLike:$rhs)>,
Results<(outs ModArithLike:$output)> {
let hasVerifier = 1;
let assemblyFormat ="operands attr-dict `:` type($output)";
}

def ModArith_ModAddOp : ModArith_ModBinaryOp<"madd", [Commutative]> {
let summary = "modular addition operation";
let description = [{
Computes addition modulo a statically known modulus $q$.
}];
}

class ModArith_BinaryOp<string mnemonic, list<Trait> traits = []> :
ModArith_Op<mnemonic, traits # [SameOperandsAndResultType, Pure, ElementwiseMappable]>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs, APIntAttr:$modulus)>,
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/ModArith/IR/ModArithTypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef LIB_DIALECT_MODARITH_IR_MODARITHTYPES_H_
#define LIB_DIALECT_MODARITH_IR_MODARITHTYPES_H_

#include "lib/Dialect/ModArith/IR/ModArithDialect.h"

#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/ModArith/IR/ModArithTypes.h.inc"

#endif // LIB_DIALECT_MODARITH_IR_MODARITHTYPES_H_
48 changes: 48 additions & 0 deletions lib/Dialect/ModArith/IR/ModArithTypes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef LIB_TYPES_MODARITH_IR_MODARITHTYPES_TD_
#define LIB_TYPES_MODARITH_IR_MODARITHTYPES_TD_

include "lib/Dialect/ModArith/IR/ModArithDialect.td"

include "mlir/IR/DialectBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/AttrTypeBase.td"

class ModArith_Type<string name, string typeMnemonic>
: TypeDef<ModArith_Dialect, name> {
let mnemonic = typeMnemonic;
}

def ModArith_ModArith : ModArith_Type<"ModArith", "mod_arith"> {
let summary = "Integer type with modular arithmetic";
let description = [{
`modulus` is the modulus the arithmetic working with.

`modulus` should be specified as, for example, `65537 : i32`.
It is required that the underlying integer type should be larger than
the modulus for a few bits. This requirement eases the handling of
modulus of similar bitwidth as the underlying type.

For example, when `modulus == 2 ** 32 - 1`, the underlying type
for the modulus should be `i64`.

The integer type should typically be `i32`, `i64` or higher.

Note that when the integer type is not specified, `i64` is implicitly
specified.

Examples:
```
!Zp1 = !mod_arith.mod_arith<modulus = 7> // implicitly being i64
!Zp2 = !mod_arith.mod_arith<modulus = 65537 : i32>
!Zp3 = !mod_arith.mod_arith<modulus = 536903681 : i64>
```
}];
let parameters = (ins
"::mlir::IntegerAttr":$modulus
);
let assemblyFormat = "`<` struct(params) `>`";
}

def ModArithLike: TypeOrContainer<ModArith_ModArith, "mod_arith-like">;

#endif // LIB_TYPES_MODARITH_IR_MODARITHTYPES_TD_

0 comments on commit 07b4ac1

Please sign in to comment.