diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index e9b044141..8569dc40e 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -726,7 +726,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int batch_size = sizes[0]; int seqlen_q = sizes[1]; + const int seqlen_q_og = seqlen_q; int num_heads = sizes[2]; + const int num_heads_og = num_heads; const int head_size_og = sizes[3]; const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); @@ -784,8 +786,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); - if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + CHECK_SHAPE(out, batch_size, seqlen_q_og, num_heads_og, head_size_og); + if (head_size_og % 8 != 0) { + out = torch::empty_like(q_padded); + } else if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads, seqlen_q, head_size_og}).transpose(1, 2); + } } else { out = torch::empty_like(q_padded); }