Skip to content

Conversation

@jimmyzho
Copy link
Contributor

@jimmyzho jimmyzho commented Nov 25, 2025

📌 Description

Porting over the trtllm fmhav2 library to support prefill cases.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • INT8 quantization and FP8 (E4M3/E5M2) conversion utilities, plus broad packed 8/16‑bit output paths.
    • Hopper GMMA/TMA optimizations and SM90 GMMA/IGMMA helpers for high‑performance kernels.
    • Extensive FMHA v2 tiling/load/store primitives (Q/K/V/O), TMA descriptor management, and paged KV cache.
  • Enhanced Support

    • Alibi positional-bias params, BF16/mixed-precision conversions, causal/sliding-window masks and multi‑token prediction.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 25, 2025

Important

Review skipped

Review was skipped as selected files did not have any reviewable changes.

💤 Files selected but had no reviewable changes (8)
  • csrc/fmha_v2/softmax_bf16.cu
  • csrc/fmha_v2/softmax_fp8.cu
  • csrc/trtllm_fmha_v2_binding.cu
  • flashinfer/jit/attention/fmha_v2/generate_kernels.py
  • flashinfer/jit/attention/fmha_v2/generator_utils.py
  • flashinfer/jit/attention/modules.py
  • flashinfer/prefill.py
  • tests/attention/test_fmha_v2_prefill_deepseek.py

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds extensive FMHA v2 CUDA infrastructure: numeric conversions, fragment/accumulator primitives, SMEM/GMEM tile loaders and storers (including packed/TMA paths), Hopper GMMA/TMA utilities, masking and kernel trait configurations across Volta/Turing/Ampere/Ada/Hopper targets.

Changes

Cohort / File(s) Summary
Numeric conversions & types
csrc/fmha_v2/convert.cu, csrc/fmha_v2/fmha/numeric_types.h
New conversion kernels and host wrappers for FP32↔FP16/BF16/FP8 (e4m3/e5m2) and int32→int8 with per-entry scaling; adds FP8/FP16/BF16 type aliases and softmax quantization scale constants.
Fragment & accumulator system
csrc/fmha_v2/fmha/fragment.h, csrc/fmha_v2/fmha/hopper/fragment.h
Introduces fragment primitives (ldg/lds/stg), Fragment/Fragment_accumulator templates and many architecture-specialized specializations (HMMA/IMMA/GMMA/QGMM/IGMMA) and Softmax_saver utilities.
SMEM tiling framework
csrc/fmha_v2/fmha/smem_tile.h, csrc/fmha_v2/fmha/hopper/smem_tile.h, csrc/fmha_v2/fmha/hopper/smem_tile_o.h
Adds shared-memory tile implementations with XOR/swizzle patterns, multi-buffer/TMA integration, Row/Col/GMMA variants and many architecture-specialized Smem_tile types.
GMEM tile loaders/storers
csrc/fmha_v2/fmha/gmem_tile_o.h, .../gmem_tile_o_packed.h, .../gmem_tile_ps.h, .../gmem_tile_qkv.h, .../gmem_tile_qkv_packed.h, .../hopper/gmem_tile_o_packed.h, .../hopper/gmem_tile_qkv_packed.h
Adds Gmem_tile types for Q/K/V/O (LDG/LDGSTS/TMA) with packing/quantization/packing helpers, 8/16/32-bit paths, interleaved/contiguous/paged layouts, and TMA support.
Hopper GMMA/TMA infra
csrc/fmha_v2/fmha/hopper/{arrive_wait,compute_tile,gmma_descriptor,kernel_traits,utils_gmma,utils_hgmma,utils_hgmma_bf16,utils_igmma,utils_tma,utils_warpgroup,tma_types,tma_descriptor}.h
Adds barrier (Arrive_wait), warpgroup utilities, GMMA/IGMMA/QGMMA wgmma wrappers, GMMA descriptor classes, TMA types and host-side multiple-descriptor manager, and Hopper-specific kernel trait wiring.
Kernel traits & dispatch
csrc/fmha_v2/fmha/kernel_traits.h, csrc/fmha_v2/fmha/hopper/kernel_traits.h
New comprehensive Kernel_traits templates and adapters (including Traits_reuse, Traits_o_adapter) to configure Q/K/V/O tiling, epilogues, LDGSTS vs TMA selection, and Hopper-specific trait sets.
Masking, GEMM, ALiBi, KV cache
csrc/fmha_v2/fmha/{mask,gemm,alibi_params,paged_kv_cache}.h
Adds multi-version masking (causal/sliding-window/MTP), small device GEMM wrapper, AlibiParams, and paged KV-cache block array struct.
Misc utilities & aggregators
csrc/fmha_v2/fmha/hopper/utils_gmma.h, various utils headers
Aggregates and exposes GMMA/TMA helper headers and low-level helpers (warpgroup, TMA/TMA descriptor helpers, quantize/pack utilities).

Sequence Diagram(s)

sequenceDiagram
    participant Host
    participant Kernel
    participant GMEM
    participant SMEM
    participant GMMA

    rect rgb(240,248,255)
    note right of Host: Setup
    Host->>Kernel: pass Params, TMA descriptors, pointers
    end

    rect rgb(255,248,240)
    note right of Kernel: QKV load phase
    Kernel->>GMEM: Gmem_tile_qkv.load() (LDG / LDGSTS / TMA)
    GMEM->>SMEM: write via Smem_tile (swizzled / TMA store)
    end

    rect rgb(248,255,240)
    note right of Kernel: Compute phase
    Kernel->>GMMA: Compute_tile_with_gmma.compute() -> wgmma/igmma/qgmma ops
    GMMA->>Kernel: accumulate into Fragment_accumulator
    Kernel->>Kernel: increment descriptors / advance tiles
    end

    rect rgb(240,240,255)
    note right of Kernel: Softmax & O write
    Kernel->>Kernel: Tile_o_normalizer (max/sum)
    Kernel->>GMEM: Gmem_tile_o.store() (quantize/pack, optional I2F trick)
    end
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Areas needing extra attention:

  • GMMA descriptor bitfield math and descriptor increment logic (hopper/gmma_descriptor.h).
  • TMA descriptor assembly and host→device copying (hopper/tma_descriptor.h, hopper/tma_types.h).
  • SMEM swizzle/XOR addressing and buffer management across Volta/Ampere/Hopper (smem_tile.h, hopper/smem_tile.h).
  • GMMA/IGMMA/QGMMA wrapper correctness and RF-from-RF/A-from-RF variants (hopper/utils_*.h, hopper/fragment.h).
  • Quantization/packing, scaling factors and I2F-emulation branches (convert.cu, gmem_tile_*_packed.h).
  • Mask progression correctness across FMHA versions (mask.h).
  • Conditional compilation and architecture guards for SM90/CP-async/TMA paths.

Suggested reviewers

  • djmmoss
  • yzh119
  • cyx-6
  • Anerudhan
  • aleozlx
  • wenscarl

Poem

🐰 Hop, hop — tiles align in rows,
Bits and descriptors dance in prose.
GMMA hums beneath Hopper's moon,
Pack, quantize, then store soon.
A rabbit cheers: "FMHA — go zoom!"

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.24% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The PR description provides a brief summary of what the PR does (porting TRTLLM FMHAv2 library) but is largely incomplete. The description template requires sections like '📌 Description', '🔍 Related Issues', and '🚀 Pull Request Checklist', and while the template structure is present, the actual completion of required checklist items is missing (all boxes unchecked) and key details are minimal. Complete the PR checklist items by confirming pre-commit setup, test additions, and verification. Expand the description with more detail about what changed and why.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: TRTLLM FMHAv2 backend for ctx attention' is clear and directly related to the changeset, which ports the TRTLLM FMHAv2 library for prefill context attention support.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request ports the trtllm fmhav2 library to FlashInfer, adding a new backend that supports prefill cases for context attention. This enhancement involves incorporating CUDA kernels for numerical type conversions and introducing fragment structures for optimized memory handling, significantly expanding FlashInfer's capabilities.

Highlights

  • Feature: Introduces a TRTLLM FMHAv2 backend to support prefill cases for context attention, leveraging the trtllm fmhav2 library.
  • Codebase Addition: Adds new CUDA files (convert.cu) and header files (alibi_params.h, fragment.h, gemm.h, gmem_tile_o.h, gmem_tile_o_packed.h, gmem_tile_qkv.h, gmem_tile_qkv_packed.h, hopper/, etc.) to implement the FMHAv2 backend.
  • Conversion Utilities: Implements CUDA kernels for converting between different numerical types (int32 to int8, fp32 to fp16/bf16/e4m3/e5m2, and e4m3/e5m2 to fp32) to ensure compatibility across different hardware and software components.
  • Fragment Management: Introduces fragment structures and utilities for efficient memory access and manipulation within CUDA kernels, optimizing data loading, storing, and processing.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ports the TensorRT-LLM FMHAv2 library to support prefill cases. The changes are extensive and introduce a large amount of low-level CUDA code, including kernels for various GPU architectures (Volta, Turing, Ampere, Hopper) and support for features like paged KV-cache and TMA. The code is highly templated and specialized for performance. My review focused on correctness and potential issues in the newly added files. I've found one issue related to unsafe type casting that could lead to incorrect behavior.

Comment on lines +46 to +49
out.x = reinterpret_cast<int8_t const&>(a);
out.y = reinterpret_cast<int8_t const&>(b);
out.z = reinterpret_cast<int8_t const&>(c);
out.w = reinterpret_cast<int8_t const&>(d);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of reinterpret_cast here is unsafe. The cvt.rni.sat.s8.f32 instruction writes the 8-bit result to the lower byte of the 32-bit destination register, but the contents of the upper 24 bits are not guaranteed to be zero. reinterpret_casting this to an int8_t is endian-dependent and relies on this assumption. Using static_cast is safer, more explicit, and avoids potential correctness issues.

    out.x = static_cast<int8_t>(a);
    out.y = static_cast<int8_t>(b);
    out.z = static_cast<int8_t>(c);
    out.w = static_cast<int8_t>(d);

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 16

🧹 Nitpick comments (30)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h (2)

48-98: Dead code branch due to prior static_assert.

Line 58 enforces Mma_tile::MMAS_N == 1 at compile time, making the condition at line 92 always true. The else branch (lines 94-96) is unreachable dead code.

Consider simplifying by removing the redundant conditional:

       // Each thread of a quad writes 16B per STS -> 64B per store.
-      if (Mma_tile::MMAS_N == 1) {
-        this->smem_write_ ^= 64;
-      } else {
-        assert(false && "Unsupported");
-      }
+      // MMAS_N == 1 enforced by static_assert above.
+      this->smem_write_ ^= 64;

Additionally, line 81 contains commented-out code. If it's no longer needed, consider removing it; otherwise, add a comment explaining why it's retained.


131-154: Cryptic comment and loop-internal static_assert placement.

  1. The comment at line 140 ("inplace multiples seem to be 1, 3, 1, 7, 1, 3, 1,") is unclear. Consider clarifying what these multiples represent or referencing documentation.

  2. The static_assert(MMAS_M_PER_LOOP == 1) at line 139 is placed inside the loop body. While functionally correct (compile-time check), placing it before the loop would be cleaner and more conventional.

   inline __device__ void store(Accumulator const (&acc)[1][1], int mi) {
     enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA };
+    // The number of MMAs that are stored per loop iteration.
+    enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS };
+    static_assert(MMAS_M_PER_LOOP == 1);

 #pragma unroll
     for (int ni = 0; ni < Mma_tile::CORES_N; ++ni) {
-      // The number of MMAs that are stored per loop iteration.
-      enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS };
-
-      static_assert(MMAS_M_PER_LOOP == 1);
-      // inplace multiples seem to be 1, 3, 1, 7, 1, 3, 1,
+      // XOR pattern for swizzled smem addressing (pattern: 0, 16, 32, 48, ...)
       auto smem_write = this->smem_write_ ^ (ni * 16);

This same pattern applies to the FP32 (lines 187-216), BF16 (lines 249-280), and other specializations.

csrc/fmha_v2/fmha/hopper/tma_types.h (1)

100-100: Minor typo in comment.

"descritptro" should be "descriptor".

-// The 512 bit of descritptro for im2col mode.
+// The 512 bit of descriptor for im2col mode.
csrc/fmha_v2/fmha/gmem_tile_o_packed.h (2)

799-801: Unused variable row_ptr.

The variable row_ptr on line 800 is declared but never used. This appears to be leftover debug code.

       // Packs the 32bit/16bit values to 8bit.
       // Depending on the type, applies extra scaling with parameter scale_bmm2.
       Stg_packed_type dst = Acc_packer<Src_type, Dst_type>::run(this, src[ii]);
-      float const* row_ptr = reinterpret_cast<float const*>(&src[ii]);

1286-1294: Consider removing commented-out debug code.

Lines 1287-1294 and 1310-1313 contain commented-out assertions and debug printf statements. These can be removed to improve code clarity.

csrc/fmha_v2/fmha/fragment.h (2)

1821-1825: Remove stray semicolon.

There's a stray semicolon after the prev_max_ declaration on line 1824.

   float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
-  ;
   float prev_sum_[ROWS_PER_THREAD] = {0};

1119-1121: NaN check pattern is intentional but non-obvious.

The pattern sum[jj] != sum[jj] is used to detect NaN (since NaN != NaN is true). While this works, it's non-obvious. Consider using isnan() for clarity, or add a brief inline comment explaining the idiom if performance is critical.

Also applies to: 1149-1149, 1413-1413, 1436-1436, 2142-2143, 2170-2171

csrc/fmha_v2/fmha/numeric_types.h (1)

19-22: Consider using a more robust FP8 feature detection.

The TODO comment suggests uncertainty about the approach. The current check CUDART_VERSION >= 11080 is reasonable but using true as the macro value is unconventional. A more typical pattern would be to just define the macro without a value.

 #if CUDART_VERSION >= 11080
-// TODO Better way?
-#define FMHA_CUDA_SUPPORTS_FP8 true
+#define FMHA_CUDA_SUPPORTS_FP8
 #endif

Then use #ifdef FMHA_CUDA_SUPPORTS_FP8 instead of #if FMHA_CUDA_SUPPORTS_FP8 throughout.

csrc/fmha_v2/fmha/paged_kv_cache.h (1)

46-58: Consider adding power-of-2 validation and using integer-based log2.

The comment on line 28-29 states mTokensPerBlock must be a power of 2, but this constraint is not enforced. Additionally, using floating-point log2() for an integer operation introduces unnecessary overhead and potential precision issues.

Apply this diff to add validation and use an integer-based approach:

   Kv_block_array(int32_t batchSize, int32_t maxBlocksPerSeq, int32_t tokensPerBlock,
                  int32_t bytesPerBlock, void* poolPtr)
       : mMaxSeqs(batchSize),
         mMaxBlocksPerSeq(maxBlocksPerSeq),
         mTokensPerBlock(tokensPerBlock),
         mBytesPerBlock{bytesPerBlock},
         mPoolPtr{poolPtr},
         mBlockOffsets{nullptr} {
-    float const tokensPerBlockSeqLog2 = log2(mTokensPerBlock);
-    mTokensPerBlockLog2 = static_cast<int>(tokensPerBlockSeqLog2);
+    // Validate power-of-2 constraint
+    assert((tokensPerBlock & (tokensPerBlock - 1)) == 0 && tokensPerBlock > 0);
+    // Use integer bit counting instead of floating-point log2
+    mTokensPerBlockLog2 = __builtin_ctz(static_cast<uint32_t>(mTokensPerBlock));
   }
csrc/fmha_v2/fmha/hopper/utils_tma.h (1)

138-152: Runtime assert on unsupported architectures may cause unexpected failures.

The tmastg_arrive() and tmastg_wait() functions call assert(false) when compiled for architectures below SM90. This could cause silent failures in release builds (where assert is typically a no-op) or unexpected crashes in debug builds if these functions are accidentally called.

Consider using a static assertion or a more explicit compile-time guard if these functions should never be called on older architectures:

 inline __device__ void tmastg_arrive() {
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
   asm volatile("cp.async.bulk.commit_group;");
 #else
-  assert(false);
+  // These functions are only valid on Hopper+ architectures
+  static_assert(__CUDA_ARCH__ >= 900, "tmastg_arrive requires SM90+");
 #endif
 }

 inline __device__ void tmastg_wait() {
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
   asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory");
 #else
-  assert(false);
+  static_assert(__CUDA_ARCH__ >= 900, "tmastg_wait requires SM90+");
 #endif
 }

Alternatively, if backward compatibility is intended but calling these is an error, use a more informative runtime check that works in release builds.

