Skip to content

Commit

Permalink
Merge pull request #21916 from ROCm:ci_pjrt
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646793145
  • Loading branch information
jax authors committed Jun 26, 2024
2 parents ca70ebb + 385283c commit 96cf5d5
Show file tree
Hide file tree
Showing 24 changed files with 732 additions and 73 deletions.
3 changes: 3 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ build:win_clang --compiler=clang-cl
build:cuda_plugin --@xla//xla/python:enable_gpu=false
build:cuda_plugin --define=xla_python_enable_gpu=false

build:rocm_plugin --@xla//xla/python:enable_gpu=false
build:rocm_plugin --define=xla_python_enable_gpu=false

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
Expand Down
65 changes: 44 additions & 21 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,10 @@ def write_bazelrc(*, remote_build,
if not enable_nccl:
f.write("build --config=nonccl\n")
if build_gpu_plugin:
f.write("build --config=cuda_plugin\n")
if enable_cuda:
f.write("build --config=cuda_plugin\n")
elif enable_rocm:
f.write("build --config=rocm_plugin\n")
if python_version:
f.write(
"build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format(
Expand Down Expand Up @@ -431,21 +434,21 @@ def main():
"plugin is still experimental and is not ready for use yet."
),
)
add_boolean_argument(
parser,
"build_cuda_kernel_plugin",
default=False,
help_str=(
"Are we building the cuda kernel plugin? jaxlib will not be built "
"when this flag is True."
parser.add_argument(
"--build_gpu_kernel_plugin",
choices=["cuda", "rocm"],
default="",
help=(
"Specify 'cuda' or 'rocm' to build the respective kernel plugin."
" When this flag is set, jaxlib will not be built."
),
)
add_boolean_argument(
parser,
"build_cuda_pjrt_plugin",
"build_gpu_pjrt_plugin",
default=False,
help_str=(
"Are we building the cuda pjrt plugin? jaxlib will not be built "
"Are we building the cuda/rocm pjrt plugin? jaxlib will not be built "
"when this flag is True."
),
)
Expand All @@ -454,6 +457,11 @@ def main():
choices=["11", "12"],
default="12",
help="Which CUDA major version the gpu plugin is for.")
parser.add_argument(
"--gpu_plugin_rocm_version",
choices=["60"],
default="60",
help="Which ROCM major version the gpu plugin is for.")
add_boolean_argument(
parser,
"enable_rocm",
Expand Down Expand Up @@ -675,8 +683,8 @@ def main():
"--verbose_failures=true",
*args.bazel_options,
)

if not args.build_cuda_kernel_plugin and not args.build_cuda_pjrt_plugin:
if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin:
build_cpu_wheel_command = [
*command_base,
"//jaxlib/tools:build_wheel", "--",
Expand All @@ -691,29 +699,44 @@ def main():
print(" ".join(build_cpu_wheel_command))
shell(build_cpu_wheel_command)

if args.build_gpu_plugin or args.build_cuda_kernel_plugin:
build_cuda_kernels_command = [
if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \
(args.build_gpu_kernel_plugin == "rocm"):
build_gpu_kernels_command = [
*command_base,
"//jaxlib/tools:build_cuda_kernels_wheel", "--",
"//jaxlib/tools:build_gpu_kernels_wheel", "--",
f"--output_path={output_path}",
f"--jaxlib_git_hash={get_githash()}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"
]
if args.enable_cuda:
build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}")
build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_cuda_version}")
elif args.enable_rocm:
build_gpu_kernels_command.append(f"--enable-rocm={args.enable_rocm}")
build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_rocm_version}")
else:
raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.")
if args.editable:
build_cuda_kernels_command.append("--editable")
print(" ".join(build_cuda_kernels_command))
shell(build_cuda_kernels_command)
build_gpu_kernels_command.append("--editable")
print(" ".join(build_gpu_kernels_command))
shell(build_gpu_kernels_command)

if args.build_gpu_plugin or args.build_cuda_pjrt_plugin:
if args.build_gpu_plugin or args.build_gpu_pjrt_plugin:
build_pjrt_plugin_command = [
*command_base,
"//jaxlib/tools:build_gpu_plugin_wheel", "--",
f"--output_path={output_path}",
f"--jaxlib_git_hash={get_githash()}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"
]
if args.enable_cuda:
build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}")
build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_cuda_version}")
elif args.enable_rocm:
build_pjrt_plugin_command.append(f"--enable-rocm={args.enable_rocm}")
build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_rocm_version}")
else:
raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.")
if args.editable:
build_pjrt_plugin_command.append("--editable")
print(" ".join(build_pjrt_plugin_command))
Expand Down
2 changes: 1 addition & 1 deletion build/rocm/build_rocm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ rocm_version=$(cat /opt/rocm/.info/version | cut -d "-" -f 1)
export JAX_ROCM_VERSION=${rocm_version//./}

#Build and install wheel
python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR}
python3 ./build/build.py --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR}

