-
Notifications
You must be signed in to change notification settings - Fork 585
feat: TRTLLM FMHAv2 backend for ctx attention #2142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Important Review skippedReview was skipped as selected files did not have any reviewable changes. 💤 Files selected but had no reviewable changes (8)
You can disable this status message by setting the Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds extensive FMHA v2 CUDA infrastructure: numeric conversions, fragment/accumulator primitives, SMEM/GMEM tile loaders and storers (including packed/TMA paths), Hopper GMMA/TMA utilities, masking and kernel trait configurations across Volta/Turing/Ampere/Ada/Hopper targets. Changes
Sequence Diagram(s)sequenceDiagram
participant Host
participant Kernel
participant GMEM
participant SMEM
participant GMMA
rect rgb(240,248,255)
note right of Host: Setup
Host->>Kernel: pass Params, TMA descriptors, pointers
end
rect rgb(255,248,240)
note right of Kernel: QKV load phase
Kernel->>GMEM: Gmem_tile_qkv.load() (LDG / LDGSTS / TMA)
GMEM->>SMEM: write via Smem_tile (swizzled / TMA store)
end
rect rgb(248,255,240)
note right of Kernel: Compute phase
Kernel->>GMMA: Compute_tile_with_gmma.compute() -> wgmma/igmma/qgmma ops
GMMA->>Kernel: accumulate into Fragment_accumulator
Kernel->>Kernel: increment descriptors / advance tiles
end
rect rgb(240,240,255)
note right of Kernel: Softmax & O write
Kernel->>Kernel: Tile_o_normalizer (max/sum)
Kernel->>GMEM: Gmem_tile_o.store() (quantize/pack, optional I2F trick)
end
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Areas needing extra attention:
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request ports the trtllm fmhav2 library to FlashInfer, adding a new backend that supports prefill cases for context attention. This enhancement involves incorporating CUDA kernels for numerical type conversions and introducing fragment structures for optimized memory handling, significantly expanding FlashInfer's capabilities. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request ports the TensorRT-LLM FMHAv2 library to support prefill cases. The changes are extensive and introduce a large amount of low-level CUDA code, including kernels for various GPU architectures (Volta, Turing, Ampere, Hopper) and support for features like paged KV-cache and TMA. The code is highly templated and specialized for performance. My review focused on correctness and potential issues in the newly added files. I've found one issue related to unsafe type casting that could lead to incorrect behavior.
| out.x = reinterpret_cast<int8_t const&>(a); | ||
| out.y = reinterpret_cast<int8_t const&>(b); | ||
| out.z = reinterpret_cast<int8_t const&>(c); | ||
| out.w = reinterpret_cast<int8_t const&>(d); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of reinterpret_cast here is unsafe. The cvt.rni.sat.s8.f32 instruction writes the 8-bit result to the lower byte of the 32-bit destination register, but the contents of the upper 24 bits are not guaranteed to be zero. reinterpret_casting this to an int8_t is endian-dependent and relies on this assumption. Using static_cast is safer, more explicit, and avoids potential correctness issues.
out.x = static_cast<int8_t>(a);
out.y = static_cast<int8_t>(b);
out.z = static_cast<int8_t>(c);
out.w = static_cast<int8_t>(d);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 16
🧹 Nitpick comments (30)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h (2)
48-98: Dead code branch due to prior static_assert.Line 58 enforces
Mma_tile::MMAS_N == 1at compile time, making the condition at line 92 always true. Theelsebranch (lines 94-96) is unreachable dead code.Consider simplifying by removing the redundant conditional:
// Each thread of a quad writes 16B per STS -> 64B per store. - if (Mma_tile::MMAS_N == 1) { - this->smem_write_ ^= 64; - } else { - assert(false && "Unsupported"); - } + // MMAS_N == 1 enforced by static_assert above. + this->smem_write_ ^= 64;Additionally, line 81 contains commented-out code. If it's no longer needed, consider removing it; otherwise, add a comment explaining why it's retained.
131-154: Cryptic comment and loop-internal static_assert placement.
The comment at line 140 ("inplace multiples seem to be 1, 3, 1, 7, 1, 3, 1,") is unclear. Consider clarifying what these multiples represent or referencing documentation.
The
static_assert(MMAS_M_PER_LOOP == 1)at line 139 is placed inside the loop body. While functionally correct (compile-time check), placing it before the loop would be cleaner and more conventional.inline __device__ void store(Accumulator const (&acc)[1][1], int mi) { enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; + // The number of MMAs that are stored per loop iteration. + enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; + static_assert(MMAS_M_PER_LOOP == 1); #pragma unroll for (int ni = 0; ni < Mma_tile::CORES_N; ++ni) { - // The number of MMAs that are stored per loop iteration. - enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; - - static_assert(MMAS_M_PER_LOOP == 1); - // inplace multiples seem to be 1, 3, 1, 7, 1, 3, 1, + // XOR pattern for swizzled smem addressing (pattern: 0, 16, 32, 48, ...) auto smem_write = this->smem_write_ ^ (ni * 16);This same pattern applies to the FP32 (lines 187-216), BF16 (lines 249-280), and other specializations.
csrc/fmha_v2/fmha/hopper/tma_types.h (1)
100-100: Minor typo in comment."descritptro" should be "descriptor".
-// The 512 bit of descritptro for im2col mode. +// The 512 bit of descriptor for im2col mode.csrc/fmha_v2/fmha/gmem_tile_o_packed.h (2)
799-801: Unused variablerow_ptr.The variable
row_ptron line 800 is declared but never used. This appears to be leftover debug code.// Packs the 32bit/16bit values to 8bit. // Depending on the type, applies extra scaling with parameter scale_bmm2. Stg_packed_type dst = Acc_packer<Src_type, Dst_type>::run(this, src[ii]); - float const* row_ptr = reinterpret_cast<float const*>(&src[ii]);
1286-1294: Consider removing commented-out debug code.Lines 1287-1294 and 1310-1313 contain commented-out assertions and debug printf statements. These can be removed to improve code clarity.
csrc/fmha_v2/fmha/fragment.h (2)
1821-1825: Remove stray semicolon.There's a stray semicolon after the
prev_max_declaration on line 1824.float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; - ; float prev_sum_[ROWS_PER_THREAD] = {0};
1119-1121: NaN check pattern is intentional but non-obvious.The pattern
sum[jj] != sum[jj]is used to detect NaN (since NaN != NaN is true). While this works, it's non-obvious. Consider usingisnan()for clarity, or add a brief inline comment explaining the idiom if performance is critical.Also applies to: 1149-1149, 1413-1413, 1436-1436, 2142-2143, 2170-2171
csrc/fmha_v2/fmha/numeric_types.h (1)
19-22: Consider using a more robust FP8 feature detection.The TODO comment suggests uncertainty about the approach. The current check
CUDART_VERSION >= 11080is reasonable but usingtrueas the macro value is unconventional. A more typical pattern would be to just define the macro without a value.#if CUDART_VERSION >= 11080 -// TODO Better way? -#define FMHA_CUDA_SUPPORTS_FP8 true +#define FMHA_CUDA_SUPPORTS_FP8 #endifThen use
#ifdef FMHA_CUDA_SUPPORTS_FP8instead of#if FMHA_CUDA_SUPPORTS_FP8throughout.csrc/fmha_v2/fmha/paged_kv_cache.h (1)
46-58: Consider adding power-of-2 validation and using integer-based log2.The comment on line 28-29 states
mTokensPerBlockmust be a power of 2, but this constraint is not enforced. Additionally, using floating-pointlog2()for an integer operation introduces unnecessary overhead and potential precision issues.Apply this diff to add validation and use an integer-based approach:
Kv_block_array(int32_t batchSize, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, int32_t bytesPerBlock, void* poolPtr) : mMaxSeqs(batchSize), mMaxBlocksPerSeq(maxBlocksPerSeq), mTokensPerBlock(tokensPerBlock), mBytesPerBlock{bytesPerBlock}, mPoolPtr{poolPtr}, mBlockOffsets{nullptr} { - float const tokensPerBlockSeqLog2 = log2(mTokensPerBlock); - mTokensPerBlockLog2 = static_cast<int>(tokensPerBlockSeqLog2); + // Validate power-of-2 constraint + assert((tokensPerBlock & (tokensPerBlock - 1)) == 0 && tokensPerBlock > 0); + // Use integer bit counting instead of floating-point log2 + mTokensPerBlockLog2 = __builtin_ctz(static_cast<uint32_t>(mTokensPerBlock)); }csrc/fmha_v2/fmha/hopper/utils_tma.h (1)
138-152: Runtime assert on unsupported architectures may cause unexpected failures.The
tmastg_arrive()andtmastg_wait()functions callassert(false)when compiled for architectures below SM90. This could cause silent failures in release builds (whereassertis typically a no-op) or unexpected crashes in debug builds if these functions are accidentally called.Consider using a static assertion or a more explicit compile-time guard if these functions should never be called on older architectures:
inline __device__ void tmastg_arrive() { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("cp.async.bulk.commit_group;"); #else - assert(false); + // These functions are only valid on Hopper+ architectures + static_assert(__CUDA_ARCH__ >= 900, "tmastg_arrive requires SM90+"); #endif } inline __device__ void tmastg_wait() { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory"); #else - assert(false); + static_assert(__CUDA_ARCH__ >= 900, "tmastg_wait requires SM90+"); #endif }Alternatively, if backward compatibility is intended but calling these is an error, use a more informative runtime check that works in release builds.
csrc/fmha_v2/fmha/hopper/tma_descriptor.h (1)
341-345: Unused memberdesc_ptr_d.The member
desc_ptr_dis declared but never initialized or used within the class. The comment says "Device desc ptr should be allocated outside the class and reused" (line 22), but having an unused member is confusing.Consider removing the unused member or documenting its intended external usage pattern:
// The TMA descriptor. Each is of 512 bit. cudaTmaDesc* desc_ptr_h; - // The TMA descriptor on the device memory. - cudaTmaDesc* desc_ptr_d; // Number of batches int batch_size = 0; };csrc/fmha_v2/fmha/hopper/arrive_wait.h (1)
373-378: Unused member variableid_.The
id_member is stored in the constructor (line 112) but never used anywhere in the class methods. All methods takeidas a parameter instead.Consider removing
id_if it's not needed, or using it in methods that currently takeidas a parameter:private: // smem barrier base pointer uint64_t* bar_base_; - // barrier id - int id_; };csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h (1)
128-142: Unused member variables declared.The members
preds_,fetch_, androw_are declared but never used in the visible methods (commit,load,move,store). If these are intentionally reserved for future use or subclass consumption, consider adding a comment. Otherwise, they contribute to unnecessary memory overhead per tile instance.csrc/fmha_v2/fmha/gmem_tile_o.h (2)
156-158: Consider usingbooltype for predicate flags.The
is_active_for_last_stg_member is declared asintbut used as a boolean predicate. Usingboolwould better convey intent and potentially allow compiler optimizations.- // Is the thread active for the last STG? - int is_active_for_last_stg_; + // Is the thread active for the last STG? + bool is_active_for_last_stg_;
424-425: Same type consistency suggestion for boolean flags.Similar to
Hmma_gmem_tile_o, consider usingboolforis_active_andis_active_for_last_stg_to better express intent.csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h (1)
228-231: Unused template parameter in wrapper function.The
bool /*ignored*/parameter inhgmma_bf16serves no purpose. If this is for API consistency with other GMMA wrappers, consider adding a brief comment explaining the pattern.csrc/fmha_v2/fmha/hopper/kernel_traits.h (2)
98-108: FLAGS bit layout could benefit from named constants.The bitfield extraction for
USE_LDGSTS_Q/K/VandUSE_TMA_Q/K/Vis correct, but the magic numbers (0x1u, 0x2u, 0x4u) could be replaced with named constants for better maintainability.+ // FLAGS bit definitions + enum { + FLAG_LDGSTS_Q = 0x1u, + FLAG_LDGSTS_K = 0x2u, + FLAG_LDGSTS_V = 0x4u, + FLAG_HEADS_NOT_INTERLEAVED = 0x20u, + FLAG_BMM1_SOFTCAPPING = 0x800u + }; + // Do we use LDGSTS for Q, K or V. If not, TMA is used! - enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; - enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; - enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; + enum { USE_LDGSTS_Q = (FLAGS & FLAG_LDGSTS_Q) != 0u }; + enum { USE_LDGSTS_K = (FLAGS & FLAG_LDGSTS_K) != 0u }; + enum { USE_LDGSTS_V = (FLAGS & FLAG_LDGSTS_V) != 0u };
358-363: Hardcoded VERSION=2 in alias template.The
FMHA_kernel_traits_hopper_v2alias hardcodesVERSION=2. This is likely intentional for versioning, but consider documenting why this specific version is used here.csrc/fmha_v2/fmha/hopper/fragment.h (1)
369-384: Useif constexprfor compile-time type dispatch.The type dispatch using
std::is_same_vwith regularif/elseworks but generates unnecessary runtime branches. SinceInput_type_AandInput_type_Bare template parameters known at compile time,if constexpris more appropriate.Apply this diff for cleaner compile-time dispatch:
- if (std::is_same_v<Input_type_A, e4m3_t> && std::is_same_v<Input_type_B, e4m3_t>) { + if constexpr (std::is_same_v<Input_type_A, e4m3_t> && std::is_same_v<Input_type_B, e4m3_t>) { qgmma_rfa_e4m3_e4m3_fp32<GMMA_N, INCREMENT_SCORE_BOARD>(a.regs_, single_desc_b.get(), this->regs_); - } else if (std::is_same_v<Input_type_A, e5m2_t> && std::is_same_v<Input_type_B, e4m3_t>) { + } else if constexpr (std::is_same_v<Input_type_A, e5m2_t> && std::is_same_v<Input_type_B, e4m3_t>) { // ... rest of the chainThe same change should be applied to the SMEM-SMEM variant at lines 405-419.
csrc/fmha_v2/fmha/hopper/compute_tile.h (1)
273-277: Remove or enable commented-out static_asserts.These commented static_asserts suggest incomplete compile-time validation. Either remove them if they're no longer relevant, or uncomment them with proper conditions to catch configuration errors early.
csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (4)
150-152: Unused member variableis_active_for_last_stg_.This member variable is declared but never assigned or read in the class. It appears to be dead code.
Consider removing this unused member or implementing the intended functionality. The same issue exists in
Gmem_tile_o_gmma_32bit_8bitat line 880.- // Is the thread active for the last STG? - int is_active_for_last_stg_;
289-306: Incorrect comment labels in nested loops.The closing brace comments don't match the actual loop variables. Line 290 says
// row_idxbut it closescol_idx, and line 306 says// mma_nibut it closesrow_idx.Fix the comment labels for clarity:
- } // row_idx - } // col_idx + } // col_idx + } // mma_ni ... - } // row_idx - } // mma_ni + } // col_idx + } // row_idx
584-586: Document the compiler workaround for predicate behavior.The comment mentions that without this
__shfl_sync, "the predicate will not behavior as expected for unknown reason." This suggests a compiler bug or undefined behavior that should be tracked.Consider adding a more detailed comment or a TODO to track this issue:
- // WARNING: Without this line, the predicate will not behavior as expected for unknown reason. + // WARNING: Without this line, the predicate will not behave as expected. + // This appears to be a compiler optimization issue. See: [link to bug/issue if available] + // TODO: Remove this workaround when the underlying issue is resolved. num_valid_rows_ = __shfl_sync(0xffffffff, num_valid_rows_, 0);
1082-1115: Add#undef STORE_COLUMNSafter macro usage.The
STORE_COLUMNSmacro is defined but never undefined, which could cause redefinition warnings if this pattern is used elsewhere in the codebase.Add
#undef STORE_COLUMNSafter line 1115:for (int ci = 0; ci < VALID_COLS_PER_THREAD_FOR_LAST_MMA; ci += 2) { STORE_COLUMNS() } } +#undef STORE_COLUMNS }csrc/fmha_v2/fmha/hopper/smem_tile.h (4)
376-383: Remove or implement commented-out buffer management code.The
move_next_write_buffer()method contains commented-out implementation code. If the buffer cycling is not needed for this path, the comments should be removed. If it is needed, the implementation should be restored.The same pattern appears at lines 538-544. Either remove the commented code or implement the functionality:
inline __device__ void move_next_write_buffer() { - // if( BUFFERS_PER_TILE > 1 ) { - // this->smem_write_offset_ += ( smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) - // ? -BYTES_PER_TILE_INC_BOUNDARY - // : BYTES_PER_BUFFER; - // } }
169-172: Add parentheses to clarify operator precedence in XOR expression.The expression on lines 170-171 relies on operator precedence where
+binds tighter than^. While likely correct, this is subtle and should be explicitly parenthesized for clarity.} else if (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B) { - smem_write_col = (tidx % (THREADS_PER_ROW / 2)) ^ - smem_write_xor + ((tidx % THREADS_PER_ROW) / (THREADS_PER_ROW / 2)) * 4; + smem_write_col = (tidx % (THREADS_PER_ROW / 2)) ^ + (smem_write_xor + ((tidx % THREADS_PER_ROW) / (THREADS_PER_ROW / 2)) * 4); }The same pattern appears at lines 343-345.
1326-1485: Remove large commented-out code block.This ~160-line commented block significantly increases file size and reduces readability. If this code is needed for reference, consider moving it to a separate documentation file or keeping it in version control history.
The
load_and_store()method should either implement the required functionality or remain as an empty stub without the commented implementation details.
1728-1730: Consider static_assert instead of runtime assert for unsupported configurations.The
assert(false)for unsupported warp/dimension configurations will only fail at runtime. Since these are compile-time known values, astatic_assertwould catch misconfigurations earlier.- } else { - assert(false); - } + } else { + static_assert(sizeof(Traits) == 0, "Unsupported warp/dimension configuration for Transposer"); + }Note: The
sizeof(Traits) == 0idiom creates a dependent false that only triggers if this branch is instantiated.csrc/fmha_v2/fmha/hopper/gmma_descriptor.h (2)
119-133: Complex ternary chain is hard to maintain.The nested ternary expression for
BYTES_PER_DESCspans 14 lines and is difficult to verify. Consider using a constexpr helper function or a lookup table approach for better readability.Consider refactoring to a constexpr function:
static constexpr uint32_t compute_bytes_per_desc() { if (Gmma_vector_size == Gmma_descriptor_size::ALL) return 0; if (Gmma_trans == Gmma_descriptor_transpose::TRANS) { if (Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) return GMMA_K * BYTES_PER_LEADING_DIM; if (Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B) return (GMMA_K / 2) * BYTES_PER_LEADING_DIM; // ... etc } // ... } static constexpr uint32_t BYTES_PER_DESC = compute_bytes_per_desc();
175-201: Consider consolidating descriptor initialization loops.The constructor has four separate loops that all iterate over
NUM_DESCRIPTORS. These could be merged into a single loop for cleaner code.inline __device__ Gmma_descriptor_a() { #pragma unroll for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { - desc[desc_idx] = 0; - } - -#pragma unroll - for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { - desc[desc_idx] |= DESCRIPTOR_MODE_IN_BIT_LOCATION; - } - -#pragma unroll - for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { - desc[desc_idx] |= STRIDE_BYTE_OFFSET_IN_BIT_LOCATION; - } - - if (LEADING_BYTE_OFFSET_NEEDED) { -#pragma unroll - for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) { + desc[desc_idx] = DESCRIPTOR_MODE_IN_BIT_LOCATION | STRIDE_BYTE_OFFSET_IN_BIT_LOCATION; + if constexpr (LEADING_BYTE_OFFSET_NEEDED) { desc[desc_idx] |= LEADING_BYTE_OFFSET_IN_BIT_LOCATION; } } }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (31)
csrc/fmha_v2/convert.cu(1 hunks)csrc/fmha_v2/fmha/alibi_params.h(1 hunks)csrc/fmha_v2/fmha/fragment.h(1 hunks)csrc/fmha_v2/fmha/gemm.h(1 hunks)csrc/fmha_v2/fmha/gmem_tile_o.h(1 hunks)csrc/fmha_v2/fmha/gmem_tile_o_packed.h(1 hunks)csrc/fmha_v2/fmha/gmem_tile_ps.h(1 hunks)csrc/fmha_v2/fmha/gmem_tile_qkv.h(1 hunks)csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h(1 hunks)csrc/fmha_v2/fmha/hopper/arrive_wait.h(1 hunks)csrc/fmha_v2/fmha/hopper/compute_tile.h(1 hunks)csrc/fmha_v2/fmha/hopper/fragment.h(1 hunks)csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h(1 hunks)csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h(1 hunks)csrc/fmha_v2/fmha/hopper/gmma_descriptor.h(1 hunks)csrc/fmha_v2/fmha/hopper/kernel_traits.h(1 hunks)csrc/fmha_v2/fmha/hopper/smem_tile.h(1 hunks)csrc/fmha_v2/fmha/hopper/smem_tile_o.h(1 hunks)csrc/fmha_v2/fmha/hopper/tma_descriptor.h(1 hunks)csrc/fmha_v2/fmha/hopper/tma_types.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_gmma.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_hgmma.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_igmma.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_tma.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_warpgroup.h(1 hunks)csrc/fmha_v2/fmha/kernel_traits.h(1 hunks)csrc/fmha_v2/fmha/mask.h(1 hunks)csrc/fmha_v2/fmha/numeric_types.h(1 hunks)csrc/fmha_v2/fmha/paged_kv_cache.h(1 hunks)csrc/fmha_v2/fmha/smem_tile.h(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/fmha_v2/fmha/gemm.hcsrc/fmha_v2/fmha/hopper/smem_tile_o.h
🧬 Code graph analysis (15)
csrc/fmha_v2/fmha/gemm.h (1)
csrc/fmha_v2/fmha/fragment.h (1)
fmha(20-182)
csrc/fmha_v2/fmha/numeric_types.h (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (1)
__nv_fp8_e5m2(91-93)
csrc/fmha_v2/fmha/mask.h (1)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
int(186-186)
csrc/fmha_v2/fmha/hopper/tma_descriptor.h (1)
csrc/cudnn_sdpa_utils.h (1)
set_tensor_common_0(92-96)
csrc/fmha_v2/fmha/gmem_tile_qkv.h (3)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
int(186-186)csrc/fmha_v2/fmha/warpspec/epilogue.h (1)
int(95-104)csrc/fmha_v2/fmha/utils.h (1)
pack_predicates(1366-1370)
csrc/fmha_v2/fmha/hopper/arrive_wait.h (2)
csrc/fmha_v2/fmha/hopper/smem_tile.h (1)
fmha(22-562)csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h (4)
bar_create(795-806)bar_base(951-951)bar_wait(853-869)set_bar_base_dsmem(901-904)
csrc/fmha_v2/fmha/hopper/fragment.h (3)
csrc/fmha_v2/fmha/fragment.h (1)
fmha(20-182)csrc/fmha_v2/fmha/gemm.h (1)
fmha(18-35)csrc/fmha_v2/fmha/hopper/utils_igmma.h (1)
igmma_int8_int32(201-203)
csrc/fmha_v2/fmha/hopper/compute_tile.h (1)
csrc/fmha_v2/fmha/hopper/smem_tile.h (1)
fmha(22-562)
csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h (2)
csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (1)
fmha(17-200)csrc/fmha_v2/fmha/hopper/smem_tile_o.h (1)
fmha(17-99)
csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (7)
csrc/fmha_v2/fmha/gmem_tile_o_packed.h (2)
fmha(19-202)v2(20-180)csrc/fmha_v2/fmha/gmem_tile_qkv.h (1)
fmha(15-162)csrc/fmha_v2/fmha/hopper/fragment.h (1)
fmha(20-35)csrc/fmha_v2/fmha/hopper/smem_tile.h (5)
fmha(22-562)store(2189-2193)store(2197-2199)store(2203-2207)store(2218-2221)csrc/fmha_v2/fmha/hopper/tma_descriptor.h (1)
fmha(16-121)csrc/fmha_v2/fmha/warpspec/epilogue.h (1)
int(95-104)csrc/fmha_v2/fmha/utils.h (11)
char(2294-2302)float(838-853)float(899-904)float(1057-1068)float(1151-1155)float(1159-1163)float(1251-1251)float(2306-2311)float(2315-2319)float(2323-2338)float(2343-2351)
csrc/fmha_v2/fmha/smem_tile.h (1)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
int(186-186)
csrc/fmha_v2/fmha/gmem_tile_o.h (1)
csrc/fmha_v2/fmha/traits.h (4)
public(264-264)public(356-359)public(433-436)public(481-493)
csrc/fmha_v2/fmha/gmem_tile_o_packed.h (1)
csrc/fmha_v2/fmha/traits.h (4)
public(264-264)public(356-359)public(433-436)public(481-493)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h (1)
csrc/fmha_v2/fmha/smem_tile_o.h (12)
Smem_tile_o_base_8bit_mma(1284-1360)Smem_tile_o(104-197)Smem_tile_o(691-691)Smem_tile_o(723-723)Smem_tile_o(753-753)Smem_tile_o(802-820)Smem_tile_o(957-975)Smem_tile_o(1452-1452)Smem_tile_o(1466-1466)Smem_tile_o(1480-1480)Smem_tile_o(1494-1494)Smem_tile_o(1508-1508)
csrc/fmha_v2/fmha/hopper/utils_hgmma.h (1)
csrc/fmha_v2/fmha/gemm.h (1)
fmha(18-35)
🪛 Clang (14.0.6)
csrc/fmha_v2/fmha/gemm.h
[error] 15-15: 'fmha/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/alibi_params.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_gmma.h
[error] 15-15: 'fmha/hopper/utils_hgmma.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_warpgroup.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/numeric_types.h
[error] 13-13: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/paged_kv_cache.h
[error] 15-15: 'cuda_runtime.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_tma.h
[error] 15-15: 'fmha/hopper/tma_types.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h
[error] 14-14: 'fmha/traits.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/mask.h
[error] 15-15: 'fmha/traits.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/tma_descriptor.h
[error] 14-14: 'fmha/hopper/tma_types.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gmem_tile_qkv.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/gmma_descriptor.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/arrive_wait.h
[error] 66-66: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 66-66: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/fragment.h
[error] 15-15: 'fmha/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/kernel_traits.h
[error] 14-14: 'fmha/gmem_tile_qkv_packed.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/compute_tile.h
[error] 14-14: 'fmha/hopper/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_igmma.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h
[error] 14-14: 'fmha/gmem_tile_o_packed.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/smem_tile.h
[error] 15-15: 'fmha/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gmem_tile_o.h
[error] 15-15: 'fmha/traits.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gmem_tile_o_packed.h
[error] 15-15: 'fmha/numeric_types.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/smem_tile.h
[error] 15-15: 'fmha/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h
[error] 15-15: 'fmha/smem_tile_o.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/tma_types.h
[error] 15-15: 'fmha/numeric_types.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h
[error] 14-14: 'fmha/traits.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/kernel_traits.h
[error] 15-15: 'fmha/alibi_params.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_hgmma.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gmem_tile_ps.h
[error] 14-14: 'fmha/hopper/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/fragment.h
[error] 15-15: 'fmha/traits.h' file not found
(clang-diagnostic-error)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (44)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h (5)
21-23: LGTM!Simple placeholder struct with zero byte footprint - appropriate for conditional tile allocation patterns.
159-217: FP32 specialization mirrors FP16 structure correctly.The conversion via
float2_to_half2for FP32 accumulators to half-precision storage is appropriate.Same optional refactoring suggestions apply as noted for the FP16 specialization (lines 193-196).
221-281: BF16 specialization correctly uses float2_to_bf16_x2 conversion.The additional
static_assert(Mma_tile::CORES_M == 2)at line 252 is good defensive validation for this specific data type path.Same optional refactoring suggestions apply regarding loop-internal static_assert placement (lines 257-259).
285-321: LGTM!Clean thin wrappers for E4M3 (FP8) and INT8 quantized paths, correctly delegating to
Smem_tile_o_gmma_32bit_8bit.
1-16: I encountered a technical limitation - the repository couldn't be cloned for verification. However, based on the original review comment's analysis, I can provide the rewritten version. The review already identified the context clearly (ported TensorRT-LLM code with expected include path setup) and approved the changes, with the verification request being supplementary.
License and include structure are appropriate for ported TensorRT-LLM code.
The static analysis error about
'fmha/smem_tile_o.h' file not foundis a false positive expected in isolated analysis without the complete build system context. This file is part of the FMHA v2 library with include paths configured in the build system.csrc/fmha_v2/fmha/alibi_params.h (1)
40-48: Well-structured parameter encapsulation with clear documentation.The struct cleanly encapsulates ALiBi parameters with precomputed values. The comment explaining the design decision for
head_idx_offsetandsequence_pos_offsetis helpful for maintainability.csrc/fmha_v2/convert.cu (1)
58-62: Correct size_t casting for dimension multiplication.This wrapper correctly casts to
size_tbefore multiplication, preventing integer overflow for large tensors. The same pattern should be applied consistently to line 183.csrc/fmha_v2/fmha/hopper/tma_types.h (1)
114-120: TMA descriptor types are well-structured.The descriptor structures are correctly sized at 64 bytes (512 bits) with proper alignment for TMA hardware requirements. The enum definitions provide clear configuration options.
csrc/fmha_v2/fmha/mask.h (4)
143-200: Clean implementation of sequence-length masking (V2).The V2 mask correctly implements basic sequence-length masking with proper warp/lane decomposition and position calculation. The
static_asserton line 146 ensuringWARPS_K == 1is good defensive programming.
344-352: FIXME comment indicates incomplete support for certain traits.The
mtp_token_idx_array assumes exactly 2 rows per thread, which may not hold for Volta/Hopper-GMMA traits. This limitation should be addressed or clearly documented as a known restriction.Consider tracking this as a follow-up issue. The current implementation may produce incorrect results if used with unsupported trait configurations.
368-398: Sliding window mask (V4) correctly extends causal mask.The V4 mask properly implements the sliding window attention pattern. The
is_validcheck at line 393 correctly constrains attention to positions within the sliding window while maintaining causality.
403-515: Custom mask (V5) with flexible packed mask loading.The V5 mask provides a well-structured approach for loading custom masks from global memory with proper handling of MMA group boundaries and partial mask loads via the predicate mechanism (line 457).
csrc/fmha_v2/fmha/gmem_tile_o_packed.h (2)
33-86: Well-designed base template with comprehensive tiling calculations.The
Hmma_gmem_tile_obase template correctly computes tiling parameters with appropriate static assertions to catch configuration errors at compile time. The calculation ofROWS_PER_LOOP,STGS_PER_LOOP, and handling of incomplete STGs shows careful attention to edge cases.
450-628: Comprehensive type conversion helpers viaAcc_packer.The
Acc_packertemplate specializations provide a clean, extensible pattern for accumulator-to-output type conversions with optional scaling. The specializations correctly handle:
- INT32 → INT8 with saturation
- FP32 → e4m3/e5m2 (FP8)
- FP32 → FP16/BF16
- FP16 → e4m3
csrc/fmha_v2/fmha/fragment.h (5)
13-18: LGTM: Header structure and includes are appropriate.The static analysis hint about missing
fmha/traits.his a false positive - the build system configures include paths correctly for this CUDA code.
543-578: LGTM: Fragment_accumulator MMA implementations.The MMA implementations use correct PTX instructions with proper register constraints. The accumulator patterns for Volta/Turing/Ampere/Ada architectures are well-structured.
943-958: Appropriate architecture guards for Ada FP8 instructions.The
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890guards correctly protect the e4m3 MMA instructions, with atrapfallback for unsupported architectures.Also applies to: 972-990
785-809: Based on my verification through web searches, I can now provide my assessment:The review comment's concern appears to be based on a misunderstanding of template inheritance semantics.
Key findings:
PTX instruction is correct: The instruction
mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32explicitly uses FP32 accumulators (thef32at the end specifies the accumulator type D).32-bit registers store FP32: The
"+r"(reg(...))constraints in inline assembly refer to 32-bit registers (.b32in PTX), which correctly hold 32-bit FP32 floating-point values.CUTLASS pattern confirmation: CUTLASS separates input element types (BF16) from accumulator types (FP32) via template parameters. The comment "BF16 MMA must accumulate with at least FP32" is correct.
Fragment template semantics: In CUTLASS/flashinfer,
Fragment<bf16_t, 8>as a base class likely refers to a trait template specialization or output fragment type for subsequent operations, not the actual accumulator storage. The actual accumulator storage is defined by the PTX instruction (f32).The code at lines 785-809 appears to correctly implement BF16 inputs with FP32 accumulation, consistent with the comment and NVIDIA's standard BF16 Tensor Core workflow.
Without direct repository access to verify the exact template inheritance at line 785 and lines 858-898, I cannot provide absolute certainty.
2286-2307: Based on my verification attempts, I encountered limitations accessing the repository directly due to cloning issues. However, my web searches reveal important findings:
- FP32_I2F_MAGIC_NUMBER_HEX: GitHub searches found no matches for this symbol in the flashinfer repository
- USE_I2F_EMULATION_TRICK: GitHub searches found no matches for this macro in flashinfer
- Clear_accumulator structure: No matches found for the exact code snippet shown
This is expected behavior for code under review, as it may not yet be in the public main branch. However, since I cannot access the actual PR branch directly to verify the include chain and definitions, I must mark this for manual verification.
Verify that
FP32_I2F_MAGIC_NUMBER_HEXandUSE_I2F_EMULATION_TRICKare properly defined in included headers.The
Clear_accumulator<int32_t, WARPS_K>struct usesFP32_I2F_MAGIC_NUMBER_HEXwhenUSE_I2F_EMULATION_TRICKis defined. Ensure both the macro and constant are declared in reachable headers—either directly included in fragment.h or transitively through existing includes. Verify the preprocessor conditional logic (#if defined(USE_I2F_EMULATION_TRICK)) aligns with build configuration and that the fallback path (fmha::clear(acc)) is always available when the macro is undefined.csrc/fmha_v2/fmha/hopper/utils_gmma.h (1)
1-18: LGTM: Clean aggregator header for GMMA utilities.This header provides a single include point for all GMMA-related utilities (HGMMA, HGMMA BF16, IGMMA, QGMMA). The static analysis hint about missing headers is a false positive - the build system handles include paths.
csrc/fmha_v2/fmha/hopper/utils_warpgroup.h (1)
19-40: LGTM: Well-guarded SM90 warpgroup synchronization primitives.The warpgroup functions are correctly guarded for SM90+ architectures with the
__CUDA_ARCH_FEAT_SM90_ALLfeature test. The template parameterNinwarpgroup_waituses the correct"n"constraint for immediate values.Note: On non-SM90 architectures, these functions compile to no-ops, which is the intended behavior for this FMHA infrastructure.
csrc/fmha_v2/fmha/gemm.h (1)
22-31: LGTM: Clean templated GEMM accumulation primitive.The implementation correctly delegates to the Fragment's
mmamethod, which handles the architecture-specific MMA intrinsics defined infragment.h. The#pragma unrolldirectives ensure full loop unrolling for performance.csrc/fmha_v2/fmha/numeric_types.h (2)
44-53: LGTM: Softmax FP8 quantization scale.The
Softmax_fp_quant_scale<e4m3_t>()specialization correctly computes the quantization scale as 256.f (2^8), which is the largest power-of-2 below the e4m3 max of 448. The comment explains the rationale well.
36-39: I'm unable to complete the automated verification as the repository cloning failed. However, I can still provide guidance on the review comment based on the learnings and context provided.Based on the retrieved learnings from your flashinfer work, I note that silent fallbacks and error-swallowing patterns are established practices in this codebase when they serve a legitimate purpose (e.g., probing multiple configurations in GEMM implementations).
The key question for your review is: Are the
e4m3_t/e5m2_t→charfallbacks intentional and safe, or do they genuinely risk runtime bugs?To properly verify this comment, I would need to:
- Examine how
e4m3_tande5m2_tare actually used in FMHA kernels- Check if there are runtime guards preventing their use on non-FP8 hardware
- Determine if a static_assert is already present elsewhere
- Verify whether the
#elsebranch is even reached in production buildsSince I cannot access the repository directly, I recommend you verify these points manually or provide repository access for automated analysis.
Fallback types may silently compile but cause runtime issues if FP8 operations are used on unsupporting hardware without guards. Consider verifying whether runtime checks prevent actual FP8 usage when the fallback is active, and whether a static_assert would be appropriate as additional safety insurance.
csrc/fmha_v2/fmha/gmem_tile_qkv.h (2)
122-147: Inconsistentis_active_usage betweenload()overloads.The parameterless
load()function (line 138-147) checksis_active_before performing loads, butload(Smem_tile&)(line 123-135) does not. This inconsistency could lead to out-of-bounds accesses ifload(Smem_tile&)is called when threads are inactive.Please verify this is intentional. If
load(Smem_tile&)relies on predicates alone for safety, consider adding a comment to clarify this design choice.
39-41: Based on my research, I found thatSLIDING_WINDOW_ATTENTIONis indeed a real and active feature in FlashInfer for implementing local/windowed attention patterns. Thegmem_tile_qkvstruct is part of the kernel infrastructure for handling tiled memory copies of QKV matrices during attention computation.However, I cannot clone the repository to verify whether the template parameter is actually unused within the implementation or to determine if it's intentionally present for future use. Given that this is a legitimate feature in FlashInfer, the presence of this parameter—even if currently unused—may indicate:
- Incomplete implementation of sliding window support in this component
- Dead code that was prepared but never integrated
- A placeholder for future sliding window functionality
Since I cannot perform definitive verification of the code implementation without repository access, I'll provide the rewritten comment:
SLIDING_WINDOW_ATTENTIONtemplate parameter appears unused but is a known FlashInfer feature.While
SLIDING_WINDOW_ATTENTIONis declared as a template parameter, it is not referenced in the implementation. This parameter corresponds to FlashInfer's sliding/local-window attention feature.Verify if this parameter should be integrated into the struct's logic (e.g., in tile loading or offset calculations) or if it represents incomplete functionality that needs implementation.
csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h (1)
86-99: Coordinate calculation logic for MQA/GQA looks correct.The handling of
qkv_offsetfor Q (0), K (1), and V (2) tensors with proper head index calculations for MQA/GQA (h_kv < h) versus standard MHA appears correct. The coordinate assembly for TMA descriptors is well-structured.csrc/fmha_v2/fmha/gmem_tile_o.h (2)
282-288: Preprocessor-based conditional initialization may leaveo_scratch_ptr_uninitialized.When
USE_DEMO_BERT_PARAMSis defined,o_scratch_ptr_is set tonullptr. Ensure that all code paths that accesso_scratch_ptr_check for this case, or add a comment clarifying that scratch space is unused in demo mode.
80-105: LGTM - Constructor logic is well-structured.The pointer arithmetic correctly handles batch/head offsets with
int64_tto prevent overflow on large tensors. The partial-write predicate initialization is properly guarded.csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h (2)
30-44: LGTM - Inline PTX assembly for BF16 GMMA is correctly implemented.The SM90 architecture guards (
__CUDA_ARCH__ >= 900 && __CUDA_ARCH_FEAT_SM90_ALL) are appropriate. The wgmma instruction format and register constraints match NVIDIA PTX ISA specifications for BF16 tensor core operations.
23-75: Repository access is unavailable for direct verification.The web search confirms that the m64n32k16 GMMA shape is a standard, documented Hopper instruction, taking 16 uint32_t accumulators. This makes the reviewer's concern technically sound—if specializations for N=8, 64, 128, 192, 256 exist but N=32 is missing from the base
Hgmma_bf16path (while present in the RFA variant), a linker error could occur for 64×32×16 operations requiring both operands from SMEM.However, I cannot directly inspect the file to confirm:
- Which specializations currently exist in
utils_hgmma_bf16.h- Whether
Hgmma_bf16<32, TA, TB>is truly missing- Whether any code paths actually invoke this specialization
The reasoning in the review comment is logically coherent and the underlying CUDA mechanics are correct, but the specific claim about the file's contents requires repository access to validate conclusively.
csrc/fmha_v2/fmha/hopper/utils_igmma.h (2)
30-53: LGTM - INT8 IGMMA implementation is correct.The
wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8instruction format is correct for INT8 tensor core operations with INT32 accumulators. The K dimension of 32 (vs 16 for FP16/BF16) correctly reflects the wider K for 8-bit operations.
200-203: Consistent API wrapper pattern.The
igmma_int8_int32wrapper follows the same pattern as the FP16/BF16 utilities, with an unused boolean template parameter for API consistency.csrc/fmha_v2/fmha/hopper/kernel_traits.h (2)
51-69: LGTM - Clear RF flag propagation from traits.The extraction of
GMMA_A_RFandGMMA_B_RFflags from bothTraits_pandTraits_ois well-documented and enables flexible GMMA operand sourcing (RF vs SMEM) per GEMM stage.
303-303: SMEM O size consistency validated with static_assert.Good use of compile-time validation to ensure
Smem_tile_o::BYTES_PER_TILEmatches the calculatedBYTES_PER_SMEM_O.csrc/fmha_v2/fmha/hopper/fragment.h (2)
78-83: Address TODO comment about fragment size calculation.There's an unresolved TODO questioning whether the
* 4factor is needed in the fragment size calculation. This could lead to incorrect register allocation if the sizing is wrong.Please verify whether the
* 4factor is needed by comparing with the FP16/BF16 Fragment_a specializations above (lines 29-46) which don't have this factor. If the factor is needed for FP8 due to different packing, please document the reasoning.
456-475: LGTM - Softmax state storage with proper bounds checking.The store method correctly handles the interleaved softmax sum/max storage with proper sequence length bounds checking at line 472. The conditional override for invalid runs (line 467-469) correctly sets sum to 1.0 to avoid division issues in downstream code.
csrc/fmha_v2/fmha/hopper/compute_tile.h (2)
28-81: Well-structured SMEM-SMEM compute tile initialization.The GMMA descriptor setup correctly initializes both A and B descriptors with proper SMEM pointers and max descriptor values for buffer cycling. The static_assert at lines 39-41 ensures descriptor compatibility between A and B tiles.
460-468: I'm unable to access the repository due to cloning issues. Based on the review comment content alone, I can see it raises a theoretically valid concern about index handling, but I cannot definitively verify whether:
- The
get_descriptor(ki)call actually uses direct indexing or modular arithmetic- Whether the descriptor arrays are sized to handle indices up to
NUM_KGROUPS * MMAS_K_PER_GROUP - 1- Whether this is an actual bug or intentional by design
Let me provide the rewritten comment reflecting this situation:
Cannot definitively verify the index mismatch claim without access to the compute() method implementation and descriptor structure.
The concern about
compute()being called at line 467 withki = NUM_KGROUPS * MMAS_K_PER_GROUP - 1while the loop at line 461 iterateskifrom 0 toMMAS_K_PER_GROUP - 2is mathematically plausible. However, verification requires examining howget_descriptor(ki)handles this index—specifically whether it uses direct array indexing, modular arithmetic (ki % MMAS_K_PER_GROUP), or global indexing. Please review thecompute()method signature at line 364 and theget_descriptor()implementation to confirm whether this represents an out-of-bounds access or is intentionally designed.csrc/fmha_v2/fmha/hopper/gmma_descriptor.h (2)
260-268: LGTM - Descriptor increment with proper boundary handling.The increment logic correctly handles the loop-back case when reaching the end of a kblock, using the precomputed
BYTES_DESC_INC_BOUNDARY_NO_4LSBconstant. The use ofreinterpret_cast<int2&>to access the lower 32 bits of the 64-bit descriptor is a standard pattern for GMMA descriptor manipulation.
361-369: LGTM - RESET_SMEM logic for large K dimensions.The
RESET_SMEMandRESET_BYTES_NO_4LSBconstants correctly handle the case where the K dimension exceeds a single 128-byte leading dimension, requiring descriptor reset during iteration. This is specific to the B matrix handling for transposed layouts.csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h (1)
906-931: Verify paged-KV block index bounds inGmem_tile_paged_kv::load
paged_kv_block_idxis derived fromrow_idxand used to indexpaged_kv_global_block_offsets_before predicates are applied:int row_idx = row_ + ii * (int)ROWS_PER_LDG; int paged_kv_block_idx = (row_idx >> paged_kv_log2_block_size_); char const* local_kv_ptr = paged_kv_block_pool_ptr_ + params_kv_block_size_in_bytes_ * paged_kv_global_block_offsets_[paged_kv_block_idx]; // ... preds[ii] = row_idx < actual_seqlen_;This is safe only if
actual_seqlen_andparams.paged_kv_cache.mMaxBlocksPerSeqguarantee that
row_idx < actual_seqlen_impliespaged_kv_block_idx < mMaxBlocksPerSeqfor all rows a CTA can
touch (includingcta_row_offset). The TODO in the comment suggests this may not be fully audited.Please double-check those invariants (especially for large sequence lengths / chunked context) so we
never index beyondmBlockOffsets. If they already hold by construction, consider documenting that
relation near this code.csrc/fmha_v2/fmha/smem_tile.h (1)
337-436: TMA specialization ofSmem_tile_without_skewslooks consistentThe TMA-enabled specialization correctly reuses the base layout, overrides buffer sizing to ignore STS granularity, wires
utmaldgviasmem_write_offset_/smem_barrier_offset_, and keeps buffer-rotation logic in sync with the number of buffers and barriers. No issues stand out in the offset math or barrier handling.csrc/fmha_v2/fmha/gmem_tile_ps.h (1)
515-585:Gmem_tile_psbase implementation appears consistent with per-thread MMA layoutThe generic
Gmem_tile_ps<Traits, Cta_tile, BITS_PER_ELEMENT>correctly:
- Derives per-thread
(row, col)from(warp, lane)in line with MMA tiling.- Computes
step_m/step_nto cover an 8×8 region of 2-element packets per thread.- Delegates to
Store_accumulator<Traits, BITS_PER_ELEMENT>with offsets based onM_PER_MMA_PER_CTAandN_PER_MMA_PER_CTA.No issues stand out in the pointer math or scaling wiring.
| __global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) { | ||
| // The step. | ||
| size_t step = (size_t)gridDim.x * blockDim.x; | ||
|
|
||
| // Iterate over the elements. | ||
| for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { | ||
| // Load 4 integers. | ||
| int4 tmp = reinterpret_cast<int4 const*>(src)[ii]; | ||
|
|
||
| // Convert to float and scale. | ||
| float x = static_cast<float>(tmp.x) * scale; | ||
| float y = static_cast<float>(tmp.y) * scale; | ||
| float z = static_cast<float>(tmp.z) * scale; | ||
| float w = static_cast<float>(tmp.w) * scale; | ||
|
|
||
| // Convert to int8. | ||
| uint32_t a; | ||
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x)); | ||
| uint32_t b; | ||
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y)); | ||
| uint32_t c; | ||
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z)); | ||
| uint32_t d; | ||
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w)); | ||
|
|
||
| // Compact. | ||
| char4 out; | ||
| out.x = reinterpret_cast<int8_t const&>(a); | ||
| out.y = reinterpret_cast<int8_t const&>(b); | ||
| out.z = reinterpret_cast<int8_t const&>(c); | ||
| out.w = reinterpret_cast<int8_t const&>(d); | ||
|
|
||
| // Store. | ||
| reinterpret_cast<uint32_t*>(dst)[ii] = reinterpret_cast<uint32_t const&>(out); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tail elements are not converted when n is not divisible by 4.
The kernel only processes n/4 iterations, each handling 4 elements. If n % 4 != 0, the remaining 1-3 elements are silently skipped, leading to uninitialized/stale data in the output.
Consider adding a tail-handling path or documenting the alignment requirement:
__global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) {
+ // Note: n must be a multiple of 4. Caller must ensure proper alignment.
+ assert(n % 4 == 0);
// The step.
size_t step = (size_t)gridDim.x * blockDim.x;Alternatively, add a scalar tail loop after the main vectorized loop.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| __global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) { | |
| // The step. | |
| size_t step = (size_t)gridDim.x * blockDim.x; | |
| // Iterate over the elements. | |
| for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { | |
| // Load 4 integers. | |
| int4 tmp = reinterpret_cast<int4 const*>(src)[ii]; | |
| // Convert to float and scale. | |
| float x = static_cast<float>(tmp.x) * scale; | |
| float y = static_cast<float>(tmp.y) * scale; | |
| float z = static_cast<float>(tmp.z) * scale; | |
| float w = static_cast<float>(tmp.w) * scale; | |
| // Convert to int8. | |
| uint32_t a; | |
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x)); | |
| uint32_t b; | |
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y)); | |
| uint32_t c; | |
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z)); | |
| uint32_t d; | |
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w)); | |
| // Compact. | |
| char4 out; | |
| out.x = reinterpret_cast<int8_t const&>(a); | |
| out.y = reinterpret_cast<int8_t const&>(b); | |
| out.z = reinterpret_cast<int8_t const&>(c); | |
| out.w = reinterpret_cast<int8_t const&>(d); | |
| // Store. | |
| reinterpret_cast<uint32_t*>(dst)[ii] = reinterpret_cast<uint32_t const&>(out); | |
| } | |
| } | |
| __global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) { | |
| // Note: n must be a multiple of 4. Caller must ensure proper alignment. | |
| assert(n % 4 == 0); | |
| // The step. | |
| size_t step = (size_t)gridDim.x * blockDim.x; | |
| // Iterate over the elements. | |
| for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { | |
| // Load 4 integers. | |
| int4 tmp = reinterpret_cast<int4 const*>(src)[ii]; | |
| // Convert to float and scale. | |
| float x = static_cast<float>(tmp.x) * scale; | |
| float y = static_cast<float>(tmp.y) * scale; | |
| float z = static_cast<float>(tmp.z) * scale; | |
| float w = static_cast<float>(tmp.w) * scale; | |
| // Convert to int8. | |
| uint32_t a; | |
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x)); | |
| uint32_t b; | |
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y)); | |
| uint32_t c; | |
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z)); | |
| uint32_t d; | |
| asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w)); | |
| // Compact. | |
| char4 out; | |
| out.x = reinterpret_cast<int8_t const&>(a); | |
| out.y = reinterpret_cast<int8_t const&>(b); | |
| out.z = reinterpret_cast<int8_t const&>(c); | |
| out.w = reinterpret_cast<int8_t const&>(d); | |
| // Store. | |
| reinterpret_cast<uint32_t*>(dst)[ii] = reinterpret_cast<uint32_t const&>(out); | |
| } | |
| } |
🤖 Prompt for AI Agents
In csrc/fmha_v2/convert.cu around lines 19-54, the kernel only iterates n/4
vector elements so when n is not divisible by 4 the last 1–3 elements are
skipped; add a tail-handling path after the main vectorized loop that computes
base = (n/4)*4 and a scalar loop for i from base to n-1 that loads the single
int32, converts to float, applies the same scale and saturating int8 conversion,
and writes the byte to dst (ensure correct byte-offsets and casts into dst), or
alternatively document and assert that n must be 4-aligned. Ensure the scalar
path uses the same rounding/saturation behavior as the vector path and respects
thread/work distribution and bounds checks to avoid out-of-bounds stores.
| void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d, | ||
| float scale_o) { | ||
| run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Integer overflow risk in dimension multiplication.
The expression s * b * h * d performs int multiplication before being passed to the size_t parameter. For large tensors, this can overflow.
Cast to size_t before multiplication:
void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d,
float scale_o) {
- run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o);
+ run_conversion_fp32_to_e4m3(dst, src, (size_t)s * b * h * d, scale_o);
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d, | |
| float scale_o) { | |
| run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o); | |
| } | |
| void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d, | |
| float scale_o) { | |
| run_conversion_fp32_to_e4m3(dst, src, (size_t)s * b * h * d, scale_o); | |
| } |
🤖 Prompt for AI Agents
In csrc/fmha_v2/convert.cu around lines 181 to 184, the multiplication s * b * h
* d is done in int and may overflow before being passed to the function
expecting a size_t; change the call to compute the product in size_t (e.g. cast
one operand to size_t before multiplying) or otherwise promote operands to
size_t so the multiplication is performed in size_t and then pass that value to
run_conversion_fp32_to_e4m3.
| constexpr static int round_down_to_power_two(int x) { | ||
| x = x | (x >> 1); | ||
| x = x | (x >> 2); | ||
| x = x | (x >> 4); | ||
| x = x | (x >> 8); | ||
| x = x | (x >> 16); | ||
| return x - (x >> 1); | ||
| } | ||
|
|
||
| AlibiParams() = default; | ||
|
|
||
| AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) { | ||
| h_pow_2 = round_down_to_power_two(h); | ||
| alibi_neg4_div_h = -4.0f / h_pow_2; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential division by zero when h=0.
If h=0 is passed to the constructor, round_down_to_power_two(0) returns 0, causing division by zero on line 31.
Consider adding a guard or assertion:
AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) {
h_pow_2 = round_down_to_power_two(h);
+ // Avoid division by zero; h_pow_2 == 0 implies h == 0 which is invalid.
+ assert(h_pow_2 > 0 && "Number of heads must be positive");
alibi_neg4_div_h = -4.0f / h_pow_2;
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| constexpr static int round_down_to_power_two(int x) { | |
| x = x | (x >> 1); | |
| x = x | (x >> 2); | |
| x = x | (x >> 4); | |
| x = x | (x >> 8); | |
| x = x | (x >> 16); | |
| return x - (x >> 1); | |
| } | |
| AlibiParams() = default; | |
| AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) { | |
| h_pow_2 = round_down_to_power_two(h); | |
| alibi_neg4_div_h = -4.0f / h_pow_2; | |
| } | |
| constexpr static int round_down_to_power_two(int x) { | |
| x = x | (x >> 1); | |
| x = x | (x >> 2); | |
| x = x | (x >> 4); | |
| x = x | (x >> 8); | |
| x = x | (x >> 16); | |
| return x - (x >> 1); | |
| } | |
| AlibiParams() = default; | |
| AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) { | |
| h_pow_2 = round_down_to_power_two(h); | |
| // Avoid division by zero; h_pow_2 == 0 implies h == 0 which is invalid. | |
| assert(h_pow_2 > 0 && "Number of heads must be positive"); | |
| alibi_neg4_div_h = -4.0f / h_pow_2; | |
| } |
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/alibi_params.h around lines 18 to 32, the constructor calls
round_down_to_power_two(h) and then divides by h_pow_2 which will be zero when
h==0; add a guard to prevent division by zero by either asserting/throwing for
non-positive h or clamping h_pow_2 to a safe minimum (e.g. 1) before computing
alibi_neg4_div_h, and optionally update round_down_to_power_two to return 1 for
non-positive input; ensure the constructor documents/validates the precondition
and that alibi_neg4_div_h is computed only with a non-zero denominator.
| inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { | ||
| int rows_so_far = mi * STGS_PER_LOOP * ROWS_PER_STG; | ||
| int rows_so_far_per_slice = rows_so_far / 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded division assumes NUM_SLICES == 2.
Line 1279 uses rows_so_far / 2 which assumes exactly 2 slices. However, line 1187 allows NUM_SLICES == 1, which would cause incorrect row calculations.
Consider using the enum value:
int rows_so_far = mi * STGS_PER_LOOP * ROWS_PER_STG;
- int rows_so_far_per_slice = rows_so_far / 2;
+ int rows_so_far_per_slice = rows_so_far / NUM_SLICES;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { | |
| int rows_so_far = mi * STGS_PER_LOOP * ROWS_PER_STG; | |
| int rows_so_far_per_slice = rows_so_far / 2; | |
| inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) { | |
| int rows_so_far = mi * STGS_PER_LOOP * ROWS_PER_STG; | |
| int rows_so_far_per_slice = rows_so_far / NUM_SLICES; |
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/gmem_tile_o_packed.h around lines 1277-1279, the code
computes rows_so_far_per_slice using a hardcoded division by 2 which assumes
NUM_SLICES == 2; replace the literal 2 with the enum/constant for the number of
slices (e.g., divide by NUM_SLICES) so the computation works for NUM_SLICES == 1
or other values, ensuring you use the appropriate integer type/cast if
NUM_SLICES is an enum or template constant.
| // Ctor for other param classes (such as Qkv_params in train_ops) | ||
| template <typename Params, typename Block_info> | ||
| inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo, | ||
| int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) | ||
| : Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h, | ||
| qkv_offset, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) {} | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug in Gmem_tile_qkv forwarding ctor: mis-ordered arguments corrupt row/col indexing
The generic Params ctor
template <typename Params, typename Block_info>
inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo,
int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
: Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h,
qkv_offset, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) {}forwards cta_row_offset into the num_kv_heads parameter of the main ctor and cta_col_offset_in_bytes into its cta_row_offset parameter, leaving cta_col_offset_in_bytes at its default. This will:
- Use an incorrect
num_kv_headsin thenum_kv_heads < num_heads(MQA/GQA) branch, and - Miscompute row offsets (
row + cta_row_offset) by treating a byte offset as a row count.
Both lead to wrong global-memory addresses, and the GQA path can even see division-by-zero if num_kv_heads ends up as 0 by accident.
A minimal safe fix is to pass an explicit num_kv_heads and then the row/column offsets in the correct positions. For non-GQA callers where Params has no KV-head count, using params.h keeps behavior identical to “no GQA/MQA” (the num_kv_heads < num_heads branch won’t trigger):
- template <typename Params, typename Block_info>
- inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo,
- int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
- : Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h,
- qkv_offset, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) {}
+ template <typename Params, typename Block_info>
+ inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo,
+ int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
+ : Gmem_tile_qkv(params.qkv_ptr,
+ params.q_stride_in_bytes,
+ params.d,
+ params.dv,
+ params.h,
+ qkv_offset,
+ binfo,
+ tidx,
+ /*num_kv_heads=*/params.h, // or Params::num_kv_heads if/when available
+ cta_row_offset,
+ cta_col_offset_in_bytes) {}If some Params type does expose a distinct KV-head count (e.g., params.h_kv), it should be threaded through here instead of params.h.
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h around lines 133-139, the forwarding
ctor passes cta_row_offset and cta_col_offset_in_bytes into the wrong parameters
of the main ctor (they are shifting into num_kv_heads and cta_row_offset),
corrupting head counts and row/column indexing; fix by explicitly supplying
num_kv_heads (use params.h_kv if available, otherwise params.h) as the sixth
argument and then pass qkv_offset, binfo, tidx, cta_row_offset,
cta_col_offset_in_bytes in that order so the main ctor receives (qkv_ptr,
q_stride_in_bytes, d, dv, h, num_kv_heads, qkv_offset, binfo, tidx,
cta_row_offset, cta_col_offset_in_bytes).
| class Multiple_tma_descriptor { | ||
| public: | ||
| // ctor | ||
| Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) { | ||
| if (batch_size > 0) { | ||
| // allocate host memory | ||
| desc_ptr_h = new cudaTmaDesc[batch_size]; | ||
| // make sure all bit fields are zeros. | ||
| memset(desc_ptr_h, 0, sizeof(cudaTmaDesc) * batch_size); | ||
| } | ||
| } | ||
|
|
||
| // ctor | ||
| Multiple_tma_descriptor() = default; | ||
|
|
||
| // destructor. | ||
| ~Multiple_tma_descriptor() { | ||
| if (batch_size > 0) { | ||
| // deallocate host memory | ||
| delete[] desc_ptr_h; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Missing copy/move semantics (Rule of Five violation).
The class manages raw memory with new[]/delete[] but does not define copy constructor, copy assignment, move constructor, or move assignment operators. If an instance is copied, the default copy will lead to double-free on destruction.
Either delete copy/move operations or implement them properly:
class Multiple_tma_descriptor {
public:
+ // Prevent accidental copies that would cause double-free
+ Multiple_tma_descriptor(Multiple_tma_descriptor const&) = delete;
+ Multiple_tma_descriptor& operator=(Multiple_tma_descriptor const&) = delete;
+ Multiple_tma_descriptor(Multiple_tma_descriptor&&) = delete;
+ Multiple_tma_descriptor& operator=(Multiple_tma_descriptor&&) = delete;
+
// ctor
Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) {📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| class Multiple_tma_descriptor { | |
| public: | |
| // ctor | |
| Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) { | |
| if (batch_size > 0) { | |
| // allocate host memory | |
| desc_ptr_h = new cudaTmaDesc[batch_size]; | |
| // make sure all bit fields are zeros. | |
| memset(desc_ptr_h, 0, sizeof(cudaTmaDesc) * batch_size); | |
| } | |
| } | |
| // ctor | |
| Multiple_tma_descriptor() = default; | |
| // destructor. | |
| ~Multiple_tma_descriptor() { | |
| if (batch_size > 0) { | |
| // deallocate host memory | |
| delete[] desc_ptr_h; | |
| } | |
| } | |
| class Multiple_tma_descriptor { | |
| public: | |
| // Prevent accidental copies that would cause double-free | |
| Multiple_tma_descriptor(Multiple_tma_descriptor const&) = delete; | |
| Multiple_tma_descriptor& operator=(Multiple_tma_descriptor const&) = delete; | |
| Multiple_tma_descriptor(Multiple_tma_descriptor&&) = delete; | |
| Multiple_tma_descriptor& operator=(Multiple_tma_descriptor&&) = delete; | |
| // ctor | |
| Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) { | |
| if (batch_size > 0) { | |
| // allocate host memory | |
| desc_ptr_h = new cudaTmaDesc[batch_size]; | |
| // make sure all bit fields are zeros. | |
| memset(desc_ptr_h, 0, sizeof(cudaTmaDesc) * batch_size); | |
| } | |
| } | |
| // ctor | |
| Multiple_tma_descriptor() = default; | |
| // destructor. | |
| ~Multiple_tma_descriptor() { | |
| if (batch_size > 0) { | |
| // deallocate host memory | |
| delete[] desc_ptr_h; | |
| } | |
| } |
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/tma_descriptor.h around lines 26-47, the class
Multiple_tma_descriptor manages raw memory (desc_ptr_h) but lacks copy/move
special members causing double-free; fix by implementing the Rule of Five:
ensure the default constructor initializes batch_size to 0 and desc_ptr_h to
nullptr, implement a copy constructor and copy-assignment that perform a deep
copy (allocate new cudaTmaDesc[batch_size], memcpy/set contents) and handle
self-assignment, implement a move constructor and move-assignment that transfer
desc_ptr_h and batch_size from the source and set the source's desc_ptr_h to
nullptr and batch_size to 0, and keep the existing destructor to delete[]
desc_ptr_h.
| // set the desctriptor. | ||
| int set_tma_desctriptor( | ||
| // ptr to gmem | ||
| void const* gmem_ptr, | ||
| // format is really data_type in TMA terminology. | ||
| cudaTmaDescFormat format, | ||
| // interleave mode. | ||
| cudaTmaDescInterleave interleave, | ||
| // swizzle mode. | ||
| cudaTmaDescSwizzle swizzle, | ||
| // L2 sector promotion. | ||
| cudaTmaDescPromotion promotion, uint32_t const (&tensor_size_array)[NUM_DIMS], | ||
| uint64_t const (&tensor_stride_array)[NUM_DIMS - 1], | ||
| uint32_t const (&traversal_stride_array)[NUM_DIMS], | ||
| uint32_t const (&box_size_array)[NUM_DIMS], | ||
| // OOB fill mode. | ||
| uint32_t fill_oob, | ||
| // FP32 to TF32 conversion. | ||
| uint32_t round_to_tf32, | ||
| // index to desc. | ||
| cudaTmaDesc* desc_ptr = nullptr) { | ||
| set_tensor_common_0(desc_ptr, reinterpret_cast<uint64_t>(gmem_ptr)); | ||
| set_tensor_common_1(desc_ptr, TILED, NUM_DIMS, format, interleave, swizzle, fill_oob, | ||
| round_to_tf32, promotion); | ||
|
|
||
| set_tensor_stride(desc_ptr, tensor_stride_array); | ||
| set_tensor_size(desc_ptr, tensor_size_array); | ||
|
|
||
| set_traversal_stride_tiled(desc_ptr, traversal_stride_array); | ||
|
|
||
| set_box_size(desc_ptr, box_size_array); | ||
| return 0; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Null pointer dereference risk in set_tma_desctriptor overload.
The second overload of set_tma_desctriptor (line 84-115) has desc_ptr = nullptr as the default parameter value. If called without providing a valid pointer, all the set_* helper functions will dereference nullptr, causing undefined behavior.
Add a null check or remove the default value:
int set_tma_desctriptor(
// ... parameters ...
- cudaTmaDesc* desc_ptr = nullptr) {
+ cudaTmaDesc* desc_ptr) {
+ assert(desc_ptr != nullptr);
set_tensor_common_0(desc_ptr, reinterpret_cast<uint64_t>(gmem_ptr));📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // set the desctriptor. | |
| int set_tma_desctriptor( | |
| // ptr to gmem | |
| void const* gmem_ptr, | |
| // format is really data_type in TMA terminology. | |
| cudaTmaDescFormat format, | |
| // interleave mode. | |
| cudaTmaDescInterleave interleave, | |
| // swizzle mode. | |
| cudaTmaDescSwizzle swizzle, | |
| // L2 sector promotion. | |
| cudaTmaDescPromotion promotion, uint32_t const (&tensor_size_array)[NUM_DIMS], | |
| uint64_t const (&tensor_stride_array)[NUM_DIMS - 1], | |
| uint32_t const (&traversal_stride_array)[NUM_DIMS], | |
| uint32_t const (&box_size_array)[NUM_DIMS], | |
| // OOB fill mode. | |
| uint32_t fill_oob, | |
| // FP32 to TF32 conversion. | |
| uint32_t round_to_tf32, | |
| // index to desc. | |
| cudaTmaDesc* desc_ptr = nullptr) { | |
| set_tensor_common_0(desc_ptr, reinterpret_cast<uint64_t>(gmem_ptr)); | |
| set_tensor_common_1(desc_ptr, TILED, NUM_DIMS, format, interleave, swizzle, fill_oob, | |
| round_to_tf32, promotion); | |
| set_tensor_stride(desc_ptr, tensor_stride_array); | |
| set_tensor_size(desc_ptr, tensor_size_array); | |
| set_traversal_stride_tiled(desc_ptr, traversal_stride_array); | |
| set_box_size(desc_ptr, box_size_array); | |
| return 0; | |
| } | |
| // set the desctriptor. | |
| int set_tma_desctriptor( | |
| // ptr to gmem | |
| void const* gmem_ptr, | |
| // format is really data_type in TMA terminology. | |
| cudaTmaDescFormat format, | |
| // interleave mode. | |
| cudaTmaDescInterleave interleave, | |
| // swizzle mode. | |
| cudaTmaDescSwizzle swizzle, | |
| // L2 sector promotion. | |
| cudaTmaDescPromotion promotion, uint32_t const (&tensor_size_array)[NUM_DIMS], | |
| uint64_t const (&tensor_stride_array)[NUM_DIMS - 1], | |
| uint32_t const (&traversal_stride_array)[NUM_DIMS], | |
| uint32_t const (&box_size_array)[NUM_DIMS], | |
| // OOB fill mode. | |
| uint32_t fill_oob, | |
| // FP32 to TF32 conversion. | |
| uint32_t round_to_tf32, | |
| // index to desc. | |
| cudaTmaDesc* desc_ptr) { | |
| assert(desc_ptr != nullptr); | |
| set_tensor_common_0(desc_ptr, reinterpret_cast<uint64_t>(gmem_ptr)); | |
| set_tensor_common_1(desc_ptr, TILED, NUM_DIMS, format, interleave, swizzle, fill_oob, | |
| round_to_tf32, promotion); | |
| set_tensor_stride(desc_ptr, tensor_stride_array); | |
| set_tensor_size(desc_ptr, tensor_size_array); | |
| set_traversal_stride_tiled(desc_ptr, traversal_stride_array); | |
| set_box_size(desc_ptr, box_size_array); | |
| return 0; | |
| } |
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/tma_descriptor.h around lines 83 to 115, the
overload of set_tma_desctriptor uses desc_ptr = nullptr by default which allows
calling without a valid descriptor and leads to null dereference when set_*
helpers are invoked; either remove the default so callers must supply a valid
cudaTmaDesc* or add an early null check at the top (e.g., if (desc_ptr ==
nullptr) return a non-zero error code or assert) and avoid calling any set_*
helpers when desc_ptr is null, keeping the function’s existing return
convention.
| // copy the desc to device memory | ||
| void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) { | ||
| FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size, | ||
| cudaMemcpyHostToDevice)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused stream parameter and synchronous copy.
The copy_to_device function accepts a cudaStream_t stream parameter but uses synchronous cudaMemcpy instead of cudaMemcpyAsync. This is misleading and the stream is silently ignored.
Either use the stream parameter with async copy or remove it:
// copy the desc to device memory
- void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) {
- FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size,
- cudaMemcpyHostToDevice));
+ void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) {
+ FMHA_CHECK_CUDA(cudaMemcpyAsync(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size,
+ cudaMemcpyHostToDevice, stream));
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // copy the desc to device memory | |
| void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) { | |
| FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size, | |
| cudaMemcpyHostToDevice)); | |
| } | |
| // copy the desc to device memory | |
| void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) { | |
| FMHA_CHECK_CUDA(cudaMemcpyAsync(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size, | |
| cudaMemcpyHostToDevice, stream)); | |
| } |
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/tma_descriptor.h around lines 117 to 121, the
function copy_to_device accepts a cudaStream_t stream but performs a synchronous
cudaMemcpy (ignoring the stream); either replace cudaMemcpy with cudaMemcpyAsync
using the provided stream and keep the parameter (propagate error via
FMHA_CHECK_CUDA and ensure memory size is correct), or remove the stream
parameter from the function signature and callers if asynchronous behavior is
not needed—pick one approach and update all call sites and tests accordingly.
| template <bool TB> | ||
| struct Hgmma_rfa_fp16<8, TB> { | ||
| static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[2]) { | ||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) | ||
| int const trans_b = TB ? 1 : 0; | ||
| asm volatile( | ||
| "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " | ||
| "{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n" | ||
|
|
||
| : "+r"(acc[0]), "+r"(acc[1]) | ||
| : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_a), "l"(desc_b), "n"(trans_b)); | ||
| #endif | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undeclared desc_a and mismatched operands in Hgmma_rfa_fp16<8, TB>::mma
In the RF-A FP16 specialization for N = 8:
template <bool TB>
struct Hgmma_rfa_fp16<8, TB> {
static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[2]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
int const trans_b = TB ? 1 : 0;
asm volatile(
"wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
"{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n"
: "+r"(acc[0]), "+r"(acc[1])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_a), "l"(desc_b), "n"(trans_b));
#endif
}
};Issues:
desc_ais not a parameter and is never defined in this scope → compile-time error.- The instruction form with A from RF should only consume B’s descriptor plus the transpose flag, just like the
16/32/...RF-A variants.
This should be simplified to use only desc_b and trans_b, with the operand indices adjusted accordingly. For example:
template <bool TB>
struct Hgmma_rfa_fp16<8, TB> {
static inline __device__ void mma(uint32_t const (&a)[4],
- uint64_t desc_b, uint32_t (&acc)[2]) {
+ uint64_t desc_b, uint32_t (&acc)[2]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
int const trans_b = TB ? 1 : 0;
asm volatile(
"wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
"{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n"
- : "+r"(acc[0]), "+r"(acc[1])
- : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_a), "l"(desc_b), "n"(trans_b));
+ : "+r"(acc[0]), "+r"(acc[1])
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
+ "l"(desc_b), "n"(trans_b));
#endif
}
};This matches the operand pattern used in Hgmma_rfa_fp16<16, TB> and other RF-A specializations and removes the undeclared identifier.
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/utils_hgmma.h around lines 440-452, the inline asm
uses an undeclared desc_a and supplies too many operands; change the asm to only
consume B's descriptor and the transpose flag (desc_b and trans_b), remove the
"l"(desc_a) operand, and adjust the input operand ordering/indices so the a[]
registers remain as inputs, acc[] remain read-write outputs, and the last two
inputs are "l"(desc_b) and "n"(trans_b) matching the RF-A pattern used by other
specializations (e.g., the <16,TB> variant). Ensure the operand list and
referenced %n indices in the asm string are updated accordingly.
| #include <cuda_runtime_api.h> | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| #pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Move #pragma once to the top of the file.
The #pragma once directive should appear before any #include statements to ensure proper header guard behavior. Currently, it appears after the includes.
+#pragma once
+
#include <cuda_runtime_api.h>
#include <cstdint>
-
-#pragma once📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #include <cuda_runtime_api.h> | |
| #include <cstdint> | |
| #pragma once | |
| #pragma once | |
| #include <cuda_runtime_api.h> | |
| #include <cstdint> |
🧰 Tools
🪛 Clang (14.0.6)
[error] 13-13: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/numeric_types.h around lines 13 to 17, the #pragma once
header guard is placed after the #include directives; move the #pragma once line
to the very top of the file (before any #include or other preprocessor
directives) so the header guard applies immediately, preventing multiple
inclusion issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 14
♻️ Duplicate comments (18)
csrc/fmha_v2/fmha/alibi_params.h (1)
29-32: Division by zero issue already flagged.This constructor can cause division by zero when
h=0, asround_down_to_power_two(0)returns0. This issue has already been flagged in a previous review with a suggested fix (adding an assertion or guard).csrc/fmha_v2/fmha/hopper/tma_descriptor.h (4)
26-47: Missing copy/move semantics (Rule of Five violation).The class manages raw memory with
new[]/delete[]but lacks copy constructor, copy assignment, move constructor, and move assignment operators. Copying an instance would cause double-free.Add deleted or properly implemented copy/move operations:
class Multiple_tma_descriptor { public: + // Prevent accidental copies + Multiple_tma_descriptor(Multiple_tma_descriptor const&) = delete; + Multiple_tma_descriptor& operator=(Multiple_tma_descriptor const&) = delete; + Multiple_tma_descriptor(Multiple_tma_descriptor&&) = delete; + Multiple_tma_descriptor& operator=(Multiple_tma_descriptor&&) = delete; + // ctor Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) {
83-115: Null pointer dereference risk with defaultdesc_ptr = nullptr.The second
set_tma_desctriptoroverload defaultsdesc_ptrtonullptr. If called without providing a valid pointer, allset_*helpers will dereference nullptr, causing undefined behavior.Remove the default or add validation:
int set_tma_desctriptor( // ... parameters ... - cudaTmaDesc* desc_ptr = nullptr) { + cudaTmaDesc* desc_ptr) { + if (!desc_ptr) return -1; set_tensor_common_0(desc_ptr, reinterpret_cast<uint64_t>(gmem_ptr));
117-121: Unusedstreamparameter with synchronous copy.The
copy_to_devicefunction accepts acudaStream_t streamparameter but uses synchronouscudaMemcpyinstead ofcudaMemcpyAsync. The stream parameter is silently ignored.Either use async copy or remove the stream parameter:
void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) { - FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size, - cudaMemcpyHostToDevice)); + FMHA_CHECK_CUDA(cudaMemcpyAsync(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size, + cudaMemcpyHostToDevice, stream)); }
49-50: Typo: "desctriptor" should be "descriptor".The method name contains a typo at line 50 and line 84.
Consider renaming for consistency:
- int set_tma_desctriptor( + int set_tma_descriptor(csrc/fmha_v2/fmha/numeric_types.h (1)
13-17: Move#pragma oncebefore includes.The
#pragma oncedirective should appear at the very top of the file, before any#includestatements, to ensure the header guard takes effect immediately.csrc/fmha_v2/convert.cu (3)
19-54: Tail elements are not converted whennis not divisible by 4.The kernel only processes
n/4iterations, leaving 1-3 trailing elements unconverted whenn % 4 != 0.
46-49: Use ofreinterpret_castfor extracting int8 from PTX result.The
cvt.rni.sat.s8.f32instruction writes to the lower byte of the 32-bit register. Usingreinterpret_casthere is fragile. Consider usingstatic_cast<int8_t>(a)instead.
181-184: Integer overflow risk in dimension multiplication.The expression
s * b * h * dperformsintmultiplication before conversion tosize_t, risking overflow for large tensors.csrc/fmha_v2/fmha/hopper/arrive_wait.h (2)
52-58: Missingdefined()guard and non-standardCUDACC_VERSIONmacro.Line 52 uses
__CUDA_ARCH__without adefined()check, and line 56 usesCUDACC_VERSIONwhich is not a standard CUDA macro.
153-189: Potential uninitialized predicate registerP3whenpredis false.When
predis 0, thembarrier.try_waitinstruction is skipped due to the@P2predicate, butselp.b32 %0, 1, 0, P3still readsP3. SinceP3was never set when the instruction was skipped, its value is undefined.csrc/fmha_v2/fmha/hopper/fragment.h (1)
133-139: Inconsistent use of NUM_REGS vs NUM_ELTS in bf16 SMEM-SMEM accumulator add().This
add()method usesBase::NUM_REGS(line 136), but the RF-SMEM variant at line 203 usesBase::NUM_ELTS. Since this is a float-based accumulator accessing viaelt(), it should useNUM_ELTSfor consistency and correctness.template <typename Other_fragment_> inline __device__ void add(Other_fragment_ const& other) { - for (int ii = 0; ii < Base::NUM_REGS; ++ii) { + for (int ii = 0; ii < Base::NUM_ELTS; ++ii) { this->elt(ii) = this->elt(ii) + other.elt(ii); } }csrc/fmha_v2/fmha/hopper/kernel_traits.h (1)
223-227: Typo: "row marjor" should be "row major".- // We know V is row marjor. So we can also deduce the descriptor mode. + // We know V is row major. So we can also deduce the descriptor mode.csrc/fmha_v2/fmha/hopper/compute_tile.h (1)
160-178: Wrong descriptor type used inincrement_gmma_desc_a_group()loop bound.Line 164 uses
Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORSbut this method operates ongmma_desc_a_. This should useSmem_tile_a::Gmma_descriptor::NUM_DESCRIPTORSto ensure the iteration count matches the A-side descriptors.inline __device__ void increment_gmma_desc_a_group() { bool reset_buffer = gmma_desc_a_[0].get_descriptor(0) >= gmma_desc_a_[0].get_max_descriptor_0(); #pragma unroll - for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { + for (int idx = 0; idx < Smem_tile_a::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { #pragma unroll for (int mma_m_idx = 0; mma_m_idx < MMAS_M; ++mma_m_idx) {csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (1)
278-291: Inconsistent use ofMma_tile::VALID_MMAS_NvsBase::VALID_MMAS_N.Line 279 uses
Mma_tile::VALID_MMAS_Nwhile line 294 usesBase::VALID_MMAS_N. The fp16 and fp32 variants consistently useBase::VALID_MMAS_N. This inconsistency could cause compilation errors ifMma_tiledoesn't defineVALID_MMAS_N.#pragma unroll - for (int mma_ni = 0; mma_ni < Mma_tile::VALID_MMAS_N - 1; ++mma_ni) { + for (int mma_ni = 0; mma_ni < Base::VALID_MMAS_N - 1; ++mma_ni) {csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h (1)
133-138: Bug inGmem_tile_qkvforwarding constructor: missingnum_kv_headsargument.The generic
Paramsconstructor forwards arguments but omits thenum_kv_headsparameter, causingcta_row_offsetto be passed asnum_kv_headsandcta_col_offset_in_bytesto becomecta_row_offset. This corrupts GQA/MQA head indexing and row offset calculations.Compare to the correct
bert::Fused_multihead_attention_params_v2constructor at lines 125-131 which properly passesparams.h_kv. Apply:template <typename Params, typename Block_info> inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo, int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0) : Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h, - qkv_offset, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) {} + qkv_offset, binfo, tidx, params.h, cta_row_offset, cta_col_offset_in_bytes) {}Note: If
Paramshas a distinct KV head count (e.g.,params.h_kv), use that instead ofparams.h.csrc/fmha_v2/fmha/gmem_tile_o_packed.h (1)
1277-1279: Hardcoded division assumesNUM_SLICES == 2.This issue was already raised in a previous review. Line 1279 uses
rows_so_far / 2which assumes exactly 2 slices, but line 1187 allowsNUM_SLICES == 1.csrc/fmha_v2/fmha/hopper/utils_hgmma.h (1)
440-452: Undeclareddesc_ain inline asm operand list.This issue was already raised in a previous review. The inline asm references
desc_awhich is not a parameter and is never defined in this scope, causing a compile-time error on SM90 builds.
🧹 Nitpick comments (31)
csrc/fmha_v2/fmha/paged_kv_cache.h (2)
48-58: Consider more robust handling ofmTokensPerBlockLog2instead of FPlog2Right now
mTokensPerBlockLog2is derived viafloat tokensPerBlockSeqLog2 = log2(mTokensPerBlock);and anintcast. This assumes:
tokensPerBlock > 0, and- it’s an exact power of two so FP rounding can’t bite.
Given the comment “It must be power of 2”, you might harden this by either:
- using an integer‑based computation (e.g.,
std::countr_zero/__builtin_ctzonuint32_t(tokensPerBlock)), and optionally- adding a debug‑time assertion that
tokensPerBlockis a power of two (e.g.,(x & (x-1)) == 0).That keeps behavior well‑defined even if a misconfigured value slips through, without affecting the ABI/layout of the struct.
15-16:cuda_runtime.hstatic‑analysis error is likely a tooling configuration issueThe reported “
cuda_runtime.hfile not found” from clang usually means the analysis environment isn’t pointed at the CUDA toolkit includes, not that this include is wrong. The header itself is standard for CUDA‑related code and seems appropriate here.If this is breaking your CI linting, you probably want to:
- ensure your
compile_commands.json/clang‑tidy config passes the correct--cuda-pathor-Ito the CUDA include directory, or- gate CUDA‑dependent checks to only run where the toolkit is available.
No code change in this file seems necessary for that error.
If you’d like, I can suggest a small script to probe your current
compile_commands.jsonand locate places where CUDA include paths are missing so you can wire them into your static‑analysis config.csrc/fmha_v2/fmha/alibi_params.h (1)
18-25: Document preconditions or validate input range.The function correctly rounds down to the nearest power of two for positive integers, but has undefined behavior for negative inputs due to signed right shift. While negative head counts are semantically invalid, adding documentation or a debug assertion would prevent misuse.
Consider adding a precondition comment or assertion:
constexpr static int round_down_to_power_two(int x) { + // Precondition: x must be non-negative x = x | (x >> 1); x = x | (x >> 2); x = x | (x >> 4); x = x | (x >> 8); x = x | (x >> 16); return x - (x >> 1); }Alternatively, note that
softmax.huses a different approach with__clzfor similar rounding—consider aligning implementations for consistency.csrc/fmha_v2/fmha/smem_tile.h (3)
222-246: Track deprecated function removal.Multiple functions are marked with TODO comments indicating they should be removed:
move_next_read_buffer()(line 223)move_next_read_buffer(int N)(line 234)move_next_write_buffer()(line 246)These wrapper functions delegate to their preferred equivalents (
move_to_next_*). Consider creating an issue to track their removal in a future cleanup.Would you like me to open an issue to track the removal of these deprecated functions?
591-604: Consider extracting repeated XOR offset update pattern.This cascading if-else pattern for updating
smem_read_offset_based onkiandMMAS_Kthresholds is duplicated across multiple tile implementations (Volta, Turing, Ampere variants for both A and B tiles).The pattern appears at lines 592-603, 757-768, 982-992, 1265-1275, 1431-1441, 1664-1674, and more.
Consider extracting this into a helper template function to reduce duplication:
template <int MMAS_K_WITH_PADDING, int BYTES_PER_LDS, int SCALE = 1> inline __device__ void update_smem_read_offset_xor(int& offset, int ki) { static_assert(MMAS_K_WITH_PADDING < 64, "Not implemented"); if (MMAS_K_WITH_PADDING >= 32 && ki % 16 == 15) { offset ^= 31 * BYTES_PER_LDS * SCALE; } else if (MMAS_K_WITH_PADDING >= 16 && ki % 8 == 7) { offset ^= 15 * BYTES_PER_LDS * SCALE; } else if (MMAS_K_WITH_PADDING >= 8 && ki % 4 == 3) { offset ^= 7 * BYTES_PER_LDS * SCALE; } else if (MMAS_K_WITH_PADDING >= 4 && ki % 2 == 1) { offset ^= 3 * BYTES_PER_LDS * SCALE; } else if (MMAS_K_WITH_PADDING >= 2) { offset ^= 1 * BYTES_PER_LDS * SCALE; } }
1993-1998: Clarify or remove commented-out static_assert.This commented-out static assertion suggests there may be untested or unsupported configurations. If these constraints are no longer relevant, remove the dead code. If they're still important but temporarily disabled, add a comment explaining why.
- // static_assert(BYTES_PER_MMA_PER_CTA >= 128 || - // BYTES_PER_MMA_PER_CTA == 64 || - // (BYTES_PER_MMA_PER_CTA == 32 && - // (Mma_tile::MMAS_M == 4 || - // Mma_tile::MMAS_M == 2 || - // Mma_tile::MMAS_M == 1)), "");Either remove entirely or add a clarifying comment like:
// TODO: Re-enable this static_assert once all configurations are validated // static_assert(...);csrc/fmha_v2/fmha/gmem_tile_qkv.h (1)
137-147: Consider documenting theis_active_predicate usage pattern.The
load()overload (lines 138-147) usesis_active_to guard the load, while the otherload(Smem_tile&)overload (lines 124-135) does not. This asymmetry may be intentional for different use cases, but a brief comment explaining when each overload should be used would improve maintainability.csrc/fmha_v2/fmha/mask.h (1)
269-352: Noted: FIXME comment indicates incomplete support.The MtpMask implementation has a FIXME on line 346 noting that Volta/Hopper-GMMA traits are not yet supported for the
mtp_token_idx_array sizing. This should be tracked for future work.Would you like me to open an issue to track adding Volta/Hopper-GMMA trait support for MtpMask?
csrc/fmha_v2/fmha/gmem_tile_ps.h (1)
277-312: Remove commented-out debug code.Lines 303-306 contain commented-out debug code that should be cleaned up before merging.
Apply this diff to remove the debug code:
for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) { uint32_t acc_0 = fmha::hmul2(acc.reg(col_idx * ROWS_PER_THREAD + row_idx), scale); - // float one = 1.f; - // if(col_idx > 2){ - // acc_0 = float2_to_half2(one, one); - // } int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n; fmha::stg(ptr + offset, acc_0); }csrc/fmha_v2/fmha/numeric_types.h (1)
36-39: Fallbacke4m3_t/e5m2_tascharmay cause silent failures.When
FMHA_CUDA_SUPPORTS_FP8is not defined, these types becomechar. Any code performing arithmetic or conversions on FP8 types will compile but produce incorrect results. Consider using a static_assert or a distinct placeholder type that fails at compile time if FP8 operations are attempted on unsupported platforms.csrc/fmha_v2/fmha/hopper/arrive_wait.h (1)
373-378: Unused member variableid_.The
id_member is declared and assigned in the constructor but never read. Consider removing it if unused, or document its intended purpose.private: // smem barrier base pointer uint64_t* bar_base_; - // barrier id - int id_; };csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h (2)
128-142: Several data members appear unused.The members
preds_,fetch_, androw_are declared but never read or written in this class. They may be vestigial from copy-paste of the non-TMAGmem_tile_qkvclass. Consider removing them to reduce memory footprint and improve clarity.// The stride between rows for the QKV matrice. int64_t params_qkv_stride_in_bytes_; // The pointer. char* qkv_ptr_; - // The register to store predicates. - uint32_t preds_[PRED_REGS]; - // The fetch registers. - uint4 fetch_[LDGS]; - // Keep track of the row the thread is processing as we move the tile. - int row_; // The sequence length. int actual_seqlen_; // tma descriptor cudaTmaDesc const* p_desc_; // coord use by TMA. For now hard code to 3D. int32_t coord[3]; };
107-118: Emptycommit()andstore()methods may need documentation.These methods are empty, which is likely intentional for the TMA path (TMA handles commit/store differently). Adding a brief comment would clarify that this is by design and not a missing implementation.
// Store data to shared memory. + // Note: No-op for TMA path; TMA handles commit internally. template <typename Smem_tile> inline __device__ void commit(Smem_tile& smem_tile) {} - // Load data from memory. + // Load data from memory via TMA. template <typename Smem_tile> inline __device__ void load(Smem_tile& smem_tile) { smem_tile.template store<TMA_DIMS, TMA_DESC_TYPE>(p_desc_, coord); } - // Store data to memory. + // Store data to memory. No-op for TMA read-only path. inline __device__ void store(uint4 const (&data)[LDGS]) {}csrc/fmha_v2/fmha/hopper/smem_tile_o.h (1)
92-96: Redundant runtime check with dead code path.The
if (Mma_tile::MMAS_N == 1)check at line 92 is always true due to thestatic_assert(Mma_tile::MMAS_N == 1)at line 58. Theelsebranch withassert(false)is unreachable dead code.If this is intentional scaffolding for future multi-MMA-N support, consider adding a comment. Otherwise, simplify:
- if (Mma_tile::MMAS_N == 1) { - this->smem_write_ ^= 64; - } else { - assert(false && "Unsupported"); - } + // MMAS_N == 1 enforced by static_assert above + this->smem_write_ ^= 64;csrc/fmha_v2/fmha/hopper/fragment.h (1)
369-384: Consider usingif constexprfor compile-time type dispatch.The type dispatch using
std::is_same_vwill be evaluated at compile time, and the compiler should optimize away unreachable branches. However, usingif constexprwould make the compile-time nature explicit and guarantee dead code elimination:if constexpr (std::is_same_v<Input_type_A, e4m3_t> && std::is_same_v<Input_type_B, e4m3_t>) { // ... } else if constexpr (...) { // ... }This is a minor stylistic improvement that makes intent clearer.
csrc/fmha_v2/fmha/gmem_tile_o.h (1)
282-288: Clarify USE_DEMO_BERT_PARAMS macro purpose.The preprocessor conditional for
USE_DEMO_BERT_PARAMSsetso_scratch_ptr_to nullptr in one path. Consider adding a brief comment explaining when this macro is defined and why the scratch pointer behavior differs, to aid future maintainability.csrc/fmha_v2/fmha/hopper/smem_tile.h (2)
377-383: Commented-out buffer management code.The
move_next_write_buffer()implementations are commented out in multiple places (lines 377-383, 538-544). If this functionality is not needed, remove the commented code. If it's planned for future use, add a TODO comment explaining when/why it will be needed.
1326-1485: Large commented-out block in load_and_store method.The
load_and_storemethod (lines 1326-1485) contains ~160 lines of commented-out transpose logic. This significantly impacts readability. Consider:
- Removing if the code is obsolete
- Moving to a separate function with a clear TODO if planned for future implementation
- Adding a brief comment explaining why it's preserved if there's a specific reason
csrc/fmha_v2/fmha/hopper/compute_tile.h (2)
273-277: Commented-out static_asserts and debug artifact.Lines 273-277 contain commented-out
static_assertstatements (including a typopstatic_assert). These appear to be debug artifacts that should either be removed or properly enabled.- // static_assert(Cta_tile::K == 128); - // static_assert(Mma_tile::K_PER_MMA_PER_CTA == 64 ); - // pstatic_assert(NUM_KBLOCKS == 384 / 64); static constexpr int NUM_KBLOCKS = Smem_tile_b::BUFFERS_PER_TILE / Cta_tile::WARPS_K; - // static_assert(NUM_KBLOCKS * Cta_tile::WARPS_K == Smem_tile_b::BUFFERS_PER_TILE);
488-491: Outdated comment regarding fragment declaration.The comment at line 489 asks "is is better to declare as Fragment a_?" which appears to be a leftover design note. Consider removing or updating this comment.
// The fragments to load A. - // Need to think about is is better to declare as Fragment a_? - // for the second GEMM, MMAS_M is most likely 1. (at least for now. ) Fragment a_[MMAS_M];csrc/fmha_v2/fmha/hopper/gmma_descriptor.h (3)
119-132: Complex nested ternary forBYTES_PER_DESCis hard to maintain.The nested ternary expression spanning lines 119-132 is difficult to read and verify. Consider refactoring to a constexpr helper function or using
if constexprfor clarity.Consider restructuring as:
static constexpr uint32_t compute_bytes_per_desc() { if constexpr (Gmma_vector_size == Gmma_descriptor_size::ALL) return 0; if constexpr (Gmma_trans == Gmma_descriptor_transpose::TRANS) { if constexpr (Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) return GMMA_K * BYTES_PER_LEADING_DIM; // ... etc } // ... } static constexpr uint32_t BYTES_PER_DESC = compute_bytes_per_desc();
260-268:increment_single_descriptormodifiesdesc[0]directly without bounds check.The method assumes
NUM_DESCRIPTORS >= 1. While the static_assert at line 88 enforcesGmma_vector_size == ONE, this assumption should be documented or an assert added for defensive coding.inline __device__ void increment_single_descriptor(bool last_of_kblock) { + static_assert(NUM_DESCRIPTORS >= 1, "At least one descriptor required"); // update smem start address, which is in lower 32bits. int2& tmp = reinterpret_cast<int2&>(desc[0]);
514-538: Three-argumentincrement_single_descriptoroverload has complex branching.The overload at lines 516-538 with
last_of_kblockandswitch_kblockparameters has nested conditionals that are difficult to follow. Consider adding a brief comment explaining the state machine or expected call sequences.+ // Advances descriptor address based on k-block traversal state: + // - switch_kblock: transitioning to a new k-block group + // - last_of_kblock: at the boundary requiring reset within group inline __device__ void increment_single_descriptor(bool last_of_kblock, bool switch_kblock) {csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (2)
151-152: Unused member variableis_active_for_last_stg_.The member
is_active_for_last_stg_at line 151 is declared but never assigned or used in this class or its derived classes in this file.- // Is the thread active for the last STG? - int is_active_for_last_stg_;
879-880: Multiple unused member variables.
is_active_for_last_stg_(line 880) is never usedparams_enable_i2f_trick_(line 899) is declared asconst bool = falsebut never referenced- // Is the thread active for the last STG? - int is_active_for_last_stg_; ... - bool const params_enable_i2f_trick_ = false;Also applies to: 899-899
csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h (3)
539-539: Inconsistent type forcol_in_bytes_across tile classes.In
Gmem_tile_q_k_v,col_in_bytes_is declared asint64_t(line 539), but inGmem_tile_qkv(line 331) it'sint. This inconsistency could cause issues if these types are used interchangeably or in template contexts.Consider unifying the type to
int64_tacross all tile classes for consistency and to avoid potential overflow with large column offsets:// In Gmem_tile_qkv (line 331): - int col_in_bytes_; + int64_t col_in_bytes_;
860-861:past_seqlen_calculated but never used.In
Gmem_tile_paged_kv,past_seqlen_is calculated at line 861 and stored at line 975 but is never referenced in any method.If this is planned for future sliding window support, add a comment. Otherwise, remove the unused member:
- past_seqlen_(binfo.actual_seqlen - binfo.actual_q_seqlen), ... - // The past sequence length (kv_seqlen - q_seqlen) considering chunked context. - int past_seqlen_;
976-977:sliding_window_size_stored but never used.Similar to
past_seqlen_, thesliding_window_size_member is initialized but never used inGmem_tile_paged_kv. The template parameterSLIDING_WINDOW_ATTENTION_is available but the runtime value appears unused.csrc/fmha_v2/fmha/kernel_traits.h (1)
266-271: Consider documenting buffer count logic.The buffer count logic at lines 266-271 has non-trivial conditions (e.g.,
USE_GRANULAR_TILING && D > 64). A brief comment explaining when 2 buffers are needed vs 1 would improve maintainability.+ // Use 2 buffers for software pipelining when: + // - Q: granular tiling with large D, or LDGSTS with looping + // - K/V: granular tiling only enum { BUFFERS_PER_TILE_SMEM_Q = (USE_GRANULAR_TILING && D > 64) || (USE_LDGSTS_Q && !NO_LOOP) ? 2 : 1 };csrc/fmha_v2/fmha/gmem_tile_o_packed.h (2)
800-800: Unused variablerow_ptr.This variable is defined but never used in the
storefunction.Stg_packed_type dst = Acc_packer<Src_type, Dst_type>::run(this, src[ii]); - float const* row_ptr = reinterpret_cast<float const*>(&src[ii]);
1119-1119: Unused variablerow_ptr.Same issue as in
Gmem_tile_o_8bit::store- this variable is defined but never used.Stg_packed_type dst = Acc_packer<Src_type, Output_type>::run(this, src[ii]); - float const* row_ptr = reinterpret_cast<float const*>(&src[ii]);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (31)
csrc/fmha_v2/convert.cu(1 hunks)csrc/fmha_v2/fmha/alibi_params.h(1 hunks)csrc/fmha_v2/fmha/fragment.h(1 hunks)csrc/fmha_v2/fmha/gemm.h(1 hunks)csrc/fmha_v2/fmha/gmem_tile_o.h(1 hunks)csrc/fmha_v2/fmha/gmem_tile_o_packed.h(1 hunks)csrc/fmha_v2/fmha/gmem_tile_ps.h(1 hunks)csrc/fmha_v2/fmha/gmem_tile_qkv.h(1 hunks)csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h(1 hunks)csrc/fmha_v2/fmha/hopper/arrive_wait.h(1 hunks)csrc/fmha_v2/fmha/hopper/compute_tile.h(1 hunks)csrc/fmha_v2/fmha/hopper/fragment.h(1 hunks)csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h(1 hunks)csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h(1 hunks)csrc/fmha_v2/fmha/hopper/gmma_descriptor.h(1 hunks)csrc/fmha_v2/fmha/hopper/kernel_traits.h(1 hunks)csrc/fmha_v2/fmha/hopper/smem_tile.h(1 hunks)csrc/fmha_v2/fmha/hopper/smem_tile_o.h(1 hunks)csrc/fmha_v2/fmha/hopper/tma_descriptor.h(1 hunks)csrc/fmha_v2/fmha/hopper/tma_types.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_gmma.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_hgmma.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_igmma.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_tma.h(1 hunks)csrc/fmha_v2/fmha/hopper/utils_warpgroup.h(1 hunks)csrc/fmha_v2/fmha/kernel_traits.h(1 hunks)csrc/fmha_v2/fmha/mask.h(1 hunks)csrc/fmha_v2/fmha/numeric_types.h(1 hunks)csrc/fmha_v2/fmha/paged_kv_cache.h(1 hunks)csrc/fmha_v2/fmha/smem_tile.h(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/fmha_v2/fmha/gemm.hcsrc/fmha_v2/fmha/hopper/fragment.hcsrc/fmha_v2/convert.cucsrc/fmha_v2/fmha/hopper/kernel_traits.hcsrc/fmha_v2/fmha/hopper/compute_tile.hcsrc/fmha_v2/fmha/hopper/smem_tile_o.hcsrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.hcsrc/fmha_v2/fmha/gmem_tile_o_packed.hcsrc/fmha_v2/fmha/hopper/utils_hgmma.h
🧬 Code graph analysis (13)
csrc/fmha_v2/fmha/hopper/arrive_wait.h (3)
csrc/fmha_v2/fmha/hopper/smem_tile.h (1)
fmha(22-562)csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h (2)
bar_create(795-806)bar_wait(853-869)csrc/fmha_v2/fmha/warpspec/circular_buffer.h (10)
int(53-53)int(56-56)int(59-63)int(66-69)int(71-79)int(82-86)int(129-131)int(134-137)int(148-148)int(203-203)
csrc/fmha_v2/fmha/alibi_params.h (1)
csrc/fmha_v2/fmha/softmax.h (1)
float(2543-2545)
csrc/fmha_v2/fmha/gmem_tile_qkv.h (2)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
int(186-186)csrc/fmha_v2/fmha/utils.h (1)
pack_predicates(1366-1370)
csrc/fmha_v2/fmha/hopper/gmma_descriptor.h (4)
csrc/fmha_v2/fmha/hopper/compute_tile.h (1)
fmha(16-240)csrc/fmha_v2/fmha/hopper/smem_tile.h (1)
fmha(22-562)csrc/fmha_v2/fmha/mask.h (2)
int(295-301)int(303-309)csrc/fmha_v2/fmha/warpspec/epilogue.h (1)
int(95-104)
csrc/fmha_v2/fmha/hopper/fragment.h (4)
csrc/fmha_v2/fmha/fragment.h (17)
fmha(20-182)add(694-698)add(820-824)add(865-869)void(284-289)void(292-294)void(297-299)void(308-313)void(316-318)void(398-403)void(560-577)void(592-628)void(643-655)void(700-704)void(741-753)void(826-830)void(871-875)csrc/fmha_v2/fmha/mask.h (2)
int(295-301)int(303-309)csrc/fmha_v2/fmha/traits.h (4)
public(264-264)public(356-359)public(433-436)public(481-493)csrc/fmha_v2/fmha/hopper/utils_qgmma.h (6)
qgmma_rfa_e4m3_e4m3_fp32(438-441)qgmma_rfa_e5m2_e4m3_fp32(1550-1553)qgmma_rfa_e4m3_e5m2_fp32(994-997)qgmma_rfa_e5m2_e5m2_fp32(2082-2085)qgmma_e4m3_e5m2_fp32(716-719)qgmma_e5m2_e5m2_fp32(1804-1807)
csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h (3)
csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h (4)
fmha(17-276)v2(18-213)commit(217-221)load(225-243)csrc/fmha_v2/fmha/hopper/tma_descriptor.h (1)
fmha(16-121)csrc/fmha_v2/fmha/warpspec/compute.h (1)
int(186-186)
csrc/fmha_v2/fmha/mask.h (2)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
int(186-186)csrc/fmha_v2/fmha/traits.h (4)
public(264-264)public(356-359)public(433-436)public(481-493)
csrc/fmha_v2/fmha/hopper/kernel_traits.h (2)
csrc/fmha_v2/fmha/gmem_tile_qkv.h (1)
fmha(15-162)csrc/fmha_v2/fmha/hopper/compute_tile.h (7)
fmha(16-240)Compute_tile_with_gmma(51-51)Compute_tile_with_gmma(54-56)Compute_tile_with_gmma(58-81)Compute_tile_with_gmma(293-293)Compute_tile_with_gmma(296-298)Compute_tile_with_gmma(300-312)
csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h (2)
csrc/fmha_v2/fmha/gemm.h (1)
fmha(18-35)csrc/fmha_v2/fmha/mask.h (2)
int(295-301)int(303-309)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h (1)
csrc/fmha_v2/fmha/smem_tile_o.h (16)
Smem_tile_o_base_8bit_mma(1284-1360)store(250-286)store(695-697)store(702-704)store(727-729)Smem_tile_o(104-197)Smem_tile_o(691-691)Smem_tile_o(723-723)Smem_tile_o(753-753)Smem_tile_o(802-820)Smem_tile_o(957-975)Smem_tile_o(1452-1452)Smem_tile_o(1466-1466)Smem_tile_o(1480-1480)Smem_tile_o(1494-1494)Smem_tile_o(1508-1508)
csrc/fmha_v2/fmha/smem_tile.h (1)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
int(186-186)
csrc/fmha_v2/fmha/kernel_traits.h (1)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
int(186-186)
csrc/fmha_v2/fmha/gmem_tile_o_packed.h (1)
csrc/fmha_v2/fmha/gmem_tile_o.h (1)
fmha(18-426)
🪛 Clang (14.0.6)
csrc/fmha_v2/fmha/hopper/utils_warpgroup.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gemm.h
[error] 15-15: 'fmha/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/arrive_wait.h
[error] 66-66: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 66-66: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_gmma.h
[error] 15-15: 'fmha/hopper/utils_hgmma.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/alibi_params.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gmem_tile_qkv.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/paged_kv_cache.h
[error] 15-15: 'cuda_runtime.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/numeric_types.h
[error] 13-13: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/gmma_descriptor.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_tma.h
[error] 15-15: 'fmha/hopper/tma_types.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/fragment.h
[error] 15-15: 'fmha/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gmem_tile_o.h
[error] 15-15: 'fmha/traits.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/tma_descriptor.h
[error] 14-14: 'fmha/hopper/tma_types.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h
[error] 14-14: 'fmha/traits.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_igmma.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/mask.h
[error] 15-15: 'fmha/traits.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/kernel_traits.h
[error] 14-14: 'fmha/gmem_tile_qkv_packed.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/smem_tile.h
[error] 15-15: 'fmha/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/compute_tile.h
[error] 14-14: 'fmha/hopper/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h
[error] 15-15: 'fmha/smem_tile_o.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h
[error] 14-14: 'fmha/traits.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h
[error] 14-14: 'fmha/gmem_tile_o_packed.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/tma_types.h
[error] 15-15: 'fmha/numeric_types.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/smem_tile.h
[error] 15-15: 'fmha/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/kernel_traits.h
[error] 15-15: 'fmha/alibi_params.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gmem_tile_o_packed.h
[error] 15-15: 'fmha/numeric_types.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/gmem_tile_ps.h
[error] 14-14: 'fmha/hopper/fragment.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/fragment.h
[error] 15-15: 'fmha/traits.h' file not found
(clang-diagnostic-error)
csrc/fmha_v2/fmha/hopper/utils_hgmma.h
[error] 15-15: unknown type name 'namespace'
(clang-diagnostic-error)
[error] 15-15: expected ';' after top level declarator
(clang-diagnostic-error)
| template <typename T> | ||
| __global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { | ||
| using Dst = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type; | ||
|
|
||
| // The step. | ||
| size_t step = (size_t)gridDim.x * blockDim.x; | ||
|
|
||
| // Iterate over the elements. | ||
| for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { | ||
| // Load 4 floats. | ||
| float4 tmp = reinterpret_cast<float4 const*>(src)[ii]; | ||
| // Scale. | ||
| tmp.x *= scale; | ||
| tmp.y *= scale; | ||
| tmp.z *= scale; | ||
| tmp.w *= scale; | ||
| // Convert to 4 Ts. | ||
| auto out = pack_float4<T>(tmp); | ||
|
|
||
| // Store. | ||
| reinterpret_cast<Dst*>(dst)[ii] = reinterpret_cast<Dst const&>(out); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same tail-handling issue exists in convert_fp32_to_T_kernel.
This kernel also iterates n / 4, skipping the last 1-3 elements when n is not divisible by 4. The same fix (tail loop or alignment assertion) should be applied here.
Apply a similar fix as suggested for convert_int32_to_int8_kernel:
template <typename T>
__global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) {
using Dst = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type;
+ // Note: n must be a multiple of 4. Caller must ensure proper alignment.
+ assert(n % 4 == 0);
// The step.
size_t step = (size_t)gridDim.x * blockDim.x;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| template <typename T> | |
| __global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { | |
| using Dst = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type; | |
| // The step. | |
| size_t step = (size_t)gridDim.x * blockDim.x; | |
| // Iterate over the elements. | |
| for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { | |
| // Load 4 floats. | |
| float4 tmp = reinterpret_cast<float4 const*>(src)[ii]; | |
| // Scale. | |
| tmp.x *= scale; | |
| tmp.y *= scale; | |
| tmp.z *= scale; | |
| tmp.w *= scale; | |
| // Convert to 4 Ts. | |
| auto out = pack_float4<T>(tmp); | |
| // Store. | |
| reinterpret_cast<Dst*>(dst)[ii] = reinterpret_cast<Dst const&>(out); | |
| } | |
| } | |
| template <typename T> | |
| __global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { | |
| using Dst = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type; | |
| // Note: n must be a multiple of 4. Caller must ensure proper alignment. | |
| assert(n % 4 == 0); | |
| // The step. | |
| size_t step = (size_t)gridDim.x * blockDim.x; | |
| // Iterate over the elements. | |
| for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { | |
| // Load 4 floats. | |
| float4 tmp = reinterpret_cast<float4 const*>(src)[ii]; | |
| // Scale. | |
| tmp.x *= scale; | |
| tmp.y *= scale; | |
| tmp.z *= scale; | |
| tmp.w *= scale; | |
| // Convert to 4 Ts. | |
| auto out = pack_float4<T>(tmp); | |
| // Store. | |
| reinterpret_cast<Dst*>(dst)[ii] = reinterpret_cast<Dst const&>(out); | |
| } | |
| } |
| template <typename T> | ||
| __global__ void convert_T_to_fp32_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { | ||
| using Src = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type; | ||
|
|
||
| union { | ||
| Src raw; | ||
| T elt[4]; | ||
| } data; | ||
|
|
||
| // The step. | ||
| size_t step = (size_t)gridDim.x * blockDim.x; | ||
|
|
||
| // Iterate over the elements. | ||
| for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { | ||
| // Load 4 floats. | ||
| data.raw = reinterpret_cast<Src const*>(src)[ii]; | ||
| float4 out; | ||
| // Scale. | ||
| out.x = float(data.elt[0]) * scale; | ||
| out.y = float(data.elt[1]) * scale; | ||
| out.z = float(data.elt[2]) * scale; | ||
| out.w = float(data.elt[3]) * scale; | ||
|
|
||
| // Store. | ||
| reinterpret_cast<float4*>(dst)[ii] = reinterpret_cast<float4 const&>(out); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same tail-handling issue exists in convert_T_to_fp32_kernel.
This kernel also skips trailing elements when n % 4 != 0. Add an assertion or tail-handling loop for consistency.
🤖 Prompt for AI Agents
In csrc/fmha_v2/convert.cu around lines 123-149 the convert_T_to_fp32_kernel
iterates only over n/4 packed groups and thus silently skips trailing elements
when n % 4 != 0; fix by handling the tail: after the existing loop that
processes groups of 4, compute remainder = n % 4 and the start index of the tail
= (n/4)*4, then have each thread process any remaining scalar elements whose
global index < n by loading the appropriate Src/element or byte-offset and
converting them to float* dst with the same scale, or alternately add an
assertion enforcing n % 4 == 0 if non-handling is intended; ensure all pointer
casts and bounds checks use the same element indexing as the main loop so no
out-of-bounds accesses occur.
| float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; | ||
| float curr_sum_[ROWS_PER_THREAD] = {0}; | ||
| float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; | ||
| ; | ||
| float prev_sum_[ROWS_PER_THREAD] = {0}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Array member initializers only set the first element.
In C++, = {-HUGE_VALF} only initializes the first element; remaining elements are value-initialized to 0.0f. If the parameterized constructor (lines 1678-1694) is called without subsequently calling the default constructor's initialization logic, the arrays will have incorrect values (first element -HUGE_VALF, rest 0.0f).
Also, line 1824 has a stray semicolon.
- float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
- float curr_sum_[ROWS_PER_THREAD] = {0};
- float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
- ;
- float prev_sum_[ROWS_PER_THREAD] = {0};
+ float curr_max_[ROWS_PER_THREAD];
+ float curr_sum_[ROWS_PER_THREAD];
+ float prev_max_[ROWS_PER_THREAD];
+ float prev_sum_[ROWS_PER_THREAD];Then ensure both constructors call a common initialization helper, or initialize in the member initializer list of each constructor.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/fragment.h around lines 1821-1825, the array initializers
like "= {-HUGE_VALF}" only set the first element (others become 0.0f) and line
1824 contains a stray semicolon; update the code to explicitly initialize all
elements (e.g. use std::fill_n or std::array and fill(), or a loop) so curr_max_
and prev_max_ are all -HUGE_VALF and the sum arrays are all 0.0f, remove the
stray semicolon, and refactor the two constructors to call a single init helper
(or perform the full initialization in each constructor’s member initializer
list) to guarantee consistent initialization regardless of which constructor is
used.
| // Update scales. | ||
| float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; | ||
| float curr_sum_[ROWS_PER_THREAD] = {0}; | ||
| float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; | ||
| float prev_sum_[ROWS_PER_THREAD] = {0}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same array initialization issue in Fragment_updater_ampere_fp32.
Same concern as Fragment_updater - brace initialization only sets the first array element.
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/fragment.h around lines 2007-2011, the four local arrays
curr_max_, curr_sum_, prev_max_, prev_sum_ are incorrectly initialized with a
single braced value (which only sets element 0); replace the brace initializers
with explicit full initialization (e.g., a small for-loop setting all
ROWS_PER_THREAD entries to -HUGE_VALF or 0 respectively, or use
std::fill_n/std::fill) so every element is initialized correctly.
| float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; | ||
| float curr_sum_[ROWS_PER_THREAD] = {0}; | ||
| float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; | ||
| float prev_sum_[ROWS_PER_THREAD] = {0}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same array initialization issue in Fragment_updater<Volta_hmma_fp16_16x16x16_traits>.
Same pattern with incomplete array initialization via brace initializers.
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/fragment.h around lines 2215 to 2218, the brace
initializers only set the first array element leaving the rest uninitialized;
replace the brace initialization with explicit full initialization (e.g. use a
loop or std::fill_n/ std::fill to set all ROWS_PER_THREAD entries of curr_max_
and prev_max_ to -HUGE_VALF and curr_sum_ and prev_sum_ to 0) so every element
is initialized deterministically.
| // The 512 bit of descritptro for im2col mode. | ||
| typedef struct { | ||
| uint64_t tensor_common0; | ||
| uint32_t tensor_common1; | ||
|
|
||
| uint32_t tensor_stride_lower[4]; | ||
| uint32_t tensor_stride_upper; | ||
| uint32_t tensor_size[5]; | ||
| uint32_t traversal_stride_range_c; | ||
|
|
||
| uint32_t box_corner_dhw; | ||
| uint32_t range_ndhw; | ||
| } cudaTmaDescIm2Col; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor typo in comment: "descritptro" should be "descriptor".
Line 100 has a typo in the comment.
-// The 512 bit of descritptro for im2col mode.
+// The 512 bit of descriptor for im2col mode.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // The 512 bit of descritptro for im2col mode. | |
| typedef struct { | |
| uint64_t tensor_common0; | |
| uint32_t tensor_common1; | |
| uint32_t tensor_stride_lower[4]; | |
| uint32_t tensor_stride_upper; | |
| uint32_t tensor_size[5]; | |
| uint32_t traversal_stride_range_c; | |
| uint32_t box_corner_dhw; | |
| uint32_t range_ndhw; | |
| } cudaTmaDescIm2Col; | |
| // The 512 bit of descriptor for im2col mode. | |
| typedef struct { | |
| uint64_t tensor_common0; | |
| uint32_t tensor_common1; | |
| uint32_t tensor_stride_lower[4]; | |
| uint32_t tensor_stride_upper; | |
| uint32_t tensor_size[5]; | |
| uint32_t traversal_stride_range_c; | |
| uint32_t box_corner_dhw; | |
| uint32_t range_ndhw; | |
| } cudaTmaDescIm2Col; |
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/tma_types.h around lines 100 to 112, fix the typo in
the comment "descritptro" by replacing it with "descriptor" so the comment reads
"The 512 bit descriptor for im2col mode."; leave the surrounding code and
formatting unchanged.
|
|
||
| namespace fmha { | ||
|
|
||
| inline __device__ uint32_t elect_one_sync(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elect_one_sync is declared but not defined.
This function is declared but has no definition in this file. This will cause a linker error if called. Either provide a definition or include the header where it's defined.
#!/bin/bash
# Search for elect_one_sync definition in the codebase
ast-grep --pattern 'inline __device__ uint32_t elect_one_sync() {
$$$
}'🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/utils_tma.h around line 20, elect_one_sync is
declared as inline __device__ uint32_t elect_one_sync(); but no definition is
provided, which will cause a linker error; fix by either (A) adding the inline
__device__ definition for elect_one_sync in this header (implement the function
body here so it remains header-only and compiled into device code), or (B)
include the header that contains the existing definition (add the appropriate
#include at top of this file) and remove the stray declaration to avoid
duplicate declarations; ensure the chosen definition is marked inline/__device__
and matches all call sites to prevent ODR/linker issues.
| inline __device__ void tmastg_arrive() { | ||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 | ||
| asm volatile("cp.async.bulk.commit_group;"); | ||
| #else | ||
| assert(false); | ||
| #endif | ||
| } | ||
|
|
||
| inline __device__ void tmastg_wait() { | ||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 | ||
| asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory"); | ||
| #else | ||
| assert(false); | ||
| #endif | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert(false) as fallback will cause silent runtime abort in release builds.
On non-SM90 architectures, tmastg_arrive() and tmastg_wait() call assert(false), which is removed in release builds (NDEBUG defined), leaving an empty function. This could lead to subtle bugs where TMA operations appear to succeed but do nothing.
Consider using a compile-time error or a runtime trap that isn't removed in release builds:
inline __device__ void tmastg_arrive() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.commit_group;");
#else
- assert(false);
+ __trap(); // Always terminates, not stripped in release
#endif
}
inline __device__ void tmastg_wait() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory");
#else
- assert(false);
+ __trap();
#endif
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| inline __device__ void tmastg_arrive() { | |
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 | |
| asm volatile("cp.async.bulk.commit_group;"); | |
| #else | |
| assert(false); | |
| #endif | |
| } | |
| inline __device__ void tmastg_wait() { | |
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 | |
| asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory"); | |
| #else | |
| assert(false); | |
| #endif | |
| } | |
| inline __device__ void tmastg_arrive() { | |
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 | |
| asm volatile("cp.async.bulk.commit_group;"); | |
| #else | |
| __trap(); // Always terminates, not stripped in release | |
| #endif | |
| } | |
| inline __device__ void tmastg_wait() { | |
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 | |
| asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory"); | |
| #else | |
| __trap(); | |
| #endif | |
| } |
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/utils_tma.h around lines 138-152, the current
fallback uses assert(false) which is removed in release builds leaving empty
no-op functions; change the fallback to a non-removable failure: either emit a
compile-time failure for non-SM90 builds (e.g., preprocessor #error when
__CUDA_ARCH__ is not >=900) to prevent building unsupported code, or replace
assert(false) with a runtime trap that is not stripped in release (e.g., call a
device-level abort/trap such as __builtin_trap() or an equivalent GPU trap
intrinsic) so that calls on unsupported architectures fail loudly at runtime.
Ensure the chosen approach is used in both tmastg_arrive() and tmastg_wait().
| template <typename T> | ||
| __host__ __device__ constexpr inline float Softmax_fp_quant_scale(); | ||
|
|
||
| template <> | ||
| __host__ __device__ constexpr inline float Softmax_fp_quant_scale<e4m3_t>() { | ||
| // Softmax has max output of 1.0, therefore we choose fp32-to-fp8 quantization scale as the | ||
| // largest power-of-2 below the e4m3 limit: | ||
| // 2^(floor(log2(E4M3_MAX / amax_exp_p))) = 2^(floor(log2(448 / 1))) = 2 ^ 8 | ||
| return 256.f; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Script executed:
#!/bin/bash
# Search for Softmax_fp_quant_scale usage with e5m2_t
rg "Softmax_fp_quant_scale.*e5m2" --type cpp -A 2 -B 2
# Also search more broadly for Softmax_fp_quant_scale usage
echo "=== All Softmax_fp_quant_scale usages ==="
rg "Softmax_fp_quant_scale" --type cpp -A 2 -B 2Repository: flashinfer-ai/flashinfer
<shell_result>
Repository: flashinfer-ai/flashinfer
Script execution failed with unrecoverable error: [CAPABILITY:RUN_SHELL_SCRIPT] Failed to clone repository flashinfer-ai/flashinfer: Clone operation failed
</shell_result>
flashinfer e5m2_t quantization type support
<web_search_result>
Yes — FlashInfer supports the FP8 E5M2 format (e5m2 / __nv_fp8_e5m2). It is listed among supported data/quantized types and used in FP8 kernels (CUTLASS / TRT-LLM) and vector layouts. [1][2][3]
References:
- Vector/data types (includes __nv_fp8_e5m2). [1]
- Core components / quantization formats (FP8 E5M2 supported). [2]
- FP8/FP4 quantization docs and formats (E5M2 listed). [3]
</web_search_result>
site:github.com flashinfer Softmax_fp_quant_scale e5m2
<web_search_result>
I found several GitHub results mentioning flashinfer and the terms softmax/Softmax_fp_quant_scale and fp8_e5m2 (KV-cache dtype). Relevant pages:
- FlashInfer repo (main project). [1]
- Issue discussing float8 (fp8_e5m2) support / dlpack error. [2]
- FlashInfer issue about using low-bit KV cache (per-head fp8/int8). [3]
- vLLM issue threads reporting errors when using FlashInfer with fp8_e5m2 KV cache. [4]
Sources:
[1] https://github.com/flashinfer-ai/flashinfer. (github.com)
[2] Issue: "Is float8 supported for flashinfer trtllm attention?" (flashinfer-ai/flashinfer). (github.com)
[3] Issue: "How to use low bit KV Cache" (flashinfer-ai/flashinfer). (github.com)
[4] vLLM issues about FlashInfer + FP8 KV cache (examples). (github.com)
If you want, I can open any of those pages and extract specific lines or error messages.
Citations:
- 1: https://github.com/flashinfer-ai/flashinfer?utm_source=openai
- 2: Is float8 supported for flashinfer trtllm attention? #1957
- 3: How to use low bit KV Cache #721
- 4: [Bug]: vllm0.6.2 Using FLASHINFER to start VLLM reported an error, enabling -- quantification gptq -- kv cache dtype fp8_e5m2 vllm-project/vllm#9243
</web_search_result>
flashinfer Softmax_fp_quant_scale template instantiation
<web_search_result>
I don’t have enough detail to diagnose that template-instantiation problem. Please paste the exact compiler error message and these details:
- FlashInfer version (git commit or pip/package version)
- File and line where Softmax_fp_quant_scale is defined/used (or paste the function/template code)
- Exact compile command / compiler (g++, clang++) and flags
- OS and CUDA/cuDNN/NVCC versions (if building CUDA)
With that I can pinpoint the template-instantiation cause and suggest a fix.
</web_search_result>
Based on my verification:
- Web search confirmed that FlashInfer officially supports the e5m2_t type (
__nv_fp8_e5m2) as a quantized data type - Active issues exist around fp8_e5m2 support, KV cache usage, and vLLM integration, indicating e5m2_t is a real, actively-used feature
- Code snippet shows only
e4m3_tspecialization exists forSoftmax_fp_quant_scale, while the generic template declaration exists
The review comment's concern is valid: if Softmax_fp_quant_scale<e5m2_t>() is instantiated anywhere in the codebase (which is likely given e5m2_t is a supported type), it would fail to link since only the e4m3_t specialization is defined.
Add Softmax_fp_quant_scale specialization for e5m2_t.
The template is declared for generic T and specialized only for e4m3_t. Since FlashInfer supports e5m2_t as a quantized type, the missing specialization for e5m2_t will cause linker errors if instantiated. A similar quantization scale calculation is needed for the e5m2 format.
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/numeric_types.h around lines 44 to 53, add a missing
specialization for Softmax_fp_quant_scale<e5m2_t>() similar to the e4m3_t case;
implement it to compute the fp32-to-fp8 quantization scale as the largest
power-of-two below the e5m2 maximum (use the project constant for E5M2_MAX or
the equivalent value) and return that as a float (i.e., mirror the comment/logic
used for e4m3_t but referencing E5M2_MAX and e5m2_t so instantiation/linking for
e5m2_t succeeds).
| // Store to the tile in shared memory. | ||
| template <int N> | ||
| inline __device__ void store(Store_type const (&data)[N], uint32_t preds) { | ||
| this->store(data, preds); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Infinite recursion in store function.
This overload calls itself recursively without termination. The function signature matches itself, causing infinite recursion at runtime.
The intent appears to be calling the array-predicate overload at line 264. Apply this fix:
// Store to the tile in shared memory.
template <int N>
inline __device__ void store(Store_type const (&data)[N], uint32_t preds) {
- this->store(data, preds);
+ uint32_t tmp[1] = {preds};
+ this->store(data, tmp);
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // Store to the tile in shared memory. | |
| template <int N> | |
| inline __device__ void store(Store_type const (&data)[N], uint32_t preds) { | |
| this->store(data, preds); | |
| } | |
| // Store to the tile in shared memory. | |
| template <int N> | |
| inline __device__ void store(Store_type const (&data)[N], uint32_t preds) { | |
| uint32_t tmp[1] = {preds}; | |
| this->store(data, tmp); | |
| } |
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/smem_tile.h around lines 270-274, the template overload
store(Store_type const (&data)[N], uint32_t preds) calls itself causing infinite
recursion; replace the self-call with a call to the array-predicate overload at
line 264 by converting the scalar preds into a small uint32_t array (e.g.
uint32_t tmp[1] = { preds };) and then call the overload that accepts a
predicate array: this->store(data, tmp); ensure the temporary array has
appropriate scope and matches the expected predicate-array parameter type.
| lse_ref = torch.logsumexp(logits, -1) # (batch, num_heads, seq_len) | ||
| # Transpose to match expected shape (batch, seq_len, num_heads) | ||
| lse_ref = lse_ref.transpose(1, 2) | ||
| p = torch.softmax(logits, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
directly use lse_ref to get softmax results may be better?
| descale = 256 | ||
| lse = lse[:, :, :, 0] + torch.log(lse[:, :, :, 1] / descale) | ||
| else: | ||
| lse = lse[:, :, :, 0] + torch.log(lse[:, :, :, 1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be done by user outside API? and what lse used for outside attention, for training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should do this inside kernel.
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, let's merge this one first and implement the in-kernel LSE calculation in a future PR.
|
/bot run |
📌 Description
Porting over the trtllm fmhav2 library to support prefill cases.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Enhanced Support
✏️ Tip: You can customize this high-level summary in your review settings.