From 881a5183cffd51b91d6e868ff776334a78347964 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Sun, 12 Oct 2025 17:38:17 -0700 Subject: [PATCH] Add pipelining to SME kernels PiperOrigin-RevId: 818417179 --- ynnpack/kernels/dot/arm64_sme.cc | 115 +++++++++++++++++++++++++----- ynnpack/kernels/dot/arm64_sme2.cc | 85 ++++++++++++++++++---- 2 files changed, 171 insertions(+), 29 deletions(-) diff --git a/ynnpack/kernels/dot/arm64_sme.cc b/ynnpack/kernels/dot/arm64_sme.cc index c63298d2fc4..f032b3cdada 100644 --- a/ynnpack/kernels/dot/arm64_sme.cc +++ b/ynnpack/kernels/dot/arm64_sme.cc @@ -94,25 +94,88 @@ __arm_new("za") __arm_locally_streaming void dot_impl( const void* B_k1 = B_k2; const void* A_k1 = A_k2; ptrdiff_t k1 = K1; - while (k1 > 0) { - auto a = svld1(m_mask_ab, reinterpret_cast(A_k1)); - auto b_0 = + // Here we pipeline two values of k1 at a time. We prime the pump + // here outside the loop by loading the first set of inputs. + // Load k1 = 0 + auto a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + auto b0_0 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 0); + auto b0_1 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 1); + auto b0_2 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 2); + auto b0_3 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 3); + while (k1 > dot_factor * 2) { + // Load k1 % 2 = 1 + B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); + A_k1 = offset_bytes(A_k1, A_stride_m); + auto a1 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + auto b1_0 = svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 0); - auto b_1 = + auto b1_1 = svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 1); - auto b_2 = + auto b1_2 = svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 2); - auto b_3 = + auto b1_3 = svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 3); - svmopa(m_mask_ab, n_mask_ab, a, b_0); - svmopa(m_mask_ab, n_mask_ab, a, b_1); - svmopa(m_mask_ab, n_mask_ab, a, b_2); - svmopa(m_mask_ab, n_mask_ab, a, b_3); - k1 -= dot_factor; + // Compute k1 % 2 = 0 + svmopa(m_mask_ab, n_mask_ab, a0, b0_0); + svmopa(m_mask_ab, n_mask_ab, a0, b0_1); + svmopa(m_mask_ab, n_mask_ab, a0, b0_2); + svmopa(m_mask_ab, n_mask_ab, a0, b0_3); + + // Load k1 % 2 = 0 + B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); + A_k1 = offset_bytes(A_k1, A_stride_m); + a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + b0_0 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 0); + b0_1 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 1); + b0_2 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 2); + b0_3 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 3); + + // Compute k1 % 2 = 1 + svmopa(m_mask_ab, n_mask_ab, a1, b1_0); + svmopa(m_mask_ab, n_mask_ab, a1, b1_1); + svmopa(m_mask_ab, n_mask_ab, a1, b1_2); + svmopa(m_mask_ab, n_mask_ab, a1, b1_3); + + k1 -= dot_factor * 2; + } + if (k1 > dot_factor) { + // Compute k1 % 2 = 0 + svmopa(m_mask_ab, n_mask_ab, a0, b0_0); + svmopa(m_mask_ab, n_mask_ab, a0, b0_1); + svmopa(m_mask_ab, n_mask_ab, a0, b0_2); + svmopa(m_mask_ab, n_mask_ab, a0, b0_3); + + // Load odd tail case, but into the k1 % 2 = 0 values, so we don't + // need a special case for odd tails. B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); A_k1 = offset_bytes(A_k1, A_stride_m); + a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + b0_0 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 0); + b0_1 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 1); + b0_2 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 2); + b0_3 = + svld1_vnum(n_mask_ab, reinterpret_cast(B_k1), 3); + + k1 -= dot_factor; } + // Compute k1 % 2 = 0, or an odd tail. + svmopa(m_mask_ab, n_mask_ab, a0, b0_0); + svmopa(m_mask_ab, n_mask_ab, a0, b0_1); + svmopa(m_mask_ab, n_mask_ab, a0, b0_2); + svmopa(m_mask_ab, n_mask_ab, a0, b0_3); + k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); A_k2 = offset_bytes(A_k2, A_stride_k2); @@ -178,15 +241,35 @@ __arm_new("za") __arm_locally_streaming void dot_impl( const void* B_k1 = B_k2; const void* A_k1 = A_k2; ptrdiff_t k1 = K1; - while (k1 > 0) { - auto a = svld1(m_mask_ab, reinterpret_cast(A_k1)); - auto b = svld1(n_mask_ab, reinterpret_cast(B_k1)); - svmopa(m_mask_ab, n_mask_ab, a, b); + auto a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + auto b0 = svld1(n_mask_ab, reinterpret_cast(B_k1)); + while (k1 > dot_factor * 2) { + B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); + A_k1 = offset_bytes(A_k1, A_stride_m); + auto a1 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + auto b1 = svld1(n_mask_ab, reinterpret_cast(B_k1)); + svmopa(m_mask_ab, n_mask_ab, a0, b0); - k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); A_k1 = offset_bytes(A_k1, A_stride_m); + a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + b0 = svld1(n_mask_ab, reinterpret_cast(B_k1)); + svmopa(m_mask_ab, n_mask_ab, a1, b1); + + k1 -= dot_factor * 2; + } + if (k1 > dot_factor) { + svmopa(m_mask_ab, n_mask_ab, a0, b0); + + B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); + A_k1 = offset_bytes(A_k1, A_stride_m); + a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + b0 = svld1(n_mask_ab, reinterpret_cast(B_k1)); + + k1 -= dot_factor; } + svmopa(m_mask_ab, n_mask_ab, a0, b0); + k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); A_k2 = offset_bytes(A_k2, A_stride_k2); diff --git a/ynnpack/kernels/dot/arm64_sme2.cc b/ynnpack/kernels/dot/arm64_sme2.cc index 1a4c342dd89..d10cdf63e85 100644 --- a/ynnpack/kernels/dot/arm64_sme2.cc +++ b/ynnpack/kernels/dot/arm64_sme2.cc @@ -76,18 +76,58 @@ __arm_new("za") __arm_locally_streaming void dot_impl( const void* B_k1 = B_k2; const void* A_k1 = A_k2; ptrdiff_t k1 = K1; - while (k1 > 0) { - auto a = svld1(m_mask_ab, reinterpret_cast(A_k1)); - auto b = svld1_x4(n_count_ab, reinterpret_cast(B_k1)); - svmopa(m_mask_ab, n_mask_ab, a, svget4(b, 0)); - svmopa(m_mask_ab, n_mask_ab, a, svget4(b, 1)); - svmopa(m_mask_ab, n_mask_ab, a, svget4(b, 2)); - svmopa(m_mask_ab, n_mask_ab, a, svget4(b, 3)); + // Here we pipeline two values of k1 at a time. We prime the pump + // here outside the loop by loading the first set of inputs. + // Load k1 = 0 + auto a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + auto b0 = svld1_x4(n_count_ab, reinterpret_cast(B_k1)); + while (k1 > dot_factor * 2) { + // Load k1 % 2 = 1 + B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); + A_k1 = offset_bytes(A_k1, A_stride_m); + auto a1 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + auto b1 = svld1_x4(n_count_ab, reinterpret_cast(B_k1)); - k1 -= dot_factor; + // Compute k1 % 2 = 0 + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 0)); + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 1)); + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 2)); + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 3)); + + // Load k1 % 2 = 0 B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); A_k1 = offset_bytes(A_k1, A_stride_m); + a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + b0 = svld1_x4(n_count_ab, reinterpret_cast(B_k1)); + svmopa(m_mask_ab, n_mask_ab, a1, svget4(b1, 0)); + svmopa(m_mask_ab, n_mask_ab, a1, svget4(b1, 1)); + svmopa(m_mask_ab, n_mask_ab, a1, svget4(b1, 2)); + svmopa(m_mask_ab, n_mask_ab, a1, svget4(b1, 3)); + + k1 -= 2 * dot_factor; + } + if (k1 > dot_factor) { + // Compute k1 % 2 = 0 + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 0)); + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 1)); + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 2)); + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 3)); + + // Load odd tail case, but into the k1 % 2 = 0 values, so we don't + // need a special case for odd tails. + B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); + A_k1 = offset_bytes(A_k1, A_stride_m); + a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + b0 = svld1_x4(n_count_ab, reinterpret_cast(B_k1)); + + k1 -= dot_factor; } + // Compute k1 % 2 = 0, or an old tail. + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 0)); + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 1)); + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 2)); + svmopa(m_mask_ab, n_mask_ab, a0, svget4(b0, 3)); + k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); A_k2 = offset_bytes(A_k2, A_stride_k2); @@ -154,15 +194,34 @@ __arm_new("za") __arm_locally_streaming void dot_impl( const void* B_k1 = B_k2; const void* A_k1 = A_k2; ptrdiff_t k1 = K1; - while (k1 > 0) { - auto a = svld1(m_mask_ab, reinterpret_cast(A_k1)); - auto b = svld1(n_mask_ab, reinterpret_cast(B_k1)); - svmopa(m_mask_ab, n_mask_ab, a, b); + auto a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + auto b0 = svld1(n_mask_ab, reinterpret_cast(B_k1)); + while (k1 > dot_factor * 2) { + B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); + A_k1 = offset_bytes(A_k1, A_stride_m); + auto a1 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + auto b1 = svld1(n_mask_ab, reinterpret_cast(B_k1)); + svmopa(m_mask_ab, n_mask_ab, a0, b0); - k1 -= dot_factor; B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); A_k1 = offset_bytes(A_k1, A_stride_m); + a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + b0 = svld1(n_mask_ab, reinterpret_cast(B_k1)); + svmopa(m_mask_ab, n_mask_ab, a1, b1); + + k1 -= 2 * dot_factor; + } + if (k1 > dot_factor) { + svmopa(m_mask_ab, n_mask_ab, a0, b0); + + B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor); + A_k1 = offset_bytes(A_k1, A_stride_m); + a0 = svld1(m_mask_ab, reinterpret_cast(A_k1)); + b0 = svld1(n_mask_ab, reinterpret_cast(B_k1)); + + k1 -= dot_factor; } + svmopa(m_mask_ab, n_mask_ab, a0, b0); k2 -= 1; B_k2 = offset_bytes(B_k2, B_stride_k2); A_k2 = offset_bytes(A_k2, A_stride_k2);