Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

ARG BASEIMAGE
FROM ${BASEIMAGE}
ARG MODE
ENV MODE=$MODE

RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
RUN pip uninstall -y jax jaxlib libtpu

RUN pip install aiohttp==3.12.15

# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
RUN pip install keyring keyrings.google-artifactregistry-auth

RUN pip install numba==0.61.2

COPY tunix /tunix
RUN pip install -e /tunix --no-cache-dir


COPY vllm /vllm
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html


COPY tpu-inference /tpu-inference
RUN pip install -e /tpu-inference --no-cache-dir --pre \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html


RUN if [ "$MODE" = "post-training-experimental" ]; then \
echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \
pip uninstall -y jax jaxlib libtpu && \
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
fi
57 changes: 37 additions & 20 deletions dependencies/scripts/docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -20,14 +20,15 @@
# bash docker_build_dependency_image.sh MODE=nightly
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
# Nightly build with JAX_VERSION for GPUs. Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax:
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
# MODE=custom_wheels is the same as nightly except that it reinstalls any
# additional wheels that are present in the maxtext directory.
# The main use case is to install custom jax or jaxlib wheels but it also
# works with any custom wheels.
# bash docker_build_dependency_image.sh MODE=custom_wheels

# bash docker_build_dependency_image.sh MODE=post-training
# bash docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local

if [ "${BASH_SOURCE-}" ]; then
this_file="${BASH_SOURCE[0]}"
Expand Down Expand Up @@ -97,6 +98,12 @@ if [[ -z ${DEVICE} ]]; then
echo "Default DEVICE=${DEVICE}"
fi

# New flag for post-training source
if [[ -z ${POST_TRAINING_SOURCE} ]]; then
export POST_TRAINING_SOURCE=remote # Default to the original Dockerfile
echo "Default POST_TRAINING_SOURCE=${POST_TRAINING_SOURCE}"
fi

