-
Notifications
You must be signed in to change notification settings - Fork 584
Refactor trtllm_mnnvl_allreduce #2118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughReplaces legacy MNNVL all-reduce with a fused Lamport-buffer allreduce exposed as Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant PyAPI as Python API / Workspace
participant Strategy as Strategy Selector
participant Comm as Comm Backend (MPI / IpcSocket)
participant Kernel as CUDA Kernel (trtllm_mnnvl_allreduce_fusion)
participant Buff as Buffers / Output
App->>PyAPI: call trtllm_mnnvl_allreduce(...) or fused API
PyAPI->>PyAPI: validate inputs, prepare workspace & outputs
PyAPI->>Strategy: select ONESHOT / TWOSHOT (AUTO inspects workspace/problem)
PyAPI->>Comm: exchange/share handles (MPI bcast/barrier or IpcSocket FD exchange)
PyAPI->>Kernel: invoke trtllm_mnnvl_allreduce_fusion(params)
rect rgb(245,250,255)
Kernel->>Kernel: lamport-stage broadcast & per-token reduction
alt RMSNorm fusion enabled
Kernel->>Kernel: compute RMS, apply gamma, add residuals
end
Kernel->>Buff: write output (and residual_out if present)
end
Buff-->>PyAPI: return tensor(s)
PyAPI-->>App: deliver result(s)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @timlee0212, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a comprehensive refactoring of the multi-node NVLink (MNNVL) all-reduce system within FlashInfer. It unifies the all-reduce and RMSNorm operations into a single, highly configurable C++ kernel, exposed through intuitive new Python APIs. A key improvement is the new workspace management class, which automates and optimizes buffer allocation. Furthermore, the PR adds crucial support for IPC Socket-based handle transfer, broadening compatibility to hardware environments like DGX machines. These changes collectively enhance the flexibility, performance, and overall robustness of distributed computations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant refactoring of the MNNVL all-reduce implementation, introducing a new, cleaner API with a dedicated workspace manager class, and adding support for IPC sockets for single-node communication. The changes are extensive and substantially improve the code's structure and capabilities. My review focuses on ensuring backward compatibility is fully maintained as intended, removing leftover debug code, improving memory usage efficiency, adding a critical safety check for buffer sizes in the new API, and suggesting a minor precision improvement in a CUDA kernel.
| def trtllm_mnnvl_allreduce( | ||
| input: torch.Tensor, | ||
| workspace: MNNVLAllreduceFusionWorkspace, | ||
| launch_with_pdl: bool, | ||
| output: Optional[torch.Tensor] = None, | ||
| strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new trtllm_mnnvl_allreduce function is missing a check to ensure that the input tensor fits within the allocated workspace. The old API had a check like if inp.shape[0] > buffer_M: raise ValueError(...). A similar check should be added here to prevent potential out-of-bounds memory access, which could lead to crashes or incorrect results. The required buffer size depends on the strategy (one-shot vs. two-shot) and can be calculated using MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want this check to be on the execution path? Or should we assuming it is the user's liability to ensure it does not overflow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do want this check. I recently added it because it did bite others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
| self.mcast_device_memory.lamport_initialize(rank, dtype) | ||
|
|
||
| def get_mc_buffer( | ||
| def get_multicast_buffer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is used internally, and left as a placeholder but not implemented. Thus, a breaking changes is fine. Tag @nvmbreughe for confirmation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/comm/mnnvl.py (1)
132-149:alloc_and_copy_to_cudareturn type and empty-input behavior are inconsistentThe function is annotated as returning
intbut returnsNonewhenhost_ptr_arrayis empty. Callers currently pass non‑empty lists, but this mismatch can trip type checkers and hide bugs if an empty list is ever passed.Either tighten behavior or relax the signature, for example:
- If empty input is invalid, raise:
def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: if not host_ptr_array: raise ValueError("host_ptr_array must be non-empty")
- Or, if you want the sentinel, change the annotation to
int | Noneand document theNonecase.tests/comm/test_trtllm_mnnvl_allreduce.py (1)
328-427: Moveallgather()and finalmpi_barrier()tofinallyblock to ensure all ranks participate in collectivesLines 414 and 434 create a deadlock risk in error scenarios. The
allgather()at line 414 is inside theexceptblock, so only ranks that hit an exception call it. Meanwhile, thempi_barrier()at line 434 is unconditionally called aftertry/except/finally. If an error occurs on some but not all ranks, failing ranks block inallgather()waiting for non-failing ranks that never enter theexceptblock, while non-failing ranks block in the final barrier—both unable to proceed.Move the
allgather()call and finalmpi_barrier()to thefinallyblock to ensure all ranks participate in these collective operations:rank_failed = False try: ... except Exception as e: rank_failed = True failure_message = ... print(failure_message) import traceback print(traceback.format_exc()) raise finally: all_failures = MPI.COMM_WORLD.allgather(rank_failed) if any(all_failures): failed_ranks = [i for i, failed in enumerate(all_failures) if failed] if rank == 0: print(f"Test failed on ranks: {failed_ranks}") if "workspace" in locals(): del workspace trtllm_mnnvl_ar.mpi_barrier()This applies to line 328–426 (main
try/except) and line 434 (final barrier).
🧹 Nitpick comments (8)
flashinfer/comm/mnnvl.py (1)
640-655: Minor polish: unused recvmsg outputs and predictableopIdTwo small, non‑blocking cleanups:
- In
IpcSocket.recv_fd(), the unpackedmsg,flags, andaddrfromrecvmsgare unused. Renaming them to_msg,_flags,_addrwill make that explicit and silence linters:_msg, ancdata, _flags, _addr = self.sock.recvmsg(...)
opIdfor the socket name is generated withrandom.randint. Since it’s only used as a uniqueness hint (not security‑sensitive), this is fine; if you want to appease S311 you could switch tosecrets.randbits(64)or document that it’s non‑cryptographic.Both are optional, but would make static analysis quieter.
Also applies to: 885-889
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (3)
23-25: Explicitly include<array>and<tuple>, and guardadjustGridConfigagainstsmCount == 0Within this header:
LamportBufferLayout,LamportFlags,PackedVec, and several kernels usestd::array.adjustGridConfigreturnsstd::tuple<int, int, int>and callers usestd::get.But only
<type_traits>is included;<array>and<tuple>are currently pulled in (if at all) via transitive includes, which is fragile.Also,
adjustGridConfigrelies onGetCudaMultiProcessorCount():int smCount = GetCudaMultiProcessorCount(); while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512) { ... }If
GetCudaMultiProcessorCount()ever returns 0 (e.g., CUDA error or misconfiguration), this loop will keep shrinkingclusterSizeand inflatingblockSizein a somewhat opaque way.Suggestions:
- Add explicit includes at the top of the header:
#include <array> #include <tuple>
- Make
adjustGridConfigrobust to a 0 or negative SM count by early‑clamping:int smCount = GetCudaMultiProcessorCount(); if (smCount <= 0) { // Fall back to single-SM configuration clusterSize = 1; blockSize = std::min(threadsNeeded, 1024); return {blockSize, clusterSize, loadsPerThread}; }This keeps the fused path predictable even if the helper cannot obtain a valid SM count.
Also applies to: 54-177, 143-163, 291-313, 348-359, 385-419, 449-497
509-651: Confirm lamport clear / wait protocol assumptions for oneshot kernelThe oneshot fused kernel uses
LamportFlagsas follows:
- Out‑of‑bounds threads call
ctaArrive()thenclearDirtyLamportBuf()and return.- In‑bounds threads:
- write their shard into the multicast lamport buffer,
- call
ctaArrive()again,- then call
clearDirtyLamportBuf()and spin on the Lamport buffers until all entries are non‑negZero.This protocol assumes:
- Every thread in the grid calls
clearDirtyLamportBuf()exactly once per iteration.- Buffer flags and
bytesToClearare correctly initialized to match the configurednumTokens * tokenDim * WorldSize.Given that this is a direct Lamport port, the logic looks consistent, but the protocol is subtle. I’d recommend:
- Double‑checking the initialization of
buffer_flagsinMNNVLAllreduceFusionWorkspacematches the expectations here (current index, dirty index, bytes per buffer, and stage counts).- Adding a brief comment near the kernel launch documenting that
buffer_flagsmust follow the[cur, dirty, bytes_per_buffer, dirty_num_stages, bytes_to_clear[4], access_ptr]layout used byLamportFlags.No code change strictly required, but the invariants are nontrivial and worth locking down in comments/tests.
754-885: Two‑shot path & RMSNorm fusion: validate world sizes and loads‑per‑thread boundsThe two‑shot kernels and dispatchers introduce several constraints:
twoshotAllreduceFusionDispatch<T>only supportsnRanksin{2, 4, 8, 16, 32, 64}and enforcestokenDim % (sizeof(float4) / sizeof(T)) == 0.rmsNormLamportis instantiated withLoadsPerThreadin[1, 8]and usesfloat4loads into shared memory; dynamic shared memory is sized as3 * rnBlockSize * iters * sizeof(T)and indexed accordingly.The implementation looks coherent, but a few invariants are implicit:
MNNVLTwoShotStage::NUM_STAGESmust stay in sync with theLamportFlags<float4>usage and the twobytes_to_clearentries inwaitAndUpdate.rnLoadsPerThreadretrieved fromadjustGridConfigmust remain in[1, 8]; thedefault:branch already errors if it’s out of range, which is good.rnClusterSizefromadjustGridConfigis assumed to be<= 8given__shared__ float sharedVal[8];in the RMSNorm kernel.Given these contracts, I’d suggest:
- Adding asserts (or comments) that
rnClusterSize <= 8when CGA is used, to guard future changes toadjustGridConfig.- Extending tests to cover the corner cases where
tokenDimis just at or above the supported boundary (e.g., maximum hidden size and multiple world sizes) so we don’t regress theFLASHINFER_CHECKconditions.Functionally the code looks sound; this is mainly about making the implicit constraints explicit.
Also applies to: 898-959, 1062-1219
csrc/trtllm_mnnvl_allreduce.cu (1)
99-107: Error message still mentions “twoshot” even for oneshot pathRegardless of
use_oneshot, the failure message says:TVM_FFI_ICHECK(status == cudaSuccess) << "twoshot_allreduce_dispatch_world_size failed with error code " << cudaGetErrorString(status);This is slightly misleading when the oneshot dispatch is used. Consider making the message neutral (e.g., “allreduce_fusion_dispatch failed…”) or switching on
use_oneshotto provide a more accurate label. Behavior is otherwise fine.tests/comm/test_trtllm_mnnvl_allreduce.py (2)
232-270: Use the sameepsfor reference RMSNorm as the fused kernelIn
prepare_test_data, the fused reference path uses:norm_out = rmsnorm( residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False )But the actual fused kernel is driven by the
epsargument passed intorow_linear_residual_norm_fusion_forward(eps = 1e-5inrun_mnnvl_ar_full).To keep the reference as close as possible to the fused implementation (and not rely on loose tolerances), consider:
def prepare_test_data(..., fusion: bool, eps: float): ... if fusion: ... norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False)and threading
epsthrough the call sites.
273-281: Annotatelegacy_explicit_workspace_bytesas optionalRuff’s RUF013 warning here is valid:
def run_mnnvl_ar_full(..., legacy_explicit_workspace_bytes: int = None, legacy_api: bool = False, ):Changing the signature to make the optionality explicit improves readability and typing:
from typing import Optional def run_mnnvl_ar_full( ..., legacy_explicit_workspace_bytes: Optional[int] = None, legacy_api: bool = False, ) -> None: ...or, in Python 3.10+:
legacy_explicit_workspace_bytes: int | None = Noneflashinfer/comm/trtllm_mnnvl_ar.py (1)
203-205: Drop debug print from hot path.
This unconditional- print( - f"[Rank {rank}] Inside Kernel: multicast_buffer_ptr: {multicast_buffer_ptr:x}, buffer_ptrs_dev: {buffer_ptrs_dev:x}, buffer_ptr_local: {buffer_ptr_local:x}, buffer_flags_mnnvl: {buffer_flags_mnnvl}" - )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/mnnvl.py(18 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh(2 hunks)include/flashinfer/utils.cuh(1 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (7)
MNNVLAllreduceFusionWorkspace(47-141)mpi_barrier(23-27)trtllm_mnnvl_fused_allreduce_add_rmsnorm(301-391)MNNVLAllreduceFusionStrategy(30-40)trtllm_mnnvl_allreduce(229-298)get_allreduce_mnnvl_workspace(398-451)get_required_buffer_size_bytes(116-141)flashinfer/comm/mnnvl.py (10)
barrier(168-168)barrier(227-228)bcast(165-165)bcast(224-225)get_multicast_ptr(868-872)get_multicast_ptr(1191-1193)get_buffer_ptrs_dev(854-856)get_buffer_ptrs_dev(1199-1201)get_unicast_ptr(858-866)get_unicast_ptr(1195-1197)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(168-222)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (5)
rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1121-1201)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1101-1118)lamport_initialize(1160-1161)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(854-856)get_buffer_ptrs_dev(1199-1201)get_unicast_ptr(858-866)get_unicast_ptr(1195-1197)get_multicast_ptr(868-872)get_multicast_ptr(1191-1193)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(31-109)trtllm_mnnvl_allreduce_fusion(31-37)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)
🪛 Ruff (0.14.5)
tests/comm/test_trtllm_mnnvl_allreduce.py
279-279: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
flashinfer/comm/trtllm_mnnvl_ar.py
74-76: Avoid specifying long messages outside the exception class
(TRY003)
261-263: Avoid specifying long messages outside the exception class
(TRY003)
268-270: Avoid specifying long messages outside the exception class
(TRY003)
338-340: Avoid specifying long messages outside the exception class
(TRY003)
342-344: Avoid specifying long messages outside the exception class
(TRY003)
346-348: Avoid specifying long messages outside the exception class
(TRY003)
352-354: Avoid specifying long messages outside the exception class
(TRY003)
358-360: Avoid specifying long messages outside the exception class
(TRY003)
500-502: Avoid specifying long messages outside the exception class
(TRY003)
571-573: Avoid specifying long messages outside the exception class
(TRY003)
577-579: Avoid specifying long messages outside the exception class
(TRY003)
582-584: Avoid specifying long messages outside the exception class
(TRY003)
586-588: Avoid specifying long messages outside the exception class
(TRY003)
591-593: Avoid specifying long messages outside the exception class
(TRY003)
596-598: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
885-885: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
include/flashinfer/utils.cuh
Outdated
| inline int GetCudaMultiProcessorCount() { | ||
| static int sm_count = 0; | ||
| if (sm_count == 0) { | ||
| int device_id; | ||
| cudaGetDevice(&device_id); | ||
| cudaDeviceProp device_prop; | ||
| cudaGetDeviceProperties(&device_prop, device_id); | ||
| sm_count = device_prop.multiProcessorCount; | ||
| } | ||
| return sm_count; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make GetCudaMultiProcessorCount thread‑safe and clarify multi‑device semantics
static int sm_countis written without synchronization; concurrent calls from multiple host threads can cause a data race and undefined behavior.- The function also permanently caches the SM count of whichever device is current on the first call; if the process later switches devices, the cached value will be wrong.
Consider making the cache atomic (or using std::call_once) and, if needed, keying by device ID. For example:
- inline int GetCudaMultiProcessorCount() {
- static int sm_count = 0;
- if (sm_count == 0) {
- int device_id;
- cudaGetDevice(&device_id);
- cudaDeviceProp device_prop;
- cudaGetDeviceProperties(&device_prop, device_id);
- sm_count = device_prop.multiProcessorCount;
- }
- return sm_count;
-}
+ inline int GetCudaMultiProcessorCount() {
+ static std::atomic<int> sm_count{0};
+ int cached = sm_count.load(std::memory_order_relaxed);
+ if (cached == 0) {
+ int device_id = 0;
+ cudaGetDevice(&device_id);
+ cudaDeviceProp device_prop{};
+ cudaGetDeviceProperties(&device_prop, device_id);
+ cached = device_prop.multiProcessorCount;
+ sm_count.store(cached, std::memory_order_relaxed);
+ }
+ return cached;
+}(You’d also need to #include <atomic> in this header.)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| inline int GetCudaMultiProcessorCount() { | |
| static int sm_count = 0; | |
| if (sm_count == 0) { | |
| int device_id; | |
| cudaGetDevice(&device_id); | |
| cudaDeviceProp device_prop; | |
| cudaGetDeviceProperties(&device_prop, device_id); | |
| sm_count = device_prop.multiProcessorCount; | |
| } | |
| return sm_count; | |
| inline int GetCudaMultiProcessorCount() { | |
| static std::atomic<int> sm_count{0}; | |
| int cached = sm_count.load(std::memory_order_relaxed); | |
| if (cached == 0) { | |
| int device_id = 0; | |
| cudaGetDevice(&device_id); | |
| cudaDeviceProp device_prop{}; | |
| cudaGetDeviceProperties(&device_prop, device_id); | |
| cached = device_prop.multiProcessorCount; | |
| sm_count.store(cached, std::memory_order_relaxed); | |
| } | |
| return cached; | |
| } |
🤖 Prompt for AI Agents
In include/flashinfer/utils.cuh around lines 292-301, the function uses a
non-synchronized static sm_count that creates a data race and also permanently
caches the SM count for whichever device was current on first call; fix by
making the lookup thread-safe and per-device: first call cudaGetDevice() to get
the current device id, then fetch properties if no cached value exists for that
device; protect the cache with thread-safety (e.g., use std::call_once
per-device, or a mutex + an unordered_map<int,int> keyed by device id, or an
atomic<int> array sized by device count) and include the required headers
(<mutex>, <unordered_map> or <atomic>); return the cached per-device SM count
after initialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
csrc/trtllm_mnnvl_allreduce.cu (1)
56-69: Guard RMSNorm fusion against missingresidual_inand validate its shape.The precondition at line 56 enforces presence of
residual_out,gamma, andepsilonwhenrmsnorm_fusion=true, but omitsresidual_in. The kernel will dereferenceresidualInPtrunconditionally whenRMSNormFusionis true, causing undefined behavior ifresidual_inis absent.Additionally, shape validation (lines 61-68) only covers
residual_outandgamma;residual_inis not validated.Extend the precondition to include
residual_in:- TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) || + TVM_FFI_ICHECK((residual_out.has_value() && residual_in.has_value() && + gamma.has_value() && epsilon.has_value()) || !rmsnorm_fusion) - << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true"; + << "residual_out, residual_in, gamma, and epsilon must be provided if rmsnorm_fusion is true";Add shape validation for
residual_inwithin theif (rmsnorm_fusion)block:if (rmsnorm_fusion) { TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens && residual_out.value().size(1) == token_dim) << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) << ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1) << ")"; + TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens && + residual_in.value().size(1) == token_dim) + << "residual_in shape mismatch: expected (" << num_tokens << ", " << token_dim + << ") but got (" << residual_in.value().size(0) << ", " + << residual_in.value().size(1) << ")"; TVM_FFI_ICHECK(gamma.value().size(0) == token_dim) << "gamma must have the same shape as token dimension (" << token_dim << ") but got (" << gamma.value().size(0) << ")"; }flashinfer/comm/trtllm_mnnvl_ar.py (1)
331-332: Restore RMSNorm epsilon default to 1e-5.Overriding
epsilonwithtorch.finfo(input.dtype).epsreplaces the kernel's built-in 1e-5 default (see line 91 incsrc/trtllm_mnnvl_allreduce.cu). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking parity with TensorRT-LLM.Apply this diff:
if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + epsilon = 1e-5
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
502-504: Clarify assertion for legacy API compatibility.The assertion at lines 502-504 will fail with a cryptic message if
wait_for_results=Falseis passed. Since this is deprecated legacy code, the assertion is reasonable, but consider improving the error message for clarity:- assert wait_for_results and (out is not None), ( - "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." - ) + if not wait_for_results or out is None: + raise ValueError( + "Legacy trtllm_mnnvl_all_reduce requires wait_for_results=True and a valid output tensor. " + "Please use the new trtllm_mnnvl_allreduce API instead." + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (3)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(168-219)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (5)
rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1121-1201)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1101-1118)lamport_initialize(1160-1161)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(854-856)get_buffer_ptrs_dev(1199-1201)get_unicast_ptr(858-866)get_unicast_ptr(1195-1197)get_multicast_ptr(868-872)get_multicast_ptr(1191-1193)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(31-108)trtllm_mnnvl_allreduce_fusion(31-37)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (7)
MNNVLAllreduceFusionWorkspace(47-141)mpi_barrier(23-27)trtllm_mnnvl_fused_allreduce_add_rmsnorm(298-388)MNNVLAllreduceFusionStrategy(30-40)trtllm_mnnvl_allreduce(226-295)get_allreduce_mnnvl_workspace(395-448)get_required_buffer_size_bytes(116-141)
🪛 Ruff (0.14.5)
flashinfer/comm/trtllm_mnnvl_ar.py
74-76: Avoid specifying long messages outside the exception class
(TRY003)
258-260: Avoid specifying long messages outside the exception class
(TRY003)
265-267: Avoid specifying long messages outside the exception class
(TRY003)
335-337: Avoid specifying long messages outside the exception class
(TRY003)
339-341: Avoid specifying long messages outside the exception class
(TRY003)
343-345: Avoid specifying long messages outside the exception class
(TRY003)
349-351: Avoid specifying long messages outside the exception class
(TRY003)
355-357: Avoid specifying long messages outside the exception class
(TRY003)
497-499: Avoid specifying long messages outside the exception class
(TRY003)
568-570: Avoid specifying long messages outside the exception class
(TRY003)
574-576: Avoid specifying long messages outside the exception class
(TRY003)
579-581: Avoid specifying long messages outside the exception class
(TRY003)
583-585: Avoid specifying long messages outside the exception class
(TRY003)
588-590: Avoid specifying long messages outside the exception class
(TRY003)
593-595: Avoid specifying long messages outside the exception class
(TRY003)
tests/comm/test_trtllm_mnnvl_allreduce.py
280-280: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
226-232: Add workspace capacity check to prevent buffer overflow.The new
trtllm_mnnvl_allreducefunction doesn't verify that the input tensor fits within the allocated workspace buffer. A previous review comment suggested adding a check similar to the legacy API'sif inp.shape[0] > buffer_Mvalidation.While the author questioned whether this should be on the execution path, buffer overflow can cause crashes or silent memory corruption. Consider adding a validation check:
required_size = MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy ) if required_size > workspace.buffer_size_bytes: raise ValueError( f"Input tensor requires {required_size} bytes but workspace only has " f"{workspace.buffer_size_bytes} bytes. Please increase workspace size." )Based on past review comments, the maintainer questioned if this check should be on the execution path. If this is intentionally omitted for performance, please document this as a user responsibility in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
flashinfer/comm/mnnvl.py (1)
566-664: Close remaining POSIX FDs in IPC path to avoid leaksIn the POSIX handle path of
_alloc_mn_mcast_mem, a few FDs are still never closed:
local_shareable_uc_handlereturned bycuMemExportToShareableHandle(line 958) is used in the IPC ring allgather but never closed.- During the ring, each rank sends its
local_shareable_uc_handleto all peers, including itself. The self‑recv forp == group_rankpopulatesall_shareable_uc_handles[self.group_rank], but that FD is never imported (due toif p != self.group_rank) and also never closed.You already close imported POSIX FDs after
cuMemImportFromShareableHandleand close the multicast FD after import; closing the remaining two FDs will complete the cleanup and prevent per‑allocation FD leaks in long‑running jobs.One way to fix this:
if ( self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC ): # All-gather fabric handles all_shareable_uc_handles = self.comm_backend.allgather( local_shareable_uc_handle.data ) else: # Implement the allgather logic with ipc socket all_shareable_uc_handles = [None] * self.group_size for i in range(self.group_size): self.comm_backend.barrier() # Send to peer at offset i dest_rank = (self.group_rank + i) % self.group_size self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank) # Receive from peer at offset -i src_rank = (self.group_rank + self.group_size - i) % self.group_size all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ): + # Close our exported UC handle FD and the self-received FD + os.close(local_shareable_uc_handle) + if all_shareable_uc_handles[self.group_rank] is not None: + os.close(all_shareable_uc_handles[self.group_rank])The existing per‑peer close after import:
if self._shareable_handle_type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR: os.close(all_shareable_uc_handles[p])and the multicast close:
if self._shareable_handle_type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR: os.close(shareable_mc_handle)can stay as‑is.
Together with the new
__del__logic callingself._ipc_socket.close(), this fully addresses the descriptor‑leak concern in the IPC path.Also applies to: 957-1005, 1008-1055
tests/comm/test_trtllm_mnnvl_allreduce.py (1)
233-271: Align reference RMSNorm epsilon with kernel default (still usingtorch.finfo(dtype).eps)
prepare_test_datastill usestorch.finfo(dtype).epsas the epsilon for the reference RMSNorm:norm_out = rmsnorm( residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False )while the kernel and test harness default to
eps = 1e-5(seerun_mnnvl_ar_fulland the C++ FFI wrapper’sparams.epsilondefault). This inconsistency can mask subtle discrepancies behind loose tolerances or cause avoidable test drift.To keep the reference path exactly aligned with the implementation, switch this to the same constant:
- norm_out = rmsnorm( - residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False - ) + norm_out = rmsnorm( + residual_out, + norm_weight, + 1e-5, + enable_pdl=False, + )(or better, reuse the same
epsvalue passed intorun_mnnvl_ar_fullto avoid hard‑coding the constant twice).
🧹 Nitpick comments (4)
csrc/trtllm_mnnvl_allreduce.cu (1)
100-114: Ensure epsilon defaults stay consistent with Python API and testsHere
params.epsilonfalls back to1e-5when the Optionalepsilonis not provided:params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5;The Python wrapper in
flashinfer/comm/trtllm_mnnvl_ar.pyand the tests intests/comm/test_trtllm_mnnvl_allreduce.pyshould use the same default to avoid silent discrepancies between the kernel and reference paths. The core test harness already setseps = 1e-5; the remaining mismatch is in the reference RMSNorm computation (seeprepare_test_data), which still usestorch.finfo(dtype).eps.flashinfer/comm/mnnvl.py (3)
132-150: Fixalloc_and_copy_to_cudareturn type vsNonebehavior
alloc_and_copy_to_cudais annotated as returningintbut still returnsNonefor an emptyhost_ptr_array:def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: if not host_ptr_array: return NoneCurrent call sites (
signal_padsanduc_ptrs) always pass non‑empty lists, so behavior is correct, but the annotation is misleading and could hide bugs if the function gets reused.Either make the return type explicit about
Noneor enforce non‑emptiness by raising:-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: - if not host_ptr_array: - return None +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: + if not host_ptr_array: + raise ValueError("host_ptr_array must be non-empty")(or change the annotation to
Optional[int]if you prefer the sentinel behavior).
885-893: IPC opId bootstrap looks fine; consider documenting ordering guarantees
_init_ipc_socketuses an MPI‑likebcastto distribute a randomly chosenopIdfrom rank 0, then uses it to constructIpcSocketendpoints on all ranks. This nicely avoids hard‑coding operation IDs and lines up with the C++ IPC model.Given the reliance on collective barriers around
send_fd/recv_fd, it would help future maintainers to mention in a comment here that all ranks are expected to participate in the same sequence of IPC operations for a givenopId, and that mismatched usage will deadlock. The code is correct as written; this is just a documentation/clarity suggestion.
1143-1170: McastGPUBuffer workspace integration and pointer getters look consistentThe new
comm_backend_for_handle_transferparameter is threaded through toMcastDeviceMemory, and the addedget_unicast_ptrwrapper simply delegates tomcast_device_memory.get_unicast_ptr(rank). This lines up with how tests andget_allreduce_mnnvl_workspaceuse these pointers and keeps pointer access encapsulated.The placeholder buffer‑view methods (
get_multicast_buffer,get_unicast_buffer) are clearly markedNotImplementedError, so they won’t be hit accidentally. If you plan to expose tensor views later, you can implement them viacreate_tensor_from_cuda_memory.Also applies to: 1209-1212
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/mnnvl.py(18 hunks)include/flashinfer/utils.cuh(2 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/utils.cuh
🧬 Code graph analysis (3)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (7)
MNNVLAllreduceFusionWorkspace(47-141)mpi_barrier(23-27)trtllm_mnnvl_fused_allreduce_add_rmsnorm(298-388)MNNVLAllreduceFusionStrategy(30-40)trtllm_mnnvl_allreduce(226-295)get_allreduce_mnnvl_workspace(395-448)get_required_buffer_size_bytes(116-141)flashinfer/comm/mnnvl.py (14)
barrier(168-168)barrier(227-228)Get_rank(156-156)Get_rank(215-216)Get_size(159-159)Get_size(218-219)bcast(165-165)bcast(224-225)get_multicast_ptr(871-875)get_multicast_ptr(1205-1207)get_buffer_ptrs_dev(857-859)get_buffer_ptrs_dev(1213-1215)get_unicast_ptr(861-869)get_unicast_ptr(1209-1211)
csrc/trtllm_mnnvl_allreduce.cu (2)
csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(168-219)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)
🪛 Ruff (0.14.5)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
888-888: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
include/flashinfer/utils.cuh (1)
293-307: Thread‑safe SM count cache looks good; confirm single‑GPU‑per‑process assumptionUsing
static std::atomic<int>with relaxed loads/stores makes this helper thread‑safe and avoids the previousstatic intdata race. The comment explicitly assumes one CUDA device per process, since the cachedsm_countis never recomputed if the current device changes.If there are any call sites that may run in a multi‑GPU‑per‑process setup, consider extending this to a per‑device cache (e.g., keyed by device id) rather than a single global integer; otherwise, this implementation is fine as long as the single‑device assumption holds.
csrc/trtllm_mnnvl_allreduce.cu (1)
41-76: RMSNorm fusion validation and shape checks look correctThe updated precondition now correctly requires
residual_in,residual_out,gamma, andepsilonwhenrmsnorm_fusionis true, and the subsequent shape checks onresidual_in,residual_out, andgammaguard the fused path against mismatched tensors. This should prevent the fused kernels from ever seeing invalid residual/norm inputs via the FFI boundary.The overall parameter wiring into
AllReduceFusionParams(including buffer pointers and flags) also looks consistent with the Python side.flashinfer/comm/mnnvl.py (1)
781-790: Good: IPC socket is now closed in destructorThe addition of:
if hasattr(self, "_ipc_socket"): self._ipc_socket.close()inside
__del__ensures the Unix domain socket is closed and, for non‑abstract sockets, the filesystem entry is unlinked. This addresses the earlier socket‑leak concern while remaining safe when construction fails before_ipc_socketis set.tests/comm/test_trtllm_mnnvl_allreduce.py (1)
16-103: Test harness refactor cleanly exercises both refactored and legacy APIsThe new helpers (
row_linear_residual_norm_fusion_forward,_legacy,run_mnnvl_ar_full) and parametrized tests (test_mnnvl_allreduce_refactored,test_mnnvl_allreduce_legacy) do a good job of:
- Sharing core logic between fused and non‑fused paths.
- Covering both the new workspace‑based API and the legacy pointer‑based API.
- Exercising a variety of sequence lengths, dtypes, and hidden sizes.
- Integrating MPI barriers and rank‑aware logging to make multi‑rank failures diagnosable.
Once the epsilon alignment in
prepare_test_datais fixed, this test suite should give solid coverage for the new fused implementation and its backward‑compatibility guarantees.Also applies to: 274-397, 439-465
flashinfer/comm/trtllm_mnnvl_ar.py
Outdated
| comm_backend: Optional[CommBackend] = None, | ||
| ): | ||
| """ | ||
| Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way we can check this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to update the doc. Fixed.
| def __init__( | ||
| self, | ||
| mapping: Mapping, | ||
| buffer_size_in_bytes: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide guidance for buffer_size_in_bytes? E.g., in function of number of tokens and hidden size? Or just refer to get_required_buffer_size_bytes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just refer to get_required_buffer_size_bytes
| def trtllm_mnnvl_allreduce( | ||
| input: torch.Tensor, | ||
| workspace: MNNVLAllreduceFusionWorkspace, | ||
| launch_with_pdl: bool, | ||
| output: Optional[torch.Tensor] = None, | ||
| strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do want this check. I recently added it because it did bite others.
| def __init__( | ||
| self, | ||
| mapping: Mapping, | ||
| buffer_size_in_bytes: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option would be to replace "buffer_size_in_bytes" by the parameters that get_required_buffer_size_bytes takes, and just call this from the init function. Seems more user friendly.
If you do want to just allocate a blob of memory, we could still have buffer_size_in_bytes as an addtional parameter that would override whatever is calculated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that is a good design as it might give the user an impression that the allocated workspace will ONLY support the set of parameters (max_num_tokens, hidden_dim, dtype, strategy)
But actually, the workspace usage is quite flexible and as long as the required workspace size is smaller than the allocation, it will work. Thus, the intended usage is the user checks the required workspace size (or we can check it when calling the allreduce function, but at some cost) before using it.
csrc/trtllm_mnnvl_allreduce.cu
Outdated
| Optional<TensorView> out) { | ||
| cudaSetDevice(in.device().device_id); | ||
| auto stream = get_stream(in.device()); | ||
| // FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is a problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment removed.
flashinfer/comm/trtllm_mnnvl_ar.py
Outdated
| AUTO = 99 | ||
|
|
||
| @staticmethod | ||
| def is_one_shot( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename this to "heuristic_for_one_shot" or something like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
361-362: Critical: Restore epsilon default to 1e-5 to match kernel.This epsilon fallback was flagged as critical in a previous review but remains unresolved. Using
torch.finfo(input.dtype).epssets epsilon to approximately 1e-3 for fp16, diverging from the kernel's built-in 1e-5 default (seecsrc/trtllm_mnnvl_allreduce.culine 96). This materially alters RMSNorm results and breaks compatibility with TensorRT-LLM.Apply this fix:
- if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + if epsilon is None: + epsilon = 1e-5
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
118-136: Consider replacing@functools.cacheon instance method.Using
@functools.cacheon an instance method can prevent the instance from being garbage collected, leading to memory leaks. Since this method takesselfas the first parameter, the cache will hold references to the instance.Consider either:
- Making this a standalone function that takes workspace parameters explicitly
- Using
@functools.lru_cache(maxsize=...)with a reasonable limit- Implementing manual caching in the instance if needed
Based on learnings
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (2)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (6)
Mapping(21-475)rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1135-1215)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1115-1132)lamport_initialize(1174-1175)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(857-859)get_buffer_ptrs_dev(1213-1215)get_unicast_ptr(861-869)get_unicast_ptr(1209-1211)get_multicast_ptr(871-875)get_multicast_ptr(1205-1207)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(29-113)trtllm_mnnvl_allreduce_fusion(29-35)
csrc/trtllm_mnnvl_allreduce.cu (3)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)flashinfer/comm/trtllm_mnnvl_ar.py (1)
trtllm_mnnvl_allreduce_fusion(192-243)
🪛 Ruff (0.14.5)
flashinfer/comm/trtllm_mnnvl_ar.py
77-79: Avoid specifying long messages outside the exception class
(TRY003)
118-118: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
282-284: Avoid specifying long messages outside the exception class
(TRY003)
289-291: Avoid specifying long messages outside the exception class
(TRY003)
303-305: Avoid specifying long messages outside the exception class
(TRY003)
365-367: Avoid specifying long messages outside the exception class
(TRY003)
369-371: Avoid specifying long messages outside the exception class
(TRY003)
373-375: Avoid specifying long messages outside the exception class
(TRY003)
379-381: Avoid specifying long messages outside the exception class
(TRY003)
385-387: Avoid specifying long messages outside the exception class
(TRY003)
398-400: Avoid specifying long messages outside the exception class
(TRY003)
528-530: Avoid specifying long messages outside the exception class
(TRY003)
599-601: Avoid specifying long messages outside the exception class
(TRY003)
605-607: Avoid specifying long messages outside the exception class
(TRY003)
610-612: Avoid specifying long messages outside the exception class
(TRY003)
614-616: Avoid specifying long messages outside the exception class
(TRY003)
619-621: Avoid specifying long messages outside the exception class
(TRY003)
624-626: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
csrc/trtllm_mnnvl_allreduce.cu (1)
29-113: LGTM! Fusion entry point is well-structured.The refactored entry point properly validates all inputs, including the RMSNorm fusion parameters that were flagged in previous reviews. The dispatch logic cleanly selects between oneshot and twoshot strategies, and error messages are clear and actionable.
flashinfer/comm/trtllm_mnnvl_ar.py (3)
30-48: Strategy enum and heuristic look good.The
MNNVLAllreduceFusionStrategyenum provides a clear interface for selecting between oneshot and twoshot approaches, with a sensible AUTO mode that uses an empirically-derived threshold.
250-326: Buffer size validation properly implemented.The function now includes the buffer size check that was requested in previous reviews (lines 300-305), preventing potential out-of-bounds access. Input validation is comprehensive and error messages are clear.
422-646: Deprecation strategy is well-executed.The legacy APIs are properly marked with
@deprecateddecorators and include clear migration guidance. The wrappers correctly redirect to the new fusion-based implementations, maintaining backward compatibility while encouraging adoption of the improved APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)
132-149: Fix return type inconsistency.The function returns
Noneat line 137 whenhost_ptr_arrayis empty, but the return type annotation at line 132 indicatesint. This creates a type mismatch.Consider one of these fixes:
Option 1: Return
Optional[int]and update callers to handle None:-def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:Option 2: Raise an error instead of returning None:
if not host_ptr_array: - return None + raise ValueError("host_ptr_array cannot be empty")
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
377-378: Restore RMSNorm epsilon default to 1e-5.Overriding
epsilonwithtorch.finfo(input.dtype).epsreplaces the kernel's built-in 1e-5 default (seetrtllm_mnnvl_allreduce_fusionincsrc/trtllm_mnnvl_allreduce.culine ~35:params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5). For fp16 this becomes ~1e-3, materially changing RMSNorm results and breaking numerical parity.Apply this diff to fix:
if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + epsilon = 1e-5
🧹 Nitpick comments (3)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
134-152: Consider alternatives to@functools.cacheon instance methods.Using
@functools.cache(or@lru_cache) on instance methods can prevent garbage collection of instances because the cache holds references to bound methods, which in turn hold references toself. SinceMNNVLAllreduceFusionWorkspaceinstances are likely long-lived in typical usage, this may be acceptable, but consider these alternatives:
- Use
@functools.lru_cache(maxsize=128)to limit cache growth- Move caching logic to a module-level cache keyed on relevant parameters
- Document the caching behavior and its memory implications
Based on learnings
Apply this diff if you want to limit cache size:
- @functools.cache + @functools.lru_cache(maxsize=128) def is_buffer_size_sufficient(flashinfer/comm/mnnvl.py (2)
640-654: Prefix unused unpacked variables with underscore.The variables
msg,flags, andaddrfromrecvmsgare unpacked but never used. Prefix them with_to indicate they're intentionally ignored.Apply this diff:
- msg, ancdata, flags, addr = self.sock.recvmsg( + _msg, ancdata, _flags, _addr = self.sock.recvmsg(
893-900: Consider usingsecretsmodule for opId generation.While cryptographic randomness is not strictly required for socket naming, using
secrets.randbelow(2**64)instead ofrandom.randintprovides better collision resistance if multiple jobs run concurrently on the same node.Apply this diff:
+import secrets + def _init_ipc_socket(self): if self.group_rank == 0: - # Gnerate the opId - opId = random.randint(0, 2**64 - 1) + # Generate the opId + opId = secrets.randbelow(2**64)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/comm/mnnvl.py(19 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (2)
flashinfer/comm/mnnvl.py (1)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)
flashinfer/comm/trtllm_mnnvl_ar.py (4)
flashinfer/comm/mapping.py (5)
rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/utils.py (2)
register_custom_op(273-282)register_custom_op(292-311)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1143-1224)CommBackend(152-171)MPIBackend(211-232)lamport_initialize(1123-1140)lamport_initialize(1183-1184)barrier(168-168)barrier(227-228)get_buffer_ptrs_dev(857-859)get_buffer_ptrs_dev(1222-1224)get_unicast_ptr(861-869)get_unicast_ptr(1218-1220)get_multicast_ptr(871-875)get_multicast_ptr(1214-1216)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(29-113)trtllm_mnnvl_allreduce_fusion(29-35)
🪛 Ruff (0.14.5)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
896-896: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
flashinfer/comm/trtllm_mnnvl_ar.py
77-79: Avoid specifying long messages outside the exception class
(TRY003)
134-134: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
298-300: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
319-321: Avoid specifying long messages outside the exception class
(TRY003)
381-383: Avoid specifying long messages outside the exception class
(TRY003)
385-387: Avoid specifying long messages outside the exception class
(TRY003)
389-391: Avoid specifying long messages outside the exception class
(TRY003)
395-397: Avoid specifying long messages outside the exception class
(TRY003)
401-403: Avoid specifying long messages outside the exception class
(TRY003)
414-416: Avoid specifying long messages outside the exception class
(TRY003)
544-546: Avoid specifying long messages outside the exception class
(TRY003)
615-617: Avoid specifying long messages outside the exception class
(TRY003)
621-623: Avoid specifying long messages outside the exception class
(TRY003)
626-628: Avoid specifying long messages outside the exception class
(TRY003)
630-632: Avoid specifying long messages outside the exception class
(TRY003)
635-637: Avoid specifying long messages outside the exception class
(TRY003)
640-642: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
266-341: LGTM!The function correctly validates inputs, selects strategy, checks buffer size sufficiency (addressing past review feedback), and invokes the fusion kernel with appropriate parameters.
flashinfer/comm/mnnvl.py (1)
788-789: LGTM!The IPC socket cleanup correctly uses
hasattrto check for existence before closing, addressing the file descriptor leak concern from past reviews.
| #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) | ||
| asm volatile("red.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) | ||
| : "memory"); | ||
| #else |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean the kernel is supported for archs < 700? What's the minimal requirement?
For the API we use the @backend_requirement decorator, which lists supported SMs. So as a minimum I think we can list: 70,80,90,100,103,110,120
Would you agree? Further back is probably not as relevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This macro is fairly common in barrier and semaphore utility functions, largely due to the memory consistency qualifiers introduced with the Volta architecture. For example, CUTLASS uses a similar pattern:
That said, I believe our usage here simply follows established convention. Given that our minimum supported architecture is sm_75, we shouldn't actually need these qualifiers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel needs multicast to work, which at least requires SM90 and needs NVSwitch
| if (NUM_INPUTS > 0) { | ||
| T_IN accum[ELTS_PER_THREAD]; | ||
| float4* accum4 = (float4*)&accum; | ||
| flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of clearing the buffer here, could we assume we have properly initialized lamport buffers at the start?
And then, at the end (e.g., right after we call PDL), we can clear the buffers, so that a next kernel using the same workspace can also assume properly initialized lamport buffers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use triple buffer so you can move the buffer clear anywhere. I found the current arrangement get the best performance.
If your assumption is using single buffer, clear the buffer at the end then assume the buffer is initialized for the next kernel, this won't work as it requires membar.sys which is very expensive. Moreover, we need to order the buffer write of the next kernel, other ranks after the buffer clear of this kernel, which has to use a flag and invalidate the usage of lamport protocol.
| #pragma unroll | ||
| for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { | ||
| valid &= !isNegZero(valuesLamport[r].elements[i]); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as one of the values in WorldSize isNegZero, we will keep reading the others over and over again. Perhaps these are all cache hits, in which case: ignore my comment. Otherwise, would it make sense to break out of the inner loop (kLAMPORT_ELTS_PER_PACKED) early?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be cache hit if it is still invalid.
For the inner loop, I had a version that only check one element for validity. But there is no architecture guarantee that a 128B read/write is atomic, which mean beak the inner loop early could, in rare case, cause wrong result.
flashinfer/comm/mnnvl.py
Outdated
| all_shareable_uc_handles = self.comm_backend.allgather( | ||
| local_shareable_uc_handle.data | ||
| ) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Add a comment regarding handle type in this case: CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Branching logic no longer exists in the latest commit.
flashinfer/comm/mnnvl.py
Outdated
| mc_fabric_handle.data if mc_fabric_handle else None, root=0 | ||
| ) | ||
| shareable_mc_handle = None | ||
| if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With all the branching on this handle type, would it make sense to build a class that encapsulate this logic?
Roughly thinking:
class ABCCommBufferAllocator:
def allocate_UnicastBuffer()
def allocate_MulitcastBuffer()
class CUFabricCommBufferAllocator(ABCCommBufferAllocator)
class CUPosixFDCommBufferAllocator(ABCCommBufferAllocator)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored in the latest commit.
| self.mcast_device_memory.lamport_initialize(rank, dtype) | ||
|
|
||
| def get_mc_buffer( | ||
| def get_multicast_buffer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)
612-648: Duplicate constructor parameter breaksMcastDeviceMemory
McastDeviceMemory.__init__declarescomm_backend_for_handle_transfertwice in the parameter list, which is a syntax error and prevents the module from importing.Remove the duplicate parameter so only one
comm_backend_for_handle_transferargument remains:- is_multi_node: bool = True, - comm_backend_for_handle_transfer: Optional[CommBackend] = None, - comm_backend_for_handle_transfer: Optional[CommBackend] = None, + is_multi_node: bool = True, + comm_backend_for_handle_transfer: Optional[CommBackend] = None,
♻️ Duplicate comments (3)
flashinfer/comm/mnnvl.py (1)
875-901: Local POSIX FD fromcuMemExportToShareableHandleis never closedIn
_alloc_mn_mcast_mem, the POSIX path exportslocal_shareable_uc_handleand circulates it viaIpcSocket. Remote FDs are correctly closed after import, but the original exported FD (local_shareable_uc_handle) is never closed, leading to a per-rank FD leak over time. This was previously flagged; the remote-close part is fixed, but the local FD still needs closing.Close the local FD after the ring allgather in the POSIX branch:
if self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: # All-gather fabric handles all_shareable_uc_handles = self.comm_backend.allgather(local_shareable_uc_handle.data) else: # Implement the allgather logic with ipc socket all_shareable_uc_handles = [None] * self.group_size for i in range(self.group_size): self.comm_backend.barrier() # Send to peer at offset i dest_rank = (self.group_rank + i) % self.group_size self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank) # Receive from peer at offset -i src_rank = (self.group_rank + self.group_size - i) % self.group_size all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() + # Local FD no longer needed after all peers have imported it + os.close(local_shareable_uc_handle)flashinfer/comm/trtllm_mnnvl_ar.py (1)
354-385: RMSNorm epsilon default diverges from CUDA/kernel and prior review
trtllm_mnnvl_fused_allreduce_add_rmsnormcurrently does:if epsilon is None: epsilon = torch.finfo(input.dtype).epsThe CUDA entry point (
trtllm_mnnvl_allreduce_fusion) still defaultsepsilonto1e-5when the Optional is not set, and TensorRT-LLM uses 1e-5 as its RMSNorm default. Because the Python wrapper always passes an explicit epsilon, the kernel’s 1e-5 default is never used, and for fp16 this changes behavior materially (~1e-3 vs 1e-5).To preserve parity and avoid surprising numerics, consider restoring the 1e-5 default here and updating the docstring accordingly:
- if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + if epsilon is None: + epsilon = 1e-5and:
- epsilon: The epsilon parameter for RMSNorm, torch.finfo.eps will be used if not provided. + epsilon: The epsilon parameter for RMSNorm; defaults to 1e-5 if not provided.You’ll also want tests to use the same constant for their reference RMSNorm (see
prepare_test_datain the test file).tests/comm/test_trtllm_mnnvl_allreduce.py (1)
242-280: Reference RMSNorm epsilon should match kernel/API default
prepare_test_databuilds the reference RMSNorm output with:norm_out = rmsnorm( residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False )While
run_mnnvl_ar_fullsetseps = 1e-5and passes that into the forward helpers, and the CUDA side uses 1e-5 as its default. Once the Python API is updated to default epsilon to 1e-5, the reference path here should also use 1e-5 to avoid systematic discrepancies.Suggested change:
- norm_out = rmsnorm( - residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False - ) + norm_out = rmsnorm( + residual_out, norm_weight, 1e-5, enable_pdl=False + )This keeps tests aligned with the kernel’s behavior and the Python wrapper’s default.
🧹 Nitpick comments (2)
flashinfer/comm/mnnvl.py (1)
127-143: alloc_and_copy_to_cuda return type vs behavior
alloc_and_copy_to_cudais annotated to returnintbut returnsNonewhenhost_ptr_arrayis empty. This is harmless for current call sites (they always pass non-empty lists) but makes the annotation misleading and can trip static type checkers.Consider either:
- Changing the annotation to
Optional[int]and documenting theNonecase, or- Removing the early return and letting the function always return a valid device pointer (e.g., disallow empty input or treat it as a zero-byte allocation).
include/flashinfer/utils.cuh (1)
24-25: Thread-safe SM count cache looks good; note single-device assumptionUsing
std::atomic<int>forsm_countremoves the data race from the previous static-int pattern while keeping the lookup cheap. The comment correctly calls out that this caches the SM count for whichever device is current on first use, assuming one CUDA device per process.If you ever need multi-device-per-process support, consider extending this to cache per-device (e.g., a small
std::vector<std::atomic<int>>indexed by device ID); otherwise this implementation is fine for the stated assumption.Also applies to: 339-353
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/trtllm_mnnvl_allreduce.cu(1 hunks)flashinfer/comm/mnnvl.py(33 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(6 hunks)include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh(2 hunks)include/flashinfer/utils.cuh(2 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(9 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/utils.cuhflashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (1)
flashinfer/comm/trtllm_mnnvl_ar.py (4)
flashinfer/comm/mapping.py (5)
Mapping(21-475)rank(311-312)rank(315-322)tp_rank(325-326)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/comm/mnnvl.py (13)
McastGPUBuffer(1032-1113)CommBackend(145-164)MPIBackend(204-225)lamport_initialize(1014-1029)lamport_initialize(1076-1077)barrier(161-161)barrier(220-221)get_buffer_ptrs_dev(778-780)get_buffer_ptrs_dev(1111-1113)get_unicast_ptr(782-790)get_unicast_ptr(1107-1109)get_multicast_ptr(792-796)get_multicast_ptr(1103-1105)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(29-113)trtllm_mnnvl_allreduce_fusion(29-35)
🪛 GitHub Actions: pre-commit
flashinfer/comm/mnnvl.py
[warning] 62-62: pre-commit: minor formatting/consistency note after fixes (merge-conflict markers touched code).
[warning] 1-1: pre-commit: potential changes applied by formatting hook (ruff-format had modifications).
flashinfer/comm/trtllm_mnnvl_ar.py
[error] 36-36: invalid-syntax: Expected class, function definition or async function definition after decorator (likely due to merge conflict markers <<<<<<< HEAD / ======= / >>>>>>> ...)
[error] 169-169: invalid-syntax: Duplicate merge conflict marker present (HEAD) or conflict resolution remnants.
[error] 314-314: invalid-syntax: Duplicate merge conflict marker present (HEAD) or conflict lines not resolved.
tests/comm/test_trtllm_mnnvl_allreduce.py
[error] 29-29: invalid-syntax: Duplicate merge conflict marker present (HEAD) or unresolved conflict markers in test file.
[error] 1-1: Parse error: file appears to contain merge conflict markers from an unresolved conflict.
🪛 Ruff (0.14.6)
flashinfer/comm/mnnvl.py
393-393: Avoid specifying long messages outside the exception class
(TRY003)
533-533: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
558-558: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
586-586: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
586-586: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
586-586: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
598-598: Avoid specifying long messages outside the exception class
(TRY003)
620-620: Duplicate parameter "comm_backend_for_handle_transfer"
(invalid-syntax)
672-672: Avoid specifying long messages outside the exception class
(TRY003)
692-692: Avoid specifying long messages outside the exception class
(TRY003)
750-750: Do not catch blind exception: Exception
(BLE001)
817-817: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
832-832: Do not catch blind exception: Exception
(BLE001)
flashinfer/comm/trtllm_mnnvl_ar.py
36-36: Expected class, function definition or async function definition after decorator
(invalid-syntax)
36-36: Expected a statement
(invalid-syntax)
36-36: Expected a statement
(invalid-syntax)
36-36: Expected a statement
(invalid-syntax)
37-37: Unexpected indentation
(invalid-syntax)
37-38: Expected an indented block after function definition
(invalid-syntax)
38-38: Expected a statement
(invalid-syntax)
38-38: Expected a statement
(invalid-syntax)
38-38: Expected a statement
(invalid-syntax)
38-38: Expected a statement
(invalid-syntax)
38-39: Expected a statement
(invalid-syntax)
39-39: Unexpected indentation
(invalid-syntax)
41-42: Expected an indented block after function definition
(invalid-syntax)
42-42: Expected a statement
(invalid-syntax)
42-42: Expected a statement
(invalid-syntax)
42-42: Expected a statement
(invalid-syntax)
42-42: Expected a statement
(invalid-syntax)
42-42: Expected ,, found name
(invalid-syntax)
42-42: Expected ,, found name
(invalid-syntax)
42-42: Expected an identifier
(invalid-syntax)
43-43: Unexpected indentation
(invalid-syntax)
51-51: Expected a statement
(invalid-syntax)
169-169: Expected a statement
(invalid-syntax)
169-169: Expected a statement
(invalid-syntax)
169-169: Expected a statement
(invalid-syntax)
169-169: Expected a statement
(invalid-syntax)
170-170: Unexpected indentation
(invalid-syntax)
173-174: Expected an indented block after if statement
(invalid-syntax)
174-174: Expected a statement
(invalid-syntax)
174-174: Expected a statement
(invalid-syntax)
174-174: Expected a statement
(invalid-syntax)
174-174: Expected a statement
(invalid-syntax)
174-175: Expected a statement
(invalid-syntax)
175-175: Unexpected indentation
(invalid-syntax)
180-181: Expected an indented block after if statement
(invalid-syntax)
181-181: Expected a statement
(invalid-syntax)
181-181: Expected a statement
(invalid-syntax)
181-181: Expected a statement
(invalid-syntax)
181-181: Expected a statement
(invalid-syntax)
181-181: Expected ,, found name
(invalid-syntax)
181-181: Expected ,, found name
(invalid-syntax)
181-181: Expected an identifier
(invalid-syntax)
183-183: Unexpected indentation
(invalid-syntax)
184-184: unindent does not match any outer indentation level
(invalid-syntax)
184-184: Expected a statement
(invalid-syntax)
184-184: Expected a statement
(invalid-syntax)
184-185: Expected a statement
(invalid-syntax)
187-187: Unexpected indentation
(invalid-syntax)
188-188: unindent does not match any outer indentation level
(invalid-syntax)
314-314: Expected a statement
(invalid-syntax)
314-314: Expected a statement
(invalid-syntax)
314-314: Expected a statement
(invalid-syntax)
314-314: Expected a statement
(invalid-syntax)
315-315: Unexpected indentation
(invalid-syntax)
319-319: Expected a statement
(invalid-syntax)
319-319: Expected a statement
(invalid-syntax)
319-319: Expected a statement
(invalid-syntax)
319-319: Expected a statement
(invalid-syntax)
319-320: Expected a statement
(invalid-syntax)
320-320: Unexpected indentation
(invalid-syntax)
332-332: Expected a statement
(invalid-syntax)
332-332: Expected a statement
(invalid-syntax)
332-332: Expected a statement
(invalid-syntax)
332-332: Expected a statement
(invalid-syntax)
332-332: Expected ,, found name
(invalid-syntax)
332-332: Expected ,, found name
(invalid-syntax)
332-332: Expected an identifier
(invalid-syntax)
333-333: Unexpected indentation
(invalid-syntax)
354-354: Expected a statement
(invalid-syntax)
tests/comm/test_trtllm_mnnvl_allreduce.py
29-29: Expected a statement
(invalid-syntax)
29-29: Expected a statement
(invalid-syntax)
29-29: Expected a statement
(invalid-syntax)
29-29: Expected a statement
(invalid-syntax)
30-30: Unexpected indentation
(invalid-syntax)
35-35: Expected a statement
(invalid-syntax)
35-35: Expected a statement
(invalid-syntax)
35-35: Expected a statement
(invalid-syntax)
35-35: Expected a statement
(invalid-syntax)
35-36: Expected a statement
(invalid-syntax)
36-36: Unexpected indentation
(invalid-syntax)
37-37: Expected a statement
(invalid-syntax)
37-37: Expected a statement
(invalid-syntax)
37-37: Expected a statement
(invalid-syntax)
37-37: Expected a statement
(invalid-syntax)
37-37: Expected ,, found name
(invalid-syntax)
37-37: Expected ,, found name
(invalid-syntax)
37-37: Expected an identifier
(invalid-syntax)
39-39: Unexpected indentation
(invalid-syntax)
114-114: Expected a statement
(invalid-syntax)
325-325: Expected a statement
(invalid-syntax)
325-325: Expected a statement
(invalid-syntax)
325-325: Expected a statement
(invalid-syntax)
325-325: Expected a statement
(invalid-syntax)
326-326: Unexpected indentation
(invalid-syntax)
329-329: unindent does not match any outer indentation level
(invalid-syntax)
330-330: Expected a statement
(invalid-syntax)
330-330: Expected a statement
(invalid-syntax)
330-330: Expected a statement
(invalid-syntax)
330-330: Expected a statement
(invalid-syntax)
330-331: Expected a statement
(invalid-syntax)
331-331: Unexpected indentation
(invalid-syntax)
337-337: Expected a statement
(invalid-syntax)
337-337: Expected a statement
(invalid-syntax)
337-337: Expected a statement
(invalid-syntax)
337-337: Expected a statement
(invalid-syntax)
337-337: Expected ,, found name
(invalid-syntax)
337-337: Expected ,, found name
(invalid-syntax)
337-337: Expected an identifier
(invalid-syntax)
338-338: Unexpected indentation
(invalid-syntax)
382-383: Expected an indented block after if statement
(invalid-syntax)
383-383: Expected except or finally after try block
(invalid-syntax)
383-383: Expected a statement
(invalid-syntax)
383-383: Expected a statement
(invalid-syntax)
383-383: Expected a statement
(invalid-syntax)
384-384: Unexpected indentation
(invalid-syntax)
388-388: unindent does not match any outer indentation level
(invalid-syntax)
393-393: Unexpected indentation
(invalid-syntax)
413-413: Expected a statement
(invalid-syntax)
413-413: Expected a statement
(invalid-syntax)
413-413: Expected a statement
(invalid-syntax)
413-413: Expected a statement
(invalid-syntax)
413-414: Expected a statement
(invalid-syntax)
414-414: Unexpected indentation
(invalid-syntax)
417-417: unindent does not match any outer indentation level
(invalid-syntax)
434-434: unindent does not match any outer indentation level
(invalid-syntax)
434-434: Expected a statement
(invalid-syntax)
434-434: Expected a statement
(invalid-syntax)
434-435: Expected a statement
(invalid-syntax)
435-435: Unexpected indentation
(invalid-syntax)
447-447: unindent does not match any outer indentation level
(invalid-syntax)
449-449: Unexpected indentation
(invalid-syntax)
451-451: unindent does not match any outer indentation level
(invalid-syntax)
451-451: Expected a statement
(invalid-syntax)
451-451: Expected a statement
(invalid-syntax)
451-452: Expected an expression
(invalid-syntax)
452-452: Unexpected indentation
(invalid-syntax)
472-472: unindent does not match any outer indentation level
(invalid-syntax)
472-472: Expected a statement
(invalid-syntax)
472-472: Expected a statement
(invalid-syntax)
472-473: Expected a statement
(invalid-syntax)
474-474: Unexpected indentation
(invalid-syntax)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
csrc/trtllm_mnnvl_allreduce.cu (1)
29-112: FFI wrapper validation for fusion/RMSNorm is solidThe new
trtllm_mnnvl_allreduce_fusionwrapper wires all parameters intoAllReduceFusionParamscorrectly and adds the right guards:
- token_dim alignment to float4-based packing
- nranks/rank range checks
- mandatory presence of
residual_in,residual_out,gamma, andepsilonwhenrmsnorm_fusionis true- shape checks for residual tensors and gamma
This closes the gap where fused kernels could be launched with missing or mismatched residual inputs.
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (1)
32-1219: Lamport-based fusion kernels and dispatch look structurally soundThe header refactor introduces:
- A richer
AllReduceFusionParamstailored for fusion (buffer pointers, multicast pointer, flags, RMSNorm toggle, PDL toggle).- A focused
utilsnamespace with Lamport buffer layout/flags, packed-vector helpers, warp/block reduction helpers, andadjustGridConfigthat usesGetCudaMultiProcessorCountto pick reasonable block/cluster/load settings.- New
oneshotAllreduceFusionKernelandtwoshotAllreduceKernelspecializations with explicit world-size dispatch (2–64) and an optional fused RMSNorm path viarmsNormLamport.From a structural standpoint, the buffer layout and synchronization logic (Lamport flags, negative-zero sentinels,
waitAndUpdate/clearDirtyLamportBuf, and cluster-aware reductions) are coherent and match established TensorRT-LLM patterns. The error paths for unsupported world sizes and loads-per-thread are clear and return CUDA errors instead of silently misbehaving.
| def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None): | ||
| MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] | ||
| comm = config.comm_backend.Split( | ||
| mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank | ||
| ) | ||
| comm = config.comm_backend.Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank) | ||
| MnnvlMemory.comm = comm # type: ignore[assignment] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug in set_comm_from_config when config is None
set_comm_from_config assigns MnnvlMemory.config = config or MnnvlConfig(...) but still uses the original config variable to access comm_backend. When config is None, this will raise AttributeError.
Suggested fix:
@staticmethod
def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None):
- MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined]
- comm = config.comm_backend.Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank)
+ config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined]
+ MnnvlMemory.config = config
+ comm = config.comm_backend.Split(
+ mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
+ )
MnnvlMemory.comm = comm # type: ignore[assignment]🤖 Prompt for AI Agents
In flashinfer/comm/mnnvl.py around lines 293 to 297, set_comm_from_config
assigns MnnvlMemory.config = config or MnnvlConfig(...) but then still uses the
original config variable to access comm_backend, which will raise AttributeError
when config is None; fix by computing an effective_config = config or
MnnvlConfig(comm_backend=MPIBackend()), assign MnnvlMemory.config =
effective_config, then call effective_config.comm_backend.Split(...) and assign
the returned comm to MnnvlMemory.comm so the fallback config is actually used
when config is None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
flashinfer/comm/mnnvl.py (1)
566-664: Still leaking the local POSIX FD in_alloc_mn_mcast_mem’s IPC path.The POSIX-FD flow now:
- Uses
IpcSocket+send_fd/recv_fdfor allgather/bcast.- Closes imported FDs after
cuMemImportFromShareableHandle.- Closes
_ipc_socketin__del__.However,
local_shareable_uc_handle(the FD returned bycuMemExportToShareableHandle) is never closed in the POSIX path after the ring allgather finishes. Over long runs this will accumulate FDs on each rank.You already close:
all_shareable_uc_handles[p]after import; andshareable_mc_handleafter multicast import.Please also close the local exported FD after the ring completes, e.g.:
if ( self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC ): # All-gather fabric handles all_shareable_uc_handles = self.comm_backend.allgather( local_shareable_uc_handle.data ) else: # Implement the allgather logic with ipc socket all_shareable_uc_handles = [None] * self.group_size for i in range(self.group_size): self.comm_backend.barrier() # Send to peer at offset i dest_rank = (self.group_rank + i) % self.group_size self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank) # Receive from peer at offset -i src_rank = (self.group_rank + self.group_size - i) % self.group_size all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() - cuda.cuCtxSynchronize() + cuda.cuCtxSynchronize() + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ): + os.close(local_shareable_uc_handle)This complements the existing closes for imported FDs and multicast FDs and avoids descriptor leaks.
Also applies to: 893-901, 965-993, 1004-1009, 1058-1063
flashinfer/comm/trtllm_mnnvl_ar.py (1)
344-379: Restore RMSNorm epsilon default to 1e-5 instead oftorch.finfo(...).eps.Using
torch.finfo(input.dtype).epswhenepsilon is Nonechanges RMSNorm behavior materially, especially for fp16 (eps ~1e-3), and diverges from the CUDA kernel’s built-in default of1e-5used by TensorRT-LLM. This was previously called out and agreed to be corrected; the current code reintroduces the same behavior.To keep parity with the C++ kernel and prior APIs, prefer a fixed
1e-5default:- if epsilon is None: - epsilon = torch.finfo(input.dtype).eps + if epsilon is None: + # Match TensorRT-LLM kernel default for RMSNorm + epsilon = 1e-5Also update the docstring so it no longer claims that
torch.finfo.epsis used by default.
🧹 Nitpick comments (3)
flashinfer/comm/mnnvl.py (2)
132-149: Clarify empty-input behavior and return typing foralloc_and_copy_to_cuda.
alloc_and_copy_to_cudais annotated to returnintbut returnsNonewhenhost_ptr_arrayis empty. That mismatch can confuse callers and type-checkers.Consider either:
- Making the return type
Optional[int]and documenting thatNonemeans “no allocation performed”, or- Raising on empty input (if that’s never expected here), or
- Returning
0as a sentinel pointer value instead ofNone.
630-645: Clean up unusedrecvmsgoutputs inIpcSocket.recv_fd.
msg,flags, andaddrfromself.sock.recvmsg(...)are never used, which also triggers Ruff’s RUF059 warnings.You can keep the call but acknowledge the unused values to improve clarity and silence the linter:
- fds = array.array("i") - msg, ancdata, flags, addr = self.sock.recvmsg( + fds = array.array("i") + _msg, ancdata, _flags, _addr = self.sock.recvmsg( 1, socket.CMSG_SPACE( fds.itemsize ), # Buffer size for dummy data # Ancillary data size )flashinfer/comm/trtllm_mnnvl_ar.py (1)
30-47: Strategy selection and workspace sizing logic are sound; reconsider caching on instance method.
MNNVLAllreduceFusionStrategy.select_strategywithMNNVL_ONE_SHOT_THRESHOLDandelem_sizeis straightforward and keeps the heuristic in bytes.get_required_buffer_size_bytesmirrors the ONESHOT vs TWOSHOT layouts and is already cached as a@staticmethod.Given that,
is_buffer_size_sufficientis a very light wrapper over the static helper but is also decorated with@functools.cacheon an instance method, which can accumulate entries keyed byself(B019 concern) without much benefit.You can simplify and avoid the method-level cache by dropping
@functools.cachethere:- @functools.cache def is_buffer_size_sufficient( self, tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, ) -> bool:
get_required_buffer_size_byteswill still provide caching for the heavy calculation.Also applies to: 134-182
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/comm/mnnvl.py(19 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(5 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/comm/trtllm_mnnvl_ar.py
🧬 Code graph analysis (1)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/comm/mapping.py (6)
Mapping(21-475)rank(311-312)rank(315-322)tp_rank(325-326)local_rank(391-392)is_multi_node(403-404)flashinfer/jit/comm.py (1)
gen_trtllm_mnnvl_comm_module(33-39)flashinfer/utils.py (2)
register_custom_op(314-323)register_custom_op(333-352)flashinfer/comm/mnnvl.py (9)
McastGPUBuffer(1143-1224)lamport_initialize(1123-1140)lamport_initialize(1183-1184)get_buffer_ptrs_dev(857-859)get_buffer_ptrs_dev(1222-1224)get_unicast_ptr(861-869)get_unicast_ptr(1218-1220)get_multicast_ptr(871-875)get_multicast_ptr(1214-1216)csrc/trtllm_mnnvl_allreduce.cu (2)
trtllm_mnnvl_allreduce_fusion(29-113)trtllm_mnnvl_allreduce_fusion(29-35)
🪛 Ruff (0.14.6)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
896-896: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
flashinfer/comm/trtllm_mnnvl_ar.py
77-79: Avoid specifying long messages outside the exception class
(TRY003)
134-134: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
298-300: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
319-321: Avoid specifying long messages outside the exception class
(TRY003)
381-383: Avoid specifying long messages outside the exception class
(TRY003)
385-387: Avoid specifying long messages outside the exception class
(TRY003)
389-391: Avoid specifying long messages outside the exception class
(TRY003)
395-397: Avoid specifying long messages outside the exception class
(TRY003)
401-403: Avoid specifying long messages outside the exception class
(TRY003)
414-416: Avoid specifying long messages outside the exception class
(TRY003)
544-546: Avoid specifying long messages outside the exception class
(TRY003)
615-617: Avoid specifying long messages outside the exception class
(TRY003)
621-623: Avoid specifying long messages outside the exception class
(TRY003)
626-628: Avoid specifying long messages outside the exception class
(TRY003)
630-632: Avoid specifying long messages outside the exception class
(TRY003)
635-637: Avoid specifying long messages outside the exception class
(TRY003)
640-642: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
flashinfer/comm/mnnvl.py (1)
717-718: Handle-type selection, workspace getters, and Lamport initialization look consistent.
_shareable_handle_typecleanly distinguishes FABRIC vs POSIX-FD paths and is threaded through allocation/multicast props.- New
get_allocation_size/get_usable_buffer_sizeand the update ofMcastGPUBuffer.buf_sizeto the usable portion (excluding signal pad) make the Python-facing API less surprising.lamport_initializenow explicitly excludes the signal pad region, and using FP32 there matches the earlier rationale about word-aligned LDG.128 sentinels.get_multicast_buffer/get_unicast_bufferremaining asNotImplementedErrorstubs is a reasonable placeholder until you wire them viacreate_tensor_from_cuda_memory; callers will fail fast instead of silently misusing raw pointers.Also applies to: 744-766, 781-790, 885-892, 1135-1137, 1164-1181, 1186-1213, 1218-1221
flashinfer/comm/trtllm_mnnvl_ar.py (4)
266-341: Newtrtllm_mnnvl_allreduceworkspace size check and validation look good.
- 2D shape validation for
input/outputis explicit and user-friendly.- Strategy selection falls back to
AUTOand uses the sameselect_strategyheuristic as the workspace utilities.- Calling
workspace.is_buffer_size_sufficient(...)and raising with a detailed message (including current vs required bytes) restores the missing guard against buffer overrun that was raised in earlier reviews.This aligns the Python API’s safety guarantees with the underlying kernel expectations.
438-490: Deprecation wrapper forget_allreduce_mnnvl_workspacepreserves behavior.Redirecting the legacy workspace helper to
MNNVLAllreduceFusionWorkspacekeeps the public return type (McastGPUBuffer, flags tensor,max_num_elements) while converging all allocation through the new path and flag layout.The use of
strideandworkspace.buffer_size_bytes // strideto computemax_num_elementsis consistent with the original intent.
575-662: Legacy fused RMSNorm wrapper looks consistent with the new fused kernel.
- Shape checks for
shard_input,residual,gamma, and both outputs are thorough and mirror the new high-level API.- The
buffer_Mcheck preserves the legacy guard against oversized inputs.- Calling
trtllm_mnnvl_allreduce_fusionwithrmsnorm_fusion=True,use_oneshot=False, and wiringnormed_output/prenorm_output/residual/gamma/epsilonmatches the intended “two-shot + RMSNorm” behavior.Given the function is now deprecated in favor of
trtllm_mnnvl_fused_allreduce_add_rmsnorm, this looks like a safe compatibility bridge.
498-569: Unfortunately, the repository clone failed, which prevents me from directly inspecting the CUDA kernel implementation to verify the core concern aboutbuffer_ptr_localusage.The review comment raises a valid architectural concern about the contract between the deprecated Python wrapper and the underlying CUDA kernel:
- The wrapper passes
buffer_ptr_local=0with an assumption that the kernel "does not use this local pointer" whenrmsnorm_fusion=false- Without access to the CUDA kernel source (
csrc/trtllm_mnnvl_allreduce.cu), I cannot definitively verify whether this assumption is correct- If the assumption is wrong, this represents a latent bug; if correct, the concern is more about defensive programming
Since I cannot verify the technical facts needed to assess this concern (i.e., whether
buffer_ptr_localis genuinely unused in the non-RMSNorm code path), I must preserve the original review comment as-is.
Legacy
trtllm_mnnvl_all_reducewrapper relies onbuffer_ptr_local = 0.The deprecation wrapper now forwards to
trtllm_mnnvl_allreduce_fusionand passesbuffer_ptr_local=0with a comment that the allreduce kernel "does not use this local pointer".That contract is subtle and brittle: if the CUDA kernel ever starts using
buffer_ptr_localeven in the non-RMSNorm path, this wrapper could cause a hard-to-debug crash.If possible, consider:
- Threading through the actual local unicast pointer (e.g., via a workspace or by evolving callers), or
- Adding a clear assertion / comment in the CUDA side that
buffer_ptr_localis ignored whenrmsnorm_fusion == false, to document the invariant this wrapper depends on.[warn-only as it's legacy and deprecated, but worth making the dependency explicit.]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (1)
406-425: Potential deadlock:allgathercalled only by failing ranks.The
MPI.COMM_WORLD.allgather(rank_failed)call is inside theexceptblock, so only ranks that throw an exception will participate. If rank 0 fails while rank 1 succeeds, rank 1 continues to the barrier at line 433 while rank 0 blocks onallgather—causing a distributed deadlock.Consider moving failure synchronization outside the exception handler:
except Exception as e: rank_failed = True failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) print(traceback.format_exc()) - # Gather failure status from all ranks for logging - all_failures = MPI.COMM_WORLD.allgather(rank_failed) - - if any(all_failures): - failed_ranks = [i for i, failed in enumerate(all_failures) if failed] - if rank == 0: - print(f"Test failed on ranks: {failed_ranks}") - # Cleanup before re-raising if "workspace" in locals(): del workspace # Re-raise the original exception so it can be caught by pytest.raises in negative tests raise finally: # Ensure cleanup happens for this list's workspace if "workspace" in locals(): del workspace + + # Gather failure status from all ranks for logging (must be outside except to avoid deadlock) + all_failures = MPI.COMM_WORLD.allgather(rank_failed) + if any(all_failures): + failed_ranks = [i for i, failed in enumerate(all_failures) if failed] + if rank == 0: + print(f"Test failed on ranks: {failed_ranks}")Note: With this change, the
raiseat line 425 would need to be moved after thefinallyblock completes, or handled differently to ensure all ranks synchronize before any rank exits.
♻️ Duplicate comments (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (1)
257-264: Epsilon mismatch between reference computation and kernel execution.The reference RMSNorm uses
torch.finfo(dtype).eps(line 263), which is ~6e-8 for float16 and ~1e-7 for bfloat16, while the actual kernel execution useseps = 1e-5(line 322). This 100-1000x difference in epsilon values will produce different numerical results between reference and actual outputs.Additionally, line 257 has a type annotation issue:
Tuple[torch.Tensor, ...]assigned toNone.- reference_output: Tuple[torch.Tensor, ...] = None + reference_output: Optional[Tuple[torch.Tensor, ...]] = None if fusion: # Fused case: AllReduce + Residual Add + RMS Norm allreduce_result = torch.sum(x_full, dim=0) # AllReduce result residual_out = allreduce_result + residual # Add residual norm_out = rmsnorm( - residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False + residual_out, norm_weight, 1e-5, enable_pdl=False )
🧹 Nitpick comments (2)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)
60-69: Remove unnecessary tensor allocation.Line 61 allocates
output = torch.empty_like(input)but it's immediately overwritten by the return value oftrtllm_mnnvl_allreduceon line 63. This wastes GPU memory unnecessarily.else: - output = torch.empty_like(input) - output = trtllm_mnnvl_ar.trtllm_mnnvl_allreduce( input, workspace, launch_with_pdl=use_pdl, strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, ) return (output.view(shape),)
274-276: Unusedmonkeypatchparameter.The
monkeypatchfixture is passed but never used in the function body. If it's not needed, consider removing it from the signature and test function parameters.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/comm/test_trtllm_mnnvl_allreduce.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mapping.py (2)
Mapping(21-475)tp_rank(325-326)flashinfer/comm/trtllm_mnnvl_ar.py (5)
MNNVLAllreduceFusionWorkspace(50-181)MNNVLAllreduceFusionStrategy(30-43)trtllm_mnnvl_allreduce(266-341)get_allreduce_mnnvl_workspace(442-495)get_required_buffer_size_bytes(156-181)flashinfer/comm/mnnvl.py (2)
barrier(168-168)barrier(227-228)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)
105-124: LGTM!The legacy API test helper correctly passes buffer pointers and uses MPI barrier for synchronization.
439-465: LGTM!Good test coverage with comprehensive parameterization across sequence lengths, fusion modes, data types, and hidden sizes. The separation between refactored and legacy API tests is well-structured.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/mnnvl.py (1)
132-149: Clarify empty-input behavior inalloc_and_copy_to_cuda
alloc_and_copy_to_cudais annotated to returnintbut returnsNonewhenhost_ptr_arrayis empty (Line 136). That mismatch can surface as a runtime bug if an empty list ever slips through.Consider failing fast instead of returning
None, e.g.:def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: - """ - A helper function that allocates memory on cuda and copies the data from the host to the device. - """ - if not host_ptr_array: - return None + """ + A helper function that allocates memory on cuda and copies the data from the host to the device. + """ + if not host_ptr_array: + raise ValueError("host_ptr_array must be non-empty")This keeps the return type consistent and surfaces misuse early.
♻️ Duplicate comments (2)
flashinfer/comm/mnnvl.py (2)
300-305: Fixset_comm_from_configwhenconfigisNoneWhen
configisNone, you assign the fallback toMnnvlMemory.configbut still callconfig.comm_backend.Split(...)(Line 302), which will raiseAttributeError.Use the effective config for both the assignment and the
Splitcall:@staticmethod def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None): - MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] - comm = config.comm_backend.Split( - mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank - ) + config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] + MnnvlMemory.config = config + comm = config.comm_backend.Split( + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, + mapping.tp_rank, + ) MnnvlMemory.comm = comm # type: ignore[assignment]This matches the earlier suggested fix and ensures the fallback path actually works.
#!/bin/bash # Simple sanity check: search for other direct uses of bare `config` in this method. rg -n "set_comm_from_config" flashinfer/comm/mnnvl.py -n -C5
720-759: Close local POSIX shareable handles to avoid slow FD leaksThe new exchanger abstraction correctly closes imported POSIX FDs via
PosixFDHandleExchanger.cleanup()(Lines 760–762), but two categories of FDs still remain unclosed in the POSIX path:
local_shareable_uc_handleproduced bycuMemExportToShareableHandle(Lines 1065–1072) is never passed throughcleanup.- For the multicast handle, rank 0’s
shareable_mc_handle(Lines 1117–1125) is broadcast but only non‑root ranks callcleanupafter import (Lines 1135–1141).Over long‑running runs that create/destroy
McastDeviceMemoryrepeatedly, these will accumulate OS FDs.You can reuse the exchanger cleanup hook to fix both without special‑casing for handle type:
# All-gather shareable handles all_shareable_uc_handles = self._exchanger.allgather(local_shareable_uc_handle) cuda.cuCtxSynchronize() # Import remote handles for p in range(self.group_size): if p != self.group_rank: self.uc_handles[p] = checkCudaErrors( cuda.cuMemImportFromShareableHandle( all_shareable_uc_handles[p], self._exchanger.handle_type, ) ) self._exchanger.cleanup(all_shareable_uc_handles[p]) + + # We no longer need our own exported shareable handle. + self._exchanger.cleanup(local_shareable_uc_handle)And in
_setup_multicast:# Broadcast multicast handle from rank 0 shareable_mc_handle = self._exchanger.broadcast(shareable_mc_handle, root=0) cuda.cuCtxSynchronize() - # Import multicast handle for non-root ranks - if self.group_rank != 0: + # Import multicast handle for non-root ranks + if self.group_rank != 0: self.mc_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( shareable_mc_handle, self._exchanger.handle_type, ) ) self._exchanger.cleanup(shareable_mc_handle) + else: + # Root rank can now drop its exported handle (POSIX FD path). + self._exchanger.cleanup(shareable_mc_handle)
FabricHandleExchanger.cleanup()is a no‑op, so these calls are harmless in the fabric path and fix the leak in the POSIX‑FD path.#!/bin/bash # Check remaining uses of cuMemExportToShareableHandle and ensure every handle is cleaned up. rg -n "cuMemExportToShareableHandle" flashinfer/comm/mnnvl.py -n -C5Also applies to: 760-765, 1065-1076, 1079-1088, 1129-1142
🧹 Nitpick comments (3)
flashinfer/comm/mnnvl.py (3)
566-664: TidyIpcSocket.recv_fdunused variables and document/tmppath usage
recv_fdcurrently bindsmsg,flags, andaddrbut never uses them (Line 640), which Ruff flags (RUF059). You can keep the signature while silencing the warning:- fds = array.array("i") - msg, ancdata, flags, addr = self.sock.recvmsg( + fds = array.array("i") + _msg, ancdata, _flags, _addr = self.sock.recvmsg( 1, socket.CMSG_SPACE( fds.itemsize ), # Buffer size for dummy data # Ancillary data size )On the S108
/tmp/mcastmem-socket-warning: since you default touse_abstract=True, the filesystem path is never actually used in normal operation. If you expectuse_abstract=Falsein multi‑tenant environments, consider adding a brief comment noting that this path is intended for trusted deployments only.
1003-1013: Consider narrowing the broadExceptioncatch in_verify_cuda_context
_verify_cuda_contextcurrently catches a bareException(Lines 1005–1012). Given this is only used for diagnostics, that’s acceptable, but narrowing to CUDA‑related exceptions (or at least logging the exception type) would align better with BLE001 and avoid swallowing unrelated programming errors.Not strictly required, but worth considering:
- except Exception as e: - print(f"Error checking CUDA context: {e}") + except Exception as e: + # Broad catch is intentional: any CUDA context error should only warn here. + print(f"Error checking CUDA context: {type(e).__name__}: {e}")
1242-1268: Explicitly markget_multicast_buffer/get_unicast_bufferas internal or implement via the new helperBoth
get_multicast_bufferandget_unicast_buffercurrently raiseNotImplementedError(Lines 1257–1258 and 1267–1268). Given prior discussion that this class is internal, that’s acceptable, but the public‑sounding names can confuse users if they discover them.Two options:
- If they are truly internal placeholders, consider prefixing with
_or clarifying in the docstring that they are not expected to be called yet.- If you want them usable, wire them up through
create_tensor_from_cuda_memoryat the top of the file, e.g., usingself.get_multicast_ptr()/self.get_unicast_ptr(rank)plus the desired shape/dtype.Happy to sketch an implementation using
create_tensor_from_cuda_memoryif you plan to expose these in this PR.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/comm/mnnvl.py(19 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/comm/mnnvl.py (2)
flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)flashinfer/utils.py (1)
round_up(631-633)
🪛 Ruff (0.14.6)
flashinfer/comm/mnnvl.py
587-587: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
612-612: Probable insecure usage of temporary file or directory: "/tmp/mcastmem-socket-"
(S108)
640-640: Unpacked variable msg is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable flags is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
640-640: Unpacked variable addr is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
656-656: Avoid specifying long messages outside the exception class
(TRY003)
726-726: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
1011-1011: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/comm/mnnvl.py (2)
878-887: Allocation sizing and usable buffer exposure look consistent
get_allocation_sizeandget_usable_buffer_size(Lines 982–989) align withlamport_initialize, which usesallocation_size - SIGNAL_PAD_SIZE(Lines 1191–1193), andMcastGPUBuffernow surfaces this usable size viaself.buf_size(Lines 1235–1237).This keeps the Python view of capacity in sync with what Lamport initialization actually touches, which is a nice improvement over passing the raw requested size around.
Also applies to: 982-989, 1191-1193, 1235-1237
164-166: I'm unable to access the repository due to persistent cloning issues. However, based on the information provided in the review comment itself, I can provide analysis:The review comment acknowledges uncertainty with the phrase "if any exist in this repo", which suggests:
- The reviewer was aware that there may not be other
CommBackendimplementations- The reviewer already identified
MPIBackendas implementingbcast(lines 224-226)- The concern is conditional on whether other subclasses exist
Without direct repository access to verify all
CommBackendsubclasses and their implementations, I cannot definitively confirm or refute the original review comment's concern.Recommended next steps:
- Manually search the repository for classes inheriting from
CommBackend- Verify each implementation includes the
bcastmethod- Or provide confirmation if only
MPIBackendexists as aCommBackendsubclassSince I cannot complete the verification due to technical limitations:
New
bcastabstract requires allCommBackendimplementations to be updatedAdding
CommBackend.bcast(...)(Lines 164–166) and implementing it inMPIBackend(Lines 224–226) is correct, but any otherCommBackendsubclasses must also implementbcastto avoidTypeErrorat instantiation. Verify that all subclasses (if any exist) have been updated accordingly.
📌 Description
This PR porting all changes in TensorRT-LLM#8018 into Flashinfer.
Apart from the changes mentioned in the original PR, this PR also introduce new API interface as
trtllm_mnnvl_allreduceandtrtllm_mnnvl_fused_allreduce_add_rmsnormto replace the original ones. The workspace allocation is wrapped as an entire class with a given buffer size and the user does not need to worry about the details inside.This PR adds support for IPC Socket based mcast device memory bootstrap so that it can run on DGX machine that does not support fabric handle.
@wenscarl This PR also incorporate the changes made in #2056 and should be able to replace that PR. A bcast interface is added to the comm backend as this is needed during the handle transfer.
The old API is tagged as deprecated and redirected to the new APIs. The user of the old API should not need to make any changes.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.