Skip to content

Commit

Permalink
Merge pull request #24124 from hawkinsp:shims
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 682363064
  • Loading branch information
Google-ML-Automation committed Oct 4, 2024
2 parents 7b5842c + d3f63a6 commit d48d96c
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 291 deletions.
12 changes: 3 additions & 9 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
Expand Down Expand Up @@ -3055,14 +3054,9 @@ def aot_cache_miss(*args, **kwargs):
fastpath_data = None
return outs, fastpath_data, False # Do not remove cache entry

if xla_extension_version >= 286:
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [],
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)
else:
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, cc_shard_arg)
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [],
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)

def cc_shard_arg(x, sharding, layout):
return shard_args([sharding], [layout], [x])[0]
Expand Down
11 changes: 3 additions & 8 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import lapack
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
Expand Down Expand Up @@ -709,8 +708,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
ctx_args = (ctx,)
w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand,
w, vl, vr, info = lapack.geev_hlo(ctx, operand_aval.dtype, operand,
input_shape_vals=op_shape_vals,
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
Expand Down Expand Up @@ -2033,8 +2031,7 @@ def _svd_cpu_gpu_lowering(
compute_uv=compute_uv)
else:
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
ctx_args = (ctx,)
s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand,
s, u, vt, info = gesvd_impl(ctx, operand_aval.dtype, operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
a_shape_vals=a_shape_vals)
Expand Down Expand Up @@ -2477,9 +2474,7 @@ def _hessenberg_batching_rule(batched_args, batch_dims):
def _hessenberg_cpu_hlo(ctx, a):
a_aval, = ctx.avals_in
batch_dims = a_aval.shape[:-2]
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 34) else ()
a, taus, info = lapack.gehrd_hlo(*ctx_args, a_aval.dtype, a)
a, taus, info = lapack.gehrd_hlo(ctx, a_aval.dtype, a)
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
Expand Down
107 changes: 35 additions & 72 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import sharding
from jax._src.mesh import AbstractMesh
from jax._src.sharding_impls import (
Expand Down Expand Up @@ -322,28 +321,11 @@ def _cpp_pjit_evict_fn(self):
_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192)


if xla_extension_version < 286:
def _get_cpp_global_cache(pjit_has_explicit_sharding):
if pjit_has_explicit_sharding:
return xc._xla.PjitFunctionCache()
else:
return _cpp_pjit_cache_fun_only

def _pjit_explicit_sharding_and_layout(
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
device, backend) -> bool:
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(o) for o in out_shardings_flat) or
any(i is not None for i in in_layouts_flat) or
any(o is not None for o in out_layouts_flat))
else:
def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore
if contains_explicit_attributes:
return _cpp_pjit_cache_explicit_attributes
else:
return _cpp_pjit_cache_fun_only
def _get_cpp_global_cache(contains_explicit_attributes: bool):
if contains_explicit_attributes:
return _cpp_pjit_cache_explicit_attributes
else:
return _cpp_pjit_cache_fun_only


def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
Expand All @@ -364,35 +346,24 @@ def cache_miss(*args, **kwargs):

return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)

if xla_extension_version >= 286:
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=jit_info.donate_argnums,
donate_argnames=jit_info.donate_argnames,
device=jit_info.device, backend=jit_info.backend,
in_shardings_treedef=jit_info.in_shardings_treedef,
in_shardings_leaves=jit_info.in_shardings_leaves,
out_shardings_treedef=jit_info.out_shardings_treedef,
out_shardings_leaves=jit_info.out_shardings_leaves,
in_layouts_treedef=jit_info.in_layouts_treedef,
in_layouts_leaves=jit_info.in_layouts_leaves,
out_layouts_treedef=jit_info.out_layouts_treedef,
out_layouts_leaves=jit_info.out_layouts_leaves,
use_resource_env=jit_info.use_resource_env)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))
else:
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
jit_info.in_shardings_leaves, jit_info.out_shardings_leaves,
jit_info.in_layouts_leaves, jit_info.out_layouts_leaves,
jit_info.device, jit_info.backend)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, jit_info.donate_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=jit_info.donate_argnums,
donate_argnames=jit_info.donate_argnames,
device=jit_info.device, backend=jit_info.backend,
in_shardings_treedef=jit_info.in_shardings_treedef,
in_shardings_leaves=jit_info.in_shardings_leaves,
out_shardings_treedef=jit_info.out_shardings_treedef,
out_shardings_leaves=jit_info.out_shardings_leaves,
in_layouts_treedef=jit_info.in_layouts_treedef,
in_layouts_leaves=jit_info.in_layouts_leaves,
out_layouts_treedef=jit_info.out_layouts_treedef,
out_layouts_leaves=jit_info.out_layouts_leaves,
use_resource_env=jit_info.use_resource_env)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))

cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
Expand Down Expand Up @@ -1752,26 +1723,18 @@ def call_impl_cache_miss(*args_, **kwargs_):
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline)
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
if xla_extension_version >= 286:
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=donated_argnums, donate_argnames=None,
device=None, backend=None,
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
use_resource_env=resource_env is not None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], cache_key,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)
else:
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=donated_argnums, donate_argnames=None,
device=None, backend=None,
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
use_resource_env=resource_env is not None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], cache_key,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)

