diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index c0ab948b41fff..11c32dac1df04 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -113,6 +113,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + ${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -306,9 +307,9 @@ endfunction() function (setup_arm_neon_nchwc) target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/sconv.h - ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp - ${MLAS_SRC_DIR}/spool_kernel_neon.cpp + ${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.h + ${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.cpp + ${MLAS_SRC_DIR}/spool_nchwc_kernel_neon.cpp ) list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC) set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE) @@ -460,6 +461,7 @@ else() ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + ${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp ) # Conditionally add the SVE implementation if compiler supports it diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 248c6d74e6cbd..c0d9681153c25 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -840,7 +840,7 @@ enum MLAS_CONV_ALGORITHM { MlasConvAlgorithmGemmDirect, MlasConvAlgorithmExpandThenGemm, MlasConvAlgorithmExpandThenGemmSegmented, -#if defined(MLAS_TARGET_WASM_SCALAR) +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) MlasConvAlgorithmDepthwise, #endif }; diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index 9518134631f2d..0349ce418d406 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -805,6 +805,90 @@ Return Value: } } +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) + +void +MlasDepthwiseThreaded( + void* Context, + ptrdiff_t Index +) + +/*++ + +Routine Description: + + This routine is invoked from a worker thread to execute a segment of a + convolution operation. + + If using this, the entire convolution operation is parallelized on the + (batch size * group count) parameter and this routine has logic to + perform a specific thread's shard of the entire Convolution operation. + +Arguments: + + Context - Supplies the pointer to the context for the threaded operation. + + Index - Supplies the current index of the threaded operation. + +Return Value: + + None. + +--*/ + +{ + + MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; + + const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters; + + const size_t GroupCount = Parameters->GroupCount; + const size_t BatchGroupCount = Parameters->BatchCount * GroupCount; + + const size_t TargetThreadCount = WorkBlock->TargetThreadCount; + + const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount; + const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount; + + size_t BatchGroupStart; + size_t BatchGroupEnd; + + if (static_cast(Index) < BatchGroupCountExtra) { + BatchGroupStart = (BatchGroupCountPerThread + 1) * Index; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1; + } else { + BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread; + } + + const size_t FilterCount = Parameters->FilterCount; + const size_t OutputSize = Parameters->OutputSize; + const size_t K = Parameters->K; + + const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize; + const size_t OutputGroupSize = FilterCount * OutputSize; + const size_t FilterGroupSize = FilterCount * K; + + for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) { + size_t group = bg % GroupCount; + + const float* input = WorkBlock->Input + bg * InputGroupSize; + const float* filter = WorkBlock->Filter + group * FilterGroupSize; + float* output = WorkBlock->Output + bg * OutputGroupSize; + const float* bias = WorkBlock->Bias; + if (bias != nullptr) { + bias += group * FilterCount; + } + + float* WorkingBuffer = WorkBlock->WorkingBuffer; + + MlasConvDepthwiseFloat_CHW(Parameters, input, filter, output, WorkingBuffer); + MlasActivation(Parameters->Activation, output, bias, FilterCount, OutputSize, OutputSize); + } +} + +#endif + inline bool MlasConvTryMultithread( @@ -985,7 +1069,7 @@ Return Value: return; } -#if defined(MLAS_TARGET_WASM_SCALAR) +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) if (Algorithm == MlasConvAlgorithmDepthwise) { // Fill the Working Buffer with Zero for use by the depthwise kernel. @@ -1019,6 +1103,35 @@ Return Value: return; } + +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) + + if (Algorithm == MlasConvAlgorithmDepthwise && ((BatchCount > 1) || (GroupCount > 1))) { + const size_t BatchGroupCount = BatchCount * GroupCount; + + ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (static_cast(TargetThreadCount) >= BatchGroupCount) { + TargetThreadCount = static_cast(BatchGroupCount); + } + + MLAS_CONV_WORK_BLOCK WorkBlock; + + WorkBlock.Parameters = Parameters; + WorkBlock.Input = Input; + WorkBlock.Filter = Filter; + WorkBlock.Bias = Bias; + WorkBlock.WorkingBuffer = WorkingBuffer; + WorkBlock.Output = Output; + WorkBlock.TargetThreadCount = TargetThreadCount; + + MlasExecuteThreaded(MlasDepthwiseThreaded, &WorkBlock, TargetThreadCount, ThreadPool); + + return; + } + +#endif + // // Iterate over each batch and group. // @@ -1082,7 +1195,7 @@ Return Value: break; } -#if defined(MLAS_TARGET_WASM_SCALAR) +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) case MlasConvAlgorithmDepthwise: { @@ -1337,17 +1450,26 @@ Return Value: } else { -#if defined(MLAS_TARGET_WASM_SCALAR) +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) - // Scalar direct conv for depthwise convolution. - // Currently only support 3x3 kernel with padding <=1 and dilations = 1. + // Scalar (WASM_SCALAR) / vectorized (ARM64) direct conv for depthwise convolution. + // Currently only support 3x3 kernel with padding <=1 and dilations = 1 + // and on ARM64, it is further restricted to strides = 1. // TODO: support more general depthwise convolution. + // On ARM64, only support stride = 1 for depthwise conv. + #if defined(MLAS_TARGET_ARM64) + bool depthwise_conv_stride_support_check = Parameters->StrideShape[0] == 1 && Parameters->StrideShape[1] == 1; + #else + bool depthwise_conv_stride_support_check = true; + #endif + if (Dimensions == 2 && Parameters->FilterCount == 1 && Parameters->InputChannels == 1 && Parameters->KernelShape[0] == 3 && Parameters->KernelShape[1] == 3 && Parameters->Padding[0] <= 1 && Parameters->Padding[1] <= 1 && Parameters->Padding[2] <= 1 && Parameters->Padding[3] <= 1 + && depthwise_conv_stride_support_check && Parameters->DilationShape[0] == 1 && Parameters->DilationShape[1] == 1) { *WorkingBufferSize = Parameters->InputShape[1] + 2; @@ -1411,8 +1533,8 @@ Return Value: if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) { - size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K, - Parameters->FilterCount * Parameters->OutputSize, + size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K, + Parameters->FilterCount * Parameters->OutputSize, static_cast(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)}); TargetThreadCount = MaximumThreadCount; if (static_cast(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) { diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index ad62cccbfb9c7..386570454b2fd 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1601,7 +1601,8 @@ MlasFp32FromBits( #pragma warning(pop) #endif -#if defined(MLAS_TARGET_WASM_SCALAR) +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) + void MLASCALL diff --git a/onnxruntime/core/mlas/lib/sconv_nchw_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_nchw_kernel_neon.cpp new file mode 100644 index 0000000000000..d3ec05ec92fac --- /dev/null +++ b/onnxruntime/core/mlas/lib/sconv_nchw_kernel_neon.cpp @@ -0,0 +1,297 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sconv_nchw_kernel_neon.cpp + +Abstract: + + This module implements the single precision NCHW convolution kernels for ARM NEON. + +--*/ + + +#include "mlasi.h" +#include + +MLAS_FORCEINLINE float DepthwiseSampleValue( + const float* row, + ptrdiff_t col, + size_t width +) +{ + if (row == nullptr || col < 0 || col >= static_cast(width)) { + return 0.0f; + } + return row[col]; +} + +MLAS_FORCEINLINE float DepthwiseAccumulateRowScalar( + float acc, + const float* row, + size_t base, + float w0, + float w1, + float w2 +) +{ + if (row == nullptr) { + return acc; + } + + acc += row[base] * w0; + acc += row[base + 1] * w1; + acc += row[base + 2] * w2; + return acc; +} + +MLAS_FORCEINLINE void DepthwiseAccumulateRowVector( + float32x4_t& acc, + const float* row, + size_t base, + float w0, + float w1, + float w2 +) +{ + if (row == nullptr) { + return; + } + + const float* r = row + base; + const float32x4_t c0 = vld1q_f32(r); + const float32x4_t c1 = vld1q_f32(r + 1); + const float32x4_t c2 = vld1q_f32(r + 2); + + acc = vmlaq_n_f32(acc, c0, w0); + acc = vmlaq_n_f32(acc, c1, w1); + acc = vmlaq_n_f32(acc, c2, w2); +} + +MLAS_FORCEINLINE float DepthwiseComputeEdge( + const float* row0, + const float* row1, + const float* row2, + ptrdiff_t iw, + size_t width, + const float w00, + const float w01, + const float w02, + const float w10, + const float w11, + const float w12, + const float w20, + const float w21, + const float w22 +) +{ + float acc = 0.0f; + const ptrdiff_t c0 = iw; + const ptrdiff_t c1 = iw + 1; + const ptrdiff_t c2 = iw + 2; + + acc += DepthwiseSampleValue(row0, c0, width) * w00; + acc += DepthwiseSampleValue(row0, c1, width) * w01; + acc += DepthwiseSampleValue(row0, c2, width) * w02; + acc += DepthwiseSampleValue(row1, c0, width) * w10; + acc += DepthwiseSampleValue(row1, c1, width) * w11; + acc += DepthwiseSampleValue(row1, c2, width) * w12; + acc += DepthwiseSampleValue(row2, c0, width) * w20; + acc += DepthwiseSampleValue(row2, c1, width) * w21; + acc += DepthwiseSampleValue(row2, c2, width) * w22; + + return acc; +} + +static void DepthwiseConv3x3Stride1PadLe1Neon( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + float* Output +) +{ + const size_t H = Parameters->InputShape[0]; + const size_t W = Parameters->InputShape[1]; + const size_t out_rows = Parameters->OutputShape[0]; + const size_t out_cols = Parameters->OutputShape[1]; + + const size_t pad_top = Parameters->Padding[0]; + const size_t pad_left = Parameters->Padding[1]; + const size_t pad_right = Parameters->Padding[3]; + + const float beta = Parameters->Beta; + const bool accumulate_output = beta != 0.0f; + + const float w00 = Filter[0]; + const float w01 = Filter[1]; + const float w02 = Filter[2]; + const float w10 = Filter[3]; + const float w11 = Filter[4]; + const float w12 = Filter[5]; + const float w20 = Filter[6]; + const float w21 = Filter[7]; + const float w22 = Filter[8]; + + for (size_t oh = 0; oh < out_rows; ++oh) { + const ptrdiff_t ih = static_cast(oh) - static_cast(pad_top); + + const ptrdiff_t row0_index = ih; + const ptrdiff_t row1_index = ih + 1; + const ptrdiff_t row2_index = ih + 2; + + const float* row0 = nullptr; + const float* row1 = nullptr; + const float* row2 = nullptr; + + if (row0_index >= 0 && row0_index < static_cast(H)) { + row0 = Input + static_cast(row0_index) * W; + } + if (row1_index >= 0 && row1_index < static_cast(H)) { + row1 = Input + static_cast(row1_index) * W; + } + if (row2_index >= 0 && row2_index < static_cast(H)) { + row2 = Input + static_cast(row2_index) * W; + } + + float* out_row = Output + oh * out_cols; + size_t ow = 0; + + if (pad_left && ow < out_cols) { + const ptrdiff_t iw = static_cast(ow) - static_cast(pad_left); + float acc = DepthwiseComputeEdge( + row0, row1, row2, iw, W, + w00, w01, w02, w10, w11, w12, w20, w21, w22 + ); + if (accumulate_output) { + acc += beta * out_row[ow]; + } + out_row[ow++] = acc; + } + + size_t interior_cols = 0; + if (out_cols > pad_left + pad_right) { + interior_cols = out_cols - pad_left - pad_right; + } + + size_t processed = 0; + while (processed + 4 <= interior_cols) { + const ptrdiff_t iw = static_cast(ow) - static_cast(pad_left); + if ((iw + 5) >= static_cast(W)) { + break; + } + + const size_t base = static_cast(iw); + float32x4_t acc = vdupq_n_f32(0.0f); + + DepthwiseAccumulateRowVector(acc, row0, base, w00, w01, w02); + DepthwiseAccumulateRowVector(acc, row1, base, w10, w11, w12); + DepthwiseAccumulateRowVector(acc, row2, base, w20, w21, w22); + + if (accumulate_output) { + const float32x4_t prev = vld1q_f32(out_row + ow); + acc = vmlaq_n_f32(acc, prev, beta); + } + + vst1q_f32(out_row + ow, acc); + ow += 4; + processed += 4; + } + + for (; processed < interior_cols; ++processed) { + const ptrdiff_t iw = static_cast(ow) - static_cast(pad_left); + const size_t base = static_cast(iw); + + float acc = 0.0f; + acc = DepthwiseAccumulateRowScalar(acc, row0, base, w00, w01, w02); + acc = DepthwiseAccumulateRowScalar(acc, row1, base, w10, w11, w12); + acc = DepthwiseAccumulateRowScalar(acc, row2, base, w20, w21, w22); + + if (accumulate_output) { + acc += beta * out_row[ow]; + } + out_row[ow++] = acc; + } + + if (pad_right && ow < out_cols) { + const ptrdiff_t iw = static_cast(ow) - static_cast(pad_left); + float acc = DepthwiseComputeEdge( + row0, row1, row2, iw, W, + w00, w01, w02, w10, w11, w12, w20, w21, w22 + ); + if (accumulate_output) { + acc += beta * out_row[ow]; + } + out_row[ow++] = acc; + } + } +} + +static +void +MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + float* Output + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute convolution on one channel input with one filter channel. + +Arguments: + + Parameters - conv parameters calculated based on conv parameters like padding, strides, dilations, etc. + + Input - input channel data start. Input is NCHW, so this pointer points to single H x W image data. + + Filter - Whole filters are of F x CpG x FH x FW, this filter points to single FH x FW filter data. + + Output - whole output are of N x F x OH x OW. This pointer points to single OH x OW output image data. + +--*/ +{ + DepthwiseConv3x3Stride1PadLe1Neon(Parameters, Input, Filter, Output); +} + +void MlasConvDepthwiseFloat_CHW( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + float* Output, + const float* Zeros + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute depthwise convolution for one filter channel on one input channel. + +Arguments: + + Parameters - conv parameters calculated based on conv parameters like padding, strides, dilations, etc. + + Input - input channel data start. Input is NCHW, so this pointer point to single H x W image data. + + Filter - Whole filters are of F x CpG x FH x FW, this filter points to single FH x FW filter data. + + Output - whole output are of N x F x OH x OW. This pointer point to single OH x OW output image data. + + Zeros - Point to working buffer where all 0.0f are filled. + +Note: + No checking here as it is inner loop. Logic in generating Parameters controls the check. + + Currently only support 2d kernel 3x3 with strides=1, dilations=1, pads<=1. + Will add general case and more special case if needed later. + +--*/ +{ + MLAS_UNREFERENCED_PARAMETER(Zeros); + MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1(Parameters, Input, Filter, Output); +} diff --git a/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp similarity index 99% rename from onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp rename to onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp index 4c5f50adb929c..a387c6a07992a 100644 --- a/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp @@ -6,18 +6,18 @@ Licensed under the MIT License. Module Name: - sconv_kernel_neon.cpp + sconv_nchwc_kernel_neon.cpp Abstract: - This module implements the single precision convolution kernels for ARM NEON. + This module implements the single precision NCHWC convolution kernels for ARM NEON. --*/ #if defined(MLAS_USE_ARM_NEON_NCHWC) #include "mlasi.h" -#include "sconv.h" +#include "sconv_nchwc_kernel_neon.h" constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; diff --git a/onnxruntime/core/mlas/lib/sconv.h b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.h similarity index 100% rename from onnxruntime/core/mlas/lib/sconv.h rename to onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.h diff --git a/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp b/onnxruntime/core/mlas/lib/spool_nchwc_kernel_neon.cpp similarity index 100% rename from onnxruntime/core/mlas/lib/spool_kernel_neon.cpp rename to onnxruntime/core/mlas/lib/spool_nchwc_kernel_neon.cpp diff --git a/onnxruntime/test/mlas/bench/bench_sconv.cpp b/onnxruntime/test/mlas/bench/bench_sconv.cpp index dc37980002978..849911e322214 100644 --- a/onnxruntime/test/mlas/bench/bench_sconv.cpp +++ b/onnxruntime/test/mlas/bench/bench_sconv.cpp @@ -326,6 +326,10 @@ static void TeamsModel(benchmark::internal::Benchmark* b) { b->Args({2, 1, 1, 12, 12, 48, 80, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); // fused Conv_376 => 48x80 b->Args({2, 1, 1, 12, 72, 48, 80, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); // Conv_59 => 24x40 + + b->Args({2, 1, 256, 1, 1, 378, 378, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // External customer model + b->Args({2, 1, 512, 1, 1, 378, 378, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // External customer model + b->Args({2, 1, 960, 1, 1, 378, 378, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // External customer model } BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime();