Skip to content

Use bazel for PR tests #115

Use bazel for PR tests

Use bazel for PR tests #115

Workflow file for this run

name: ROCm GPU CI
on:
# Trigger the workflow on push or pull request,
# but only for the rocm-main branch
push:
branches:
- rocm-main
pull_request:
branches:
- rocm-main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
build-jax-in-docker: # strategy and matrix come here
runs-on: mi-250
env:
BASE_IMAGE: "ubuntu:22.04"
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
PYTHON_VERSION: "3.10"
ROCM_VERSION: "6.2.4"
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
steps:
- name: Clean up old runs
run: |
ls
# Make sure that we own all of the files so that we have permissions to delete them
docker run -v "./:/jax" ubuntu /bin/bash -c "chown -R $UID /jax/workdir_* || true"
# Remove any old work directories from this machine
rm -rf workdir_*
ls
- name: Print system info
run: |
whoami
printenv
df -h
rocm-smi
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: ${{ env.WORKSPACE_DIR }}
- name: Build JAX
run: |
pushd $WORKSPACE_DIR
python3 build/rocm/ci_build \
--rocm-version $ROCM_VERSION \
--base-docker $BASE_IMAGE \
--python-versions $PYTHON_VERSION \
--compiler=clang \
dist_docker \
--image-tag $TEST_IMAGE
- name: Archive jax wheels
uses: actions/upload-artifact@v4
with:
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
path: ./dist/*.whl
- name: Run tests
env:
GPU_COUNT: "8"
GFX: "gfx90a"
run: |
cd $WORKSPACE_DIR
bazel_cmd=(
wget "https://github.com/bazelbuild/bazelisk/releases/download/v1.25.0/bazelisk-linux-amd64"
"&&"
chmod +x bazelisk-linux-amd64
"&&"
./bazelisk-linux-amd64 test -k
--jobs=4
--test_verbose_timeout_warnings=true
--test_output=all
--test_summary=detailed
--local_test_jobs=1
--test_env=JAX_ACCELERATOR_COUNT=${GPU_COUNT}
--test_env=JAX_TESTS_PER_ACCELERATOR=${GPU_COUNT}
--test_env=JAX_SKIP_SLOW_TESTS=0
--verbose_failures=true
--config=rocm
--action_env=ROCM_PATH=/opt/rocm
--test_env=CLANG_PATH=/opt/rocm/lib/llvm/bin/clang-18
--action_env=CLANG_COMPILER_PATH=/opt/rocm/lib/llvm/bin/clang-18
--repo_env=CC=/opt/rocm/lib/llvm/bin/clang-18
--repo_env=BAZEL_COMPILER=/opt/rocm/lib/llvm/bin/clang-18
--action_env=TF_ROCM_AMDGPU_TARGETS=${GFX}
--test_tag_filters=-multiaccelerator
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow
//tests:gpu_tests
//tests:backend_independent_tests
//tests/pallas:gpu_tests
//tests/pallas:backend_independent_tests
)
printf '%s\n' "${bazel_cmd[*]}"
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "${bazel_cmd[*]}"