csrc/fmha_v2/fmha/hopper/tma_descriptor.h (1)

341-345: Unused member desc_ptr_d.

The member desc_ptr_d is declared but never initialized or used within the class. The comment says "Device desc ptr should be allocated outside the class and reused" (line 22), but having an unused member is confusing.

Consider removing the unused member or documenting its intended external usage pattern:

   // The TMA descriptor. Each is of 512 bit.
   cudaTmaDesc* desc_ptr_h;
-  // The TMA descriptor on the device memory.
-  cudaTmaDesc* desc_ptr_d;
   // Number of batches
   int batch_size = 0;
 };
csrc/fmha_v2/fmha/hopper/arrive_wait.h (1)

373-378: Unused member variable id_.

The id_ member is stored in the constructor (line 112) but never used anywhere in the class methods. All methods take id as a parameter instead.

Consider removing id_ if it's not needed, or using it in methods that currently take id as a parameter:

  private:
   // smem barrier base pointer
   uint64_t* bar_base_;
-  // barrier id
-  int id_;
 };
csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h (1)

128-142: Unused member variables declared.

The members preds_, fetch_, and row_ are declared but never used in the visible methods (commit, load, move, store). If these are intentionally reserved for future use or subclass consumption, consider adding a comment. Otherwise, they contribute to unnecessary memory overhead per tile instance.

csrc/fmha_v2/fmha/gmem_tile_o.h (2)

156-158: Consider using bool type for predicate flags.

The is_active_for_last_stg_ member is declared as int but used as a boolean predicate. Using bool would better convey intent and potentially allow compiler optimizations.

-  // Is the thread active for the last STG?
-  int is_active_for_last_stg_;
+  // Is the thread active for the last STG?
+  bool is_active_for_last_stg_;

424-425: Same type consistency suggestion for boolean flags.

Similar to Hmma_gmem_tile_o, consider using bool for is_active_ and is_active_for_last_stg_ to better express intent.

csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h (1)

228-231: Unused template parameter in wrapper function.

The bool /*ignored*/ parameter in hgmma_bf16 serves no purpose. If this is for API consistency with other GMMA wrappers, consider adding a brief comment explaining the pattern.

csrc/fmha_v2/fmha/hopper/kernel_traits.h (2)

98-108: FLAGS bit layout could benefit from named constants.

The bitfield extraction for USE_LDGSTS_Q/K/V and USE_TMA_Q/K/V is correct, but the magic numbers (0x1u, 0x2u, 0x4u) could be replaced with named constants for better maintainability.

+  // FLAGS bit definitions
+  enum {
+    FLAG_LDGSTS_Q = 0x1u,
+    FLAG_LDGSTS_K = 0x2u,
+    FLAG_LDGSTS_V = 0x4u,
+    FLAG_HEADS_NOT_INTERLEAVED = 0x20u,
+    FLAG_BMM1_SOFTCAPPING = 0x800u
+  };
+
   // Do we use LDGSTS for Q, K or V. If not, TMA is used!
-  enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u };
-  enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u };
-  enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u };
+  enum { USE_LDGSTS_Q = (FLAGS & FLAG_LDGSTS_Q) != 0u };
+  enum { USE_LDGSTS_K = (FLAGS & FLAG_LDGSTS_K) != 0u };
+  enum { USE_LDGSTS_V = (FLAGS & FLAG_LDGSTS_V) != 0u };

358-363: Hardcoded VERSION=2 in alias template.

The FMHA_kernel_traits_hopper_v2 alias hardcodes VERSION=2. This is likely intentional for versioning, but consider documenting why this specific version is used here.

csrc/fmha_v2/fmha/hopper/fragment.h (1)

369-384: Use if constexpr for compile-time type dispatch.

The type dispatch using std::is_same_v with regular if/else works but generates unnecessary runtime branches. Since Input_type_A and Input_type_B are template parameters known at compile time, if constexpr is more appropriate.

Apply this diff for cleaner compile-time dispatch:

-    if (std::is_same_v<Input_type_A, e4m3_t> && std::is_same_v<Input_type_B, e4m3_t>) {
+    if constexpr (std::is_same_v<Input_type_A, e4m3_t> && std::is_same_v<Input_type_B, e4m3_t>) {
       qgmma_rfa_e4m3_e4m3_fp32<GMMA_N, INCREMENT_SCORE_BOARD>(a.regs_, single_desc_b.get(),
                                                               this->regs_);
-    } else if (std::is_same_v<Input_type_A, e5m2_t> && std::is_same_v<Input_type_B, e4m3_t>) {
+    } else if constexpr (std::is_same_v<Input_type_A, e5m2_t> && std::is_same_v<Input_type_B, e4m3_t>) {
       // ... rest of the chain

The same change should be applied to the SMEM-SMEM variant at lines 405-419.

csrc/fmha_v2/fmha/hopper/compute_tile.h (1)

273-277: Remove or enable commented-out static_asserts.

These commented static_asserts suggest incomplete compile-time validation. Either remove them if they're no longer relevant, or uncomment them with proper conditions to catch configuration errors early.

csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (4)

150-152: Unused member variable is_active_for_last_stg_.

This member variable is declared but never assigned or read in the class. It appears to be dead code.

Consider removing this unused member or implementing the intended functionality. The same issue exists in Gmem_tile_o_gmma_32bit_8bit at line 880.

-  // Is the thread active for the last STG?
-  int is_active_for_last_stg_;

289-306: Incorrect comment labels in nested loops.

The closing brace comments don't match the actual loop variables. Line 290 says // row_idx but it closes col_idx, and line 306 says // mma_ni but it closes row_idx.

Fix the comment labels for clarity:

-        }  // row_idx
-      }  // col_idx
+        }  // col_idx
+      }  // mma_ni

...

-      }  // row_idx
-    }  // mma_ni
+      }  // col_idx
+    }  // row_idx

584-586: Document the compiler workaround for predicate behavior.

The comment mentions that without this __shfl_sync, "the predicate will not behavior as expected for unknown reason." This suggests a compiler bug or undefined behavior that should be tracked.

Consider adding a more detailed comment or a TODO to track this issue:

-      // WARNING: Without this line, the predicate will not behavior as expected for unknown reason.
+      // WARNING: Without this line, the predicate will not behave as expected.
+      // This appears to be a compiler optimization issue. See: [link to bug/issue if available]
+      // TODO: Remove this workaround when the underlying issue is resolved.
       num_valid_rows_ = __shfl_sync(0xffffffff, num_valid_rows_, 0);

1082-1115: Add #undef STORE_COLUMNS after macro usage.

The STORE_COLUMNS macro is defined but never undefined, which could cause redefinition warnings if this pattern is used elsewhere in the codebase.

Add #undef STORE_COLUMNS after line 1115:

       for (int ci = 0; ci < VALID_COLS_PER_THREAD_FOR_LAST_MMA; ci += 2) {
         STORE_COLUMNS()
       }
     }
+#undef STORE_COLUMNS
   }
csrc/fmha_v2/fmha/hopper/smem_tile.h (4)

376-383: Remove or implement commented-out buffer management code.

The move_next_write_buffer() method contains commented-out implementation code. If the buffer cycling is not needed for this path, the comments should be removed. If it is needed, the implementation should be restored.

The same pattern appears at lines 538-544. Either remove the commented code or implement the functionality:

   inline __device__ void move_next_write_buffer() {
-    // if( BUFFERS_PER_TILE > 1 ) {
-    //     this->smem_write_offset_ += ( smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY )
-    //                                     ? -BYTES_PER_TILE_INC_BOUNDARY
-    //                                     : BYTES_PER_BUFFER;
-    // }
   }

169-172: Add parentheses to clarify operator precedence in XOR expression.

The expression on lines 170-171 relies on operator precedence where + binds tighter than ^. While likely correct, this is subtle and should be explicitly parenthesized for clarity.

     } else if (desc_mode == fmha::Gmma_descriptor_mode::SWIZZLE_64B) {
-      smem_write_col = (tidx % (THREADS_PER_ROW / 2)) ^
-                       smem_write_xor + ((tidx % THREADS_PER_ROW) / (THREADS_PER_ROW / 2)) * 4;
+      smem_write_col = (tidx % (THREADS_PER_ROW / 2)) ^
+                       (smem_write_xor + ((tidx % THREADS_PER_ROW) / (THREADS_PER_ROW / 2)) * 4);
     }

The same pattern appears at lines 343-345.


1326-1485: Remove large commented-out code block.

This ~160-line commented block significantly increases file size and reduces readability. If this code is needed for reference, consider moving it to a separate documentation file or keeping it in version control history.

The load_and_store() method should either implement the required functionality or remain as an empty stub without the commented implementation details.


1728-1730: Consider static_assert instead of runtime assert for unsupported configurations.

The assert(false) for unsupported warp/dimension configurations will only fail at runtime. Since these are compile-time known values, a static_assert would catch misconfigurations earlier.

-    } else {
-      assert(false);
-    }
+    } else {
+      static_assert(sizeof(Traits) == 0, "Unsupported warp/dimension configuration for Transposer");
+    }

Note: The sizeof(Traits) == 0 idiom creates a dependent false that only triggers if this branch is instantiated.

csrc/fmha_v2/fmha/hopper/gmma_descriptor.h (2)

119-133: Complex ternary chain is hard to maintain.

The nested ternary expression for BYTES_PER_DESC spans 14 lines and is difficult to verify. Consider using a constexpr helper function or a lookup table approach for better readability.

Consider refactoring to a constexpr function:

static constexpr uint32_t compute_bytes_per_desc() {
  if (Gmma_vector_size == Gmma_descriptor_size::ALL) return 0;
  if (Gmma_trans == Gmma_descriptor_transpose::TRANS) {
    if (Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) 
      return GMMA_K * BYTES_PER_LEADING_DIM;
    if (Gmma_mode == Gmma_descriptor_mode::SWIZZLE_64B)
      return (GMMA_K / 2) * BYTES_PER_LEADING_DIM;
    // ... etc
  }
  // ...
}
static constexpr uint32_t BYTES_PER_DESC = compute_bytes_per_desc();

175-201: Consider consolidating descriptor initialization loops.

The constructor has four separate loops that all iterate over NUM_DESCRIPTORS. These could be merged into a single loop for cleaner code.

   inline __device__ Gmma_descriptor_a() {
 #pragma unroll
     for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) {
-      desc[desc_idx] = 0;
-    }
-
-#pragma unroll
-    for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) {
-      desc[desc_idx] |= DESCRIPTOR_MODE_IN_BIT_LOCATION;
-    }
-
-#pragma unroll
-    for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) {
-      desc[desc_idx] |= STRIDE_BYTE_OFFSET_IN_BIT_LOCATION;
-    }
-
-    if (LEADING_BYTE_OFFSET_NEEDED) {
-#pragma unroll
-      for (int desc_idx = 0; desc_idx < NUM_DESCRIPTORS; ++desc_idx) {
+      desc[desc_idx] = DESCRIPTOR_MODE_IN_BIT_LOCATION | STRIDE_BYTE_OFFSET_IN_BIT_LOCATION;
+      if constexpr (LEADING_BYTE_OFFSET_NEEDED) {
         desc[desc_idx] |= LEADING_BYTE_OFFSET_IN_BIT_LOCATION;
       }
     }
   }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d0d99d2 and 1c36cdc.

