diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index 8b5d5ea6c907..c11de92c0627 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -287,27 +287,27 @@ " if x.dtype != jnp.float32:\n", " raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n", "\n", - " # In this case, the output of our FFI function is just a single array with the\n", - " # same shape and dtype as the input. We discuss a case with a more interesting\n", - " # output type below.\n", - " out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n", - "\n", - " return jex.ffi.ffi_call(\n", + " call = jex.ffi.ffi_call(\n", " # The target name must be the same string as we used to register the target\n", " # above in `register_custom_call_target`\n", " \"rms_norm\",\n", - " out_type,\n", - " x,\n", - " # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n", - " # the attribute `eps`. Our FFI function expects this to have the C++ `float`\n", - " # type (which corresponds to numpy's `float32` type), and it must be a\n", - " # static parameter (i.e. not a JAX array).\n", - " eps=np.float32(eps),\n", + "\n", + " # In this case, the output of our FFI function is just a single array with\n", + " # the same shape and dtype as the input. We discuss a case with a more\n", + " # interesting output type below.\n", + " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", + "\n", " # The `vmap_method` parameter controls this function's behavior under `vmap`\n", " # as discussed below.\n", " vmap_method=\"broadcast_fullrank\",\n", " )\n", "\n", + " # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n", + " # the attribute `eps`. Our FFI function expects this to have the C++ `float`\n", + " # type (which corresponds to numpy's `float32` type), and it must be a\n", + " # static parameter (i.e. not a JAX array).\n", + " return call(x, eps=np.float32(eps))\n", + "\n", "\n", "# Test that this gives the same result as our reference implementation\n", "x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))\n", @@ -403,10 +403,8 @@ " return jex.ffi.ffi_call(\n", " \"rms_norm\",\n", " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", - " x,\n", - " eps=np.float32(eps),\n", " vmap_method=\"sequential\",\n", - " )\n", + " )(x, eps=np.float32(eps))\n", "\n", "\n", "jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)" @@ -462,10 +460,8 @@ " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", " jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),\n", " ),\n", - " x,\n", - " eps=np.float32(eps),\n", " vmap_method=\"broadcast_fullrank\",\n", - " )\n", + " )(x, eps=np.float32(eps))\n", " return y, (res, x)\n", "\n", "\n", @@ -478,11 +474,8 @@ " jex.ffi.ffi_call(\n", " \"rms_norm_bwd\",\n", " jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n", - " res,\n", - " x,\n", - " ct,\n", - " vmap_method=\"broadcast_fullrank\",\n", - " ),\n", + " vmap_method=\"broadcast_fullrank\",\n", + " )(res, x, ct),\n", " )\n", "\n", "\n", @@ -569,10 +562,8 @@ " return lambda x: jex.ffi.ffi_call(\n", " target_name,\n", " out_type,\n", - " x,\n", - " eps=np.float32(eps),\n", " vmap_method=\"broadcast_fullrank\",\n", - " )\n", + " )(x, eps=np.float32(eps))\n", "\n", " return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n", "\n", diff --git a/docs/ffi.md b/docs/ffi.md index b3d1dcf46364..2c70f87f0d03 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -248,27 +248,27 @@ def rms_norm(x, eps=1e-5): if x.dtype != jnp.float32: raise ValueError("Only the float32 dtype is implemented by rms_norm") - # In this case, the output of our FFI function is just a single array with the - # same shape and dtype as the input. We discuss a case with a more interesting - # output type below. - out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) - - return jex.ffi.ffi_call( + call = jex.ffi.ffi_call( # The target name must be the same string as we used to register the target # above in `register_custom_call_target` "rms_norm", - out_type, - x, - # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for - # the attribute `eps`. Our FFI function expects this to have the C++ `float` - # type (which corresponds to numpy's `float32` type), and it must be a - # static parameter (i.e. not a JAX array). - eps=np.float32(eps), + + # In this case, the output of our FFI function is just a single array with + # the same shape and dtype as the input. We discuss a case with a more + # interesting output type below. + jax.ShapeDtypeStruct(x.shape, x.dtype), + # The `vmap_method` parameter controls this function's behavior under `vmap` # as discussed below. vmap_method="broadcast_fullrank", ) + # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for + # the attribute `eps`. Our FFI function expects this to have the C++ `float` + # type (which corresponds to numpy's `float32` type), and it must be a + # static parameter (i.e. not a JAX array). + return call(x, eps=np.float32(eps)) + # Test that this gives the same result as our reference implementation x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) @@ -334,10 +334,8 @@ def rms_norm_sequential(x, eps=1e-5): return jex.ffi.ffi_call( "rms_norm", jax.ShapeDtypeStruct(x.shape, x.dtype), - x, - eps=np.float32(eps), vmap_method="sequential", - ) + )(x, eps=np.float32(eps)) jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x) @@ -380,10 +378,8 @@ def rms_norm_fwd(x, eps=1e-5): jax.ShapeDtypeStruct(x.shape, x.dtype), jax.ShapeDtypeStruct(x.shape[:-1], x.dtype), ), - x, - eps=np.float32(eps), vmap_method="broadcast_fullrank", - ) + )(x, eps=np.float32(eps)) return y, (res, x) @@ -396,11 +392,8 @@ def rms_norm_bwd(eps, res, ct): jex.ffi.ffi_call( "rms_norm_bwd", jax.ShapeDtypeStruct(ct.shape, ct.dtype), - res, - x, - ct, - vmap_method="broadcast_fullrank", - ), + vmap_method="broadcast_fullrank", + )(res, x, ct), ) @@ -477,10 +470,8 @@ def rms_norm_cross_platform(x, eps=1e-5): return lambda x: jex.ffi.ffi_call( target_name, out_type, - x, - eps=np.float32(eps), vmap_method="broadcast_fullrank", - ) + )(x, eps=np.float32(eps)) return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda")) diff --git a/examples/ffi/src/jax_ffi_example/attrs.py b/examples/ffi/src/jax_ffi_example/attrs.py index 30d7d6c74344..2f215e8e25b1 100644 --- a/examples/ffi/src/jax_ffi_example/attrs.py +++ b/examples/ffi/src/jax_ffi_example/attrs.py @@ -35,13 +35,11 @@ def array_attr(num: int): return jex.ffi.ffi_call( "array_attr", jax.ShapeDtypeStruct((), np.int32), - array=np.arange(num, dtype=np.int32), - ) + )(array=np.arange(num, dtype=np.int32)) def dictionary_attr(**kwargs): return jex.ffi.ffi_call( "dictionary_attr", (jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)), - **kwargs, - ) + )(**kwargs) diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py index d063f1cf319c..51913232c212 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.py +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -49,19 +49,17 @@ def rms_norm(x, eps=1e-5): # same shape and dtype as the input. out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for + # the attribute `eps`. Our FFI function expects this to have the C++ `float` + # type (which corresponds to numpy's `float32` type), and it must be a + # static parameter (i.e. not a JAX array). return jex.ffi.ffi_call( # The target name must be the same string as we used to register the target # above in `register_ffi_target` "rms_norm", out_type, - x, - # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for - # the attribute `eps`. Our FFI function expects this to have the C++ `float` - # type (which corresponds to numpy's `float32` type), and it must be a - # static parameter (i.e. not a JAX array). - eps=np.float32(eps), vmap_method="broadcast_fullrank", - ) + )(x, eps=np.float32(eps)) def rms_norm_fwd(x, eps=1e-5): @@ -71,10 +69,8 @@ def rms_norm_fwd(x, eps=1e-5): jax.ShapeDtypeStruct(x.shape, x.dtype), jax.ShapeDtypeStruct(x.shape[:-1], x.dtype), ), - x, - eps=np.float32(eps), vmap_method="broadcast_fullrank", - ) + )(x, eps=np.float32(eps)) return y, (res, x) @@ -87,11 +83,8 @@ def rms_norm_bwd(eps, res, ct): jex.ffi.ffi_call( "rms_norm_bwd", jax.ShapeDtypeStruct(ct.shape, ct.dtype), - res, - x, - ct, vmap_method="broadcast_fullrank", - ), + )(res, x, ct), ) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index 6bbba0cbd88e..579b8ee17440 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -14,11 +14,11 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence import ctypes import functools import os -from typing import Any +from typing import overload, Any import numpy as np @@ -38,6 +38,11 @@ from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray, Shape) +# TODO(dfm): Remove after 6 months or less because there aren't any offical +# compatibility guarantees for jax.extend (see JEP 15856) +# Added Oct 13, 2024 +deprecations.register("jax-ffi-call-args") + map, unsafe_map = util.safe_map, map FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None @@ -197,17 +202,43 @@ def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue return tuple(avals) +@overload def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], - *args: ArrayLike, + *, has_side_effect: bool = False, vmap_method: str | None = None, vectorized: bool | DeprecatedArg = DeprecatedArg(), - **kwargs: Any, -) -> Array | list[Array]: +) -> Callable[..., Array | Sequence[Array]]: + ... + +@overload +def ffi_call( + target_name: str, + result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], + *deprecated_args: ArrayLike, + has_side_effect: bool = False, + vmap_method: str | None = None, + vectorized: bool | DeprecatedArg = DeprecatedArg(), + **deprecated_kwargs: Any, +) -> Array | Sequence[Array]: + ... + + +def ffi_call( + target_name: str, + result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], + *deprecated_args: ArrayLike, + has_side_effect: bool = False, + vmap_method: str | None = None, + vectorized: bool | DeprecatedArg = DeprecatedArg(), + **deprecated_kwargs: Any, +) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]: """Call a foreign function interface (FFI) target. + See the :ref:`ffi-tutorial` tutorial for more information. + Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under :func:`~jax.vmap` depends on the value of ``vmap_method``. See the :func:`~jax.pure_callback` documenation for more details about the allowed @@ -226,18 +257,16 @@ def ffi_call( the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often used to define the elements of ``result_shape_dtypes``. ``jax.core.abstract_token`` may be used to represent a token-typed output. - *args: the arguments passed to the custom call. has_side_effect: boolean specifying whether the custom call has side effects. When ``True``, the FFI call will be executed even when the outputs are not used. vmap_method: string specifying how the FFI call transforms under :func:`~jax.vmap` as described above. - **kwargs: keyword arguments that are passed as named attributes to the - custom call using XLA's FFI interface. Returns: - One or more :class:`~jax.Array` objects whose shapes and dtypes match - ``result_shape_dtypes``. + A function that can be called with the input arrays as positional arguments + to execute the FFI handler. Any keyword arguments are passed as named + attributes to the FFI handler using XLA's FFI interface. """ if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: deprecations.warn( @@ -264,19 +293,35 @@ def ffi_call( else: multiple_results = False result_avals = _result_avals((result_shape_dtypes,)) - results = ffi_call_p.bind( - *args, - result_avals=result_avals, - vectorized=vectorized, - vmap_method=vmap_method, - target_name=target_name, - has_side_effect=has_side_effect, - **_wrap_kwargs_hashable(kwargs), - ) - if multiple_results: - return results + + def wrapped(*args: ArrayLike, **kwargs: Any): + results = ffi_call_p.bind( + *args, + result_avals=result_avals, + vectorized=vectorized, + vmap_method=vmap_method, + target_name=target_name, + has_side_effect=has_side_effect, + **_wrap_kwargs_hashable(kwargs), + ) + if multiple_results: + return results + else: + return results[0] + + if deprecated_args or deprecated_kwargs: + deprecations.warn( + "jax-ffi-call-args", + "Calling ffi_call directly with input arguments is deprecated. " + "Instead, ffi_call should be used to construct a callable, which can " + "then be called with the appropriate inputs. For example,\n" + " ffi_call('target_name', output_type, x, argument=5)\n" + "should be replaced with\n" + " ffi_call('target_name', output_type)(x, argument=5)", + stacklevel=2) + return wrapped(*deprecated_args, **deprecated_kwargs) else: - return results[0] + return wrapped # ffi_call must support some small non-hashable input arguments, like np.arrays diff --git a/tests/extend_test.py b/tests/extend_test.py index 805ad937bc02..8ff864ef1ced 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -154,7 +154,7 @@ def lowering_rule(ctx, x): ]) def testParams(self, param, expected_builder): def fun(x): - return jex.ffi.ffi_call("test_ffi", x, x, param=param) + return jex.ffi.ffi_call("test_ffi", x)(x, param=param) # Here we inspect the lowered IR to test that the parameter has been # serialized with the appropriate type. @@ -171,7 +171,7 @@ def fun(x): def testToken(self): def fun(): token = lax.create_token() - return jex.ffi.ffi_call("test_ffi", core.abstract_token, token) + return jex.ffi.ffi_call("test_ffi", core.abstract_token)(token) # Ensure that token inputs and outputs are translated to the correct type module = jax.jit(fun).lower().compiler_ir("stablehlo") @@ -192,7 +192,7 @@ def testEffectsHlo(self): else: raise unittest.SkipTest("Unsupported device") def fun(): - jex.ffi.ffi_call(target_name, (), has_side_effect=True) + jex.ffi.ffi_call(target_name, (), has_side_effect=True)() hlo = jax.jit(fun).lower() self.assertIn(target_name, hlo.as_text()) self.assertIn("has_side_effect = true", hlo.as_text()) @@ -200,14 +200,14 @@ def fun(): def testJvpError(self): def fun(x): - return jex.ffi.ffi_call("test_ffi", x, x, non_hashable_arg={"a": 1}) + return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1}) with self.assertRaisesRegex( ValueError, "The FFI call to `.+` cannot be differentiated."): jax.jvp(fun, (0.5,), (0.5,)) def testNonHashableAttributes(self): def fun(x): - return jex.ffi.ffi_call("test_ffi", x, x, non_hashable_arg={"a": 1}) + return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1}) self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5)))) hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() @@ -220,7 +220,7 @@ def fun(x): self.assertNotIsInstance(manager.exception, TypeError) def fun(x): - return jex.ffi.ffi_call("test_ffi", x, x, non_hashable_arg=np.arange(3)) + return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg=np.arange(3)) self.assertIn("HashableArray", str(jax.make_jaxpr(fun)(jnp.ones(5)))) hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() self.assertIn("non_hashable_arg = array", hlo) @@ -274,6 +274,12 @@ def testVectorizedDeprecation(self): jax.vmap( lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size))(pivots) + def testBackwardCompatSyntax(self): + def fun(x): + return jex.ffi.ffi_call("test_ffi", x, x, param=0.5) + with self.assertWarns(DeprecationWarning): + jax.jit(fun).lower(jnp.ones(5)) + # TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation` # custom call target because that's the only one in jaxlib that uses the @@ -286,9 +292,8 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs): shape=pivots.shape[:-1] + (permutation_size,), dtype=pivots.dtype, ), - pivots, **kwargs, - ) + )(pivots) if __name__ == "__main__":