Skip to content

Commit

Permalink
Various --secretize related cleanup
Browse files Browse the repository at this point in the history
* secretize applies to all funcs unless specified
* move secretize tests to correct folder
* remove --secretize from --mlir-to-secret-arithmetic
* switch WrapGeneric to use  WalkPatternRewriteDriver
* update WrapGeneric documentation
* cleanup naive_matmul example with aliases
  • Loading branch information
AlexanderViand-Intel committed Dec 4, 2024
1 parent 7800de5 commit df53b0d
Show file tree
Hide file tree
Showing 42 changed files with 124 additions and 121 deletions.
4 changes: 2 additions & 2 deletions docs/content/en/docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ but eventually we would use an MLIR-based tool to convert an input language to
MLIR like in that file. The program is below:

```mlir
func.func @dot_product(%arg0: tensor<8xi16>, %arg1: tensor<8xi16>) -> i16 {
func.func @dot_product(%arg0: tensor<8xi16> {secret.secret}, %arg1: tensor<8xi16> {secret.secret}) -> i16 {
%c0 = arith.constant 0 : index
%c0_si16 = arith.constant 0 : i16
%0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_si16) -> (i16) {
Expand Down Expand Up @@ -400,7 +400,7 @@ Which outputs

```bash
bazel run --noallow_analysis_cache_discard //tools:heir-opt -- \
--secretize=entry-function=box_blur --wrap-generic --canonicalize --cse --full-loop-unroll \
--secretize --wrap-generic --canonicalize --cse --full-loop-unroll \
--insert-rotate --cse --canonicalize --collapse-insertion-chains \
--canonicalize --cse /path/to/heir/tests/simd/box_blur_64x64.mlir
```
9 changes: 2 additions & 7 deletions lib/Pipelines/ArithmeticPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) {
manager.addPass(createCSEPass());
}

void mlirToSecretArithmeticPipelineBuilder(
OpPassManager &pm, const MlirToSecretArithmeticPipelineOptions &options) {
// Secretize inputs
pm.addPass(createSecretize(SecretizeOptions{options.entryFunction}));
void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) {
pm.addPass(createWrapGeneric());
convertToDataObliviousPipelineBuilder(pm);
pm.addPass(createCanonicalizerPass());
Expand All @@ -91,9 +88,7 @@ void mlirToSecretArithmeticPipelineBuilder(
void mlirToRLWEPipeline(OpPassManager &pm,
const MlirToRLWEPipelineOptions &options,
const RLWEScheme scheme) {
MlirToSecretArithmeticPipelineOptions mlirToSecretArithmeticPipelineOpts{};
mlirToSecretArithmeticPipelineOpts.entryFunction = options.entryFunction;
mlirToSecretArithmeticPipelineBuilder(pm, mlirToSecretArithmeticPipelineOpts);
mlirToSecretArithmeticPipelineBuilder(pm);

// Prepare to lower to RLWE Scheme
pm.addPass(secret::createSecretDistributeGeneric());
Expand Down
10 changes: 1 addition & 9 deletions lib/Pipelines/ArithmeticPipelineRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@ enum RLWEScheme { ckksScheme, bgvScheme };

void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager);

struct MlirToSecretArithmeticPipelineOptions
: public PassPipelineOptions<MlirToSecretArithmeticPipelineOptions> {
PassOptions::Option<std::string> entryFunction{
*this, "entry-function", llvm::cl::desc("Entry function to secretize"),
llvm::cl::init("main")};
};

struct MlirToRLWEPipelineOptions
: public PassPipelineOptions<MlirToRLWEPipelineOptions> {
PassOptions::Option<std::string> entryFunction{
Expand All @@ -44,8 +37,7 @@ void mlirToRLWEPipeline(OpPassManager &pm,
const MlirToRLWEPipelineOptions &options,
RLWEScheme scheme);

void mlirToSecretArithmeticPipelineBuilder(
OpPassManager &pm, const MlirToSecretArithmeticPipelineOptions &options);
void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm);

RLWEPipelineBuilder mlirToRLWEPipelineBuilder(RLWEScheme scheme);

Expand Down
18 changes: 7 additions & 11 deletions lib/Transforms/Secretize/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ def Secretize : Pass<"secretize", "ModuleOp"> {
let summary = "Adds secret argument attributes to entry function";

let description = [{
Adds a secret.secret attribute argument to each argument in the entry
function of an MLIR module. By default, the function is `main`. This may be
overridden with the option -entry-function=top_level_func.
Helper pass that adds a secret.secret attribute argument to each function argument.
By default, the pass applies to all functions in the module.
This may be overridden with the option -function=func_name to apply to a single function only.
}];

let dependentDialects = [
Expand All @@ -18,21 +18,17 @@ def Secretize : Pass<"secretize", "ModuleOp"> {
];

let options = [
Option<"entryFunction", "entry-function", "std::string", "\"main\"", "entry function of the module">
Option<"function", "function", "std::string", "\"\"", "function to add secret annotations to">
];
}

def WrapGeneric : Pass<"wrap-generic", "ModuleOp"> {
let summary = "Wraps regions using secret args in secret.generic bodies";

let description = [{
This pass wraps function regions of `func.func` that use secret arguments in
`secret.generic` bodies.

Secret arguments are annotated using a `secret.secret` argument attribute.
This pass converts these to secret types and then inserts a `secret.generic`
body to hold the functions region. The output type is also converted to a
secret.
This pass converts functions (`func.func`) with `{secret.secret}` annotated arguments
to use `!secret.secret<...>` types and wraps the function body in a `secret.generic` region.
The output type is also converted to `!secret.secret<...>`.

Example input:
```mlir
Expand Down
31 changes: 20 additions & 11 deletions lib/Transforms/Secretize/Secretize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,29 @@ struct Secretize : impl::SecretizeBase<Secretize> {
ModuleOp module = getOperation();
OpBuilder builder(module);

auto mainFunction = dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupSymbolIn(module, entryFunction));
if (!mainFunction) {
module.emitError("could not find entry point function");
signalPassFailure();
return;
}

