diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 2667176d24d1..85dd03ee3407 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2037,6 +2037,12 @@ def to_gspmd_sharding(s: sharding_impls.XLACompatibleSharding, memory_kind=s.memory_kind) +# Dummy function which is a no-op in OSS since enhanced barrier is switched on +# in OSS. +def spmd_mode_check(da_object, inline): + return + + @profiler.annotate_function def lower_sharding_computation( closed_jaxpr: core.ClosedJaxpr, @@ -2127,19 +2133,7 @@ def lower_sharding_computation( da_object, it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding])) # type: ignore - if not da_object.is_fully_addressable: # type: ignore - if inline and config.spmd_mode.value != 'allow_all': - raise RuntimeError( - "Running operations on `Array`s that are not fully addressable by this " - "process (i.e. `Array`s with data sharded across multiple devices and " - "processes.) is dangerous. It’s very important that all processes run " - "the same cross-process computations in the same order otherwise it " - "can lead to hangs. " - "If you’re not already familiar with JAX’s multi-process " - "programming model, please read " - "https://jax.readthedocs.io/en/latest/multi_process.html. " - "To fix this error, run your `jitted` computation inside " - "`with jax.spmd_mode('allow_all'):` context manager.") + spmd_mode_check(da_object, inline) # 2. Build up the HLO semantic_in_shardings = SemanticallyEqualShardings(