Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ def apply_attention(
or (self.attention_kernel == "autoselected" and model_mode == MODEL_MODE_AUTOREGRESSIVE)
or (self.attention_kernel == "autoselected" and length < 128)
or (self.attention_kernel == "paged")
or (self.attention_kernel == "vllm_rpa")
):
return self.apply_attention_dot(
query,
Expand Down
51 changes: 50 additions & 1 deletion src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,42 @@ def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous
)
return [prefill_kv_cache, ar_kv_cache]

def forward_serve_vllm(
self, query: Array, key: Array, value: Array, rpa_kv_cache: list[Array], rpa_metadata: dict[str, Any]
) -> tuple[list[Array], Array]:
"""Forward function for vLLM serving with RPA attention."""
try:
# pylint: disable=import-outside-toplevel
from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops
except ImportError as e:
raise ImportError(
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
) from e

if self.config.attention_sink:
raise NotImplementedError("Attention sink is not supported in MaxText vLLM RPA attention.")

query = query.reshape(-1, query.shape[2], query.shape[3])
key = key.reshape(-1, key.shape[2], key.shape[3])
value = value.reshape(-1, value.shape[2], value.shape[3])

attention_chunk_size = self.config.chunk_attn_window_size if self.config.chunk_attn_window_size > 0 else None
q_scale, k_scale, v_scale = None, None, None

md = rpa_metadata

output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)(
query,
key,
value,
rpa_kv_cache,
md.seq_lens,
md.block_tables,
md.query_start_loc,
md.request_distribution,
)
return kv_cache, output

