Skip to content

Commit d0c2503

Browse files
committed
modifying decoders and attention for vllm.
removing calls into specialized attention modules. adding vllm_rpa unit test. fixing additional unit tests. adding validation support for vllm_rpa. rebasing deepseek and gpt-oss. adding skip for vllm-tpu test.
1 parent be7c2de commit d0c2503

File tree

17 files changed

+312
-76
lines changed

17 files changed

+312
-76
lines changed

src/MaxText/layers/attention_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ def apply_attention(
852852
or (self.attention_kernel == "autoselected" and model_mode == MODEL_MODE_AUTOREGRESSIVE)
853853
or (self.attention_kernel == "autoselected" and length < 128)
854854
or (self.attention_kernel == "paged")
855+
or (self.attention_kernel == "vllm_rpa")
855856
):
856857
return self.apply_attention_dot(
857858
query,

src/MaxText/layers/attentions.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,42 @@ def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous
864864
)
865865
return [prefill_kv_cache, ar_kv_cache]
866866

867+
def forward_serve_vllm(
868+
self, query: Array, key: Array, value: Array, rpa_kv_cache: list[Array], rpa_metadata: dict[str, Any]
869+
) -> tuple[list[Array], Array]:
870+
"""Forward function for vLLM serving with RPA attention."""
871+
try:
872+
# pylint: disable=import-outside-toplevel
873+
from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops
874+
except ImportError as e:
875+
raise ImportError(
876+
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
877+
) from e
878+
879+
if self.config.attention_sink:
880+
raise NotImplementedError("Attention sink is not supported in MaxText vLLM RPA attention.")
881+
882+
query = query.reshape(-1, query.shape[2], query.shape[3])
883+
key = key.reshape(-1, key.shape[2], key.shape[3])
884+
value = value.reshape(-1, value.shape[2], value.shape[3])
885+
886+
attention_chunk_size = self.config.chunk_attn_window_size if self.config.chunk_attn_window_size > 0 else None
887+
q_scale, k_scale, v_scale = None, None, None
888+
889+
md = rpa_metadata
890+
891+
output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)(
892+
query,
893+
key,
894+
value,
895+
rpa_kv_cache,
896+
md.seq_lens,
897+
md.block_tables,
898+
md.query_start_loc,
899+
md.request_distribution,
900+
)
901+
return kv_cache, output
902+
867903
def __call__(
868904
self,
869905
inputs_q: Array,
@@ -878,6 +914,8 @@ def __call__(
878914
slot: Optional[int] = None,
879915
page_state: Optional[page_manager.PageState] = None,
880916
bidirectional_mask: Any = None,
917+
kv_cache: Optional[Array] = None,
918+
attention_metadata: Optional[dict[str, Any]] = None,
881919
):
882920
"""Applies Attention on the input data.
883921
@@ -905,6 +943,8 @@ def __call__(
905943
slot: The batch slot index for paged attention.
906944
page_state: The current state of the paged attention manager.
907945
bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
946+
kv_cache: Optional KV cache input, used when invoking from vLLM.
947+
attention_metadata: Optional mapping to store attention metadata, used when invoking from vLLM.
908948
909949
Returns:
910950
output of shape `[batch, length, q_features]`.
@@ -1000,6 +1040,15 @@ def __call__(
10001040
query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state
10011041
)
10021042
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
1043+
1044+
elif self.config.attention == "vllm_rpa" and attention_metadata and model_mode != MODEL_MODE_TRAIN:
1045+
batch, seq_len, num_heads, head_dim = query.shape
1046+
updated_kv, attn_out = self.forward_serve_vllm(
1047+
query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata
1048+
)
1049+
out = attn_out.reshape(batch, seq_len, num_heads, head_dim)
1050+
kv_cache = updated_kv
1051+
10031052
else:
10041053
cached_values = [None, None]
10051054
if model_mode != MODEL_MODE_TRAIN:
@@ -1028,4 +1077,4 @@ def __call__(
10281077
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)
10291078
out = self.out_projection(out, out_sharding=out_sharding)
10301079
out = checkpoint_name(out, "out_proj")
1031-
return out
1080+
return out, kv_cache

src/MaxText/layers/decoders.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __call__(
8787
previous_chunk=None,
8888
slot: None | int = None,
8989
page_state: None | page_manager.PageState = None,
90+
kv_cache: jax.Array | None = None,
91+
attention_metadata: dict[str, Any] | None = None,
9092
):
9193
cfg = self.config
9294
mesh = self.mesh
@@ -149,13 +151,15 @@ def __call__(
149151
model_mode=model_mode,
150152
)
151153

152-
attention_lnx = attention_layer(
154+
attention_lnx, kv_cache = attention_layer(
153155
lnx,
154156
lnx,
155157
decoder_positions,
156158
decoder_segment_ids=decoder_segment_ids,
157159
deterministic=deterministic,
158160
model_mode=model_mode,
161+
kv_cache=kv_cache,
162+
attention_metadata=attention_metadata,
159163
)
160164

161165
if model_mode == MODEL_MODE_PREFILL:
@@ -209,7 +213,7 @@ def __call__(
209213
jnp.sum(layer_output == 0) / jnp.size(layer_output),
210214
)
211215

212-
return layer_output, None if cfg.scan_layers else layer_output
216+
return layer_output, None if cfg.scan_layers else layer_output, kv_cache
213217

214218

215219
class SequentialBlockDecoderLayers(nn.Module):
@@ -690,6 +694,8 @@ def __call__(
690694
bidirectional_mask: None | Any = None,
691695
image_embeddings: None | jnp.ndarray = None,
692696
image_masks: None | jnp.ndarray = None,
697+
kv_caches: list[jax.Array] | None = None,
698+
attention_metadata=None,
693699
):
694700
cfg = self.config
695701
mesh = self.mesh
@@ -843,7 +849,8 @@ def __call__(
843849
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
844850
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
845851
for index in range(num_layers):
846-
y = layer(
852+
kv_cache = kv_caches[index] if kv_caches is not None else None
853+
y, kv_cache = layer(
847854
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
848855
)(
849856
y,
@@ -854,7 +861,11 @@ def __call__(
854861
previous_chunk=previous_chunk,
855862
page_state=page_state,
856863
slot=slot,
864+
kv_cache=kv_cache,
865+
attention_metadata=attention_metadata,
857866
)
867+
if kv_caches is not None:
868+
kv_caches[index] = kv_cache
858869
else:
859870
for lyr in range(cfg.num_decoder_layers):
860871
RemattedBlockLayer = RemattedBlockLayers[0]
@@ -876,7 +887,8 @@ def __call__(
876887
layer = RemattedBlockLayer(
877888
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
878889
)
879-
y = layer(
890+
kv_cache = kv_caches[lyr] if kv_caches is not None else None
891+
y, kv_cache = layer(
880892
y,
881893
decoder_segment_ids,
882894
decoder_positions,
@@ -885,8 +897,12 @@ def __call__(
885897
previous_chunk=previous_chunk,
886898
page_state=page_state,
887899
slot=slot,
900+
kv_cache=kv_cache,
901+
attention_metadata=attention_metadata,
888902
**layer_call_kwargs,
889903
)
904+
if kv_caches is not None:
905+
kv_caches[lyr] = kv_cache
890906

891907
assert isinstance(y, jax.Array)
892908

@@ -903,7 +919,7 @@ def __call__(
903919

904920
# The API of the Decoder is now a tuple, providing both the main output
905921
# and the raw hidden state needed for auxiliary tasks.
906-
return logits, hidden_state
922+
return logits, hidden_state, kv_caches
907923

908924
def _apply_gemma3_scanned_blocks(
909925
self,

src/MaxText/layers/deepseek.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def __init__(
5454
self.quant = quant
5555
self.rngs = rngs
5656

57-
batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(
58-
self.config, self.model_mode
59-
)
57+
batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode)
6058
self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim)
6159

6260
self.pre_self_attention_layer_norm = RMSNorm(
@@ -108,9 +106,7 @@ def __init__(
108106
rngs=rngs,
109107
)
110108

111-
self.dropout = Dropout(
112-
rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs
113-
)
109+
self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs)
114110

115111
def __call__(
116112
self,
@@ -122,6 +118,8 @@ def __call__(
122118
previous_chunk=None,
123119
page_state: None | page_manager.PageState = None,
124120
slot: None | int = None,
121+
kv_cache=None,
122+
attention_metadata=None,
125123
):
126124
x = self.with_logical_constraint(inputs)
127125
x = checkpoint_name(x, "decoder_layer_input")
@@ -141,7 +139,7 @@ def __call__(
141139
layer_output = mlp_lnx + intermediate_inputs
142140
layer_output = self.dropout_op(layer_output, deterministic=deterministic)
143141

144-
return self.post_process(layer_output)
142+
return self.post_process(layer_output, kv_cache)
145143

146144
def mlp_op(self, x, deterministic):
147145
"""Executes the MLP operation. To be implemented by subclasses."""
@@ -151,9 +149,7 @@ def with_logical_constraint(self, x):
151149
return nn.with_logical_constraint(x, self.logical_axis_names)
152150

153151
def dropout_op(self, x, deterministic):
154-
return self.with_logical_constraint(
155-
self.dropout(x, deterministic=deterministic)
156-
)
152+
return self.with_logical_constraint(self.dropout(x, deterministic=deterministic))
157153

158154
def pre_attention_norm_op(self, x):
159155
return self.with_logical_constraint(self.pre_self_attention_layer_norm(x))
@@ -201,7 +197,7 @@ def logical_axis_names(self):
201197
"activation_embed",
202198
)
203199

204-
def post_process(self, layer_output):
200+
def post_process(self, layer_output, kv_cache):
205201
"""postprocessing."""
206202
if self.config.record_internal_nn_metrics:
207203
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
@@ -214,7 +210,7 @@ def post_process(self, layer_output):
214210

215211
if self.config.scan_layers:
216212
return layer_output, None
217-
return layer_output
213+
return layer_output, kv_cache
218214

219215
def self_attention_with_norm_op(
220216
self,
@@ -300,9 +296,7 @@ def __init__(
300296
self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE(
301297
config=self.config,
302298
mesh=mesh,
303-
kernel_init=initializers.nd_dense_init(
304-
1.0, "fan_in", "truncated_normal"
305-
),
299+
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
306300
kernel_axes=("embed", None),
307301
dtype=self.config.dtype,
308302
weight_dtype=self.config.weight_dtype,

src/MaxText/layers/gemma.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def __call__(
129129
page_manager=None,
130130
page_state=None,
131131
slot=None,
132+
kv_cache=None,
133+
attention_metadata=None,
132134
):
133135
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
134136
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -137,13 +139,15 @@ def __call__(
137139

138140
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
139141

140-
attention_lnx = self.self_attention(
142+
attention_lnx, kv_cache = self.self_attention(
141143
lnx,
142144
lnx,
143145
decoder_positions,
144146
decoder_segment_ids=decoder_segment_ids,
145147
deterministic=deterministic,
146148
model_mode=model_mode,
149+
kv_cache=kv_cache,
150+
attention_metadata=attention_metadata,
147151
)
148152

149153
attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names)
@@ -177,7 +181,7 @@ def __call__(
177181
if self.config.scan_layers:
178182
return layer_output, None
179183
else:
180-
return layer_output
184+
return layer_output, kv_cache
181185

182186

183187
GemmaDecoderLayerToLinen = nnx_wrappers.to_linen_class(

src/MaxText/layers/gemma2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,20 +223,24 @@ def __call__(
223223
previous_chunk=None,
224224
page_state=None,
225225
slot=None,
226+
kv_cache=None,
227+
attention_metadata=None,
226228
):
227229
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
228230
inputs = checkpoint_name(inputs, "decoder_layer_input")
229231
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
230232
lnx = self.pre_self_attention_norm_local(inputs)
231233
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
232234

233-
attention_lnx = self.self_attention_local(
235+
attention_lnx, kv_cache = self.self_attention_local(
234236
lnx,
235237
lnx,
236238
decoder_positions,
237239
decoder_segment_ids=decoder_segment_ids,
238240
deterministic=deterministic,
239241
model_mode=model_mode,
242+
kv_cache=kv_cache,
243+
attention_metadata=attention_metadata,
240244
)
241245
if self.config.use_post_attn_norm:
242246
attention_lnx = self.post_self_attention_norm_local(attention_lnx)
@@ -311,7 +315,7 @@ def __call__(
311315
if self.config.scan_layers:
312316
return layer_output, None
313317
else:
314-
return layer_output
318+
return layer_output, kv_cache
315319

316320

317321
Gemma2DecoderLayerToLinen = nnx_wrappers.to_linen_class(

src/MaxText/layers/gemma3.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def __call__(
189189
page_state=None,
190190
slot=None,
191191
bidirectional_mask=None,
192+
kv_cache=None,
193+
attention_metadata=None,
192194
):
193195
cfg = self.config
194196
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
@@ -198,14 +200,16 @@ def __call__(
198200
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
199201

200202
# Self-attention block
201-
attention_lnx = self.self_attention(
203+
attention_lnx, kv_cache = self.self_attention(
202204
lnx,
203205
lnx,
204206
decoder_positions,
205207
decoder_segment_ids=decoder_segment_ids,
206208
deterministic=deterministic,
207209
model_mode=model_mode,
208210
bidirectional_mask=bidirectional_mask,
211+
kv_cache=kv_cache,
212+
attention_metadata=attention_metadata,
209213
)
210214
if cfg.use_post_attn_norm:
211215
attention_lnx = self.post_self_attention_norm(attention_lnx)
@@ -240,7 +244,7 @@ def __call__(
240244
if cfg.scan_layers:
241245
return layer_output, None
242246
else:
243-
return layer_output
247+
return layer_output, kv_cache
244248

245249

246250
Gemma3DecoderLayerToLinen = nnx_wrappers.to_linen_class(

0 commit comments

Comments
 (0)