diff --git a/jax_dataloader/core.py b/jax_dataloader/core.py index 8dd6315..50badd0 100644 --- a/jax_dataloader/core.py +++ b/jax_dataloader/core.py @@ -89,11 +89,11 @@ class DataLoader: def __init__( self, - dataset, # Dataset or Pytorch Dataset or HuggingFace Dataset - backend: str, # Dataloader backend - batch_size: int = 1, # batch size - shuffle: bool = False, # if true, dataloader shuffles before sampling each batch - drop_last: bool = False, # drop last batches or not + dataset, # Dataset from which to load the data + backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset + batch_size: int = 1, # How many samples per batch to load + shuffle: bool = False, # If true, dataloader reshuffles every epoch + drop_last: bool = False, # If true, drop the last incomplete batch **kwargs ): dl_cls = _dispatch_dataloader(backend) diff --git a/nbs/core.ipynb b/nbs/core.ipynb index 46f1e14..67ea166 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -38,18 +38,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-12-26 15:13:36.437449: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-26 15:13:36.437528: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-26 15:13:36.439236: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2023-12-26 15:13:37.500782: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" - ] - } - ], + "outputs": [], "source": [ "#| export\n", "from __future__ import print_function, division, annotations\n", @@ -179,11 +168,11 @@ "\n", " def __init__(\n", " self,\n", - " dataset, # Dataset or Pytorch Dataset or HuggingFace Dataset\n", - " backend: str, # Dataloader backend\n", - " batch_size: int = 1, # batch size\n", - " shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n", - " drop_last: bool = False, # drop last batches or not\n", + " dataset, # Dataset from which to load the data\n", + " backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset\n", + " batch_size: int = 1, # How many samples per batch to load\n", + " shuffle: bool = False, # If true, dataloader reshuffles every epoch\n", + " drop_last: bool = False, # If true, drop the last incomplete batch\n", " **kwargs\n", " ):\n", " dl_cls = _dispatch_dataloader(backend)\n",