Skip to content

Conversation

@lanluo-nvidia
Copy link
Collaborator

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@lanluo-nvidia lanluo-nvidia added the WIP Work is in progress, pull request should not be merged yet label Dec 10, 2025
@meta-cla meta-cla bot added the cla signed label Dec 10, 2025
@github-actions github-actions bot added the component: tests Issues re: Tests label Dec 10, 2025
@github-actions github-actions bot requested a review from narendasan December 10, 2025 02:02
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/automatic_plugin/cutile/attention.py	2025-12-12 22:37:00.909120+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/automatic_plugin/cutile/attention.py	2025-12-12 22:37:42.243426+00:00
@@ -16,20 +16,25 @@
ConstBool = ct.Constant[bool]


# --- FMHA Kernel Implementation ---
@ct.kernel(occupancy=2)
-def fmha_kernel(Q, K, V, Out,
-                qk_scale: float,
-                input_pos: int,
-                TILE_D: ConstInt,  # TILE_D = hidden_size
-                H: ConstInt,
-                TILE_M: ConstInt,
-                TILE_N: ConstInt,
-                QUERY_GROUP_SIZE: ConstInt,
-                CAUSAL: ConstBool,
-                EVEN_K: ConstBool):
+def fmha_kernel(
+    Q,
+    K,
+    V,
+    Out,
+    qk_scale: float,
+    input_pos: int,
+    TILE_D: ConstInt,  # TILE_D = hidden_size
+    H: ConstInt,
+    TILE_M: ConstInt,
+    TILE_N: ConstInt,
+    QUERY_GROUP_SIZE: ConstInt,
+    CAUSAL: ConstBool,
+    EVEN_K: ConstBool,
+):
    """
    cuTile kernel for Fused Multi-Head Attention (FMHA).
    Computes attention output for a specific batch item and head, using tiling and online softmax.
    """
    # Map block IDs to batch and head indices
@@ -57,11 +62,13 @@
    acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)

    # Load query tile for this batch, head, and M-chunk
    q = ct.load(
        Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
-    ).reshape((TILE_M, TILE_D))  # [TILE_M, TILE_D]
+    ).reshape(
+        (TILE_M, TILE_D)
+    )  # [TILE_M, TILE_D]

    # loop over k, v and update accumulator
    m_end = input_pos + (bid_x + 1) * TILE_M
    k_seqlen = K.shape[2]
    if CAUSAL:
@@ -76,16 +83,18 @@

    # Loop over K, V blocks (N-dimension chunks)
    for j in range(0, Tc):
        # --- Compute QK product ---
        k = ct.load(
-            K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
+            K,
+            index=(batch_idx, off_kv_h, 0, j),
+            shape=(1, 1, TILE_D, TILE_N),
            order=(0, 1, 3, 2),
            latency=2,
        )
        k = k.reshape((TILE_D, TILE_N))  # [TILE_D, TILE_N]
-        qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
+        qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32)
        qk = ct.mma(q, k, qk)  # [TILE_M, TILE_N]

        # --- Apply Causal Masking ---
        if (CAUSAL or not EVEN_K) and j >= mask_start:
            offs_n = j * TILE_N + offs_n_tile
@@ -113,16 +122,20 @@
        # scale acc
        acc = acc * alpha  # [TILE_M, TILE_N]

        # --- Compute PV product ---
        v = ct.load(
-            V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
+            V,
+            index=(batch_idx, off_kv_h, j, 0),
+            shape=(1, 1, TILE_N, TILE_D),
            latency=4,
-        ).reshape((TILE_N, TILE_D))  # [TILE_N, TILE_D]
+        ).reshape(
+            (TILE_N, TILE_D)
+        )  # [TILE_N, TILE_D]
        p = p.astype(Q.dtype)
        acc = ct.mma(p, v, acc)  # [TILE_M, TILE_N]
        m_i = m_ij  # [TILE_M, 1]

    # --- Final Normalization and Store ---
    acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
    acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
-    ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
\ No newline at end of file
+    ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/automatic_plugin/cutile/matmul.py	2025-12-12 22:37:00.909120+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/automatic_plugin/cutile/matmul.py	2025-12-12 22:37:42.315077+00:00
@@ -23,14 +23,18 @@
    bid_n = (bid % num_bid_in_group) // group_size_m
    return bid_m, bid_n


