|
41 | 41 | "name": "stderr",
|
42 | 42 | "output_type": "stream",
|
43 | 43 | "text": [
|
44 |
| - "2023-04-05 18:20:59.105985: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:\n", |
45 |
| - "2023-04-05 18:20:59.106076: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:\n", |
46 |
| - "2023-04-05 18:20:59.106084: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" |
| 44 | + "2023-11-28 00:39:22.955810: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", |
| 45 | + "2023-11-28 00:39:22.955851: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", |
| 46 | + "2023-11-28 00:39:22.956517: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", |
| 47 | + "2023-11-28 00:39:23.532258: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" |
47 | 48 | ]
|
48 | 49 | }
|
49 | 50 | ],
|
|
120 | 121 | " assert len(_X) == len(X_list) * batch_size\n"
|
121 | 122 | ]
|
122 | 123 | },
|
| 124 | + { |
| 125 | + "cell_type": "code", |
| 126 | + "execution_count": null, |
| 127 | + "metadata": {}, |
| 128 | + "outputs": [], |
| 129 | + "source": [ |
| 130 | + "#| hide\n", |
| 131 | + "def test_keras_dataloader(samples=1000, batch_size=12):\n", |
| 132 | + " from keras.trainers.epoch_iterator import EpochIterator\n", |
| 133 | + "\n", |
| 134 | + " feats = jnp.arange(samples).repeat(10).reshape(samples, 10)\n", |
| 135 | + " labels = jnp.arange(samples).reshape(samples, 1)\n", |
| 136 | + " ds = ArrayDataset(feats, labels)\n", |
| 137 | + " # N % batchsize != 0\n", |
| 138 | + " dl = EpochIterator(feats, labels, batch_size=batch_size, shuffle=False)\n", |
| 139 | + " for _ in range(2):\n", |
| 140 | + " X_list, Y_list = [], []\n", |
| 141 | + " for step, batch in dl.enumerate_epoch('np'):\n", |
| 142 | + " x, y = batch[0]\n", |
| 143 | + " X_list.append(x)\n", |
| 144 | + " Y_list.append(y)\n", |
| 145 | + " _X, _Y = map(jnp.concatenate, (X_list, Y_list))\n", |
| 146 | + " assert jnp.array_equal(_X, feats)\n", |
| 147 | + " assert jnp.array_equal(_Y, labels)\n", |
| 148 | + "\n", |
| 149 | + " dl = EpochIterator(feats, labels, batch_size=batch_size, shuffle=False, )\n", |
| 150 | + " for _ in range(2):\n", |
| 151 | + " X_list, Y_list = [], []\n", |
| 152 | + " for step, batch in dl.enumerate_epoch('np'):\n", |
| 153 | + " x, y = batch[0]\n", |
| 154 | + " X_list.append(x)\n", |
| 155 | + " Y_list.append(y)\n", |
| 156 | + " _X, _Y = map(jnp.concatenate, (X_list, Y_list))\n", |
| 157 | + " last_idx = len(X_list) * batch_size\n", |
| 158 | + " jnp.array_equal(_X, feats[: last_idx])\n", |
| 159 | + " jnp.array_equal(_Y, labels[: last_idx])\n", |
| 160 | + "\n", |
| 161 | + "\n", |
| 162 | + " dl_shuffle = EpochIterator(feats, labels, batch_size=batch_size, shuffle=True, )\n", |
| 163 | + " last_X, last_Y = jnp.array([]), jnp.array([])\n", |
| 164 | + " for _ in range(2):\n", |
| 165 | + " X_list, Y_list = [], []\n", |
| 166 | + " for step, batch in dl_shuffle.enumerate_epoch('np'):\n", |
| 167 | + " x, y = batch[0]\n", |
| 168 | + " assert jnp.array_equal(x[:, :1], y)\n", |
| 169 | + " X_list.append(x)\n", |
| 170 | + " Y_list.append(y)\n", |
| 171 | + " _X, _Y = map(jnp.concatenate, (X_list, Y_list))\n", |
| 172 | + " not jnp.array_equal(_X, feats)\n", |
| 173 | + " not jnp.array_equal(_Y, labels)\n", |
| 174 | + " jnp.sum(_X) == jnp.sum(feats), \\\n", |
| 175 | + " f\"jnp.sum(_X)={jnp.sum(_X)}, jnp.sum(feats)={jnp.sum(feats)}\"\n", |
| 176 | + " not jnp.array_equal(_X, last_X)\n", |
| 177 | + " not jnp.array_equal(_Y, last_Y)\n", |
| 178 | + " last_X, last_Y = _X, _Y\n", |
| 179 | + "\n", |
| 180 | + "\n", |
| 181 | + " dl_shuffle = EpochIterator(feats, labels, batch_size=batch_size, shuffle=True, )\n", |
| 182 | + " for _ in range(2):\n", |
| 183 | + " X_list, Y_list = [], []\n", |
| 184 | + " for step, batch in dl_shuffle.enumerate_epoch('np'):\n", |
| 185 | + " x, y = batch[0]\n", |
| 186 | + " assert jnp.array_equal(x[:, :1], y)\n", |
| 187 | + " X_list.append(x)\n", |
| 188 | + " Y_list.append(y)\n", |
| 189 | + " _X, _Y = map(jnp.concatenate, (X_list, Y_list))\n", |
| 190 | + " not jnp.array_equal(_X, feats)\n", |
| 191 | + " not jnp.array_equal(_Y, labels)\n", |
| 192 | + " len(_X) == len(X_list) * batch_size\n" |
| 193 | + ] |
| 194 | + }, |
123 | 195 | {
|
124 | 196 | "cell_type": "code",
|
125 | 197 | "execution_count": null,
|
|
188 | 260 | " self.pose = 0 # record the current position in the dataset\n",
|
189 | 261 | " self._shuffle()\n",
|
190 | 262 | "\n",
|
| 263 | + " self.num_batches = len(self)\n", |
| 264 | + "\n", |
191 | 265 | " def _shuffle(self):\n",
|
192 | 266 | " if self.shuffle:\n",
|
193 | 267 | " self.indices = jax.random.permutation(next(self.keys), self.indices)\n",
|
|
196 | 270 | " self.pose = 0\n",
|
197 | 271 | " self._shuffle()\n",
|
198 | 272 | " raise StopIteration\n",
|
199 |
| - "\n", |
| 273 | + " \n", |
200 | 274 | " def __len__(self):\n",
|
201 | 275 | " if self.drop_last:\n",
|
202 | 276 | " batches = len(self.dataset) // self.batch_size # get the floor of division\n",
|
|
205 | 279 | " return batches\n",
|
206 | 280 | "\n",
|
207 | 281 | " def __next__(self):\n",
|
208 |
| - " if self.pose + self.batch_size <= self.data_len:\n", |
209 |
| - " batch_indices = self.indices[self.pose: self.pose + self.batch_size]\n", |
| 282 | + " if self.pose < self.num_batches:\n", |
| 283 | + " batch_indices = self.indices[self.pose * self.batch_size: (self.pose + 1) * self.batch_size]\n", |
210 | 284 | " batch_data = self.dataset[batch_indices]\n",
|
211 |
| - " self.pose += self.batch_size\n", |
212 |
| - " return batch_data\n", |
213 |
| - " elif self.pose < self.data_len and not self.drop_last:\n", |
214 |
| - " batch_indices = self.indices[self.pose:]\n", |
215 |
| - " batch_data = self.dataset[batch_indices]\n", |
216 |
| - " self.pose += self.batch_size\n", |
| 285 | + " self.pose += 1\n", |
217 | 286 | " return batch_data\n",
|
218 | 287 | " else:\n",
|
219 | 288 | " self._stop_iteration()\n",
|
220 | 289 | "\n",
|
221 | 290 | " def __iter__(self):\n",
|
222 |
| - " return self" |
| 291 | + " return self\n" |
223 | 292 | ]
|
224 | 293 | },
|
225 | 294 | {
|
226 | 295 | "cell_type": "code",
|
227 | 296 | "execution_count": null,
|
228 | 297 | "metadata": {},
|
229 |
| - "outputs": [], |
| 298 | + "outputs": [ |
| 299 | + { |
| 300 | + "name": "stderr", |
| 301 | + "output_type": "stream", |
| 302 | + "text": [ |
| 303 | + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" |
| 304 | + ] |
| 305 | + } |
| 306 | + ], |
230 | 307 | "source": [
|
231 | 308 | "#| hide\n",
|
232 | 309 | "test_dataloader(DataLoaderJax, samples=20, batch_size=12)\n",
|
233 | 310 | "test_dataloader(DataLoaderJax, samples=20, batch_size=10)\n",
|
234 |
| - "test_dataloader(DataLoaderJax, samples=11, batch_size=10)" |
| 311 | + "test_dataloader(DataLoaderJax, samples=11, batch_size=10)\n", |
| 312 | + "test_dataloader(DataLoaderJax, samples=40, batch_size=12)" |
| 313 | + ] |
| 314 | + }, |
| 315 | + { |
| 316 | + "cell_type": "code", |
| 317 | + "execution_count": null, |
| 318 | + "metadata": {}, |
| 319 | + "outputs": [], |
| 320 | + "source": [ |
| 321 | + "#| hide\n", |
| 322 | + "test_keras_dataloader(samples=20, batch_size=12)\n", |
| 323 | + "test_keras_dataloader(samples=20, batch_size=10)\n", |
| 324 | + "test_keras_dataloader(samples=11, batch_size=10)\n", |
| 325 | + "test_keras_dataloader(samples=40, batch_size=12)" |
| 326 | + ] |
| 327 | + }, |
| 328 | + { |
| 329 | + "cell_type": "code", |
| 330 | + "execution_count": null, |
| 331 | + "metadata": {}, |
| 332 | + "outputs": [ |
| 333 | + { |
| 334 | + "name": "stdout", |
| 335 | + "output_type": "stream", |
| 336 | + "text": [ |
| 337 | + "1.48 s ± 29.8 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)\n" |
| 338 | + ] |
| 339 | + } |
| 340 | + ], |
| 341 | + "source": [ |
| 342 | + "%%timeit -n 5 -r 3\n", |
| 343 | + "test_dataloader(DataLoaderJax, samples=1280, batch_size=10)" |
| 344 | + ] |
| 345 | + }, |
| 346 | + { |
| 347 | + "cell_type": "code", |
| 348 | + "execution_count": null, |
| 349 | + "metadata": {}, |
| 350 | + "outputs": [ |
| 351 | + { |
| 352 | + "name": "stdout", |
| 353 | + "output_type": "stream", |
| 354 | + "text": [ |
| 355 | + "301 ms ± 2.4 ms per loop (mean ± std. dev. of 3 runs, 5 loops each)\n" |
| 356 | + ] |
| 357 | + } |
| 358 | + ], |
| 359 | + "source": [ |
| 360 | + "#| hide\n", |
| 361 | + "%%timeit -n 5 -r 3\n", |
| 362 | + "test_keras_dataloader(samples=1280, batch_size=10)" |
235 | 363 | ]
|
236 | 364 | },
|
237 | 365 | {
|
|
0 commit comments