def __call__(
self,
inputs_q: Array,
Expand All @@ -878,6 +914,8 @@ def __call__(
slot: Optional[int] = None,
page_state: Optional[page_manager.PageState] = None,
bidirectional_mask: Any = None,
kv_cache: Optional[Array] = None,
attention_metadata: Optional[dict[str, Any]] = None,
):
"""Applies Attention on the input data.

Expand Down Expand Up @@ -905,6 +943,8 @@ def __call__(
slot: The batch slot index for paged attention.
page_state: The current state of the paged attention manager.
bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
kv_cache: Optional KV cache input, used when invoking from vLLM.
attention_metadata: Optional mapping to store attention metadata, used when invoking from vLLM.

Returns:
output of shape `[batch, length, q_features]`.
Expand Down Expand Up @@ -1000,6 +1040,15 @@ def __call__(
query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state
)
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out

elif self.config.attention == "vllm_rpa" and attention_metadata and model_mode != MODEL_MODE_TRAIN:
batch, seq_len, num_heads, head_dim = query.shape
updated_kv, attn_out = self.forward_serve_vllm(
query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata
)
out = attn_out.reshape(batch, seq_len, num_heads, head_dim)
kv_cache = updated_kv

else:
cached_values = [None, None]
if model_mode != MODEL_MODE_TRAIN:
Expand Down Expand Up @@ -1028,4 +1077,4 @@ def __call__(
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)
out = self.out_projection(out, out_sharding=out_sharding)
out = checkpoint_name(out, "out_proj")
return out
return out, kv_cache
26 changes: 21 additions & 5 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __call__(
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
kv_cache: jax.Array | None = None,
attention_metadata: dict[str, Any] | None = None,
):
cfg = self.config
mesh = self.mesh
Expand Down Expand Up @@ -149,13 +151,15 @@ def __call__(
model_mode=model_mode,
)

attention_lnx = attention_layer(
attention_lnx, kv_cache = attention_layer(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)

if model_mode == MODEL_MODE_PREFILL:
Expand Down Expand Up @@ -209,7 +213,7 @@ def __call__(
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)

return layer_output, None if cfg.scan_layers else layer_output
return layer_output, None if cfg.scan_layers else layer_output, kv_cache


class SequentialBlockDecoderLayers(nn.Module):
Expand Down Expand Up @@ -690,6 +694,8 @@ def __call__(
bidirectional_mask: None | Any = None,
image_embeddings: None | jnp.ndarray = None,
image_masks: None | jnp.ndarray = None,
kv_caches: list[jax.Array] | None = None,
attention_metadata=None,
):
cfg = self.config
mesh = self.mesh
Expand Down Expand Up @@ -843,7 +849,8 @@ def __call__(
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
for index in range(num_layers):
y = layer(
kv_cache = kv_caches[index] if kv_caches is not None else None
y, kv_cache = layer(
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
)(
y,
Expand All @@ -854,7 +861,11 @@ def __call__(
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
if kv_caches is not None:
kv_caches[index] = kv_cache
else:
for lyr in range(cfg.num_decoder_layers):
RemattedBlockLayer = RemattedBlockLayers[0]
Expand All @@ -876,7 +887,8 @@ def __call__(
layer = RemattedBlockLayer(
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
)
y = layer(
kv_cache = kv_caches[lyr] if kv_caches is not None else None
y, kv_cache = layer(
y,
decoder_segment_ids,
decoder_positions,
Expand All @@ -885,8 +897,12 @@ def __call__(
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
**layer_call_kwargs,
)
if kv_caches is not None:
kv_caches[lyr] = kv_cache

assert isinstance(y, jax.Array)

Expand All @@ -903,7 +919,7 @@ def __call__(

# The API of the Decoder is now a tuple, providing both the main output
# and the raw hidden state needed for auxiliary tasks.
return logits, hidden_state
return logits, hidden_state, kv_caches

def _apply_gemma3_scanned_blocks(
self,
Expand Down
24 changes: 9 additions & 15 deletions src/MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def __init__(
self.quant = quant
self.rngs = rngs

batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(
self.config, self.model_mode
)
batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode)
self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim)

self.pre_self_attention_layer_norm = RMSNorm(
Expand Down Expand Up @@ -108,9 +106,7 @@ def __init__(
rngs=rngs,
)

self.dropout = Dropout(
rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs
)
self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs)

def __call__(
self,
Expand All @@ -122,6 +118,8 @@ def __call__(
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot: None | int = None,
kv_cache=None,
attention_metadata=None,
):
x = self.with_logical_constraint(inputs)
x = checkpoint_name(x, "decoder_layer_input")
Expand All @@ -141,7 +139,7 @@ def __call__(
layer_output = mlp_lnx + intermediate_inputs
layer_output = self.dropout_op(layer_output, deterministic=deterministic)

return self.post_process(layer_output)
return self.post_process(layer_output, kv_cache)

def mlp_op(self, x, deterministic):
"""Executes the MLP operation. To be implemented by subclasses."""
Expand All @@ -151,9 +149,7 @@ def with_logical_constraint(self, x):
return nn.with_logical_constraint(x, self.logical_axis_names)

def dropout_op(self, x, deterministic):
return self.with_logical_constraint(
self.dropout(x, deterministic=deterministic)
)
return self.with_logical_constraint(self.dropout(x, deterministic=deterministic))

def pre_attention_norm_op(self, x):
return self.with_logical_constraint(self.pre_self_attention_layer_norm(x))
Expand Down Expand Up @@ -201,7 +197,7 @@ def logical_axis_names(self):
"activation_embed",
)

def post_process(self, layer_output):
def post_process(self, layer_output, kv_cache):
"""postprocessing."""
if self.config.record_internal_nn_metrics:
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
Expand All @@ -214,7 +210,7 @@ def post_process(self, layer_output):

if self.config.scan_layers:
return layer_output, None
return layer_output
return layer_output, kv_cache

def self_attention_with_norm_op(
self,
Expand Down Expand Up @@ -300,9 +296,7 @@ def __init__(
self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE(
config=self.config,
mesh=mesh,
kernel_init=initializers.nd_dense_init(
1.0, "fan_in", "truncated_normal"
),
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
Expand Down
8 changes: 6 additions & 2 deletions src/MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def __call__(
page_manager=None,
page_state=None,
slot=None,
kv_cache=None,
attention_metadata=None,
):
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
Expand All @@ -137,13 +139,15 @@ def __call__(

lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)

attention_lnx = self.self_attention(
attention_lnx, kv_cache = self.self_attention(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)

attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names)
Expand Down Expand Up @@ -177,7 +181,7 @@ def __call__(
if self.config.scan_layers:
return layer_output, None
else:
return layer_output
return layer_output, kv_cache


GemmaDecoderLayerToLinen = nnx_wrappers.to_linen_class(
Expand Down
8 changes: 6 additions & 2 deletions src/MaxText/layers/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,24 @@ def __call__(
previous_chunk=None,
page_state=None,
slot=None,
kv_cache=None,
attention_metadata=None,
):
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
lnx = self.pre_self_attention_norm_local(inputs)
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)

attention_lnx = self.self_attention_local(
attention_lnx, kv_cache = self.self_attention_local(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
if self.config.use_post_attn_norm:
attention_lnx = self.post_self_attention_norm_local(attention_lnx)
Expand Down Expand Up @@ -311,7 +315,7 @@ def __call__(
if self.config.scan_layers:
return layer_output, None
else:
return layer_output
return layer_output, kv_cache


Gemma2DecoderLayerToLinen = nnx_wrappers.to_linen_class(
Expand Down
8 changes: 6 additions & 2 deletions src/MaxText/layers/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def __call__(
page_state=None,
slot=None,
bidirectional_mask=None,
kv_cache=None,
attention_metadata=None,
):
cfg = self.config
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
Expand All @@ -198,14 +200,16 @@ def __call__(
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)

# Self-attention block
attention_lnx = self.self_attention(
attention_lnx, kv_cache = self.self_attention(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
bidirectional_mask=bidirectional_mask,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
if cfg.use_post_attn_norm:
attention_lnx = self.post_self_attention_norm(attention_lnx)
Expand Down Expand Up @@ -240,7 +244,7 @@ def __call__(
if cfg.scan_layers:
return layer_output, None
else:
return layer_output
return layer_output, kv_cache


Gemma3DecoderLayerToLinen = nnx_wrappers.to_linen_class(
Expand Down
Loading
Loading