Skip to content

Commit e114fe9

Browse files
authored
Merge pull request #21 from BirkhoffG/readme
Update readme
2 parents 9cbc538 + ed75931 commit e114fe9

File tree

4 files changed

+218
-110
lines changed

4 files changed

+218
-110
lines changed

README.md

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ License](https://img.shields.io/github/license/BirkhoffG/jax-dataloader.svg)
1717
- **downloading and pre-processing datasets** via [huggingface
1818
datasets](https://github.com/huggingface/datasets), [pytorch
1919
Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset),
20-
and tensorflow dataset (forthcoming)
20+
and [tensorflow dataset](www.tensorflow.org/datasets);
2121

2222
- **iteratively loading batches** via (vanillla) [jax
2323
dataloader](https://birkhoffg.github.io/jax-dataloader/core.html#jax-dataloader),
2424
[pytorch
25-
dataloader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader),
26-
tensorflow (forthcoming), and merlin (forthcoming).
25+
dataloader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
26+
and [tensorflow dataset](www.tensorflow.org/datasets).
2727

2828
A minimum `jax-dataloader` example:
2929

@@ -57,9 +57,11 @@ pip install git+https://github.com/BirkhoffG/jax-dataloader.git
5757
> **Note**
5858
>
5959
> We will only install `jax`-related dependencies. If you wish to use
60-
> integration of `pytorch` or huggingface `datasets`, you should try to
61-
> manually install them, or run `pip install jax-dataloader[all]` for
62-
> installing all the dependencies.
60+
> integration of `pytorch`, huggingface `datasets`, or `tensorflow`, we
61+
> recommend manually install those dependencies.
62+
>
63+
> You can also run `pip install jax-dataloader[all]` to install
64+
> everything (not recommended).
6365
6466
</div>
6567

@@ -68,16 +70,21 @@ pip install git+https://github.com/BirkhoffG/jax-dataloader.git
6870
[`jax_dataloader.core.DataLoader`](https://birkhoffg.github.io/jax-dataloader/core.html#dataloader)
6971
follows similar API as the pytorch dataloader.
7072

71-
- The `dataset` argument takes `jax_dataloader.core.Dataset` or
72-
`torch.utils.data.Dataset` or (the huggingface) `datasets.Dataset` as
73-
an input from which to load the data.
74-
- The `backend` argument takes `"jax"` or`"pytorch"` as an input, which
75-
specifies which backend dataloader to use batches.
73+
- The `dataset` should be an object of the subclass of
74+
`jax_dataloader.core.Dataset` or `torch.utils.data.Dataset` or (the
75+
huggingface) `datasets.Dataset` or `tf.data.Dataset`.
76+
- The `backend` should be one of `"jax"` or `"pytorch"` or
77+
`"tensorflow"`. This argument specifies which backend dataloader to
78+
load batches.
7679

77-
``` python
78-
import jax_dataloader as jdl
79-
import jax.numpy as jnp
80-
```
80+
Note that not every dataset is compatible with every backend. See the
81+
compatibility table below:
82+
83+
| | `jdl.Dataset` | `torch_data.Dataset` | `tf.data.Dataset` | `datasets.Dataset` |
84+
|:---------------|:--------------|:---------------------|:------------------|:-------------------|
85+
| `"jax"` |||||
86+
| `"pytorch"` |||||
87+
| `"tensorflow"` |||||
8188

8289
### Using [`ArrayDataset`](https://birkhoffg.github.io/jax-dataloader/dataset.html#arraydataset)
8390

@@ -94,7 +101,7 @@ y = jnp.arange(10)
94101
arr_ds = jdl.ArrayDataset(X, y)
95102
```
96103

97-
This `arr_ds` can be loaded by both `"jax"` and `"pytorch"` dataloaders.
104+
This `arr_ds` can be loaded by *every* backends.
98105

99106
``` python
100107
# Create a `DataLoader` from the `ArrayDataset` via jax backend
@@ -103,6 +110,31 @@ dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
103110
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
104111
```
105112

113+
### Using Huggingface Datasets
114+
115+
The huggingface [datasets](https://github.com/huggingface/datasets) is a
116+
morden library for downloading, pre-processing, and sharing datasets.
117+
`jax_dataloader` supports directly passing the huggingface datasets.
118+
119+
``` python
120+
from datasets import load_dataset
121+
```
122+
123+
For example, We load the `"squad"` dataset from `datasets`:
124+
125+
``` python
126+
hf_ds = load_dataset("squad")
127+
```
128+
129+
Then, we can use `jax_dataloader` to load batches of `hf_ds`.
130+
131+
``` python
132+
# Create a `DataLoader` from the `datasets.Dataset` via jax backend
133+
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
134+
# Or we can use the pytorch backend
135+
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
136+
```
137+
106138
### Using Pytorch Datasets
107139

108140
The [pytorch Dataset](https://pytorch.org/docs/stable/data.html) and its
@@ -147,27 +179,24 @@ This `pt_ds` can **only** be loaded via `"pytorch"` dataloaders.
147179
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)
148180
```
149181

150-
### Using Huggingface Datasets
182+
### Using Tensowflow Datasets
151183

152-
The huggingface [datasets](https://github.com/huggingface/datasets) is a
153-
morden library for downloading, pre-processing, and sharing datasets.
154-
`jax_dataloader` supports directly passing the huggingface datasets.
184+
`jax_dataloader` supports directly passing the [tensorflow
185+
datasets](www.tensorflow.org/datasets).
155186

156187
``` python
157-
from datasets import load_dataset
188+
import tensorflow_datasets as tfds
189+
import tensorflow as tf
158190
```
159191

160-
For example, We load the `"squad"` dataset from `datasets`:
192+
For instance, we can load the MNIST dataset from `tensorflow_datasets`
161193

162194
``` python
163-
hf_ds = load_dataset("squad")
195+
tf_ds = tfds.load('mnist', split='test', as_supervised=True)
164196
```
165197

166-
This `hf_ds` can be loaded via `"jax"` and `"pytorch"` dataloaders.
198+
and use `jax_dataloader` for iterating the dataset.
167199

168200
``` python
169-
# Create a `DataLoader` from the `datasets.Dataset` via jax backend
170-
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
171-
# Or we can use the pytorch backend
172-
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
201+
dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)
173202
```

jax_dataloader/core.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,15 @@ def _check_backend_compatibility(ds, backend: str):
6262
return DataLoader(ds, backend=backend)
6363

6464
# %% ../nbs/core.ipynb 8
65-
def get_backend_compatibilities():
65+
def get_backend_compatibilities() -> dict[str, list[type]]:
6666

6767
ds = {
68-
'JAX': ArrayDataset(np.array([1,2,3])),
69-
'Pytorch': torch_data.Dataset(),
70-
'Tensorflow': tf.data.Dataset.from_tensor_slices(np.array([1,2,3])),
71-
'Huggingface': hf_datasets.Dataset.from_dict({'a': [1,2,3]})
68+
JAXDataset: ArrayDataset(np.array([1,2,3])),
69+
TorchDataset: torch_data.Dataset(),
70+
TFDataset: tf.data.Dataset.from_tensor_slices(np.array([1,2,3])),
71+
HFDataset: hf_datasets.Dataset.from_dict({'a': [1,2,3]})
7272
}
73+
assert len(ds) == len(SUPPORTED_DATASETS)
7374
backends = {b: [] for b in _get_backends()}
7475
for b in _get_backends():
7576
for name, dataset in ds.items():

nbs/core.ipynb

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,7 @@
2323
"cell_type": "code",
2424
"execution_count": null,
2525
"metadata": {},
26-
"outputs": [
27-
{
28-
"name": "stdout",
29-
"output_type": "stream",
30-
"text": [
31-
"The autoreload extension is already loaded. To reload it, use:\n",
32-
" %reload_ext autoreload\n"
33-
]
34-
}
35-
],
26+
"outputs": [],
3627
"source": [
3728
"#| include: false\n",
3829
"%load_ext autoreload\n",
@@ -47,7 +38,18 @@
4738
"cell_type": "code",
4839
"execution_count": null,
4940
"metadata": {},
50-
"outputs": [],
41+
"outputs": [
42+
{
43+
"name": "stderr",
44+
"output_type": "stream",
45+
"text": [
46+
"2023-12-26 15:13:36.437449: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
47+
"2023-12-26 15:13:36.437528: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
48+
"2023-12-26 15:13:36.439236: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
49+
"2023-12-26 15:13:37.500782: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
50+
]
51+
}
52+
],
5153
"source": [
5254
"#| export\n",
5355
"from __future__ import print_function, division, annotations\n",
@@ -143,14 +145,15 @@
143145
"outputs": [],
144146
"source": [
145147
"#| export\n",
146-
"def get_backend_compatibilities():\n",
148+
"def get_backend_compatibilities() -> dict[str, list[type]]:\n",
147149
"\n",
148150
" ds = {\n",
149-
" 'JAX': ArrayDataset(np.array([1,2,3])),\n",
150-
" 'Pytorch': torch_data.Dataset(),\n",
151-
" 'Tensorflow': tf.data.Dataset.from_tensor_slices(np.array([1,2,3])),\n",
152-
" 'Huggingface': hf_datasets.Dataset.from_dict({'a': [1,2,3]})\n",
151+
" JAXDataset: ArrayDataset(np.array([1,2,3])),\n",
152+
" TorchDataset: torch_data.Dataset(),\n",
153+
" TFDataset: tf.data.Dataset.from_tensor_slices(np.array([1,2,3])),\n",
154+
" HFDataset: hf_datasets.Dataset.from_dict({'a': [1,2,3]})\n",
153155
" }\n",
156+
" assert len(ds) == len(SUPPORTED_DATASETS)\n",
154157
" backends = {b: [] for b in _get_backends()}\n",
155158
" for b in _get_backends():\n",
156159
" for name, dataset in ds.items():\n",

0 commit comments

Comments
 (0)