@@ -64,7 +64,9 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
64
64
acc += make_fp32 (va[i]) * make_fp32 (vb[i]);
65
65
else if constexpr (std::is_same_v<Ta, float > &&
66
66
std::is_same_v<Tc, float > ||
67
- std::is_integral_v<Ta> && std::is_integral_v<Tc>)
67
+ std::is_integral_v<Ta> && std::is_integral_v<Tc> ||
68
+ (std::is_same_v<Ta, double > &&
69
+ std::is_same_v<Tc, double >))
68
70
acc += va[i] * vb[i];
69
71
else if constexpr (std::is_same_v<Ta, sycl::half> &&
70
72
std::is_same_v<Tc, float >)
@@ -127,7 +129,8 @@ void matrix_rand(unsigned int rows, unsigned int cols, T *src, T val) {
127
129
128
130
for (unsigned int i = 0 ; i < rows; i++) {
129
131
for (unsigned int j = 0 ; j < cols; j++) {
130
- if constexpr (std::is_same_v<T, bfloat16> || std::is_same_v<T, float >) {
132
+ if constexpr (std::is_same_v<T, bfloat16> || std::is_same_v<T, float > ||
133
+ std::is_same_v<T, double >) {
131
134
src[i * cols + j] = T (fdistr (dev));
132
135
} else if constexpr (std::is_same_v<T, int8_t > ||
133
136
std::is_same_v<T, int32_t >) {
0 commit comments