Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 99 additions & 16 deletions ynnpack/kernels/dot/arm64_sme.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,88 @@ __arm_new("za") __arm_locally_streaming void dot_impl(
const void* B_k1 = B_k2;
const void* A_k1 = A_k2;
ptrdiff_t k1 = K1;
while (k1 > 0) {
auto a = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b_0 =
// Here we pipeline two values of k1 at a time. We prime the pump
// here outside the loop by loading the first set of inputs.
// Load k1 = 0
auto a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b0_0 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 0);
auto b0_1 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 1);
auto b0_2 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 2);
auto b0_3 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 3);
while (k1 > dot_factor * 2) {
// Load k1 % 2 = 1
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
auto a1 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b1_0 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 0);
auto b_1 =
auto b1_1 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 1);
auto b_2 =
auto b1_2 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 2);
auto b_3 =
auto b1_3 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 3);
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a, b_0);
svmopa</*tile=*/1>(m_mask_ab, n_mask_ab, a, b_1);
svmopa</*tile=*/2>(m_mask_ab, n_mask_ab, a, b_2);
svmopa</*tile=*/3>(m_mask_ab, n_mask_ab, a, b_3);

k1 -= dot_factor;
// Compute k1 % 2 = 0
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, b0_0);
svmopa</*tile=*/1>(m_mask_ab, n_mask_ab, a0, b0_1);
svmopa</*tile=*/2>(m_mask_ab, n_mask_ab, a0, b0_2);
svmopa</*tile=*/3>(m_mask_ab, n_mask_ab, a0, b0_3);

// Load k1 % 2 = 0
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
b0_0 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 0);
b0_1 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 1);
b0_2 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 2);
b0_3 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 3);

// Compute k1 % 2 = 1
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a1, b1_0);
svmopa</*tile=*/1>(m_mask_ab, n_mask_ab, a1, b1_1);
svmopa</*tile=*/2>(m_mask_ab, n_mask_ab, a1, b1_2);
svmopa</*tile=*/3>(m_mask_ab, n_mask_ab, a1, b1_3);

k1 -= dot_factor * 2;
}
if (k1 > dot_factor) {
// Compute k1 % 2 = 0
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, b0_0);
svmopa</*tile=*/1>(m_mask_ab, n_mask_ab, a0, b0_1);
svmopa</*tile=*/2>(m_mask_ab, n_mask_ab, a0, b0_2);
svmopa</*tile=*/3>(m_mask_ab, n_mask_ab, a0, b0_3);

// Load odd tail case, but into the k1 % 2 = 0 values, so we don't
// need a special case for odd tails.
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
b0_0 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 0);
b0_1 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 1);
b0_2 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 2);
b0_3 =
svld1_vnum(n_mask_ab, reinterpret_cast<const TAB*>(B_k1), 3);

k1 -= dot_factor;
}
// Compute k1 % 2 = 0, or an odd tail.
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, b0_0);
svmopa</*tile=*/1>(m_mask_ab, n_mask_ab, a0, b0_1);
svmopa</*tile=*/2>(m_mask_ab, n_mask_ab, a0, b0_2);
svmopa</*tile=*/3>(m_mask_ab, n_mask_ab, a0, b0_3);

k2 -= 1;
B_k2 = offset_bytes(B_k2, B_stride_k2);
A_k2 = offset_bytes(A_k2, A_stride_k2);
Expand Down Expand Up @@ -178,15 +241,35 @@ __arm_new("za") __arm_locally_streaming void dot_impl(
const void* B_k1 = B_k2;
const void* A_k1 = A_k2;
ptrdiff_t k1 = K1;
while (k1 > 0) {
auto a = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b = svld1(n_mask_ab, reinterpret_cast<const TAB*>(B_k1));
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a, b);
auto a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b0 = svld1(n_mask_ab, reinterpret_cast<const TAB*>(B_k1));
while (k1 > dot_factor * 2) {
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
auto a1 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b1 = svld1(n_mask_ab, reinterpret_cast<const TAB*>(B_k1));
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, b0);

k1 -= dot_factor;
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
b0 = svld1(n_mask_ab, reinterpret_cast<const TAB*>(B_k1));
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a1, b1);

k1 -= dot_factor * 2;
}
if (k1 > dot_factor) {
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, b0);

B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
b0 = svld1(n_mask_ab, reinterpret_cast<const TAB*>(B_k1));

k1 -= dot_factor;
}
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, b0);

