Skip to content

Commit

Permalink
Make ffi_call return a callable
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Oct 19, 2024
1 parent 48bddc6 commit b37ccb0
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 102 deletions.
45 changes: 18 additions & 27 deletions docs/ffi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -287,27 +287,27 @@
" if x.dtype != jnp.float32:\n",
" raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n",
"\n",
" # In this case, the output of our FFI function is just a single array with the\n",
" # same shape and dtype as the input. We discuss a case with a more interesting\n",
" # output type below.\n",
" out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
"\n",
" return jex.ffi.ffi_call(\n",
" call = jex.ffi.ffi_call(\n",
" # The target name must be the same string as we used to register the target\n",
" # above in `register_custom_call_target`\n",
" \"rms_norm\",\n",
" out_type,\n",
" x,\n",
" # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n",
" # the attribute `eps`. Our FFI function expects this to have the C++ `float`\n",
" # type (which corresponds to numpy's `float32` type), and it must be a\n",
" # static parameter (i.e. not a JAX array).\n",
" eps=np.float32(eps),\n",
"\n",
" # In this case, the output of our FFI function is just a single array with\n",
" # the same shape and dtype as the input. We discuss a case with a more\n",
" # interesting output type below.\n",
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
"\n",
" # The `vmap_method` parameter controls this function's behavior under `vmap`\n",
" # as discussed below.\n",
" vmap_method=\"broadcast_fullrank\",\n",
" )\n",
"\n",
" # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n",
" # the attribute `eps`. Our FFI function expects this to have the C++ `float`\n",
" # type (which corresponds to numpy's `float32` type), and it must be a\n",
" # static parameter (i.e. not a JAX array).\n",
" return call(x, eps=np.float32(eps))\n",
"\n",
"\n",
"# Test that this gives the same result as our reference implementation\n",
"x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))\n",
Expand Down Expand Up @@ -403,10 +403,8 @@
" return jex.ffi.ffi_call(\n",
" \"rms_norm\",\n",
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
" x,\n",
" eps=np.float32(eps),\n",
" vmap_method=\"sequential\",\n",
" )\n",
" )(x, eps=np.float32(eps))\n",
"\n",
"\n",
"jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)"
Expand Down Expand Up @@ -462,10 +460,8 @@
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
" jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),\n",
" ),\n",
" x,\n",
" eps=np.float32(eps),\n",
" vmap_method=\"broadcast_fullrank\",\n",
" )\n",
" )(x, eps=np.float32(eps))\n",
" return y, (res, x)\n",
"\n",
"\n",
Expand All @@ -478,11 +474,8 @@
" jex.ffi.ffi_call(\n",
" \"rms_norm_bwd\",\n",
" jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n",
" res,\n",
" x,\n",
" ct,\n",
" vmap_method=\"broadcast_fullrank\",\n",
" ),\n",
" vmap_method=\"broadcast_fullrank\",\n",
" )(res, x, ct),\n",
" )\n",
"\n",
"\n",
Expand Down Expand Up @@ -569,10 +562,8 @@
" return lambda x: jex.ffi.ffi_call(\n",
" target_name,\n",
" out_type,\n",
" x,\n",
" eps=np.float32(eps),\n",
" vmap_method=\"broadcast_fullrank\",\n",
" )\n",
" )(x, eps=np.float32(eps))\n",
"\n",
" return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n",
"\n",
Expand Down
45 changes: 18 additions & 27 deletions docs/ffi.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,27 +248,27 @@ def rms_norm(x, eps=1e-5):
if x.dtype != jnp.float32:
raise ValueError("Only the float32 dtype is implemented by rms_norm")
# In this case, the output of our FFI function is just a single array with the
# same shape and dtype as the input. We discuss a case with a more interesting
# output type below.
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jex.ffi.ffi_call(
call = jex.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_custom_call_target`
"rms_norm",
out_type,
x,
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
eps=np.float32(eps),
# In this case, the output of our FFI function is just a single array with
# the same shape and dtype as the input. We discuss a case with a more
# interesting output type below.
jax.ShapeDtypeStruct(x.shape, x.dtype),
# The `vmap_method` parameter controls this function's behavior under `vmap`
# as discussed below.
vmap_method="broadcast_fullrank",
)
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
return call(x, eps=np.float32(eps))
# Test that this gives the same result as our reference implementation
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
Expand Down Expand Up @@ -334,10 +334,8 @@ def rms_norm_sequential(x, eps=1e-5):
return jex.ffi.ffi_call(
"rms_norm",
jax.ShapeDtypeStruct(x.shape, x.dtype),
x,
eps=np.float32(eps),
vmap_method="sequential",
)
)(x, eps=np.float32(eps))
jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
Expand Down Expand Up @@ -380,10 +378,8 @@ def rms_norm_fwd(x, eps=1e-5):
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
),
x,
eps=np.float32(eps),
vmap_method="broadcast_fullrank",
)
)(x, eps=np.float32(eps))
return y, (res, x)
Expand All @@ -396,11 +392,8 @@ def rms_norm_bwd(eps, res, ct):
jex.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
res,
x,
ct,
vmap_method="broadcast_fullrank",
),
vmap_method="broadcast_fullrank",
)(res, x, ct),
)
Expand Down Expand Up @@ -477,10 +470,8 @@ def rms_norm_cross_platform(x, eps=1e-5):
return lambda x: jex.ffi.ffi_call(
target_name,
out_type,
x,
eps=np.float32(eps),
vmap_method="broadcast_fullrank",
)
)(x, eps=np.float32(eps))
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))
Expand Down
6 changes: 2 additions & 4 deletions examples/ffi/src/jax_ffi_example/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,11 @@ def array_attr(num: int):
return jex.ffi.ffi_call(
"array_attr",
jax.ShapeDtypeStruct((), np.int32),
array=np.arange(num, dtype=np.int32),
)
)(array=np.arange(num, dtype=np.int32))


def dictionary_attr(**kwargs):
return jex.ffi.ffi_call(
"dictionary_attr",
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
**kwargs,
)
)(**kwargs)
21 changes: 7 additions & 14 deletions examples/ffi/src/jax_ffi_example/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,17 @@ def rms_norm(x, eps=1e-5):
# same shape and dtype as the input.
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)

# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
return jex.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_ffi_target`
"rms_norm",
out_type,
x,
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
eps=np.float32(eps),
vmap_method="broadcast_fullrank",
)
)(x, eps=np.float32(eps))