📒 Files selected for processing (31)
  • csrc/fmha_v2/convert.cu (1 hunks)
  • csrc/fmha_v2/fmha/alibi_params.h (1 hunks)
  • csrc/fmha_v2/fmha/fragment.h (1 hunks)
  • csrc/fmha_v2/fmha/gemm.h (1 hunks)
  • csrc/fmha_v2/fmha/gmem_tile_o.h (1 hunks)
  • csrc/fmha_v2/fmha/gmem_tile_o_packed.h (1 hunks)
  • csrc/fmha_v2/fmha/gmem_tile_ps.h (1 hunks)
  • csrc/fmha_v2/fmha/gmem_tile_qkv.h (1 hunks)
  • csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/arrive_wait.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/compute_tile.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/fragment.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/gmma_descriptor.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/kernel_traits.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/smem_tile.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/smem_tile_o.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/tma_descriptor.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/tma_types.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_gmma.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_hgmma.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_igmma.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_tma.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_warpgroup.h (1 hunks)
  • csrc/fmha_v2/fmha/kernel_traits.h (1 hunks)
  • csrc/fmha_v2/fmha/mask.h (1 hunks)
  • csrc/fmha_v2/fmha/numeric_types.h (1 hunks)
  • csrc/fmha_v2/fmha/paged_kv_cache.h (1 hunks)
  • csrc/fmha_v2/fmha/smem_tile.h (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/fmha_v2/fmha/gemm.h
  • csrc/fmha_v2/fmha/hopper/smem_tile_o.h
🧬 Code graph analysis (15)
csrc/fmha_v2/fmha/gemm.h (1)
csrc/fmha_v2/fmha/fragment.h (1)
  • fmha (20-182)
csrc/fmha_v2/fmha/numeric_types.h (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (1)
  • __nv_fp8_e5m2 (91-93)
csrc/fmha_v2/fmha/mask.h (1)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
  • int (186-186)
csrc/fmha_v2/fmha/hopper/tma_descriptor.h (1)
csrc/cudnn_sdpa_utils.h (1)
  • set_tensor_common_0 (92-96)
csrc/fmha_v2/fmha/gmem_tile_qkv.h (3)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
  • int (186-186)
csrc/fmha_v2/fmha/warpspec/epilogue.h (1)
  • int (95-104)
csrc/fmha_v2/fmha/utils.h (1)
  • pack_predicates (1366-1370)
csrc/fmha_v2/fmha/hopper/arrive_wait.h (2)
csrc/fmha_v2/fmha/hopper/smem_tile.h (1)
  • fmha (22-562)
csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h (4)
  • bar_create (795-806)
  • bar_base (951-951)
  • bar_wait (853-869)
  • set_bar_base_dsmem (901-904)
csrc/fmha_v2/fmha/hopper/fragment.h (3)
csrc/fmha_v2/fmha/fragment.h (1)
  • fmha (20-182)
csrc/fmha_v2/fmha/gemm.h (1)
  • fmha (18-35)
csrc/fmha_v2/fmha/hopper/utils_igmma.h (1)
  • igmma_int8_int32 (201-203)
csrc/fmha_v2/fmha/hopper/compute_tile.h (1)
csrc/fmha_v2/fmha/hopper/smem_tile.h (1)
  • fmha (22-562)
csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h (2)
csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (1)
  • fmha (17-200)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h (1)
  • fmha (17-99)
csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (7)
csrc/fmha_v2/fmha/gmem_tile_o_packed.h (2)
  • fmha (19-202)
  • v2 (20-180)
csrc/fmha_v2/fmha/gmem_tile_qkv.h (1)
  • fmha (15-162)
csrc/fmha_v2/fmha/hopper/fragment.h (1)
  • fmha (20-35)
csrc/fmha_v2/fmha/hopper/smem_tile.h (5)
  • fmha (22-562)
  • store (2189-2193)
  • store (2197-2199)
  • store (2203-2207)
  • store (2218-2221)
csrc/fmha_v2/fmha/hopper/tma_descriptor.h (1)
  • fmha (16-121)
csrc/fmha_v2/fmha/warpspec/epilogue.h (1)
  • int (95-104)
csrc/fmha_v2/fmha/utils.h (11)
  • char (2294-2302)
  • float (838-853)
  • float (899-904)
  • float (1057-1068)
  • float (1151-1155)
  • float (1159-1163)
  • float (1251-1251)
  • float (2306-2311)
  • float (2315-2319)
  • float (2323-2338)
  • float (2343-2351)
csrc/fmha_v2/fmha/smem_tile.h (1)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
  • int (186-186)
csrc/fmha_v2/fmha/gmem_tile_o.h (1)
csrc/fmha_v2/fmha/traits.h (4)
  • public (264-264)
  • public (356-359)
  • public (433-436)
  • public (481-493)
csrc/fmha_v2/fmha/gmem_tile_o_packed.h (1)
csrc/fmha_v2/fmha/traits.h (4)
  • public (264-264)
  • public (356-359)
  • public (433-436)
  • public (481-493)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h (1)
csrc/fmha_v2/fmha/smem_tile_o.h (12)
  • Smem_tile_o_base_8bit_mma (1284-1360)
  • Smem_tile_o (104-197)
  • Smem_tile_o (691-691)
  • Smem_tile_o (723-723)
  • Smem_tile_o (753-753)
  • Smem_tile_o (802-820)
  • Smem_tile_o (957-975)
  • Smem_tile_o (1452-1452)
  • Smem_tile_o (1466-1466)
  • Smem_tile_o (1480-1480)
  • Smem_tile_o (1494-1494)
  • Smem_tile_o (1508-1508)
csrc/fmha_v2/fmha/hopper/utils_hgmma.h (1)
csrc/fmha_v2/fmha/gemm.h (1)
  • fmha (18-35)
🪛 Clang (14.0.6)
csrc/fmha_v2/fmha/gemm.h

[error] 15-15: 'fmha/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/alibi_params.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_gmma.h

[error] 15-15: 'fmha/hopper/utils_hgmma.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_warpgroup.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/numeric_types.h

[error] 13-13: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/paged_kv_cache.h

[error] 15-15: 'cuda_runtime.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_tma.h

[error] 15-15: 'fmha/hopper/tma_types.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h

[error] 14-14: 'fmha/traits.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/mask.h

[error] 15-15: 'fmha/traits.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/tma_descriptor.h

[error] 14-14: 'fmha/hopper/tma_types.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gmem_tile_qkv.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/gmma_descriptor.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/arrive_wait.h

[error] 66-66: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 66-66: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/fragment.h

[error] 15-15: 'fmha/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/kernel_traits.h

[error] 14-14: 'fmha/gmem_tile_qkv_packed.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/compute_tile.h

[error] 14-14: 'fmha/hopper/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_igmma.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h

[error] 14-14: 'fmha/gmem_tile_o_packed.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/smem_tile.h

[error] 15-15: 'fmha/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gmem_tile_o.h

[error] 15-15: 'fmha/traits.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gmem_tile_o_packed.h

[error] 15-15: 'fmha/numeric_types.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/smem_tile.h

[error] 15-15: 'fmha/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/smem_tile_o.h

[error] 15-15: 'fmha/smem_tile_o.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/tma_types.h

[error] 15-15: 'fmha/numeric_types.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h

[error] 14-14: 'fmha/traits.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/kernel_traits.h

[error] 15-15: 'fmha/alibi_params.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_hgmma.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gmem_tile_ps.h

[error] 14-14: 'fmha/hopper/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/fragment.h

[error] 15-15: 'fmha/traits.h' file not found

(clang-diagnostic-error)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (44)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h (5)

21-23: LGTM!

Simple placeholder struct with zero byte footprint - appropriate for conditional tile allocation patterns.


159-217: FP32 specialization mirrors FP16 structure correctly.

The conversion via float2_to_half2 for FP32 accumulators to half-precision storage is appropriate.

Same optional refactoring suggestions apply as noted for the FP16 specialization (lines 193-196).


221-281: BF16 specialization correctly uses float2_to_bf16_x2 conversion.

The additional static_assert(Mma_tile::CORES_M == 2) at line 252 is good defensive validation for this specific data type path.

Same optional refactoring suggestions apply regarding loop-internal static_assert placement (lines 257-259).


285-321: LGTM!

Clean thin wrappers for E4M3 (FP8) and INT8 quantized paths, correctly delegating to Smem_tile_o_gmma_32bit_8bit.


1-16: I encountered a technical limitation - the repository couldn't be cloned for verification. However, based on the original review comment's analysis, I can provide the rewritten version. The review already identified the context clearly (ported TensorRT-LLM code with expected include path setup) and approved the changes, with the verification request being supplementary.


License and include structure are appropriate for ported TensorRT-LLM code.

The static analysis error about 'fmha/smem_tile_o.h' file not found is a false positive expected in isolated analysis without the complete build system context. This file is part of the FMHA v2 library with include paths configured in the build system.

csrc/fmha_v2/fmha/alibi_params.h (1)

40-48: Well-structured parameter encapsulation with clear documentation.

The struct cleanly encapsulates ALiBi parameters with precomputed values. The comment explaining the design decision for head_idx_offset and sequence_pos_offset is helpful for maintainability.

csrc/fmha_v2/convert.cu (1)

58-62: Correct size_t casting for dimension multiplication.

This wrapper correctly casts to size_t before multiplication, preventing integer overflow for large tensors. The same pattern should be applied consistently to line 183.

csrc/fmha_v2/fmha/hopper/tma_types.h (1)

114-120: TMA descriptor types are well-structured.

The descriptor structures are correctly sized at 64 bytes (512 bits) with proper alignment for TMA hardware requirements. The enum definitions provide clear configuration options.

csrc/fmha_v2/fmha/mask.h (4)

143-200: Clean implementation of sequence-length masking (V2).

The V2 mask correctly implements basic sequence-length masking with proper warp/lane decomposition and position calculation. The static_assert on line 146 ensuring WARPS_K == 1 is good defensive programming.


344-352: FIXME comment indicates incomplete support for certain traits.

The mtp_token_idx_ array assumes exactly 2 rows per thread, which may not hold for Volta/Hopper-GMMA traits. This limitation should be addressed or clearly documented as a known restriction.

Consider tracking this as a follow-up issue. The current implementation may produce incorrect results if used with unsupported trait configurations.


368-398: Sliding window mask (V4) correctly extends causal mask.

The V4 mask properly implements the sliding window attention pattern. The is_valid check at line 393 correctly constrains attention to positions within the sliding window while maintaining causality.


403-515: Custom mask (V5) with flexible packed mask loading.

The V5 mask provides a well-structured approach for loading custom masks from global memory with proper handling of MMA group boundaries and partial mask loads via the predicate mechanism (line 457).

csrc/fmha_v2/fmha/gmem_tile_o_packed.h (2)

33-86: Well-designed base template with comprehensive tiling calculations.

The Hmma_gmem_tile_o base template correctly computes tiling parameters with appropriate static assertions to catch configuration errors at compile time. The calculation of ROWS_PER_LOOP, STGS_PER_LOOP, and handling of incomplete STGs shows careful attention to edge cases.


450-628: Comprehensive type conversion helpers via Acc_packer.

The Acc_packer template specializations provide a clean, extensible pattern for accumulator-to-output type conversions with optional scaling. The specializations correctly handle:

  • INT32 → INT8 with saturation
  • FP32 → e4m3/e5m2 (FP8)
  • FP32 → FP16/BF16
  • FP16 → e4m3
csrc/fmha_v2/fmha/fragment.h (5)

13-18: LGTM: Header structure and includes are appropriate.

The static analysis hint about missing fmha/traits.h is a false positive - the build system configures include paths correctly for this CUDA code.


543-578: LGTM: Fragment_accumulator MMA implementations.

The MMA implementations use correct PTX instructions with proper register constraints. The accumulator patterns for Volta/Turing/Ampere/Ada architectures are well-structured.


943-958: Appropriate architecture guards for Ada FP8 instructions.

The #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 guards correctly protect the e4m3 MMA instructions, with a trap fallback for unsupported architectures.

Also applies to: 972-990


785-809: Based on my verification through web searches, I can now provide my assessment:

The review comment's concern appears to be based on a misunderstanding of template inheritance semantics.

Key findings:

  1. PTX instruction is correct: The instruction mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 explicitly uses FP32 accumulators (the f32 at the end specifies the accumulator type D).

  2. 32-bit registers store FP32: The "+r"(reg(...)) constraints in inline assembly refer to 32-bit registers (.b32 in PTX), which correctly hold 32-bit FP32 floating-point values.

  3. CUTLASS pattern confirmation: CUTLASS separates input element types (BF16) from accumulator types (FP32) via template parameters. The comment "BF16 MMA must accumulate with at least FP32" is correct.

  4. Fragment template semantics: In CUTLASS/flashinfer, Fragment<bf16_t, 8> as a base class likely refers to a trait template specialization or output fragment type for subsequent operations, not the actual accumulator storage. The actual accumulator storage is defined by the PTX instruction (f32).

The code at lines 785-809 appears to correctly implement BF16 inputs with FP32 accumulation, consistent with the comment and NVIDIA's standard BF16 Tensor Core workflow.

Without direct repository access to verify the exact template inheritance at line 785 and lines 858-898, I cannot provide absolute certainty.


2286-2307: Based on my verification attempts, I encountered limitations accessing the repository directly due to cloning issues. However, my web searches reveal important findings:

  1. FP32_I2F_MAGIC_NUMBER_HEX: GitHub searches found no matches for this symbol in the flashinfer repository
  2. USE_I2F_EMULATION_TRICK: GitHub searches found no matches for this macro in flashinfer
  3. Clear_accumulator structure: No matches found for the exact code snippet shown

This is expected behavior for code under review, as it may not yet be in the public main branch. However, since I cannot access the actual PR branch directly to verify the include chain and definitions, I must mark this for manual verification.


Verify that FP32_I2F_MAGIC_NUMBER_HEX and USE_I2F_EMULATION_TRICK are properly defined in included headers.

The Clear_accumulator<int32_t, WARPS_K> struct uses FP32_I2F_MAGIC_NUMBER_HEX when USE_I2F_EMULATION_TRICK is defined. Ensure both the macro and constant are declared in reachable headers—either directly included in fragment.h or transitively through existing includes. Verify the preprocessor conditional logic (#if defined(USE_I2F_EMULATION_TRICK)) aligns with build configuration and that the fallback path (fmha::clear(acc)) is always available when the macro is undefined.

csrc/fmha_v2/fmha/hopper/utils_gmma.h (1)

1-18: LGTM: Clean aggregator header for GMMA utilities.

This header provides a single include point for all GMMA-related utilities (HGMMA, HGMMA BF16, IGMMA, QGMMA). The static analysis hint about missing headers is a false positive - the build system handles include paths.

csrc/fmha_v2/fmha/hopper/utils_warpgroup.h (1)

19-40: LGTM: Well-guarded SM90 warpgroup synchronization primitives.

The warpgroup functions are correctly guarded for SM90+ architectures with the __CUDA_ARCH_FEAT_SM90_ALL feature test. The template parameter N in warpgroup_wait uses the correct "n" constraint for immediate values.

Note: On non-SM90 architectures, these functions compile to no-ops, which is the intended behavior for this FMHA infrastructure.

csrc/fmha_v2/fmha/gemm.h (1)

22-31: LGTM: Clean templated GEMM accumulation primitive.

The implementation correctly delegates to the Fragment's mma method, which handles the architecture-specific MMA intrinsics defined in fragment.h. The #pragma unroll directives ensure full loop unrolling for performance.

csrc/fmha_v2/fmha/numeric_types.h (2)

44-53: LGTM: Softmax FP8 quantization scale.

The Softmax_fp_quant_scale<e4m3_t>() specialization correctly computes the quantization scale as 256.f (2^8), which is the largest power-of-2 below the e4m3 max of 448. The comment explains the rationale well.


36-39: I'm unable to complete the automated verification as the repository cloning failed. However, I can still provide guidance on the review comment based on the learnings and context provided.

Based on the retrieved learnings from your flashinfer work, I note that silent fallbacks and error-swallowing patterns are established practices in this codebase when they serve a legitimate purpose (e.g., probing multiple configurations in GEMM implementations).

The key question for your review is: Are the e4m3_t/e5m2_tchar fallbacks intentional and safe, or do they genuinely risk runtime bugs?

To properly verify this comment, I would need to:

  1. Examine how e4m3_t and e5m2_t are actually used in FMHA kernels
  2. Check if there are runtime guards preventing their use on non-FP8 hardware
  3. Determine if a static_assert is already present elsewhere
  4. Verify whether the #else branch is even reached in production builds

Since I cannot access the repository directly, I recommend you verify these points manually or provide repository access for automated analysis.

Fallback types may silently compile but cause runtime issues if FP8 operations are used on unsupporting hardware without guards. Consider verifying whether runtime checks prevent actual FP8 usage when the fallback is active, and whether a static_assert would be appropriate as additional safety insurance.

csrc/fmha_v2/fmha/gmem_tile_qkv.h (2)

122-147: Inconsistent is_active_ usage between load() overloads.

The parameterless load() function (line 138-147) checks is_active_ before performing loads, but load(Smem_tile&) (line 123-135) does not. This inconsistency could lead to out-of-bounds accesses if load(Smem_tile&) is called when threads are inactive.

Please verify this is intentional. If load(Smem_tile&) relies on predicates alone for safety, consider adding a comment to clarify this design choice.


39-41: Based on my research, I found that SLIDING_WINDOW_ATTENTION is indeed a real and active feature in FlashInfer for implementing local/windowed attention patterns. The gmem_tile_qkv struct is part of the kernel infrastructure for handling tiled memory copies of QKV matrices during attention computation.

However, I cannot clone the repository to verify whether the template parameter is actually unused within the implementation or to determine if it's intentionally present for future use. Given that this is a legitimate feature in FlashInfer, the presence of this parameter—even if currently unused—may indicate:

  1. Incomplete implementation of sliding window support in this component
  2. Dead code that was prepared but never integrated
  3. A placeholder for future sliding window functionality

Since I cannot perform definitive verification of the code implementation without repository access, I'll provide the rewritten comment:


SLIDING_WINDOW_ATTENTION template parameter appears unused but is a known FlashInfer feature.

While SLIDING_WINDOW_ATTENTION is declared as a template parameter, it is not referenced in the implementation. This parameter corresponds to FlashInfer's sliding/local-window attention feature.

Verify if this parameter should be integrated into the struct's logic (e.g., in tile loading or offset calculations) or if it represents incomplete functionality that needs implementation.

csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h (1)

86-99: Coordinate calculation logic for MQA/GQA looks correct.

The handling of qkv_offset for Q (0), K (1), and V (2) tensors with proper head index calculations for MQA/GQA (h_kv < h) versus standard MHA appears correct. The coordinate assembly for TMA descriptors is well-structured.

csrc/fmha_v2/fmha/gmem_tile_o.h (2)

282-288: Preprocessor-based conditional initialization may leave o_scratch_ptr_ uninitialized.

When USE_DEMO_BERT_PARAMS is defined, o_scratch_ptr_ is set to nullptr. Ensure that all code paths that access o_scratch_ptr_ check for this case, or add a comment clarifying that scratch space is unused in demo mode.


80-105: LGTM - Constructor logic is well-structured.

The pointer arithmetic correctly handles batch/head offsets with int64_t to prevent overflow on large tensors. The partial-write predicate initialization is properly guarded.

csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h (2)

30-44: LGTM - Inline PTX assembly for BF16 GMMA is correctly implemented.

The SM90 architecture guards (__CUDA_ARCH__ >= 900 && __CUDA_ARCH_FEAT_SM90_ALL) are appropriate. The wgmma instruction format and register constraints match NVIDIA PTX ISA specifications for BF16 tensor core operations.


23-75: Repository access is unavailable for direct verification.

The web search confirms that the m64n32k16 GMMA shape is a standard, documented Hopper instruction, taking 16 uint32_t accumulators. This makes the reviewer's concern technically sound—if specializations for N=8, 64, 128, 192, 256 exist but N=32 is missing from the base Hgmma_bf16 path (while present in the RFA variant), a linker error could occur for 64×32×16 operations requiring both operands from SMEM.

However, I cannot directly inspect the file to confirm:

  1. Which specializations currently exist in utils_hgmma_bf16.h
  2. Whether Hgmma_bf16<32, TA, TB> is truly missing
  3. Whether any code paths actually invoke this specialization

The reasoning in the review comment is logically coherent and the underlying CUDA mechanics are correct, but the specific claim about the file's contents requires repository access to validate conclusively.

csrc/fmha_v2/fmha/hopper/utils_igmma.h (2)

30-53: LGTM - INT8 IGMMA implementation is correct.

The wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 instruction format is correct for INT8 tensor core operations with INT32 accumulators. The K dimension of 32 (vs 16 for FP16/BF16) correctly reflects the wider K for 8-bit operations.


200-203: Consistent API wrapper pattern.

The igmma_int8_int32 wrapper follows the same pattern as the FP16/BF16 utilities, with an unused boolean template parameter for API consistency.

csrc/fmha_v2/fmha/hopper/kernel_traits.h (2)

51-69: LGTM - Clear RF flag propagation from traits.

The extraction of GMMA_A_RF and GMMA_B_RF flags from both Traits_p and Traits_o is well-documented and enables flexible GMMA operand sourcing (RF vs SMEM) per GEMM stage.


303-303: SMEM O size consistency validated with static_assert.

Good use of compile-time validation to ensure Smem_tile_o::BYTES_PER_TILE matches the calculated BYTES_PER_SMEM_O.

csrc/fmha_v2/fmha/hopper/fragment.h (2)

78-83: Address TODO comment about fragment size calculation.

There's an unresolved TODO questioning whether the * 4 factor is needed in the fragment size calculation. This could lead to incorrect register allocation if the sizing is wrong.

Please verify whether the * 4 factor is needed by comparing with the FP16/BF16 Fragment_a specializations above (lines 29-46) which don't have this factor. If the factor is needed for FP8 due to different packing, please document the reasoning.


456-475: LGTM - Softmax state storage with proper bounds checking.

The store method correctly handles the interleaved softmax sum/max storage with proper sequence length bounds checking at line 472. The conditional override for invalid runs (line 467-469) correctly sets sum to 1.0 to avoid division issues in downstream code.

csrc/fmha_v2/fmha/hopper/compute_tile.h (2)

28-81: Well-structured SMEM-SMEM compute tile initialization.

The GMMA descriptor setup correctly initializes both A and B descriptors with proper SMEM pointers and max descriptor values for buffer cycling. The static_assert at lines 39-41 ensures descriptor compatibility between A and B tiles.


460-468: I'm unable to access the repository due to cloning issues. Based on the review comment content alone, I can see it raises a theoretically valid concern about index handling, but I cannot definitively verify whether:

  1. The get_descriptor(ki) call actually uses direct indexing or modular arithmetic
  2. Whether the descriptor arrays are sized to handle indices up to NUM_KGROUPS * MMAS_K_PER_GROUP - 1
  3. Whether this is an actual bug or intentional by design

Let me provide the rewritten comment reflecting this situation:


Cannot definitively verify the index mismatch claim without access to the compute() method implementation and descriptor structure.

The concern about compute() being called at line 467 with ki = NUM_KGROUPS * MMAS_K_PER_GROUP - 1 while the loop at line 461 iterates ki from 0 to MMAS_K_PER_GROUP - 2 is mathematically plausible. However, verification requires examining how get_descriptor(ki) handles this index—specifically whether it uses direct array indexing, modular arithmetic (ki % MMAS_K_PER_GROUP), or global indexing. Please review the compute() method signature at line 364 and the get_descriptor() implementation to confirm whether this represents an out-of-bounds access or is intentionally designed.

csrc/fmha_v2/fmha/hopper/gmma_descriptor.h (2)

260-268: LGTM - Descriptor increment with proper boundary handling.

The increment logic correctly handles the loop-back case when reaching the end of a kblock, using the precomputed BYTES_DESC_INC_BOUNDARY_NO_4LSB constant. The use of reinterpret_cast<int2&> to access the lower 32 bits of the 64-bit descriptor is a standard pattern for GMMA descriptor manipulation.


361-369: LGTM - RESET_SMEM logic for large K dimensions.

The RESET_SMEM and RESET_BYTES_NO_4LSB constants correctly handle the case where the K dimension exceeds a single 128-byte leading dimension, requiring descriptor reset during iteration. This is specific to the B matrix handling for transposed layouts.

csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h (1)

906-931: Verify paged-KV block index bounds in Gmem_tile_paged_kv::load

paged_kv_block_idx is derived from row_idx and used to index paged_kv_global_block_offsets_ before predicates are applied:

int row_idx = row_ + ii * (int)ROWS_PER_LDG;
int paged_kv_block_idx = (row_idx >> paged_kv_log2_block_size_);
char const* local_kv_ptr =
    paged_kv_block_pool_ptr_ +
    params_kv_block_size_in_bytes_ * paged_kv_global_block_offsets_[paged_kv_block_idx];
// ...
preds[ii] = row_idx < actual_seqlen_;

This is safe only if actual_seqlen_ and params.paged_kv_cache.mMaxBlocksPerSeq guarantee that
row_idx < actual_seqlen_ implies paged_kv_block_idx < mMaxBlocksPerSeq for all rows a CTA can
touch (including cta_row_offset). The TODO in the comment suggests this may not be fully audited.

Please double-check those invariants (especially for large sequence lengths / chunked context) so we
never index beyond mBlockOffsets. If they already hold by construction, consider documenting that
relation near this code.

csrc/fmha_v2/fmha/smem_tile.h (1)

337-436: TMA specialization of Smem_tile_without_skews looks consistent

The TMA-enabled specialization correctly reuses the base layout, overrides buffer sizing to ignore STS granularity, wires utmaldg via smem_write_offset_ / smem_barrier_offset_, and keeps buffer-rotation logic in sync with the number of buffers and barriers. No issues stand out in the offset math or barrier handling.

csrc/fmha_v2/fmha/gmem_tile_ps.h (1)

515-585: Gmem_tile_ps base implementation appears consistent with per-thread MMA layout

The generic Gmem_tile_ps<Traits, Cta_tile, BITS_PER_ELEMENT> correctly:

  • Derives per-thread (row, col) from (warp, lane) in line with MMA tiling.
  • Computes step_m / step_n to cover an 8×8 region of 2-element packets per thread.
  • Delegates to Store_accumulator<Traits, BITS_PER_ELEMENT> with offsets based on M_PER_MMA_PER_CTA and N_PER_MMA_PER_CTA.

No issues stand out in the pointer math or scaling wiring.

Comment on lines +19 to +54
__global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) {
// The step.
size_t step = (size_t)gridDim.x * blockDim.x;

// Iterate over the elements.
for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) {
// Load 4 integers.
int4 tmp = reinterpret_cast<int4 const*>(src)[ii];

// Convert to float and scale.
float x = static_cast<float>(tmp.x) * scale;
float y = static_cast<float>(tmp.y) * scale;
float z = static_cast<float>(tmp.z) * scale;
float w = static_cast<float>(tmp.w) * scale;

// Convert to int8.
uint32_t a;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x));
uint32_t b;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y));
uint32_t c;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z));
uint32_t d;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w));

