diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h index eedcb6376b..265364e077 100644 --- a/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h @@ -377,8 +377,8 @@ class CustomMmaMultistage : public CustomMmaBase { CUTLASS_PRAGMA_UNROLL for (int stage = 0; stage < kNumStagesConcurrentLoad; ++stage, --gemm_k_iterations) { - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); iterator_A.set_iteration_index(0); smem_iterator_A_.set_iteration_index(0); @@ -559,8 +559,8 @@ class CustomMmaMultistage : public CustomMmaBase { ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; @@ -725,8 +725,8 @@ class CustomMmaMultistage : public CustomMmaBase { } --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/examples/45_dual_gemm/threadblock/dual_mma_multistage.h b/examples/45_dual_gemm/threadblock/dual_mma_multistage.h index 485922ef2e..200101125c 100644 --- a/examples/45_dual_gemm/threadblock/dual_mma_multistage.h +++ b/examples/45_dual_gemm/threadblock/dual_mma_multistage.h @@ -363,9 +363,9 @@ class DualMmaMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B0.clear_mask(gemm_k_iterations == 0); - iterator_B1.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B0.clear_mask(gemm_k_iterations <= 0); + iterator_B1.clear_mask(gemm_k_iterations <= 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -555,9 +555,9 @@ class DualMmaMultistage : ++this->warp_tile_iterator_B0_; ++this->warp_tile_iterator_B1_; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B0.clear_mask(gemm_k_iterations == 0); - iterator_B1.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B0.clear_mask(gemm_k_iterations <= 0); + iterator_B1.clear_mask(gemm_k_iterations <= 0); int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; @@ -730,9 +730,9 @@ class DualMmaMultistage : } --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B0.clear_mask(gemm_k_iterations == 0); - iterator_B1.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B0.clear_mask(gemm_k_iterations <= 0); + iterator_B1.clear_mask(gemm_k_iterations <= 0); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/include/cutlass/gemm/threadblock/ell_mma_multistage.h b/include/cutlass/gemm/threadblock/ell_mma_multistage.h index 27f410ccd1..2da5511d3b 100644 --- a/include/cutlass/gemm/threadblock/ell_mma_multistage.h +++ b/include/cutlass/gemm/threadblock/ell_mma_multistage.h @@ -332,8 +332,8 @@ class EllMmaMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -456,8 +456,8 @@ class EllMmaMultistage : ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); if (is_A_sparse){ iterator_A.ell_add_mask(ell_iterator.get_blocksize()); @@ -608,8 +608,8 @@ class EllMmaMultistage : } --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/include/cutlass/gemm/threadblock/mma_blas3_multistage.h b/include/cutlass/gemm/threadblock/mma_blas3_multistage.h index 11eb20adbb..94938eef90 100644 --- a/include/cutlass/gemm/threadblock/mma_blas3_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_blas3_multistage.h @@ -339,8 +339,8 @@ class MmaBlas3Multistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -519,8 +519,8 @@ class MmaBlas3Multistage : ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; @@ -661,8 +661,8 @@ class MmaBlas3Multistage : } --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h b/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h index 11ad544461..cbce0076b1 100644 --- a/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h @@ -572,9 +572,9 @@ class MmaLayernormMainloopFusionMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -692,9 +692,9 @@ class MmaLayernormMainloopFusionMultistage : ++this->warp_tile_iterator_A_gamma_beta_; ++this->warp_tile_iterator_B_; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; @@ -824,9 +824,9 @@ class MmaLayernormMainloopFusionMultistage : } --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/include/cutlass/gemm/threadblock/mma_multistage.h b/include/cutlass/gemm/threadblock/mma_multistage.h index ef55131707..be2281f46e 100644 --- a/include/cutlass/gemm/threadblock/mma_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_multistage.h @@ -370,8 +370,8 @@ class MmaMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { // Disable global fetching if done with global fetch iterations - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -588,8 +588,8 @@ class MmaMultistage : // Disable global fetching when done with global fetch iterations --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); } // The last warp-tile also converts the shared memory fragments used by @@ -620,8 +620,8 @@ class MmaMultistage : PipeState pipe_state; // Disable global fetching if done with global fetch iterations - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); // Load first warp-tile's A fragment from shared memory this->warp_tile_iterator_A_.set_kgroup_index(0); diff --git a/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h b/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h index b9deb6320e..0852e8dc89 100644 --- a/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h @@ -374,10 +374,10 @@ class MmaPlanarComplexMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - iterator_A_real.clear_mask(gemm_k_iterations == 0); - iterator_A_imag.clear_mask(gemm_k_iterations == 0); - iterator_B_real.clear_mask(gemm_k_iterations == 0); - iterator_B_imag.clear_mask(gemm_k_iterations == 0); + iterator_A_real.clear_mask(gemm_k_iterations <= 0); + iterator_A_imag.clear_mask(gemm_k_iterations <= 0); + iterator_B_real.clear_mask(gemm_k_iterations <= 0); + iterator_B_imag.clear_mask(gemm_k_iterations <= 0); iterator_A_real.set_iteration_index(0); iterator_A_imag.set_iteration_index(0); @@ -503,10 +503,10 @@ class MmaPlanarComplexMultistage : ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - iterator_A_real.clear_mask(gemm_k_iterations == 0); - iterator_A_imag.clear_mask(gemm_k_iterations == 0); - iterator_B_real.clear_mask(gemm_k_iterations == 0); - iterator_B_imag.clear_mask(gemm_k_iterations == 0); + iterator_A_real.clear_mask(gemm_k_iterations <= 0); + iterator_A_imag.clear_mask(gemm_k_iterations <= 0); + iterator_B_real.clear_mask(gemm_k_iterations <= 0); + iterator_B_imag.clear_mask(gemm_k_iterations <= 0); // Start issuing the first group of the next stage outside of the mainloop copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag); @@ -611,10 +611,10 @@ class MmaPlanarComplexMultistage : } --gemm_k_iterations; - iterator_A_real.clear_mask(gemm_k_iterations == 0); - iterator_A_imag.clear_mask(gemm_k_iterations == 0); - iterator_B_real.clear_mask(gemm_k_iterations == 0); - iterator_B_imag.clear_mask(gemm_k_iterations == 0); + iterator_A_real.clear_mask(gemm_k_iterations <= 0); + iterator_A_imag.clear_mask(gemm_k_iterations <= 0); + iterator_B_real.clear_mask(gemm_k_iterations <= 0); + iterator_B_imag.clear_mask(gemm_k_iterations <= 0); } warp_mma_planar_complex( diff --git a/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h b/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h index bd793fc84f..4c862dd377 100644 --- a/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h @@ -486,8 +486,8 @@ class MmaSoftmaxMainloopFusionMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -581,8 +581,8 @@ class MmaSoftmaxMainloopFusionMultistage : ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); // Start issuing the first group of the next stage outside of the mainloop copy_tiles_and_advance(iterator_A, iterator_B); @@ -708,8 +708,8 @@ class MmaSoftmaxMainloopFusionMultistage : } --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/include/cutlass/gemm/threadblock/mma_sparse_multistage.h b/include/cutlass/gemm/threadblock/mma_sparse_multistage.h index 8113583d69..30177e18ff 100644 --- a/include/cutlass/gemm/threadblock/mma_sparse_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_sparse_multistage.h @@ -381,9 +381,9 @@ class SparseMmaMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_E.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); + iterator_E.clear_mask(gemm_k_iterations <= 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -499,9 +499,9 @@ class SparseMmaMultistage : ++this->warp_tile_iterator_B_; ++this->warp_tile_iterator_E_; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_E.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); + iterator_E.clear_mask(gemm_k_iterations <= 0); int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; @@ -634,9 +634,9 @@ class SparseMmaMultistage : } --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_E.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); + iterator_E.clear_mask(gemm_k_iterations <= 0); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h b/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h index fa95dd7d2a..6f2ca69a23 100644 --- a/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h @@ -310,8 +310,8 @@ class MmaWithReductionMultistage : for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); iterator_A.set_iteration_index(0); this->smem_iterator_A_.set_iteration_index(0); @@ -403,8 +403,8 @@ class MmaWithReductionMultistage : ++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_B_; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; @@ -513,8 +513,8 @@ class MmaWithReductionMultistage : } --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_A.clear_mask(gemm_k_iterations <= 0); + iterator_B.clear_mask(gemm_k_iterations <= 0); } // Do any conversions feeding the first stage at the end of the loop so diff --git a/test/unit/gemm/device/testbed_grouped.h b/test/unit/gemm/device/testbed_grouped.h index d8f5d43913..0b3c8c5b1f 100644 --- a/test/unit/gemm/device/testbed_grouped.h +++ b/test/unit/gemm/device/testbed_grouped.h @@ -216,6 +216,8 @@ struct TestbedGrouped { problem_sizes_host.resize(problem_count); for (int32_t i = 0; i < problem_count; ++i) { + // Make the last problem a special case where the inner dimension is zero. + bool zero_k = i == problem_count - 1; cutlass::gemm::GemmCoord problem( 8 * (rand() % 64) + 24, @@ -224,6 +226,8 @@ struct TestbedGrouped { if (!i) { problem = cutlass::gemm::GemmCoord(48, 16, 8); + } else if (zero_k) { + problem.k() = 0; } problem_sizes_host.at(i) = problem; @@ -281,8 +285,16 @@ struct TestbedGrouped { std::vector ptr_D_host(problem_count); for (int32_t i = 0; i < problem_count; ++i) { - ptr_A_host.at(i) = block_A.get() + offset_A.at(i); - ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + bool zero_k = i == problem_count - 1; + if (zero_k) { + // For k=0, the input matrices have no elements and should not be accessed. + // Set the input pointers to nullptr to catch any unintended accesses. + ptr_A_host.at(i) = nullptr; + ptr_B_host.at(i) = nullptr; + } else { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + } ptr_C_host.at(i) = block_C.get() + offset_C.at(i); ptr_D_host.at(i) = block_D.get() + offset_D.at(i); } @@ -386,12 +398,15 @@ struct TestbedGrouped { ElementAccumulator(0) ); - // Ensure that no input or output is entirely zero - EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + // Ensure that no input or output is entirely zero, except for the last problem with k=0, + // which should produce an all-zero output. + if (i != problem_count - 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + } // Compare against reference passed = cutlass::reference::host::TensorEquals(view_D, view_Ref);