Skip to content

Commit 652d3ea

Browse files
authored
[SYCL][Matrix tests] Missing general double type case in initialization and ref compute (#12800)
1 parent a261ac1 commit 652d3ea

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

sycl/test-e2e/Matrix/common.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
6464
acc += make_fp32(va[i]) * make_fp32(vb[i]);
6565
else if constexpr (std::is_same_v<Ta, float> &&
6666
std::is_same_v<Tc, float> ||
67-
std::is_integral_v<Ta> && std::is_integral_v<Tc>)
67+
std::is_integral_v<Ta> && std::is_integral_v<Tc> ||
68+
(std::is_same_v<Ta, double> &&
69+
std::is_same_v<Tc, double>))
6870
acc += va[i] * vb[i];
6971
else if constexpr (std::is_same_v<Ta, sycl::half> &&
7072
std::is_same_v<Tc, float>)
@@ -127,7 +129,8 @@ void matrix_rand(unsigned int rows, unsigned int cols, T *src, T val) {
127129

128130
for (unsigned int i = 0; i < rows; i++) {
129131
for (unsigned int j = 0; j < cols; j++) {
130-
if constexpr (std::is_same_v<T, bfloat16> || std::is_same_v<T, float>) {
132+
if constexpr (std::is_same_v<T, bfloat16> || std::is_same_v<T, float> ||
133+
std::is_same_v<T, double>) {
131134
src[i * cols + j] = T(fdistr(dev));
132135
} else if constexpr (std::is_same_v<T, int8_t> ||
133136
std::is_same_v<T, int32_t>) {

0 commit comments

Comments
 (0)