|
| 1 | +# Copyright 2023–2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""vLLM adapter for MaxText models.""" |
| 16 | + |
| 17 | +import jax |
| 18 | +import pathlib |
| 19 | +import os |
| 20 | +import jax.numpy as jnp |
| 21 | + |
| 22 | +from flax import nnx |
| 23 | +from jax.sharding import Mesh |
| 24 | +from MaxText import model_creation_utils |
| 25 | +from MaxText import pyconfig |
| 26 | +from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE |
| 27 | +from MaxText.globals import MAXTEXT_PKG_DIR |
| 28 | +from MaxText.utils import gcs_utils |
| 29 | + |
| 30 | +from tpu_inference.layers.common.attention_metadata import AttentionMetadata |
| 31 | +from vllm.config import VllmConfig |
| 32 | + |
| 33 | + |
| 34 | +def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters: |
| 35 | + """Generates a MaxText configuration from a vLLM configuration. |
| 36 | +
|
| 37 | + This function takes a vLLM configuration object and translates relevant |
| 38 | + parameters into a MaxText `HyperParameters` object. It handles loading |
| 39 | + paths and model names from the vLLM config, and applies a base MaxText |
| 40 | + vLLM configuration file. |
| 41 | +
|
| 42 | + Args: |
| 43 | + vllm_config: The vLLM configuration object containing model and load |
| 44 | + parameters. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + A `pyconfig.HyperParameters` object configured for MaxText. |
| 48 | +
|
| 49 | + Raises: |
| 50 | + ValueError: If `hf_config_path` is not provided in the vLLM model config. |
| 51 | + """ |
| 52 | + |
| 53 | + def _path_exists(path: str) -> bool: |
| 54 | + if not path: |
| 55 | + return False |
| 56 | + return os.path.exists(path) or gcs_utils.gcs_path_exists(path) |
| 57 | + |
| 58 | + if "maxtext_config" in vllm_config.additional_config: |
| 59 | + overrides = vllm_config.additional_config["maxtext_config"] |
| 60 | + else: |
| 61 | + overrides = {} |
| 62 | + load_path = None |
| 63 | + if _path_exists(vllm_config.load.download_dir): |
| 64 | + load_path = vllm_config.load.download_dir |
| 65 | + elif _path_exists(vllm_config.model.model): |
| 66 | + load_path = vllm_config.model.model |
| 67 | + |
| 68 | + if load_path: |
| 69 | + overrides["load_parameters_path"] = load_path |
| 70 | + elif vllm_config.model.model: |
| 71 | + overrides["model_name"] = vllm_config.model.model |
| 72 | + |
| 73 | + if vllm_config.model_config.hf_config_path is None: |
| 74 | + raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.") |
| 75 | + |
| 76 | + # Add base config path to positional args |
| 77 | + base_config_path = pathlib.Path(MAXTEXT_PKG_DIR) / "configs" / "vllm.yml" |
| 78 | + argv_list = ["", str(base_config_path)] |
| 79 | + |
| 80 | + maxtext_config = pyconfig.initialize(argv_list, **overrides) |
| 81 | + return maxtext_config |
| 82 | + |
| 83 | + |
| 84 | +class MaxTextDecoderModel(nnx.Module): |
| 85 | + """A vLLM-compatible decoder model wrapper for MaxText. |
| 86 | +
|
| 87 | + This class adapts a MaxText model for use within the vLLM framework, |
| 88 | + handling configuration generation, model initialization, and execution |
| 89 | + of the decoding step. |
| 90 | + """ |
| 91 | + |
| 92 | + def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> None: |
| 93 | + """Initializes the MaxTextDecoderModel. |
| 94 | +
|
| 95 | + Args: |
| 96 | + vllm_config: The vLLM configuration object. |
| 97 | + rng_key: A JAX random key for model initialization. |
| 98 | + mesh: The JAX mesh device for model sharding. |
| 99 | + """ |
| 100 | + self.vllm_config = vllm_config |
| 101 | + self.maxtext_config = generate_maxtext_config(vllm_config) |
| 102 | + |
| 103 | + # Model configuration |
| 104 | + self.mesh = mesh |
| 105 | + self.model_mode = MODEL_MODE_AUTOREGRESSIVE |
| 106 | + |
| 107 | + # Model creation |
| 108 | + self.model: nnx.Module | None = None |
| 109 | + self.logits: jax.Array | None = None |
| 110 | + |
| 111 | + def __call__( |
| 112 | + self, |
| 113 | + kv_caches: list[jax.Array], |
| 114 | + input_ids: jax.Array, |
| 115 | + attention_metadata: AttentionMetadata, |
| 116 | + *args, |
| 117 | + **kwargs, |
| 118 | + ) -> tuple[list[jax.Array], jax.Array]: |
| 119 | + """Performs a forward pass through the decoder model. |
| 120 | +
|
| 121 | + Args: |
| 122 | + kv_caches: A list of JAX arrays representing the KV caches. |
| 123 | + input_ids: A JAX array of input token IDs. |
| 124 | + attention_metadata: Attention metadata for the decoding process. |
| 125 | + *args: Variable length argument list. |
| 126 | + **kwargs: Arbitrary keyword arguments. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + A tuple containing: |
| 130 | + - updated_kv_caches: A list of updated KV caches. |
| 131 | + - hidden: The hidden states (Q, d_model). |
| 132 | + - aux_hidden_states: A list of auxiliary hidden states. |
| 133 | +
|
| 134 | + Raises: |
| 135 | + ValueError: If the model is not an instance of `nnx.Module`. |
| 136 | + """ |
| 137 | + if not isinstance(self.model, nnx.Module): |
| 138 | + raise ValueError("Model must be an instance of type nnx.Module.") |
| 139 | + |
| 140 | + if input_ids.ndim < 2: |
| 141 | + input_ids = jnp.expand_dims(input_ids, axis=0) |
| 142 | + |
| 143 | + input_positions = attention_metadata.input_positions |
| 144 | + if input_positions.ndim < 2: |
| 145 | + input_positions = jnp.expand_dims(input_positions, axis=0) |
| 146 | + |
| 147 | + aux_hidden_states = [] |
| 148 | + logits, hidden, kv_caches = self.model( |
| 149 | + decoder_input_tokens=input_ids, |
| 150 | + decoder_positions=input_positions, |
| 151 | + kv_caches=kv_caches, |
| 152 | + attention_metadata=attention_metadata, |
| 153 | + model_mode=self.model_mode, |
| 154 | + **kwargs, |
| 155 | + ) |
| 156 | + if hidden.ndim > 1: |
| 157 | + hidden = jnp.squeeze(hidden, axis=0) |
| 158 | + logits = jnp.squeeze(logits, axis=0) |
| 159 | + |
| 160 | + self.logits = logits # cache logits for compute_logits call |
| 161 | + |
| 162 | + return kv_caches, hidden, aux_hidden_states |
| 163 | + |
| 164 | + def compute_logits(self, hidden_states: jax.Array) -> jax.Array: |
| 165 | + """Computes the logits from the hidden states. |
| 166 | +
|
| 167 | + Args: |
| 168 | + hidden_states: A JAX array of hidden states. |
| 169 | +
|
| 170 | + Returns: |
| 171 | + A JAX array of logits (Q, vocab_size). |
| 172 | + """ |
| 173 | + if self.logits is not None: |
| 174 | + return self.logits |
| 175 | + |
| 176 | + embeddings = self.model.token_embedder |
| 177 | + # pylint: disable=protected-access |
| 178 | + return self.model.decoder._apply_output_head(embeddings, hidden_states, True, self.model_mode) |
| 179 | + |
| 180 | + def load_weights(self, rng_key: jax.Array) -> None: |
| 181 | + """Loads model parameters on the provided mesh. |
| 182 | +
|
| 183 | + Args: |
| 184 | + rng_key: A JAX random key for model initialization. |
| 185 | + """ |
| 186 | + self.model, _ = model_creation_utils.create_nnx_model( |
| 187 | + self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key |
| 188 | + ) |
| 189 | + |
| 190 | + |
| 191 | +class MaxTextForCausalLM(nnx.Module): |
| 192 | + """A vLLM-compatible causal language model wrapper for MaxText. |
| 193 | +
|
| 194 | + This class serves as the primary interface for integrating MaxText models |
| 195 | + into the vLLM serving framework, specifically for causal language modeling |
| 196 | + tasks. It wraps the `MaxTextDecoderModel` and exposes methods expected |
| 197 | + by vLLM. |
| 198 | + """ |
| 199 | + |
| 200 | + def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh): |
| 201 | + """Initializes the MaxTextForCausalLM model. |
| 202 | +
|
| 203 | + Args: |
| 204 | + vllm_config: The vLLM configuration object. |
| 205 | + rng_key: A JAX random key for model initialization. |
| 206 | + mesh: The JAX mesh device for model sharding. |
| 207 | + """ |
| 208 | + self.cfg = vllm_config.model_config |
| 209 | + self.mesh = mesh |
| 210 | + self.model = MaxTextDecoderModel(vllm_config, rng_key, mesh) |
| 211 | + self.is_text_generation_model = True |
| 212 | + |
| 213 | + def __call__( |
| 214 | + self, kv_caches: list[jax.Array], input_ids: jax.Array, attention_metadata: AttentionMetadata, *args, **kwargs |
| 215 | + ) -> tuple[list[jax.Array], jax.Array]: |
| 216 | + """Performs a forward pass through the causal language model. |
| 217 | +
|
| 218 | + Args: |
| 219 | + kv_caches: A list of JAX arrays representing the KV caches. |
| 220 | + input_ids: A JAX array of input token IDs. |
| 221 | + attention_metadata: Attention metadata for the decoding process. |
| 222 | + *args: Variable length argument list. |
| 223 | + **kwargs: Arbitrary keyword arguments. |
| 224 | +
|
| 225 | + Returns: |
| 226 | + A tuple containing: |
| 227 | + - updated_kv_caches: A list of updated KV caches. |
| 228 | + - hidden: The hidden states. |
| 229 | + - aux_hidden_states: A list of auxiliary hidden states. |
| 230 | + """ |
| 231 | + kv_caches, hidden, aux_hidden_states = self.model(kv_caches, input_ids, attention_metadata, *args, **kwargs) |
| 232 | + return kv_caches, hidden, aux_hidden_states |
| 233 | + |
| 234 | + def forward(self, *args, **kwargs): |
| 235 | + """Alias for __call__ for compatibility. |
| 236 | +
|
| 237 | + Args: |
| 238 | + *args: Variable length argument list. |
| 239 | + **kwargs: Arbitrary keyword arguments. |
| 240 | +
|
| 241 | + Returns: |
| 242 | + The result of the `__call__` method. |
| 243 | + """ |
| 244 | + return self(*args, **kwargs) |
| 245 | + |
| 246 | + def get_input_embeddings(self) -> jax.Array: |
| 247 | + """Returns the input embeddings of the model. |
| 248 | +
|
| 249 | + Returns: |
| 250 | + A JAX array representing the input embeddings. |
| 251 | + """ |
| 252 | + return self.model.model.token_embedder.embedding |
| 253 | + |
| 254 | + def compute_logits(self, hidden_states: jax.Array) -> jax.Array: |
| 255 | + """Computes the logits from the hidden states using the underlying decoder model. |
| 256 | +
|
| 257 | + Args: |
| 258 | + hidden_states: A JAX array of hidden states. |
| 259 | +
|
| 260 | + Returns: |
| 261 | + A JAX array of logits. |
| 262 | + """ |
| 263 | + return self.model.compute_logits(hidden_states) |
| 264 | + |
| 265 | + def load_weights(self, rng_key: jax.Array) -> None: |
| 266 | + """Loads model weights using the underlying decoder model. |
| 267 | +
|
| 268 | + Args: |
| 269 | + rng_key: A JAX random key for model initialization. |
| 270 | + """ |
| 271 | + self.model.load_weights(rng_key) |
0 commit comments