k2 -= 1;
B_k2 = offset_bytes(B_k2, B_stride_k2);
A_k2 = offset_bytes(A_k2, A_stride_k2);
Expand Down
85 changes: 72 additions & 13 deletions ynnpack/kernels/dot/arm64_sme2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,58 @@ __arm_new("za") __arm_locally_streaming void dot_impl(
const void* B_k1 = B_k2;
const void* A_k1 = A_k2;
ptrdiff_t k1 = K1;
while (k1 > 0) {
auto a = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b = svld1_x4(n_count_ab, reinterpret_cast<const TAB*>(B_k1));
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a, svget4(b, 0));
svmopa</*tile=*/1>(m_mask_ab, n_mask_ab, a, svget4(b, 1));
svmopa</*tile=*/2>(m_mask_ab, n_mask_ab, a, svget4(b, 2));
svmopa</*tile=*/3>(m_mask_ab, n_mask_ab, a, svget4(b, 3));
// Here we pipeline two values of k1 at a time. We prime the pump
// here outside the loop by loading the first set of inputs.
// Load k1 = 0
auto a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b0 = svld1_x4(n_count_ab, reinterpret_cast<const TAB*>(B_k1));
while (k1 > dot_factor * 2) {
// Load k1 % 2 = 1
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
auto a1 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b1 = svld1_x4(n_count_ab, reinterpret_cast<const TAB*>(B_k1));

k1 -= dot_factor;
// Compute k1 % 2 = 0
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, svget4(b0, 0));
svmopa</*tile=*/1>(m_mask_ab, n_mask_ab, a0, svget4(b0, 1));
svmopa</*tile=*/2>(m_mask_ab, n_mask_ab, a0, svget4(b0, 2));
svmopa</*tile=*/3>(m_mask_ab, n_mask_ab, a0, svget4(b0, 3));

// Load k1 % 2 = 0
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
b0 = svld1_x4(n_count_ab, reinterpret_cast<const TAB*>(B_k1));
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a1, svget4(b1, 0));
svmopa</*tile=*/1>(m_mask_ab, n_mask_ab, a1, svget4(b1, 1));
svmopa</*tile=*/2>(m_mask_ab, n_mask_ab, a1, svget4(b1, 2));
svmopa</*tile=*/3>(m_mask_ab, n_mask_ab, a1, svget4(b1, 3));

k1 -= 2 * dot_factor;
}
if (k1 > dot_factor) {
// Compute k1 % 2 = 0
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, svget4(b0, 0));
svmopa</*tile=*/1>(m_mask_ab, n_mask_ab, a0, svget4(b0, 1));
svmopa</*tile=*/2>(m_mask_ab, n_mask_ab, a0, svget4(b0, 2));
svmopa</*tile=*/3>(m_mask_ab, n_mask_ab, a0, svget4(b0, 3));

// Load odd tail case, but into the k1 % 2 = 0 values, so we don't
// need a special case for odd tails.
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
b0 = svld1_x4(n_count_ab, reinterpret_cast<const TAB*>(B_k1));

k1 -= dot_factor;
}
// Compute k1 % 2 = 0, or an old tail.
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, svget4(b0, 0));
svmopa</*tile=*/1>(m_mask_ab, n_mask_ab, a0, svget4(b0, 1));
svmopa</*tile=*/2>(m_mask_ab, n_mask_ab, a0, svget4(b0, 2));
svmopa</*tile=*/3>(m_mask_ab, n_mask_ab, a0, svget4(b0, 3));

k2 -= 1;
B_k2 = offset_bytes(B_k2, B_stride_k2);
A_k2 = offset_bytes(A_k2, A_stride_k2);
Expand Down Expand Up @@ -154,15 +194,34 @@ __arm_new("za") __arm_locally_streaming void dot_impl(
const void* B_k1 = B_k2;
const void* A_k1 = A_k2;
ptrdiff_t k1 = K1;
while (k1 > 0) {
auto a = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b = svld1(n_mask_ab, reinterpret_cast<const TAB*>(B_k1));
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a, b);
auto a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b0 = svld1(n_mask_ab, reinterpret_cast<const TAB*>(B_k1));
while (k1 > dot_factor * 2) {
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
auto a1 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
auto b1 = svld1(n_mask_ab, reinterpret_cast<const TAB*>(B_k1));
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, b0);

k1 -= dot_factor;
B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
b0 = svld1(n_mask_ab, reinterpret_cast<const TAB*>(B_k1));
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a1, b1);

k1 -= 2 * dot_factor;
}
if (k1 > dot_factor) {
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, b0);

B_k1 = offset_bytes(B_k1, B_stride_k1 * dot_factor);
A_k1 = offset_bytes(A_k1, A_stride_m);
a0 = svld1(m_mask_ab, reinterpret_cast<const TAB*>(A_k1));
b0 = svld1(n_mask_ab, reinterpret_cast<const TAB*>(B_k1));

k1 -= dot_factor;
}
svmopa</*tile=*/0>(m_mask_ab, n_mask_ab, a0, b0);
k2 -= 1;
B_k2 = offset_bytes(B_k2, B_stride_k2);
A_k2 = offset_bytes(A_k2, A_stride_k2);
Expand Down
Loading