def rms_norm_fwd(x, eps=1e-5):
Expand All @@ -71,10 +69,8 @@ def rms_norm_fwd(x, eps=1e-5):
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
),
x,
eps=np.float32(eps),
vmap_method="broadcast_fullrank",
)
)(x, eps=np.float32(eps))
return y, (res, x)


Expand All @@ -87,11 +83,8 @@ def rms_norm_bwd(eps, res, ct):
jex.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
res,
x,
ct,
vmap_method="broadcast_fullrank",
),
)(res, x, ct),
)


Expand Down
89 changes: 67 additions & 22 deletions jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
import ctypes
import functools
import os
from typing import Any
from typing import overload, Any

import numpy as np

Expand All @@ -38,6 +38,11 @@
from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray,
Shape)

# TODO(dfm): Remove after 6 months or less because there aren't any offical
# compatibility guarantees for jax.extend (see JEP 15856)
# Added Oct 13, 2024
deprecations.register("jax-ffi-call-args")

map, unsafe_map = util.safe_map, map
FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None

Expand Down Expand Up @@ -197,17 +202,43 @@ def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue
return tuple(avals)


@overload
def ffi_call(
target_name: str,
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
*args: ArrayLike,
*,
has_side_effect: bool = False,
vmap_method: str | None = None,
vectorized: bool | DeprecatedArg = DeprecatedArg(),
**kwargs: Any,
) -> Array | list[Array]:
) -> Callable[..., Array | Sequence[Array]]:
...

@overload
def ffi_call(
target_name: str,
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
*deprecated_args: ArrayLike,
has_side_effect: bool = False,
vmap_method: str | None = None,
vectorized: bool | DeprecatedArg = DeprecatedArg(),
**deprecated_kwargs: Any,
) -> Array | Sequence[Array]:
...


def ffi_call(
target_name: str,
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
*deprecated_args: ArrayLike,
has_side_effect: bool = False,
vmap_method: str | None = None,
vectorized: bool | DeprecatedArg = DeprecatedArg(),
**deprecated_kwargs: Any,
) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]:
"""Call a foreign function interface (FFI) target.
See the :ref:`ffi-tutorial` tutorial for more information.
Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under
:func:`~jax.vmap` depends on the value of ``vmap_method``. See the
:func:`~jax.pure_callback` documenation for more details about the allowed
Expand All @@ -226,18 +257,16 @@ def ffi_call(
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often
used to define the elements of ``result_shape_dtypes``.
``jax.core.abstract_token`` may be used to represent a token-typed output.
*args: the arguments passed to the custom call.
has_side_effect: boolean specifying whether the custom call has side
effects. When ``True``, the FFI call will be executed even when the
outputs are not used.
vmap_method: string specifying how the FFI call transforms under
:func:`~jax.vmap` as described above.
**kwargs: keyword arguments that are passed as named attributes to the
custom call using XLA's FFI interface.
Returns:
One or more :class:`~jax.Array` objects whose shapes and dtypes match
``result_shape_dtypes``.
A function that can be called with the input arrays as positional arguments
to execute the FFI handler. Any keyword arguments are passed as named
attributes to the FFI handler using XLA's FFI interface.
"""
if not isinstance(vectorized, DeprecatedArg) and not vectorized is None:
deprecations.warn(
Expand All @@ -264,19 +293,35 @@ def ffi_call(
else:
multiple_results = False
result_avals = _result_avals((result_shape_dtypes,))
results = ffi_call_p.bind(
*args,
result_avals=result_avals,
vectorized=vectorized,
vmap_method=vmap_method,
target_name=target_name,
has_side_effect=has_side_effect,
**_wrap_kwargs_hashable(kwargs),
)
if multiple_results:
return results

def wrapped(*args: ArrayLike, **kwargs: Any):
results = ffi_call_p.bind(
*args,
result_avals=result_avals,
vectorized=vectorized,
vmap_method=vmap_method,
target_name=target_name,
has_side_effect=has_side_effect,
**_wrap_kwargs_hashable(kwargs),
)
if multiple_results:
return results
else:
return results[0]

if deprecated_args or deprecated_kwargs:
deprecations.warn(
"jax-ffi-call-args",
"Calling ffi_call directly with input arguments is deprecated. "
"Instead, ffi_call should be used to construct a callable, which can "
"then be called with the appropriate inputs. For example,\n"
" ffi_call('target_name', output_type, x, argument=5)\n"
"should be replaced with\n"
" ffi_call('target_name', output_type)(x, argument=5)",
stacklevel=2)
return wrapped(*deprecated_args, **deprecated_kwargs)
else:
return results[0]
return wrapped


# ffi_call must support some small non-hashable input arguments, like np.arrays
Expand Down
Loading

0 comments on commit b37ccb0

Please sign in to comment.