Skip to content

Commit 56bf0d1

Browse files
committed
updates
1 parent be7c2de commit 56bf0d1

File tree

10 files changed

+132
-51
lines changed

10 files changed

+132
-51
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
ARG BASEIMAGE
16+
FROM ${BASEIMAGE}
17+
ARG MODE
18+
ENV MODE=$MODE
19+
20+
RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"
21+
RUN pip uninstall -y jax jaxlib libtpu
22+
23+
RUN pip install aiohttp==3.12.15
24+
25+
# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
26+
RUN pip install keyring keyrings.google-artifactregistry-auth
27+
28+
RUN pip install numba==0.61.2
29+
30+
COPY tunix /tunix
31+
RUN pip install -e /tunix --no-cache-dir
32+
33+
34+
COPY vllm /vllm
35+
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
36+
--extra-index-url https://pypi.org/simple/ \
37+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
38+
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
39+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
40+
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
41+
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
42+
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
43+
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
44+
45+
46+
COPY tpu-inference /tpu-inference
47+
RUN pip install -e /tpu-inference --no-cache-dir --pre \
48+
--extra-index-url https://pypi.org/simple/ \
49+
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
50+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
51+
52+
53+
RUN if [ "$MODE" = "grpo-experimental" ]; then \
54+
echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \
55+
pip uninstall -y jax jaxlib libtpu && \
56+
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
57+
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
58+
fi

dependencies/scripts/docker_build_dependency_image.sh

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# you may not use this file except in compliance with the License.
77
# You may obtain a copy of the License at
88
#
9-
# https://www.apache.org/licenses/LICENSE-2.0
9+
# https://www.apache.org/licenses/LICENSE-2.0
1010
#
1111
# Unless required by applicable law or agreed to in writing, software
1212
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -20,14 +20,15 @@
2020
# bash docker_build_dependency_image.sh MODE=nightly
2121
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
2222
# Nightly build with JAX_VERSION for GPUs. Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax:
23-
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
23+
# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly
2424
# MODE=custom_wheels is the same as nightly except that it reinstalls any
2525
# additional wheels that are present in the maxtext directory.
2626
# The main use case is to install custom jax or jaxlib wheels but it also
2727
# works with any custom wheels.
2828
# bash docker_build_dependency_image.sh MODE=custom_wheels
2929

3030
# bash docker_build_dependency_image.sh MODE=post-training
31+
# bash docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local
3132

3233
if [ "${BASH_SOURCE-}" ]; then
3334
this_file="${BASH_SOURCE[0]}"
@@ -97,6 +98,12 @@ if [[ -z ${DEVICE} ]]; then
9798
echo "Default DEVICE=${DEVICE}"
9899
fi
99100

