Skip to content

Commit 65a777e

Browse files
committed
clean code
1 parent ee36d6b commit 65a777e

File tree

4 files changed

+173
-36
lines changed

4 files changed

+173
-36
lines changed

jax_dataloader/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ def __init__(
3131
):
3232
assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
3333
"All arrays must have the same dimension."
34-
self.arrays = arrays
34+
self.arrays = tuple(arrays)
3535

3636
def __len__(self):
3737
return self.arrays[0].shape[0]
3838

3939
def __getitem__(self, index):
40-
return tuple(arr[index] for arr in self.arrays)
40+
return jax.tree_util.tree_map(lambda x: x[index], self.arrays)
4141

4242
def to_tf_dataset(self):
4343
return tf.data.Dataset.from_tensor_slices(self.arrays)

jax_dataloader/loaders.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# %% auto 0
1010
__all__ = ['BaseDataLoader', 'DataLoaderJax', 'DataLoaderPytorch', 'to_tf_dataset', 'DataLoaderTensorflow']
1111

12-
# %% ../nbs/loader.ipynb 5
12+
# %% ../nbs/loader.ipynb 6
1313
class BaseDataLoader:
1414
"""Dataloader Interface"""
1515

@@ -32,7 +32,7 @@ def __next__(self):
3232
def __iter__(self):
3333
raise NotImplementedError
3434

35-
# %% ../nbs/loader.ipynb 7
35+
# %% ../nbs/loader.ipynb 8
3636
class DataLoaderJax(BaseDataLoader):
3737
"""Dataloder in Vanilla Jax"""
3838

@@ -56,6 +56,8 @@ def __init__(
5656
self.pose = 0 # record the current position in the dataset
5757
self._shuffle()
5858

59+
self.num_batches = len(self)
60+
5961
def _shuffle(self):
6062
if self.shuffle:
6163
self.indices = jax.random.permutation(next(self.keys), self.indices)
@@ -64,7 +66,7 @@ def _stop_iteration(self):
6466
self.pose = 0
6567
self._shuffle()
6668
raise StopIteration
67-
69+
6870
def __len__(self):
6971
if self.drop_last:
7072
batches = len(self.dataset) // self.batch_size # get the floor of division
@@ -73,23 +75,19 @@ def __len__(self):
7375
return batches
7476

7577
def __next__(self):
76-
if self.pose + self.batch_size <= self.data_len:
77-
batch_indices = self.indices[self.pose: self.pose + self.batch_size]
78-
batch_data = self.dataset[batch_indices]
79-
self.pose += self.batch_size
80-
return batch_data
81-
elif self.pose < self.data_len and not self.drop_last:
82-
batch_indices = self.indices[self.pose:]
78+
if self.pose < self.num_batches:
79+
batch_indices = self.indices[self.pose * self.batch_size: (self.pose + 1) * self.batch_size]
8380
batch_data = self.dataset[batch_indices]
84-
self.pose += self.batch_size
81+
self.pose += 1
8582
return batch_data
8683
else:
8784
self._stop_iteration()
8885

8986
def __iter__(self):
9087
return self
9188

92-
# %% ../nbs/loader.ipynb 10
89+
90+
# %% ../nbs/loader.ipynb 14
9391
# adapted from https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
9492
def _numpy_collate(batch):
9593
if isinstance(batch[0], (np.ndarray, jax.Array)):
@@ -110,7 +108,7 @@ def __getitem__(self, idx): return self.dataset[idx]
110108

111109
return DatasetPytorch(dataset)
112110

113-
# %% ../nbs/loader.ipynb 11
111+
# %% ../nbs/loader.ipynb 15
114112
class DataLoaderPytorch(BaseDataLoader):
115113
"""Pytorch Dataloader"""
116114
def __init__(
@@ -151,7 +149,7 @@ def __next__(self):
151149
def __iter__(self):
152150
return self.dataloader.__iter__()
153151

154-
# %% ../nbs/loader.ipynb 14
152+
# %% ../nbs/loader.ipynb 18
155153
def to_tf_dataset(dataset) -> tf.data.Dataset:
156154
if is_tf_dataset(dataset):
157155
return dataset
@@ -162,7 +160,7 @@ def to_tf_dataset(dataset) -> tf.data.Dataset:
162160
else:
163161
raise ValueError(f"Dataset type {type(dataset)} is not supported.")
164162

165-
# %% ../nbs/loader.ipynb 15
163+
# %% ../nbs/loader.ipynb 19
166164
class DataLoaderTensorflow(BaseDataLoader):
167165
"""Tensorflow Dataloader"""
168166
def __init__(

nbs/dataset.ipynb

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,18 @@
3636
"cell_type": "code",
3737
"execution_count": null,
3838
"metadata": {},
39-
"outputs": [],
39+
"outputs": [
40+
{
41+
"name": "stderr",
42+
"output_type": "stream",
43+
"text": [
44+
"2023-11-28 00:38:49.613518: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
45+
"2023-11-28 00:38:49.613568: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
46+
"2023-11-28 00:38:49.614328: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
47+
"2023-11-28 00:38:50.239626: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
48+
]
49+
}
50+
],
4051
"source": [
4152
"#| export\n",
4253
"from __future__ import print_function, division, annotations\n",
@@ -80,13 +91,13 @@
8091
" ):\n",
8192
" assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \\\n",
8293
" \"All arrays must have the same dimension.\"\n",
83-
" self.arrays = arrays\n",
94+
" self.arrays = tuple(arrays)\n",
8495
"\n",
8596
" def __len__(self):\n",
8697
" return self.arrays[0].shape[0]\n",
8798
"\n",
8899
" def __getitem__(self, index):\n",
89-
" return tuple(arr[index] for arr in self.arrays)\n",
100+
" return jax.tree_util.tree_map(lambda x: x[index], self.arrays)\n",
90101
" \n",
91102
" def to_tf_dataset(self):\n",
92103
" return tf.data.Dataset.from_tensor_slices(self.arrays)"
@@ -109,7 +120,7 @@
109120
"name": "stderr",
110121
"output_type": "stream",
111122
"text": [
112-
"WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
123+
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
113124
]
114125
}
115126
],

nbs/loader.ipynb

Lines changed: 143 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
"name": "stderr",
4242
"output_type": "stream",
4343
"text": [
44-
"2023-04-05 18:20:59.105985: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:\n",
45-
"2023-04-05 18:20:59.106076: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:\n",
46-
"2023-04-05 18:20:59.106084: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
44+
"2023-11-28 00:39:22.955810: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
45+
"2023-11-28 00:39:22.955851: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
46+
"2023-11-28 00:39:22.956517: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
47+
"2023-11-28 00:39:23.532258: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
4748
]
4849
}
4950
],
@@ -120,6 +121,77 @@
120121
" assert len(_X) == len(X_list) * batch_size\n"
121122
]
122123
},
124+
{
125+
"cell_type": "code",
126+
"execution_count": null,
127+
"metadata": {},
128+
"outputs": [],
129+
"source": [
130+
"#| hide\n",
131+
"def test_keras_dataloader(samples=1000, batch_size=12):\n",
132+
" from keras.trainers.epoch_iterator import EpochIterator\n",
133+
"\n",
134+
" feats = jnp.arange(samples).repeat(10).reshape(samples, 10)\n",
135+
" labels = jnp.arange(samples).reshape(samples, 1)\n",
136+
" ds = ArrayDataset(feats, labels)\n",
137+
" # N % batchsize != 0\n",
138+
" dl = EpochIterator(feats, labels, batch_size=batch_size, shuffle=False)\n",
139+
" for _ in range(2):\n",
140+
" X_list, Y_list = [], []\n",
141+
" for step, batch in dl.enumerate_epoch('np'):\n",
142+
" x, y = batch[0]\n",
143+
" X_list.append(x)\n",
144+
" Y_list.append(y)\n",
145+
" _X, _Y = map(jnp.concatenate, (X_list, Y_list))\n",
146+
" assert jnp.array_equal(_X, feats)\n",
147+
" assert jnp.array_equal(_Y, labels)\n",
148+
"\n",
149+
" dl = EpochIterator(feats, labels, batch_size=batch_size, shuffle=False, )\n",
150+
" for _ in range(2):\n",
151+
" X_list, Y_list = [], []\n",
152+
" for step, batch in dl.enumerate_epoch('np'):\n",
153+
" x, y = batch[0]\n",
154+
" X_list.append(x)\n",
155+
" Y_list.append(y)\n",
156+
" _X, _Y = map(jnp.concatenate, (X_list, Y_list))\n",
157+
" last_idx = len(X_list) * batch_size\n",
158+
" jnp.array_equal(_X, feats[: last_idx])\n",
159+
" jnp.array_equal(_Y, labels[: last_idx])\n",
160+
"\n",
161+
"\n",
162+
" dl_shuffle = EpochIterator(feats, labels, batch_size=batch_size, shuffle=True, )\n",
163+
" last_X, last_Y = jnp.array([]), jnp.array([])\n",
164+
" for _ in range(2):\n",
165+
" X_list, Y_list = [], []\n",
166+
" for step, batch in dl_shuffle.enumerate_epoch('np'):\n",
167+
" x, y = batch[0]\n",
168+
" assert jnp.array_equal(x[:, :1], y)\n",
169+
" X_list.append(x)\n",
170+
" Y_list.append(y)\n",
171+
" _X, _Y = map(jnp.concatenate, (X_list, Y_list))\n",
172+
" not jnp.array_equal(_X, feats)\n",
173+
" not jnp.array_equal(_Y, labels)\n",
174+
" jnp.sum(_X) == jnp.sum(feats), \\\n",
175+
" f\"jnp.sum(_X)={jnp.sum(_X)}, jnp.sum(feats)={jnp.sum(feats)}\"\n",
176+
" not jnp.array_equal(_X, last_X)\n",
177+
" not jnp.array_equal(_Y, last_Y)\n",
178+
" last_X, last_Y = _X, _Y\n",
179+
"\n",
180+
"\n",
181+
" dl_shuffle = EpochIterator(feats, labels, batch_size=batch_size, shuffle=True, )\n",
182+
" for _ in range(2):\n",
183+
" X_list, Y_list = [], []\n",
184+
" for step, batch in dl_shuffle.enumerate_epoch('np'):\n",
185+
" x, y = batch[0]\n",
186+
" assert jnp.array_equal(x[:, :1], y)\n",
187+
" X_list.append(x)\n",
188+
" Y_list.append(y)\n",
189+
" _X, _Y = map(jnp.concatenate, (X_list, Y_list))\n",
190+
" not jnp.array_equal(_X, feats)\n",
191+
" not jnp.array_equal(_Y, labels)\n",
192+
" len(_X) == len(X_list) * batch_size\n"
193+
]
194+
},
123195
{
124196
"cell_type": "code",
125197
"execution_count": null,
@@ -188,6 +260,8 @@
188260
" self.pose = 0 # record the current position in the dataset\n",
189261
" self._shuffle()\n",
190262
"\n",
263+
" self.num_batches = len(self)\n",
264+
"\n",
191265
" def _shuffle(self):\n",
192266
" if self.shuffle:\n",
193267
" self.indices = jax.random.permutation(next(self.keys), self.indices)\n",
@@ -196,7 +270,7 @@
196270
" self.pose = 0\n",
197271
" self._shuffle()\n",
198272
" raise StopIteration\n",
199-
"\n",
273+
" \n",
200274
" def __len__(self):\n",
201275
" if self.drop_last:\n",
202276
" batches = len(self.dataset) // self.batch_size # get the floor of division\n",
@@ -205,33 +279,87 @@
205279
" return batches\n",
206280
"\n",
207281
" def __next__(self):\n",
208-
" if self.pose + self.batch_size <= self.data_len:\n",
209-
" batch_indices = self.indices[self.pose: self.pose + self.batch_size]\n",
282+
" if self.pose < self.num_batches:\n",
283+
" batch_indices = self.indices[self.pose * self.batch_size: (self.pose + 1) * self.batch_size]\n",
210284
" batch_data = self.dataset[batch_indices]\n",
211-
" self.pose += self.batch_size\n",
212-
" return batch_data\n",
213-
" elif self.pose < self.data_len and not self.drop_last:\n",
214-
" batch_indices = self.indices[self.pose:]\n",
215-
" batch_data = self.dataset[batch_indices]\n",
216-
" self.pose += self.batch_size\n",
285+
" self.pose += 1\n",
217286
" return batch_data\n",
218287
" else:\n",
219288
" self._stop_iteration()\n",
220289
"\n",
221290
" def __iter__(self):\n",
222-
" return self"
291+
" return self\n"
223292
]
224293
},
225294
{
226295
"cell_type": "code",
227296
"execution_count": null,
228297
"metadata": {},
229-
"outputs": [],
298+
"outputs": [
299+
{
300+
"name": "stderr",
301+
"output_type": "stream",
302+
"text": [
303+
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
304+
]
305+
}
306+
],
230307
"source": [
231308
"#| hide\n",
232309
"test_dataloader(DataLoaderJax, samples=20, batch_size=12)\n",
233310
"test_dataloader(DataLoaderJax, samples=20, batch_size=10)\n",
234-
"test_dataloader(DataLoaderJax, samples=11, batch_size=10)"
311+
"test_dataloader(DataLoaderJax, samples=11, batch_size=10)\n",
312+
"test_dataloader(DataLoaderJax, samples=40, batch_size=12)"
313+
]
314+
},
315+
{
316+
"cell_type": "code",
317+
"execution_count": null,
318+
"metadata": {},
319+
"outputs": [],
320+
"source": [
321+
"#| hide\n",
322+
"test_keras_dataloader(samples=20, batch_size=12)\n",
323+
"test_keras_dataloader(samples=20, batch_size=10)\n",
324+
"test_keras_dataloader(samples=11, batch_size=10)\n",
325+
"test_keras_dataloader(samples=40, batch_size=12)"
326+
]
327+
},
328+
{
329+
"cell_type": "code",
330+
"execution_count": null,
331+
"metadata": {},
332+
"outputs": [
333+
{
334+
"name": "stdout",
335+
"output_type": "stream",
336+
"text": [
337+
"1.48 s ± 29.8 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)\n"
338+
]
339+
}
340+
],
341+
"source": [
342+
"%%timeit -n 5 -r 3\n",
343+
"test_dataloader(DataLoaderJax, samples=1280, batch_size=10)"
344+
]
345+
},
346+
{
347+
"cell_type": "code",
348+
"execution_count": null,
349+
"metadata": {},
350+
"outputs": [
351+
{
352+
"name": "stdout",
353+
"output_type": "stream",
354+
"text": [
355+
"301 ms ± 2.4 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)\n"
356+
]
357+
}
358+
],
359+
"source": [
360+
"#| hide\n",
361+
"%%timeit -n 5 -r 3\n",
362+
"test_keras_dataloader(samples=1280, batch_size=10)"
235363
]
236364
},
237365
{

0 commit comments

Comments
 (0)