// Compact.
char4 out;
out.x = reinterpret_cast<int8_t const&>(a);
out.y = reinterpret_cast<int8_t const&>(b);
out.z = reinterpret_cast<int8_t const&>(c);
out.w = reinterpret_cast<int8_t const&>(d);

// Store.
reinterpret_cast<uint32_t*>(dst)[ii] = reinterpret_cast<uint32_t const&>(out);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Tail elements are not converted when n is not divisible by 4.

The kernel only processes n/4 iterations, each handling 4 elements. If n % 4 != 0, the remaining 1-3 elements are silently skipped, leading to uninitialized/stale data in the output.

Consider adding a tail-handling path or documenting the alignment requirement:

 __global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) {
+  // Note: n must be a multiple of 4. Caller must ensure proper alignment.
+  assert(n % 4 == 0);
   // The step.
   size_t step = (size_t)gridDim.x * blockDim.x;

Alternatively, add a scalar tail loop after the main vectorized loop.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
__global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) {
// The step.
size_t step = (size_t)gridDim.x * blockDim.x;
// Iterate over the elements.
for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) {
// Load 4 integers.
int4 tmp = reinterpret_cast<int4 const*>(src)[ii];
// Convert to float and scale.
float x = static_cast<float>(tmp.x) * scale;
float y = static_cast<float>(tmp.y) * scale;
float z = static_cast<float>(tmp.z) * scale;
float w = static_cast<float>(tmp.w) * scale;
// Convert to int8.
uint32_t a;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x));
uint32_t b;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y));
uint32_t c;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z));
uint32_t d;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w));
// Compact.
char4 out;
out.x = reinterpret_cast<int8_t const&>(a);
out.y = reinterpret_cast<int8_t const&>(b);
out.z = reinterpret_cast<int8_t const&>(c);
out.w = reinterpret_cast<int8_t const&>(d);
// Store.
reinterpret_cast<uint32_t*>(dst)[ii] = reinterpret_cast<uint32_t const&>(out);
}
}
__global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) {
// Note: n must be a multiple of 4. Caller must ensure proper alignment.
assert(n % 4 == 0);
// The step.
size_t step = (size_t)gridDim.x * blockDim.x;
// Iterate over the elements.
for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) {
// Load 4 integers.
int4 tmp = reinterpret_cast<int4 const*>(src)[ii];
// Convert to float and scale.
float x = static_cast<float>(tmp.x) * scale;
float y = static_cast<float>(tmp.y) * scale;
float z = static_cast<float>(tmp.z) * scale;
float w = static_cast<float>(tmp.w) * scale;
// Convert to int8.
uint32_t a;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x));
uint32_t b;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y));
uint32_t c;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z));
uint32_t d;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w));
// Compact.
char4 out;
out.x = reinterpret_cast<int8_t const&>(a);
out.y = reinterpret_cast<int8_t const&>(b);
out.z = reinterpret_cast<int8_t const&>(c);
out.w = reinterpret_cast<int8_t const&>(d);
// Store.
reinterpret_cast<uint32_t*>(dst)[ii] = reinterpret_cast<uint32_t const&>(out);
}
}
🤖 Prompt for AI Agents
In csrc/fmha_v2/convert.cu around lines 19-54, the kernel only iterates n/4
vector elements so when n is not divisible by 4 the last 1–3 elements are
skipped; add a tail-handling path after the main vectorized loop that computes
base = (n/4)*4 and a scalar loop for i from base to n-1 that loads the single
int32, converts to float, applies the same scale and saturating int8 conversion,
and writes the byte to dst (ensure correct byte-offsets and casts into dst), or
alternatively document and assert that n must be 4-aligned. Ensure the scalar
path uses the same rounding/saturation behavior as the vector path and respects
thread/work distribution and bounds checks to avoid out-of-bounds stores.

Comment on lines +181 to +184
void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d,
float scale_o) {
run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Integer overflow risk in dimension multiplication.

The expression s * b * h * d performs int multiplication before being passed to the size_t parameter. For large tensors, this can overflow.

Cast to size_t before multiplication:

 void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d,
                                  float scale_o) {
-  run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o);
+  run_conversion_fp32_to_e4m3(dst, src, (size_t)s * b * h * d, scale_o);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d,
float scale_o) {
run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o);
}
void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d,
float scale_o) {
run_conversion_fp32_to_e4m3(dst, src, (size_t)s * b * h * d, scale_o);
}
🤖 Prompt for AI Agents
In csrc/fmha_v2/convert.cu around lines 181 to 184, the multiplication s * b * h
* d is done in int and may overflow before being passed to the function
expecting a size_t; change the call to compute the product in size_t (e.g. cast
one operand to size_t before multiplying) or otherwise promote operands to
size_t so the multiplication is performed in size_t and then pass that value to
run_conversion_fp32_to_e4m3.

Comment on lines +18 to +32
constexpr static int round_down_to_power_two(int x) {
x = x | (x >> 1);
x = x | (x >> 2);
x = x | (x >> 4);
x = x | (x >> 8);
x = x | (x >> 16);
return x - (x >> 1);
}

AlibiParams() = default;

AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) {
h_pow_2 = round_down_to_power_two(h);
alibi_neg4_div_h = -4.0f / h_pow_2;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential division by zero when h=0.

If h=0 is passed to the constructor, round_down_to_power_two(0) returns 0, causing division by zero on line 31.

Consider adding a guard or assertion:

 AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) {
   h_pow_2 = round_down_to_power_two(h);
+  // Avoid division by zero; h_pow_2 == 0 implies h == 0 which is invalid.
+  assert(h_pow_2 > 0 && "Number of heads must be positive");
   alibi_neg4_div_h = -4.0f / h_pow_2;
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
constexpr static int round_down_to_power_two(int x) {
x = x | (x >> 1);
x = x | (x >> 2);
x = x | (x >> 4);
x = x | (x >> 8);
x = x | (x >> 16);
return x - (x >> 1);
}
AlibiParams() = default;
AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) {
h_pow_2 = round_down_to_power_two(h);
alibi_neg4_div_h = -4.0f / h_pow_2;
}
constexpr static int round_down_to_power_two(int x) {
x = x | (x >> 1);
x = x | (x >> 2);
x = x | (x >> 4);
x = x | (x >> 8);
x = x | (x >> 16);
return x - (x >> 1);
}
AlibiParams() = default;
AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) {
h_pow_2 = round_down_to_power_two(h);
// Avoid division by zero; h_pow_2 == 0 implies h == 0 which is invalid.
assert(h_pow_2 > 0 && "Number of heads must be positive");
alibi_neg4_div_h = -4.0f / h_pow_2;
}
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/alibi_params.h around lines 18 to 32, the constructor calls
round_down_to_power_two(h) and then divides by h_pow_2 which will be zero when
h==0; add a guard to prevent division by zero by either asserting/throwing for
non-positive h or clamping h_pow_2 to a safe minimum (e.g. 1) before computing
alibi_neg4_div_h, and optionally update round_down_to_power_two to return 1 for
non-positive input; ensure the constructor documents/validates the precondition
and that alibi_neg4_div_h is computed only with a non-zero denominator.

Comment on lines +1277 to +1279
inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) {
int rows_so_far = mi * STGS_PER_LOOP * ROWS_PER_STG;
int rows_so_far_per_slice = rows_so_far / 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hardcoded division assumes NUM_SLICES == 2.

Line 1279 uses rows_so_far / 2 which assumes exactly 2 slices. However, line 1187 allows NUM_SLICES == 1, which would cause incorrect row calculations.

Consider using the enum value:

   int rows_so_far = mi * STGS_PER_LOOP * ROWS_PER_STG;
-  int rows_so_far_per_slice = rows_so_far / 2;
+  int rows_so_far_per_slice = rows_so_far / NUM_SLICES;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) {
int rows_so_far = mi * STGS_PER_LOOP * ROWS_PER_STG;
int rows_so_far_per_slice = rows_so_far / 2;
inline __device__ void store(uint4 const (&src)[STGS_PER_LOOP], int mi) {
int rows_so_far = mi * STGS_PER_LOOP * ROWS_PER_STG;
int rows_so_far_per_slice = rows_so_far / NUM_SLICES;
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/gmem_tile_o_packed.h around lines 1277-1279, the code
computes rows_so_far_per_slice using a hardcoded division by 2 which assumes
NUM_SLICES == 2; replace the literal 2 with the enum/constant for the number of
slices (e.g., divide by NUM_SLICES) so the computation works for NUM_SLICES == 1
or other values, ensuring you use the appropriate integer type/cast if
NUM_SLICES is an enum or template constant.

Comment on lines +133 to +139
// Ctor for other param classes (such as Qkv_params in train_ops)
template <typename Params, typename Block_info>
inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo,
int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
: Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h,
qkv_offset, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug in Gmem_tile_qkv forwarding ctor: mis-ordered arguments corrupt row/col indexing

The generic Params ctor

template <typename Params, typename Block_info>
inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo,
                                int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
    : Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h,
                    qkv_offset, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) {}

forwards cta_row_offset into the num_kv_heads parameter of the main ctor and cta_col_offset_in_bytes into its cta_row_offset parameter, leaving cta_col_offset_in_bytes at its default. This will:

  • Use an incorrect num_kv_heads in the num_kv_heads < num_heads (MQA/GQA) branch, and
  • Miscompute row offsets (row + cta_row_offset) by treating a byte offset as a row count.

Both lead to wrong global-memory addresses, and the GQA path can even see division-by-zero if num_kv_heads ends up as 0 by accident.

A minimal safe fix is to pass an explicit num_kv_heads and then the row/column offsets in the correct positions. For non-GQA callers where Params has no KV-head count, using params.h keeps behavior identical to “no GQA/MQA” (the num_kv_heads < num_heads branch won’t trigger):

-  template <typename Params, typename Block_info>
-  inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo,
-                                  int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
-      : Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h,
-                      qkv_offset, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) {}
+  template <typename Params, typename Block_info>
+  inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo,
+                                  int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
+      : Gmem_tile_qkv(params.qkv_ptr,
+                      params.q_stride_in_bytes,
+                      params.d,
+                      params.dv,
+                      params.h,
+                      qkv_offset,
+                      binfo,
+                      tidx,
+                      /*num_kv_heads=*/params.h,  // or Params::num_kv_heads if/when available
+                      cta_row_offset,
+                      cta_col_offset_in_bytes) {}

If some Params type does expose a distinct KV-head count (e.g., params.h_kv), it should be threaded through here instead of params.h.

🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h around lines 133-139, the forwarding
ctor passes cta_row_offset and cta_col_offset_in_bytes into the wrong parameters
of the main ctor (they are shifting into num_kv_heads and cta_row_offset),
corrupting head counts and row/column indexing; fix by explicitly supplying
num_kv_heads (use params.h_kv if available, otherwise params.h) as the sixth
argument and then pass qkv_offset, binfo, tidx, cta_row_offset,
cta_col_offset_in_bytes in that order so the main ctor receives (qkv_ptr,
q_stride_in_bytes, d, dv, h, num_kv_heads, qkv_offset, binfo, tidx,
cta_row_offset, cta_col_offset_in_bytes).

