Skip to content

Commit

Permalink
Better docs for jnp.fromfunction
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 20, 2024
1 parent ca2d158 commit b856473
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6182,9 +6182,63 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None,
from jax.dlpack import from_dlpack # pylint: disable=g-import-not-at-top
return from_dlpack(x, device=device, copy=copy)

@util.implements(np.fromfunction)

def fromfunction(function: Callable[..., Array], shape: Any,
*, dtype: DTypeLike = float, **kwargs) -> Array:
"""Create an array from a function applied over indices.
JAX implementation of :func:`numpy.fromfunction`. The JAX implementation
differs in that it dispatches via :func:`jax.vmap`, and so unlike in NumPy
the function logically operates on scalar inputs, and need not explicitly
handle broadcasted inputs.
Args:
function: a function that takes *N* dynamic scalars and outputs a scalar.
shape: a length-*N* tuple of integers specifying the output shape.
dtype: optionally specify the dtype of the inputs. Defaults to floating-point.
kwargs: additional keyword arguments are passed statically to ``function``.
Returns:
An array of shape ``shape`` if ``function`` returns a scalar, or in general
a pytree of arrays with leading dimensions ``shape``, as determined by the
output of ``function``.
See also:
- :func:`jax.vmap`: the core transformation that the :func:`fromfunction`
API is built on.
Examples:
Generate a multiplication table of a given shape:
>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int)
Array([[ 0, 0, 0, 0, 0, 0],
[ 0, 1, 2, 3, 4, 5],
[ 0, 2, 4, 6, 8, 10]], dtype=int32)
When ``function`` returns a non-scalar the output will have leading
dimension of ``shape``:
>>> def f(x):
... return (x + 1) * jnp.arange(3)
>>> jnp.fromfunction(f, shape=(2,))
Array([[0., 1., 2.],
[0., 2., 4.]], dtype=float32)
``function`` may return multiple results, in which case each is mapped
independently:
>>> def f(x, y):
... return x + y, x * y
>>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5))
>>> print(x_plus_y)
[[0. 1. 2. 3. 4.]
[1. 2. 3. 4. 5.]
[2. 3. 4. 5. 6.]]
>>> print(x_times_y)
[[0. 0. 0. 0. 0.]
[0. 1. 2. 3. 4.]
[0. 2. 4. 6. 8.]]
"""
shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()")
for i in range(len(shape)):
in_axes = [0 if i == j else None for j in range(len(shape))]
Expand Down

0 comments on commit b856473

Please sign in to comment.