JAX_RELEASE=1 python -m build
pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def supported_dtypes():
return types

def is_device_rocm():
return xla_bridge.get_backend().platform_version.startswith('rocm')
return 'rocm' in xla_bridge.get_backend().platform_version

def is_device_cuda():
return 'cuda' in xla_bridge.get_backend().platform_version
Expand Down
10 changes: 10 additions & 0 deletions jax/tools/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,13 @@ def update_setup_with_cuda_version(file_dir: pathlib.Path, cuda_version: str):
)
with open(src_file, "w") as f:
f.write(content)

def update_setup_with_rocm_version(file_dir: pathlib.Path, rocm_version: str):
src_file = file_dir / "setup.py"
with open(src_file) as f:
content = f.read()
content = content.replace(
"rocm_version = 0 # placeholder", f"rocm_version = {rocm_version}"
)
with open(src_file, "w") as f:
f.write(content)
5 changes: 4 additions & 1 deletion jax_plugins/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ licenses(["notice"])
load(
"//jaxlib:jax.bzl",
"if_cuda_is_configured",
"if_rocm_is_configured",
"py_library_providing_imports_info",
)

Expand All @@ -30,5 +31,7 @@ py_library(
":jax_plugins",
] + if_cuda_is_configured([
"//jax_plugins/cuda:cuda_plugin",
]) + if_rocm_is_configured([
"//jax_plugins/rocm:rocm_plugin",
]),
)
)
55 changes: 55 additions & 0 deletions jax_plugins/rocm/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

licenses(["notice"])

load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"if_windows",
"py_library_providing_imports_info",
"pytype_library",
)

package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
)

exports_files([
"__init__.py",
"plugin_pyproject.toml",
"plugin_setup.py",
"pyproject.toml",
"setup.py",
])

symlink_files(
name = "pjrt_c_api_gpu_plugin",
srcs = if_windows(
["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"],
["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"],
),
dst = ".",
flatten = True,
)

py_library_providing_imports_info(
name = "rocm_plugin",
srcs = [
"__init__.py",
],
data = [":pjrt_c_api_gpu_plugin"],
lib_rule = pytype_library,
)
91 changes: 91 additions & 0 deletions jax_plugins/rocm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import importlib
import logging
import pathlib
import platform

from jax._src.lib import xla_client
import jax._src.xla_bridge as xb

# rocm_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax rocm plugin packages.
for pkg_name in ['jax_rocm60_plugin', 'jaxlib']:
try:
rocm_plugin_extension = importlib.import_module(
f'{pkg_name}.rocm_plugin_extension'
)
except ImportError:
rocm_plugin_extension = None
else:
break

logger = logging.getLogger(__name__)


def _get_library_path():
base_path = pathlib.Path(__file__).resolve().parent
installed_path = (
base_path / 'xla_rocm_plugin.so'
)
if installed_path.exists():
return installed_path

local_path = (
base_path / 'pjrt_c_api_gpu_plugin.so'
)
if local_path.exists():
logger.debug(
'Native library %s does not exist. This most likely indicates an issue'
' with how %s was built or installed. Fallback to local test'
' library %s',
installed_path,
__package__,
local_path,
)
return local_path

logger.debug(
'WARNING: Native library %s and local test library path %s do not'
' exist. This most likely indicates an issue with how %s was built or'
' installed or missing src files.',
installed_path,
local_path,
__package__,
)
return None


def initialize():
path = _get_library_path()
if path is None:
return
options = xla_client.generate_pjrt_gpu_plugin_options()
options["platform_name"] = "ROCM"
c_api = xb.register_plugin(
'rocm', priority=500, library_path=str(path), options=options
)
if rocm_plugin_extension:
xla_client.register_custom_call_handler(
"ROCM",
functools.partial(
rocm_plugin_extension.register_custom_call_target, c_api
),
)
for _name, _value in rocm_plugin_extension.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
else:
logger.warning('rocm_plugin_extension is not found.')
3 changes: 3 additions & 0 deletions jax_plugins/rocm/plugin_pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
Loading

0 comments on commit 96cf5d5

Please sign in to comment.