diff --git a/include/dlaf/multiplication/general/impl.h b/include/dlaf/multiplication/general/impl.h index 2fc28c1df1..62d429ca19 100644 --- a/include/dlaf/multiplication/general/impl.h +++ b/include/dlaf/multiplication/general/impl.h @@ -11,6 +11,7 @@ #pragma once #include +#include #include #include #include @@ -37,6 +38,14 @@ void General::callNN(const T alpha, MatrixRef& mat_a, Matri DLAF_ASSERT_HEAVY(matrix::multipliable(mat_a, mat_b, mat_c, blas::Op::NoTrans, blas::Op::NoTrans), mat_a, mat_b, mat_c); + if (mat_a.nrTiles().cols() == 0) { + for (SizeType j = 0; j < mat_c.nrTiles().cols(); ++j) + for (SizeType i = 0; i < mat_c.nrTiles().rows(); ++i) + ex::start_detached(dlaf::internal::whenAllLift(beta, mat_c.readwrite(GlobalTileIndex(i, j))) | + tile::scal(dlaf::internal::Policy())); + return; + } + for (SizeType j = 0; j < mat_c.nrTiles().cols(); ++j) { for (SizeType i = 0; i < mat_c.nrTiles().rows(); ++i) { for (SizeType k = 0; k < mat_a.nrTiles().cols(); ++k) {