Skip to content

Conversation

@timlee0212
Copy link
Contributor

@timlee0212 timlee0212 commented Nov 20, 2025

📌 Description

This PR porting all changes in TensorRT-LLM#8018 into Flashinfer.

Apart from the changes mentioned in the original PR, this PR also introduce new API interface as trtllm_mnnvl_allreduce and trtllm_mnnvl_fused_allreduce_add_rmsnorm to replace the original ones. The workspace allocation is wrapped as an entire class with a given buffer size and the user does not need to worry about the details inside.

This PR adds support for IPC Socket based mcast device memory bootstrap so that it can run on DGX machine that does not support fabric handle.

@wenscarl This PR also incorporate the changes made in #2056 and should be able to replace that PR. A bcast interface is added to the comm backend as this is needed during the handle transfer.

The old API is tagged as deprecated and redirected to the new APIs. The user of the old API should not need to make any changes.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Fused all‑reduce with optional RMS normalization fusion and selectable one‑shot/two‑shot strategies; new workspace‑based Python APIs and migration wrappers.
  • Improvements

    • IPC POSIX‑FD handle transfer and pluggable handle‑exchange backends; stricter input/output validation and deprecated legacy interfaces; cached CUDA SM count for runtime tuning.
  • Tests

    • MPI‑aware tests covering fused and legacy flows, workspace usage, synchronization, and expanded sequence/hidden size coverage.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 20, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Replaces legacy MNNVL all-reduce with a fused Lamport-buffer allreduce exposed as trtllm_mnnvl_allreduce_fusion; adds optional RMSNorm fusion and one-/two-shot dispatch, IPC-based POSIX-FD handle exchange and MPI bcast/barrier, updated CUDA kernels/params, Python workspace/backends, and MPI-aware tests.

Changes

Cohort / File(s) Summary
CUDA entry
csrc/trtllm_mnnvl_allreduce.cu
Renamed public entry to trtllm_mnnvl_allreduce_fusion, expanded parameters (rmsnorm_fusion, launch_with_pdl, use_oneshot, residual/gamma/epsilon), renamed in/out→input/output, added stricter shape/validation checks, built AllReduceFusionParams, and dispatches oneshot/twoshot fusion kernels.
CUDA header / kernels
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
AllReduceParamsAllReduceFusionParams, updated RMSNorm params, added flashinfer::utils helpers, packed-load primitives, Lamport buffer layout, oneshot/twoshot fused kernel declarations and grid/launch helpers, and integrated RMSNorm fusion paths.
CUDA utils
include/flashinfer/utils.cuh
Added thread-safe cached flashinfer::GetCudaMultiProcessorCount() using std::atomic.
Python IPC & MNNVL backend
flashinfer/comm/mnnvl.py
Added IpcSocket for POSIX FD transfer, introduced HandleExchanger abstractions (Fabric/PosixFD), added CommBackend.bcast/MPI bcast implementation, added comm_backend_for_handle_transfer plumbing, and updated allocation/share/import flows; alloc_and_copy_to_cuda now returns int.
High-level Python API & workspace
flashinfer/comm/trtllm_mnnvl_ar.py
Added MNNVLAllreduceFusionStrategy enum and MNNVLAllreduceFusionWorkspace, exposed trtllm_mnnvl_allreduce_fusion, added trtllm_mnnvl_allreduce and trtllm_mnnvl_fused_allreduce_add_rmsnorm with strategy selection and deprecated legacy wrappers.
Tests
tests/comm/test_trtllm_mnnvl_allreduce.py
Converted tests to MPI-aware orchestration, added prepare_test_data, refactored fused vs legacy flows (test_mnnvl_allreduce_refactored, test_mnnvl_allreduce_legacy), added barriers/logging/traceback handling, and adapted workspace usage and validations.

Sequence Diagram(s)

sequenceDiagram
    participant App as Application
    participant PyAPI as Python API / Workspace
    participant Strategy as Strategy Selector
    participant Comm as Comm Backend (MPI / IpcSocket)
    participant Kernel as CUDA Kernel (trtllm_mnnvl_allreduce_fusion)
    participant Buff as Buffers / Output

    App->>PyAPI: call trtllm_mnnvl_allreduce(...) or fused API
    PyAPI->>PyAPI: validate inputs, prepare workspace & outputs
    PyAPI->>Strategy: select ONESHOT / TWOSHOT (AUTO inspects workspace/problem)
    PyAPI->>Comm: exchange/share handles (MPI bcast/barrier or IpcSocket FD exchange)
    PyAPI->>Kernel: invoke trtllm_mnnvl_allreduce_fusion(params)
    rect rgb(245,250,255)
      Kernel->>Kernel: lamport-stage broadcast & per-token reduction
      alt RMSNorm fusion enabled
        Kernel->>Kernel: compute RMS, apply gamma, add residuals
      end
      Kernel->>Buff: write output (and residual_out if present)
    end
    Buff-->>PyAPI: return tensor(s)
    PyAPI-->>App: deliver result(s)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Areas needing extra attention:
    • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh — Lamport layout, packed loads/stores, synchronization and RMSNorm math.
    • csrc/trtllm_mnnvl_allreduce.cu — parameter mapping, strict shape/validation logic, oneshot/twoshot dispatch paths and error messages.
    • flashinfer/comm/mnnvl.pyIpcSocket/HandleExchanger semantics, FD lifecycle, and MPI/IPC interplay.
    • flashinfer/comm/trtllm_mnnvl_ar.py and tests — workspace sizing/validation, strategy heuristics, and MPI test orchestration.

Possibly related PRs

Suggested reviewers

  • djmmoss
  • cyx-6
  • yzh119
  • nvmbreughe
  • wenscarl
  • IwakuraRein
  • bkryu

Poem

🐇 I hop through lamport lanes where token shards play,

I ferry file descriptors softly, across ranks on their way,
One-shot or two-shot — fusion hums into the night,
Gamma and residuals twirl till the math feels right,
Kernels clap and buffers cheer — reduce done, carrots bright! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.11% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Refactor trtllm_mnnvl_allreduce' accurately describes the main change—a refactoring of the allreduce kernel and API, which is the primary objective of the PR.
Description check ✅ Passed The PR description covers the main objectives, related issues/PRs, and includes completed pre-commit and testing checklists, but lacks explicit detail about specific improvements and architectural changes.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @timlee0212, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a comprehensive refactoring of the multi-node NVLink (MNNVL) all-reduce system within FlashInfer. It unifies the all-reduce and RMSNorm operations into a single, highly configurable C++ kernel, exposed through intuitive new Python APIs. A key improvement is the new workspace management class, which automates and optimizes buffer allocation. Furthermore, the PR adds crucial support for IPC Socket-based handle transfer, broadening compatibility to hardware environments like DGX machines. These changes collectively enhance the flexibility, performance, and overall robustness of distributed computations.

Highlights

  • Refactored MNNVL All-Reduce Implementation: The core multi-node NVLink (MNNVL) all-reduce logic has been significantly refactored, consolidating all-reduce and RMSNorm functionalities into a single, flexible C++ kernel (trtllm_mnnvl_allreduce_fusion).
  • New Python API Interfaces: New Python APIs, trtllm_mnnvl_allreduce and trtllm_mnnvl_fused_allreduce_add_rmsnorm, have been introduced to provide clearer and more flexible usage for non-fused and fused all-reduce operations, respectively.
  • Enhanced Workspace Management: A new MNNVLAllreduceFusionWorkspace class now handles buffer allocation and management, simplifying the process for users and ensuring robust workspace sizing based on problem dimensions and chosen strategy (one-shot or two-shot).
  • IPC Socket Support for Handle Transfer: Support for IPC Socket-based device memory bootstrap has been added, enabling MNNVL operations on DGX machines and other environments that may not support fabric handles.
  • Deprecated Old APIs: The previous get_allreduce_mnnvl_workspace, trtllm_mnnvl_all_reduce, and trtllm_mnnvl_fused_allreduce_rmsnorm APIs are now marked as deprecated, though they remain functional for backward compatibility by internally calling the new fusion kernel.
  • Performance Optimizations: The CUDA kernels have been optimized with one-shot and two-shot strategies, refined Lamport synchronization, and dynamic grid configuration adjustments for improved efficiency across various problem sizes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant refactoring of the MNNVL all-reduce implementation, introducing a new, cleaner API with a dedicated workspace manager class, and adding support for IPC sockets for single-node communication. The changes are extensive and substantially improve the code's structure and capabilities. My review focuses on ensuring backward compatibility is fully maintained as intended, removing leftover debug code, improving memory usage efficiency, adding a critical safety check for buffer sizes in the new API, and suggesting a minor precision improvement in a CUDA kernel.

