Skip to content

Commit

Permalink
fix: fix secret-distribute-generic with multiple ops and secret loop-…
Browse files Browse the repository at this point in the history
…carried vars

Re-runs secretness analysis after each split

If there is a secret loop-carried variable, the secretness analysis must be re-run after each split to ensure that the secretness state for a loop-carried variable is correct. Otherwise, after one operation is split from the generic, the generic is replaced and the secretness lattice doesn't think the loop-carried var (the yielded values from the loop) are secret.

PiperOrigin-RevId: 721009497
  • Loading branch information
asraa authored and copybara-github committed Jan 29, 2025
1 parent 8f0bff7 commit 595fee3
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 8 deletions.
15 changes: 10 additions & 5 deletions lib/Dialect/Secret/Transforms/DistributeGeneric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {

LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
auto top = op->getParentOfType<func::FuncOp>();
Block *body = op.getBody();
unsigned numOps = body->getOperations().size();
assert(numOps > 0 &&
Expand Down Expand Up @@ -569,15 +570,19 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
LLVM_DEBUG(opToDistribute->emitRemark()
<< "Distributing through region holding op isolated in its "
"own generic\n");
return distributeThroughRegionHoldingOp(op, *opToDistribute, rewriter);
}

if (first) {
LogicalResult result =
distributeThroughRegionHoldingOp(op, *opToDistribute, rewriter);
if (failed(result)) {
return failure();
}
} else if (first) {
splitGenericAfterFirstOp(op, rewriter);
} else {
splitGenericBeforeOp(op, *opToDistribute, rewriter);
}
return success();

// Rerun secretness analysis
return solver->initializeAndRun(top);
}

private:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: heir-opt --secret-distribute-generic %s | FileCheck %s

// This tests a generic with multiple operations in its body, including a loop
// with secret loop-carried variables. For this test, the secretness analysis
// must have been re-run after each operation is split to ensure that the
// secretness state for a loop-carried variable is correct.

// CHECK-LABEL: test
// CHECK-SAME: %[[arg0:.*]]: !secret.secret<tensor<1x1024xf32>>) -> !secret.secret<tensor<1x1024xf32>> {
module {
func.func @test(%arg0: !secret.secret<tensor<1x1024xf32>>) -> !secret.secret<tensor<1x1024xf32>> {
// CHECK-DAG: %[[cst:.*]] = arith.constant
// CHECK-DAG: %[[v0:.*]] = secret.conceal %[[cst]] : tensor<1x1024xf32>
// CHECK-NEXT: %[[v1:.*]] = affine.for %[[i:.*]] = 0 to 1023 iter_args(%[[arg2:.*]] = %[[v0]]) -> (!secret.secret<tensor<1x1024xf32>>) {
// CHECK-NEXT: %[[v3:.*]] = secret.generic ins(%[[arg0]], %[[arg2]] : !secret.secret<tensor<1x1024xf32>>, !secret.secret<tensor<1x1024xf32>>) {
// CHECK-NEXT: ^body(%[[input0:.*]]: tensor<1x1024xf32>, %[[input1:.*]]: tensor<1x1024xf32>):
// CHECK-NEXT: %[[v4:.*]] = arith.addf %[[input1]], %[[input0]] : tensor<1x1024xf32>
// CHECK-NEXT: secret.yield %[[v4]] : tensor<1x1024xf32>
// CHECK-NEXT: } -> !secret.secret<tensor<1x1024xf32>>
// CHECK-NEXT: affine.yield %[[v3]] : !secret.secret<tensor<1x1024xf32>>
// CHECK-NEXT: }
// CHECK-NEXT: %[[v2:.*]] = secret.generic ins(%[[v1]] : !secret.secret<tensor<1x1024xf32>>) {
// CHECK-NEXT: ^body(%[[input0:.*]]: tensor<1x1024xf32>):
// CHECK-NEXT: %[[v3:.*]] = arith.addf %[[input0]], %[[input0]] : tensor<1x1024xf32>
// CHECK-NEXT: secret.yield %[[v3]] : tensor<1x1024xf32>
// CHECK-NEXT: } -> !secret.secret<tensor<1x1024xf32>>
// CHECK-NEXT: return %[[v2]] : !secret.secret<tensor<1x1024xf32>>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<1x1024xf32>
%0 = secret.generic ins(%arg0 : !secret.secret<tensor<1x1024xf32>>) {
^body(%input0: tensor<1x1024xf32>):
%1 = affine.for %arg1 = 0 to 1023 iter_args(%arg2 = %cst_3) -> (tensor<1x1024xf32>) {
%9 = arith.addf %arg2, %input0 : tensor<1x1024xf32>
affine.yield %9 : tensor<1x1024xf32>
}
%3 = arith.addf %1, %1 : tensor<1x1024xf32>
secret.yield %3 : tensor<1x1024xf32>
} -> !secret.secret<tensor<1x1024xf32>>
return %0 : !secret.secret<tensor<1x1024xf32>>
}
}
6 changes: 3 additions & 3 deletions tests/Transforms/tosa_to_boolean_tfhe/hello_world_small.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ module attributes {tf_saved_model.semantics} {

func.func @main(%arg0: tensor<1x1xi8> {iree.identifier = "serving_default_dense_input:0", tf_saved_model.index_path = ["dense_input"]}) -> (tensor<1x1xi8> {iree.identifier = "StatefulPartitionedCall:0", tf_saved_model.index_path = ["dense_2"]}) attributes {tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tosa.const"() {value = dense<429> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tosa.const"() {value = dense<[[-39, 59, 39]]> : tensor<1x3xi8>} : () -> tensor<1x3xi8>
%2 = "tosa.const"() {value = dense<[-729, 1954, 610]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = "tosa.const"() {value = dense<[[-39, 0, 0]]> : tensor<1x3xi8>} : () -> tensor<1x3xi8>
%2 = "tosa.const"() {value = dense<[0, 1954, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
%3 = "tosa.const"() {value = dense<"0xF41AED091921F424E0"> : tensor<3x3xi8>} : () -> tensor<3x3xi8>
%4 = "tosa.const"() {value = dense<[0, 0, -5438]> : tensor<3xi32>} : () -> tensor<3xi32>
%5 = "tosa.const"() {value = dense<[[-9], [-54], [57]]> : tensor<3x1xi8>} : () -> tensor<3x1xi8>
%5 = "tosa.const"() {value = dense<[[-9], [-54], [2]]> : tensor<3x1xi8>} : () -> tensor<3x1xi8>
%6 = "tosa.fully_connected"(%arg0, %5, %4) {quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 0>} : (tensor<1x1xi8>, tensor<3x1xi8>, tensor<3xi32>) -> tensor<1x3xi32>
%7 = "tosa.rescale"(%6) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2039655736>, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array<i8: 38>} : (tensor<1x3xi32>) -> tensor<1x3xi8>
%8 = "tosa.clamp"(%7) {max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64} : (tensor<1x3xi8>) -> tensor<1x3xi8>
Expand Down

0 comments on commit 595fee3

Please sign in to comment.