From 963158168792b0c94cfd453254b901007d411737 Mon Sep 17 00:00:00 2001 From: Gagik Amirkhanyan Date: Sun, 30 Nov 2025 22:34:52 -0800 Subject: [PATCH] Support Custom MaxText model (with vLLM engine) in RL rollouts. Fix formatting. Refactor model creation and error handling in RL training fix linting. adding no-op mappings to tunix adapter. removing kvcache init for vllm case. latest updates from debugging. adding null logical axis rules to adapter. adding linting fixes. fixing pyink remove unused imports attentions test. adding fixes. --- src/MaxText/configs/base.yml | 6 +++ src/MaxText/configs/types.py | 12 +++--- src/MaxText/configs/vllm.yml | 13 +++++- .../integration/tunix/tunix_adapter.py | 14 +++++++ src/MaxText/integration/tunix/utils.py | 5 ++- .../vllm/maxtext_vllm_adapter/adapter.py | 40 +++++++------------ src/MaxText/layers/attentions.py | 13 ++++-- src/MaxText/model_creation_utils.py | 6 +-- src/MaxText/rl/evaluate_rl.py | 13 ++++-- src/MaxText/rl/train_rl.py | 30 +++++++++++++- .../ckpt_conversion/utils/param_mapping.py | 2 +- 11 files changed, 108 insertions(+), 46 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index e07363d08..edb0a661b 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -979,3 +979,9 @@ use_tokamax_gmm: false use_tokamax_splash: false # Setting this flag will use a non-pallas implementation. use_jax_splash: false + +# vLLM Adapter Configurations +# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter) +vllm_hf_config_path: "" +# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}') +vllm_additional_config: {} diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 8a2449192..d03f8da93 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -279,9 +279,7 @@ class Checkpointing(BaseModel): save_checkpoint_on_completion: bool = Field( True, description="If True, saves a final checkpoint upon training completion." ) - enable_continuous_checkpointing: bool = Field( - False, description="If True, enables continuous checkpointing." - ) + enable_continuous_checkpointing: bool = Field(False, description="If True, enables continuous checkpointing.") class OrbaxStorage(BaseModel): @@ -463,9 +461,7 @@ class Attention(BaseModel): ragged_block_size: int = Field(256, description="Block size for ragged attention.") enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.") use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.") - use_jax_splash: bool = Field( - False, description="Whether to use jax splash attention." - ) + use_jax_splash: bool = Field(False, description="Whether to use jax splash attention.") class MoBa(BaseModel): @@ -1376,6 +1372,8 @@ class VLLM(BaseModel): kv_cache_buffer: int = Field(256, description="Buffer for KV cache.") hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.") swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.") + vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.") + vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.") class GRPO(BaseModel): @@ -2163,6 +2161,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "tensor": self.ici_tensor_parallelism, "tensor_transpose": self.ici_tensor_transpose_parallelism, "tensor_sequence": self.ici_tensor_sequence_parallelism, + "model": self.ici_tensor_parallelism, "expert": self.ici_expert_parallelism, "autoregressive": self.ici_autoregressive_parallelism, } @@ -2179,6 +2178,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "tensor": self.dcn_tensor_parallelism, "tensor_transpose": self.dcn_tensor_transpose_parallelism, "tensor_sequence": self.dcn_tensor_sequence_parallelism, + "model": self.dcn_tensor_parallelism, "expert": self.dcn_expert_parallelism, "autoregressive": self.dcn_autoregressive_parallelism, } diff --git a/src/MaxText/configs/vllm.yml b/src/MaxText/configs/vllm.yml index 2cf9b4020..1f2a839e5 100644 --- a/src/MaxText/configs/vllm.yml +++ b/src/MaxText/configs/vllm.yml @@ -41,6 +41,7 @@ logical_axis_rules: [ ['activation_kv_batch_no_exp', []], ['activation_kv_head_dim', ['model']], ['activation_vocab', ['model']], + ['activation_embed', ['model']], ['activation_exp', ['expert']], ['decode_batch', ['expert']], ['mlp', ['model']], @@ -49,13 +50,23 @@ logical_axis_rules: [ ['heads', ['model']], ['q_heads', ['model']], ['kv_heads', ['model']], + ['kv_head_dim', []], + ['kv', []], ['embed', ['expert']], + ['embed_no_exp', []], ['q_lora', ['expert']], ['kv_lora', ['expert']], ['norm', ['model']], ['cache_heads', ['model']], ['exp', ['expert']], ['paged_kv_heads', ['model']], + ['autoregressive', ['model']], + ['tensor', ['model']], + ['tensor_transpose', ['model']], + ['fsdp', ['data']], + ['fsdp_transpose', ['data']], + ['sequence', ['model']], + ['context', ['model']], ] data_sharding: [['data', 'model', 'expert']] -input_data_sharding_logical_axes: ['activation_embed_and_logits_batch'] +input_data_sharding_logical_axes: ['activation_embed_and_logits_batch'] \ No newline at end of file diff --git a/src/MaxText/integration/tunix/tunix_adapter.py b/src/MaxText/integration/tunix/tunix_adapter.py index d6adb5108..6ffceb404 100644 --- a/src/MaxText/integration/tunix/tunix_adapter.py +++ b/src/MaxText/integration/tunix/tunix_adapter.py @@ -37,6 +37,7 @@ def __init__( self, base_model: Transformer, use_standalone_mappings: bool = True, + use_no_op_mappings: bool = False, ): super().__init__() self.base = base_model @@ -45,6 +46,7 @@ def __init__( HF_MODEL_CONFIGS[self.base.config.model_name].to_dict(), use_standalone_mappings, ) + self.use_no_op_mappings = use_no_op_mappings # ------------------------------------------------------------------ # # Tunix call signature @@ -69,13 +71,25 @@ def __call__( return logits, None def to_hf_mappings(self): + if self.use_no_op_mappings: + return {} + return self._vllm_weight_mapping.to_hf_mapping() def to_hf_transpose_keys(self): + if self.use_no_op_mappings: + return {} + return self._vllm_weight_mapping.to_hf_transpose_keys() def to_hf_hook_fns(self): + if self.use_no_op_mappings: + return {} + return self._vllm_weight_mapping.to_hf_hook_fns() def lora_to_hf_mappings(self): + if self.use_no_op_mappings: + return {} + return self._vllm_weight_mapping.lora_to_hf_mappings() diff --git a/src/MaxText/integration/tunix/utils.py b/src/MaxText/integration/tunix/utils.py index 2cf12c048..463cff8ee 100644 --- a/src/MaxText/integration/tunix/utils.py +++ b/src/MaxText/integration/tunix/utils.py @@ -147,7 +147,10 @@ def to_hf_hook_fns(self): return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_hook_fns() model_family = self.model_name.split("-")[0] - return VLLM_HOOK_FNS[model_family]() + if model_family in VLLM_HOOK_FNS: + return VLLM_HOOK_FNS[model_family]() + else: + return {} def lora_to_hf_mappings(self): if self.use_standalone_mappings: diff --git a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py index 9899ef941..ab3642810 100644 --- a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -49,26 +49,17 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters Raises: ValueError: If `hf_config_path` is not provided in the vLLM model config. """ - - def _path_exists(path: str) -> bool: - if not path: - return False - return epath.Path(path).exists() - if "maxtext_config" in vllm_config.additional_config: overrides = vllm_config.additional_config["maxtext_config"] else: overrides = {} - load_path = None - if _path_exists(vllm_config.load.download_dir): - load_path = vllm_config.load.download_dir - elif _path_exists(vllm_config.model.model): - load_path = vllm_config.model.model - if load_path: - overrides["load_parameters_path"] = load_path - elif vllm_config.model.model: - overrides["model_name"] = vllm_config.model.model + if vllm_config.load_config.load_format == "dummy": + if overrides.get("load_parameters_path") is not None: + max_logging.log( + "Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped." + ) + overrides["load_parameters_path"] = None if vllm_config.model_config.hf_config_path is None: raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.") @@ -110,12 +101,6 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> N # Handle dummy weight loading during initialization if vllm_config.load_config.load_format == "dummy": - if self.maxtext_config.load_parameters_path is not None: - max_logging.log( - "Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped." - ) - self.maxtext_config.load_parameters_path = None - with self.mesh: self.load_weights(rng_key) @@ -173,7 +158,7 @@ def __call__( hidden = jnp.squeeze(hidden, axis=0) logits = jnp.squeeze(logits, axis=0) - self.logits = logits # cache logits for compute_logits call + self.logits = nnx.data(logits) # cache logits for compute_logits call return kv_caches, hidden, aux_hidden_states @@ -199,9 +184,14 @@ def load_weights(self, rng_key: jax.Array) -> None: Args: rng_key: A JAX random key for model initialization. """ - self.model, _ = model_creation_utils.create_nnx_model( - self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key - ) + if self.model is not None: + return + + with nn.logical_axis_rules(""): + model, _ = model_creation_utils.create_nnx_model( + self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key + ) + self.model = nnx.data(model) class MaxTextForCausalLM(nnx.Module): diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index e3f1d1950..fec00614d 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -423,7 +423,7 @@ def __init__( # Module attribute names must match names previously passed to Linen for checkpointing self.KVCache_0 = ( self.init_kv_caches(inputs_kv_shape=inputs_kv_shape) - if self.model_mode != MODEL_MODE_TRAIN and base_kv_cache + if self.model_mode != MODEL_MODE_TRAIN and base_kv_cache and config.attention != "vllm_rpa" else None ) @@ -909,7 +909,7 @@ def forward_serve_vllm( try: # pylint: disable=import-outside-toplevel # pytype: disable=import-error - from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops + from tpu_inference.layers.common.attention_interface import sharded_ragged_paged_attention as rpa_ops except ImportError as e: raise ImportError( "vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`." @@ -930,7 +930,8 @@ def forward_serve_vllm( md = rpa_metadata - output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)( + output, kv_cache = rpa_ops( + self.mesh, query, key, value, @@ -939,6 +940,12 @@ def forward_serve_vllm( md.block_tables, md.query_start_loc, md.request_distribution, + None, + 1.0, + attention_chunk_size, + q_scale, + k_scale, + v_scale, ) return kv_cache, output diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index ade242d43..816f83815 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -78,9 +78,9 @@ def from_config( Example: model = from_config(config) """ - devices_array = maxtext_utils.create_device_mesh(config, devices) - if mesh is None: + devices_array = maxtext_utils.create_device_mesh(config, devices) + if config.shard_mode == ShardMode.EXPLICIT: axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes)) else: @@ -154,7 +154,7 @@ def create_sharded_state(): model = _create_model_partial() return nnx.state(model) - with jax.set_mesh(mesh): + with mesh: # Create the model with sharded parameters. with nn.logical_axis_rules(config.logical_axis_rules): sharded_state = create_sharded_state() diff --git a/src/MaxText/rl/evaluate_rl.py b/src/MaxText/rl/evaluate_rl.py index 29610ef58..03a003074 100644 --- a/src/MaxText/rl/evaluate_rl.py +++ b/src/MaxText/rl/evaluate_rl.py @@ -121,13 +121,18 @@ def score_responses(tmvp_config, question, responses, answer): # Check exact correctness try: - if float(extracted_response.strip()) == float(answer.strip()): + # Remove ',' and '$' then convert to float + val_extracted = float(extracted_response.replace(",", "").replace("$", "").strip()) + val_answer = float(answer.replace(",", "").replace("$", "").strip()) + + if val_extracted == val_answer: is_correct = True # Check partial correctness (within 10%) - ratio = float(extracted_response.strip()) / float(answer.strip()) - if 0.9 <= ratio <= 1.1: - is_partially_correct = True + if val_answer != 0.0: + ratio = val_extracted / val_answer + if 0.9 <= ratio <= 1.1: + is_partially_correct = True except Exception as e: if tmvp_config.debug["rl"]: max_logging.log(f"Evaluation Exception: {e}") diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index fc83ae60f..0e4466c3b 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -48,6 +48,7 @@ import collections import grain import jax +import json import os import pathwaysutils import tensorflow_datasets as tfds @@ -70,6 +71,7 @@ from MaxText import max_logging, max_utils, maxtext_utils, pyconfig from MaxText import model_creation_utils +from MaxText.globals import MAXTEXT_PKG_DIR from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter from MaxText.rl.evaluate_rl import evaluate from MaxText.rl import utils_rl @@ -93,7 +95,8 @@ def get_maxtext_model(config, devices=None): """ model, mesh = model_creation_utils.create_nnx_model(config, devices=devices) with jax.set_mesh(mesh): - tunix_model = TunixMaxTextAdapter(base_model=model) + use_no_op_mappings = "maxtext_config" in config.vllm_additional_config + tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) tunix_model.config = None return tunix_model, mesh @@ -352,6 +355,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): set_profile_options=False, ) + # Parse vllm_additional_config + rollout_additional_config = None + if trainer_config.vllm_additional_config: + if isinstance(trainer_config.vllm_additional_config, dict): + # It's already parsed into a dict + rollout_additional_config = trainer_config.vllm_additional_config + elif isinstance(trainer_config.vllm_additional_config, str): + # It's a string, so we need to parse it + try: + rollout_additional_config = json.loads(trainer_config.vllm_additional_config) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse additional_config JSON: {e}") from e + + max_logging.log(f"Parsed additional config: {rollout_additional_config}") + # RL Cluster config # Note that we use vLLM as the rollout engine. # and we are using Tensor Parallelism for rollout @@ -394,6 +412,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm, rollout_vllm_tpu_backend_type="jax", rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb, + rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path, + rollout_vllm_additional_config=rollout_additional_config, + rollout_vllm_init_with_random_weights=True, **get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)), ), ) @@ -423,7 +444,12 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): max_logging.log( "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics." ) - with nn_partitioning.axis_rules(trainer_config.logical_axis_rules): + + vllm_config_path = epath.Path(MAXTEXT_PKG_DIR) / "configs" / "vllm.yml" + argv_list = ["", str(vllm_config_path), "log_config=False"] + vllm_config = pyconfig.initialize(argv_list) + + with nn_partitioning.axis_rules(vllm_config.logical_axis_rules): rl_cluster = rl_cluster_lib.RLCluster( actor=actor_model, reference=reference_model, diff --git a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py index bde6a62e4..3d4a194bc 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py +++ b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py @@ -1479,5 +1479,5 @@ def transform_query_kernel(arr): VLLM_HOOK_FNS = { "qwen3": QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN, "llama3.1": LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN, - "deepseek3-671b": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN, + "deepseek3": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN, }