diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index ac6fefdc6a..611beb7b84 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -29,7 +29,7 @@ transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *so return view; } -// Helper function to calculate stride for packed QKV tensor unpacking +// Helper function to calculate stride in bytes for packed QKV tensor unpacking size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, size_t h, size_t d) { size_t stride = 0; diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index a99f4fae90..ac7eba5c87 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -123,17 +123,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { - // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - - // For kv_packed auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim}; - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - - // For separate q, k, v auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; @@ -156,7 +147,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( nvte_tensor_pack_create(&aux_output_tensors); TensorWrapper query_workspace_tensor; - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; @@ -174,37 +164,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); auto ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, - false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), - nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), - ragged_offset_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, - query_workspace_tensor.data(), nullptr); - } else { - NVTE_ERROR("Unsupported QKVLayout."); - } + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), + ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_output_tensors); @@ -291,47 +258,57 @@ static void FusedAttnForwardImpl( /* Call the underlying NVTE API */ auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); + + // Prepare Q, K, V pointers and shapes based on layout + // Python passes dummy tensors for unused slots, so we extract from the actual packed data + void *q_ptr = q; + void *k_ptr = k; + void *v_ptr = v; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false, - false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + // QKV packed in q: [batch*seqlen, 3, heads, dim] + // Python passes: q=packed_qkv, k=dummy, v=dummy + // Extract K and V pointers from the packed q data + NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen"); + NVTE_CHECK(qk_head_dim == v_head_dim, + "For QKV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim); + q_ptr = q; + k_ptr = static_cast(static_cast(q) + stride); + v_ptr = static_cast(static_cast(q) + 2 * stride); + // For packed QKV, all have same shape since they're views into the same packed tensor + k_shape = q_shape; + v_shape = q_shape; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(k, kv_shape, dtype); - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); + // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim] + // Python passes: q=query, k=packed_kv, v=dummy + // Extract V pointer from the packed k data + NVTE_CHECK(qk_head_dim == v_head_dim, + "For KV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim); + q_ptr = q; + k_ptr = k; + v_ptr = static_cast(static_cast(k) + stride); + // V has same shape as K since they're packed together + v_shape = k_shape; } + // else NVTE_HD_HD_HD: pointers and shapes already correct + + auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype); + auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype); + auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype); + + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -414,20 +391,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { - // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - - // For kv_packed auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim}; - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - - // For separate q, k, v auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); @@ -450,7 +416,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( TensorWrapper query_workspace_tensor; - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; @@ -471,42 +436,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - nvte_fused_attn_bwd_qkvpacked( - qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, false, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_bwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); - } else { - NVTE_ERROR("Unsupported qkv_layout."); - } + + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -552,76 +493,82 @@ static void FusedAttnBackwardImpl( softmax_aux, rng_state, bias); /* Call the underly NVTE API */ + // Prepare Q, K, V pointers and shapes based on layout + void *q_ptr = q; + void *k_ptr = k; + void *v_ptr = v; + void *dq_ptr = dq; + void *dk_ptr = dk; + void *dv_ptr = dv; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); - if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype), - stream); - } - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, deterministic, - false, workspace_tensor.data(), stream); + // QKV packed in q: [batch*seqlen, 3, heads, dim] + NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen"); + NVTE_CHECK(qk_head_dim == v_head_dim, + "For QKV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim); + q_ptr = q; + k_ptr = static_cast(static_cast(q) + stride); + v_ptr = static_cast(static_cast(q) + 2 * stride); + dq_ptr = dq; + dk_ptr = static_cast(static_cast(dq) + stride); + dv_ptr = static_cast(static_cast(dq) + 2 * stride); + k_shape = q_shape; + v_shape = q_shape; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(k, kv_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); - if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype), - stream); - } - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, deterministic, false, - workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv_tensor = TensorWrapper(dv, v_shape, dtype); - if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); + // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim] + NVTE_CHECK(qk_head_dim == v_head_dim, + "For KV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim); + q_ptr = q; + k_ptr = k; + v_ptr = static_cast(static_cast(k) + stride); + dq_ptr = dq; + dk_ptr = dk; + dv_ptr = static_cast(static_cast(dk) + stride); + // V has same shape as K since they're packed together + v_shape = k_shape; + } + + auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype); + auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype); + auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype); + auto dq_tensor = TensorWrapper(dq_ptr, q_shape, dtype); + auto dk_tensor = TensorWrapper(dk_ptr, k_shape, dtype); + auto dv_tensor = TensorWrapper(dv_ptr, v_shape, dtype); + + if (is_ragged) { + size_t dtype_size = typeToSize(dtype); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + // For packed QKV, dq contains all gradients (dq, dk, dv) - clear all at once + cudaMemsetAsync(dq, 0, 3 * transformer_engine::jax::product(q_shape) * dtype_size, stream); + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + // Clear dq + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream); + // For packed KV, dk contains both dk and dv - clear all at once + cudaMemsetAsync(dk, 0, 2 * transformer_engine::jax::product(k_shape) * dtype_size, stream); + } else { + // All separate - clear each individually + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream); + cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * dtype_size, stream); + cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * dtype_size, stream); } - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, deterministic, - false, workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); } + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream); + nvte_tensor_pack_destroy(&aux_input_tensors); }