@@ -864,6 +864,42 @@ def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous
864864 )
865865 return [prefill_kv_cache , ar_kv_cache ]
866866
867+ def forward_serve_vllm (
868+ self , query : Array , key : Array , value : Array , rpa_kv_cache : list [Array ], rpa_metadata : dict [str , Any ]
869+ ) -> tuple [list [Array ], Array ]:
870+ """Forward function for vLLM serving with RPA attention."""
871+ try :
872+ # pylint: disable=import-outside-toplevel
873+ from tpu_inference .layers .jax .attention_interface import sharded_ragged_paged_attention as rpa_ops
874+ except ImportError as e :
875+ raise ImportError (
876+ "vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
877+ ) from e
878+
879+ if self .config .attention_sink :
880+ raise NotImplementedError ("Attention sink is not supported in MaxText vLLM RPA attention." )
881+
882+ query = query .reshape (- 1 , query .shape [2 ], query .shape [3 ])
883+ key = key .reshape (- 1 , key .shape [2 ], key .shape [3 ])
884+ value = value .reshape (- 1 , value .shape [2 ], value .shape [3 ])
885+
886+ attention_chunk_size = self .config .chunk_attn_window_size if self .config .chunk_attn_window_size > 0 else None
887+ q_scale , k_scale , v_scale = None , None , None
888+
889+ md = rpa_metadata
890+
891+ output , kv_cache = rpa_ops (1.0 , self .mesh , attention_chunk_size , q_scale , k_scale , v_scale )(
892+ query ,
893+ key ,
894+ value ,
895+ rpa_kv_cache ,
896+ md .seq_lens ,
897+ md .block_tables ,
898+ md .query_start_loc ,
899+ md .request_distribution ,
900+ )
901+ return kv_cache , output
902+
867903 def __call__ (
868904 self ,
869905 inputs_q : Array ,
@@ -878,6 +914,8 @@ def __call__(
878914 slot : Optional [int ] = None ,
879915 page_state : Optional [page_manager .PageState ] = None ,
880916 bidirectional_mask : Any = None ,
917+ kv_cache : Optional [Array ] = None ,
918+ attention_metadata : Optional [dict [str , Any ]] = None ,
881919 ):
882920 """Applies Attention on the input data.
883921
@@ -905,6 +943,8 @@ def __call__(
905943 slot: The batch slot index for paged attention.
906944 page_state: The current state of the paged attention manager.
907945 bidirectional_mask: A mask for bidirectional attention, used in multimodal models.
946+ kv_cache: Optional KV cache input, used when invoking from vLLM.
947+ attention_metadata: Optional mapping to store attention metadata, used when invoking from vLLM.
908948
909949 Returns:
910950 output of shape `[batch, length, q_features]`.
@@ -1000,6 +1040,15 @@ def __call__(
10001040 query , key , value , decoder_segment_ids , model_mode , previous_chunk , slot = slot , page_state = page_state
10011041 )
10021042 out = unnormalized_out / (exp_sum + 1e-9 ) if exp_sum is not None else unnormalized_out
1043+
1044+ elif self .config .attention == "vllm_rpa" and attention_metadata and model_mode != MODEL_MODE_TRAIN :
1045+ batch , seq_len , num_heads , head_dim = query .shape
1046+ updated_kv , attn_out = self .forward_serve_vllm (
1047+ query , key , value , rpa_kv_cache = kv_cache , rpa_metadata = attention_metadata
1048+ )
1049+ out = attn_out .reshape (batch , seq_len , num_heads , head_dim )
1050+ kv_cache = updated_kv
1051+
10031052 else :
10041053 cached_values = [None , None ]
10051054 if model_mode != MODEL_MODE_TRAIN :
@@ -1028,4 +1077,4 @@ def __call__(
10281077 out = self ._maybe_shard_with_logical (out , self .decode_out_axis_names )
10291078 out = self .out_projection (out , out_sharding = out_sharding )
10301079 out = checkpoint_name (out , "out_proj" )
1031- return out
1080+ return out , kv_cache
0 commit comments