Comment on lines +26 to +47
class Multiple_tma_descriptor {
public:
// ctor
Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) {
if (batch_size > 0) {
// allocate host memory
desc_ptr_h = new cudaTmaDesc[batch_size];
// make sure all bit fields are zeros.
memset(desc_ptr_h, 0, sizeof(cudaTmaDesc) * batch_size);
}
}

// ctor
Multiple_tma_descriptor() = default;

// destructor.
~Multiple_tma_descriptor() {
if (batch_size > 0) {
// deallocate host memory
delete[] desc_ptr_h;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Missing copy/move semantics (Rule of Five violation).

The class manages raw memory with new[]/delete[] but does not define copy constructor, copy assignment, move constructor, or move assignment operators. If an instance is copied, the default copy will lead to double-free on destruction.

Either delete copy/move operations or implement them properly:

 class Multiple_tma_descriptor {
  public:
+  // Prevent accidental copies that would cause double-free
+  Multiple_tma_descriptor(Multiple_tma_descriptor const&) = delete;
+  Multiple_tma_descriptor& operator=(Multiple_tma_descriptor const&) = delete;
+  Multiple_tma_descriptor(Multiple_tma_descriptor&&) = delete;
+  Multiple_tma_descriptor& operator=(Multiple_tma_descriptor&&) = delete;
+
   // ctor
   Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) {
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class Multiple_tma_descriptor {
public:
// ctor
Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) {
if (batch_size > 0) {
// allocate host memory
desc_ptr_h = new cudaTmaDesc[batch_size];
// make sure all bit fields are zeros.
memset(desc_ptr_h, 0, sizeof(cudaTmaDesc) * batch_size);
}
}
// ctor
Multiple_tma_descriptor() = default;
// destructor.
~Multiple_tma_descriptor() {
if (batch_size > 0) {
// deallocate host memory
delete[] desc_ptr_h;
}
}
class Multiple_tma_descriptor {
public:
// Prevent accidental copies that would cause double-free
Multiple_tma_descriptor(Multiple_tma_descriptor const&) = delete;
Multiple_tma_descriptor& operator=(Multiple_tma_descriptor const&) = delete;
Multiple_tma_descriptor(Multiple_tma_descriptor&&) = delete;
Multiple_tma_descriptor& operator=(Multiple_tma_descriptor&&) = delete;
// ctor
Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) {
if (batch_size > 0) {
// allocate host memory
desc_ptr_h = new cudaTmaDesc[batch_size];
// make sure all bit fields are zeros.
memset(desc_ptr_h, 0, sizeof(cudaTmaDesc) * batch_size);
}
}
// ctor
Multiple_tma_descriptor() = default;
// destructor.
~Multiple_tma_descriptor() {
if (batch_size > 0) {
// deallocate host memory
delete[] desc_ptr_h;
}
}
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/tma_descriptor.h around lines 26-47, the class
Multiple_tma_descriptor manages raw memory (desc_ptr_h) but lacks copy/move
special members causing double-free; fix by implementing the Rule of Five:
ensure the default constructor initializes batch_size to 0 and desc_ptr_h to
nullptr, implement a copy constructor and copy-assignment that perform a deep
copy (allocate new cudaTmaDesc[batch_size], memcpy/set contents) and handle
self-assignment, implement a move constructor and move-assignment that transfer
desc_ptr_h and batch_size from the source and set the source's desc_ptr_h to
nullptr and batch_size to 0, and keep the existing destructor to delete[]
desc_ptr_h.

Comment on lines +83 to +115
// set the desctriptor.
int set_tma_desctriptor(
// ptr to gmem
void const* gmem_ptr,
// format is really data_type in TMA terminology.
cudaTmaDescFormat format,
// interleave mode.
cudaTmaDescInterleave interleave,
// swizzle mode.
cudaTmaDescSwizzle swizzle,
// L2 sector promotion.
cudaTmaDescPromotion promotion, uint32_t const (&tensor_size_array)[NUM_DIMS],
uint64_t const (&tensor_stride_array)[NUM_DIMS - 1],
uint32_t const (&traversal_stride_array)[NUM_DIMS],
uint32_t const (&box_size_array)[NUM_DIMS],
// OOB fill mode.
uint32_t fill_oob,
// FP32 to TF32 conversion.
uint32_t round_to_tf32,
// index to desc.
cudaTmaDesc* desc_ptr = nullptr) {
set_tensor_common_0(desc_ptr, reinterpret_cast<uint64_t>(gmem_ptr));
set_tensor_common_1(desc_ptr, TILED, NUM_DIMS, format, interleave, swizzle, fill_oob,
round_to_tf32, promotion);

set_tensor_stride(desc_ptr, tensor_stride_array);
set_tensor_size(desc_ptr, tensor_size_array);

set_traversal_stride_tiled(desc_ptr, traversal_stride_array);

set_box_size(desc_ptr, box_size_array);
return 0;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Null pointer dereference risk in set_tma_desctriptor overload.

The second overload of set_tma_desctriptor (line 84-115) has desc_ptr = nullptr as the default parameter value. If called without providing a valid pointer, all the set_* helper functions will dereference nullptr, causing undefined behavior.

Add a null check or remove the default value:

   int set_tma_desctriptor(
       // ... parameters ...
-      cudaTmaDesc* desc_ptr = nullptr) {
+      cudaTmaDesc* desc_ptr) {
+    assert(desc_ptr != nullptr);
     set_tensor_common_0(desc_ptr, reinterpret_cast<uint64_t>(gmem_ptr));
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// set the desctriptor.
int set_tma_desctriptor(
// ptr to gmem
void const* gmem_ptr,
// format is really data_type in TMA terminology.
cudaTmaDescFormat format,
// interleave mode.
cudaTmaDescInterleave interleave,
// swizzle mode.
cudaTmaDescSwizzle swizzle,
// L2 sector promotion.
cudaTmaDescPromotion promotion, uint32_t const (&tensor_size_array)[NUM_DIMS],
uint64_t const (&tensor_stride_array)[NUM_DIMS - 1],
uint32_t const (&traversal_stride_array)[NUM_DIMS],
uint32_t const (&box_size_array)[NUM_DIMS],
// OOB fill mode.
uint32_t fill_oob,
// FP32 to TF32 conversion.
uint32_t round_to_tf32,
// index to desc.
cudaTmaDesc* desc_ptr = nullptr) {
set_tensor_common_0(desc_ptr, reinterpret_cast<uint64_t>(gmem_ptr));
set_tensor_common_1(desc_ptr, TILED, NUM_DIMS, format, interleave, swizzle, fill_oob,
round_to_tf32, promotion);
set_tensor_stride(desc_ptr, tensor_stride_array);
set_tensor_size(desc_ptr, tensor_size_array);
set_traversal_stride_tiled(desc_ptr, traversal_stride_array);
set_box_size(desc_ptr, box_size_array);
return 0;
}
// set the desctriptor.
int set_tma_desctriptor(
// ptr to gmem
void const* gmem_ptr,
// format is really data_type in TMA terminology.
cudaTmaDescFormat format,
// interleave mode.
cudaTmaDescInterleave interleave,
// swizzle mode.
cudaTmaDescSwizzle swizzle,
// L2 sector promotion.
cudaTmaDescPromotion promotion, uint32_t const (&tensor_size_array)[NUM_DIMS],
uint64_t const (&tensor_stride_array)[NUM_DIMS - 1],
uint32_t const (&traversal_stride_array)[NUM_DIMS],
uint32_t const (&box_size_array)[NUM_DIMS],
// OOB fill mode.
uint32_t fill_oob,
// FP32 to TF32 conversion.
uint32_t round_to_tf32,
// index to desc.
cudaTmaDesc* desc_ptr) {
assert(desc_ptr != nullptr);
set_tensor_common_0(desc_ptr, reinterpret_cast<uint64_t>(gmem_ptr));
set_tensor_common_1(desc_ptr, TILED, NUM_DIMS, format, interleave, swizzle, fill_oob,
round_to_tf32, promotion);
set_tensor_stride(desc_ptr, tensor_stride_array);
set_tensor_size(desc_ptr, tensor_size_array);
set_traversal_stride_tiled(desc_ptr, traversal_stride_array);
set_box_size(desc_ptr, box_size_array);
return 0;
}
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/tma_descriptor.h around lines 83 to 115, the
overload of set_tma_desctriptor uses desc_ptr = nullptr by default which allows
calling without a valid descriptor and leads to null dereference when set_*
helpers are invoked; either remove the default so callers must supply a valid
cudaTmaDesc* or add an early null check at the top (e.g., if (desc_ptr ==
nullptr) return a non-zero error code or assert) and avoid calling any set_*
helpers when desc_ptr is null, keeping the function’s existing return
convention.

Comment on lines +117 to +121
// copy the desc to device memory
void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) {
FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size,
cudaMemcpyHostToDevice));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused stream parameter and synchronous copy.

The copy_to_device function accepts a cudaStream_t stream parameter but uses synchronous cudaMemcpy instead of cudaMemcpyAsync. This is misleading and the stream is silently ignored.

Either use the stream parameter with async copy or remove it:

   // copy the desc to device memory
-  void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) {
-    FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size,
-                               cudaMemcpyHostToDevice));
+  void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) {
+    FMHA_CHECK_CUDA(cudaMemcpyAsync(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size,
+                                    cudaMemcpyHostToDevice, stream));
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// copy the desc to device memory
void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) {
FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size,
cudaMemcpyHostToDevice));
}
// copy the desc to device memory
void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) {
FMHA_CHECK_CUDA(cudaMemcpyAsync(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size,
cudaMemcpyHostToDevice, stream));
}
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/tma_descriptor.h around lines 117 to 121, the
function copy_to_device accepts a cudaStream_t stream but performs a synchronous
cudaMemcpy (ignoring the stream); either replace cudaMemcpy with cudaMemcpyAsync
using the provided stream and keep the parameter (propagate error via
FMHA_CHECK_CUDA and ensure memory size is correct), or remove the stream
parameter from the function signature and callers if asynchronous behavior is
not needed—pick one approach and update all call sites and tests accordingly.

Comment on lines +440 to +452
template <bool TB>
struct Hgmma_rfa_fp16<8, TB> {
static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[2]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
int const trans_b = TB ? 1 : 0;
asm volatile(
"wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
"{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n"

: "+r"(acc[0]), "+r"(acc[1])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_a), "l"(desc_b), "n"(trans_b));
#endif
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Undeclared desc_a and mismatched operands in Hgmma_rfa_fp16<8, TB>::mma

In the RF-A FP16 specialization for N = 8:

template <bool TB>
struct Hgmma_rfa_fp16<8, TB> {
  static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[2]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
    int const trans_b = TB ? 1 : 0;
    asm volatile(
        "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
        "{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n"

        : "+r"(acc[0]), "+r"(acc[1])
        : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_a), "l"(desc_b), "n"(trans_b));
#endif
  }
};

Issues:

  • desc_a is not a parameter and is never defined in this scope → compile-time error.
  • The instruction form with A from RF should only consume B’s descriptor plus the transpose flag, just like the 16/32/... RF-A variants.

This should be simplified to use only desc_b and trans_b, with the operand indices adjusted accordingly. For example:

 template <bool TB>
 struct Hgmma_rfa_fp16<8, TB> {
   static inline __device__ void mma(uint32_t const (&a)[4],
-                                    uint64_t desc_b, uint32_t (&acc)[2]) {
+                                    uint64_t desc_b, uint32_t (&acc)[2]) {
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
     int const trans_b = TB ? 1 : 0;
     asm volatile(
         "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
         "{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n"
 
-        : "+r"(acc[0]), "+r"(acc[1])
-        : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_a), "l"(desc_b), "n"(trans_b));
+        : "+r"(acc[0]), "+r"(acc[1])
+        : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
+          "l"(desc_b), "n"(trans_b));
 #endif
   }
 };

This matches the operand pattern used in Hgmma_rfa_fp16<16, TB> and other RF-A specializations and removes the undeclared identifier.

🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/utils_hgmma.h around lines 440-452, the inline asm
uses an undeclared desc_a and supplies too many operands; change the asm to only
consume B's descriptor and the transpose flag (desc_b and trans_b), remove the
"l"(desc_a) operand, and adjust the input operand ordering/indices so the a[]
registers remain as inputs, acc[] remain read-write outputs, and the last two
inputs are "l"(desc_b) and "n"(trans_b) matching the RF-A pattern used by other
specializations (e.g., the <16,TB> variant). Ensure the operand list and
referenced %n indices in the asm string are updated accordingly.

Comment on lines +13 to +17
#include <cuda_runtime_api.h>

#include <cstdint>

#pragma once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Move #pragma once to the top of the file.

The #pragma once directive should appear before any #include statements to ensure proper header guard behavior. Currently, it appears after the includes.

+#pragma once
+
 #include <cuda_runtime_api.h>
 
 #include <cstdint>
-
-#pragma once
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#include <cuda_runtime_api.h>
#include <cstdint>
#pragma once
#pragma once
#include <cuda_runtime_api.h>
#include <cstdint>
🧰 Tools
🪛 Clang (14.0.6)

[error] 13-13: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/numeric_types.h around lines 13 to 17, the #pragma once
header guard is placed after the #include directives; move the #pragma once line
to the very top of the file (before any #include or other preprocessor
directives) so the header guard applies immediately, preventing multiple
inclusion issues.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

♻️ Duplicate comments (18)
csrc/fmha_v2/fmha/alibi_params.h (1)

29-32: Division by zero issue already flagged.

This constructor can cause division by zero when h=0, as round_down_to_power_two(0) returns 0. This issue has already been flagged in a previous review with a suggested fix (adding an assertion or guard).

csrc/fmha_v2/fmha/hopper/tma_descriptor.h (4)

26-47: Missing copy/move semantics (Rule of Five violation).

The class manages raw memory with new[]/delete[] but lacks copy constructor, copy assignment, move constructor, and move assignment operators. Copying an instance would cause double-free.

Add deleted or properly implemented copy/move operations:

 class Multiple_tma_descriptor {
  public:
+  // Prevent accidental copies
+  Multiple_tma_descriptor(Multiple_tma_descriptor const&) = delete;
+  Multiple_tma_descriptor& operator=(Multiple_tma_descriptor const&) = delete;
+  Multiple_tma_descriptor(Multiple_tma_descriptor&&) = delete;
+  Multiple_tma_descriptor& operator=(Multiple_tma_descriptor&&) = delete;
+
   // ctor
   Multiple_tma_descriptor(int batch_size_) : batch_size(batch_size_) {

83-115: Null pointer dereference risk with default desc_ptr = nullptr.

The second set_tma_desctriptor overload defaults desc_ptr to nullptr. If called without providing a valid pointer, all set_* helpers will dereference nullptr, causing undefined behavior.

Remove the default or add validation:

   int set_tma_desctriptor(
       // ... parameters ...
-      cudaTmaDesc* desc_ptr = nullptr) {
+      cudaTmaDesc* desc_ptr) {
+    if (!desc_ptr) return -1;
     set_tensor_common_0(desc_ptr, reinterpret_cast<uint64_t>(gmem_ptr));

117-121: Unused stream parameter with synchronous copy.

The copy_to_device function accepts a cudaStream_t stream parameter but uses synchronous cudaMemcpy instead of cudaMemcpyAsync. The stream parameter is silently ignored.

Either use async copy or remove the stream parameter:

   void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) {
-    FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size,
-                               cudaMemcpyHostToDevice));
+    FMHA_CHECK_CUDA(cudaMemcpyAsync(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size,
+                                    cudaMemcpyHostToDevice, stream));
   }

49-50: Typo: "desctriptor" should be "descriptor".

The method name contains a typo at line 50 and line 84.

Consider renaming for consistency:

