Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
160c71e
Initial commit
hariharans29 Dec 1, 2025
a134ea0
More changes
hariharans29 Dec 1, 2025
a44f708
More changes
hariharans29 Dec 1, 2025
3a3ccf7
Fix builds
hariharans29 Dec 1, 2025
212dbf1
Fix builds 2
hariharans29 Dec 1, 2025
fceae09
Threaded
hariharans29 Dec 1, 2025
3793d70
Fix x64 builds
hariharans29 Dec 1, 2025
481a7f6
Experiment
hariharans29 Dec 2, 2025
8993a0a
Experiment revert
hariharans29 Dec 2, 2025
d765c1a
Refactor
hariharans29 Dec 2, 2025
a428d50
More changes
hariharans29 Dec 2, 2025
d53dd15
a
hariharans29 Dec 3, 2025
8f12c51
Try
hariharans29 Dec 3, 2025
67b6801
More changes
hariharans29 Dec 3, 2025
01b43fb
Relax padding
hariharans29 Dec 4, 2025
ea83394
Vanilla NEON Depthwise
hariharans29 Dec 4, 2025
dd94a3b
Fix indexing
hariharans29 Dec 4, 2025
ffd291a
Add benchmark
hariharans29 Dec 4, 2025
92fb604
Add lambda
hariharans29 Dec 4, 2025
d15bb93
Rework
hariharans29 Dec 4, 2025
119ec9a
Update onnxruntime/test/mlas/bench/bench_sconv.cpp
hariharans29 Dec 4, 2025
d0fc143
Fix
hariharans29 Dec 4, 2025
2820a84
Remove Winograd implementation
hariharans29 Dec 5, 2025
0ffb811
Update onnxruntime/core/mlas/lib/sconv_nchw_kernel_neon.cpp
hariharans29 Dec 5, 2025
59e2b2d
Update onnxruntime/core/mlas/lib/sconv_nchw_kernel_neon.cpp
hariharans29 Dec 5, 2025
e34c930
Update onnxruntime/core/mlas/lib/convolve.cpp
hariharans29 Dec 5, 2025
027e742
Update onnxruntime/core/mlas/inc/mlas.h
hariharans29 Dec 5, 2025
f93ed67
Update onnxruntime/core/mlas/lib/sconv_nchw_kernel_neon.cpp
hariharans29 Dec 5, 2025
f5c1b81
Update onnxruntime/core/mlas/lib/convolve.cpp
hariharans29 Dec 5, 2025
f15e554
Benchmark updates
hariharans29 Dec 5, 2025
bb324b5
Merge remote-tracking branch 'origin/main' into hari/expt_conv
hariharans29 Dec 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -306,9 +307,9 @@ endfunction()

function (setup_arm_neon_nchwc)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/sconv.h
${MLAS_SRC_DIR}/sconv_kernel_neon.cpp
${MLAS_SRC_DIR}/spool_kernel_neon.cpp
${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.h
${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.cpp
${MLAS_SRC_DIR}/spool_nchwc_kernel_neon.cpp
)
list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC)
set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE)
Expand Down Expand Up @@ -460,6 +461,7 @@ else()
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
)

# Conditionally add the SVE implementation if compiler supports it
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ enum MLAS_CONV_ALGORITHM {
MlasConvAlgorithmGemmDirect,
MlasConvAlgorithmExpandThenGemm,
MlasConvAlgorithmExpandThenGemmSegmented,
#if defined(MLAS_TARGET_WASM_SCALAR)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
MlasConvAlgorithmDepthwise,
#endif
};
Expand Down
136 changes: 129 additions & 7 deletions onnxruntime/core/mlas/lib/convolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,90 @@ Return Value:
}
}

#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)

void
MlasDepthwiseThreaded(
void* Context,
ptrdiff_t Index
)

/*++

Routine Description:

This routine is invoked from a worker thread to execute a segment of a
convolution operation.

If using this, the entire convolution operation is parallelized on the
(batch size * group count) parameter and this routine has logic to
perform a specific thread's shard of the entire Convolution operation.

Arguments:

Context - Supplies the pointer to the context for the threaded operation.

Index - Supplies the current index of the threaded operation.

Return Value:

None.

--*/

