Skip to content

Commit

Permalink
Merge pull request #24300 from ROCm:ci_rocm_readme
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686872994
  • Loading branch information
Google-ML-Automation committed Oct 17, 2024
2 parents 36ec513 + 3c3b08d commit 3bdc57d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ Some standouts:
| CPU | yes | yes | yes | yes | yes | yes |
| NVIDIA GPU | yes | yes | no | n/a | no | experimental |
| Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
| AMD GPU | experimental | no | no | n/a | no | no |
| AMD GPU | yes | no | no | n/a | no | no |
| Apple GPU | n/a | no | experimental | experimental | n/a | n/a |


Expand All @@ -399,7 +399,7 @@ Some standouts:
| CPU | `pip install -U jax` |
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |

See [the documentation](https://jax.readthedocs.io/en/latest/installation.html)
Expand Down
34 changes: 17 additions & 17 deletions build/rocm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,34 @@ This directory contains files and setup instructions to build and test JAX for R

1. Install Docker: Follow the [instructions on the docker website](https://docs.docker.com/engine/installation/).

2. Build a runtime JAX-ROCm docker container and keep this image by running the following command. Note: must pass in Python version. The example below builds Python 3.9 container.
2. Build a runtime JAX-ROCm docker container and keep this image by running the following command. Note: must pass in appropriate
options. The example below builds Python 3.12 container.

./build/rocm/ci_build.sh --keep_image --py_version==3.9.0 --runtime bash -c "./build/rocm/build_rocm.sh"
```Bash
./build/rocm/ci_build.sh --py_version 3.12
```

3. To launch a JAX-ROCm container: If the build was successful, there should be a docker image with name "jax-rocm:latest" in list of docker images (use "docker images" command to list them).
```
sudo docker run -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --entrypoint /bin/bash jax-rocm:latest

```Bash
docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v ./:/jax --name rocm_jax jax-rocm:latest /bin/bash
```

***
### JAX ROCm Releases
We strive to push all ROCm related changes to the OpenXLA repository. However, at times some JAX/JAXLIB changes for ROCm may not be present in upstream JAX repo.Therefore, we have ROCm Jax/Jaxlib branches that are associated with a Jaxlib release. These
are available in ROCm fork of JAX https://github.com/ROCmSoftwarePlatform/jax. See branches named as rocm-jaxlib-[jaxlib-version]. For examples, for jaxlib-v0.4.10, the branch is named rocm-jaxlib-v0.4.10. See path https://github.com/ROCmSoftwarePlatform/jax/tree/rocm-jaxlib-v0.4.10
We aim to push all ROCm-related changes to the OpenXLA repository. However, there may be times when certain JAX/jaxlib updates for
ROCm are not yet reflected in the upstream JAX repository. To address this, we maintain ROCm-specific JAX/jaxlib branches tied to JAX
releases. These branches are available in the ROCm fork of JAX at https://github.com/ROCm/jax. Look for branches named in the format
rocm-jaxlib-[jaxlib-version]. You can also find corresponding branches in https://github.com/ROCm/xla. For example, for JAX version
0.4.33, the branch is named rocm-jaxlib-v0.4.33, which can be accessed at https://github.com/ROCm/jax/tree/rocm-jaxlib-v0.4.33.

JAX and Jaxlib wheels for ROCm are available here
```
https://github.com/ROCmSoftwarePlatform/jax/releases
JAX source-code and related wheels for ROCm are available here

```Bash
https://github.com/ROCm/jax/releases
```

***Note:*** Some earlier jaxlib versions on ROCm were released on ***PyPi***.
```
https://pypi.org/project/jaxlib-rocm/#history
```
However, due to strict naming PyPI requirement we had to name our wheels slightly differently. This would then result in Jax/Jaxlib dependent not recognizing jaxlib-rocm wheels and would end up with multiple jaxlib installations and also runtime issues


***
### XLA for JAX ROCm
We strive to push all ROCm related changes to the OpenXLA repository. However, at times some XLA changes for ROCm may not be upstreamed to XLA repo.Therefore, we have ROCm XLA branches that are associated with a Jaxlib release. These are available in ROCm fork of XLA here https://github.com/ROCmSoftwarePlatform/xla. See branches named as rocm-jaxlib-[jaxlib version]. For example, for jaxlib-v0.4.10, the branch is named rocm-jaxlib-v0.4.10. See path https://github.com/ROCmSoftwarePlatform/xla/tree/rocm-jaxlib-v0.4.10


12 changes: 8 additions & 4 deletions docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,28 +195,32 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs
```
The recommended way to install these dependencies is by running our script, `jax/build/rocm/tools/get_rocm.py`,
and selecting the appropriate options.
To build jaxlib with ROCM support, you can run the following build command,
suitably adjusted for your paths and ROCM version.
```
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0
python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=/opt/rocm-6.2.3
```
AMD's fork of the XLA repository may include fixes not present in the upstream
XLA repository. If you experience problems with the upstream repository, you can
try AMD's fork, by cloning their repository:
```
git clone https://github.com/ROCmSoftwarePlatform/xla.git
git clone https://github.com/ROCm/xla.git
```
and override the XLA repository with which JAX is built:
```
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0 \
--bazel_options=--override_repository=xla=/path/to/xla-rocm
python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --bazel_options=--override_repository=xla=/rel/xla/ --rocm_path=/opt/rocm-6.2.3
```
For a simplified installation process, we also recommend checking out the `jax/build/rocm/dev_build_rocm.py script`.
## Managing hermetic Python
To make sure that JAX's build is reproducible, behaves uniformly across
Expand Down

0 comments on commit 3bdc57d

Please sign in to comment.