Skip to content

Commit

Permalink
Update READM
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Dec 26, 2023
1 parent 16163a1 commit 9cbc538
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,20 @@ or install directly from the repository:
pip install git+https://github.com/BirkhoffG/jax-dataloader.git
```

!!! note
<div>

We will only install `jax`-related dependencies.
If you wish to use integration of `pytorch` or huggingface `datasets`,
you should try to manually install them,
or run `pip install jax-dataloader[all]` for installing all the dependencies.
> **Note**
>
> We will only install `jax`-related dependencies. If you wish to use
> integration of `pytorch` or huggingface `datasets`, you should try to
> manually install them, or run `pip install jax-dataloader[all]` for
> installing all the dependencies.
</div>

## Usage

[`jax_dataloader.core.DataLoader`](https://birkhoffg.github.io/jax-dataloader/core/#dataloader)
[`jax_dataloader.core.DataLoader`](https://birkhoffg.github.io/jax-dataloader/core.html#dataloader)
follows similar API as the pytorch dataloader.

- The `dataset` argument takes `jax_dataloader.core.Dataset` or
Expand All @@ -75,11 +79,11 @@ import jax_dataloader as jdl
import jax.numpy as jnp
```

### Using [`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset/#arraydataset)
### Using [`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)

The `jax_dataloader.core.ArrayDataset` is an easy way to wrap multiple
`jax.numpy.array` into one Dataset. For example, we can create an
[`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset/#arraydataset)
[`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)
as follows:

``` python
Expand Down Expand Up @@ -109,10 +113,15 @@ ecosystems (e.g.,
built-in datasets. `jax_dataloader` supports directly passing the
pytorch Dataset.

!!! note
<div>

> **Note**
>
> Unfortuantely, the [pytorch
> Dataset](https://pytorch.org/docs/stable/data.html) can only work with
> `backend=pytorch`. See the belowing example.
Unfortuantely, the [pytorch Dataset](https://pytorch.org/docs/stable/data.html)
can only work with `backend=pytorch`. See the belowing example.
</div>

``` python
from torchvision.datasets import MNIST
Expand Down

0 comments on commit 9cbc538

Please sign in to comment.