From e5da6e4dcd436a782da8ef73c03cdc95f60e9442 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 31 May 2024 10:05:15 -0700 Subject: [PATCH] Fix out kwarg shape check with ngroups swapped (#4) --- csrc/flash_attn/flash_api.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index e9b044141..8569dc40e 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -726,7 +726,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int batch_size = sizes[0]; int seqlen_q = sizes[1]; + const int seqlen_q_og = seqlen_q; int num_heads = sizes[2]; + const int num_heads_og = num_heads; const int head_size_og = sizes[3]; const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); @@ -784,8 +786,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); - if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + CHECK_SHAPE(out, batch_size, seqlen_q_og, num_heads_og, head_size_og); + if (head_size_og % 8 != 0) { + out = torch::empty_like(q_padded); + } else if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads, seqlen_q, head_size_og}).transpose(1, 2); + } } else { out = torch::empty_like(q_padded); }