Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ def add_llm_args(parser):
default='bfloat16',
choices=['auto', 'float16', 'bfloat16', 'float32'],
help='Data type for Mamba SSM cache.')
parser.add_argument(
'--mamba_ssm_stochastic_rounding',
default=False,
action='store_true',
help=
'Enable stochastic rounding for Mamba SSM state updates (fp16 only, FlashInfer limitation).'
)
parser.add_argument(
'--mamba_ssm_philox_rounds',
type=int,
default=10,
help=
'Number of Philox rounds for stochastic rounding PRNG (default: 10). Higher values give better randomness.'
)
parser.add_argument('--log_kv_cache_events',
default=False,
action='store_true')
Expand Down Expand Up @@ -222,6 +236,8 @@ def setup_llm(args, **kwargs):
tokens_per_block=args.tokens_per_block,
use_kv_cache_manager_v2=args.use_kv_cache_manager_v2,
mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype,
mamba_ssm_stochastic_rounding=args.mamba_ssm_stochastic_rounding,
mamba_ssm_philox_rounds=args.mamba_ssm_philox_rounds,
event_buffer_max_size=1024 if args.log_kv_cache_events else 0)

spec_decode_algo = args.spec_decode_algo.upper(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ ordered-set
peft
patchelf
einops
flashinfer-python==0.6.4
flashinfer-python @ https://github.com/flashinfer-ai/flashinfer/releases/download/nightly-v0.6.5-20260308/flashinfer_python-0.6.5.dev20260308-py3-none-any.whl
opencv-python-headless
xgrammar==0.1.25
llguidance==0.7.29
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,10 @@ def prepare(self, attn_metadata: AttentionMetadata):
initial_states = [
num_cached_tokens_per_seq[i] > 0 for i in range(num_contexts)
]
self.has_initial_states[:num_contexts] = torch.tensor(
initial_states, dtype=torch.bool)
self.use_initial_states = any(initial_states)
if self.use_initial_states:
self.has_initial_states[:num_contexts] = torch.tensor(
initial_states, dtype=torch.bool)
self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets_triton(
self.cu_seqlens[:num_contexts + 1], self.chunk_size)
else:
Expand Down
94 changes: 64 additions & 30 deletions tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,31 @@ def __init__(
# Choose between flashinfer and native implementation. (default to flashinfer)
self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype
supported_head_dim_in_flashinfer = [64, 128]
if head_dim in supported_head_dim_in_flashinfer:
logger.info_once(
"Using flashinfer for selective state update for no MTP",
key="selective_state_update_no_mtp")
self._use_flashinfer = head_dim in supported_head_dim_in_flashinfer
# Stochastic rounding requires FlashInfer and fp16 cache
self._use_stochastic_rounding = (
config.quant_config.mamba_ssm_stochastic_rounding
and self._use_flashinfer
and self._mamba_ssm_cache_dtype == torch.float16)
self._philox_rounds = config.quant_config.mamba_ssm_philox_rounds

if self._use_flashinfer:
logger.info_once("Using flashinfer for selective state update",
key="selective_state_update")
self.selective_state_update_func_no_mtp = selective_state_update_fi
self.selective_state_update_func_mtp = selective_state_update_fi
else:
logger.info_once(
"Using native for selective state update for no MTP",
key="selective_state_update_no_mtp")
logger.info_once("Using native for selective state update",
key="selective_state_update")
self.selective_state_update_func_no_mtp = selective_state_update_native
# TODO: support MTP selective state update in flashinfer.
logger.info_once("Using native for selective state update for MTP",
key="selective_state_update_mtp")
self.selective_state_update_func_mtp = selective_state_update_native
self.selective_state_update_func_mtp = selective_state_update_native

# Warn if stochastic rounding was requested but couldn't be enabled
if config.quant_config.mamba_ssm_stochastic_rounding and not self._use_stochastic_rounding:
logger.warning_once(
f"Stochastic rounding requires FlashInfer and float16 SSM cache, "
f"but got head_dim={head_dim}, dtype={self._mamba_ssm_cache_dtype}. Disabled.",
key="stochastic_rounding_disabled")

# D
self.D = nn.Parameter(
Expand Down Expand Up @@ -409,6 +420,30 @@ def forward(
D = repeat(self.D, "h -> h p", p=self.head_dim)
if is_target_verify:
intermediate_ssm_states = layer_cache.intermediate_ssm
# Build kwargs for MTP selective_state_update
mtp_kwargs = dict(
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_d[:num_decodes],
out=preallocated_ssm_out_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
disable_state_update=True,
intermediate_states_buffer=intermediate_ssm_states,
cache_steps=draft_token_num,
intermediate_state_indices=self.intermediate_state_indices,
)
if self._use_stochastic_rounding:
mtp_kwargs['rand_seed'] = torch.randint(0,
2**62, (1, ),
device=x_d.device,
dtype=torch.int64)
mtp_kwargs['philox_rounds'] = self._philox_rounds

self.selective_state_update_func_mtp(
ssm_states,
x_d.view(
Expand All @@ -427,22 +462,26 @@ def forward(
B_d.view(num_decodes, draft_token_num, self.tp_ngroups, -1),
C_d.view(num_decodes, draft_token_num, self.tp_ngroups, -1),
D,
**mtp_kwargs,
)
else:
# Build kwargs for selective_state_update
ssu_kwargs = dict(
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_d[:num_decodes],
out=preallocated_ssm_out_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
disable_state_update=True,
intermediate_states_buffer=intermediate_ssm_states,
cache_steps=draft_token_num,
intermediate_state_indices=self.intermediate_state_indices,
dt_softplus=self.delta_softplus,
state_batch_indices=state_indices_d,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
)
else:

if self._use_stochastic_rounding:
ssu_kwargs['rand_seed'] = torch.randint(0,
2**62, (1, ),
device=x_d.device,
dtype=torch.int64)
ssu_kwargs['philox_rounds'] = self._philox_rounds

self.selective_state_update_func_no_mtp(
ssm_states,
x_d,
Expand All @@ -451,12 +490,7 @@ def forward(
B_d,
C_d,
D,
z=None,
dt_bias=dt_bias,
dt_softplus=self.delta_softplus,
state_batch_indices=state_indices_d,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
**ssu_kwargs,
)

# norm
Expand Down
13 changes: 10 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@
_VALID_KV_CACHE_DTYPES = ("fp8", "nvfp4", "auto")


def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig,
mamba_ssm_cache_dtype: str) -> None:
def validate_and_set_mamba_ssm_cache_dtype(
config: ModelConfig,
mamba_ssm_cache_dtype: str,
mamba_ssm_stochastic_rounding: bool = False,
mamba_ssm_philox_rounds: int = 10) -> None:
if mamba_ssm_cache_dtype == "auto":
hf_dtype = getattr(config.pretrained_config, "mamba_ssm_cache_dtype",
None)
Expand All @@ -47,6 +50,8 @@ def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig,
mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype)

config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
config.quant_config.mamba_ssm_stochastic_rounding = mamba_ssm_stochastic_rounding
config.quant_config.mamba_ssm_philox_rounds = mamba_ssm_philox_rounds


def validate_and_set_kv_cache_quant(model_config: ModelConfig,
Expand Down Expand Up @@ -421,7 +426,9 @@ def _load_and_validate_config(
validate_and_set_kv_cache_quant(config,
self.llm_args.kv_cache_config.dtype)
validate_and_set_mamba_ssm_cache_dtype(
config, self.llm_args.kv_cache_config.mamba_ssm_cache_dtype)
config, self.llm_args.kv_cache_config.mamba_ssm_cache_dtype,
self.llm_args.kv_cache_config.mamba_ssm_stochastic_rounding,
self.llm_args.kv_cache_config.mamba_ssm_philox_rounds)

# Allow overriding the number of layers via environment variable
# Note: This is kept for backward compatibility, but model_kwargs is preferred
Expand Down
16 changes: 16 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,6 +1921,22 @@ class KvCacheConfig(StrictBaseModel, PybindMirror):
"The data type to use for the Mamba SSM cache. If set to 'auto', the data type will be inferred from the model config."
)

# This is a pure python field, not a pybind field. It is only for the Pytorch backend.
mamba_ssm_stochastic_rounding: bool = Field(
default=False,
description=
"Enable stochastic rounding for Mamba SSM state updates. Only applicable with float16 cache dtype."
)

# This is a pure python field, not a pybind field. It is only for the Pytorch backend.
mamba_ssm_philox_rounds: int = Field(
default=10,
ge=1,
description=
"Number of Philox rounds for stochastic rounding PRNG. Higher values give better randomness "
"but increase compute cost. Only used when mamba_ssm_stochastic_rounding is enabled."
)

tokens_per_block: int = Field(default=32,
description="The number of tokens per block.")

Expand Down
11 changes: 11 additions & 0 deletions tensorrt_llm/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ class QuantConfig(StrictBaseModel):
description="Module name patterns that are skipped in quantization.")
mamba_ssm_cache_dtype: Optional[str] = Field(
default=None, description="Data type for mamba SSM cache.")
mamba_ssm_stochastic_rounding: bool = Field(
default=False,
description=
"Enable stochastic rounding for Mamba SSM state updates. Requires fp16 cache."
)
mamba_ssm_philox_rounds: int = Field(
default=10,
ge=1,
description=
"Number of Philox rounds for stochastic rounding PRNG. Higher values give better randomness."
)

@cached_property
def quant_mode(self) -> QuantModeWrapper:
Expand Down
Loading