diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 1ef0090049..19bbcd9d86 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -1289,7 +1289,9 @@ template <> GenOpMix getGenOpMix(Type t, Operation *op) { return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2}, {GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3}, - {GenericOps::FloorGop, 2}}; + {GenericOps::FloorGop, 2}, + {GenericOps::EstimatedVectorRegisterPressure, + 4 /* Little parallelism in code. */}}; } template <> diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 5d31eb9b8b..30802ee4e2 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -662,22 +662,28 @@ int64_t computeSuitableUnrollFactor(MemRefType memRefType, return 1; } // Gather operation statics - int64_t vectorizedOpNum, scalarOpNum; - double avgVL = VectorMachineSupport::getAvgArchVectorLength( - genOps, elementType, vectorizedOpNum, scalarOpNum); + int64_t vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure; + double avgVL = + VectorMachineSupport::getAvgArchVectorLength(genOps, elementType, + vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure); if (avgVL < 1.5) { LLVM_DEBUG(llvm::dbgs() << " simd disabled: too few SIMD operations with " << avgVL << " avg VL\n"); return 1; } - LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL << "\n"); + LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL + << ", vec op num " << vectorizedOpNum + << ", max reg pressure " + << estimatedMaxVectorRegisterPressure << "\n"); // Define a target max unroll as a function of register pressure. int64_t unrollVL; int64_t vrNum = VectorMachineSupport::getArchVectorRegisterNum(); - if (vectorizedOpNum >= vrNum / 2) + if (estimatedMaxVectorRegisterPressure >= vrNum) + unrollVL = 1; + else if (estimatedMaxVectorRegisterPressure * 2 >= vrNum) unrollVL = 2; - else if (vectorizedOpNum >= vrNum / 4) + else if (estimatedMaxVectorRegisterPressure * 4 >= vrNum) unrollVL = 4; else unrollVL = 8; diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 775ee0cc35..c51aab30b4 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -54,7 +54,9 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5}, {GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2}, {GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3}, - {GenericOps::FloorGop, 2}}; + {GenericOps::FloorGop, 2}, + {GenericOps::EstimatedVectorRegisterPressure, + 8 /* Little parallelism in code. */}}; totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/, innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount, simdOnly); @@ -83,7 +85,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, adjustX = create.math.add(roundX, zeroPoint); else adjustX = roundX; - // Saturate + // Saturate: use max into a min. Value saturateX = create.math.clip(adjustX, qMin, qMax); Value res = create.math.cast(quantizedElementType, saturateX); return res; diff --git a/src/Dialect/Mlir/VectorMachineSupport.cpp b/src/Dialect/Mlir/VectorMachineSupport.cpp index f5e6cf897e..14959470f8 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.cpp +++ b/src/Dialect/Mlir/VectorMachineSupport.cpp @@ -78,21 +78,31 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { } /*static*/ double VectorMachineSupport::getAvgArchVectorLength(GenOpMix &genOps, - Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) { + Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum, + int64_t &maxVectorRegisterPressure) { int64_t size = genOps.size(); if (!hasSimd()) { - vectorizedOpNum = 0; + vectorizedOpNum = maxVectorRegisterPressure = 0; scalarOpNum = size; return 1; } int64_t totProcessedValues = 0.0; - vectorizedOpNum = 0; + vectorizedOpNum = maxVectorRegisterPressure = 0; scalarOpNum = 0; + bool hasRegisterPressure = false; + // Determine which operations support SIMD and accumulate their vector // lengths. for (auto pair : genOps) { GenericOps genOp = pair.first; int64_t num = pair.second; + // Handle other metrics first. + if (genOp == GenericOps::EstimatedVectorRegisterPressure) { + maxVectorRegisterPressure = std::max(maxVectorRegisterPressure, num); + hasRegisterPressure = true; + continue; + } + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t vl = getArchVectorLength(genOp, elementType); // If past last value, assume 1; otherwise use actual value. // Accumulate weighted scalar/vectorized num and vl length. @@ -107,6 +117,10 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { // Compute final values int64_t totNum = vectorizedOpNum + scalarOpNum; scalarOpNum = size - vectorizedOpNum; + if (!hasRegisterPressure) { + // Estimate default register pressure as one per 2 vector operation. + maxVectorRegisterPressure = std::max(vectorizedOpNum / 2, (int64_t)1); + } return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0; } @@ -115,13 +129,13 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { // ============================================================================= int64_t Z16VectorMachineSupport::computeArchVectorLength( - GenericOps Gop, Type elementType) { + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); - // Support shared between int and float. - switch (Gop) { + switch (genOp) { case GenericOps::ScalarOnlyGop: return 1; // Must be scalar. case GenericOps::SelectGop: @@ -137,10 +151,10 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( // Supports only 32 and 64 bit Floats; There is support for extended too // but ignore this for now. if (!(bitWidth == 32 || bitWidth == 64 || - (bitWidth == 16 && Gop == GenericOps::ConversionGop))) + (bitWidth == 16 && genOp == GenericOps::ConversionGop))) return UNSUPPORTED; // Now we have a supported length, test for specific operations. - switch (Gop) { + switch (genOp) { case GenericOps::AbsGop: /* Supported via compare and select */ case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::CeilGop: /* Use load integer & rounding modes*/ @@ -161,7 +175,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( } } // Support for integer (we consider bit-wide ops as byte wide ops). - switch (Gop) { + switch (genOp) { // 1 - 16 byte operations. case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::ConversionGop: @@ -190,13 +204,14 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( // ============================================================================= int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( - GenericOps Gop, Type elementType) { + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); // Support shared between int and float. - switch (Gop) { + switch (genOp) { case GenericOps::ScalarOnlyGop: return 1; // Must be scalar. case GenericOps::SelectGop: @@ -212,10 +227,10 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( // Supports only 32 and 64 bit Floats; There is support for extended too // but ignore this for now. if (!(bitWidth == 32 || bitWidth == 64 || - (bitWidth == 16 && Gop == GenericOps::ConversionGop))) + (bitWidth == 16 && genOp == GenericOps::ConversionGop))) return UNSUPPORTED; // Now we have a supported length, test for specific operations. - switch (Gop) { + switch (genOp) { case GenericOps::AbsGop: case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::CeilGop: @@ -237,7 +252,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( } } // Support for integer (we consider bit-wide ops as byte wide ops). - switch (Gop) { + switch (genOp) { // 1 - 16 byte operations. case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::ConversionGop: @@ -276,13 +291,14 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( // ============================================================================= int64_t NeonVectorMachineSupport::computeArchVectorLength( - GenericOps Gop, Type elementType) { + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); // Support shared between int and float. - switch (Gop) { + switch (genOp) { case GenericOps::ScalarOnlyGop: return 1; // Must be scalar. case GenericOps::SelectGop: @@ -297,10 +313,10 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength( if (isFloat) { // Supports only 32 and 64 bit Floats; if (!(bitWidth == 32 || bitWidth == 64 || - (bitWidth == 16 && Gop == GenericOps::ConversionGop))) + (bitWidth == 16 && genOp == GenericOps::ConversionGop))) return UNSUPPORTED; // Now we have a supported length, test for specific operations. - switch (Gop) { + switch (genOp) { case GenericOps::AbsGop: case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::CeilGop: @@ -322,7 +338,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength( } } // Support for integer (we consider bit-wide ops as byte wide ops). - switch (Gop) { + switch (genOp) { // 1 - 16 byte operations. case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::ConversionGop: @@ -370,10 +386,19 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) { for (auto pair : mix1) { GenericOps genOp = pair.first; int64_t num = pair.second; - if (u.find(genOp) != u.end()) - u[genOp] += num; // Has this op already, add to it. - else + if (u.find(genOp) != u.end()) { + // Merge the 2 operation counts/metrics. + if (genOp == GenericOps::EstimatedVectorRegisterPressure) { + // For register pressure, pick the max of both. + u[genOp] = std::max(u[genOp], num); + } else { + // For operation count, use the sum of both + u[genOp] += num; + } + } else { + // First time we have this. u[genOp] = num; + } } return u; } diff --git a/src/Dialect/Mlir/VectorMachineSupport.hpp b/src/Dialect/Mlir/VectorMachineSupport.hpp index bcd2ad1a88..0d1104bbad 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.hpp +++ b/src/Dialect/Mlir/VectorMachineSupport.hpp @@ -32,6 +32,10 @@ namespace onnx_mlir { // (e.g. all the compares). enum class GenericOps { + ///////////////////////////////////// + // Generic ops. + ///////////////////////////////////// + AbsGop, ArithmeticGop, /* Simple compute ops: add/sub/neg + ops of same complexity. */ CeilDivGop, @@ -62,6 +66,17 @@ enum class GenericOps { TrigArcGop, /* Arc trigonometry ops: asin, acos, atan. */ TrigGop, /* Trigonometry ops: sin, cos, tan. */ TrigHyperbolicGop, /* Hyperbolic trig. */ + + LastGop, /* Marker of the last op. Used to delineate from other metrics. */ + + ///////////////////////////////////// + // Metrics others than operations. + ///////////////////////////////////// + + // Metric that provides an estimate of the maximum number of vector registers + // used in a kernel. If none is provided, we estimate the pressure based on + // the number of operations. + EstimatedVectorRegisterPressure, }; // Describe the mix of Generic operations in a given kernel. Each generic @@ -132,8 +147,12 @@ class VectorMachineSupport { // number of times that generic operation was found. Note that scalar // operation have a vector length of one in the weighted average as they still // contribute one result. + // Max vector register pressure is also reported, either from an explicit + // mention in the genOps, or estimated as one vector register per vector + // operation. static double getAvgArchVectorLength(GenOpMix &genOps, mlir::Type elementType, - int64_t &vectorizedOpNum, int64_t &scalarOpNum); + int64_t &vectorizedOpNum, int64_t &scalarOpNum, + int64_t &maxVectorRegisterPressure); protected: // Virtual functions that do the actual work. Called by the "get" functions. diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir index 61800a518a..12c4cbaa10 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir @@ -13,11 +13,11 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_only // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<256x16xf32>) -> (memref<256x16xui8>, memref, memref) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> // CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> // CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> // CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 @@ -98,36 +98,36 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: affine.store [[CST_4096_]], [[RES_9_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 4096){ // CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> -// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<16xf32> +// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<16xf32> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<16xf32> +// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<8xf32> -// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<16xf32> +// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<16xf32> +// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_55_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_55_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xui8>, vector<8xui8> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<16xf32> +// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<16xf32> +// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<16xf32> to vector<16xi8> +// CHECK: [[VAR_55_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_55_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xui8>, vector<16xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<256x16xui8>, memref, memref // CHECK: } @@ -143,11 +143,11 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // mlir2FileCheck.py // CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_and_scalar // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<255x17xf32>) -> (memref<255x17xui8>, memref, memref) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> // CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> // CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> // CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 @@ -241,39 +241,39 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: affine.store [[CST_4335_]], [[RES_9_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_2_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 4328){ +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 4320){ // CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.splat [[VAR_11_]] : vector<8xf32> -// CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<8xf32> -// CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = math.floor [[LOAD_RES_4_MEM_1_]] : vector<8xf32> -// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_]], [[LOAD_RES_6_MEM_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<16xf32> +// CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = math.floor [[LOAD_RES_4_MEM_1_]] : vector<16xf32> +// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_]], [[LOAD_RES_6_MEM_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_]], [[LOAD_RES_6_MEM_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_43_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = math.floor [[VAR_43_]] : vector<8xf32> -// CHECK: [[VAR_45_:%.+]] = arith.mulf [[VAR_44_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_46_:%.+]] = arith.subf [[LOAD_RES_6_MEM_1_]], [[VAR_45_]] : vector<8xf32> -// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_46_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_43_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_44_:%.+]] = math.floor [[VAR_43_]] : vector<16xf32> +// CHECK: [[VAR_45_:%.+]] = arith.mulf [[VAR_44_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_46_:%.+]] = arith.subf [[LOAD_RES_6_MEM_1_]], [[VAR_45_]] : vector<16xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_46_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_47_]], [[VAR_48_]], [[LOAD_RES_6_MEM_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_47_]], [[VAR_48_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_:%.+]] = arith.select [[VAR_50_]], [[VAR_49_]], [[VAR_42_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_52_:%.+]] = vector.splat [[VAR_29_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_]], [[VAR_52_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_21_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xui8>, vector<8xui8> +// CHECK-DAG: [[VAR_51_:%.+]] = arith.select [[VAR_50_]], [[VAR_49_]], [[VAR_42_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_52_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_]], [[VAR_52_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<16xf32> to vector<16xi8> +// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_21_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xui8>, vector<16xui8> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 4328 to 4335){ +// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 4320 to 4335){ // CHECK: [[VAR_34_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index // CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[VAR_reshape_19_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xf32> // CHECK: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_11_]] : f32 diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir index 2bcf1dba86..3d4983815e 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir @@ -18,11 +18,11 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (d0 * -512 + 4096, 512)> // CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_only // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<256x16xf32>) -> (memref<256x16xui8>, memref, memref) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> // CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> // CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<1xf32> // CHECK-DAG: [[VAR_cst_6_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> @@ -152,37 +152,37 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: affine.store [[CST_4096_]], [[RES_11_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_25_:%.+]] = memref.reshape [[RES_]]([[RES_]]_24) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4096){ // CHECK: [[VAR_33_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_34_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xf32>, vector<8xf32> -// CHECK-DAG: [[VAR_35_2_:%.+]] = vector.splat [[VAR_11_]] : vector<8xf32> -// CHECK: [[VAR_36_1_:%.+]] = arith.divf [[VAR_34_1_]], [[VAR_35_2_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = math.floor [[VAR_36_1_]] : vector<8xf32> -// CHECK: [[VAR_38_2_:%.+]] = arith.subf [[VAR_36_1_]], [[VAR_37_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_2_:%.+]] = arith.cmpf ogt, [[VAR_38_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.addf [[VAR_37_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.select [[VAR_39_2_]], [[VAR_40_1_]], [[VAR_37_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_44_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_45_1_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_44_1_]] : vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_45_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_37_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_37_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_38_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_41_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_51_2_:%.+]] = vector.splat [[VAR_29_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_2_]], [[VAR_51_2_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_56_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_56_]], [[VAR_reshape_25_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xui8>, vector<8xui8> +// CHECK-DAG: [[VAR_34_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_35_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_36_1_:%.+]] = arith.divf [[VAR_34_1_]], [[VAR_35_2_]] : vector<16xf32> +// CHECK: [[VAR_37_1_:%.+]] = math.floor [[VAR_36_1_]] : vector<16xf32> +// CHECK: [[VAR_38_2_:%.+]] = arith.subf [[VAR_36_1_]], [[VAR_37_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_39_2_:%.+]] = arith.cmpf ogt, [[VAR_38_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.addf [[VAR_37_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.select [[VAR_39_2_]], [[VAR_40_1_]], [[VAR_37_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_37_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> +// CHECK: [[VAR_44_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_45_1_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_44_1_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_45_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_37_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_37_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_38_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_50_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_41_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_51_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_2_]], [[VAR_51_2_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<16xf32> to vector<16xi8> +// CHECK: [[VAR_56_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_56_]], [[VAR_reshape_25_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xui8>, vector<16xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_13, [[RES_]]_14 : memref<256x16xui8>, memref, memref // CHECK: } @@ -203,11 +203,11 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (d0 * -542 + 4335, 542)> // CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_and_scalar // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<255x17xf32>) -> (memref<255x17xui8>, memref, memref) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> // CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> // CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<1xf32> // CHECK-DAG: [[VAR_cst_6_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> @@ -337,40 +337,40 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: affine.store [[CST_4335_]], [[RES_11_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_25_:%.+]] = memref.reshape [[RES_]]([[RES_]]_24) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4328){ +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4320){ // CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xf32>, vector<8xf32> -// CHECK-DAG: [[VAR_36_2_:%.+]] = vector.splat [[VAR_11_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_35_1_]], [[VAR_36_2_]] : vector<8xf32> -// CHECK: [[VAR_38_1_:%.+]] = math.floor [[VAR_37_1_]] : vector<8xf32> -// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_38_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_1_]], [[VAR_38_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_45_1_]] : vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_38_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_42_1_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_52_2_:%.+]] = vector.splat [[VAR_29_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_2_]], [[VAR_52_2_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xui8>, vector<8xui8> +// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_36_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_35_1_]], [[VAR_36_2_]] : vector<16xf32> +// CHECK: [[VAR_38_1_:%.+]] = math.floor [[VAR_37_1_]] : vector<16xf32> +// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_38_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_1_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> +// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_45_1_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_51_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_42_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_52_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_2_]], [[VAR_52_2_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<16xf32> to vector<16xi8> +// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xui8>, vector<16xui8> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 4328 to 4335){ +// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 4320 to 4335){ // CHECK: [[VAR_34_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index // CHECK: [[VAR_35_1_1_:%.+]] = krnl.load [[VAR_reshape_23_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xf32> // CHECK: [[VAR_36_3_:%.+]] = arith.divf [[VAR_35_1_1_]], [[VAR_11_]] : f32