auto secretArgAttr =
StringAttr::get(ctx, secret::SecretDialect::kArgSecretAttrName);
for (unsigned i = 0; i < mainFunction.getNumArguments(); i++) {
if (!isa<secret::SecretType>(mainFunction.getArgument(i).getType())) {
mainFunction.setArgAttr(i, secretArgAttr, UnitAttr::get(ctx));

auto setSecretAttr = [&](func::FuncOp func) {
for (unsigned i = 0; i < func.getNumArguments(); i++) {
if (!isa<secret::SecretType>(func.getArgument(i).getType())) {
func.setArgAttr(i, secretArgAttr, UnitAttr::get(ctx));
}
}
};

if (function.empty()) {
module.walk([&](func::FuncOp func) { setSecretAttr(func); });
} else {
auto mainFunction = dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupSymbolIn(module, function));
if (!mainFunction) {
module.emitError("could not find function \"" + function +
"\" to secretize");
signalPassFailure();
return;
}
setSecretAttr(mainFunction);
}
}
};
Expand Down
7 changes: 3 additions & 4 deletions lib/Transforms/Secretize/WrapGeneric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#include "lib/Transforms/Secretize/Passes.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand Down Expand Up @@ -92,8 +92,7 @@ struct WrapGeneric : impl::WrapGenericBase<WrapGeneric> {

mlir::RewritePatternSet patterns(context);
patterns.add<WrapWithGeneric>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)walkAndApplyPatterns(getOperation(), std::move(patterns));
}
};

