Skip to content

Commit

Permalink
bmad doesn't use marray.
Browse files Browse the repository at this point in the history
Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
  • Loading branch information
JackAKirk committed Aug 10, 2022
1 parent 0283942 commit 42e2b17
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ struct joint_matrix<
"For the matrix_use::a case, matrix_layout::row_major must "
"be used for Bitwise MAD");
};
int32_t wi_marray[1];
int32_t data;
};

template <matrix_layout Layout>
Expand All @@ -181,7 +181,7 @@ struct joint_matrix<
"For the matrix_use::b case, matrix_layout::col_major must "
"be used for Bitwise MAD");
};
int32_t wi_marray[1];
int32_t data;
};
#undef __SYCL_JOINT_MATRIX_OVERLOAD

Expand Down Expand Up @@ -418,10 +418,10 @@ struct joint_matrix_load_impl<
matrix::precision::b1>::value) {
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
if constexpr (NumRows == 8 && NumCols == 128) {
__bmma_m8n8k128_ld_a_b1(res.wi_marray, tileptr, stride,
__bmma_m8n8k128_ld_a_b1(&res.data, tileptr, stride,
get_layout_id<Layout>());
} else if constexpr (NumRows == 128 && NumCols == 8) {
__bmma_m8n8k128_ld_b_b1(res.wi_marray, tileptr, stride,
__bmma_m8n8k128_ld_b_b1(&res.data, tileptr, stride,
get_layout_id<Layout>());
}
}
Expand Down Expand Up @@ -803,14 +803,14 @@ struct joint_matrix_bmad_impl<
sycl::bit_and<sycl::ext::oneapi::experimental::matrix::
precision::b1>>::value) {
__bmma_m8n8k128_mma_and_popc_b1(
reinterpret_cast<int32_t *>(&D.wi_marray), A.wi_marray, B.wi_marray,
reinterpret_cast<int32_t *>(&D.wi_marray), &A.data, &B.data,
reinterpret_cast<int32_t *>(&C.wi_marray), 1);
} else if constexpr (std::is_same<
BinaryOperation,
sycl::bit_xor<sycl::ext::oneapi::experimental::
matrix::precision::b1>>::value) {
__bmma_m8n8k128_mma_xor_popc_b1(
reinterpret_cast<int32_t *>(&D.wi_marray), A.wi_marray, B.wi_marray,
reinterpret_cast<int32_t *>(&D.wi_marray), &A.data, &B.data,
reinterpret_cast<int32_t *>(&C.wi_marray), 1);
}
return D;
Expand Down

0 comments on commit 42e2b17

Please sign in to comment.