Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix illegal memory accesses in multistage Mma's for k=0 #1593

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kNumStagesConcurrentLoad;
++stage, --gemm_k_iterations) {
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -559,8 +559,8 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -725,8 +725,8 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
18 changes: 9 additions & 9 deletions examples/45_dual_gemm/threadblock/dual_mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,9 @@ class DualMmaMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B0.clear_mask(gemm_k_iterations == 0);
iterator_B1.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B0.clear_mask(gemm_k_iterations <= 0);
iterator_B1.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -555,9 +555,9 @@ class DualMmaMultistage :
++this->warp_tile_iterator_B0_;
++this->warp_tile_iterator_B1_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B0.clear_mask(gemm_k_iterations == 0);
iterator_B1.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B0.clear_mask(gemm_k_iterations <= 0);
iterator_B1.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -730,9 +730,9 @@ class DualMmaMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B0.clear_mask(gemm_k_iterations == 0);
iterator_B1.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B0.clear_mask(gemm_k_iterations <= 0);
iterator_B1.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
12 changes: 6 additions & 6 deletions include/cutlass/gemm/threadblock/ell_mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ class EllMmaMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -456,8 +456,8 @@ class EllMmaMultistage :
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

if (is_A_sparse){
iterator_A.ell_add_mask(ell_iterator.get_blocksize());
Expand Down Expand Up @@ -608,8 +608,8 @@ class EllMmaMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
12 changes: 6 additions & 6 deletions include/cutlass/gemm/threadblock/mma_blas3_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ class MmaBlas3Multistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -519,8 +519,8 @@ class MmaBlas3Multistage :
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -661,8 +661,8 @@ class MmaBlas3Multistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,9 @@ class MmaLayernormMainloopFusionMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -692,9 +692,9 @@ class MmaLayernormMainloopFusionMultistage :
++this->warp_tile_iterator_A_gamma_beta_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -824,9 +824,9 @@ class MmaLayernormMainloopFusionMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_A_gamma_beta.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
12 changes: 6 additions & 6 deletions include/cutlass/gemm/threadblock/mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ class MmaMultistage :
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {

// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -588,8 +588,8 @@ class MmaMultistage :

// Disable global fetching when done with global fetch iterations
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// The last warp-tile also converts the shared memory fragments used by
Expand Down Expand Up @@ -620,8 +620,8 @@ class MmaMultistage :
PipeState pipe_state;

// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

// Load first warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index(0);
Expand Down
24 changes: 12 additions & 12 deletions include/cutlass/gemm/threadblock/mma_planar_complex_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,10 @@ class MmaPlanarComplexMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A_real.clear_mask(gemm_k_iterations == 0);
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
iterator_B_real.clear_mask(gemm_k_iterations == 0);
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
iterator_A_real.clear_mask(gemm_k_iterations <= 0);
iterator_A_imag.clear_mask(gemm_k_iterations <= 0);
iterator_B_real.clear_mask(gemm_k_iterations <= 0);
iterator_B_imag.clear_mask(gemm_k_iterations <= 0);

iterator_A_real.set_iteration_index(0);
iterator_A_imag.set_iteration_index(0);
Expand Down Expand Up @@ -503,10 +503,10 @@ class MmaPlanarComplexMultistage :
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A_real.clear_mask(gemm_k_iterations == 0);
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
iterator_B_real.clear_mask(gemm_k_iterations == 0);
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
iterator_A_real.clear_mask(gemm_k_iterations <= 0);
iterator_A_imag.clear_mask(gemm_k_iterations <= 0);
iterator_B_real.clear_mask(gemm_k_iterations <= 0);
iterator_B_imag.clear_mask(gemm_k_iterations <= 0);

// Start issuing the first group of the next stage outside of the mainloop
copy_tiles_and_advance(iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag);
Expand Down Expand Up @@ -611,10 +611,10 @@ class MmaPlanarComplexMultistage :
}

--gemm_k_iterations;
iterator_A_real.clear_mask(gemm_k_iterations == 0);
iterator_A_imag.clear_mask(gemm_k_iterations == 0);
iterator_B_real.clear_mask(gemm_k_iterations == 0);
iterator_B_imag.clear_mask(gemm_k_iterations == 0);
iterator_A_real.clear_mask(gemm_k_iterations <= 0);
iterator_A_imag.clear_mask(gemm_k_iterations <= 0);
iterator_B_real.clear_mask(gemm_k_iterations <= 0);
iterator_B_imag.clear_mask(gemm_k_iterations <= 0);
}

warp_mma_planar_complex(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@ class MmaSoftmaxMainloopFusionMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -581,8 +581,8 @@ class MmaSoftmaxMainloopFusionMultistage :
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

// Start issuing the first group of the next stage outside of the mainloop
copy_tiles_and_advance(iterator_A, iterator_B);
Expand Down Expand Up @@ -708,8 +708,8 @@ class MmaSoftmaxMainloopFusionMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
18 changes: 9 additions & 9 deletions include/cutlass/gemm/threadblock/mma_sparse_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ class SparseMmaMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_E.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
iterator_E.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -499,9 +499,9 @@ class SparseMmaMultistage :
++this->warp_tile_iterator_B_;
++this->warp_tile_iterator_E_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_E.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
iterator_E.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -634,9 +634,9 @@ class SparseMmaMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_E.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
iterator_E.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
12 changes: 6 additions & 6 deletions include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ class MmaWithReductionMultistage :
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
Expand Down Expand Up @@ -403,8 +403,8 @@ class MmaWithReductionMultistage :
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;

iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);

int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
Expand Down Expand Up @@ -513,8 +513,8 @@ class MmaWithReductionMultistage :
}

--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.clear_mask(gemm_k_iterations <= 0);
iterator_B.clear_mask(gemm_k_iterations <= 0);
}

// Do any conversions feeding the first stage at the end of the loop so
Expand Down
Loading