-  int set_tma_desctriptor(
+  int set_tma_descriptor(
csrc/fmha_v2/fmha/numeric_types.h (1)

13-17: Move #pragma once before includes.

The #pragma once directive should appear at the very top of the file, before any #include statements, to ensure the header guard takes effect immediately.

csrc/fmha_v2/convert.cu (3)

19-54: Tail elements are not converted when n is not divisible by 4.

The kernel only processes n/4 iterations, leaving 1-3 trailing elements unconverted when n % 4 != 0.


46-49: Use of reinterpret_cast for extracting int8 from PTX result.

The cvt.rni.sat.s8.f32 instruction writes to the lower byte of the 32-bit register. Using reinterpret_cast here is fragile. Consider using static_cast<int8_t>(a) instead.


181-184: Integer overflow risk in dimension multiplication.

The expression s * b * h * d performs int multiplication before conversion to size_t, risking overflow for large tensors.

csrc/fmha_v2/fmha/hopper/arrive_wait.h (2)

52-58: Missing defined() guard and non-standard CUDACC_VERSION macro.

Line 52 uses __CUDA_ARCH__ without a defined() check, and line 56 uses CUDACC_VERSION which is not a standard CUDA macro.


153-189: Potential uninitialized predicate register P3 when pred is false.

When pred is 0, the mbarrier.try_wait instruction is skipped due to the @P2 predicate, but selp.b32 %0, 1, 0, P3 still reads P3. Since P3 was never set when the instruction was skipped, its value is undefined.

csrc/fmha_v2/fmha/hopper/fragment.h (1)

133-139: Inconsistent use of NUM_REGS vs NUM_ELTS in bf16 SMEM-SMEM accumulator add().

This add() method uses Base::NUM_REGS (line 136), but the RF-SMEM variant at line 203 uses Base::NUM_ELTS. Since this is a float-based accumulator accessing via elt(), it should use NUM_ELTS for consistency and correctness.

   template <typename Other_fragment_>
   inline __device__ void add(Other_fragment_ const& other) {
-    for (int ii = 0; ii < Base::NUM_REGS; ++ii) {
+    for (int ii = 0; ii < Base::NUM_ELTS; ++ii) {
       this->elt(ii) = this->elt(ii) + other.elt(ii);
     }
   }
csrc/fmha_v2/fmha/hopper/kernel_traits.h (1)

223-227: Typo: "row marjor" should be "row major".

-  // We know V is row marjor. So we can also deduce the descriptor mode.
+  // We know V is row major. So we can also deduce the descriptor mode.
csrc/fmha_v2/fmha/hopper/compute_tile.h (1)

160-178: Wrong descriptor type used in increment_gmma_desc_a_group() loop bound.

Line 164 uses Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS but this method operates on gmma_desc_a_. This should use Smem_tile_a::Gmma_descriptor::NUM_DESCRIPTORS to ensure the iteration count matches the A-side descriptors.

   inline __device__ void increment_gmma_desc_a_group() {
     bool reset_buffer = gmma_desc_a_[0].get_descriptor(0) >= gmma_desc_a_[0].get_max_descriptor_0();

 #pragma unroll
-    for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) {
+    for (int idx = 0; idx < Smem_tile_a::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) {
 #pragma unroll
       for (int mma_m_idx = 0; mma_m_idx < MMAS_M; ++mma_m_idx) {
csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (1)

278-291: Inconsistent use of Mma_tile::VALID_MMAS_N vs Base::VALID_MMAS_N.

Line 279 uses Mma_tile::VALID_MMAS_N while line 294 uses Base::VALID_MMAS_N. The fp16 and fp32 variants consistently use Base::VALID_MMAS_N. This inconsistency could cause compilation errors if Mma_tile doesn't define VALID_MMAS_N.

 #pragma unroll
-      for (int mma_ni = 0; mma_ni < Mma_tile::VALID_MMAS_N - 1; ++mma_ni) {
+      for (int mma_ni = 0; mma_ni < Base::VALID_MMAS_N - 1; ++mma_ni) {
csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h (1)

133-138: Bug in Gmem_tile_qkv forwarding constructor: missing num_kv_heads argument.

The generic Params constructor forwards arguments but omits the num_kv_heads parameter, causing cta_row_offset to be passed as num_kv_heads and cta_col_offset_in_bytes to become cta_row_offset. This corrupts GQA/MQA head indexing and row offset calculations.

Compare to the correct bert::Fused_multihead_attention_params_v2 constructor at lines 125-131 which properly passes params.h_kv. Apply:

   template <typename Params, typename Block_info>
   inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo,
                                   int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
       : Gmem_tile_qkv(params.qkv_ptr, params.q_stride_in_bytes, params.d, params.dv, params.h,
-                      qkv_offset, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes) {}
+                      qkv_offset, binfo, tidx, params.h, cta_row_offset, cta_col_offset_in_bytes) {}

Note: If Params has a distinct KV head count (e.g., params.h_kv), use that instead of params.h.

csrc/fmha_v2/fmha/gmem_tile_o_packed.h (1)

1277-1279: Hardcoded division assumes NUM_SLICES == 2.

This issue was already raised in a previous review. Line 1279 uses rows_so_far / 2 which assumes exactly 2 slices, but line 1187 allows NUM_SLICES == 1.

csrc/fmha_v2/fmha/hopper/utils_hgmma.h (1)

440-452: Undeclared desc_a in inline asm operand list.

This issue was already raised in a previous review. The inline asm references desc_a which is not a parameter and is never defined in this scope, causing a compile-time error on SM90 builds.

🧹 Nitpick comments (31)
csrc/fmha_v2/fmha/paged_kv_cache.h (2)

48-58: Consider more robust handling of mTokensPerBlockLog2 instead of FP log2

Right now mTokensPerBlockLog2 is derived via float tokensPerBlockSeqLog2 = log2(mTokensPerBlock); and an int cast. This assumes:

  • tokensPerBlock > 0, and
  • it’s an exact power of two so FP rounding can’t bite.

Given the comment “It must be power of 2”, you might harden this by either:

  • using an integer‑based computation (e.g., std::countr_zero / __builtin_ctz on uint32_t(tokensPerBlock)), and optionally
  • adding a debug‑time assertion that tokensPerBlock is a power of two (e.g., (x & (x-1)) == 0).

That keeps behavior well‑defined even if a misconfigured value slips through, without affecting the ABI/layout of the struct.


15-16: cuda_runtime.h static‑analysis error is likely a tooling configuration issue

The reported “cuda_runtime.h file not found” from clang usually means the analysis environment isn’t pointed at the CUDA toolkit includes, not that this include is wrong. The header itself is standard for CUDA‑related code and seems appropriate here.

If this is breaking your CI linting, you probably want to:

  • ensure your compile_commands.json/clang‑tidy config passes the correct --cuda-path or -I to the CUDA include directory, or
  • gate CUDA‑dependent checks to only run where the toolkit is available.

No code change in this file seems necessary for that error.

If you’d like, I can suggest a small script to probe your current compile_commands.json and locate places where CUDA include paths are missing so you can wire them into your static‑analysis config.

csrc/fmha_v2/fmha/alibi_params.h (1)

18-25: Document preconditions or validate input range.

The function correctly rounds down to the nearest power of two for positive integers, but has undefined behavior for negative inputs due to signed right shift. While negative head counts are semantically invalid, adding documentation or a debug assertion would prevent misuse.

Consider adding a precondition comment or assertion:

 constexpr static int round_down_to_power_two(int x) {
+  // Precondition: x must be non-negative
   x = x | (x >> 1);
   x = x | (x >> 2);
   x = x | (x >> 4);
   x = x | (x >> 8);
   x = x | (x >> 16);
   return x - (x >> 1);
 }

Alternatively, note that softmax.h uses a different approach with __clz for similar rounding—consider aligning implementations for consistency.

csrc/fmha_v2/fmha/smem_tile.h (3)

222-246: Track deprecated function removal.

Multiple functions are marked with TODO comments indicating they should be removed:

  • move_next_read_buffer() (line 223)
  • move_next_read_buffer(int N) (line 234)
  • move_next_write_buffer() (line 246)

These wrapper functions delegate to their preferred equivalents (move_to_next_*). Consider creating an issue to track their removal in a future cleanup.

Would you like me to open an issue to track the removal of these deprecated functions?


591-604: Consider extracting repeated XOR offset update pattern.

This cascading if-else pattern for updating smem_read_offset_ based on ki and MMAS_K thresholds is duplicated across multiple tile implementations (Volta, Turing, Ampere variants for both A and B tiles).

The pattern appears at lines 592-603, 757-768, 982-992, 1265-1275, 1431-1441, 1664-1674, and more.

Consider extracting this into a helper template function to reduce duplication:

template <int MMAS_K_WITH_PADDING, int BYTES_PER_LDS, int SCALE = 1>
inline __device__ void update_smem_read_offset_xor(int& offset, int ki) {
    static_assert(MMAS_K_WITH_PADDING < 64, "Not implemented");
    if (MMAS_K_WITH_PADDING >= 32 && ki % 16 == 15) {
        offset ^= 31 * BYTES_PER_LDS * SCALE;
    } else if (MMAS_K_WITH_PADDING >= 16 && ki % 8 == 7) {
        offset ^= 15 * BYTES_PER_LDS * SCALE;
    } else if (MMAS_K_WITH_PADDING >= 8 && ki % 4 == 3) {
        offset ^= 7 * BYTES_PER_LDS * SCALE;
    } else if (MMAS_K_WITH_PADDING >= 4 && ki % 2 == 1) {
        offset ^= 3 * BYTES_PER_LDS * SCALE;
    } else if (MMAS_K_WITH_PADDING >= 2) {
        offset ^= 1 * BYTES_PER_LDS * SCALE;
    }
}

1993-1998: Clarify or remove commented-out static_assert.

This commented-out static assertion suggests there may be untested or unsupported configurations. If these constraints are no longer relevant, remove the dead code. If they're still important but temporarily disabled, add a comment explaining why.

-      // static_assert(BYTES_PER_MMA_PER_CTA >= 128 ||
-      //               BYTES_PER_MMA_PER_CTA ==  64 ||
-      //               (BYTES_PER_MMA_PER_CTA == 32 &&
-      //               (Mma_tile::MMAS_M == 4 ||
-      //               Mma_tile::MMAS_M == 2 ||
-      //               Mma_tile::MMAS_M == 1)), "");

Either remove entirely or add a clarifying comment like:

// TODO: Re-enable this static_assert once all configurations are validated
// static_assert(...);
csrc/fmha_v2/fmha/gmem_tile_qkv.h (1)

137-147: Consider documenting the is_active_ predicate usage pattern.

The load() overload (lines 138-147) uses is_active_ to guard the load, while the other load(Smem_tile&) overload (lines 124-135) does not. This asymmetry may be intentional for different use cases, but a brief comment explaining when each overload should be used would improve maintainability.

csrc/fmha_v2/fmha/mask.h (1)

269-352: Noted: FIXME comment indicates incomplete support.

The MtpMask implementation has a FIXME on line 346 noting that Volta/Hopper-GMMA traits are not yet supported for the mtp_token_idx_ array sizing. This should be tracked for future work.

Would you like me to open an issue to track adding Volta/Hopper-GMMA trait support for MtpMask?

csrc/fmha_v2/fmha/gmem_tile_ps.h (1)

277-312: Remove commented-out debug code.

Lines 303-306 contain commented-out debug code that should be cleaned up before merging.

Apply this diff to remove the debug code:

      for (int row_idx = 0; row_idx < ROWS_PER_THREAD; ++row_idx) {
        uint32_t acc_0 = fmha::hmul2(acc.reg(col_idx * ROWS_PER_THREAD + row_idx), scale);
-        // float one = 1.f;
-        // if(col_idx > 2){
-        //   acc_0 = float2_to_half2(one, one);
-        // }
        int64_t offset = (int64_t)row_idx * step_m + (int64_t)col_idx * step_n;
        fmha::stg(ptr + offset, acc_0);
      }
csrc/fmha_v2/fmha/numeric_types.h (1)

36-39: Fallback e4m3_t/e5m2_t as char may cause silent failures.

When FMHA_CUDA_SUPPORTS_FP8 is not defined, these types become char. Any code performing arithmetic or conversions on FP8 types will compile but produce incorrect results. Consider using a static_assert or a distinct placeholder type that fails at compile time if FP8 operations are attempted on unsupported platforms.

csrc/fmha_v2/fmha/hopper/arrive_wait.h (1)

373-378: Unused member variable id_.

The id_ member is declared and assigned in the constructor but never read. Consider removing it if unused, or document its intended purpose.

  private:
   // smem barrier base pointer
   uint64_t* bar_base_;
-  // barrier id
-  int id_;
 };
csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h (2)

128-142: Several data members appear unused.

The members preds_, fetch_, and row_ are declared but never read or written in this class. They may be vestigial from copy-paste of the non-TMA Gmem_tile_qkv class. Consider removing them to reduce memory footprint and improve clarity.

   // The stride between rows for the QKV matrice.
   int64_t params_qkv_stride_in_bytes_;
   // The pointer.
   char* qkv_ptr_;
-  // The register to store predicates.
-  uint32_t preds_[PRED_REGS];
-  // The fetch registers.
-  uint4 fetch_[LDGS];
-  // Keep track of the row the thread is processing as we move the tile.
-  int row_;
   // The sequence length.
   int actual_seqlen_;
   // tma descriptor
   cudaTmaDesc const* p_desc_;
   // coord use by TMA. For now hard code to 3D.
   int32_t coord[3];
 };

107-118: Empty commit() and store() methods may need documentation.

These methods are empty, which is likely intentional for the TMA path (TMA handles commit/store differently). Adding a brief comment would clarify that this is by design and not a missing implementation.

   // Store data to shared memory.
+  // Note: No-op for TMA path; TMA handles commit internally.
   template <typename Smem_tile>
   inline __device__ void commit(Smem_tile& smem_tile) {}
 
-  // Load data from memory.
+  // Load data from memory via TMA.
   template <typename Smem_tile>
   inline __device__ void load(Smem_tile& smem_tile) {
     smem_tile.template store<TMA_DIMS, TMA_DESC_TYPE>(p_desc_, coord);
   }
 
-  // Store data to memory.
+  // Store data to memory. No-op for TMA read-only path.
   inline __device__ void store(uint4 const (&data)[LDGS]) {}
csrc/fmha_v2/fmha/hopper/smem_tile_o.h (1)

92-96: Redundant runtime check with dead code path.

The if (Mma_tile::MMAS_N == 1) check at line 92 is always true due to the static_assert(Mma_tile::MMAS_N == 1) at line 58. The else branch with assert(false) is unreachable dead code.

If this is intentional scaffolding for future multi-MMA-N support, consider adding a comment. Otherwise, simplify:

-      if (Mma_tile::MMAS_N == 1) {
-        this->smem_write_ ^= 64;
-      } else {
-        assert(false && "Unsupported");
-      }
+      // MMAS_N == 1 enforced by static_assert above
+      this->smem_write_ ^= 64;
csrc/fmha_v2/fmha/hopper/fragment.h (1)

369-384: Consider using if constexpr for compile-time type dispatch.

The type dispatch using std::is_same_v will be evaluated at compile time, and the compiler should optimize away unreachable branches. However, using if constexpr would make the compile-time nature explicit and guarantee dead code elimination:

if constexpr (std::is_same_v<Input_type_A, e4m3_t> && std::is_same_v<Input_type_B, e4m3_t>) {
  // ...
} else if constexpr (...) {
  // ...
}

This is a minor stylistic improvement that makes intent clearer.

csrc/fmha_v2/fmha/gmem_tile_o.h (1)

282-288: Clarify USE_DEMO_BERT_PARAMS macro purpose.

The preprocessor conditional for USE_DEMO_BERT_PARAMS sets o_scratch_ptr_ to nullptr in one path. Consider adding a brief comment explaining when this macro is defined and why the scratch pointer behavior differs, to aid future maintainability.

csrc/fmha_v2/fmha/hopper/smem_tile.h (2)

377-383: Commented-out buffer management code.

The move_next_write_buffer() implementations are commented out in multiple places (lines 377-383, 538-544). If this functionality is not needed, remove the commented code. If it's planned for future use, add a TODO comment explaining when/why it will be needed.


1326-1485: Large commented-out block in load_and_store method.

The load_and_store method (lines 1326-1485) contains ~160 lines of commented-out transpose logic. This significantly impacts readability. Consider:

  1. Removing if the code is obsolete
  2. Moving to a separate function with a clear TODO if planned for future implementation
  3. Adding a brief comment explaining why it's preserved if there's a specific reason
csrc/fmha_v2/fmha/hopper/compute_tile.h (2)

273-277: Commented-out static_asserts and debug artifact.

Lines 273-277 contain commented-out static_assert statements (including a typo pstatic_assert). These appear to be debug artifacts that should either be removed or properly enabled.

-  // static_assert(Cta_tile::K == 128);
-  // static_assert(Mma_tile::K_PER_MMA_PER_CTA == 64 );
-  // pstatic_assert(NUM_KBLOCKS == 384 / 64);
   static constexpr int NUM_KBLOCKS = Smem_tile_b::BUFFERS_PER_TILE / Cta_tile::WARPS_K;
-  // static_assert(NUM_KBLOCKS * Cta_tile::WARPS_K == Smem_tile_b::BUFFERS_PER_TILE);

488-491: Outdated comment regarding fragment declaration.

The comment at line 489 asks "is is better to declare as Fragment a_?" which appears to be a leftover design note. Consider removing or updating this comment.

   // The fragments to load A.
-  // Need to think about is is better to declare as Fragment a_?
-  // for the second GEMM, MMAS_M is most likely 1. (at least for now. )
   Fragment a_[MMAS_M];
csrc/fmha_v2/fmha/hopper/gmma_descriptor.h (3)

119-132: Complex nested ternary for BYTES_PER_DESC is hard to maintain.

The nested ternary expression spanning lines 119-132 is difficult to read and verify. Consider refactoring to a constexpr helper function or using if constexpr for clarity.

Consider restructuring as:

static constexpr uint32_t compute_bytes_per_desc() {
  if constexpr (Gmma_vector_size == Gmma_descriptor_size::ALL) return 0;
  if constexpr (Gmma_trans == Gmma_descriptor_transpose::TRANS) {
    if constexpr (Gmma_mode == Gmma_descriptor_mode::SWIZZLE_128B) 
      return GMMA_K * BYTES_PER_LEADING_DIM;
    // ... etc
  }
  // ...
}
static constexpr uint32_t BYTES_PER_DESC = compute_bytes_per_desc();

260-268: increment_single_descriptor modifies desc[0] directly without bounds check.

The method assumes NUM_DESCRIPTORS >= 1. While the static_assert at line 88 enforces Gmma_vector_size == ONE, this assumption should be documented or an assert added for defensive coding.

   inline __device__ void increment_single_descriptor(bool last_of_kblock) {
+    static_assert(NUM_DESCRIPTORS >= 1, "At least one descriptor required");
     // update smem start address, which is in lower 32bits.
     int2& tmp = reinterpret_cast<int2&>(desc[0]);

514-538: Three-argument increment_single_descriptor overload has complex branching.

The overload at lines 516-538 with last_of_kblock and switch_kblock parameters has nested conditionals that are difficult to follow. Consider adding a brief comment explaining the state machine or expected call sequences.

+  // Advances descriptor address based on k-block traversal state:
+  // - switch_kblock: transitioning to a new k-block group
+  // - last_of_kblock: at the boundary requiring reset within group
   inline __device__ void increment_single_descriptor(bool last_of_kblock, bool switch_kblock) {
csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (2)

151-152: Unused member variable is_active_for_last_stg_.

The member is_active_for_last_stg_ at line 151 is declared but never assigned or used in this class or its derived classes in this file.

-  // Is the thread active for the last STG?
-  int is_active_for_last_stg_;

879-880: Multiple unused member variables.

  • is_active_for_last_stg_ (line 880) is never used
  • params_enable_i2f_trick_ (line 899) is declared as const bool = false but never referenced
-  // Is the thread active for the last STG?
-  int is_active_for_last_stg_;
   ...
-  bool const params_enable_i2f_trick_ = false;

Also applies to: 899-899

csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h (3)

539-539: Inconsistent type for col_in_bytes_ across tile classes.

In Gmem_tile_q_k_v, col_in_bytes_ is declared as int64_t (line 539), but in Gmem_tile_qkv (line 331) it's int. This inconsistency could cause issues if these types are used interchangeably or in template contexts.

Consider unifying the type to int64_t across all tile classes for consistency and to avoid potential overflow with large column offsets:

// In Gmem_tile_qkv (line 331):
-  int col_in_bytes_;
+  int64_t col_in_bytes_;

860-861: past_seqlen_ calculated but never used.

In Gmem_tile_paged_kv, past_seqlen_ is calculated at line 861 and stored at line 975 but is never referenced in any method.

If this is planned for future sliding window support, add a comment. Otherwise, remove the unused member:

-        past_seqlen_(binfo.actual_seqlen - binfo.actual_q_seqlen),
...
-  // The past sequence length (kv_seqlen - q_seqlen) considering chunked context.
-  int past_seqlen_;

976-977: sliding_window_size_ stored but never used.

Similar to past_seqlen_, the sliding_window_size_ member is initialized but never used in Gmem_tile_paged_kv. The template parameter SLIDING_WINDOW_ATTENTION_ is available but the runtime value appears unused.

csrc/fmha_v2/fmha/kernel_traits.h (1)

266-271: Consider documenting buffer count logic.

The buffer count logic at lines 266-271 has non-trivial conditions (e.g., USE_GRANULAR_TILING && D > 64). A brief comment explaining when 2 buffers are needed vs 1 would improve maintainability.

+  // Use 2 buffers for software pipelining when:
+  // - Q: granular tiling with large D, or LDGSTS with looping
+  // - K/V: granular tiling only
   enum {
     BUFFERS_PER_TILE_SMEM_Q = (USE_GRANULAR_TILING && D > 64) || (USE_LDGSTS_Q && !NO_LOOP) ? 2 : 1
   };
csrc/fmha_v2/fmha/gmem_tile_o_packed.h (2)

800-800: Unused variable row_ptr.

This variable is defined but never used in the store function.

       Stg_packed_type dst = Acc_packer<Src_type, Dst_type>::run(this, src[ii]);
-      float const* row_ptr = reinterpret_cast<float const*>(&src[ii]);

1119-1119: Unused variable row_ptr.

Same issue as in Gmem_tile_o_8bit::store - this variable is defined but never used.

       Stg_packed_type dst = Acc_packer<Src_type, Output_type>::run(this, src[ii]);
-      float const* row_ptr = reinterpret_cast<float const*>(&src[ii]);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d0d99d2 and 4d8f356.

📒 Files selected for processing (31)
  • csrc/fmha_v2/convert.cu (1 hunks)
  • csrc/fmha_v2/fmha/alibi_params.h (1 hunks)
  • csrc/fmha_v2/fmha/fragment.h (1 hunks)
  • csrc/fmha_v2/fmha/gemm.h (1 hunks)
  • csrc/fmha_v2/fmha/gmem_tile_o.h (1 hunks)
  • csrc/fmha_v2/fmha/gmem_tile_o_packed.h (1 hunks)
  • csrc/fmha_v2/fmha/gmem_tile_ps.h (1 hunks)
  • csrc/fmha_v2/fmha/gmem_tile_qkv.h (1 hunks)
  • csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/arrive_wait.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/compute_tile.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/fragment.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/gmma_descriptor.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/kernel_traits.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/smem_tile.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/smem_tile_o.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/tma_descriptor.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/tma_types.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_gmma.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_hgmma.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_igmma.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_tma.h (1 hunks)
  • csrc/fmha_v2/fmha/hopper/utils_warpgroup.h (1 hunks)
  • csrc/fmha_v2/fmha/kernel_traits.h (1 hunks)
  • csrc/fmha_v2/fmha/mask.h (1 hunks)
  • csrc/fmha_v2/fmha/numeric_types.h (1 hunks)
  • csrc/fmha_v2/fmha/paged_kv_cache.h (1 hunks)
  • csrc/fmha_v2/fmha/smem_tile.h (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/fmha_v2/fmha/gemm.h
  • csrc/fmha_v2/fmha/hopper/fragment.h
  • csrc/fmha_v2/convert.cu
  • csrc/fmha_v2/fmha/hopper/kernel_traits.h
  • csrc/fmha_v2/fmha/hopper/compute_tile.h
  • csrc/fmha_v2/fmha/hopper/smem_tile_o.h
  • csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h
  • csrc/fmha_v2/fmha/gmem_tile_o_packed.h
  • csrc/fmha_v2/fmha/hopper/utils_hgmma.h
🧬 Code graph analysis (13)
csrc/fmha_v2/fmha/hopper/arrive_wait.h (3)
csrc/fmha_v2/fmha/hopper/smem_tile.h (1)
  • fmha (22-562)
csrc/nv_internal/include/tensorrt_llm/common/cudaUtils.h (2)
  • bar_create (795-806)
  • bar_wait (853-869)
csrc/fmha_v2/fmha/warpspec/circular_buffer.h (10)
  • int (53-53)
  • int (56-56)
  • int (59-63)
  • int (66-69)
  • int (71-79)
  • int (82-86)
  • int (129-131)
  • int (134-137)
  • int (148-148)
  • int (203-203)
csrc/fmha_v2/fmha/alibi_params.h (1)
csrc/fmha_v2/fmha/softmax.h (1)
  • float (2543-2545)
csrc/fmha_v2/fmha/gmem_tile_qkv.h (2)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
  • int (186-186)
csrc/fmha_v2/fmha/utils.h (1)
  • pack_predicates (1366-1370)
csrc/fmha_v2/fmha/hopper/gmma_descriptor.h (4)
csrc/fmha_v2/fmha/hopper/compute_tile.h (1)
  • fmha (16-240)
csrc/fmha_v2/fmha/hopper/smem_tile.h (1)
  • fmha (22-562)
csrc/fmha_v2/fmha/mask.h (2)
  • int (295-301)
  • int (303-309)
csrc/fmha_v2/fmha/warpspec/epilogue.h (1)
  • int (95-104)
csrc/fmha_v2/fmha/hopper/fragment.h (4)
csrc/fmha_v2/fmha/fragment.h (17)
  • fmha (20-182)
  • add (694-698)
  • add (820-824)
  • add (865-869)
  • void (284-289)
  • void (292-294)
  • void (297-299)
  • void (308-313)
  • void (316-318)
  • void (398-403)
  • void (560-577)
  • void (592-628)
  • void (643-655)
  • void (700-704)
  • void (741-753)
  • void (826-830)
  • void (871-875)
csrc/fmha_v2/fmha/mask.h (2)
  • int (295-301)
  • int (303-309)
csrc/fmha_v2/fmha/traits.h (4)
  • public (264-264)
  • public (356-359)
  • public (433-436)
  • public (481-493)
csrc/fmha_v2/fmha/hopper/utils_qgmma.h (6)
  • qgmma_rfa_e4m3_e4m3_fp32 (438-441)
  • qgmma_rfa_e5m2_e4m3_fp32 (1550-1553)
  • qgmma_rfa_e4m3_e5m2_fp32 (994-997)
  • qgmma_rfa_e5m2_e5m2_fp32 (2082-2085)
  • qgmma_e4m3_e5m2_fp32 (716-719)
  • qgmma_e5m2_e5m2_fp32 (1804-1807)
csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h (3)
csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h (4)
  • fmha (17-276)
  • v2 (18-213)
  • commit (217-221)
  • load (225-243)
csrc/fmha_v2/fmha/hopper/tma_descriptor.h (1)
  • fmha (16-121)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
  • int (186-186)
csrc/fmha_v2/fmha/mask.h (2)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
  • int (186-186)
csrc/fmha_v2/fmha/traits.h (4)
  • public (264-264)
  • public (356-359)
  • public (433-436)
  • public (481-493)
csrc/fmha_v2/fmha/hopper/kernel_traits.h (2)
csrc/fmha_v2/fmha/gmem_tile_qkv.h (1)
  • fmha (15-162)
csrc/fmha_v2/fmha/hopper/compute_tile.h (7)
  • fmha (16-240)
  • Compute_tile_with_gmma (51-51)
  • Compute_tile_with_gmma (54-56)
  • Compute_tile_with_gmma (58-81)
  • Compute_tile_with_gmma (293-293)
  • Compute_tile_with_gmma (296-298)
  • Compute_tile_with_gmma (300-312)
csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h (2)
csrc/fmha_v2/fmha/gemm.h (1)
  • fmha (18-35)
csrc/fmha_v2/fmha/mask.h (2)
  • int (295-301)
  • int (303-309)
csrc/fmha_v2/fmha/hopper/smem_tile_o.h (1)
csrc/fmha_v2/fmha/smem_tile_o.h (16)
  • Smem_tile_o_base_8bit_mma (1284-1360)
  • store (250-286)
  • store (695-697)
  • store (702-704)
  • store (727-729)
  • Smem_tile_o (104-197)
  • Smem_tile_o (691-691)
  • Smem_tile_o (723-723)
  • Smem_tile_o (753-753)
  • Smem_tile_o (802-820)
  • Smem_tile_o (957-975)
  • Smem_tile_o (1452-1452)
  • Smem_tile_o (1466-1466)
  • Smem_tile_o (1480-1480)
  • Smem_tile_o (1494-1494)
  • Smem_tile_o (1508-1508)
csrc/fmha_v2/fmha/smem_tile.h (1)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
  • int (186-186)
csrc/fmha_v2/fmha/kernel_traits.h (1)
csrc/fmha_v2/fmha/warpspec/compute.h (1)
  • int (186-186)
csrc/fmha_v2/fmha/gmem_tile_o_packed.h (1)
csrc/fmha_v2/fmha/gmem_tile_o.h (1)
  • fmha (18-426)
🪛 Clang (14.0.6)
csrc/fmha_v2/fmha/hopper/utils_warpgroup.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gemm.h

[error] 15-15: 'fmha/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/arrive_wait.h

[error] 66-66: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 66-66: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_gmma.h

[error] 15-15: 'fmha/hopper/utils_hgmma.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/alibi_params.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gmem_tile_qkv.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/paged_kv_cache.h

[error] 15-15: 'cuda_runtime.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/numeric_types.h

[error] 13-13: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/gmma_descriptor.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_tma.h

[error] 15-15: 'fmha/hopper/tma_types.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/fragment.h

[error] 15-15: 'fmha/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gmem_tile_o.h

[error] 15-15: 'fmha/traits.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/tma_descriptor.h

[error] 14-14: 'fmha/hopper/tma_types.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h

[error] 14-14: 'fmha/traits.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_igmma.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/mask.h

[error] 15-15: 'fmha/traits.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/kernel_traits.h

[error] 14-14: 'fmha/gmem_tile_qkv_packed.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_hgmma_bf16.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/smem_tile.h

[error] 15-15: 'fmha/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/compute_tile.h

[error] 14-14: 'fmha/hopper/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/smem_tile_o.h

[error] 15-15: 'fmha/smem_tile_o.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h

[error] 14-14: 'fmha/traits.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h

[error] 14-14: 'fmha/gmem_tile_o_packed.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/tma_types.h

[error] 15-15: 'fmha/numeric_types.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/smem_tile.h

[error] 15-15: 'fmha/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/kernel_traits.h

[error] 15-15: 'fmha/alibi_params.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gmem_tile_o_packed.h

[error] 15-15: 'fmha/numeric_types.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/gmem_tile_ps.h

[error] 14-14: 'fmha/hopper/fragment.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/fragment.h

[error] 15-15: 'fmha/traits.h' file not found

(clang-diagnostic-error)

csrc/fmha_v2/fmha/hopper/utils_hgmma.h

[error] 15-15: unknown type name 'namespace'

(clang-diagnostic-error)


[error] 15-15: expected ';' after top level declarator

(clang-diagnostic-error)

Comment on lines +99 to +121
template <typename T>
__global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) {
using Dst = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type;

// The step.
size_t step = (size_t)gridDim.x * blockDim.x;

// Iterate over the elements.
for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) {
// Load 4 floats.
float4 tmp = reinterpret_cast<float4 const*>(src)[ii];
// Scale.
tmp.x *= scale;
tmp.y *= scale;
tmp.z *= scale;
tmp.w *= scale;
// Convert to 4 Ts.
auto out = pack_float4<T>(tmp);

// Store.
reinterpret_cast<Dst*>(dst)[ii] = reinterpret_cast<Dst const&>(out);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Same tail-handling issue exists in convert_fp32_to_T_kernel.

This kernel also iterates n / 4, skipping the last 1-3 elements when n is not divisible by 4. The same fix (tail loop or alignment assertion) should be applied here.

Apply a similar fix as suggested for convert_int32_to_int8_kernel:

 template <typename T>
 __global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) {
   using Dst = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type;
+  // Note: n must be a multiple of 4. Caller must ensure proper alignment.
+  assert(n % 4 == 0);
 
   // The step.
   size_t step = (size_t)gridDim.x * blockDim.x;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
template <typename T>
__global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) {
using Dst = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type;
// The step.
size_t step = (size_t)gridDim.x * blockDim.x;
// Iterate over the elements.
for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) {
// Load 4 floats.
float4 tmp = reinterpret_cast<float4 const*>(src)[ii];
// Scale.
tmp.x *= scale;
tmp.y *= scale;
tmp.z *= scale;
tmp.w *= scale;
// Convert to 4 Ts.
auto out = pack_float4<T>(tmp);
// Store.
reinterpret_cast<Dst*>(dst)[ii] = reinterpret_cast<Dst const&>(out);
}
}
template <typename T>
__global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) {
using Dst = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type;
// Note: n must be a multiple of 4. Caller must ensure proper alignment.
assert(n % 4 == 0);
// The step.
size_t step = (size_t)gridDim.x * blockDim.x;
// Iterate over the elements.
for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) {
// Load 4 floats.
float4 tmp = reinterpret_cast<float4 const*>(src)[ii];
// Scale.
tmp.x *= scale;
tmp.y *= scale;
tmp.z *= scale;
tmp.w *= scale;
// Convert to 4 Ts.
auto out = pack_float4<T>(tmp);
// Store.
reinterpret_cast<Dst*>(dst)[ii] = reinterpret_cast<Dst const&>(out);
}
}

Comment on lines +123 to +149
template <typename T>
__global__ void convert_T_to_fp32_kernel(void* dst, void const* src, size_t n, float scale = 1.f) {
using Src = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type;

union {
Src raw;
T elt[4];
} data;

// The step.
size_t step = (size_t)gridDim.x * blockDim.x;

// Iterate over the elements.
for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) {
// Load 4 floats.
data.raw = reinterpret_cast<Src const*>(src)[ii];
float4 out;
// Scale.
out.x = float(data.elt[0]) * scale;
out.y = float(data.elt[1]) * scale;
out.z = float(data.elt[2]) * scale;
out.w = float(data.elt[3]) * scale;

// Store.
reinterpret_cast<float4*>(dst)[ii] = reinterpret_cast<float4 const&>(out);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Same tail-handling issue exists in convert_T_to_fp32_kernel.

This kernel also skips trailing elements when n % 4 != 0. Add an assertion or tail-handling loop for consistency.

🤖 Prompt for AI Agents
In csrc/fmha_v2/convert.cu around lines 123-149 the convert_T_to_fp32_kernel
iterates only over n/4 packed groups and thus silently skips trailing elements
when n % 4 != 0; fix by handling the tail: after the existing loop that
processes groups of 4, compute remainder = n % 4 and the start index of the tail
= (n/4)*4, then have each thread process any remaining scalar elements whose
global index < n by loading the appropriate Src/element or byte-offset and
converting them to float* dst with the same scale, or alternately add an
assertion enforcing n % 4 == 0 if non-handling is intended; ensure all pointer
casts and bounds checks use the same element indexing as the main loop so no
out-of-bounds accesses occur.

Comment on lines +1821 to +1825
float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
float curr_sum_[ROWS_PER_THREAD] = {0};
float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
;
float prev_sum_[ROWS_PER_THREAD] = {0};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Array member initializers only set the first element.

In C++, = {-HUGE_VALF} only initializes the first element; remaining elements are value-initialized to 0.0f. If the parameterized constructor (lines 1678-1694) is called without subsequently calling the default constructor's initialization logic, the arrays will have incorrect values (first element -HUGE_VALF, rest 0.0f).

Also, line 1824 has a stray semicolon.

-  float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
-  float curr_sum_[ROWS_PER_THREAD] = {0};
-  float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
-  ;
-  float prev_sum_[ROWS_PER_THREAD] = {0};
+  float curr_max_[ROWS_PER_THREAD];
+  float curr_sum_[ROWS_PER_THREAD];
+  float prev_max_[ROWS_PER_THREAD];
+  float prev_sum_[ROWS_PER_THREAD];

Then ensure both constructors call a common initialization helper, or initialize in the member initializer list of each constructor.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/fragment.h around lines 1821-1825, the array initializers
like "= {-HUGE_VALF}" only set the first element (others become 0.0f) and line
1824 contains a stray semicolon; update the code to explicitly initialize all
elements (e.g. use std::fill_n or std::array and fill(), or a loop) so curr_max_
and prev_max_ are all -HUGE_VALF and the sum arrays are all 0.0f, remove the
stray semicolon, and refactor the two constructors to call a single init helper
(or perform the full initialization in each constructor’s member initializer
list) to guarantee consistent initialization regardless of which constructor is
used.

Comment on lines +2007 to +2011
// Update scales.
float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
float curr_sum_[ROWS_PER_THREAD] = {0};
float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
float prev_sum_[ROWS_PER_THREAD] = {0};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Same array initialization issue in Fragment_updater_ampere_fp32.

Same concern as Fragment_updater - brace initialization only sets the first array element.

🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/fragment.h around lines 2007-2011, the four local arrays
curr_max_, curr_sum_, prev_max_, prev_sum_ are incorrectly initialized with a
single braced value (which only sets element 0); replace the brace initializers
with explicit full initialization (e.g., a small for-loop setting all
ROWS_PER_THREAD entries to -HUGE_VALF or 0 respectively, or use
std::fill_n/std::fill) so every element is initialized correctly.

Comment on lines +2215 to +2218
float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
float curr_sum_[ROWS_PER_THREAD] = {0};
float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF};
float prev_sum_[ROWS_PER_THREAD] = {0};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Same array initialization issue in Fragment_updater<Volta_hmma_fp16_16x16x16_traits>.

Same pattern with incomplete array initialization via brace initializers.

🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/fragment.h around lines 2215 to 2218, the brace
initializers only set the first array element leaving the rest uninitialized;
replace the brace initialization with explicit full initialization (e.g. use a
loop or std::fill_n/ std::fill to set all ROWS_PER_THREAD entries of curr_max_
and prev_max_ to -HUGE_VALF and curr_sum_ and prev_sum_ to 0) so every element
is initialized deterministically.

Comment on lines +100 to +112
// The 512 bit of descritptro for im2col mode.
typedef struct {
uint64_t tensor_common0;
uint32_t tensor_common1;

uint32_t tensor_stride_lower[4];
uint32_t tensor_stride_upper;
uint32_t tensor_size[5];
uint32_t traversal_stride_range_c;

uint32_t box_corner_dhw;
uint32_t range_ndhw;
} cudaTmaDescIm2Col;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Minor typo in comment: "descritptro" should be "descriptor".

Line 100 has a typo in the comment.

-// The 512 bit of descritptro for im2col mode.
+// The 512 bit of descriptor for im2col mode.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// The 512 bit of descritptro for im2col mode.
typedef struct {
uint64_t tensor_common0;
uint32_t tensor_common1;
uint32_t tensor_stride_lower[4];
uint32_t tensor_stride_upper;
uint32_t tensor_size[5];
uint32_t traversal_stride_range_c;
uint32_t box_corner_dhw;
uint32_t range_ndhw;
} cudaTmaDescIm2Col;
// The 512 bit of descriptor for im2col mode.
typedef struct {
uint64_t tensor_common0;
uint32_t tensor_common1;
uint32_t tensor_stride_lower[4];
uint32_t tensor_stride_upper;
uint32_t tensor_size[5];
uint32_t traversal_stride_range_c;
uint32_t box_corner_dhw;
uint32_t range_ndhw;
} cudaTmaDescIm2Col;
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/tma_types.h around lines 100 to 112, fix the typo in
the comment "descritptro" by replacing it with "descriptor" so the comment reads
"The 512 bit descriptor for im2col mode."; leave the surrounding code and
formatting unchanged.


namespace fmha {

inline __device__ uint32_t elect_one_sync();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

elect_one_sync is declared but not defined.

This function is declared but has no definition in this file. This will cause a linker error if called. Either provide a definition or include the header where it's defined.

#!/bin/bash
# Search for elect_one_sync definition in the codebase
ast-grep --pattern 'inline __device__ uint32_t elect_one_sync() {
  $$$
}'
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/utils_tma.h around line 20, elect_one_sync is
declared as inline __device__ uint32_t elect_one_sync(); but no definition is
provided, which will cause a linker error; fix by either (A) adding the inline
__device__ definition for elect_one_sync in this header (implement the function
body here so it remains header-only and compiled into device code), or (B)
include the header that contains the existing definition (add the appropriate
#include at top of this file) and remove the stray declaration to avoid
duplicate declarations; ensure the chosen definition is marked inline/__device__
and matches all call sites to prevent ODR/linker issues.

Comment on lines +138 to +152
inline __device__ void tmastg_arrive() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.commit_group;");
#else
assert(false);
#endif
}

inline __device__ void tmastg_wait() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory");
#else
assert(false);
#endif
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

assert(false) as fallback will cause silent runtime abort in release builds.

On non-SM90 architectures, tmastg_arrive() and tmastg_wait() call assert(false), which is removed in release builds (NDEBUG defined), leaving an empty function. This could lead to subtle bugs where TMA operations appear to succeed but do nothing.

Consider using a compile-time error or a runtime trap that isn't removed in release builds:

 inline __device__ void tmastg_arrive() {
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
   asm volatile("cp.async.bulk.commit_group;");
 #else
-  assert(false);
+  __trap();  // Always terminates, not stripped in release
 #endif
 }
 
 inline __device__ void tmastg_wait() {
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
   asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory");
 #else
-  assert(false);
+  __trap();
 #endif
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
inline __device__ void tmastg_arrive() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.commit_group;");
#else
assert(false);
#endif
}
inline __device__ void tmastg_wait() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory");
#else
assert(false);
#endif
}
inline __device__ void tmastg_arrive() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.commit_group;");
#else
__trap(); // Always terminates, not stripped in release
#endif
}
inline __device__ void tmastg_wait() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(0) : "memory");
#else
__trap();
#endif
}
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/utils_tma.h around lines 138-152, the current
fallback uses assert(false) which is removed in release builds leaving empty
no-op functions; change the fallback to a non-removable failure: either emit a
compile-time failure for non-SM90 builds (e.g., preprocessor #error when
__CUDA_ARCH__ is not >=900) to prevent building unsupported code, or replace
assert(false) with a runtime trap that is not stripped in release (e.g., call a
device-level abort/trap such as __builtin_trap() or an equivalent GPU trap
intrinsic) so that calls on unsupported architectures fail loudly at runtime.
Ensure the chosen approach is used in both tmastg_arrive() and tmastg_wait().

