From c06aa8f8de99724c7df594e5b92fdf78dadccc23 Mon Sep 17 00:00:00 2001 From: A9isha Date: Thu, 6 Nov 2025 23:57:08 +0000 Subject: [PATCH] updates --- ...ost_training_local_dependencies.Dockerfile | 58 +++++++++++++++++++ .../scripts/docker_build_dependency_image.sh | 57 +++++++++++------- docs/tutorials/grpo.md | 12 ++-- docs/tutorials/grpo_with_pathways.md | 12 ++-- .../integration/tunix/tunix_adapter.py | 2 +- src/MaxText/integration/tunix/utils.py | 2 +- .../tunix/weight_mapping/__init__.py | 4 +- src/MaxText/rl/evaluate_rl.py | 2 +- src/MaxText/rl/train_rl.py | 37 +++++++----- src/MaxText/rl/utils_rl.py | 2 +- 10 files changed, 135 insertions(+), 53 deletions(-) create mode 100644 dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile diff --git a/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile new file mode 100644 index 000000000..79d29bc93 --- /dev/null +++ b/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile @@ -0,0 +1,58 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG BASEIMAGE +FROM ${BASEIMAGE} +ARG MODE +ENV MODE=$MODE + +RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}" +RUN pip uninstall -y jax jaxlib libtpu + +RUN pip install aiohttp==3.12.15 + +# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically. +RUN pip install keyring keyrings.google-artifactregistry-auth + +RUN pip install numba==0.61.2 + +COPY tunix /tunix +RUN pip install -e /tunix --no-cache-dir + + +COPY vllm /vllm +RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + +COPY tpu-inference /tpu-inference +RUN pip install -e /tpu-inference --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + + +RUN if [ "$MODE" = "post-training-experimental" ]; then \ + echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \ + pip uninstall -y jax jaxlib libtpu && \ + pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ + pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ + fi diff --git a/dependencies/scripts/docker_build_dependency_image.sh b/dependencies/scripts/docker_build_dependency_image.sh index 695ee29b4..fa717825e 100644 --- a/dependencies/scripts/docker_build_dependency_image.sh +++ b/dependencies/scripts/docker_build_dependency_image.sh @@ -6,7 +6,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# https://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,7 +20,7 @@ # bash docker_build_dependency_image.sh MODE=nightly # bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13 # Nightly build with JAX_VERSION for GPUs. Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax: -# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly +# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly # MODE=custom_wheels is the same as nightly except that it reinstalls any # additional wheels that are present in the maxtext directory. # The main use case is to install custom jax or jaxlib wheels but it also @@ -28,6 +28,7 @@ # bash docker_build_dependency_image.sh MODE=custom_wheels # bash docker_build_dependency_image.sh MODE=post-training +# bash docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local if [ "${BASH_SOURCE-}" ]; then this_file="${BASH_SOURCE[0]}" @@ -97,6 +98,12 @@ if [[ -z ${DEVICE} ]]; then echo "Default DEVICE=${DEVICE}" fi +# New flag for post-training source +if [[ -z ${POST_TRAINING_SOURCE} ]]; then + export POST_TRAINING_SOURCE=remote # Default to the original Dockerfile + echo "Default POST_TRAINING_SOURCE=${POST_TRAINING_SOURCE}" +fi + # Function to build with MODE=jax_ai_image build_ai_image() { if [[ -z ${BASEIMAGE+x} ]]; then @@ -171,24 +178,34 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then exit 1 fi - # # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__. - # # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext). - # rsync -a --exclude='__pycache__' ../tpu_commons . - # # To install vllm from a local path, we copy it into the build context, excluding __pycache__. - # # This assumes vllm is a sibling directory to the current one (maxtext). - # rsync -a --exclude='__pycache__' ../vllm . - - # rsync -a --exclude='__pycache__' ../tunix . - - # # The cleanup is set to run even if the build fails to remove the copied directory. - # trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM - - docker build \ - --network host \ - --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ - --build-arg MODE=${MODE} \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile' \ - -t ${LOCAL_IMAGE_NAME} . + DOCKERFILE_NAME="" + if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then + + # To install tpu-inference from a local path, we copy it into the build context, excluding __pycache__. + # This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext). + rsync -a --exclude='__pycache__' ../tpu-inference . + # To install vllm from a local path, we copy it into the build context, excluding __pycache__. + # This assumes vllm is a sibling directory to the current one (maxtext). + rsync -a --exclude='__pycache__' ../vllm . + + rsync -a --exclude='__pycache__' ../tunix . + + # The cleanup is set to run even if the build fails to remove the copied directory. + trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM + + DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile' + echo "Using local post-training dependencies Dockerfile: $DOCKERFILE_NAME" + else + DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile' + echo "Using remote post-training dependencies Dockerfile: $DOCKERFILE_NAME" + fi + + docker build \ + --network host \ + --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ + --build-arg MODE=${MODE} \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/'"$DOCKERFILE_NAME" \ + -t ${LOCAL_IMAGE_NAME} . fi if [[ ${CUSTOM_JAX} -eq 1 ]] ; then diff --git a/docs/tutorials/grpo.md b/docs/tutorials/grpo.md index ee201948d..36207e168 100644 --- a/docs/tutorials/grpo.md +++ b/docs/tutorials/grpo.md @@ -62,12 +62,12 @@ Finally, run the command ``` python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ - --model_name=llama3.1-8b \ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint/0/items \ - --run_name=$WORKLOAD \ - --base_output_directory=$OUTPUT_PATH \ - --hf_access_token=$HF_TOKEN + model_name=llama3.1-8b \ + tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + load_parameters_path=gs://path/to/checkpoint/0/items \ + run_name=$WORKLOAD \ + base_output_directory=$OUTPUT_PATH \ + hf_access_token=$HF_TOKEN ``` The overview of the demo script is as follows: diff --git a/docs/tutorials/grpo_with_pathways.md b/docs/tutorials/grpo_with_pathways.md index c2ce421c7..0d6561c8d 100644 --- a/docs/tutorials/grpo_with_pathways.md +++ b/docs/tutorials/grpo_with_pathways.md @@ -61,12 +61,12 @@ xpk workload create-pathways --workload $WORKLOAD \ --project=$PROJECT_ID --priority=high \ --command "HF_TOKEN=$HF_TOKEN TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' # Llama3.1-70B-Instruct python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ - --model_name=llama3.1-70b \ - --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint/0/items \ - --run_name=$WORKLOAD \ - --base_output_directory=$OUTPUT_PATH \ - --hf_access_token=$HF_TOKEN" + model_name=llama3.1-70b \ + tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ + load_parameters_path=gs://path/to/checkpoint/0/items \ + run_name=$WORKLOAD \ + base_output_directory=$OUTPUT_PATH \ + hf_access_token=$HF_TOKEN" ``` The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows: diff --git a/src/MaxText/integration/tunix/tunix_adapter.py b/src/MaxText/integration/tunix/tunix_adapter.py index 499adda94..4d54f3aba 100644 --- a/src/MaxText/integration/tunix/tunix_adapter.py +++ b/src/MaxText/integration/tunix/tunix_adapter.py @@ -26,7 +26,7 @@ from jax import Array from flax import nnx from MaxText.layers.models import Transformer -from maxtext.src.maxtext.integration.tunix.utils import VllmWeightMapping +from MaxText.integration.tunix.utils import VllmWeightMapping from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS # pylint: disable=ungrouped-imports diff --git a/src/MaxText/integration/tunix/utils.py b/src/MaxText/integration/tunix/utils.py index 43ca8577c..6843f191e 100644 --- a/src/MaxText/integration/tunix/utils.py +++ b/src/MaxText/integration/tunix/utils.py @@ -16,7 +16,7 @@ import re -import maxtext.src.maxtext.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import +from MaxText.integration.tunix import weight_mapping # pylint: disable=consider-using-from-import from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS diff --git a/src/MaxText/integration/tunix/weight_mapping/__init__.py b/src/MaxText/integration/tunix/weight_mapping/__init__.py index c1dc34a88..d250ee2fe 100644 --- a/src/MaxText/integration/tunix/weight_mapping/__init__.py +++ b/src/MaxText/integration/tunix/weight_mapping/__init__.py @@ -19,8 +19,8 @@ model name. This allows for easy extension to support new models. """ -from maxtext.src.maxtext.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING -from maxtext.src.maxtext.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING +from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING +from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING class StandaloneVllmWeightMapping: diff --git a/src/MaxText/rl/evaluate_rl.py b/src/MaxText/rl/evaluate_rl.py index d8c6b6151..29610ef58 100644 --- a/src/MaxText/rl/evaluate_rl.py +++ b/src/MaxText/rl/evaluate_rl.py @@ -68,7 +68,7 @@ def generate_responses( responses = rl_cluster.rollout.generate( prompts, rollout_config=RolloutConfig( - max_tokens_to_generate=tmvp_config.max_target_length, + max_tokens_to_generate=tmvp_config.max_target_length - tmvp_config.max_prefill_predict_length, temperature=eval_strategy["eval_temperature"], top_k=eval_strategy["eval_top_k"], top_p=eval_strategy["eval_top_p"], diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index 2a2bbeeae..e76aa41a6 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -26,21 +26,21 @@ # Llama3.1-8B-Instruct python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ - --model_name=llama3.1-8b \ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint/0/items \ - --run_name=$WORKLOAD \ - --base_output_directory=$OUTPUT_PATH \ - --hf_access_token=$HF_TOKEN + model_name=llama3.1-8b \ + tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + load_parameters_path=gs://path/to/checkpoint/0/items \ + run_name=$WORKLOAD \ + base_output_directory=$OUTPUT_PATH \ + hf_access_token=$HF_TOKEN # Llama3.1-70B-Instruct python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ - --model_name=llama3.1-70b \ - --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint/0/items \ - --run_name=$WORKLOAD \ - --base_output_directory=$OUTPUT_PATH \ - --hf_access_token=$HF_TOKEN + model_name=llama3.1-70b \ + tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ + load_parameters_path=gs://path/to/checkpoint/0/items \ + run_name=$WORKLOAD \ + base_output_directory=$OUTPUT_PATH \ + hf_access_token=$HF_TOKEN """ @@ -52,6 +52,7 @@ from flax import nnx from flax.linen import partitioning as nn_partitioning import grain +from etils import epath from vllm.outputs import PoolingRequestOutput # pylint: disable=unused-import import jax @@ -186,6 +187,13 @@ def rl_train(tmvp_config): """ max_logging.log("Starting GRPO Training") + LOG_DIR = f"{tmvp_config.base_output_directory}-{tmvp_config.run_name}/tensorboard/" + if not epath.Path(LOG_DIR).exists(): + epath.Path(LOG_DIR).mkdir(parents=True) + + ckpt_dir = tmvp_config.checkpoint_dir + if not epath.Path(ckpt_dir).exists(): + epath.Path(ckpt_dir).mkdir(parents=True) # Number of training steps. max_train_steps = int( @@ -266,16 +274,15 @@ def rl_train(tmvp_config): optimizer = utils_rl.get_optimizer(tmvp_config, max_train_steps) # Setup checkpointing - ckpt_dir = tmvp_config.checkpoint_dir checkpointing_options = ocp.CheckpointManagerOptions( save_interval_steps=tmvp_config.checkpoint_period, max_to_keep=tmvp_config.max_num_checkpoints_to_keep ) # Setup metrics logging - max_logging.log(f"TensorBoard logs directory: {tmvp_config.tensorboard_dir}") + max_logging.log(f"TensorBoard logs directory: {LOG_DIR}") # Metrics logger metrics_logging_options = metrics_logger.MetricsLoggerOptions( - log_dir=tmvp_config.tensorboard_dir, flush_every_n_steps=tmvp_config.log_period + log_dir=LOG_DIR, flush_every_n_steps=tmvp_config.log_period ) # Profiler configurations diff --git a/src/MaxText/rl/utils_rl.py b/src/MaxText/rl/utils_rl.py index 3422cd01a..8b1cede37 100644 --- a/src/MaxText/rl/utils_rl.py +++ b/src/MaxText/rl/utils_rl.py @@ -172,7 +172,7 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): return scores -def extract_hash_answer(text: str, debug: bool = False) -> str | None: +def extract_hash_answer(text: str) -> str | None: """Function to extract only the answer hash from the text.""" if "####" not in text: return None