diff --git a/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile index 277c1fe92..3cb1c0c38 100644 --- a/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile +++ b/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile @@ -22,38 +22,17 @@ RUN echo "Installing Post-Training dependencies (vLLM, tpu-common, tunix) with M # Uninstall existing jax to avoid conflicts -RUN pip uninstall -y jax jaxlib libtpu - -RUN pip install aiohttp==3.12.15 - -# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically. -RUN pip install keyring keyrings.google-artifactregistry-auth - -RUN pip install numba==0.61.2 - -# Install vLLM for Jax and TPUs from the artifact registry -RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ - --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ - --extra-index-url https://pypi.org/simple/ \ - --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ - --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ - --find-links https://storage.googleapis.com/libtpu-releases/index.html \ - --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ - --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ - vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu - -# Install tpu-commons from the artifact registry -RUN pip install --no-cache-dir --pre \ - --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ - --extra-index-url https://pypi.org/simple/ \ - --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ - tpu-commons==0.1.2 +RUN uv pip uninstall -y jax jaxlib libtpu + +RUN uv pip install aiohttp==3.12.15 + +RUN uv pip install numba==0.61.2 + +# Install vLLM for Jax and TPUs +RUN uv pip install vllm-tpu RUN if [ "$MODE" = "post-training-experimental" ]; then \ - pip uninstall -y jax jaxlib libtpu && \ - pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ - pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ + uv pip uninstall -y jax jaxlib libtpu && \ + uv pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ + uv pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ fi diff --git a/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile new file mode 100644 index 000000000..79d29bc93 --- /dev/null +++ b/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile @@ -0,0 +1,58 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG BASEIMAGE +FROM ${BASEIMAGE} +ARG MODE +ENV MODE=$MODE + +RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}" +RUN pip uninstall -y jax jaxlib libtpu + +RUN pip install aiohttp==3.12.15 + +# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically. +RUN pip install keyring keyrings.google-artifactregistry-auth + +RUN pip install numba==0.61.2 + +COPY tunix /tunix +RUN pip install -e /tunix --no-cache-dir + + +COPY vllm /vllm +RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + +COPY tpu-inference /tpu-inference +RUN pip install -e /tpu-inference --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + + +RUN if [ "$MODE" = "post-training-experimental" ]; then \ + echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \ + pip uninstall -y jax jaxlib libtpu && \ + pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ + pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ + fi diff --git a/dependencies/scripts/docker_build_dependency_image.sh b/dependencies/scripts/docker_build_dependency_image.sh index 695ee29b4..fa717825e 100644 --- a/dependencies/scripts/docker_build_dependency_image.sh +++ b/dependencies/scripts/docker_build_dependency_image.sh @@ -6,7 +6,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# https://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,7 +20,7 @@ # bash docker_build_dependency_image.sh MODE=nightly # bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13 # Nightly build with JAX_VERSION for GPUs. Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax: -# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly +# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly # MODE=custom_wheels is the same as nightly except that it reinstalls any # additional wheels that are present in the maxtext directory. # The main use case is to install custom jax or jaxlib wheels but it also @@ -28,6 +28,7 @@ # bash docker_build_dependency_image.sh MODE=custom_wheels # bash docker_build_dependency_image.sh MODE=post-training +# bash docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local if [ "${BASH_SOURCE-}" ]; then this_file="${BASH_SOURCE[0]}" @@ -97,6 +98,12 @@ if [[ -z ${DEVICE} ]]; then echo "Default DEVICE=${DEVICE}" fi +# New flag for post-training source +if [[ -z ${POST_TRAINING_SOURCE} ]]; then + export POST_TRAINING_SOURCE=remote # Default to the original Dockerfile + echo "Default POST_TRAINING_SOURCE=${POST_TRAINING_SOURCE}" +fi + # Function to build with MODE=jax_ai_image build_ai_image() { if [[ -z ${BASEIMAGE+x} ]]; then @@ -171,24 +178,34 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then exit 1 fi - # # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__. - # # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext). - # rsync -a --exclude='__pycache__' ../tpu_commons . - # # To install vllm from a local path, we copy it into the build context, excluding __pycache__. - # # This assumes vllm is a sibling directory to the current one (maxtext). - # rsync -a --exclude='__pycache__' ../vllm . - - # rsync -a --exclude='__pycache__' ../tunix . - - # # The cleanup is set to run even if the build fails to remove the copied directory. - # trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM - - docker build \ - --network host \ - --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ - --build-arg MODE=${MODE} \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile' \ - -t ${LOCAL_IMAGE_NAME} . + DOCKERFILE_NAME="" + if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then + + # To install tpu-inference from a local path, we copy it into the build context, excluding __pycache__. + # This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext). + rsync -a --exclude='__pycache__' ../tpu-inference . + # To install vllm from a local path, we copy it into the build context, excluding __pycache__. + # This assumes vllm is a sibling directory to the current one (maxtext). + rsync -a --exclude='__pycache__' ../vllm . + + rsync -a --exclude='__pycache__' ../tunix . + + # The cleanup is set to run even if the build fails to remove the copied directory. + trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM + + DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile' + echo "Using local post-training dependencies Dockerfile: $DOCKERFILE_NAME" + else + DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile' + echo "Using remote post-training dependencies Dockerfile: $DOCKERFILE_NAME" + fi + + docker build \ + --network host \ + --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ + --build-arg MODE=${MODE} \ + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/'"$DOCKERFILE_NAME" \ + -t ${LOCAL_IMAGE_NAME} . fi if [[ ${CUSTOM_JAX} -eq 1 ]] ; then diff --git a/docs/tutorials/grpo.md b/docs/tutorials/grpo.md index ee201948d..32b4dd41e 100644 --- a/docs/tutorials/grpo.md +++ b/docs/tutorials/grpo.md @@ -25,35 +25,20 @@ And we use vLLM as the library for efficient model inference and generation. In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started! -## Setup your virtual environment +## Create virtual environment and Install MaxText dependencies +Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), but +recommend creating the virtual environment outside the `maxtext` directory. -### Create a Python3.12 venv if not already pre-existing and install MaxText dependencies -```sh -bash tools/setup/setup.sh -``` - -### Activate your virtual environment (Skip if you have already done this for running `bash tools/setup/setup.sh` ) -``` -# Replace with your virtual environment name if not using this default name -venv_name="maxtext_venv" -source ~/$venv_name/bin/activate -``` - -## vLLM and tpu-commons installations +## vLLM and tpu-inference installations -Next, run the following bash script to get all the necessary installations inside the virtual environment. +Next, run the following bash script to get all the necessary installations inside the virtual environment (for e.g., `maxtext_venv`). This will take few minutes. Follow along the installation logs and look out for any issues! ``` bash ~/maxtext/src/MaxText/examples/install_tunix_vllm_requirement.sh ``` -1. It installs `pip install keyring keyrings.google-artifactregistry-auth` which enables pip to authenticate with Google Artifact Registry automatically. -2. Next, it installs `vLLM` for Jax and TPUs from the artifact registry `https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/` -3. Then, it installs `tpu-commons` from the same artifact registry. - -`tpu_commons` is the TPU backend for vLLM. You will need both libraries to run vLLM on tpus. -We use the scheduler code from vLLM, and the model runner code from `tpu_commons` +Primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support. ## Run GRPO @@ -62,15 +47,15 @@ Finally, run the command ``` python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ - --model_name=llama3.1-8b \ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint/0/items \ - --run_name=$WORKLOAD \ - --base_output_directory=$OUTPUT_PATH \ - --hf_access_token=$HF_TOKEN + model_name=llama3.1-8b \ + tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + load_parameters_path=gs://path/to/checkpoint/0/items \ + run_name=$WORKLOAD \ + base_output_directory=$OUTPUT_PATH \ + hf_access_token=$HF_TOKEN ``` -The overview of the demo script is as follows: +The overview of the what this run will do is as follows: 1. We load a policy model and a reference model. Both are copies of `Llama3.1-8b-Instruct`. 2. Evaluate the policy model's performance on GSM8K math reasoning benchmark. diff --git a/docs/tutorials/grpo_with_pathways.md b/docs/tutorials/grpo_with_pathways.md index c2ce421c7..5ec6ac537 100644 --- a/docs/tutorials/grpo_with_pathways.md +++ b/docs/tutorials/grpo_with_pathways.md @@ -24,27 +24,28 @@ We use Tunix as the library for GRPO. And we use vLLM as the library for efficient model inference and generation. Furthermore, we use Pathways for [orchestration](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro). Using Pathways, you can also run GRPO in a disaggregated mode where the trainer and the samplers are running on separate mesh. Try out the following recipe `v5p-64`. You can submit jobs to a Pathways enabled GKE cluster. - -## Build and Upload MaxText Docker Image with Tunix, vLLM, tpu-commons dependencies -Run the following bash script to create a docker image with all the dependencies of MaxText, Tunix, vLLM and tpu-commons installed. - -In addition to MaxText dependencies, -1. It installs `pip install keyring keyrings.google-artifactregistry-auth` which enables pip to authenticate with Google Artifact Registry automatically. -2. Next, it installs `vLLM` for Jax and TPUs from the artifact registry `https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/` -3. Then, it installs `tpu-commons` from the same artifact registry. +## Create virtual environment and Install MaxText dependencies +Follow instructions in [Install MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/install_maxtext.md), but +recommend creating the virtual environment outside the `maxtext` directory. +## Build and Upload MaxText Docker Image with Tunix, vLLM, tpu-inference dependencies -`tpu_commons` is the TPU backend for vLLM. You will need both libraries to run vLLM on tpus. -We use the scheduler code from vLLM, and the model runner code from `tpu_commons` +### Installing stable releases of tunix and vllm-tpu +Run the following bash script to create a docker image with all the dependencies of MaxText, Tunix, vLLM and tpu-inference installed. +In addition to MaxText dependencies, primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support. + ``` -bash docker_build_dependency_image.sh MODE=post-training +bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training ``` -You can also use `bash docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API +You can also use `bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API +### Install from locally git cloned repo's +You can also locally git clone [tunix](https://github.com/google/tunix), [tpu-inference](https://github.com/vllm-project/tpu-inference), [vllm](https://github.com/vllm-project/vllm.git) and then use the following command to build a docker image using them: +`bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local` ### Upload the dependency docker image along with MaxText code ``` @@ -61,12 +62,12 @@ xpk workload create-pathways --workload $WORKLOAD \ --project=$PROJECT_ID --priority=high \ --command "HF_TOKEN=$HF_TOKEN TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' # Llama3.1-70B-Instruct python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ - --model_name=llama3.1-70b \ - --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint/0/items \ - --run_name=$WORKLOAD \ - --base_output_directory=$OUTPUT_PATH \ - --hf_access_token=$HF_TOKEN" + model_name=llama3.1-70b \ + tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ + load_parameters_path=gs://path/to/checkpoint/0/items \ + run_name=$WORKLOAD \ + base_output_directory=$OUTPUT_PATH \ + hf_access_token=$HF_TOKEN" ``` The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows: diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 5667b5063..67975e1c2 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -71,6 +71,7 @@ value_proj: 'offload' checkpoint_storage_use_ocdbt: False # For Pathways checkpoint_storage_use_zarr3: False # For Pathways use_pathways: True +log_period: 20 # ====== Debugging ====== debug: diff --git a/src/MaxText/examples/install_tunix_vllm_requirement.sh b/src/MaxText/examples/install_tunix_vllm_requirement.sh index cb7184e28..600327a8a 100644 --- a/src/MaxText/examples/install_tunix_vllm_requirement.sh +++ b/src/MaxText/examples/install_tunix_vllm_requirement.sh @@ -19,34 +19,15 @@ set -e set -x -python -m ensurepip --default-pip - -pip uninstall -y jax jaxlib libtpu - -pip install aiohttp==3.12.15 - -# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically. -pip install keyring keyrings.google-artifactregistry-auth - -# Install vLLM for Jax and TPUs from the artifact registry -VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ - --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ - --extra-index-url https://pypi.org/simple/ \ - --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ - --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ - --find-links https://storage.googleapis.com/libtpu-releases/index.html \ - --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ - --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ - vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu - -# Install tpu-commons from the artifact registry -pip install --no-cache-dir --pre \ - --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ - --extra-index-url https://pypi.org/simple/ \ - --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ - tpu-commons==0.1.2 - -pip install numba==0.61.2 +uv pip uninstall -y jax jaxlib libtpu + +uv pip install aiohttp==3.12.15 + +# Install vLLM for Jax and TPUs +uv pip install vllm-tpu + +uv pip install numba==0.61.2 + +uv pip install qwix==0.1.1 + +uv pip install flax==0.11.1 diff --git a/src/MaxText/examples/sft_qwen3_demo.ipynb b/src/MaxText/examples/sft_qwen3_demo.ipynb index 0ad68860f..93691b002 100644 --- a/src/MaxText/examples/sft_qwen3_demo.ipynb +++ b/src/MaxText/examples/sft_qwen3_demo.ipynb @@ -91,6 +91,11 @@ }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OSPRVbi7n6tB" + }, + "outputs": [], "source": [ "!git clone https://github.com/AI-Hypercomputer/maxtext.git\n", "%cd /content/maxtext\n", @@ -102,30 +107,9 @@ "!uv pip install -e .[tpu] --resolution=lowest\n", "!python3 -m MaxText.install_maxtext_extra_deps\n", "\n", - "# Install vLLM\n", - "!VLLM_TARGET_DEVICE=\"tpu\" pip install --no-cache-dir --pre \\\n", - " --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n", - " --extra-index-url https://pypi.org/simple/ \\\n", - " --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n", - " --extra-index-url https://download.pytorch.org/whl/nightly/cpu \\\n", - " --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n", - " --find-links https://storage.googleapis.com/libtpu-wheels/index.html \\\n", - " --find-links https://storage.googleapis.com/libtpu-releases/index.html \\\n", - " --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \\\n", - " --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \\\n", - " vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu\n", - "!pip install --no-cache-dir --pre \\\n", - " --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \\\n", - " --extra-index-url https://pypi.org/simple/ \\\n", - " --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \\\n", - " --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \\\n", - " tpu-commons==0.1.2" - ], - "metadata": { - "id": "OSPRVbi7n6tB" - }, - "execution_count": null, - "outputs": [] + "# Install vLLM for Jax and TPUs\n", + "!uv pip install vllm-tpu" + ] }, { "cell_type": "markdown", @@ -522,8 +506,8 @@ "execution_count": null, "metadata": { "editable": true, - "tags": [], - "id": "-JtYTPvJZUQN" + "id": "-JtYTPvJZUQN", + "tags": [] }, "outputs": [], "source": [ diff --git a/src/MaxText/rl/evaluate_rl.py b/src/MaxText/rl/evaluate_rl.py index d8c6b6151..29610ef58 100644 --- a/src/MaxText/rl/evaluate_rl.py +++ b/src/MaxText/rl/evaluate_rl.py @@ -68,7 +68,7 @@ def generate_responses( responses = rl_cluster.rollout.generate( prompts, rollout_config=RolloutConfig( - max_tokens_to_generate=tmvp_config.max_target_length, + max_tokens_to_generate=tmvp_config.max_target_length - tmvp_config.max_prefill_predict_length, temperature=eval_strategy["eval_temperature"], top_k=eval_strategy["eval_top_k"], top_p=eval_strategy["eval_top_p"], diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index a690c4c86..76af81bc4 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -26,21 +26,21 @@ # Llama3.1-8B-Instruct python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ - --model_name=llama3.1-8b \ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint/0/items \ - --run_name=$WORKLOAD \ - --base_output_directory=$OUTPUT_PATH \ - --hf_access_token=$HF_TOKEN + model_name=llama3.1-8b \ + tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + load_parameters_path=gs://path/to/checkpoint/0/items \ + run_name=$WORKLOAD \ + base_output_directory=$OUTPUT_PATH \ + hf_access_token=$HF_TOKEN # Llama3.1-70B-Instruct python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ - --model_name=llama3.1-70b \ - --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint/0/items \ - --run_name=$WORKLOAD \ - --base_output_directory=$OUTPUT_PATH \ - --hf_access_token=$HF_TOKEN + model_name=llama3.1-70b \ + tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ + load_parameters_path=gs://path/to/checkpoint/0/items \ + run_name=$WORKLOAD \ + base_output_directory=$OUTPUT_PATH \ + hf_access_token=$HF_TOKEN """ @@ -52,6 +52,7 @@ from flax import nnx from flax.linen import partitioning as nn_partitioning import grain +from etils import epath from vllm.outputs import PoolingRequestOutput # pylint: disable=unused-import import jax @@ -186,6 +187,12 @@ def rl_train(tmvp_config): """ max_logging.log("Starting GRPO Training") + max_logging.log(f"Ensuring TensorBoard log directory exists: {tmvp_config.tensorboard_dir}") + if not epath.Path(tmvp_config.tensorboard_dir).exists(): + epath.Path(tmvp_config.tensorboard_dir).mkdir(parents=True, exist_ok=True) + + if not epath.Path(tmvp_config.checkpoint_dir).exists(): + epath.Path(tmvp_config.checkpoint_dir).mkdir(parents=True) # Number of training steps. max_train_steps = int( @@ -268,7 +275,6 @@ def rl_train(tmvp_config): optimizer = utils_rl.get_optimizer(tmvp_config, max_train_steps) # Setup checkpointing - ckpt_dir = tmvp_config.checkpoint_dir checkpointing_options = ocp.CheckpointManagerOptions( save_interval_steps=tmvp_config.checkpoint_period, max_to_keep=tmvp_config.max_num_checkpoints_to_keep ) @@ -277,10 +283,9 @@ def rl_train(tmvp_config): micro_batch_size = None if tmvp_config.micro_batch_size == -1 else tmvp_config.micro_batch_size # Setup metrics logging - max_logging.log(f"TensorBoard logs directory: {tmvp_config.tensorboard_dir}") - # Metrics logger + max_logging.log(f"Tensorboard logs directory: {tmvp_config.tensorboard_dir}") metrics_logging_options = metrics_logger.MetricsLoggerOptions( - log_dir=tmvp_config.tensorboard_dir, flush_every_n_steps=tmvp_config.log_period + log_dir=tmvp_config.tensorboard_dir, flush_every_n_steps=tmvp_config.log_period ) profiler_options = None @@ -316,7 +321,7 @@ def rl_train(tmvp_config): # Profiling profiler_options=profiler_options, # Checkpoint saving - checkpoint_root_directory=ckpt_dir, + checkpoint_root_directory=tmvp_config.checkpoint_dir, checkpointing_options=checkpointing_options, ), rollout_config=base_rollout.RolloutConfig( @@ -330,7 +335,7 @@ def rl_train(tmvp_config): rollout_vllm_hbm_utilization=tmvp_config.hbm_utilization_vllm, rollout_vllm_tpu_backend_type="jax", rollout_vllm_swap_space_size_gb=tmvp_config.swap_space_vllm_gb, - ), + ), ) grpo_config = GrpoConfig( num_generations=tmvp_config.num_generations, diff --git a/src/MaxText/rl/utils_rl.py b/src/MaxText/rl/utils_rl.py index 3422cd01a..8b1cede37 100644 --- a/src/MaxText/rl/utils_rl.py +++ b/src/MaxText/rl/utils_rl.py @@ -172,7 +172,7 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): return scores -def extract_hash_answer(text: str, debug: bool = False) -> str | None: +def extract_hash_answer(text: str) -> str | None: """Function to extract only the answer hash from the text.""" if "####" not in text: return None