Comment on lines +44 to +53
template <typename T>
__host__ __device__ constexpr inline float Softmax_fp_quant_scale();

template <>
__host__ __device__ constexpr inline float Softmax_fp_quant_scale<e4m3_t>() {
// Softmax has max output of 1.0, therefore we choose fp32-to-fp8 quantization scale as the
// largest power-of-2 below the e4m3 limit:
// 2^(floor(log2(E4M3_MAX / amax_exp_p))) = 2^(floor(log2(448 / 1))) = 2 ^ 8
return 256.f;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Script executed:

#!/bin/bash
# Search for Softmax_fp_quant_scale usage with e5m2_t
rg "Softmax_fp_quant_scale.*e5m2" --type cpp -A 2 -B 2

# Also search more broadly for Softmax_fp_quant_scale usage
echo "=== All Softmax_fp_quant_scale usages ==="
rg "Softmax_fp_quant_scale" --type cpp -A 2 -B 2

Repository: flashinfer-ai/flashinfer


<shell_result>
Repository: flashinfer-ai/flashinfer
Script execution failed with unrecoverable error: [CAPABILITY:RUN_SHELL_SCRIPT] Failed to clone repository flashinfer-ai/flashinfer: Clone operation failed
</shell_result>


flashinfer e5m2_t quantization type support


<web_search_result>
Yes — FlashInfer supports the FP8 E5M2 format (e5m2 / __nv_fp8_e5m2). It is listed among supported data/quantized types and used in FP8 kernels (CUTLASS / TRT-LLM) and vector layouts. [1][2][3]

References:

  • Vector/data types (includes __nv_fp8_e5m2). [1]
  • Core components / quantization formats (FP8 E5M2 supported). [2]
  • FP8/FP4 quantization docs and formats (E5M2 listed). [3]
    </web_search_result>

site:github.com flashinfer Softmax_fp_quant_scale e5m2


<web_search_result>
I found several GitHub results mentioning flashinfer and the terms softmax/Softmax_fp_quant_scale and fp8_e5m2 (KV-cache dtype). Relevant pages:

  1. FlashInfer repo (main project). [1]
  2. Issue discussing float8 (fp8_e5m2) support / dlpack error. [2]
  3. FlashInfer issue about using low-bit KV cache (per-head fp8/int8). [3]
  4. vLLM issue threads reporting errors when using FlashInfer with fp8_e5m2 KV cache. [4]

Sources:
[1] https://github.com/flashinfer-ai/flashinfer. (github.com)
[2] Issue: "Is float8 supported for flashinfer trtllm attention?" (flashinfer-ai/flashinfer). (github.com)
[3] Issue: "How to use low bit KV Cache" (flashinfer-ai/flashinfer). (github.com)
[4] vLLM issues about FlashInfer + FP8 KV cache (examples). (github.com)

If you want, I can open any of those pages and extract specific lines or error messages.

Citations:


flashinfer Softmax_fp_quant_scale template instantiation


<web_search_result>
I don’t have enough detail to diagnose that template-instantiation problem. Please paste the exact compiler error message and these details:

  • FlashInfer version (git commit or pip/package version)
  • File and line where Softmax_fp_quant_scale is defined/used (or paste the function/template code)
  • Exact compile command / compiler (g++, clang++) and flags
  • OS and CUDA/cuDNN/NVCC versions (if building CUDA)

With that I can pinpoint the template-instantiation cause and suggest a fix.
</web_search_result>


Based on my verification:

  1. Web search confirmed that FlashInfer officially supports the e5m2_t type (__nv_fp8_e5m2) as a quantized data type
  2. Active issues exist around fp8_e5m2 support, KV cache usage, and vLLM integration, indicating e5m2_t is a real, actively-used feature
  3. Code snippet shows only e4m3_t specialization exists for Softmax_fp_quant_scale, while the generic template declaration exists

The review comment's concern is valid: if Softmax_fp_quant_scale<e5m2_t>() is instantiated anywhere in the codebase (which is likely given e5m2_t is a supported type), it would fail to link since only the e4m3_t specialization is defined.


Add Softmax_fp_quant_scale specialization for e5m2_t.

The template is declared for generic T and specialized only for e4m3_t. Since FlashInfer supports e5m2_t as a quantized type, the missing specialization for e5m2_t will cause linker errors if instantiated. A similar quantization scale calculation is needed for the e5m2 format.

🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/numeric_types.h around lines 44 to 53, add a missing
specialization for Softmax_fp_quant_scale<e5m2_t>() similar to the e4m3_t case;
implement it to compute the fp32-to-fp8 quantization scale as the largest
power-of-two below the e5m2 maximum (use the project constant for E5M2_MAX or
the equivalent value) and return that as a float (i.e., mirror the comment/logic
used for e4m3_t but referencing E5M2_MAX and e5m2_t so instantiation/linking for
e5m2_t succeeds).

Comment on lines +270 to +274
// Store to the tile in shared memory.
template <int N>
inline __device__ void store(Store_type const (&data)[N], uint32_t preds) {
this->store(data, preds);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Infinite recursion in store function.

This overload calls itself recursively without termination. The function signature matches itself, causing infinite recursion at runtime.

The intent appears to be calling the array-predicate overload at line 264. Apply this fix:

   // Store to the tile in shared memory.
   template <int N>
   inline __device__ void store(Store_type const (&data)[N], uint32_t preds) {
-    this->store(data, preds);
+    uint32_t tmp[1] = {preds};
+    this->store(data, tmp);
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Store to the tile in shared memory.
template <int N>
inline __device__ void store(Store_type const (&data)[N], uint32_t preds) {
this->store(data, preds);
}
// Store to the tile in shared memory.
template <int N>
inline __device__ void store(Store_type const (&data)[N], uint32_t preds) {
uint32_t tmp[1] = {preds};
this->store(data, tmp);
}
🤖 Prompt for AI Agents
In csrc/fmha_v2/fmha/smem_tile.h around lines 270-274, the template overload
store(Store_type const (&data)[N], uint32_t preds) calls itself causing infinite
recursion; replace the self-call with a call to the array-predicate overload at
line 264 by converting the scalar preds into a small uint32_t array (e.g.
uint32_t tmp[1] = { preds };) and then call the overload that accepts a
predicate array: this->store(data, tmp); ensure the temporary array has
appropriate scope and matches the expected predicate-array parameter type.

lse_ref = torch.logsumexp(logits, -1) # (batch, num_heads, seq_len)
# Transpose to match expected shape (batch, seq_len, num_heads)
lse_ref = lse_ref.transpose(1, 2)
p = torch.softmax(logits, dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directly use lse_ref to get softmax results may be better?

descale = 256
lse = lse[:, :, :, 0] + torch.log(lse[:, :, :, 1] / descale)
else:
lse = lse[:, :, :, 0] + torch.log(lse[:, :, :, 1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be done by user outside API? and what lse used for outside attention, for training?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do this inside kernel.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, let's merge this one first and implement the in-kernel LSE calculation in a future PR.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 28, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !167 has been created, and the CI pipeline #39288528 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 merged commit b14408b into flashinfer-ai:main Nov 28, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants