|
6 | 6 | # you may not use this file except in compliance with the License. |
7 | 7 | # You may obtain a copy of the License at |
8 | 8 | # |
9 | | -# https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# https://www.apache.org/licenses/LICENSE-2.0 |
10 | 10 | # |
11 | 11 | # Unless required by applicable law or agreed to in writing, software |
12 | 12 | # distributed under the License is distributed on an "AS IS" BASIS, |
|
20 | 20 | # bash docker_build_dependency_image.sh MODE=nightly |
21 | 21 | # bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13 |
22 | 22 | # Nightly build with JAX_VERSION for GPUs. Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax: |
23 | | -# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly |
| 23 | +# bash docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 # Note: this sets both jax-nightly and jaxlib-nightly |
24 | 24 | # MODE=custom_wheels is the same as nightly except that it reinstalls any |
25 | 25 | # additional wheels that are present in the maxtext directory. |
26 | 26 | # The main use case is to install custom jax or jaxlib wheels but it also |
27 | 27 | # works with any custom wheels. |
28 | 28 | # bash docker_build_dependency_image.sh MODE=custom_wheels |
29 | 29 |
|
30 | 30 | # bash docker_build_dependency_image.sh MODE=post-training |
| 31 | +# bash docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local |
31 | 32 |
|
32 | 33 | if [ "${BASH_SOURCE-}" ]; then |
33 | 34 | this_file="${BASH_SOURCE[0]}" |
@@ -97,6 +98,12 @@ if [[ -z ${DEVICE} ]]; then |
97 | 98 | echo "Default DEVICE=${DEVICE}" |
98 | 99 | fi |
99 | 100 |
|
| 101 | +# New flag for post-training source |
| 102 | +if [[ -z ${POST_TRAINING_SOURCE} ]]; then |
| 103 | + export POST_TRAINING_SOURCE=remote # Default to the original Dockerfile |
| 104 | + echo "Default POST_TRAINING_SOURCE=${POST_TRAINING_SOURCE}" |
| 105 | +fi |
| 106 | + |
100 | 107 | # Function to build with MODE=jax_ai_image |
101 | 108 | build_ai_image() { |
102 | 109 | if [[ -z ${BASEIMAGE+x} ]]; then |
@@ -170,25 +177,36 @@ if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then |
170 | 177 | echo "Error: MODE=post-training is only supported for DEVICE=tpu" |
171 | 178 | exit 1 |
172 | 179 | fi |
173 | | - |
174 | | - # # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__. |
175 | | - # # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext). |
176 | | - # rsync -a --exclude='__pycache__' ../tpu_commons . |
177 | | - # # To install vllm from a local path, we copy it into the build context, excluding __pycache__. |
178 | | - # # This assumes vllm is a sibling directory to the current one (maxtext). |
179 | | - # rsync -a --exclude='__pycache__' ../vllm . |
180 | | - |
181 | | - # rsync -a --exclude='__pycache__' ../tunix . |
182 | | - |
183 | | - # # The cleanup is set to run even if the build fails to remove the copied directory. |
184 | | - # trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM |
185 | | - |
186 | | - docker build \ |
187 | | - --network host \ |
188 | | - --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ |
189 | | - --build-arg MODE=${MODE} \ |
190 | | - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile' \ |
191 | | - -t ${LOCAL_IMAGE_NAME} . |
| 180 | + fi |
| 181 | + |
| 182 | + DOCKERFILE_NAME="" |
| 183 | + if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then |
| 184 | + |
| 185 | + # To install tpu-inference from a local path, we copy it into the build context, excluding __pycache__. |
| 186 | + # This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext). |
| 187 | + rsync -a --exclude='__pycache__' ../tpu-inference . |
| 188 | + # To install vllm from a local path, we copy it into the build context, excluding __pycache__. |
| 189 | + # This assumes vllm is a sibling directory to the current one (maxtext). |
| 190 | + rsync -a --exclude='__pycache__' ../vllm . |
| 191 | + |
| 192 | + rsync -a --exclude='__pycache__' ../tunix . |
| 193 | + |
| 194 | + # The cleanup is set to run even if the build fails to remove the copied directory. |
| 195 | + trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM |
| 196 | + |
| 197 | + DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile' |
| 198 | + echo "Using local post-training dependencies Dockerfile: $DOCKERFILE_NAME" |
| 199 | + else |
| 200 | + DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile' |
| 201 | + echo "Using remote post-training dependencies Dockerfile: $DOCKERFILE_NAME" |
| 202 | + fi |
| 203 | + |
| 204 | + docker build \ |
| 205 | + --network host \ |
| 206 | + --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ |
| 207 | + --build-arg MODE=${MODE} \ |
| 208 | + -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/'"$DOCKERFILE_NAME" \ |
| 209 | + -t ${LOCAL_IMAGE_NAME} . |
192 | 210 | fi |
193 | 211 |
|
194 | 212 | if [[ ${CUSTOM_JAX} -eq 1 ]] ; then |
|
0 commit comments