Skip to content

Commit

Permalink
Improve docs in main DataLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Feb 16, 2024
1 parent 66437b4 commit a05aef7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 22 deletions.
10 changes: 5 additions & 5 deletions jax_dataloader/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ class DataLoader:

def __init__(
self,
dataset, # Dataset or Pytorch Dataset or HuggingFace Dataset
backend: str, # Dataloader backend
batch_size: int = 1, # batch size
shuffle: bool = False, # if true, dataloader shuffles before sampling each batch
drop_last: bool = False, # drop last batches or not
dataset, # Dataset from which to load the data
backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset
batch_size: int = 1, # How many samples per batch to load
shuffle: bool = False, # If true, dataloader reshuffles every epoch
drop_last: bool = False, # If true, drop the last incomplete batch
**kwargs
):
dl_cls = _dispatch_dataloader(backend)
Expand Down
23 changes: 6 additions & 17 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-12-26 15:13:36.437449: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2023-12-26 15:13:36.437528: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2023-12-26 15:13:36.439236: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2023-12-26 15:13:37.500782: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
}
],
"outputs": [],
"source": [
"#| export\n",
"from __future__ import print_function, division, annotations\n",
Expand Down Expand Up @@ -179,11 +168,11 @@
"\n",
" def __init__(\n",
" self,\n",
" dataset, # Dataset or Pytorch Dataset or HuggingFace Dataset\n",
" backend: str, # Dataloader backend\n",
" batch_size: int = 1, # batch size\n",
" shuffle: bool = False, # if true, dataloader shuffles before sampling each batch\n",
" drop_last: bool = False, # drop last batches or not\n",
" dataset, # Dataset from which to load the data\n",
" backend: Literal['jax', 'pytorch', 'tensorflow'], # Dataloader backend to load the dataset\n",
" batch_size: int = 1, # How many samples per batch to load\n",
" shuffle: bool = False, # If true, dataloader reshuffles every epoch\n",
" drop_last: bool = False, # If true, drop the last incomplete batch\n",
" **kwargs\n",
" ):\n",
" dl_cls = _dispatch_dataloader(backend)\n",
Expand Down

0 comments on commit a05aef7

Please sign in to comment.