Comment on lines +229 to +279
def trtllm_mnnvl_allreduce(
input: torch.Tensor,
workspace: MNNVLAllreduceFusionWorkspace,
launch_with_pdl: bool,
output: Optional[torch.Tensor] = None,
strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The new trtllm_mnnvl_allreduce function is missing a check to ensure that the input tensor fits within the allocated workspace. The old API had a check like if inp.shape[0] > buffer_M: raise ValueError(...). A similar check should be added here to prevent potential out-of-bounds memory access, which could lead to crashes or incorrect results. The required buffer size depends on the strategy (one-shot vs. two-shot) and can be calculated using MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this check to be on the execution path? Or should we assuming it is the user's liability to ensure it does not overflow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do want this check. I recently added it because it did bite others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

self.mcast_device_memory.lamport_initialize(rank, dtype)

def get_mc_buffer(
def get_multicast_buffer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Renaming get_mc_buffer to get_multicast_buffer is a breaking change. The pull request description states an intention to maintain backward compatibility. To align with this, please consider re-introducing get_mc_buffer as a deprecated function that calls get_multicast_buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is used internally, and left as a placeholder but not implemented. Thus, a breaking changes is fine. Tag @nvmbreughe for confirmation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/comm/mnnvl.py (1)

132-149: alloc_and_copy_to_cuda return type and empty-input behavior are inconsistent

The function is annotated as returning int but returns None when host_ptr_array is empty. Callers currently pass non‑empty lists, but this mismatch can trip type checkers and hide bugs if an empty list is ever passed.

Either tighten behavior or relax the signature, for example:

  • If empty input is invalid, raise:
def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
    if not host_ptr_array:
        raise ValueError("host_ptr_array must be non-empty")
  • Or, if you want the sentinel, change the annotation to int | None and document the None case.
tests/comm/test_trtllm_mnnvl_allreduce.py (1)

328-427: Move allgather() and final mpi_barrier() to finally block to ensure all ranks participate in collectives

Lines 414 and 434 create a deadlock risk in error scenarios. The allgather() at line 414 is inside the except block, so only ranks that hit an exception call it. Meanwhile, the mpi_barrier() at line 434 is unconditionally called after try/except/finally. If an error occurs on some but not all ranks, failing ranks block in allgather() waiting for non-failing ranks that never enter the except block, while non-failing ranks block in the final barrier—both unable to proceed.

Move the allgather() call and final mpi_barrier() to the finally block to ensure all ranks participate in these collective operations:

rank_failed = False
try:
    ...
except Exception as e:
    rank_failed = True
    failure_message = ...
    print(failure_message)
    import traceback
    print(traceback.format_exc())
    raise
finally:
    all_failures = MPI.COMM_WORLD.allgather(rank_failed)
    if any(all_failures):
        failed_ranks = [i for i, failed in enumerate(all_failures) if failed]
        if rank == 0:
            print(f"Test failed on ranks: {failed_ranks}")
    if "workspace" in locals():
        del workspace
    trtllm_mnnvl_ar.mpi_barrier()

This applies to line 328–426 (main try/except) and line 434 (final barrier).

🧹 Nitpick comments (8)
flashinfer/comm/mnnvl.py (1)

640-655: Minor polish: unused recvmsg outputs and predictable opId

Two small, non‑blocking cleanups:

  • In IpcSocket.recv_fd(), the unpacked msg, flags, and addr from recvmsg are unused. Renaming them to _msg, _flags, _addr will make that explicit and silence linters:
_msg, ancdata, _flags, _addr = self.sock.recvmsg(...)
  • opId for the socket name is generated with random.randint. Since it’s only used as a uniqueness hint (not security‑sensitive), this is fine; if you want to appease S311 you could switch to secrets.randbits(64) or document that it’s non‑cryptographic.

Both are optional, but would make static analysis quieter.

Also applies to: 885-889

include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (3)

23-25: Explicitly include <array> and <tuple>, and guard adjustGridConfig against smCount == 0

Within this header:

  • LamportBufferLayout, LamportFlags, PackedVec, and several kernels use std::array.
  • adjustGridConfig returns std::tuple<int, int, int> and callers use std::get.

But only <type_traits> is included; <array> and <tuple> are currently pulled in (if at all) via transitive includes, which is fragile.

Also, adjustGridConfig relies on GetCudaMultiProcessorCount():

int smCount = GetCudaMultiProcessorCount();
while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512) {
  ...
}

If GetCudaMultiProcessorCount() ever returns 0 (e.g., CUDA error or misconfiguration), this loop will keep shrinking clusterSize and inflating blockSize in a somewhat opaque way.

Suggestions:

  • Add explicit includes at the top of the header:
#include <array>
#include <tuple>
  • Make adjustGridConfig robust to a 0 or negative SM count by early‑clamping:
int smCount = GetCudaMultiProcessorCount();
if (smCount <= 0) {
  // Fall back to single-SM configuration
  clusterSize = 1;
  blockSize = std::min(threadsNeeded, 1024);
  return {blockSize, clusterSize, loadsPerThread};
}

This keeps the fused path predictable even if the helper cannot obtain a valid SM count.

Also applies to: 54-177, 143-163, 291-313, 348-359, 385-419, 449-497


509-651: Confirm lamport clear / wait protocol assumptions for oneshot kernel

The oneshot fused kernel uses LamportFlags as follows:

  • Out‑of‑bounds threads call ctaArrive() then clearDirtyLamportBuf() and return.
  • In‑bounds threads:
    • write their shard into the multicast lamport buffer,
    • call ctaArrive() again,
    • then call clearDirtyLamportBuf() and spin on the Lamport buffers until all entries are non‑negZero.

This protocol assumes:

  • Every thread in the grid calls clearDirtyLamportBuf() exactly once per iteration.
  • Buffer flags and bytesToClear are correctly initialized to match the configured numTokens * tokenDim * WorldSize.

Given that this is a direct Lamport port, the logic looks consistent, but the protocol is subtle. I’d recommend:

  • Double‑checking the initialization of buffer_flags in MNNVLAllreduceFusionWorkspace matches the expectations here (current index, dirty index, bytes per buffer, and stage counts).
  • Adding a brief comment near the kernel launch documenting that buffer_flags must follow the [cur, dirty, bytes_per_buffer, dirty_num_stages, bytes_to_clear[4], access_ptr] layout used by LamportFlags.

No code change strictly required, but the invariants are nontrivial and worth locking down in comments/tests.


754-885: Two‑shot path & RMSNorm fusion: validate world sizes and loads‑per‑thread bounds

The two‑shot kernels and dispatchers introduce several constraints:

  • twoshotAllreduceFusionDispatch<T> only supports nRanks in {2, 4, 8, 16, 32, 64} and enforces tokenDim % (sizeof(float4) / sizeof(T)) == 0.
  • rmsNormLamport is instantiated with LoadsPerThread in [1, 8] and uses float4 loads into shared memory; dynamic shared memory is sized as 3 * rnBlockSize * iters * sizeof(T) and indexed accordingly.

The implementation looks coherent, but a few invariants are implicit:

  • MNNVLTwoShotStage::NUM_STAGES must stay in sync with the LamportFlags<float4> usage and the two bytes_to_clear entries in waitAndUpdate.
  • rnLoadsPerThread retrieved from adjustGridConfig must remain in [1, 8]; the default: branch already errors if it’s out of range, which is good.
  • rnClusterSize from adjustGridConfig is assumed to be <= 8 given __shared__ float sharedVal[8]; in the RMSNorm kernel.

Given these contracts, I’d suggest:

  • Adding asserts (or comments) that rnClusterSize <= 8 when CGA is used, to guard future changes to adjustGridConfig.
  • Extending tests to cover the corner cases where tokenDim is just at or above the supported boundary (e.g., maximum hidden size and multiple world sizes) so we don’t regress the FLASHINFER_CHECK conditions.

Functionally the code looks sound; this is mainly about making the implicit constraints explicit.

Also applies to: 898-959, 1062-1219

csrc/trtllm_mnnvl_allreduce.cu (1)

99-107: Error message still mentions “twoshot” even for oneshot path

Regardless of use_oneshot, the failure message says:

TVM_FFI_ICHECK(status == cudaSuccess)
    << "twoshot_allreduce_dispatch_world_size failed with error code "
    << cudaGetErrorString(status);

This is slightly misleading when the oneshot dispatch is used. Consider making the message neutral (e.g., “allreduce_fusion_dispatch failed…”) or switching on use_oneshot to provide a more accurate label. Behavior is otherwise fine.

tests/comm/test_trtllm_mnnvl_allreduce.py (2)

232-270: Use the same eps for reference RMSNorm as the fused kernel

In prepare_test_data, the fused reference path uses:

norm_out = rmsnorm(
    residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
)

But the actual fused kernel is driven by the eps argument passed into row_linear_residual_norm_fusion_forward (eps = 1e-5 in run_mnnvl_ar_full).

To keep the reference as close as possible to the fused implementation (and not rely on loose tolerances), consider:

def prepare_test_data(..., fusion: bool, eps: float):
    ...
    if fusion:
        ...
        norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False)

and threading eps through the call sites.


273-281: Annotate legacy_explicit_workspace_bytes as optional

Ruff’s RUF013 warning here is valid:

def run_mnnvl_ar_full(...,
    legacy_explicit_workspace_bytes: int = None,
    legacy_api: bool = False,
):

Changing the signature to make the optionality explicit improves readability and typing:

from typing import Optional

def run_mnnvl_ar_full(
    ...,
    legacy_explicit_workspace_bytes: Optional[int] = None,
    legacy_api: bool = False,
) -> None:
    ...

or, in Python 3.10+:

legacy_explicit_workspace_bytes: int | None = None
flashinfer/comm/trtllm_mnnvl_ar.py (1)

203-205: Drop debug print from hot path.
This unconditional print will spam stdout for every call to the fused kernel. Please remove it or guard it behind a proper debug logger.

-        print(
-            f"[Rank {rank}] Inside Kernel: multicast_buffer_ptr: {multicast_buffer_ptr:x}, buffer_ptrs_dev: {buffer_ptrs_dev:x}, buffer_ptr_local: {buffer_ptr_local:x}, buffer_flags_mnnvl: {buffer_flags_mnnvl}"
-        )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0753095 and a2670e8.

📒 Files selected for processing (6)
  • csrc/trtllm_mnnvl_allreduce.cu (1 hunks)
  • flashinfer/comm/mnnvl.py (18 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (5 hunks)
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2 hunks)
  • include/flashinfer/utils.cuh (1 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • tp_rank (325-326)
flashinfer/comm/trtllm_mnnvl_ar.py (7)
  • MNNVLAllreduceFusionWorkspace (47-141)
  • mpi_barrier (23-27)
  • trtllm_mnnvl_fused_allreduce_add_rmsnorm (301-391)
  • MNNVLAllreduceFusionStrategy (30-40)
  • trtllm_mnnvl_allreduce (229-298)
  • get_allreduce_mnnvl_workspace (398-451)
  • get_required_buffer_size_bytes (116-141)
flashinfer/comm/mnnvl.py (10)
  • barrier (168-168)
  • barrier (227-228)
  • bcast (165-165)
  • bcast (224-225)
  • get_multicast_ptr (868-872)
  • get_multicast_ptr (1191-1193)
  • get_buffer_ptrs_dev (854-856)
  • get_buffer_ptrs_dev (1199-1201)
  • get_unicast_ptr (858-866)
  • get_unicast_ptr (1195-1197)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
  • cudaSetDevice (149-150)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • trtllm_mnnvl_allreduce_fusion (168-222)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (5)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • local_rank (391-392)
  • is_multi_node (403-404)
flashinfer/jit/comm.py (1)
  • gen_trtllm_mnnvl_comm_module (33-39)
flashinfer/utils.py (2)
  • register_custom_op (273-282)
  • register_custom_op (292-311)
flashinfer/comm/mnnvl.py (13)
  • McastGPUBuffer (1121-1201)
  • CommBackend (152-171)
  • MPIBackend (211-232)
  • lamport_initialize (1101-1118)
  • lamport_initialize (1160-1161)
  • barrier (168-168)
  • barrier (227-228)
  • get_buffer_ptrs_dev (854-856)
  • get_buffer_ptrs_dev (1199-1201)
  • get_unicast_ptr (858-866)
  • get_unicast_ptr (1195-1197)
  • get_multicast_ptr (868-872)
  • get_multicast_ptr (1191-1193)
csrc/trtllm_mnnvl_allreduce.cu (2)
  • trtllm_mnnvl_allreduce_fusion (31-109)
  • trtllm_mnnvl_allreduce_fusion (31-37)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
  • checkCudaErrors (51-61)
🪛 Ruff (0.14.5)
tests/comm/test_trtllm_mnnvl_allreduce.py

279-279: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

flashinfer/comm/trtllm_mnnvl_ar.py

74-76: Avoid specifying long messages outside the exception class

(TRY003)


261-263: Avoid specifying long messages outside the exception class

(TRY003)


268-270: Avoid specifying long messages outside the exception class

(TRY003)


338-340: Avoid specifying long messages outside the exception class

(TRY003)


342-344: Avoid specifying long messages outside the exception class

(TRY003)


346-348: Avoid specifying long messages outside the exception class

(TRY003)


352-354: Avoid specifying long messages outside the exception class

(TRY003)


358-360: Avoid specifying long messages outside the exception class

(TRY003)


500-502: Avoid specifying long messages outside the exception class

(TRY003)


571-573: Avoid specifying long messages outside the exception class

(TRY003)


577-579: Avoid specifying long messages outside the exception class

(TRY003)


582-584: Avoid specifying long messages outside the exception class

(TRY003)


586-588: Avoid specifying long messages outside the exception class

(TRY003)


591-593: Avoid specifying long messages outside the exception class

(TRY003)


596-598: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/comm/mnnvl.py

587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


640-640: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


656-656: Avoid specifying long messages outside the exception class

(TRY003)


885-885: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment on lines 292 to 301
inline int GetCudaMultiProcessorCount() {
static int sm_count = 0;
if (sm_count == 0) {
int device_id;
cudaGetDevice(&device_id);
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
sm_count = device_prop.multiProcessorCount;
}
return sm_count;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make GetCudaMultiProcessorCount thread‑safe and clarify multi‑device semantics

  • static int sm_count is written without synchronization; concurrent calls from multiple host threads can cause a data race and undefined behavior.
  • The function also permanently caches the SM count of whichever device is current on the first call; if the process later switches devices, the cached value will be wrong.

Consider making the cache atomic (or using std::call_once) and, if needed, keying by device ID. For example:

- inline int GetCudaMultiProcessorCount() {
-  static int sm_count = 0;
-  if (sm_count == 0) {
-    int device_id;
-    cudaGetDevice(&device_id);
-    cudaDeviceProp device_prop;
-    cudaGetDeviceProperties(&device_prop, device_id);
-    sm_count = device_prop.multiProcessorCount;
-  }
-  return sm_count;
-}
+ inline int GetCudaMultiProcessorCount() {
+  static std::atomic<int> sm_count{0};
+  int cached = sm_count.load(std::memory_order_relaxed);
+  if (cached == 0) {
+    int device_id = 0;
+    cudaGetDevice(&device_id);
+    cudaDeviceProp device_prop{};
+    cudaGetDeviceProperties(&device_prop, device_id);
+    cached = device_prop.multiProcessorCount;
+    sm_count.store(cached, std::memory_order_relaxed);
+  }
+  return cached;
+}

(You’d also need to #include <atomic> in this header.)

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
inline int GetCudaMultiProcessorCount() {
static int sm_count = 0;
if (sm_count == 0) {
int device_id;
cudaGetDevice(&device_id);
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
sm_count = device_prop.multiProcessorCount;
}
return sm_count;
inline int GetCudaMultiProcessorCount() {
static std::atomic<int> sm_count{0};
int cached = sm_count.load(std::memory_order_relaxed);
if (cached == 0) {
int device_id = 0;
cudaGetDevice(&device_id);
cudaDeviceProp device_prop{};
cudaGetDeviceProperties(&device_prop, device_id);
cached = device_prop.multiProcessorCount;
sm_count.store(cached, std::memory_order_relaxed);
}
return cached;
}
🤖 Prompt for AI Agents
In include/flashinfer/utils.cuh around lines 292-301, the function uses a
non-synchronized static sm_count that creates a data race and also permanently
caches the SM count for whichever device was current on first call; fix by
making the lookup thread-safe and per-device: first call cudaGetDevice() to get
the current device id, then fetch properties if no cached value exists for that
device; protect the cache with thread-safety (e.g., use std::call_once
per-device, or a mutex + an unordered_map<int,int> keyed by device id, or an
atomic<int> array sized by device count) and include the required headers
(<mutex>, <unordered_map> or <atomic>); return the cached per-device SM count
after initialization.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
csrc/trtllm_mnnvl_allreduce.cu (1)

56-69: Guard RMSNorm fusion against missing residual_in and validate its shape.

The precondition at line 56 enforces presence of residual_out, gamma, and epsilon when rmsnorm_fusion=true, but omits residual_in. The kernel will dereference residualInPtr unconditionally when RMSNormFusion is true, causing undefined behavior if residual_in is absent.

Additionally, shape validation (lines 61-68) only covers residual_out and gamma; residual_in is not validated.

Extend the precondition to include residual_in:

-    TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) ||
+    TVM_FFI_ICHECK((residual_out.has_value() && residual_in.has_value() &&
+                    gamma.has_value() && epsilon.has_value()) ||
                    !rmsnorm_fusion)
-        << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true";
+        << "residual_out, residual_in, gamma, and epsilon must be provided if rmsnorm_fusion is true";

Add shape validation for residual_in within the if (rmsnorm_fusion) block:

     if (rmsnorm_fusion) {
       TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens &&
                      residual_out.value().size(1) == token_dim)
           << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
           << ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1)
           << ")";
+      TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens &&
+                     residual_in.value().size(1) == token_dim)
+          << "residual_in shape mismatch: expected (" << num_tokens << ", " << token_dim
+          << ") but got (" << residual_in.value().size(0) << ", "
+          << residual_in.value().size(1) << ")";
       TVM_FFI_ICHECK(gamma.value().size(0) == token_dim)
           << "gamma must have the same shape as token dimension (" << token_dim << ") but got ("
           << gamma.value().size(0) << ")";
     }