Expand Down
9 changes: 5 additions & 4 deletions scripts/jupyter/Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,15 @@
"source": [
"### Abstracted Private Computation\n",
"\n",
"The first step is to mark which inputs in the IR should be treated as private, or `secret` data. HEIR has the `--secretize` pass, which marks all inputs to a given function with the secret annotation. You can use the pass flag `entry-function=$func` to select the function to secretize.\n",
"The first step is to mark which inputs in the IR should be treated as private, or `secret` data.\n",
"You can select the private argument(s) by adding the `{secret.secret}` annotation onto the function arguments.\n",
"\n",
"You can also manually select the private arguments by adding the `{secret.secret}` annotation onto the function arguments."
"HEIR also has the `--secretize` helper pass, which marks all function inputs with the secret annotation. You can use the pass flag `function=$func` to select the function to secretize."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -142,7 +143,7 @@
}
],
"source": [
"%%heir_opt --secretize=entry-function=cmux\n",
"%%heir_opt --secretize\n",
"\n",
"module {\n",
" func.func @cmux(%arg0: i16, %arg1: i16, %arg2: i1) -> i16 {\n",
Expand Down
9 changes: 0 additions & 9 deletions tests/Dialect/Secret/Transforms/secretize/missing.mlir

This file was deleted.

15 changes: 0 additions & 15 deletions tests/Dialect/Secret/Transforms/secretize/multiple.mlir

This file was deleted.

9 changes: 0 additions & 9 deletions tests/Dialect/Secret/Transforms/secretize/named.mlir

This file was deleted.

2 changes: 1 addition & 1 deletion tests/Dialect/TensorExt/Transforms/simd_pack.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --secretize=entry-function=main --wrap-generic \
// RUN: heir-opt --secretize --wrap-generic \
// RUN: --align-tensor-sizes --canonicalize --cse --split-input-file %s | FileCheck %s

module {
Expand Down
2 changes: 1 addition & 1 deletion tests/Examples/openfhe/box_blur_64x64.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
func.func @box_blur(%arg0: tensor<4096xi16>) -> tensor<4096xi16> {
func.func @box_blur(%arg0: tensor<4096xi16> {secret.secret}) -> tensor<4096xi16> {
%c4096 = arith.constant 4096 : index
%c64 = arith.constant 64 : index
%0 = affine.for %x = 0 to 64 iter_args(%arg0_x = %arg0) -> (tensor<4096xi16>) {
Expand Down
2 changes: 1 addition & 1 deletion tests/Examples/openfhe/dot_product_8.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
func.func @dot_product(%arg0: tensor<8xi16>, %arg1: tensor<8xi16>) -> i16 {
func.func @dot_product(%arg0: tensor<8xi16> {secret.secret}, %arg1: tensor<8xi16> {secret.secret}) -> i16 {
%c0 = arith.constant 0 : index
%c0_si16 = arith.constant 0 : i16
%0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_si16) -> (i16) {
Expand Down
2 changes: 1 addition & 1 deletion tests/Examples/openfhe/dot_product_8f.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
func.func @dot_product(%arg0: tensor<8xf16>, %arg1: tensor<8xf16>) -> f16 {
func.func @dot_product(%arg0: tensor<8xf16> {secret.secret}, %arg1: tensor<8xf16> {secret.secret}) -> f16 {
%c0 = arith.constant 0 : index
%c0_sf16 = arith.constant 0.1 : f16
%0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_sf16) -> (f16) {
Expand Down
2 changes: 1 addition & 1 deletion tests/Examples/openfhe/halevi_shoup_matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#map1 = affine_map<(d0, d1) -> (d0, d1)>

module {
func.func @matmul(%arg0 : tensor<1x16xf32>) -> tensor<1x16xf32> {
func.func @matmul(%arg0 : tensor<1x16xf32> {secret.secret}) -> tensor<1x16xf32> {
%cst = arith.constant dense<"0x5036CB3DED693F3E647E8E3F1C7029BFBEE1923F83F258BE77DDA1BFD7EDB93C309239C0842D833EB0C3163FD1DB2ABF5D971B3F1091853EED900CBE19464BBFE147C3BE584F72BF27116D3DAA05383F32109FBFC956703FADE2DCBEEE8F443E3EF9903F90A07CBF7803133FCDB2E83EDC29103FCC190A3FECD986BFAE4E853DE4A9393E3CB9E83EBA00FA3D0DBFF0BE79DF193F4E00073F29DA10BE1C9693BD12360BBFEEFCFBBD051EBDBEAE500E3F598190BF369C0D3F65AB74BF022E55BE47C021BEBD0E8D3EDDD4C93EBF82CB3F726237BFE1645E3FA4569FBDB0863BBFD15A1FC097BE063E94A451BFE4AA0B3F0C8726BEA2AD41BE1A15F63E1CDB073F40F376BF4D87BDBEA96AA03E8D9F843DBF6FFB3F9C8203BF24B92ABE7F70C5BE733F6C3F7CE7DA3E1F83C13F0284113EF339193F1B19E13CA1F5FA3E6E31C9BEA1078D3E5A0439BF1FD4A7BE3640A63F55B69FBF6D8B66BD4DF072BFC166A93E8D4BDF3FBF4AEABD9FD976BFB809763ECD4C10BFC88265BF6B2BABBEC202C5BE8D53EB3DBE94AABE3C3297BD75BF7FBE1EFEF83E1936893FAE8AA13EAF6CACBD615C2DC0473F593FE32D25BFDF71D0BD382D6CBE2DD392BD9C4FECB84BF853BED6E0493ECDCA91BE387D02BFF615053D7D4EB63F0042113E4C661B3FBF5F023F83868D3F497814BD911CAB3F4BC424BE7A020C3F6509A73E95E499BEEE54DB3EEFFC3CBF695FA93EA695923E1937CB3F553A37BF6EC745BF3DCF823EEC98153FC2D49C3F8523E03D07ED413D13E9853E016DA2BED73A6F3E2D0268BEEBC9613EBEB947BE870B93BE3402CB3EC68B41BE50F054BF161EB23E4FDC1F3EB562A1BDD9A115C0CCA8D6BDD65AFDBED757033F590DF63EC4280CBF8EA7EABE74C317BE4597B5BB576920BF6A4E0F3EEC66B9BE4072EE3F570AFE3D961D7C3E5FEE0ABD1EC09EBFEA77253E532AA23F15A656BE1923163DC24284BE374FFD3DB9F2A3BED185903E6294083F8D0700BFD998223FACE7A1BF79CC633C1039C43D8FB21CBF36CD1B3F99D43DBE85E451BF522F40BE8B94383D76727CBEADBB19BF755B6F3F9B9C1BBE4C08633D195E3E3E90944FBD511F9DBF8E44723F665C36BFCE54193F9ABE19C0ECACB53E88925B3EA19AA4BE1AD4EABEC5DE023FC759C83F37CE383FEB0713BDCBACC6BDCA2E0EBF28F39CBCE8A4C23F8D844E3F2791AF3DF31C79BE5129963F74A657BF3D09BD3DF1F9953EFDA50DBB79B7B8BE4A69D73EF01E2F3D23B418BFD8C9243F6CAA17BE21AE853F3C11793FBE51FABE47B452BE2A0A763E4CA4BDBF7AE739BE6C0AC33F0FDF2E3DBDA0BABC6E8E23BF37B836BF989532BE66736C3EABF141BE63FE853E26F7803F68822ABF7BFA0F3F34128B3EA3B655BFAB2F1C3ECC272F3F2B3A19BF198BD9BEA75C0DBF12C739BDE5F4E1BE1C591EBE"> : tensor<16x16xf32>
%cst_1 = arith.constant dense<[[-0.45141533, -0.0277900472, 0.311195374, 0.18254894, -0.258809537, 0.497506738, 0.00115649134, -0.194445714, 0.158549473, 0.000000e+00, 0.310650676, -0.214976981, -0.023661999, -0.392960966, 6.472870e-01, 0.831665277]]> : tensor<1x16xf32>
%2 = tensor.empty() : tensor<16x16xf32>
Expand Down
Loading

0 comments on commit df53b0d

Please sign in to comment.