Skip to content

Commit

Permalink
Fix out kwarg shape check with ngroups swapped (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 committed May 31, 2024
1 parent 03bf1f8 commit e5da6e4
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he

const int batch_size = sizes[0];
int seqlen_q = sizes[1];
const int seqlen_q_og = seqlen_q;
int num_heads = sizes[2];
const int num_heads_og = num_heads;
const int head_size_og = sizes[3];

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
Expand Down Expand Up @@ -784,8 +786,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
CHECK_SHAPE(out, batch_size, seqlen_q_og, num_heads_og, head_size_og);
if (head_size_og % 8 != 0) {
out = torch::empty_like(q_padded);
} else if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads, seqlen_q, head_size_og}).transpose(1, 2);
}
} else {
out = torch::empty_like(q_padded);
}
Expand Down

0 comments on commit e5da6e4

Please sign in to comment.