@ct.kernel(num_ctas=ct.ByTarget(sm_100=2))
-def matmul_kernel(A, B, C,
-                  tm: ConstInt,         # Tile size along M dimension (rows of C)
-                  tn: ConstInt,         # Tile size along N dimension (columns of C)
-                  tk: ConstInt):        # Tile size along K dimension (inner product dimension)
+def matmul_kernel(
+    A,
+    B,
+    C,
+    tm: ConstInt,  # Tile size along M dimension (rows of C)
+    tn: ConstInt,  # Tile size along N dimension (columns of C)
+    tk: ConstInt,
+):  # Tile size along K dimension (inner product dimension)
    """
    cuTile kernel for performing matrix multiplication C = A @ B.

    This kernel uses a tiled approach, where each CUDA thread block (CTA)
    computes a `tm` x `tn` tile of the output matrix C. The computation
@@ -72,16 +76,20 @@
    # are loaded, multiplied, and accumulated.
    for k in range(num_tiles_k):
        # Load tile from matrix A.
        # The `index=(bidx, k_tile_idx)` specifies which (M-tile, K-tile) to load
        # from global memory A. `shape=(tm, tk)` defines the size of this tile.
-        a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_mode=zero_pad).astype(dtype)
+        a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_mode=zero_pad).astype(
+            dtype
+        )

        # Load tile from matrix B.
        # The `index=(k_tile_idx, bidy)` specifies which (K-tile, N-tile) to load
        # from global memory B. `shape=(tk, tn)` defines the size of this tile.
-        b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_mode=zero_pad).astype(dtype)
+        b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_mode=zero_pad).astype(
+            dtype
+        )

        # Perform Matrix Multiplication for the current tiles.
        # `ct.mma` computes the product of the two loaded tiles and accumulates the result.
        accumulator = ct.mma(a, b, accumulator)

@@ -93,13 +101,13 @@
    # The `(bidx, bidy)` directly corresponds to the tile's position in the 2D output matrix.
    ct.store(C, index=(bidx, bidy), tile=accumulator)


@ct.kernel
-def matmul_split_k_kernel(A, B, C, LOCKS, COUNTS,
-                          tm: ConstInt, tn: ConstInt, tk: ConstInt,
-                          SPLIT_K: ConstInt):
+def matmul_split_k_kernel(
+    A, B, C, LOCKS, COUNTS, tm: ConstInt, tn: ConstInt, tk: ConstInt, SPLIT_K: ConstInt
+):
    GROUP_SIZE_M = 8
    M = A.shape[0]
    N = B.shape[1]
    bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M)
    bidz = ct.bid(1)
@@ -110,20 +118,25 @@

    # Convert fp32 to tf32 to use tensorcore
    dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype

    for k in range(bidz, num_tiles, SPLIT_K):
-        a = ct.load(A, index=(bidx, k), shape=(tm, tk),
-                    padding_mode=zero_pad).astype(dtype)
-        b = ct.load(B, index=(k, bidy), shape=(tk, tn),
-                    padding_mode=zero_pad).astype(dtype)
+        a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_mode=zero_pad).astype(
+            dtype
+        )
+        b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_mode=zero_pad).astype(
+            dtype
+        )
        sum = ct.mma(a, b, sum)

    sum = ct.astype(sum, C.dtype)
    lock_offset = ct.bid(0)
    count_offset = lock_offset
-    while ct.atomic_cas(LOCKS, lock_offset, 0, 1, memory_order=ct.MemoryOrder.ACQUIRE) == 1:
+    while (
+        ct.atomic_cas(LOCKS, lock_offset, 0, 1, memory_order=ct.MemoryOrder.ACQUIRE)
+        == 1
+    ):
        pass
    count = ct.gather(COUNTS, count_offset)
    if count == 0:
        ct.store(C, index=(bidx, bidy), tile=sum)
    else:
@@ -154,19 +167,23 @@
    zero_pad = ct.PaddingMode.ZERO
    # K-dimension loop
    for k in range(num_k_tiles):
        # Load tiles with 3D index and 3D shape
        # A is (Batch, M, K), load (1, tm, tk) tile
-        a = ct.load(A, index=(pid_batch, pidx, k), shape=(1, tm, tk), padding_mode=zero_pad)
+        a = ct.load(
+            A, index=(pid_batch, pidx, k), shape=(1, tm, tk), padding_mode=zero_pad
+        )
        a = ct.reshape(a, (tm, tk))  # Reshape to 2D for ct.mma

        # B is (Batch, K, N), load (1, tk, tn) tile
-        b = ct.load(B, index=(pid_batch, k, pidy), shape=(1, tk, tn), padding_mode=zero_pad)
+        b = ct.load(
+            B, index=(pid_batch, k, pidy), shape=(1, tk, tn), padding_mode=zero_pad
+        )
        b = ct.reshape(b, (tk, tn))  # Reshape to 2D for ct.mma

        accumulator = ct.mma(a, b, acc=accumulator)

    # Convert to output dtype and store
    result = ct.astype(accumulator, C.dtype)
    # Store with 3D index and 3D shape, C is (Batch, M, N)
    result_3d = ct.reshape(result, (1, tm, tn))
-    ct.store(C, index=(pid_batch, pidx, pidy), tile=result_3d)
\ No newline at end of file
+    ct.store(C, index=(pid_batch, pidx, pidy), tile=result_3d)
--- /home/runner/work/TensorRT/TensorRT/tools/llm/torchtrt_ext/attention.py	2025-12-12 22:37:00.923120+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/torchtrt_ext/attention.py	2025-12-12 22:37:47.051575+00:00
@@ -16,20 +16,25 @@
ConstBool = ct.Constant[bool]


# --- FMHA Kernel Implementation ---
@ct.kernel(occupancy=2)
-def fmha_kernel(Q, K, V, Out,
-                qk_scale: float,
-                input_pos: int,
-                TILE_D: ConstInt,  # TILE_D = hidden_size
-                H: ConstInt,
-                TILE_M: ConstInt,
-                TILE_N: ConstInt,
-                QUERY_GROUP_SIZE: ConstInt,
-                CAUSAL: ConstBool,
-                EVEN_K: ConstBool):
+def fmha_kernel(
+    Q,
+    K,
+    V,
+    Out,
+    qk_scale: float,
+    input_pos: int,
+    TILE_D: ConstInt,  # TILE_D = hidden_size
+    H: ConstInt,
+    TILE_M: ConstInt,
+    TILE_N: ConstInt,
+    QUERY_GROUP_SIZE: ConstInt,
+    CAUSAL: ConstBool,
+    EVEN_K: ConstBool,
+):
    """
    cuTile kernel for Fused Multi-Head Attention (FMHA).
    Computes attention output for a specific batch item and head, using tiling and online softmax.
    """
    # Map block IDs to batch and head indices
@@ -57,11 +62,13 @@
    acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)

    # Load query tile for this batch, head, and M-chunk
    q = ct.load(
        Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
-    ).reshape((TILE_M, TILE_D))  # [TILE_M, TILE_D]
+    ).reshape(
+        (TILE_M, TILE_D)
+    )  # [TILE_M, TILE_D]

    # loop over k, v and update accumulator
    m_end = input_pos + (bid_x + 1) * TILE_M
    k_seqlen = K.shape[2]
    if CAUSAL:
@@ -76,16 +83,18 @@

    # Loop over K, V blocks (N-dimension chunks)
    for j in range(0, Tc):
        # --- Compute QK product ---
        k = ct.load(
-            K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
+            K,
+            index=(batch_idx, off_kv_h, 0, j),
+            shape=(1, 1, TILE_D, TILE_N),
            order=(0, 1, 3, 2),
            latency=2,
        )
        k = k.reshape((TILE_D, TILE_N))  # [TILE_D, TILE_N]
-        qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
+        qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32)
        qk = ct.mma(q, k, qk)  # [TILE_M, TILE_N]

        # --- Apply Causal Masking ---
        if (CAUSAL or not EVEN_K) and j >= mask_start:
            offs_n = j * TILE_N + offs_n_tile
@@ -113,16 +122,20 @@
        # scale acc
        acc = acc * alpha  # [TILE_M, TILE_N]

        # --- Compute PV product ---
        v = ct.load(
-            V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
+            V,
+            index=(batch_idx, off_kv_h, j, 0),
+            shape=(1, 1, TILE_N, TILE_D),
            latency=4,
-        ).reshape((TILE_N, TILE_D))  # [TILE_N, TILE_D]
+        ).reshape(
+            (TILE_N, TILE_D)
+        )  # [TILE_N, TILE_D]
        p = p.astype(Q.dtype)
        acc = ct.mma(p, v, acc)  # [TILE_M, TILE_N]
        m_i = m_ij  # [TILE_M, 1]

    # --- Final Normalization and Store ---
    acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
    acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
-    ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
\ No newline at end of file
+    ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tools/llm/torchtrt_ext/attention.py	2025-12-12 22:57:38.287784+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/torchtrt_ext/attention.py	2025-12-12 22:58:21.945158+00:00
@@ -16,20 +16,25 @@
ConstBool = ct.Constant[bool]


# --- FMHA Kernel Implementation ---
@ct.kernel(occupancy=2)
-def fmha_kernel(Q, K, V, Out,
-                qk_scale: float,
-                input_pos: int,
-                TILE_D: ConstInt,  # TILE_D = hidden_size
-                H: ConstInt,
-                TILE_M: ConstInt,
-                TILE_N: ConstInt,
-                QUERY_GROUP_SIZE: ConstInt,
-                CAUSAL: ConstBool,
-                EVEN_K: ConstBool):
+def fmha_kernel(
+    Q,
+    K,
+    V,
+    Out,
+    qk_scale: float,
+    input_pos: int,
+    TILE_D: ConstInt,  # TILE_D = hidden_size
+    H: ConstInt,
+    TILE_M: ConstInt,
+    TILE_N: ConstInt,
+    QUERY_GROUP_SIZE: ConstInt,
+    CAUSAL: ConstBool,
+    EVEN_K: ConstBool,
+):
    """
    cuTile kernel for Fused Multi-Head Attention (FMHA).
    Computes attention output for a specific batch item and head, using tiling and online softmax.
    """
    # Map block IDs to batch and head indices
@@ -57,11 +62,13 @@
    acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)

    # Load query tile for this batch, head, and M-chunk
    q = ct.load(
        Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
-    ).reshape((TILE_M, TILE_D))  # [TILE_M, TILE_D]
+    ).reshape(
+        (TILE_M, TILE_D)
+    )  # [TILE_M, TILE_D]

    # loop over k, v and update accumulator
    m_end = input_pos + (bid_x + 1) * TILE_M
    k_seqlen = K.shape[2]
    if CAUSAL:
@@ -76,16 +83,18 @@

    # Loop over K, V blocks (N-dimension chunks)
    for j in range(0, Tc):
        # --- Compute QK product ---
        k = ct.load(
-            K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
+            K,
+            index=(batch_idx, off_kv_h, 0, j),
+            shape=(1, 1, TILE_D, TILE_N),
            order=(0, 1, 3, 2),
            latency=2,
        )
        k = k.reshape((TILE_D, TILE_N))  # [TILE_D, TILE_N]
-        qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
+        qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32)
        qk = ct.mma(q, k, qk)  # [TILE_M, TILE_N]

        # --- Apply Causal Masking ---
        if (CAUSAL or not EVEN_K) and j >= mask_start:
            offs_n = j * TILE_N + offs_n_tile
@@ -113,16 +122,20 @@
        # scale acc
        acc = acc * alpha  # [TILE_M, TILE_N]

        # --- Compute PV product ---
        v = ct.load(
-            V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
+            V,
+            index=(batch_idx, off_kv_h, j, 0),
+            shape=(1, 1, TILE_N, TILE_D),
            latency=4,
-        ).reshape((TILE_N, TILE_D))  # [TILE_N, TILE_D]
+        ).reshape(
+            (TILE_N, TILE_D)
+        )  # [TILE_N, TILE_D]
        p = p.astype(Q.dtype)
        acc = ct.mma(p, v, acc)  # [TILE_M, TILE_N]
        m_i = m_ij  # [TILE_M, 1]

    # --- Final Normalization and Store ---
    acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
    acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
-    ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
\ No newline at end of file
+    ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: tests Issues re: Tests WIP Work is in progress, pull request should not be merged yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants