diff --git a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp index 4eedf1c232..9126430098 100644 --- a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp @@ -19,8 +19,8 @@ #include "src/Support/SmallVectorHelper.hpp" #define DEBUG_TYPE "lowering-to-krnl" -#define DEBUG_FORCE_SHUFFLE_REDUCTION 1 /* hi alex, should be 0 in repo */ -#define REDUCTION_MULTIPLE_OF_VL_ONLY 0 /* 1: conservative, 0 new */ +#define DEBUG_FORCE_SHUFFLE_REDUCTION 0 /* should be 0 in repo */ +#define REDUCTION_MULTIPLE_OF_VL_ONLY 0 /* 0: new improved, 1: old, for debug */ using namespace mlir; @@ -282,7 +282,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, using MDBuilder = MultiDialectBuilder; // hi alex + AffineBuilder>; //===----------------------------------------------------------------------===// // Helper function to perform reduction when an entire tensor is reduced to a diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir index 4e18edd5ec..ec98797e6b 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir @@ -1074,65 +1074,28 @@ func.func private @test_reducemax_int_v13(%arg0 : tensor<128x256x768xi32>) -> te "func.return"(%0) : (tensor<*xi32>) -> () // mlir2FileCheck.py -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0 + 1)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 + 2)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (d0 + 3)> // CHECK-LABEL: func.func private @test_reducemax_int_v13 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<128x256x768xi32>) -> memref<128x256xi32> { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-2147483648> : vector<4xi32> -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-2147483648> : vector<32xi32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<128x256xi32> // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 128, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){ -// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<4x4xi32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]#1) -// CHECK-DAG: [[VAR_3_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#1) -// CHECK-DAG: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#1) -// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK: affine.for [[I_2_:%.+]] = 0 to 768 step 4 { -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[I_2_]]{{.}} : memref<128x256x768xi32>, vector<4xi32> -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_2_]], [[I_2_]]{{.}} : memref<128x256x768xi32>, vector<4xi32> -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_3_]], [[I_2_]]{{.}} : memref<128x256x768xi32>, vector<4xi32> -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_4_]], [[I_2_]]{{.}} : memref<128x256x768xi32>, vector<4xi32> -// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK-DAG: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_26_:%.+]] = arith.maxsi [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xi32> -// CHECK-DAG: [[VAR_27_:%.+]] = arith.maxsi [[LOAD_RES_1_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xi32> -// CHECK-DAG: [[VAR_28_:%.+]] = arith.maxsi [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xi32> -// CHECK-DAG: [[VAR_29_:%.+]] = arith.maxsi [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xi32> -// CHECK: vector.store [[VAR_26_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK: vector.store [[VAR_27_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK: vector.store [[VAR_28_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK: vector.store [[VAR_29_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 128, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){ +// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<1x32xi32> +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32> +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 768){ +// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_5_]]{{.}} : memref<128x256x768xi32>, vector<32xi32> +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32> +// CHECK: [[VAR_8_:%.+]] = arith.maxsi [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<32xi32> +// CHECK: vector.store [[VAR_8_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32> // CHECK: } -// CHECK-DAG: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK-DAG: [[LOAD_RES_1_MEM_5_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK-DAG: [[LOAD_RES_1_MEM_6_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK-DAG: [[LOAD_RES_1_MEM_7_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xi32>, vector<4xi32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_4_]], [[LOAD_RES_1_MEM_5_]] [0, 4, 1, 5] : vector<4xi32>, vector<4xi32> -// CHECK-DAG: [[VAR_10_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_4_]], [[LOAD_RES_1_MEM_5_]] [2, 6, 3, 7] : vector<4xi32>, vector<4xi32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_11_:%.+]] = arith.maxsi [[VAR_10_]], [[VAR_9_]] : vector<4xi32> -// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xi32>, vector<4xi32> -// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xi32>, vector<4xi32> -// CHECK: [[VAR_14_:%.+]] = arith.maxsi [[VAR_13_]], [[VAR_12_]] : vector<4xi32> -// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [0, 1, 4, 5] : vector<4xi32>, vector<4xi32> -// CHECK-DAG: [[VAR_16_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [2, 3, 6, 7] : vector<4xi32>, vector<4xi32> -// CHECK: [[VAR_17_:%.+]] = arith.maxsi [[VAR_16_]], [[VAR_15_]] : vector<4xi32> -// CHECK: vector.store [[VAR_17_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<128x256xi32>, vector<4xi32> +// CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32> +// CHECK: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_1_MEM_1_]] : vector<32xi32> into i32 +// CHECK: krnl.store [[VAR_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<128x256xi32> // CHECK: } // CHECK: return [[RES_]] : memref<128x256xi32> // CHECK: }