# Function to build with MODE=jax_ai_image
build_ai_image() {
if [[ -z ${BASEIMAGE+x} ]]; then
Expand Down Expand Up @@ -171,24 +178,34 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
exit 1
fi

# # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__.
# # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext).
# rsync -a --exclude='__pycache__' ../tpu_commons .
# # To install vllm from a local path, we copy it into the build context, excluding __pycache__.
# # This assumes vllm is a sibling directory to the current one (maxtext).
# rsync -a --exclude='__pycache__' ../vllm .

# rsync -a --exclude='__pycache__' ../tunix .

# # The cleanup is set to run even if the build fails to remove the copied directory.
# trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM

docker build \
--network host \
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
--build-arg MODE=${MODE} \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile' \
-t ${LOCAL_IMAGE_NAME} .
DOCKERFILE_NAME=""
if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then

# To install tpu-inference from a local path, we copy it into the build context, excluding __pycache__.
# This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext).
rsync -a --exclude='__pycache__' ../tpu-inference .
# To install vllm from a local path, we copy it into the build context, excluding __pycache__.
# This assumes vllm is a sibling directory to the current one (maxtext).
rsync -a --exclude='__pycache__' ../vllm .

rsync -a --exclude='__pycache__' ../tunix .

# The cleanup is set to run even if the build fails to remove the copied directory.
trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM

DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile'
echo "Using local post-training dependencies Dockerfile: $DOCKERFILE_NAME"
else
DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile'
echo "Using remote post-training dependencies Dockerfile: $DOCKERFILE_NAME"
fi

docker build \
--network host \
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
--build-arg MODE=${MODE} \
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/'"$DOCKERFILE_NAME" \
-t ${LOCAL_IMAGE_NAME} .
fi

if [[ ${CUSTOM_JAX} -eq 1 ]] ; then
Expand Down
12 changes: 6 additions & 6 deletions docs/tutorials/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ Finally, run the command

```
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
--model_name=llama3.1-8b \
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
--load_parameters_path=gs://path/to/checkpoint/0/items \
--run_name=$WORKLOAD \
--base_output_directory=$OUTPUT_PATH \
--hf_access_token=$HF_TOKEN
model_name=llama3.1-8b \
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
load_parameters_path=gs://path/to/checkpoint/0/items \
run_name=$WORKLOAD \
base_output_directory=$OUTPUT_PATH \
hf_access_token=$HF_TOKEN
```

The overview of the demo script is as follows:
Expand Down
12 changes: 6 additions & 6 deletions docs/tutorials/grpo_with_pathways.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ xpk workload create-pathways --workload $WORKLOAD \
--project=$PROJECT_ID --priority=high \
--command "HF_TOKEN=$HF_TOKEN TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' # Llama3.1-70B-Instruct
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
--model_name=llama3.1-70b \
--tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
--load_parameters_path=gs://path/to/checkpoint/0/items \
--run_name=$WORKLOAD \
--base_output_directory=$OUTPUT_PATH \
--hf_access_token=$HF_TOKEN"
model_name=llama3.1-70b \
tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
load_parameters_path=gs://path/to/checkpoint/0/items \
run_name=$WORKLOAD \
base_output_directory=$OUTPUT_PATH \
hf_access_token=$HF_TOKEN"
```

The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows:
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/integration/tunix/tunix_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from jax import Array
from flax import nnx
from MaxText.layers.models import Transformer
from maxtext.src.maxtext.integration.tunix.utils import VllmWeightMapping
from MaxText.integration.tunix.utils import VllmWeightMapping
from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS # pylint: disable=ungrouped-imports


Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/integration/tunix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import re

import maxtext.src.maxtext.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import
from MaxText.integration.tunix import weight_mapping # pylint: disable=consider-using-from-import
from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING
from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS

Expand Down
4 changes: 2 additions & 2 deletions src/MaxText/integration/tunix/weight_mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
model name. This allows for easy extension to support new models.
"""

from maxtext.src.maxtext.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
from maxtext.src.maxtext.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING


class StandaloneVllmWeightMapping:
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/rl/evaluate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def generate_responses(
responses = rl_cluster.rollout.generate(
prompts,
rollout_config=RolloutConfig(
max_tokens_to_generate=tmvp_config.max_target_length,
max_tokens_to_generate=tmvp_config.max_target_length - tmvp_config.max_prefill_predict_length,
temperature=eval_strategy["eval_temperature"],
top_k=eval_strategy["eval_top_k"],
top_p=eval_strategy["eval_top_p"],
Expand Down
32 changes: 19 additions & 13 deletions src/MaxText/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@
# Llama3.1-8B-Instruct
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
--model_name=llama3.1-8b \
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
--load_parameters_path=gs://path/to/checkpoint/0/items \
--run_name=$WORKLOAD \
--base_output_directory=$OUTPUT_PATH \
--hf_access_token=$HF_TOKEN
model_name=llama3.1-8b \
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
load_parameters_path=gs://path/to/checkpoint/0/items \
run_name=$WORKLOAD \
base_output_directory=$OUTPUT_PATH \
hf_access_token=$HF_TOKEN
# Llama3.1-70B-Instruct
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
--model_name=llama3.1-70b \
--tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
--load_parameters_path=gs://path/to/checkpoint/0/items \
--run_name=$WORKLOAD \
--base_output_directory=$OUTPUT_PATH \
--hf_access_token=$HF_TOKEN
model_name=llama3.1-70b \
tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
load_parameters_path=gs://path/to/checkpoint/0/items \
run_name=$WORKLOAD \
base_output_directory=$OUTPUT_PATH \
hf_access_token=$HF_TOKEN
"""

Expand All @@ -52,6 +52,7 @@
from flax import nnx
from flax.linen import partitioning as nn_partitioning
import grain
from etils import epath

from vllm.outputs import PoolingRequestOutput # pylint: disable=unused-import
import jax
Expand Down Expand Up @@ -186,6 +187,12 @@ def rl_train(tmvp_config):
"""

max_logging.log("Starting GRPO Training")
if not os.path.exists(tmvp_config.tensorboard_dir):
epath.Path(tmvp_config.tensorboard_dir).mkdir(parents=True)

ckpt_dir = tmvp_config.checkpoint_dir
if not epath.Path(ckpt_dir).exists():
epath.Path(ckpt_dir).mkdir(parents=True)

# Number of training steps.
max_train_steps = int(
Expand Down Expand Up @@ -266,7 +273,6 @@ def rl_train(tmvp_config):
optimizer = utils_rl.get_optimizer(tmvp_config, max_train_steps)

# Setup checkpointing
ckpt_dir = tmvp_config.checkpoint_dir
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=tmvp_config.checkpoint_period, max_to_keep=tmvp_config.max_num_checkpoints_to_keep
)
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/rl/utils_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
return scores


def extract_hash_answer(text: str, debug: bool = False) -> str | None:
def extract_hash_answer(text: str) -> str | None:
"""Function to extract only the answer hash from the text."""
if "####" not in text:
return None
Expand Down
Loading