diff --git a/include/dlaf/multiplication/general.h b/include/dlaf/multiplication/general.h index 3d99a8cdf2..c84b3decc2 100644 --- a/include/dlaf/multiplication/general.h +++ b/include/dlaf/multiplication/general.h @@ -82,8 +82,13 @@ void generalMatrix(comm::CommunicatorPipeline& row_ comm::CommunicatorPipeline& col_task_chain, const T alpha, MatrixRef& mat_a, MatrixRef& mat_b, const T beta, MatrixRef& mat_c) { - DLAF_ASSERT(matrix::same_process_grid(mat_c, mat_a), mat_c, mat_b); - DLAF_ASSERT(matrix::same_process_grid(mat_c, mat_b), mat_c, mat_b); + DLAF_ASSERT(equal_process_grid(mat_a, row_task_chain), mat_a, row_task_chain); + DLAF_ASSERT(equal_process_grid(mat_b, row_task_chain), mat_b, row_task_chain); + DLAF_ASSERT(equal_process_grid(mat_c, row_task_chain), mat_c, row_task_chain); + + DLAF_ASSERT(equal_process_grid(mat_a, col_task_chain), mat_a, col_task_chain); + DLAF_ASSERT(equal_process_grid(mat_b, col_task_chain), mat_b, col_task_chain); + DLAF_ASSERT(equal_process_grid(mat_c, col_task_chain), mat_c, col_task_chain); DLAF_ASSERT_HEAVY(matrix::multipliable(mat_a, mat_b, mat_c, blas::Op::NoTrans, blas::Op::NoTrans), mat_a, mat_b, mat_c);