Skip to content

Commit

Permalink
Merge pull request #23881 from dfm:deprecate-default-vmap-callback
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681862488
  • Loading branch information
Google-ML-Automation committed Oct 3, 2024
2 parents 11fe32f + 1d27d42 commit ba4052d
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 99 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead.
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
`jax.errors.JaxRuntimeError` instead.
* The default behavior of {func}`jax.pure_callback` and
{func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has
the `vectorized` parameter to those functions. The `vmap_method` parameter
should be used instead for better defined behavior. See the discussion in
{jax-issue}`#23881` for more details.

* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
Expand Down
44 changes: 27 additions & 17 deletions docs/ffi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@
" # type (which corresponds to numpy's `float32` type), and it must be a\n",
" # static parameter (i.e. not a JAX array).\n",
" eps=np.float32(eps),\n",
" # The `vectorized` parameter controls this function's behavior under `vmap`\n",
" # The `vmap_method` parameter controls this function's behavior under `vmap`\n",
" # as discussed below.\n",
" vectorized=True,\n",
" vmap_method=\"broadcast_fullrank\",\n",
" )\n",
"\n",
"\n",
Expand All @@ -325,7 +325,7 @@
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n",
"Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n",
"\n",
"The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
"The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
"\n",
"```{tip}\n",
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n",
Expand All @@ -336,19 +336,29 @@
"(ffi-call-vmap)=\n",
"### Batching with `vmap`\n",
"\n",
"All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.\n",
"By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
"This default implementation is general purpose, but it doesn't parallelize very well.\n",
"But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.\n",
"{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n",
"\n",
"The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.\n",
"The simplest `vmap_method` is `\"sequential\"`.\n",
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
"This implementation is general purpose, but it doesn't parallelize very well.\n",
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n",
"\n",
"In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n",
"The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n",
"Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n",
"\n",
"```python\n",
"ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n",
"```\n",
"\n",
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:"
"```{tip}\n",
"Note that things get a bit more complicated when we have multiple input arguments.\n",
"For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n",
"The documentation for {func}`~jax.pure_callback` includes some examples of this\n",
"```\n",
"\n",
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:"
]
},
{
Expand Down Expand Up @@ -380,7 +390,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:"
"Using `vmap_method=\"sequential\"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:"
]
},
{
Expand All @@ -389,24 +399,24 @@
"metadata": {},
"outputs": [],
"source": [
"def rms_norm_not_vectorized(x, eps=1e-5):\n",
"def rms_norm_sequential(x, eps=1e-5):\n",
" return jex.ffi.ffi_call(\n",
" \"rms_norm\",\n",
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
" x,\n",
" eps=np.float32(eps),\n",
" vectorized=False, # This is the default behavior\n",
" vmap_method=\"sequential\",\n",
" )\n",
"\n",
"\n",
"jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)"
"jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
]
},
{
Expand Down Expand Up @@ -454,7 +464,7 @@
" ),\n",
" x,\n",
" eps=np.float32(eps),\n",
" vectorized=True,\n",
" vmap_method=\"broadcast_fullrank\",\n",
" )\n",
" return y, (res, x)\n",
"\n",
Expand All @@ -471,7 +481,7 @@
" res,\n",
" x,\n",
" ct,\n",
" vectorized=True,\n",
" vmap_method=\"broadcast_fullrank\",\n",
" ),\n",
" )\n",
"\n",
Expand Down Expand Up @@ -561,7 +571,7 @@
" out_type,\n",
" x,\n",
" eps=np.float32(eps),\n",
" vectorized=True,\n",
" vmap_method=\"broadcast_fullrank\",\n",
" )\n",
"\n",
" return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n",
Expand Down
44 changes: 27 additions & 17 deletions docs/ffi.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ def rms_norm(x, eps=1e-5):
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
eps=np.float32(eps),
# The `vectorized` parameter controls this function's behavior under `vmap`
# The `vmap_method` parameter controls this function's behavior under `vmap`
# as discussed below.
vectorized=True,
vmap_method="broadcast_fullrank",
)
Expand All @@ -282,7 +282,7 @@ It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_cal
Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.
Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.

The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.

```{tip}
If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.
Expand All @@ -293,19 +293,29 @@ One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support so
(ffi-call-vmap)=
### Batching with `vmap`

All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.
By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
This default implementation is general purpose, but it doesn't parallelize very well.
But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.
{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.
The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.

The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.
The simplest `vmap_method` is `"sequential"`.
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
This implementation is general purpose, but it doesn't parallelize very well.
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"broadcast"` or `"broadcast_fullrank"` methods can be used to expose a better implementation.

In this case, since we only have one input argument, `"broadcast"` and `"broadcast_fullrank"` actually have the same behavior.
The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.
Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:

```python
ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
```

Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:
```{tip}
Note that things get a bit more complicated when we have multiple input arguments.
For simplicity, we will use the `"broadcast_fullrank"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"broadcast"` method.
The documentation for {func}`~jax.pure_callback` includes some examples of this
```

Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_fullrank"` out of the box:

```{code-cell} ipython3
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
Expand All @@ -317,23 +327,23 @@ We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms
jax.make_jaxpr(jax.vmap(rms_norm))(x)
```

If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:
Using `vmap_method="sequential"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:

```{code-cell} ipython3
def rms_norm_not_vectorized(x, eps=1e-5):
def rms_norm_sequential(x, eps=1e-5):
return jex.ffi.ffi_call(
"rms_norm",
jax.ShapeDtypeStruct(x.shape, x.dtype),
x,
eps=np.float32(eps),
vectorized=False, # This is the default behavior
vmap_method="sequential",
)
jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)
jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
```

If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).

+++

Expand Down Expand Up @@ -372,7 +382,7 @@ def rms_norm_fwd(x, eps=1e-5):
),
x,
eps=np.float32(eps),
vectorized=True,
vmap_method="broadcast_fullrank",
)
return y, (res, x)
Expand All @@ -389,7 +399,7 @@ def rms_norm_bwd(eps, res, ct):
res,
x,
ct,
vectorized=True,
vmap_method="broadcast_fullrank",
),
)
Expand Down Expand Up @@ -469,7 +479,7 @@ def rms_norm_cross_platform(x, eps=1e-5):
out_type,
x,
eps=np.float32(eps),
vectorized=True,
vmap_method="broadcast_fullrank",
)
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))
Expand Down
7 changes: 3 additions & 4 deletions examples/ffi/src/jax_ffi_example/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def rms_norm(x, eps=1e-5):
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
eps=np.float32(eps),
# The `vectorized` parameter controls this function's behavior under `vmap`.
vectorized=True,
vmap_method="broadcast_fullrank",
)


Expand All @@ -74,7 +73,7 @@ def rms_norm_fwd(x, eps=1e-5):
),
x,
eps=np.float32(eps),
vectorized=True,
vmap_method="broadcast_fullrank",
)
return y, (res, x)

Expand All @@ -91,7 +90,7 @@ def rms_norm_bwd(eps, res, ct):
res,
x,
ct,
vectorized=True,
vmap_method="broadcast_fullrank",
),
)

Expand Down
Loading

0 comments on commit ba4052d

Please sign in to comment.