From 3c3b08dfd63a685f959927026ee47061d46cbffe Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Mon, 14 Oct 2024 15:27:15 -0500 Subject: [PATCH] [ROCm] Fix README.md to update AMD JAX installation instructions --- README.md | 4 ++-- build/rocm/README.md | 34 +++++++++++++++++----------------- docs/developer.md | 12 ++++++++---- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 79cf8fc6d358..058db9158029 100644 --- a/README.md +++ b/README.md @@ -388,7 +388,7 @@ Some standouts: | CPU | yes | yes | yes | yes | yes | yes | | NVIDIA GPU | yes | yes | no | n/a | no | experimental | | Google TPU | yes | n/a | n/a | n/a | n/a | n/a | -| AMD GPU | experimental | no | no | n/a | no | no | +| AMD GPU | yes | no | no | n/a | no | no | | Apple GPU | n/a | no | experimental | experimental | n/a | n/a | @@ -399,7 +399,7 @@ Some standouts: | CPU | `pip install -U jax` | | NVIDIA GPU | `pip install -U "jax[cuda12]"` | | Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` | -| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). | +| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). | | Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | See [the documentation](https://jax.readthedocs.io/en/latest/installation.html) diff --git a/build/rocm/README.md b/build/rocm/README.md index 0d5f850af112..8cfe10e2d8b2 100644 --- a/build/rocm/README.md +++ b/build/rocm/README.md @@ -5,34 +5,34 @@ This directory contains files and setup instructions to build and test JAX for R 1. Install Docker: Follow the [instructions on the docker website](https://docs.docker.com/engine/installation/). -2. Build a runtime JAX-ROCm docker container and keep this image by running the following command. Note: must pass in Python version. The example below builds Python 3.9 container. +2. Build a runtime JAX-ROCm docker container and keep this image by running the following command. Note: must pass in appropriate +options. The example below builds Python 3.12 container. - ./build/rocm/ci_build.sh --keep_image --py_version==3.9.0 --runtime bash -c "./build/rocm/build_rocm.sh" +```Bash +./build/rocm/ci_build.sh --py_version 3.12 +``` 3. To launch a JAX-ROCm container: If the build was successful, there should be a docker image with name "jax-rocm:latest" in list of docker images (use "docker images" command to list them). -``` -sudo docker run -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --entrypoint /bin/bash jax-rocm:latest + +```Bash +docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v ./:/jax --name rocm_jax jax-rocm:latest /bin/bash ``` *** ### JAX ROCm Releases -We strive to push all ROCm related changes to the OpenXLA repository. However, at times some JAX/JAXLIB changes for ROCm may not be present in upstream JAX repo.Therefore, we have ROCm Jax/Jaxlib branches that are associated with a Jaxlib release. These -are available in ROCm fork of JAX https://github.com/ROCmSoftwarePlatform/jax. See branches named as rocm-jaxlib-[jaxlib-version]. For examples, for jaxlib-v0.4.10, the branch is named rocm-jaxlib-v0.4.10. See path https://github.com/ROCmSoftwarePlatform/jax/tree/rocm-jaxlib-v0.4.10 +We aim to push all ROCm-related changes to the OpenXLA repository. However, there may be times when certain JAX/jaxlib updates for +ROCm are not yet reflected in the upstream JAX repository. To address this, we maintain ROCm-specific JAX/jaxlib branches tied to JAX +releases. These branches are available in the ROCm fork of JAX at https://github.com/ROCm/jax. Look for branches named in the format +rocm-jaxlib-[jaxlib-version]. You can also find corresponding branches in https://github.com/ROCm/xla. For example, for JAX version +0.4.33, the branch is named rocm-jaxlib-v0.4.33, which can be accessed at https://github.com/ROCm/jax/tree/rocm-jaxlib-v0.4.33. -JAX and Jaxlib wheels for ROCm are available here -``` -https://github.com/ROCmSoftwarePlatform/jax/releases +JAX source-code and related wheels for ROCm are available here + +```Bash +https://github.com/ROCm/jax/releases ``` ***Note:*** Some earlier jaxlib versions on ROCm were released on ***PyPi***. ``` https://pypi.org/project/jaxlib-rocm/#history ``` -However, due to strict naming PyPI requirement we had to name our wheels slightly differently. This would then result in Jax/Jaxlib dependent not recognizing jaxlib-rocm wheels and would end up with multiple jaxlib installations and also runtime issues - - -*** -### XLA for JAX ROCm -We strive to push all ROCm related changes to the OpenXLA repository. However, at times some XLA changes for ROCm may not be upstreamed to XLA repo.Therefore, we have ROCm XLA branches that are associated with a Jaxlib release. These are available in ROCm fork of XLA here https://github.com/ROCmSoftwarePlatform/xla. See branches named as rocm-jaxlib-[jaxlib version]. For example, for jaxlib-v0.4.10, the branch is named rocm-jaxlib-v0.4.10. See path https://github.com/ROCmSoftwarePlatform/xla/tree/rocm-jaxlib-v0.4.10 - - diff --git a/docs/developer.md b/docs/developer.md index 5f57b2499860..68e8e931e2e5 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -195,11 +195,14 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \ rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs ``` +The recommended way to install these dependencies is by running our script, `jax/build/rocm/tools/get_rocm.py`, +and selecting the appropriate options. + To build jaxlib with ROCM support, you can run the following build command, suitably adjusted for your paths and ROCM version. ``` -python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0 +python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=/opt/rocm-6.2.3 ``` AMD's fork of the XLA repository may include fixes not present in the upstream @@ -207,16 +210,17 @@ XLA repository. If you experience problems with the upstream repository, you can try AMD's fork, by cloning their repository: ``` -git clone https://github.com/ROCmSoftwarePlatform/xla.git +git clone https://github.com/ROCm/xla.git ``` and override the XLA repository with which JAX is built: ``` -python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0 \ - --bazel_options=--override_repository=xla=/path/to/xla-rocm +python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --bazel_options=--override_repository=xla=/rel/xla/ --rocm_path=/opt/rocm-6.2.3 ``` +For a simplified installation process, we also recommend checking out the `jax/build/rocm/dev_build_rocm.py script`. + ## Managing hermetic Python To make sure that JAX's build is reproducible, behaves uniformly across