101+
# New flag for post-training source
102+
if [[ -z ${POST_TRAINING_SOURCE} ]]; then
103+
export POST_TRAINING_SOURCE=remote # Default to the original Dockerfile
104+
echo "Default POST_TRAINING_SOURCE=${POST_TRAINING_SOURCE}"
105+
fi
106+
100107
# Function to build with MODE=jax_ai_image
101108
build_ai_image() {
102109
if [[ -z ${BASEIMAGE+x} ]]; then
@@ -170,25 +177,36 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then
170177
echo "Error: MODE=post-training is only supported for DEVICE=tpu"
171178
exit 1
172179
fi
173-
174-
# # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__.
175-
# # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext).
176-
# rsync -a --exclude='__pycache__' ../tpu_commons .
177-
# # To install vllm from a local path, we copy it into the build context, excluding __pycache__.
178-
# # This assumes vllm is a sibling directory to the current one (maxtext).
179-
# rsync -a --exclude='__pycache__' ../vllm .
180-
181-
# rsync -a --exclude='__pycache__' ../tunix .
182-
183-
# # The cleanup is set to run even if the build fails to remove the copied directory.
184-
# trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM
185-
186-
docker build \
187-
--network host \
188-
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
189-
--build-arg MODE=${MODE} \
190-
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile' \
191-
-t ${LOCAL_IMAGE_NAME} .
180+
fi
181+
182+
DOCKERFILE_NAME=""
183+
if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then
184+
185+
# To install tpu-inference from a local path, we copy it into the build context, excluding __pycache__.
186+
# This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext).
187+
rsync -a --exclude='__pycache__' ../tpu-inference .
188+
# To install vllm from a local path, we copy it into the build context, excluding __pycache__.
189+
# This assumes vllm is a sibling directory to the current one (maxtext).
190+
rsync -a --exclude='__pycache__' ../vllm .
191+
192+
rsync -a --exclude='__pycache__' ../tunix .
193+
194+
# The cleanup is set to run even if the build fails to remove the copied directory.
195+
trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM
196+
197+
DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile'
198+
echo "Using local post-training dependencies Dockerfile: $DOCKERFILE_NAME"
199+
else
200+
DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile'
201+
echo "Using remote post-training dependencies Dockerfile: $DOCKERFILE_NAME"
202+
fi
203+
204+
docker build \
205+
--network host \
206+
--build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
207+
--build-arg MODE=${MODE} \
208+
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/'"$DOCKERFILE_NAME" \
209+
-t ${LOCAL_IMAGE_NAME} .
192210
fi
193211

194212
if [[ ${CUSTOM_JAX} -eq 1 ]] ; then

docs/tutorials/grpo.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ Finally, run the command
6262

6363
```
6464
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
65-
--model_name=llama3.1-8b \
66-
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
67-
--load_parameters_path=gs://path/to/checkpoint/0/items \
68-
--run_name=$WORKLOAD \
69-
--base_output_directory=$OUTPUT_PATH \
70-
--hf_access_token=$HF_TOKEN
65+
model_name=llama3.1-8b \
66+
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
67+
load_parameters_path=gs://path/to/checkpoint/0/items \
68+
run_name=$WORKLOAD \
69+
base_output_directory=$OUTPUT_PATH \
70+
hf_access_token=$HF_TOKEN
7171
```
7272

7373
The overview of the demo script is as follows:

docs/tutorials/grpo_with_pathways.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ xpk workload create-pathways --workload $WORKLOAD \
6161
--project=$PROJECT_ID --priority=high \
6262
--command "HF_TOKEN=$HF_TOKEN TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' # Llama3.1-70B-Instruct
6363
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
64-
--model_name=llama3.1-70b \
65-
--tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
66-
--load_parameters_path=gs://path/to/checkpoint/0/items \
67-
--run_name=$WORKLOAD \
68-
--base_output_directory=$OUTPUT_PATH \
69-
--hf_access_token=$HF_TOKEN"
64+
model_name=llama3.1-70b \
65+
tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
66+
load_parameters_path=gs://path/to/checkpoint/0/items \
67+
run_name=$WORKLOAD \
68+
base_output_directory=$OUTPUT_PATH \
69+
hf_access_token=$HF_TOKEN"
7070
```
7171

7272
The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows:

src/MaxText/integration/tunix/tunix_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from jax import Array
2727
from flax import nnx
2828
from MaxText.layers.models import Transformer
29-
from maxtext.src.maxtext.integration.tunix.utils import VllmWeightMapping
29+
from MaxText.integration.tunix.utils import VllmWeightMapping
3030
from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS # pylint: disable=ungrouped-imports
3131

3232

src/MaxText/integration/tunix/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import re
1818

19-
import maxtext.src.maxtext.integration.tunix.weight_mapping as weight_mapping # pylint: disable=consider-using-from-import
19+
from MaxText.integration.tunix import weight_mapping # pylint: disable=consider-using-from-import
2020
from MaxText.utils.ckpt_conversion.utils.param_mapping import PARAM_MAPPING
2121
from MaxText.utils.ckpt_conversion.utils.param_mapping import VLLM_HOOK_FNS
2222

src/MaxText/integration/tunix/weight_mapping/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
model name. This allows for easy extension to support new models.
2020
"""
2121

22-
from maxtext.src.maxtext.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
23-
from maxtext.src.maxtext.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
22+
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
23+
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING
2424

2525

2626
class StandaloneVllmWeightMapping:

src/MaxText/rl/evaluate_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def generate_responses(
6868
responses = rl_cluster.rollout.generate(
6969
prompts,
7070
rollout_config=RolloutConfig(
71-
max_tokens_to_generate=tmvp_config.max_target_length,
71+
max_tokens_to_generate=tmvp_config.max_target_length - tmvp_config.max_prefill_predict_length,
7272
temperature=eval_strategy["eval_temperature"],
7373
top_k=eval_strategy["eval_top_k"],
7474
top_p=eval_strategy["eval_top_p"],

src/MaxText/rl/train_rl.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@
2626
2727
# Llama3.1-8B-Instruct
2828
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
29-
--model_name=llama3.1-8b \
30-
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
31-
--load_parameters_path=gs://path/to/checkpoint/0/items \
32-
--run_name=$WORKLOAD \
33-
--base_output_directory=$OUTPUT_PATH \
34-
--hf_access_token=$HF_TOKEN
29+
model_name=llama3.1-8b \
30+
tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
31+
load_parameters_path=gs://path/to/checkpoint/0/items \
32+
run_name=$WORKLOAD \
33+
base_output_directory=$OUTPUT_PATH \
34+
hf_access_token=$HF_TOKEN
3535
3636
# Llama3.1-70B-Instruct
3737
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
38-
--model_name=llama3.1-70b \
39-
--tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
40-
--load_parameters_path=gs://path/to/checkpoint/0/items \
41-
--run_name=$WORKLOAD \
42-
--base_output_directory=$OUTPUT_PATH \
43-
--hf_access_token=$HF_TOKEN
38+
model_name=llama3.1-70b \
39+
tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
40+
load_parameters_path=gs://path/to/checkpoint/0/items \
41+
run_name=$WORKLOAD \
42+
base_output_directory=$OUTPUT_PATH \
43+
hf_access_token=$HF_TOKEN
4444
4545
"""
4646

@@ -52,6 +52,7 @@
5252
from flax import nnx
5353
from flax.linen import partitioning as nn_partitioning
5454
import grain
55+
from etils import epath
5556

5657
from vllm.outputs import PoolingRequestOutput # pylint: disable=unused-import
5758
import jax
@@ -267,10 +268,14 @@ def rl_train(tmvp_config):
267268

268269
# Setup checkpointing
269270
ckpt_dir = tmvp_config.checkpoint_dir
271+
if not epath.Path(ckpt_dir).exists():
272+
epath.Path(ckpt_dir).mkdir(parents=True)
270273
checkpointing_options = ocp.CheckpointManagerOptions(
271274
save_interval_steps=tmvp_config.checkpoint_period, max_to_keep=tmvp_config.max_num_checkpoints_to_keep
272275
)
273276

277+
if not os.path.exists(tmvp_config.tensorboard_dir):
278+
epath.Path(tmvp_config.tensorboard_dir).mkdir(parents=True)
274279
# Setup metrics logging
275280
max_logging.log(f"TensorBoard logs directory: {tmvp_config.tensorboard_dir}")
276281
# Metrics logger

src/MaxText/rl/utils_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
172172
return scores
173173

174174

175-
def extract_hash_answer(text: str, debug: bool = False) -> str | None:
175+
def extract_hash_answer(text: str) -> str | None:
176176
"""Function to extract only the answer hash from the text."""
177177
if "####" not in text:
178178
return None

0 commit comments

Comments
 (0)