Skip to content

Commit

Permalink
Error out in donation if:
Browse files Browse the repository at this point in the history
1) input layout is AUTO and output layout is not AUTO (i.e. default or concrete)

2) input layout is not AUTO (i.e. default or concrete) and output layout is AUTO

This is because there is a conflict in such cases and almost always leads to the wrong layout being chosen by the compiler. For example, let's talk about (1): since input layout is AUTO and output layout is default and since they are aliased, XLA will end up choose default layout for input which is not what you want in majority of the cases.
Erroring is best in such cases and the user can mark the input layout to be default if they want to do that.

The correct choice is to always make both of them AUTO since you want the compiler to choose the best possible layout instead of choosing the input or output layout if the other one is AUTO.

PiperOrigin-RevId: 683688470
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 8, 2024
1 parent 55153cc commit e5fa965
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
30 changes: 20 additions & 10 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,21 +1220,34 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out,

xla_donated_args = None
out_donated_args = list(donated_args)
in_out_layout_not_none = in_layouts is not None and out_layouts is not None
for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)):
# Only donate if memory kinds match. Relax this when the compiler can
# donate across memories.
key = (aval, rm)
if donations.get(key, ()):
input_id = donations[key].popleft()
out_donated_args[input_id] = False
# We can alias if XLA performs layout assignment because XLA will
# respect the aliases when assigning layouts. Its only for two
# mismatched explicitly assigned layouts that XLA will certainly fail.
if (in_layouts is None or
out_layouts is None or
in_layouts[input_id] == out_layouts[i] or
isinstance(in_layouts[input_id], AutoLayout) or
if (in_out_layout_not_none and
isinstance(in_layouts[input_id], AutoLayout) and
not isinstance(out_layouts[i], AutoLayout)):
raise ValueError(
f"Input layout being donated was {in_layouts[input_id]} while"
f" output layout was {out_layouts[i]}. Did you mean to set the"
" **output layout** to **DeviceLocalLayout.AUTO**?\nThis will"
" allow for the input and output layout to be chosen by XLA and"
" not the layout of the output which might not be optimal.")
if (in_out_layout_not_none and
not isinstance(in_layouts[input_id], AutoLayout) and
isinstance(out_layouts[i], AutoLayout)):
raise ValueError(
f"Input layout being donated was {in_layouts[input_id]} while"
f" output layout was {out_layouts[i]}. Did you mean to set the"
" **input layout** to **DeviceLocalLayout.AUTO**?\nThis will allow"
" for the input and output layout to be chosen by XLA and not the"
" layout of the input which might not be optimal.")
if (in_layouts is None or out_layouts is None or
in_layouts[input_id] == out_layouts[i]):
input_output_aliases[input_id] = i
else:
# Fallback to xla donation if layouts don't match.
Expand Down Expand Up @@ -1508,7 +1521,6 @@ def lower_jaxpr_to_fun(
aliases.extend([None] * len_ir_types(itypes))
else:
aliases.extend(output_ids[alias])

for attrs, alias in zip(arg_attrs, aliases):
if alias is not None:
attrs["tf.aliasing_output"] = i32_attr(alias)
Expand Down Expand Up @@ -2595,8 +2607,6 @@ def merge_mlir_modules(dst_module: ir.Module,
return renamings["main"]




DEVICE_TO_DEVICE_TYPE = 1
SEND_TO_HOST_TYPE = 2
RECV_FROM_HOST_TYPE = 3
Expand Down
17 changes: 17 additions & 0 deletions tests/layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,23 @@ def f(x):
f.lower(sds).compile()(arr)
self.assertFalse(arr.is_deleted())

def test_donation_error_on_auto(self):
@partial(jax.jit, donate_argnums=0, in_shardings=Layout(DLL.AUTO))
def f(x):
return x * 2

with self.assertRaisesRegex(
ValueError, ".*Did you mean to set the.*output layout.*AUTO.*"):
f(jnp.arange(8))

@partial(jax.jit, donate_argnums=0, out_shardings=Layout(DLL.AUTO))
def g(x):
return x * 2

with self.assertRaisesRegex(
ValueError, ".*Did you mean to set the.*input layout.*AUTO.*"):
g(jnp.arange(8))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit e5fa965

Please sign in to comment.