{

MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;

const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters;

const size_t GroupCount = Parameters->GroupCount;
const size_t BatchGroupCount = Parameters->BatchCount * GroupCount;

const size_t TargetThreadCount = WorkBlock->TargetThreadCount;

const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount;
const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount;

size_t BatchGroupStart;
size_t BatchGroupEnd;

if (static_cast<size_t>(Index) < BatchGroupCountExtra) {
BatchGroupStart = (BatchGroupCountPerThread + 1) * Index;
BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1;
} else {
BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra;
BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread;
}

const size_t FilterCount = Parameters->FilterCount;
const size_t OutputSize = Parameters->OutputSize;
const size_t K = Parameters->K;

const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize;
const size_t OutputGroupSize = FilterCount * OutputSize;
const size_t FilterGroupSize = FilterCount * K;

for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) {
size_t group = bg % GroupCount;

const float* input = WorkBlock->Input + bg * InputGroupSize;
const float* filter = WorkBlock->Filter + group * FilterGroupSize;
float* output = WorkBlock->Output + bg * OutputGroupSize;
const float* bias = WorkBlock->Bias;
if (bias != nullptr) {
bias += group * FilterCount;
}

float* WorkingBuffer = WorkBlock->WorkingBuffer;

MlasConvDepthwiseFloat_CHW(Parameters, input, filter, output, WorkingBuffer);
MlasActivation(Parameters->Activation, output, bias, FilterCount, OutputSize, OutputSize);
}
}

#endif

inline
bool
MlasConvTryMultithread(
Expand Down Expand Up @@ -985,7 +1069,7 @@ Return Value:
return;
}

#if defined(MLAS_TARGET_WASM_SCALAR)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)

if (Algorithm == MlasConvAlgorithmDepthwise) {
// Fill the Working Buffer with Zero for use by the depthwise kernel.
Expand Down Expand Up @@ -1019,6 +1103,35 @@ Return Value:
return;
}


#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)

if (Algorithm == MlasConvAlgorithmDepthwise && ((BatchCount > 1) || (GroupCount > 1))) {
const size_t BatchGroupCount = BatchCount * GroupCount;

ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);

if (static_cast<size_t>(TargetThreadCount) >= BatchGroupCount) {
TargetThreadCount = static_cast<ptrdiff_t>(BatchGroupCount);
}

MLAS_CONV_WORK_BLOCK WorkBlock;

WorkBlock.Parameters = Parameters;
WorkBlock.Input = Input;
WorkBlock.Filter = Filter;
WorkBlock.Bias = Bias;
WorkBlock.WorkingBuffer = WorkingBuffer;
WorkBlock.Output = Output;
WorkBlock.TargetThreadCount = TargetThreadCount;

MlasExecuteThreaded(MlasDepthwiseThreaded, &WorkBlock, TargetThreadCount, ThreadPool);

return;
}

#endif

//
// Iterate over each batch and group.
//
Expand Down Expand Up @@ -1082,7 +1195,7 @@ Return Value:
break;
}

#if defined(MLAS_TARGET_WASM_SCALAR)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)

case MlasConvAlgorithmDepthwise:
{
Expand Down Expand Up @@ -1337,17 +1450,26 @@ Return Value:

} else {

#if defined(MLAS_TARGET_WASM_SCALAR)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)

// Scalar direct conv for depthwise convolution.
// Currently only support 3x3 kernel with padding <=1 and dilations = 1.
// Scalar (WASM_SCALAR) / vectorized (ARM64) direct conv for depthwise convolution.
// Currently only support 3x3 kernel with padding <=1 and dilations = 1
// and on ARM64, it is further restricted to strides = 1.
// TODO: support more general depthwise convolution.

// On ARM64, only support stride = 1 for depthwise conv.
#if defined(MLAS_TARGET_ARM64)
bool depthwise_conv_stride_support_check = Parameters->StrideShape[0] == 1 && Parameters->StrideShape[1] == 1;
#else
bool depthwise_conv_stride_support_check = true;
#endif

if (Dimensions == 2
&& Parameters->FilterCount == 1 && Parameters->InputChannels == 1
&& Parameters->KernelShape[0] == 3 && Parameters->KernelShape[1] == 3
&& Parameters->Padding[0] <= 1 && Parameters->Padding[1] <= 1
&& Parameters->Padding[2] <= 1 && Parameters->Padding[3] <= 1
&& depthwise_conv_stride_support_check
&& Parameters->DilationShape[0] == 1 && Parameters->DilationShape[1] == 1) {

*WorkingBufferSize = Parameters->InputShape[1] + 2;
Expand Down Expand Up @@ -1411,8 +1533,8 @@ Return Value:

if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) {

size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K,
Parameters->FilterCount * Parameters->OutputSize,
size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K,
Parameters->FilterCount * Parameters->OutputSize,
static_cast<size_t>(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)});
TargetThreadCount = MaximumThreadCount;
if (static_cast<size_t>(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1601,7 +1601,8 @@ MlasFp32FromBits(
#pragma warning(pop)
#endif

#if defined(MLAS_TARGET_WASM_SCALAR)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)


void
MLASCALL
Expand Down
Loading
Loading