Skip to content

Commit

Permalink
Merge pull request #1124 from AI-Hypercomputer:lance-bumpup-te
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709158692
  • Loading branch information
maxtext authors committed Dec 23, 2024
2 parents 087f746 + 0299f7a commit 77f6459
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 19 deletions.
18 changes: 9 additions & 9 deletions constraints_gpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ importlib_metadata==8.4.0
importlib_resources==6.4.5
iniconfig==2.0.0
isort==5.13.2
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.34
jax==0.4.38
jax-cuda12-pjrt==0.4.38
jax-cuda12-plugin==0.4.38
jaxlib==0.4.38
jaxtyping==0.2.34
Jinja2==3.1.4
jsonlines==4.0.0
Expand Down Expand Up @@ -108,16 +108,16 @@ networkx==3.4.2
ninja==1.11.1.1
nodeenv==1.9.1
numpy==1.26.4
nvidia-cublas-cu12==12.6.3.3
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvcc-cu12==12.6.77
nvidia-cuda-nvcc-cu12==12.6.85
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.0.50
nvidia-cudnn-cu12==9.6.0.74
nvidia-cufft-cu12==11.3.0.4
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-nccl-cu12==2.23.4
nvidia-nvjitlink-cu12==12.6.77
nvidia-nvjitlink-cu12==12.6.85
oauthlib==3.2.2
opentelemetry-api==1.27.0
opt_einsum==3.4.0
Expand Down Expand Up @@ -196,7 +196,7 @@ tomli==2.0.2
tomlkit==0.13.2
toolz==1.0.0
tqdm==4.66.5
transformer-engine==1.5.0+297459b
transformer-engine==1.13.0+e5edd6c
transformers==4.46.0
typeguard==2.13.3
typing_extensions==4.12.2
Expand Down
3 changes: 2 additions & 1 deletion docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
set -e

export LOCAL_IMAGE_NAME=maxtext_base_image
echo "Building to $LOCAL_IMAGE_NAME"

# Use Docker BuildKit so we can cache pip packages.
export DOCKER_BUILDKIT=1
Expand Down Expand Up @@ -90,7 +91,7 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
build_stable_stack
else
if [[ ${MODE} == "pinned" ]]; then
export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-05-07
export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-12-04
else
export BASEIMAGE=ghcr.io/nvidia/jax:base
fi
Expand Down
2 changes: 1 addition & 1 deletion maxtext_gpu_dependencies.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ RUN apt-get update && apt-get install -y google-cloud-sdk
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"

# Upgrade libcusprase to work with Jax
RUN apt-get update && apt-get install -y libcusparse-12-3
RUN apt-get update && apt-get install -y libcusparse-12-6

ARG MODE
ENV ENV_MODE=$MODE
Expand Down
3 changes: 2 additions & 1 deletion maxtext_transformerengine_builder.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ ENV NVTE_FRAMEWORK=jax

RUN git clone https://github.com/NVIDIA/TransformerEngine
WORKDIR /root/TransformerEngine
RUN git checkout 297459bd08e1b791ca7a2872cfa8582220477782
RUN git pull
RUN git checkout e5edd6c
RUN git submodule update --init --recursive
RUN python setup.py bdist_wheel
13 changes: 6 additions & 7 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_SUSPEND=1
export NEEDRESTART_MODE=l


apt-get update && apt-get install -y sudo
(sudo bash || bash) <<'EOF'
apt update && \
apt install -y numactl lsb-release gnupg curl net-tools iproute2 procps lsof git ethtool && \
Expand Down Expand Up @@ -90,9 +90,9 @@ run_name_folder_path=$(pwd)
# Install dependencies from requirements.txt
cd $run_name_folder_path && pip install --upgrade pip
if [[ "$MODE" == "pinned" ]]; then
pip3 install -U -r requirements.txt -c constraints_gpu.txt
pip3 install --no-cache-dir -U -r requirements.txt -c constraints_gpu.txt
else
pip3 install -U -r requirements.txt
pip3 install --no-cache-dir -U -r requirements.txt
fi

# Uninstall existing jax, jaxlib and libtpu-nightly
Expand All @@ -110,11 +110,10 @@ if [[ "$MODE" == "pinned" ]]; then
echo "pinned mode is supported for GPU builds only."
exit 1
fi
echo "Installing pinned jax, jaxlib for NVIDIA gpu."
echo "Installing Jax and Transformer Engine."
pip3 install "jax[cuda12]" -c constraints_gpu.txt
pip3 install "transformer-engine==1.5.0+297459b" \
--extra-index-url https://us-python.pkg.dev/gce-ai-infra/maxtext-build-support-packages/simple/ \
-c constraints_gpu.txt
pip install transformer-engine[jax]==1.13.0

elif [[ "$MODE" == "stable" || ! -v MODE ]]; then
# Stable mode
if [[ $DEVICE == "tpu" ]]; then
Expand Down

0 comments on commit 77f6459

Please sign in to comment.