flashinfer/comm/trtllm_mnnvl_ar.py (1)

331-332: Restore RMSNorm epsilon default to 1e-5.

Overriding epsilon with torch.finfo(input.dtype).eps replaces the kernel's built-in 1e-5 default (see line 91 in csrc/trtllm_mnnvl_allreduce.cu). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking parity with TensorRT-LLM.

Apply this diff:

     if epsilon is None:
-        epsilon = torch.finfo(input.dtype).eps
+        epsilon = 1e-5
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

502-504: Clarify assertion for legacy API compatibility.

The assertion at lines 502-504 will fail with a cryptic message if wait_for_results=False is passed. Since this is deprecated legacy code, the assertion is reasonable, but consider improving the error message for clarity:

-    assert wait_for_results and (out is not None), (
-        "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead."
-    )
+    if not wait_for_results or out is None:
+        raise ValueError(
+            "Legacy trtllm_mnnvl_all_reduce requires wait_for_results=True and a valid output tensor. "
+            "Please use the new trtllm_mnnvl_allreduce API instead."
+        )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a2670e8 and 92cbd48.

📒 Files selected for processing (3)
  • csrc/trtllm_mnnvl_allreduce.cu (1 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (5 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (3)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
  • cudaSetDevice (149-150)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • trtllm_mnnvl_allreduce_fusion (168-219)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (5)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • local_rank (391-392)
  • is_multi_node (403-404)
flashinfer/jit/comm.py (1)
  • gen_trtllm_mnnvl_comm_module (33-39)
flashinfer/utils.py (2)
  • register_custom_op (273-282)
  • register_custom_op (292-311)
flashinfer/comm/mnnvl.py (13)
  • McastGPUBuffer (1121-1201)
  • CommBackend (152-171)
  • MPIBackend (211-232)
  • lamport_initialize (1101-1118)
  • lamport_initialize (1160-1161)
  • barrier (168-168)
  • barrier (227-228)
  • get_buffer_ptrs_dev (854-856)
  • get_buffer_ptrs_dev (1199-1201)
  • get_unicast_ptr (858-866)
  • get_unicast_ptr (1195-1197)
  • get_multicast_ptr (868-872)
  • get_multicast_ptr (1191-1193)
csrc/trtllm_mnnvl_allreduce.cu (2)
  • trtllm_mnnvl_allreduce_fusion (31-108)
  • trtllm_mnnvl_allreduce_fusion (31-37)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • tp_rank (325-326)
flashinfer/comm/trtllm_mnnvl_ar.py (7)
  • MNNVLAllreduceFusionWorkspace (47-141)
  • mpi_barrier (23-27)
  • trtllm_mnnvl_fused_allreduce_add_rmsnorm (298-388)
  • MNNVLAllreduceFusionStrategy (30-40)
  • trtllm_mnnvl_allreduce (226-295)
  • get_allreduce_mnnvl_workspace (395-448)
  • get_required_buffer_size_bytes (116-141)
🪛 Ruff (0.14.5)
flashinfer/comm/trtllm_mnnvl_ar.py

74-76: Avoid specifying long messages outside the exception class

(TRY003)


258-260: Avoid specifying long messages outside the exception class

(TRY003)


265-267: Avoid specifying long messages outside the exception class

(TRY003)


335-337: Avoid specifying long messages outside the exception class

(TRY003)


339-341: Avoid specifying long messages outside the exception class

(TRY003)


343-345: Avoid specifying long messages outside the exception class

(TRY003)


349-351: Avoid specifying long messages outside the exception class

(TRY003)


355-357: Avoid specifying long messages outside the exception class

(TRY003)


497-499: Avoid specifying long messages outside the exception class

(TRY003)


568-570: Avoid specifying long messages outside the exception class

(TRY003)


574-576: Avoid specifying long messages outside the exception class

(TRY003)


579-581: Avoid specifying long messages outside the exception class

(TRY003)


583-585: Avoid specifying long messages outside the exception class

(TRY003)


588-590: Avoid specifying long messages outside the exception class

(TRY003)


593-595: Avoid specifying long messages outside the exception class

(TRY003)

tests/comm/test_trtllm_mnnvl_allreduce.py

280-280: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

226-232: Add workspace capacity check to prevent buffer overflow.

The new trtllm_mnnvl_allreduce function doesn't verify that the input tensor fits within the allocated workspace buffer. A previous review comment suggested adding a check similar to the legacy API's if inp.shape[0] > buffer_M validation.

While the author questioned whether this should be on the execution path, buffer overflow can cause crashes or silent memory corruption. Consider adding a validation check:

required_size = MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes(
    workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy
)
if required_size > workspace.buffer_size_bytes:
    raise ValueError(
        f"Input tensor requires {required_size} bytes but workspace only has "
        f"{workspace.buffer_size_bytes} bytes. Please increase workspace size."
    )

Based on past review comments, the maintainer questioned if this check should be on the execution path. If this is intentionally omitted for performance, please document this as a user responsibility in the docstring.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
flashinfer/comm/mnnvl.py (1)

566-664: Close remaining POSIX FDs in IPC path to avoid leaks

In the POSIX handle path of _alloc_mn_mcast_mem, a few FDs are still never closed:

  • local_shareable_uc_handle returned by cuMemExportToShareableHandle (line 958) is used in the IPC ring allgather but never closed.
  • During the ring, each rank sends its local_shareable_uc_handle to all peers, including itself. The self‑recv for p == group_rank populates all_shareable_uc_handles[self.group_rank], but that FD is never imported (due to if p != self.group_rank) and also never closed.

You already close imported POSIX FDs after cuMemImportFromShareableHandle and close the multicast FD after import; closing the remaining two FDs will complete the cleanup and prevent per‑allocation FD leaks in long‑running jobs.

One way to fix this:

        if (
            self._shareable_handle_type
            == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
        ):
            # All-gather fabric handles
            all_shareable_uc_handles = self.comm_backend.allgather(
                local_shareable_uc_handle.data
            )
        else:
            # Implement the allgather logic with ipc socket
            all_shareable_uc_handles = [None] * self.group_size
            for i in range(self.group_size):
                self.comm_backend.barrier()
                # Send to peer at offset i
                dest_rank = (self.group_rank + i) % self.group_size
                self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank)
                # Receive from peer at offset -i
                src_rank = (self.group_rank + self.group_size - i) % self.group_size
                all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd()
+           if (
+               self._shareable_handle_type
+               == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
+           ):
+               # Close our exported UC handle FD and the self-received FD
+               os.close(local_shareable_uc_handle)
+               if all_shareable_uc_handles[self.group_rank] is not None:
+                   os.close(all_shareable_uc_handles[self.group_rank])

The existing per‑peer close after import:

if self._shareable_handle_type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR:
    os.close(all_shareable_uc_handles[p])

and the multicast close:

if self._shareable_handle_type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR:
    os.close(shareable_mc_handle)

can stay as‑is.

Together with the new __del__ logic calling self._ipc_socket.close(), this fully addresses the descriptor‑leak concern in the IPC path.

Also applies to: 957-1005, 1008-1055

tests/comm/test_trtllm_mnnvl_allreduce.py (1)

233-271: Align reference RMSNorm epsilon with kernel default (still using torch.finfo(dtype).eps)

prepare_test_data still uses torch.finfo(dtype).eps as the epsilon for the reference RMSNorm:

norm_out = rmsnorm(
    residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
)

while the kernel and test harness default to eps = 1e-5 (see run_mnnvl_ar_full and the C++ FFI wrapper’s params.epsilon default). This inconsistency can mask subtle discrepancies behind loose tolerances or cause avoidable test drift.

To keep the reference path exactly aligned with the implementation, switch this to the same constant:

-        norm_out = rmsnorm(
-            residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
-        )
+        norm_out = rmsnorm(
+            residual_out,
+            norm_weight,
+            1e-5,
+            enable_pdl=False,
+        )

(or better, reuse the same eps value passed into run_mnnvl_ar_full to avoid hard‑coding the constant twice).

🧹 Nitpick comments (4)
csrc/trtllm_mnnvl_allreduce.cu (1)

100-114: Ensure epsilon defaults stay consistent with Python API and tests

Here params.epsilon falls back to 1e-5 when the Optional epsilon is not provided:

params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5;

The Python wrapper in flashinfer/comm/trtllm_mnnvl_ar.py and the tests in tests/comm/test_trtllm_mnnvl_allreduce.py should use the same default to avoid silent discrepancies between the kernel and reference paths. The core test harness already sets eps = 1e-5; the remaining mismatch is in the reference RMSNorm computation (see prepare_test_data), which still uses torch.finfo(dtype).eps.

flashinfer/comm/mnnvl.py (3)

132-150: Fix alloc_and_copy_to_cuda return type vs None behavior

alloc_and_copy_to_cuda is annotated as returning int but still returns None for an empty host_ptr_array:

def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
    if not host_ptr_array:
        return None

Current call sites (signal_pads and uc_ptrs) always pass non‑empty lists, so behavior is correct, but the annotation is misleading and could hide bugs if the function gets reused.

Either make the return type explicit about None or enforce non‑emptiness by raising:

-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
-    if not host_ptr_array:
-        return None
+def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
+    if not host_ptr_array:
+        raise ValueError("host_ptr_array must be non-empty")

(or change the annotation to Optional[int] if you prefer the sentinel behavior).


885-893: IPC opId bootstrap looks fine; consider documenting ordering guarantees

_init_ipc_socket uses an MPI‑like bcast to distribute a randomly chosen opId from rank 0, then uses it to construct IpcSocket endpoints on all ranks. This nicely avoids hard‑coding operation IDs and lines up with the C++ IPC model.

Given the reliance on collective barriers around send_fd/recv_fd, it would help future maintainers to mention in a comment here that all ranks are expected to participate in the same sequence of IPC operations for a given opId, and that mismatched usage will deadlock. The code is correct as written; this is just a documentation/clarity suggestion.


1143-1170: McastGPUBuffer workspace integration and pointer getters look consistent

The new comm_backend_for_handle_transfer parameter is threaded through to McastDeviceMemory, and the added get_unicast_ptr wrapper simply delegates to mcast_device_memory.get_unicast_ptr(rank). This lines up with how tests and get_allreduce_mnnvl_workspace use these pointers and keeps pointer access encapsulated.

The placeholder buffer‑view methods (get_multicast_buffer, get_unicast_buffer) are clearly marked NotImplementedError, so they won’t be hit accidentally. If you plan to expose tensor views later, you can implement them via create_tensor_from_cuda_memory.

Also applies to: 1209-1212

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 92cbd48 and 5be2697.

📒 Files selected for processing (4)
  • csrc/trtllm_mnnvl_allreduce.cu (1 hunks)
  • flashinfer/comm/mnnvl.py (18 hunks)
  • include/flashinfer/utils.cuh (2 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/utils.cuh
🧬 Code graph analysis (3)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • tp_rank (325-326)
flashinfer/comm/trtllm_mnnvl_ar.py (7)
  • MNNVLAllreduceFusionWorkspace (47-141)
  • mpi_barrier (23-27)
  • trtllm_mnnvl_fused_allreduce_add_rmsnorm (298-388)
  • MNNVLAllreduceFusionStrategy (30-40)
  • trtllm_mnnvl_allreduce (226-295)
  • get_allreduce_mnnvl_workspace (395-448)
  • get_required_buffer_size_bytes (116-141)
flashinfer/comm/mnnvl.py (14)
  • barrier (168-168)
  • barrier (227-228)
  • Get_rank (156-156)
  • Get_rank (215-216)
  • Get_size (159-159)
  • Get_size (218-219)
  • bcast (165-165)
  • bcast (224-225)
  • get_multicast_ptr (871-875)
  • get_multicast_ptr (1205-1207)
  • get_buffer_ptrs_dev (857-859)
  • get_buffer_ptrs_dev (1213-1215)
  • get_unicast_ptr (861-869)
  • get_unicast_ptr (1209-1211)
csrc/trtllm_mnnvl_allreduce.cu (2)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • trtllm_mnnvl_allreduce_fusion (168-219)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
  • checkCudaErrors (51-61)
🪛 Ruff (0.14.5)
flashinfer/comm/mnnvl.py

587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


640-640: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


656-656: Avoid specifying long messages outside the exception class

(TRY003)


888-888: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (4)
include/flashinfer/utils.cuh (1)

293-307: Thread‑safe SM count cache looks good; confirm single‑GPU‑per‑process assumption

Using static std::atomic<int> with relaxed loads/stores makes this helper thread‑safe and avoids the previous static int data race. The comment explicitly assumes one CUDA device per process, since the cached sm_count is never recomputed if the current device changes.

If there are any call sites that may run in a multi‑GPU‑per‑process setup, consider extending this to a per‑device cache (e.g., keyed by device id) rather than a single global integer; otherwise, this implementation is fine as long as the single‑device assumption holds.

csrc/trtllm_mnnvl_allreduce.cu (1)

41-76: RMSNorm fusion validation and shape checks look correct

The updated precondition now correctly requires residual_in, residual_out, gamma, and epsilon when rmsnorm_fusion is true, and the subsequent shape checks on residual_in, residual_out, and gamma guard the fused path against mismatched tensors. This should prevent the fused kernels from ever seeing invalid residual/norm inputs via the FFI boundary.

The overall parameter wiring into AllReduceFusionParams (including buffer pointers and flags) also looks consistent with the Python side.

flashinfer/comm/mnnvl.py (1)

781-790: Good: IPC socket is now closed in destructor

The addition of:

if hasattr(self, "_ipc_socket"):
    self._ipc_socket.close()

inside __del__ ensures the Unix domain socket is closed and, for non‑abstract sockets, the filesystem entry is unlinked. This addresses the earlier socket‑leak concern while remaining safe when construction fails before _ipc_socket is set.

tests/comm/test_trtllm_mnnvl_allreduce.py (1)

16-103: Test harness refactor cleanly exercises both refactored and legacy APIs

The new helpers (row_linear_residual_norm_fusion_forward, _legacy, run_mnnvl_ar_full) and parametrized tests (test_mnnvl_allreduce_refactored, test_mnnvl_allreduce_legacy) do a good job of:

  • Sharing core logic between fused and non‑fused paths.
  • Covering both the new workspace‑based API and the legacy pointer‑based API.
  • Exercising a variety of sequence lengths, dtypes, and hidden sizes.
  • Integrating MPI barriers and rank‑aware logging to make multi‑rank failures diagnosable.

Once the epsilon alignment in prepare_test_data is fixed, this test suite should give solid coverage for the new fused implementation and its backward‑compatibility guarantees.

Also applies to: 274-397, 439-465

comm_backend: Optional[CommBackend] = None,
):
"""
Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way we can check this?

Copy link
Contributor Author

@timlee0212 timlee0212 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to update the doc. Fixed.

def __init__(
self,
mapping: Mapping,
buffer_size_in_bytes: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide guidance for buffer_size_in_bytes? E.g., in function of number of tokens and hidden size? Or just refer to get_required_buffer_size_bytes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just refer to get_required_buffer_size_bytes

Comment on lines +229 to +279
def trtllm_mnnvl_allreduce(
input: torch.Tensor,
workspace: MNNVLAllreduceFusionWorkspace,
launch_with_pdl: bool,
output: Optional[torch.Tensor] = None,
strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do want this check. I recently added it because it did bite others.

def __init__(
self,
mapping: Mapping,
buffer_size_in_bytes: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to replace "buffer_size_in_bytes" by the parameters that get_required_buffer_size_bytes takes, and just call this from the init function. Seems more user friendly.

If you do want to just allocate a blob of memory, we could still have buffer_size_in_bytes as an addtional parameter that would override whatever is calculated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that is a good design as it might give the user an impression that the allocated workspace will ONLY support the set of parameters (max_num_tokens, hidden_dim, dtype, strategy)

But actually, the workspace usage is quite flexible and as long as the required workspace size is smaller than the allocation, it will work. Thus, the intended usage is the user checks the required workspace size (or we can check it when calling the allreduce function, but at some cost) before using it.

Optional<TensorView> out) {
cudaSetDevice(in.device().device_id);
auto stream = get_stream(in.device());
// FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is a problem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment removed.

AUTO = 99

@staticmethod
def is_one_shot(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename this to "heuristic_for_one_shot" or something like that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

361-362: Critical: Restore epsilon default to 1e-5 to match kernel.

This epsilon fallback was flagged as critical in a previous review but remains unresolved. Using torch.finfo(input.dtype).eps sets epsilon to approximately 1e-3 for fp16, diverging from the kernel's built-in 1e-5 default (see csrc/trtllm_mnnvl_allreduce.cu line 96). This materially alters RMSNorm results and breaks compatibility with TensorRT-LLM.

Apply this fix:

-    if epsilon is None:
-        epsilon = torch.finfo(input.dtype).eps
+    if epsilon is None:
+        epsilon = 1e-5
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

118-136: Consider replacing @functools.cache on instance method.

Using @functools.cache on an instance method can prevent the instance from being garbage collected, leading to memory leaks. Since this method takes self as the first parameter, the cache will hold references to the instance.

Consider either:

  1. Making this a standalone function that takes workspace parameters explicitly
  2. Using @functools.lru_cache(maxsize=...) with a reasonable limit
  3. Implementing manual caching in the instance if needed

Based on learnings

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5be2697 and c6ed147.

📒 Files selected for processing (2)
  • csrc/trtllm_mnnvl_allreduce.cu (1 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (2)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (6)
  • Mapping (21-475)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • local_rank (391-392)
  • is_multi_node (403-404)
flashinfer/jit/comm.py (1)
  • gen_trtllm_mnnvl_comm_module (33-39)
flashinfer/utils.py (2)
  • register_custom_op (273-282)
  • register_custom_op (292-311)
flashinfer/comm/mnnvl.py (13)
  • McastGPUBuffer (1135-1215)
  • CommBackend (152-171)
  • MPIBackend (211-232)
  • lamport_initialize (1115-1132)
  • lamport_initialize (1174-1175)
  • barrier (168-168)
  • barrier (227-228)
  • get_buffer_ptrs_dev (857-859)
  • get_buffer_ptrs_dev (1213-1215)
  • get_unicast_ptr (861-869)
  • get_unicast_ptr (1209-1211)
  • get_multicast_ptr (871-875)
  • get_multicast_ptr (1205-1207)
csrc/trtllm_mnnvl_allreduce.cu (2)
  • trtllm_mnnvl_allreduce_fusion (29-113)
  • trtllm_mnnvl_allreduce_fusion (29-35)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
  • cudaSetDevice (149-150)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
  • trtllm_mnnvl_allreduce_fusion (192-243)
🪛 Ruff (0.14.5)
flashinfer/comm/trtllm_mnnvl_ar.py

77-79: Avoid specifying long messages outside the exception class

(TRY003)


118-118: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)


282-284: Avoid specifying long messages outside the exception class

(TRY003)


289-291: Avoid specifying long messages outside the exception class

(TRY003)


303-305: Avoid specifying long messages outside the exception class

(TRY003)


365-367: Avoid specifying long messages outside the exception class

(TRY003)


369-371: Avoid specifying long messages outside the exception class

(TRY003)


373-375: Avoid specifying long messages outside the exception class

(TRY003)


379-381: Avoid specifying long messages outside the exception class

(TRY003)


385-387: Avoid specifying long messages outside the exception class

(TRY003)


398-400: Avoid specifying long messages outside the exception class

(TRY003)


528-530: Avoid specifying long messages outside the exception class

(TRY003)


599-601: Avoid specifying long messages outside the exception class

(TRY003)


605-607: Avoid specifying long messages outside the exception class

(TRY003)


610-612: Avoid specifying long messages outside the exception class

(TRY003)


614-616: Avoid specifying long messages outside the exception class

(TRY003)


619-621: Avoid specifying long messages outside the exception class

(TRY003)


624-626: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (4)
csrc/trtllm_mnnvl_allreduce.cu (1)

29-113: LGTM! Fusion entry point is well-structured.

The refactored entry point properly validates all inputs, including the RMSNorm fusion parameters that were flagged in previous reviews. The dispatch logic cleanly selects between oneshot and twoshot strategies, and error messages are clear and actionable.

flashinfer/comm/trtllm_mnnvl_ar.py (3)

30-48: Strategy enum and heuristic look good.

The MNNVLAllreduceFusionStrategy enum provides a clear interface for selecting between oneshot and twoshot approaches, with a sensible AUTO mode that uses an empirically-derived threshold.


250-326: Buffer size validation properly implemented.

The function now includes the buffer size check that was requested in previous reviews (lines 300-305), preventing potential out-of-bounds access. Input validation is comprehensive and error messages are clear.


422-646: Deprecation strategy is well-executed.

The legacy APIs are properly marked with @deprecated decorators and include clear migration guidance. The wrappers correctly redirect to the new fusion-based implementations, maintaining backward compatibility while encouraging adoption of the improved APIs.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)

132-149: Fix return type inconsistency.

The function returns None at line 137 when host_ptr_array is empty, but the return type annotation at line 132 indicates int. This creates a type mismatch.

Consider one of these fixes:

Option 1: Return Optional[int] and update callers to handle None:

-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
+def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:

Option 2: Raise an error instead of returning None:

     if not host_ptr_array:
-        return None
+        raise ValueError("host_ptr_array cannot be empty")
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

377-378: Restore RMSNorm epsilon default to 1e-5.

Overriding epsilon with torch.finfo(input.dtype).eps replaces the kernel's built-in 1e-5 default (see trtllm_mnnvl_allreduce_fusion in csrc/trtllm_mnnvl_allreduce.cu line ~35: params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking numerical parity.

Apply this diff to fix:

     if epsilon is None:
-        epsilon = torch.finfo(input.dtype).eps
+        epsilon = 1e-5
🧹 Nitpick comments (3)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

134-152: Consider alternatives to @functools.cache on instance methods.

Using @functools.cache (or @lru_cache) on instance methods can prevent garbage collection of instances because the cache holds references to bound methods, which in turn hold references to self. Since MNNVLAllreduceFusionWorkspace instances are likely long-lived in typical usage, this may be acceptable, but consider these alternatives:

  • Use @functools.lru_cache(maxsize=128) to limit cache growth
  • Move caching logic to a module-level cache keyed on relevant parameters
  • Document the caching behavior and its memory implications

Based on learnings

Apply this diff if you want to limit cache size:

-    @functools.cache
+    @functools.lru_cache(maxsize=128)
     def is_buffer_size_sufficient(
flashinfer/comm/mnnvl.py (2)

640-654: Prefix unused unpacked variables with underscore.

The variables msg, flags, and addr from recvmsg are unpacked but never used. Prefix them with _ to indicate they're intentionally ignored.

Apply this diff:

-        msg, ancdata, flags, addr = self.sock.recvmsg(
+        _msg, ancdata, _flags, _addr = self.sock.recvmsg(

893-900: Consider using secrets module for opId generation.

While cryptographic randomness is not strictly required for socket naming, using secrets.randbelow(2**64) instead of random.randint provides better collision resistance if multiple jobs run concurrently on the same node.

Apply this diff:

+import secrets
+
     def _init_ipc_socket(self):
         if self.group_rank == 0:
-            # Gnerate the opId
-            opId = random.randint(0, 2**64 - 1)
+            # Generate the opId
+            opId = secrets.randbelow(2**64)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c6ed147 and a390685.

📒 Files selected for processing (2)
  • flashinfer/comm/mnnvl.py (19 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (2)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
  • checkCudaErrors (51-61)
flashinfer/comm/trtllm_mnnvl_ar.py (4)
flashinfer/comm/mapping.py (5)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • local_rank (391-392)
  • is_multi_node (403-404)
flashinfer/utils.py (2)
  • register_custom_op (273-282)
  • register_custom_op (292-311)
flashinfer/comm/mnnvl.py (13)
  • McastGPUBuffer (1143-1224)
  • CommBackend (152-171)
  • MPIBackend (211-232)
  • lamport_initialize (1123-1140)
  • lamport_initialize (1183-1184)
  • barrier (168-168)
  • barrier (227-228)
  • get_buffer_ptrs_dev (857-859)
  • get_buffer_ptrs_dev (1222-1224)
  • get_unicast_ptr (861-869)
  • get_unicast_ptr (1218-1220)
  • get_multicast_ptr (871-875)
  • get_multicast_ptr (1214-1216)
csrc/trtllm_mnnvl_allreduce.cu (2)
  • trtllm_mnnvl_allreduce_fusion (29-113)
  • trtllm_mnnvl_allreduce_fusion (29-35)
🪛 Ruff (0.14.5)
flashinfer/comm/mnnvl.py

587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


640-640: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


656-656: Avoid specifying long messages outside the exception class

(TRY003)


896-896: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)

flashinfer/comm/trtllm_mnnvl_ar.py

77-79: Avoid specifying long messages outside the exception class

(TRY003)


134-134: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)


298-300: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


319-321: Avoid specifying long messages outside the exception class

(TRY003)


381-383: Avoid specifying long messages outside the exception class

(TRY003)


385-387: Avoid specifying long messages outside the exception class

(TRY003)


389-391: Avoid specifying long messages outside the exception class

(TRY003)


395-397: Avoid specifying long messages outside the exception class

(TRY003)


401-403: Avoid specifying long messages outside the exception class

(TRY003)


414-416: Avoid specifying long messages outside the exception class

(TRY003)


544-546: Avoid specifying long messages outside the exception class

(TRY003)


615-617: Avoid specifying long messages outside the exception class

(TRY003)


621-623: Avoid specifying long messages outside the exception class

(TRY003)


626-628: Avoid specifying long messages outside the exception class

(TRY003)


630-632: Avoid specifying long messages outside the exception class

(TRY003)


635-637: Avoid specifying long messages outside the exception class

(TRY003)


640-642: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

266-341: LGTM!

The function correctly validates inputs, selects strategy, checks buffer size sufficiency (addressing past review feedback), and invokes the fusion kernel with appropriate parameters.

flashinfer/comm/mnnvl.py (1)

788-789: LGTM!

The IPC socket cleanup correctly uses hasattr to check for existence before closing, addressing the file descriptor leak concern from past reviews.

#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
asm volatile("red.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1)
: "memory");
#else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the kernel is supported for archs < 700? What's the minimal requirement?
For the API we use the @backend_requirement decorator, which lists supported SMs. So as a minimum I think we can list: 70,80,90,100,103,110,120

Would you agree? Further back is probably not as relevant.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This macro is fairly common in barrier and semaphore utility functions, largely due to the memory consistency qualifiers introduced with the Volta architecture. For example, CUTLASS uses a similar pattern:

https://github.com/NVIDIA/cutlass/blob/e67e63c331d6e4b729047c95cf6b92c8454cba89/include/cutlass/barrier.h#L116-L129

That said, I believe our usage here simply follows established convention. Given that our minimum supported architecture is sm_75, we shouldn't actually need these qualifiers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel needs multicast to work, which at least requires SM90 and needs NVSwitch

if (NUM_INPUTS > 0) {
T_IN accum[ELTS_PER_THREAD];
float4* accum4 = (float4*)&accum;
flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of clearing the buffer here, could we assume we have properly initialized lamport buffers at the start?
And then, at the end (e.g., right after we call PDL), we can clear the buffers, so that a next kernel using the same workspace can also assume properly initialized lamport buffers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use triple buffer so you can move the buffer clear anywhere. I found the current arrangement get the best performance.

If your assumption is using single buffer, clear the buffer at the end then assume the buffer is initialized for the next kernel, this won't work as it requires membar.sys which is very expensive. Moreover, we need to order the buffer write of the next kernel, other ranks after the buffer clear of this kernel, which has to use a flag and invalidate the usage of lamport protocol.

#pragma unroll
for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) {
valid &= !isNegZero(valuesLamport[r].elements[i]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as one of the values in WorldSize isNegZero, we will keep reading the others over and over again. Perhaps these are all cache hits, in which case: ignore my comment. Otherwise, would it make sense to break out of the inner loop (kLAMPORT_ELTS_PER_PACKED) early?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be cache hit if it is still invalid.

For the inner loop, I had a version that only check one element for validity. But there is no architecture guarantee that a 128B read/write is atomic, which mean beak the inner loop early could, in rare case, cause wrong result.

all_shareable_uc_handles = self.comm_backend.allgather(
local_shareable_uc_handle.data
)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Add a comment regarding handle type in this case: CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Branching logic no longer exists in the latest commit.

mc_fabric_handle.data if mc_fabric_handle else None, root=0
)
shareable_mc_handle = None
if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With all the branching on this handle type, would it make sense to build a class that encapsulate this logic?

Roughly thinking:

class ABCCommBufferAllocator:
def allocate_UnicastBuffer()
def allocate_MulitcastBuffer()

class CUFabricCommBufferAllocator(ABCCommBufferAllocator)
class CUPosixFDCommBufferAllocator(ABCCommBufferAllocator)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored in the latest commit.

self.mcast_device_memory.lamport_initialize(rank, dtype)

def get_mc_buffer(
def get_multicast_buffer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)

612-648: Duplicate constructor parameter breaks McastDeviceMemory

McastDeviceMemory.__init__ declares comm_backend_for_handle_transfer twice in the parameter list, which is a syntax error and prevents the module from importing.

Remove the duplicate parameter so only one comm_backend_for_handle_transfer argument remains:

-        is_multi_node: bool = True,
-        comm_backend_for_handle_transfer: Optional[CommBackend] = None,
-        comm_backend_for_handle_transfer: Optional[CommBackend] = None,
+        is_multi_node: bool = True,
+        comm_backend_for_handle_transfer: Optional[CommBackend] = None,
♻️ Duplicate comments (3)
flashinfer/comm/mnnvl.py (1)

875-901: Local POSIX FD from cuMemExportToShareableHandle is never closed

In _alloc_mn_mcast_mem, the POSIX path exports local_shareable_uc_handle and circulates it via IpcSocket. Remote FDs are correctly closed after import, but the original exported FD (local_shareable_uc_handle) is never closed, leading to a per-rank FD leak over time. This was previously flagged; the remote-close part is fixed, but the local FD still needs closing.

Close the local FD after the ring allgather in the POSIX branch:

        if self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC:
            # All-gather fabric handles
            all_shareable_uc_handles = self.comm_backend.allgather(local_shareable_uc_handle.data)
        else:
            # Implement the allgather logic with ipc socket
            all_shareable_uc_handles = [None] * self.group_size
            for i in range(self.group_size):
                self.comm_backend.barrier()
                # Send to peer at offset i
                dest_rank = (self.group_rank + i) % self.group_size
                self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank)
                # Receive from peer at offset -i
                src_rank = (self.group_rank + self.group_size - i) % self.group_size
                all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd()
+           # Local FD no longer needed after all peers have imported it
+           os.close(local_shareable_uc_handle)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

354-385: RMSNorm epsilon default diverges from CUDA/kernel and prior review

trtllm_mnnvl_fused_allreduce_add_rmsnorm currently does:

if epsilon is None:
    epsilon = torch.finfo(input.dtype).eps

The CUDA entry point (trtllm_mnnvl_allreduce_fusion) still defaults epsilon to 1e-5 when the Optional is not set, and TensorRT-LLM uses 1e-5 as its RMSNorm default. Because the Python wrapper always passes an explicit epsilon, the kernel’s 1e-5 default is never used, and for fp16 this changes behavior materially (~1e-3 vs 1e-5).

To preserve parity and avoid surprising numerics, consider restoring the 1e-5 default here and updating the docstring accordingly:

-    if epsilon is None:
-        epsilon = torch.finfo(input.dtype).eps
+    if epsilon is None:
+        epsilon = 1e-5

and:

-        epsilon: The epsilon parameter for RMSNorm, torch.finfo.eps will be used if not provided.
+        epsilon: The epsilon parameter for RMSNorm; defaults to 1e-5 if not provided.

You’ll also want tests to use the same constant for their reference RMSNorm (see prepare_test_data in the test file).

tests/comm/test_trtllm_mnnvl_allreduce.py (1)

242-280: Reference RMSNorm epsilon should match kernel/API default

prepare_test_data builds the reference RMSNorm output with:

norm_out = rmsnorm(
    residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
)

While run_mnnvl_ar_full sets eps = 1e-5 and passes that into the forward helpers, and the CUDA side uses 1e-5 as its default. Once the Python API is updated to default epsilon to 1e-5, the reference path here should also use 1e-5 to avoid systematic discrepancies.

Suggested change:

-        norm_out = rmsnorm(
-            residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
-        )
+        norm_out = rmsnorm(
+            residual_out, norm_weight, 1e-5, enable_pdl=False
+        )

This keeps tests aligned with the kernel’s behavior and the Python wrapper’s default.

🧹 Nitpick comments (2)
flashinfer/comm/mnnvl.py (1)

127-143: alloc_and_copy_to_cuda return type vs behavior

alloc_and_copy_to_cuda is annotated to return int but returns None when host_ptr_array is empty. This is harmless for current call sites (they always pass non-empty lists) but makes the annotation misleading and can trip static type checkers.

Consider either:

  • Changing the annotation to Optional[int] and documenting the None case, or
  • Removing the early return and letting the function always return a valid device pointer (e.g., disallow empty input or treat it as a zero-byte allocation).
include/flashinfer/utils.cuh (1)

24-25: Thread-safe SM count cache looks good; note single-device assumption

Using std::atomic<int> for sm_count removes the data race from the previous static-int pattern while keeping the lookup cheap. The comment correctly calls out that this caches the SM count for whichever device is current on first use, assuming one CUDA device per process.

If you ever need multi-device-per-process support, consider extending this to cache per-device (e.g., a small std::vector<std::atomic<int>> indexed by device ID); otherwise this implementation is fine for the stated assumption.

Also applies to: 339-353

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a390685 and 815aaf3.

📒 Files selected for processing (6)
  • csrc/trtllm_mnnvl_allreduce.cu (1 hunks)
  • flashinfer/comm/mnnvl.py (33 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (6 hunks)
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2 hunks)
  • include/flashinfer/utils.cuh (2 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (9 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/utils.cuh
  • flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (1)
flashinfer/comm/trtllm_mnnvl_ar.py (4)
flashinfer/comm/mapping.py (5)
  • Mapping (21-475)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • is_multi_node (403-404)
flashinfer/jit/comm.py (1)
  • gen_trtllm_mnnvl_comm_module (33-39)
flashinfer/comm/mnnvl.py (13)
  • McastGPUBuffer (1032-1113)
  • CommBackend (145-164)
  • MPIBackend (204-225)
  • lamport_initialize (1014-1029)
  • lamport_initialize (1076-1077)
  • barrier (161-161)
  • barrier (220-221)
  • get_buffer_ptrs_dev (778-780)
  • get_buffer_ptrs_dev (1111-1113)
  • get_unicast_ptr (782-790)
  • get_unicast_ptr (1107-1109)
  • get_multicast_ptr (792-796)
  • get_multicast_ptr (1103-1105)
csrc/trtllm_mnnvl_allreduce.cu (2)
  • trtllm_mnnvl_allreduce_fusion (29-113)
  • trtllm_mnnvl_allreduce_fusion (29-35)
🪛 GitHub Actions: pre-commit
flashinfer/comm/mnnvl.py

[warning] 62-62: pre-commit: minor formatting/consistency note after fixes (merge-conflict markers touched code).


[warning] 1-1: pre-commit: potential changes applied by formatting hook (ruff-format had modifications).

flashinfer/comm/trtllm_mnnvl_ar.py

[error] 36-36: invalid-syntax: Expected class, function definition or async function definition after decorator (likely due to merge conflict markers <<<<<<< HEAD / ======= / >>>>>>> ...)


[error] 169-169: invalid-syntax: Duplicate merge conflict marker present (HEAD) or conflict resolution remnants.


[error] 314-314: invalid-syntax: Duplicate merge conflict marker present (HEAD) or conflict lines not resolved.

tests/comm/test_trtllm_mnnvl_allreduce.py

[error] 29-29: invalid-syntax: Duplicate merge conflict marker present (HEAD) or unresolved conflict markers in test file.


[error] 1-1: Parse error: file appears to contain merge conflict markers from an unresolved conflict.

🪛 Ruff (0.14.6)
flashinfer/comm/mnnvl.py

393-393: Avoid specifying long messages outside the exception class

(TRY003)


533-533: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


558-558: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


586-586: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


586-586: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


586-586: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


598-598: Avoid specifying long messages outside the exception class

(TRY003)


620-620: Duplicate parameter "comm_backend_for_handle_transfer"

(invalid-syntax)


672-672: Avoid specifying long messages outside the exception class

(TRY003)


692-692: Avoid specifying long messages outside the exception class

(TRY003)


750-750: Do not catch blind exception: Exception

(BLE001)


817-817: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


832-832: Do not catch blind exception: Exception

(BLE001)

flashinfer/comm/trtllm_mnnvl_ar.py

36-36: Expected class, function definition or async function definition after decorator

(invalid-syntax)


36-36: Expected a statement

(invalid-syntax)


36-36: Expected a statement

(invalid-syntax)


36-36: Expected a statement

(invalid-syntax)


37-37: Unexpected indentation

(invalid-syntax)


37-38: Expected an indented block after function definition

(invalid-syntax)


38-38: Expected a statement

(invalid-syntax)


38-38: Expected a statement

(invalid-syntax)


38-38: Expected a statement

(invalid-syntax)


38-38: Expected a statement

(invalid-syntax)


38-39: Expected a statement

(invalid-syntax)


39-39: Unexpected indentation

(invalid-syntax)


41-42: Expected an indented block after function definition

(invalid-syntax)


42-42: Expected a statement

(invalid-syntax)


42-42: Expected a statement

(invalid-syntax)


42-42: Expected a statement

(invalid-syntax)


42-42: Expected a statement

(invalid-syntax)


42-42: Expected ,, found name

(invalid-syntax)


42-42: Expected ,, found name

(invalid-syntax)


42-42: Expected an identifier

(invalid-syntax)


43-43: Unexpected indentation

(invalid-syntax)


51-51: Expected a statement

(invalid-syntax)


169-169: Expected a statement

(invalid-syntax)


169-169: Expected a statement

(invalid-syntax)


169-169: Expected a statement

(invalid-syntax)


169-169: Expected a statement

(invalid-syntax)


170-170: Unexpected indentation

(invalid-syntax)


173-174: Expected an indented block after if statement

(invalid-syntax)


174-174: Expected a statement

(invalid-syntax)


174-174: Expected a statement

(invalid-syntax)


174-174: Expected a statement

(invalid-syntax)


174-174: Expected a statement

(invalid-syntax)


174-175: Expected a statement

(invalid-syntax)


175-175: Unexpected indentation

(invalid-syntax)


180-181: Expected an indented block after if statement

(invalid-syntax)


181-181: Expected a statement

(invalid-syntax)


181-181: Expected a statement

(invalid-syntax)


181-181: Expected a statement

(invalid-syntax)


181-181: Expected a statement

(invalid-syntax)


181-181: Expected ,, found name

(invalid-syntax)


181-181: Expected ,, found name

(invalid-syntax)


181-181: Expected an identifier

(invalid-syntax)


183-183: Unexpected indentation

(invalid-syntax)


184-184: unindent does not match any outer indentation level

(invalid-syntax)


184-184: Expected a statement

(invalid-syntax)


184-184: Expected a statement

(invalid-syntax)


184-185: Expected a statement

(invalid-syntax)


187-187: Unexpected indentation

(invalid-syntax)


188-188: unindent does not match any outer indentation level

(invalid-syntax)


314-314: Expected a statement

(invalid-syntax)


314-314: Expected a statement

(invalid-syntax)


314-314: Expected a statement

(invalid-syntax)


314-314: Expected a statement

(invalid-syntax)


315-315: Unexpected indentation

(invalid-syntax)


319-319: Expected a statement

(invalid-syntax)


319-319: Expected a statement

(invalid-syntax)


319-319: Expected a statement

(invalid-syntax)


319-319: Expected a statement

(invalid-syntax)


319-320: Expected a statement

(invalid-syntax)


320-320: Unexpected indentation

(invalid-syntax)


332-332: Expected a statement

(invalid-syntax)


332-332: Expected a statement

(invalid-syntax)


332-332: Expected a statement

(invalid-syntax)


332-332: Expected a statement

(invalid-syntax)


332-332: Expected ,, found name

(invalid-syntax)


332-332: Expected ,, found name

(invalid-syntax)


332-332: Expected an identifier

(invalid-syntax)


333-333: Unexpected indentation

(invalid-syntax)


354-354: Expected a statement

(invalid-syntax)

tests/comm/test_trtllm_mnnvl_allreduce.py

29-29: Expected a statement

(invalid-syntax)


29-29: Expected a statement

(invalid-syntax)


29-29: Expected a statement

(invalid-syntax)


29-29: Expected a statement

(invalid-syntax)


30-30: Unexpected indentation

(invalid-syntax)


35-35: Expected a statement

(invalid-syntax)


35-35: Expected a statement

(invalid-syntax)


35-35: Expected a statement

(invalid-syntax)


35-35: Expected a statement

(invalid-syntax)


35-36: Expected a statement

(invalid-syntax)


36-36: Unexpected indentation

(invalid-syntax)


37-37: Expected a statement

(invalid-syntax)


37-37: Expected a statement

(invalid-syntax)


37-37: Expected a statement

(invalid-syntax)


37-37: Expected a statement

(invalid-syntax)


37-37: Expected ,, found name

(invalid-syntax)


37-37: Expected ,, found name

(invalid-syntax)


37-37: Expected an identifier

(invalid-syntax)


39-39: Unexpected indentation

(invalid-syntax)


114-114: Expected a statement

(invalid-syntax)


325-325: Expected a statement

(invalid-syntax)


325-325: Expected a statement

(invalid-syntax)


325-325: Expected a statement

(invalid-syntax)


325-325: Expected a statement

(invalid-syntax)


326-326: Unexpected indentation

(invalid-syntax)


329-329: unindent does not match any outer indentation level

(invalid-syntax)


330-330: Expected a statement

(invalid-syntax)


330-330: Expected a statement

(invalid-syntax)


330-330: Expected a statement

(invalid-syntax)


330-330: Expected a statement

(invalid-syntax)


330-331: Expected a statement

(invalid-syntax)


331-331: Unexpected indentation

(invalid-syntax)


337-337: Expected a statement

(invalid-syntax)


337-337: Expected a statement

(invalid-syntax)


337-337: Expected a statement

(invalid-syntax)


337-337: Expected a statement

(invalid-syntax)


337-337: Expected ,, found name

(invalid-syntax)


337-337: Expected ,, found name

(invalid-syntax)


337-337: Expected an identifier

(invalid-syntax)


338-338: Unexpected indentation

(invalid-syntax)


382-383: Expected an indented block after if statement

(invalid-syntax)


383-383: Expected except or finally after try block

(invalid-syntax)


383-383: Expected a statement

(invalid-syntax)


383-383: Expected a statement

(invalid-syntax)


383-383: Expected a statement

(invalid-syntax)


384-384: Unexpected indentation

(invalid-syntax)


388-388: unindent does not match any outer indentation level

(invalid-syntax)


393-393: Unexpected indentation

(invalid-syntax)


413-413: Expected a statement

(invalid-syntax)


413-413: Expected a statement

(invalid-syntax)


413-413: Expected a statement

(invalid-syntax)


413-413: Expected a statement

(invalid-syntax)


413-414: Expected a statement

(invalid-syntax)


414-414: Unexpected indentation

(invalid-syntax)


417-417: unindent does not match any outer indentation level

(invalid-syntax)


434-434: unindent does not match any outer indentation level

(invalid-syntax)


434-434: Expected a statement

(invalid-syntax)


434-434: Expected a statement

(invalid-syntax)


434-435: Expected a statement

(invalid-syntax)


435-435: Unexpected indentation

(invalid-syntax)


447-447: unindent does not match any outer indentation level

(invalid-syntax)


449-449: Unexpected indentation

(invalid-syntax)


451-451: unindent does not match any outer indentation level

(invalid-syntax)


451-451: Expected a statement

(invalid-syntax)


451-451: Expected a statement

(invalid-syntax)


451-452: Expected an expression

(invalid-syntax)


452-452: Unexpected indentation

(invalid-syntax)


472-472: unindent does not match any outer indentation level

(invalid-syntax)


472-472: Expected a statement

(invalid-syntax)


472-472: Expected a statement

(invalid-syntax)


472-473: Expected a statement

(invalid-syntax)


474-474: Unexpected indentation

(invalid-syntax)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
csrc/trtllm_mnnvl_allreduce.cu (1)

29-112: FFI wrapper validation for fusion/RMSNorm is solid

The new trtllm_mnnvl_allreduce_fusion wrapper wires all parameters into AllReduceFusionParams correctly and adds the right guards:

  • token_dim alignment to float4-based packing
  • nranks/rank range checks
  • mandatory presence of residual_in, residual_out, gamma, and epsilon when rmsnorm_fusion is true
  • shape checks for residual tensors and gamma

This closes the gap where fused kernels could be launched with missing or mismatched residual inputs.

include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (1)

32-1219: Lamport-based fusion kernels and dispatch look structurally sound

The header refactor introduces:

  • A richer AllReduceFusionParams tailored for fusion (buffer pointers, multicast pointer, flags, RMSNorm toggle, PDL toggle).
  • A focused utils namespace with Lamport buffer layout/flags, packed-vector helpers, warp/block reduction helpers, and adjustGridConfig that uses GetCudaMultiProcessorCount to pick reasonable block/cluster/load settings.
  • New oneshotAllreduceFusionKernel and twoshotAllreduceKernel specializations with explicit world-size dispatch (2–64) and an optional fused RMSNorm path via rmsNormLamport.

From a structural standpoint, the buffer layout and synchronization logic (Lamport flags, negative-zero sentinels, waitAndUpdate/clearDirtyLamportBuf, and cluster-aware reductions) are coherent and match established TensorRT-LLM patterns. The error paths for unsupported world sizes and loads-per-thread are clear and return CUDA errors instead of silently misbehaving.

Comment on lines 293 to 297
def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None):
MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined]
comm = config.comm_backend.Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
comm = config.comm_backend.Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank)
MnnvlMemory.comm = comm # type: ignore[assignment]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Bug in set_comm_from_config when config is None

set_comm_from_config assigns MnnvlMemory.config = config or MnnvlConfig(...) but still uses the original config variable to access comm_backend. When config is None, this will raise AttributeError.

Suggested fix:

@staticmethod
def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None):
-    MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend())  # type: ignore[attr-defined]
-    comm = config.comm_backend.Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank)
+    config = config or MnnvlConfig(comm_backend=MPIBackend())  # type: ignore[attr-defined]
+    MnnvlMemory.config = config
+    comm = config.comm_backend.Split(
+        mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
+    )
    MnnvlMemory.comm = comm  # type: ignore[assignment]
🤖 Prompt for AI Agents
In flashinfer/comm/mnnvl.py around lines 293 to 297, set_comm_from_config
assigns MnnvlMemory.config = config or MnnvlConfig(...) but then still uses the
original config variable to access comm_backend, which will raise AttributeError
when config is None; fix by computing an effective_config = config or
MnnvlConfig(comm_backend=MPIBackend()), assign MnnvlMemory.config =
effective_config, then call effective_config.comm_backend.Split(...) and assign
the returned comm to MnnvlMemory.comm so the fallback config is actually used
when config is None.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
flashinfer/comm/mnnvl.py (1)

566-664: Still leaking the local POSIX FD in _alloc_mn_mcast_mem’s IPC path.

The POSIX-FD flow now:

  • Uses IpcSocket + send_fd / recv_fd for allgather/bcast.
  • Closes imported FDs after cuMemImportFromShareableHandle.
  • Closes _ipc_socket in __del__.

However, local_shareable_uc_handle (the FD returned by cuMemExportToShareableHandle) is never closed in the POSIX path after the ring allgather finishes. Over long runs this will accumulate FDs on each rank.

You already close:

  • all_shareable_uc_handles[p] after import; and
  • shareable_mc_handle after multicast import.

Please also close the local exported FD after the ring completes, e.g.:

        if (
            self._shareable_handle_type
            == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
        ):
            # All-gather fabric handles
            all_shareable_uc_handles = self.comm_backend.allgather(
                local_shareable_uc_handle.data
            )
        else:
            # Implement the allgather logic with ipc socket
            all_shareable_uc_handles = [None] * self.group_size
            for i in range(self.group_size):
                self.comm_backend.barrier()
                # Send to peer at offset i
                dest_rank = (self.group_rank + i) % self.group_size
                self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank)
                # Receive from peer at offset -i
                src_rank = (self.group_rank + self.group_size - i) % self.group_size
                all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd()
-        cuda.cuCtxSynchronize()
+        cuda.cuCtxSynchronize()
+        if (
+            self._shareable_handle_type
+            == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
+        ):
+            os.close(local_shareable_uc_handle)

This complements the existing closes for imported FDs and multicast FDs and avoids descriptor leaks.

Also applies to: 893-901, 965-993, 1004-1009, 1058-1063

flashinfer/comm/trtllm_mnnvl_ar.py (1)

344-379: Restore RMSNorm epsilon default to 1e-5 instead of torch.finfo(...).eps.

Using torch.finfo(input.dtype).eps when epsilon is None changes RMSNorm behavior materially, especially for fp16 (eps ~1e-3), and diverges from the CUDA kernel’s built-in default of 1e-5 used by TensorRT-LLM. This was previously called out and agreed to be corrected; the current code reintroduces the same behavior.

To keep parity with the C++ kernel and prior APIs, prefer a fixed 1e-5 default:

-    if epsilon is None:
-        epsilon = torch.finfo(input.dtype).eps
+    if epsilon is None:
+        # Match TensorRT-LLM kernel default for RMSNorm
+        epsilon = 1e-5

Also update the docstring so it no longer claims that torch.finfo.eps is used by default.

🧹 Nitpick comments (3)
flashinfer/comm/mnnvl.py (2)

132-149: Clarify empty-input behavior and return typing for alloc_and_copy_to_cuda.

alloc_and_copy_to_cuda is annotated to return int but returns None when host_ptr_array is empty. That mismatch can confuse callers and type-checkers.

Consider either:

  • Making the return type Optional[int] and documenting that None means “no allocation performed”, or
  • Raising on empty input (if that’s never expected here), or
  • Returning 0 as a sentinel pointer value instead of None.

630-645: Clean up unused recvmsg outputs in IpcSocket.recv_fd.

msg, flags, and addr from self.sock.recvmsg(...) are never used, which also triggers Ruff’s RUF059 warnings.

You can keep the call but acknowledge the unused values to improve clarity and silence the linter:

-        fds = array.array("i")
-        msg, ancdata, flags, addr = self.sock.recvmsg(
+        fds = array.array("i")
+        _msg, ancdata, _flags, _addr = self.sock.recvmsg(
             1,
             socket.CMSG_SPACE(
                 fds.itemsize
             ),  # Buffer size for dummy data  # Ancillary data size
         )
flashinfer/comm/trtllm_mnnvl_ar.py (1)

30-47: Strategy selection and workspace sizing logic are sound; reconsider caching on instance method.

  • MNNVLAllreduceFusionStrategy.select_strategy with MNNVL_ONE_SHOT_THRESHOLD and elem_size is straightforward and keeps the heuristic in bytes.
  • get_required_buffer_size_bytes mirrors the ONESHOT vs TWOSHOT layouts and is already cached as a @staticmethod.

Given that, is_buffer_size_sufficient is a very light wrapper over the static helper but is also decorated with @functools.cache on an instance method, which can accumulate entries keyed by self (B019 concern) without much benefit.

You can simplify and avoid the method-level cache by dropping @functools.cache there:

-    @functools.cache
     def is_buffer_size_sufficient(
         self,
         tp_size: int,
         num_tokens: int,
         hidden_dim: int,
         dtype: torch.dtype,
         strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO,
     ) -> bool:

get_required_buffer_size_bytes will still provide caching for the heavy calculation.

Also applies to: 134-182

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 815aaf3 and 68a9b9b.

📒 Files selected for processing (2)
  • flashinfer/comm/mnnvl.py (19 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (1)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (6)
  • Mapping (21-475)
  • rank (311-312)
  • rank (315-322)
  • tp_rank (325-326)
  • local_rank (391-392)
  • is_multi_node (403-404)
flashinfer/jit/comm.py (1)
  • gen_trtllm_mnnvl_comm_module (33-39)
flashinfer/utils.py (2)
  • register_custom_op (314-323)
  • register_custom_op (333-352)
flashinfer/comm/mnnvl.py (9)
  • McastGPUBuffer (1143-1224)
  • lamport_initialize (1123-1140)
  • lamport_initialize (1183-1184)
  • get_buffer_ptrs_dev (857-859)
  • get_buffer_ptrs_dev (1222-1224)
  • get_unicast_ptr (861-869)
  • get_unicast_ptr (1218-1220)
  • get_multicast_ptr (871-875)
  • get_multicast_ptr (1214-1216)
csrc/trtllm_mnnvl_allreduce.cu (2)
  • trtllm_mnnvl_allreduce_fusion (29-113)
  • trtllm_mnnvl_allreduce_fusion (29-35)
🪛 Ruff (0.14.6)
flashinfer/comm/mnnvl.py

587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


640-640: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


656-656: Avoid specifying long messages outside the exception class

(TRY003)


896-896: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)

flashinfer/comm/trtllm_mnnvl_ar.py

77-79: Avoid specifying long messages outside the exception class

(TRY003)


134-134: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)


298-300: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


319-321: Avoid specifying long messages outside the exception class

(TRY003)


381-383: Avoid specifying long messages outside the exception class

(TRY003)


385-387: Avoid specifying long messages outside the exception class

(TRY003)


389-391: Avoid specifying long messages outside the exception class

(TRY003)


395-397: Avoid specifying long messages outside the exception class

(TRY003)


401-403: Avoid specifying long messages outside the exception class

(TRY003)


414-416: Avoid specifying long messages outside the exception class

(TRY003)


544-546: Avoid specifying long messages outside the exception class

(TRY003)


615-617: Avoid specifying long messages outside the exception class

(TRY003)


621-623: Avoid specifying long messages outside the exception class

(TRY003)


626-628: Avoid specifying long messages outside the exception class

(TRY003)


630-632: Avoid specifying long messages outside the exception class

(TRY003)


635-637: Avoid specifying long messages outside the exception class

(TRY003)


640-642: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (5)
flashinfer/comm/mnnvl.py (1)

717-718: Handle-type selection, workspace getters, and Lamport initialization look consistent.

  • _shareable_handle_type cleanly distinguishes FABRIC vs POSIX-FD paths and is threaded through allocation/multicast props.
  • New get_allocation_size / get_usable_buffer_size and the update of McastGPUBuffer.buf_size to the usable portion (excluding signal pad) make the Python-facing API less surprising.
  • lamport_initialize now explicitly excludes the signal pad region, and using FP32 there matches the earlier rationale about word-aligned LDG.128 sentinels.
  • get_multicast_buffer / get_unicast_buffer remaining as NotImplementedError stubs is a reasonable placeholder until you wire them via create_tensor_from_cuda_memory; callers will fail fast instead of silently misusing raw pointers.

Also applies to: 744-766, 781-790, 885-892, 1135-1137, 1164-1181, 1186-1213, 1218-1221

flashinfer/comm/trtllm_mnnvl_ar.py (4)

266-341: New trtllm_mnnvl_allreduce workspace size check and validation look good.

  • 2D shape validation for input / output is explicit and user-friendly.
  • Strategy selection falls back to AUTO and uses the same select_strategy heuristic as the workspace utilities.
  • Calling workspace.is_buffer_size_sufficient(...) and raising with a detailed message (including current vs required bytes) restores the missing guard against buffer overrun that was raised in earlier reviews.

This aligns the Python API’s safety guarantees with the underlying kernel expectations.


438-490: Deprecation wrapper for get_allreduce_mnnvl_workspace preserves behavior.

Redirecting the legacy workspace helper to MNNVLAllreduceFusionWorkspace keeps the public return type (McastGPUBuffer, flags tensor, max_num_elements) while converging all allocation through the new path and flag layout.

The use of stride and workspace.buffer_size_bytes // stride to compute max_num_elements is consistent with the original intent.


575-662: Legacy fused RMSNorm wrapper looks consistent with the new fused kernel.

  • Shape checks for shard_input, residual, gamma, and both outputs are thorough and mirror the new high-level API.
  • The buffer_M check preserves the legacy guard against oversized inputs.
  • Calling trtllm_mnnvl_allreduce_fusion with rmsnorm_fusion=True, use_oneshot=False, and wiring normed_output / prenorm_output / residual / gamma / epsilon matches the intended “two-shot + RMSNorm” behavior.

Given the function is now deprecated in favor of trtllm_mnnvl_fused_allreduce_add_rmsnorm, this looks like a safe compatibility bridge.


498-569: Unfortunately, the repository clone failed, which prevents me from directly inspecting the CUDA kernel implementation to verify the core concern about buffer_ptr_local usage.

The review comment raises a valid architectural concern about the contract between the deprecated Python wrapper and the underlying CUDA kernel:

  • The wrapper passes buffer_ptr_local=0 with an assumption that the kernel "does not use this local pointer" when rmsnorm_fusion=false
  • Without access to the CUDA kernel source (csrc/trtllm_mnnvl_allreduce.cu), I cannot definitively verify whether this assumption is correct
  • If the assumption is wrong, this represents a latent bug; if correct, the concern is more about defensive programming

Since I cannot verify the technical facts needed to assess this concern (i.e., whether buffer_ptr_local is genuinely unused in the non-RMSNorm code path), I must preserve the original review comment as-is.


Legacy trtllm_mnnvl_all_reduce wrapper relies on buffer_ptr_local = 0.

The deprecation wrapper now forwards to trtllm_mnnvl_allreduce_fusion and passes buffer_ptr_local=0 with a comment that the allreduce kernel "does not use this local pointer".

That contract is subtle and brittle: if the CUDA kernel ever starts using buffer_ptr_local even in the non-RMSNorm path, this wrapper could cause a hard-to-debug crash.

If possible, consider:

  • Threading through the actual local unicast pointer (e.g., via a workspace or by evolving callers), or
  • Adding a clear assertion / comment in the CUDA side that buffer_ptr_local is ignored when rmsnorm_fusion == false, to document the invariant this wrapper depends on.

[warn-only as it's legacy and deprecated, but worth making the dependency explicit.]

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (1)

406-425: Potential deadlock: allgather called only by failing ranks.

The MPI.COMM_WORLD.allgather(rank_failed) call is inside the except block, so only ranks that throw an exception will participate. If rank 0 fails while rank 1 succeeds, rank 1 continues to the barrier at line 433 while rank 0 blocks on allgather—causing a distributed deadlock.

Consider moving failure synchronization outside the exception handler:

     except Exception as e:
         rank_failed = True
         failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}"
         print(failure_message)
         print(traceback.format_exc())
 
-        # Gather failure status from all ranks for logging
-        all_failures = MPI.COMM_WORLD.allgather(rank_failed)
-
-        if any(all_failures):
-            failed_ranks = [i for i, failed in enumerate(all_failures) if failed]
-            if rank == 0:
-                print(f"Test failed on ranks: {failed_ranks}")
-
         # Cleanup before re-raising
         if "workspace" in locals():
             del workspace
 
         # Re-raise the original exception so it can be caught by pytest.raises in negative tests
         raise
 
     finally:
         # Ensure cleanup happens for this list's workspace
         if "workspace" in locals():
             del workspace
+        
+        # Gather failure status from all ranks for logging (must be outside except to avoid deadlock)
+        all_failures = MPI.COMM_WORLD.allgather(rank_failed)
+        if any(all_failures):
+            failed_ranks = [i for i, failed in enumerate(all_failures) if failed]
+            if rank == 0:
+                print(f"Test failed on ranks: {failed_ranks}")

Note: With this change, the raise at line 425 would need to be moved after the finally block completes, or handled differently to ensure all ranks synchronize before any rank exits.

♻️ Duplicate comments (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (1)

257-264: Epsilon mismatch between reference computation and kernel execution.

The reference RMSNorm uses torch.finfo(dtype).eps (line 263), which is ~6e-8 for float16 and ~1e-7 for bfloat16, while the actual kernel execution uses eps = 1e-5 (line 322). This 100-1000x difference in epsilon values will produce different numerical results between reference and actual outputs.

Additionally, line 257 has a type annotation issue: Tuple[torch.Tensor, ...] assigned to None.

-    reference_output: Tuple[torch.Tensor, ...] = None
+    reference_output: Optional[Tuple[torch.Tensor, ...]] = None
     if fusion:
         # Fused case: AllReduce + Residual Add + RMS Norm
         allreduce_result = torch.sum(x_full, dim=0)  # AllReduce result
         residual_out = allreduce_result + residual  # Add residual
         norm_out = rmsnorm(
-            residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False
+            residual_out, norm_weight, 1e-5, enable_pdl=False
         )
🧹 Nitpick comments (2)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)

60-69: Remove unnecessary tensor allocation.

Line 61 allocates output = torch.empty_like(input) but it's immediately overwritten by the return value of trtllm_mnnvl_allreduce on line 63. This wastes GPU memory unnecessarily.

         else:
-            output = torch.empty_like(input)
-
             output = trtllm_mnnvl_ar.trtllm_mnnvl_allreduce(
                 input,
                 workspace,
                 launch_with_pdl=use_pdl,
                 strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO,
             )
             return (output.view(shape),)

274-276: Unused monkeypatch parameter.

The monkeypatch fixture is passed but never used in the function body. If it's not needed, consider removing it from the signature and test function parameters.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 68a9b9b and 9e11752.

📒 Files selected for processing (1)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • tp_rank (325-326)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
  • MNNVLAllreduceFusionWorkspace (50-181)
  • MNNVLAllreduceFusionStrategy (30-43)
  • trtllm_mnnvl_allreduce (266-341)
  • get_allreduce_mnnvl_workspace (442-495)
  • get_required_buffer_size_bytes (156-181)
flashinfer/comm/mnnvl.py (2)
  • barrier (168-168)
  • barrier (227-228)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)

105-124: LGTM!

The legacy API test helper correctly passes buffer pointers and uses MPI barrier for synchronization.


439-465: LGTM!

Good test coverage with comprehensive parameterization across sequence lengths, fusion modes, data types, and hidden sizes. The separation between refactored and legacy API tests is well-structured.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)

132-149: Clarify empty-input behavior in alloc_and_copy_to_cuda

alloc_and_copy_to_cuda is annotated to return int but returns None when host_ptr_array is empty (Line 136). That mismatch can surface as a runtime bug if an empty list ever slips through.

Consider failing fast instead of returning None, e.g.:

 def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int:
-    """
-    A helper function that allocates memory on cuda and copies the data from the host to the device.
-    """
-    if not host_ptr_array:
-        return None
+    """
+    A helper function that allocates memory on cuda and copies the data from the host to the device.
+    """
+    if not host_ptr_array:
+        raise ValueError("host_ptr_array must be non-empty")

This keeps the return type consistent and surfaces misuse early.

♻️ Duplicate comments (2)
flashinfer/comm/mnnvl.py (2)

300-305: Fix set_comm_from_config when config is None

When config is None, you assign the fallback to MnnvlMemory.config but still call config.comm_backend.Split(...) (Line 302), which will raise AttributeError.

Use the effective config for both the assignment and the Split call:

 @staticmethod
 def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None):
-    MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend())  # type: ignore[attr-defined]
-    comm = config.comm_backend.Split(
-        mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
-    )
+    config = config or MnnvlConfig(comm_backend=MPIBackend())  # type: ignore[attr-defined]
+    MnnvlMemory.config = config
+    comm = config.comm_backend.Split(
+        mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
+        mapping.tp_rank,
+    )
     MnnvlMemory.comm = comm  # type: ignore[assignment]

This matches the earlier suggested fix and ensures the fallback path actually works.

#!/bin/bash
# Simple sanity check: search for other direct uses of bare `config` in this method.
rg -n "set_comm_from_config" flashinfer/comm/mnnvl.py -n -C5

720-759: Close local POSIX shareable handles to avoid slow FD leaks

The new exchanger abstraction correctly closes imported POSIX FDs via PosixFDHandleExchanger.cleanup() (Lines 760–762), but two categories of FDs still remain unclosed in the POSIX path:

  1. local_shareable_uc_handle produced by cuMemExportToShareableHandle (Lines 1065–1072) is never passed through cleanup.
  2. For the multicast handle, rank 0’s shareable_mc_handle (Lines 1117–1125) is broadcast but only non‑root ranks call cleanup after import (Lines 1135–1141).

Over long‑running runs that create/destroy McastDeviceMemory repeatedly, these will accumulate OS FDs.

You can reuse the exchanger cleanup hook to fix both without special‑casing for handle type:

        # All-gather shareable handles
        all_shareable_uc_handles = self._exchanger.allgather(local_shareable_uc_handle)
        cuda.cuCtxSynchronize()

        # Import remote handles
        for p in range(self.group_size):
            if p != self.group_rank:
                self.uc_handles[p] = checkCudaErrors(
                    cuda.cuMemImportFromShareableHandle(
                        all_shareable_uc_handles[p],
                        self._exchanger.handle_type,
                    )
                )
                self._exchanger.cleanup(all_shareable_uc_handles[p])
+
+        # We no longer need our own exported shareable handle.
+        self._exchanger.cleanup(local_shareable_uc_handle)

And in _setup_multicast:

        # Broadcast multicast handle from rank 0
        shareable_mc_handle = self._exchanger.broadcast(shareable_mc_handle, root=0)
        cuda.cuCtxSynchronize()

-        # Import multicast handle for non-root ranks
-        if self.group_rank != 0:
+        # Import multicast handle for non-root ranks
+        if self.group_rank != 0:
             self.mc_handle = checkCudaErrors(
                 cuda.cuMemImportFromShareableHandle(
                     shareable_mc_handle,
                     self._exchanger.handle_type,
                 )
             )
             self._exchanger.cleanup(shareable_mc_handle)
+        else:
+            # Root rank can now drop its exported handle (POSIX FD path).
+            self._exchanger.cleanup(shareable_mc_handle)

FabricHandleExchanger.cleanup() is a no‑op, so these calls are harmless in the fabric path and fix the leak in the POSIX‑FD path.

#!/bin/bash
# Check remaining uses of cuMemExportToShareableHandle and ensure every handle is cleaned up.
rg -n "cuMemExportToShareableHandle" flashinfer/comm/mnnvl.py -n -C5

Also applies to: 760-765, 1065-1076, 1079-1088, 1129-1142

🧹 Nitpick comments (3)
flashinfer/comm/mnnvl.py (3)

566-664: Tidy IpcSocket.recv_fd unused variables and document /tmp path usage

recv_fd currently binds msg, flags, and addr but never uses them (Line 640), which Ruff flags (RUF059). You can keep the signature while silencing the warning:

-        fds = array.array("i")
-        msg, ancdata, flags, addr = self.sock.recvmsg(
+        fds = array.array("i")
+        _msg, ancdata, _flags, _addr = self.sock.recvmsg(
             1,
             socket.CMSG_SPACE(
                 fds.itemsize
             ),  # Buffer size for dummy data  # Ancillary data size
         )

On the S108 /tmp/mcastmem-socket- warning: since you default to use_abstract=True, the filesystem path is never actually used in normal operation. If you expect use_abstract=False in multi‑tenant environments, consider adding a brief comment noting that this path is intended for trusted deployments only.


1003-1013: Consider narrowing the broad Exception catch in _verify_cuda_context

_verify_cuda_context currently catches a bare Exception (Lines 1005–1012). Given this is only used for diagnostics, that’s acceptable, but narrowing to CUDA‑related exceptions (or at least logging the exception type) would align better with BLE001 and avoid swallowing unrelated programming errors.

Not strictly required, but worth considering:

-        except Exception as e:
-            print(f"Error checking CUDA context: {e}")
+        except Exception as e:
+            # Broad catch is intentional: any CUDA context error should only warn here.
+            print(f"Error checking CUDA context: {type(e).__name__}: {e}")

1242-1268: Explicitly mark get_multicast_buffer / get_unicast_buffer as internal or implement via the new helper

Both get_multicast_buffer and get_unicast_buffer currently raise NotImplementedError (Lines 1257–1258 and 1267–1268). Given prior discussion that this class is internal, that’s acceptable, but the public‑sounding names can confuse users if they discover them.

Two options:

  • If they are truly internal placeholders, consider prefixing with _ or clarifying in the docstring that they are not expected to be called yet.
  • If you want them usable, wire them up through create_tensor_from_cuda_memory at the top of the file, e.g., using self.get_multicast_ptr() / self.get_unicast_ptr(rank) plus the desired shape/dtype.

Happy to sketch an implementation using create_tensor_from_cuda_memory if you plan to expose these in this PR.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9e11752 and 4a5faef.

📒 Files selected for processing (1)
  • flashinfer/comm/mnnvl.py (19 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/comm/mnnvl.py (2)
flashinfer/cuda_utils.py (1)
  • checkCudaErrors (51-61)
flashinfer/utils.py (1)
  • round_up (631-633)
🪛 Ruff (0.14.6)
flashinfer/comm/mnnvl.py

587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"

(S108)


640-640: Unpacked variable msg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable flags is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


640-640: Unpacked variable addr is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


656-656: Avoid specifying long messages outside the exception class

(TRY003)


726-726: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


1011-1011: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/comm/mnnvl.py (2)

878-887: Allocation sizing and usable buffer exposure look consistent

get_allocation_size and get_usable_buffer_size (Lines 982–989) align with lamport_initialize, which uses allocation_size - SIGNAL_PAD_SIZE (Lines 1191–1193), and McastGPUBuffer now surfaces this usable size via self.buf_size (Lines 1235–1237).

This keeps the Python view of capacity in sync with what Lamport initialization actually touches, which is a nice improvement over passing the raw requested size around.

Also applies to: 982-989, 1191-1193, 1235-1237


164-166: I'm unable to access the repository due to persistent cloning issues. However, based on the information provided in the review comment itself, I can provide analysis:

The review comment acknowledges uncertainty with the phrase "if any exist in this repo", which suggests:

  1. The reviewer was aware that there may not be other CommBackend implementations
  2. The reviewer already identified MPIBackend as implementing bcast (lines 224-226)
  3. The concern is conditional on whether other subclasses exist

Without direct repository access to verify all CommBackend subclasses and their implementations, I cannot definitively confirm or refute the original review comment's concern.

Recommended next steps:

  • Manually search the repository for classes inheriting from CommBackend
  • Verify each implementation includes the bcast method
  • Or provide confirmation if only MPIBackend exists as a CommBackend subclass

Since I cannot complete the verification due to technical limitations:


New bcast abstract requires all CommBackend implementations to be updated

Adding CommBackend.bcast(...) (Lines 164–166) and implementing it in MPIBackend (Lines 224–226) is correct, but any other CommBackend subclasses must also implement bcast to avoid TypeError at instantiation. Verify that all subclasses (if any exist) have been updated accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants