Skip to content

Commit

Permalink
Remove the spmd_mode check from OSS JAX since enhanced barrier is swi…
Browse files Browse the repository at this point in the history
…tched on for OSS JAX

PiperOrigin-RevId: 625763988
  • Loading branch information
yashk2810 authored and jax authors committed Apr 17, 2024
1 parent 1c8534e commit 7cb0e60
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,6 +2037,12 @@ def to_gspmd_sharding(s: sharding_impls.XLACompatibleSharding,
memory_kind=s.memory_kind)


# Dummy function which is a no-op in OSS since enhanced barrier is switched on
# in OSS.
def spmd_mode_check(da_object, inline):
return


@profiler.annotate_function
def lower_sharding_computation(
closed_jaxpr: core.ClosedJaxpr,
Expand Down Expand Up @@ -2127,19 +2133,7 @@ def lower_sharding_computation(
da_object,
it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding])) # type: ignore

if not da_object.is_fully_addressable: # type: ignore
if inline and config.spmd_mode.value != 'allow_all':
raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and "
"processes.) is dangerous. It’s very important that all processes run "
"the same cross-process computations in the same order otherwise it "
"can lead to hangs. "
"If you’re not already familiar with JAX’s multi-process "
"programming model, please read "
"https://jax.readthedocs.io/en/latest/multi_process.html. "
"To fix this error, run your `jitted` computation inside "
"`with jax.spmd_mode('allow_all'):` context manager.")
spmd_mode_check(da_object, inline)

# 2. Build up the HLO
semantic_in_shardings = SemanticallyEqualShardings(
Expand Down

0 comments on commit 7cb0e60

Please sign in to comment.