Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better documentation for jnp.load #24403

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Oct 19, 2024

Part of #21461

@jakevdp jakevdp self-assigned this Oct 19, 2024
@@ -320,11 +321,43 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
return clip(val, min_val, max_val).astype(dtype)


@util.implements(np.load, update_doc=False)
def load(*args: Any, **kwargs: Any) -> Array:
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC can we use os.PathLike[str] here or do we want to allow both str and bytes?

This function is a simple wrapper of :func:`numpy.load`, but in the case of
``.npy`` files created with :func:`numpy.save` or :func:`jax.numpy.save`,
the output will be returned as a :class:`jax.Array`, and ``bfloat16`` data
types will be restored. For ``.npz`` files, results will be returned as
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it should be possible to return jax.Arrays for .npz files as well. Is there a reason why we don't do that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants