diff --git a/build_tools/jax.py b/build_tools/jax.py index 1f9552eb69..b03c7e59d3 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -36,15 +36,17 @@ def xla_path() -> str: from jax.extend import ffi # pylint: disable=ungrouped-imports except ImportError: + print("Could not import jax. Looking for XLA source in $XLA_HOME or /opt/xla") if os.getenv("XLA_HOME"): - xla_home = Path(os.getenv("XLA_HOME")) + xla_home = Path(os.getenv("XLA_HOME").strip()) else: xla_home = "/opt/xla" else: xla_home = ffi.include_dir() + print(f"Found XLA source in {xla_home}") if not os.path.isdir(xla_home): - raise FileNotFoundError("Could not find xla source.") + raise FileNotFoundError(f"Could not find xla source. Searched: {xla_home}") return xla_home