pjit_p.def_impl(_pjit_call_impl)

Expand Down
144 changes: 9 additions & 135 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,6 @@ def power3_with_cotangents(x):
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo

Expand Down Expand Up @@ -1079,117 +1078,6 @@ def _outside_call_impl(*args, **params):
outside_call_p.def_impl(_outside_call_impl)


def _with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()

def _outside_call_translation_rule(ctx,
avals_in,
avals_out,
*args_op: XlaOp,
has_token,
identity,
device_index,
flat_results_aval=(),
**params):
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
assert has_token
use_outfeed = _use_outfeed(ctx.platform)
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
current_token = args_op[-2]
current_itoken = args_op[-1]
comp = ctx.builder
assert comp.get_shape(current_token).is_token() and comp.get_shape(current_itoken).is_token(), (
"The last two arguments must be tokens")

args_to_outfeed = args_op[:-2]
# Some platforms refuse to infeed empty arrays. We generate constants
# instead.
non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)),
flat_results_aval))
need_callback_results_on_device = (not identity and
len(non_empty_flat_results_aval) > 0)
send_infeed = use_outfeed and need_callback_results_on_device
generated_infeed = False # Keep track if we emitted an infeed op

_raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform))
callback_id = _register_callback(
functools.partial(
_outside_call_run_callback,
send_infeed=send_infeed,
identity=identity,
flat_results_aval=flat_results_aval,
**params))
next_token = _callback_handler_data.receiver.add_outfeed(
comp, current_token, callback_id, args_to_outfeed, device_index)
if identity:
results = list(args_to_outfeed)
next_itoken = current_itoken
else:
empty_results = [
xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype))
for aval in flat_results_aval
if _aval_is_empty(aval)
]
if non_empty_flat_results_aval:
assert need_callback_results_on_device
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
# We shard the infeed as AssignedDevice(device_index). This must match the
# outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support
# this kind of sharding, we use a custom translation for infeed.
array_sharding_proto = xla_client.OpSharding()
array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL
array_sharding_proto.tile_assignment_dimensions = [1]
array_sharding_proto.tile_assignment_devices = [device_index]

token_sharding_proto = xla_client.OpSharding()
token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
infeed_sharding_proto = xla.tuple_sharding_proto(
[array_sharding_proto] * len(non_empty_flat_results_aval) +
[token_sharding_proto])

shape = [
shape.with_major_to_minor_layout_if_absent()
for x in non_empty_flat_results_aval
for shape in xla.aval_to_xla_shapes(x)
]

build_infeed = functools.partial(xops.InfeedWithToken,
after_outfeed_itoken,
xla_client.Shape.tuple_shape(shape))
outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto,
build_infeed)
outs = xops.GetTupleElement(outs_and_token, 0)
next_itoken = xops.GetTupleElement(outs_and_token, 1)
non_empty_results = [
xops.GetTupleElement(outs, i)
for i in range(len(non_empty_flat_results_aval))
]
generated_infeed = True
results = [
empty_results.pop(0)
if _aval_is_empty(result_aval) else non_empty_results.pop(0)
for result_aval in flat_results_aval
]
else:
results = empty_results
next_itoken = current_itoken

assert generated_infeed == send_infeed, (
f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})")
assert identity or len(results) == len(flat_results_aval), (
f"got {len(results)} but expected {len(flat_results_aval)}. "
f"identity = {identity}")
return results + [next_token, next_itoken]

if xla_extension_version < 287:
xla.register_translation(outside_call_p, _outside_call_translation_rule)


def _outside_call_outfeed_lowering(ctx: mlir.LoweringRuleContext,
*args_op,
identity,
Expand Down Expand Up @@ -1318,25 +1206,14 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
platform = ctx.module_context.platforms[0]
use_outfeed = _use_outfeed(platform)
if use_outfeed:
if xla_extension_version < 287:
return mlir.xla_fallback_lowering(outside_call_p)(
ctx,
*args,
has_token=has_token,
identity=identity,
device_index=device_index,
flat_results_aval=flat_results_aval,
**params,
)
else:
return _outside_call_outfeed_lowering(
ctx, *args,
has_token=has_token,
identity=identity,
flat_results_aval=flat_results_aval,
device_index=device_index,
**params,
)
return _outside_call_outfeed_lowering(
ctx, *args,
has_token=has_token,
identity=identity,
flat_results_aval=flat_results_aval,
device_index=device_index,
**params,
)
else:
# TODO(necula): It seems that on CPU, with custom call, the device_index
# does not work, and the callback is always run on device_index=0
Expand Down Expand Up @@ -1405,10 +1282,7 @@ def wrapped_callback(*args):
f"identity = {identity}")
return list(results) + [next_token, next_itoken]

if xla_extension_version < 287:
mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu")
else:
mlir.register_lowering(outside_call_p, _outside_call_lowering)
mlir.register_lowering(outside_call_p, _outside_call_lowering)

def _outside_call_run_callback(
arrays, device, *,
Expand Down
Loading

0 comments on commit d48d96c

Please sign in to comment.