Skip to content

Commit e1fe512

Browse files
committed
Adding MaxTextForCausalLM interface.
1 parent cb136bc commit e1fe512

File tree

5 files changed

+359
-12
lines changed

5 files changed

+359
-12
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""MaxText vLLM adapter package."""
16+
17+
from tpu_inference.logger import init_logger
18+
from tpu_inference.models.common.model_loader import register_model
19+
from .adapter import MaxTextForCausalLM
20+
21+
22+
logger = init_logger(__name__)
23+
24+
25+
def register():
26+
logger.info("Registering MaxTextForCausalLM model with tpu_inference and vllm.")
27+
register_model("MaxTextForCausalLM", MaxTextForCausalLM)
28+
logger.info("Successfully registered MaxTextForCausalLM model.")
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""vLLM adapter for MaxText models."""
16+
17+
import jax
18+
import pathlib
19+
import os
20+
import jax.numpy as jnp
21+
22+
from flax import nnx
23+
from jax.sharding import Mesh
24+
from MaxText import model_creation_utils
25+
from MaxText import pyconfig
26+
from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE
27+
from MaxText.globals import MAXTEXT_PKG_DIR
28+
from MaxText.utils import gcs_utils
29+
30+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
31+
from vllm.config import VllmConfig
32+
33+
34+
def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters:
35+
"""Generates a MaxText configuration from a vLLM configuration.
36+
37+
This function takes a vLLM configuration object and translates relevant
38+
parameters into a MaxText `HyperParameters` object. It handles loading
39+
paths and model names from the vLLM config, and applies a base MaxText
40+
vLLM configuration file.
41+
42+
Args:
43+
vllm_config: The vLLM configuration object containing model and load
44+
parameters.
45+
46+
Returns:
47+
A `pyconfig.HyperParameters` object configured for MaxText.
48+
49+
Raises:
50+
ValueError: If `hf_config_path` is not provided in the vLLM model config.
51+
"""
52+
53+
def _path_exists(path: str) -> bool:
54+
if not path:
55+
return False
56+
return os.path.exists(path) or gcs_utils.gcs_path_exists(path)
57+
58+
if "maxtext_config" in vllm_config.additional_config:
59+
overrides = vllm_config.additional_config["maxtext_config"]
60+
else:
61+
overrides = {}
62+
load_path = None
63+
if _path_exists(vllm_config.load.download_dir):
64+
load_path = vllm_config.load.download_dir
65+
elif _path_exists(vllm_config.model.model):
66+
load_path = vllm_config.model.model
67+
68+
if load_path:
69+
overrides["load_parameters_path"] = load_path
70+
elif vllm_config.model.model:
71+
overrides["model_name"] = vllm_config.model.model
72+
73+
if vllm_config.model_config.hf_config_path is None:
74+
raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.")
75+
76+
# Add base config path to positional args
77+
base_config_path = pathlib.Path(MAXTEXT_PKG_DIR) / "configs" / "vllm.yml"
78+
argv_list = ["", str(base_config_path)]
79+
80+
maxtext_config = pyconfig.initialize(argv_list, **overrides)
81+
return maxtext_config
82+
83+
84+
class MaxTextDecoderModel(nnx.Module):
85+
"""A vLLM-compatible decoder model wrapper for MaxText.
86+
87+
This class adapts a MaxText model for use within the vLLM framework,
88+
handling configuration generation, model initialization, and execution
89+
of the decoding step.
90+
"""
91+
92+
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> None:
93+
"""Initializes the MaxTextDecoderModel.
94+
95+
Args:
96+
vllm_config: The vLLM configuration object.
97+
rng_key: A JAX random key for model initialization.
98+
mesh: The JAX mesh device for model sharding.
99+
"""
100+
self.vllm_config = vllm_config
101+
self.maxtext_config = generate_maxtext_config(vllm_config)
102+
103+
# Model configuration
104+
self.mesh = mesh
105+
self.model_mode = MODEL_MODE_AUTOREGRESSIVE
106+
107+
# Model creation
108+
self.model: nnx.Module | None = None
109+
self.logits: jax.Array | None = None
110+
111+
def __call__(
112+
self,
113+
kv_caches: list[jax.Array],
114+
input_ids: jax.Array,
115+
attention_metadata: AttentionMetadata,
116+
*args,
117+
**kwargs,
118+
) -> tuple[list[jax.Array], jax.Array]:
119+
"""Performs a forward pass through the decoder model.
120+
121+
Args:
122+
kv_caches: A list of JAX arrays representing the KV caches.
123+
input_ids: A JAX array of input token IDs.
124+
attention_metadata: Attention metadata for the decoding process.
125+
*args: Variable length argument list.
126+
**kwargs: Arbitrary keyword arguments.
127+
128+
Returns:
129+
A tuple containing:
130+
- updated_kv_caches: A list of updated KV caches.
131+
- hidden: The hidden states (Q, d_model).
132+
- aux_hidden_states: A list of auxiliary hidden states.
133+
134+
Raises:
135+
ValueError: If the model is not an instance of `nnx.Module`.
136+
"""
137+
if not isinstance(self.model, nnx.Module):
138+
raise ValueError("Model must be an instance of type nnx.Module.")
139+
140+
if input_ids.ndim < 2:
141+
input_ids = jnp.expand_dims(input_ids, axis=0)
142+
143+
input_positions = attention_metadata.input_positions
144+
if input_positions.ndim < 2:
145+
input_positions = jnp.expand_dims(input_positions, axis=0)
146+
147+
aux_hidden_states = []
148+
logits, hidden, kv_caches = self.model(
149+
decoder_input_tokens=input_ids,
150+
decoder_positions=input_positions,
151+
kv_caches=kv_caches,
152+
attention_metadata=attention_metadata,
153+
model_mode=self.model_mode,
154+
**kwargs,
155+
)
156+
if hidden.ndim > 1:
157+
hidden = jnp.squeeze(hidden, axis=0)
158+
logits = jnp.squeeze(logits, axis=0)
159+
160+
self.logits = logits # cache logits for compute_logits call
161+
162+
return kv_caches, hidden, aux_hidden_states
163+
164+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
165+
"""Computes the logits from the hidden states.
166+
167+
Args:
168+
hidden_states: A JAX array of hidden states.
169+
170+
Returns:
171+
A JAX array of logits (Q, vocab_size).
172+
"""
173+
if self.logits is not None:
174+
return self.logits
175+
176+
embeddings = self.model.token_embedder
177+
# pylint: disable=protected-access
178+
return self.model.decoder._apply_output_head(embeddings, hidden_states, True, self.model_mode)
179+
180+
def load_weights(self, rng_key: jax.Array) -> None:
181+
"""Loads model parameters on the provided mesh.
182+
183+
Args:
184+
rng_key: A JAX random key for model initialization.
185+
"""
186+
self.model, _ = model_creation_utils.create_nnx_model(
187+
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
188+
)
189+
190+
191+
class MaxTextForCausalLM(nnx.Module):
192+
"""A vLLM-compatible causal language model wrapper for MaxText.
193+
194+
This class serves as the primary interface for integrating MaxText models
195+
into the vLLM serving framework, specifically for causal language modeling
196+
tasks. It wraps the `MaxTextDecoderModel` and exposes methods expected
197+
by vLLM.
198+
"""
199+
200+
def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
201+
"""Initializes the MaxTextForCausalLM model.
202+
203+
Args:
204+
vllm_config: The vLLM configuration object.
205+
rng_key: A JAX random key for model initialization.
206+
mesh: The JAX mesh device for model sharding.
207+
"""
208+
self.cfg = vllm_config.model_config
209+
self.mesh = mesh
210+
self.model = MaxTextDecoderModel(vllm_config, rng_key, mesh)
211+
self.is_text_generation_model = True
212+
213+
def __call__(
214+
self, kv_caches: list[jax.Array], input_ids: jax.Array, attention_metadata: AttentionMetadata, *args, **kwargs
215+
) -> tuple[list[jax.Array], jax.Array]:
216+
"""Performs a forward pass through the causal language model.
217+
218+
Args:
219+
kv_caches: A list of JAX arrays representing the KV caches.
220+
input_ids: A JAX array of input token IDs.
221+
attention_metadata: Attention metadata for the decoding process.
222+
*args: Variable length argument list.
223+
**kwargs: Arbitrary keyword arguments.
224+
225+
Returns:
226+
A tuple containing:
227+
- updated_kv_caches: A list of updated KV caches.
228+
- hidden: The hidden states.
229+
- aux_hidden_states: A list of auxiliary hidden states.
230+
"""
231+
kv_caches, hidden, aux_hidden_states = self.model(kv_caches, input_ids, attention_metadata, *args, **kwargs)
232+
return kv_caches, hidden, aux_hidden_states
233+
234+
def forward(self, *args, **kwargs):
235+
"""Alias for __call__ for compatibility.
236+
237+
Args:
238+
*args: Variable length argument list.
239+
**kwargs: Arbitrary keyword arguments.
240+
241+
Returns:
242+
The result of the `__call__` method.
243+
"""
244+
return self(*args, **kwargs)
245+
246+
def get_input_embeddings(self) -> jax.Array:
247+
"""Returns the input embeddings of the model.
248+
249+
Returns:
250+
A JAX array representing the input embeddings.
251+
"""
252+
return self.model.model.token_embedder.embedding
253+
254+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
255+
"""Computes the logits from the hidden states using the underlying decoder model.
256+
257+
Args:
258+
hidden_states: A JAX array of hidden states.
259+
260+
Returns:
261+
A JAX array of logits.
262+
"""
263+
return self.model.compute_logits(hidden_states)
264+
265+
def load_weights(self, rng_key: jax.Array) -> None:
266+
"""Loads model weights using the underlying decoder model.
267+
268+
Args:
269+
rng_key: A JAX random key for model initialization.
270+
"""
271+
self.model.load_weights(rng_key)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Setup for MaxText vLLM adapter package."""
16+
17+
from setuptools import setup
18+
19+
setup(
20+
name="maxtext_vllm_adapter",
21+
version="0.1.0",
22+
packages=["maxtext_vllm_adapter"],
23+
entry_points={"vllm.general_plugins": ["register_maxtext_vllm_adapter = maxtext_vllm_adapter:register"]},
24+
)

0 commit comments

Comments
 (0)