Skip to content

Commit fcec10f

Browse files
authored
Merge pull request #36 from BirkhoffG/reproducibility
Implement manual_seed to set the global seed work for Reproducibility
2 parents 6f4afe2 + 5742bbb commit fcec10f

File tree

7 files changed

+122
-28
lines changed

7 files changed

+122
-28
lines changed

jax_dataloader/_modidx.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@
113113
'jax_dataloader/tests.py'),
114114
'jax_dataloader.tests.test_shuffle': ('tests.html#test_shuffle', 'jax_dataloader/tests.py'),
115115
'jax_dataloader.tests.test_shuffle_drop_last': ( 'tests.html#test_shuffle_drop_last',
116-
'jax_dataloader/tests.py')},
116+
'jax_dataloader/tests.py'),
117+
'jax_dataloader.tests.test_shuffle_reproducible': ( 'tests.html#test_shuffle_reproducible',
118+
'jax_dataloader/tests.py')},
117119
'jax_dataloader.utils': { 'jax_dataloader.utils.Config': ('utils.html#config', 'jax_dataloader/utils.py'),
118120
'jax_dataloader.utils.Config.default': ('utils.html#config.default', 'jax_dataloader/utils.py'),
119121
'jax_dataloader.utils.asnumpy': ('utils.html#asnumpy', 'jax_dataloader/utils.py'),
@@ -125,4 +127,5 @@
125127
'jax_dataloader/utils.py'),
126128
'jax_dataloader.utils.get_config': ('utils.html#get_config', 'jax_dataloader/utils.py'),
127129
'jax_dataloader.utils.has_pytorch_tensor': ( 'utils.html#has_pytorch_tensor',
128-
'jax_dataloader/utils.py')}}}
130+
'jax_dataloader/utils.py'),
131+
'jax_dataloader.utils.manual_seed': ('utils.html#manual_seed', 'jax_dataloader/utils.py')}}}

jax_dataloader/loaders/torch.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ..imports import *
66
from . import BaseDataLoader
77
from ..datasets import Dataset, ArrayDataset, JAXDataset
8-
from ..utils import check_pytorch_installed
8+
from ..utils import check_pytorch_installed, get_config
99
from ..tests import *
1010
from jax.tree_util import tree_map
1111
import warnings
@@ -53,13 +53,20 @@ def __init__(
5353
super().__init__(dataset, batch_size, shuffle, drop_last)
5454
check_pytorch_installed()
5555
from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler
56+
import torch
5657

5758
if 'sampler' in kwargs:
5859
warnings.warn("`sampler` is currently not supported. We will ignore it and use `shuffle` instead.")
5960
del kwargs['sampler']
6061

62+
# convert to torch dataset
6163
dataset = to_torch_dataset(dataset)
62-
sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
64+
# init batch sampler
65+
generator = torch.Generator().manual_seed(get_config().global_seed)
66+
if shuffle:
67+
sampler = RandomSampler(dataset, generator=generator)
68+
else:
69+
sampler = SequentialSampler(dataset)
6370
batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)
6471

6572
self.dataloader = torch_data.DataLoader(

jax_dataloader/tests.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from __future__ import print_function, division, annotations
55
from .imports import *
66
from .datasets import ArrayDataset
7+
import jax_dataloader as jdl
78

89
# %% auto 0
9-
__all__ = ['test_dataloader']
10+
__all__ = ['test_shuffle_reproducible', 'test_dataloader']
1011

1112
# %% ../nbs/tests.ipynb 3
1213
def get_batch(batch):
@@ -82,6 +83,31 @@ def test_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):
8283
assert len(_X) == len(X_list) * batch_size
8384

8485
# %% ../nbs/tests.ipynb 8
86+
def test_shuffle_reproducible(cls, ds, batch_size: int, feats, labels):
87+
"""Test that the shuffle is reproducible"""
88+
def _iter_dataloader(dataloader):
89+
X_list, Y_list = [], []
90+
for batch in dataloader:
91+
x, y = get_batch(batch)
92+
X_list.append(x)
93+
Y_list.append(y)
94+
return X_list, Y_list
95+
96+
# Test that the shuffle is reproducible
97+
jdl.manual_seed(0)
98+
dl_1 = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
99+
X_list_1, Y_list_1 = _iter_dataloader(dl_1)
100+
dl_2 = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
101+
X_list_2, Y_list_2 = _iter_dataloader(dl_2)
102+
assert jnp.array_equal(jnp.concatenate(X_list_1), jnp.concatenate(X_list_2))
103+
104+
# Test that the shuffle is different if the seed is different
105+
jdl.manual_seed(1234)
106+
dl_3 = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
107+
X_list_3, Y_list_3 = _iter_dataloader(dl_3)
108+
assert not jnp.array_equal(jnp.concatenate(X_list_1), jnp.concatenate(X_list_3))
109+
110+
# %% ../nbs/tests.ipynb 9
85111
def test_dataloader(cls, ds_type='jax', samples=1000, batch_size=12):
86112
feats = np.arange(samples).repeat(10).reshape(samples, 10)
87113
labels = np.arange(samples).reshape(samples, 1)
@@ -102,3 +128,4 @@ def test_dataloader(cls, ds_type='jax', samples=1000, batch_size=12):
102128
test_no_shuffle_drop_last(cls, ds, batch_size, feats, labels)
103129
test_shuffle(cls, ds, batch_size, feats, labels)
104130
test_shuffle_drop_last(cls, ds, batch_size, feats, labels)
131+
test_shuffle_reproducible(cls, ds, batch_size, feats, labels)

jax_dataloader/utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import collections
88

99
# %% auto 0
10-
__all__ = ['Config', 'get_config', 'check_pytorch_installed', 'has_pytorch_tensor', 'check_hf_installed', 'check_tf_installed',
11-
'asnumpy']
10+
__all__ = ['Config', 'get_config', 'manual_seed', 'check_pytorch_installed', 'has_pytorch_tensor', 'check_hf_installed',
11+
'check_tf_installed', 'asnumpy']
1212

1313
# %% ../nbs/utils.ipynb 6
1414
@dataclass
@@ -28,15 +28,20 @@ def default(cls) -> Config:
2828
def get_config() -> Config:
2929
return main_config
3030

31-
# %% ../nbs/utils.ipynb 10
31+
# %% ../nbs/utils.ipynb 9
32+
def manual_seed(seed: int):
33+
"""Set the seed for the library"""
34+
main_config.global_seed = seed
35+
36+
# %% ../nbs/utils.ipynb 12
3237
def check_pytorch_installed():
3338
if torch_data is None:
3439
raise ModuleNotFoundError("`pytorch` library needs to be installed. "
3540
"Try `pip install torch`. Please refer to pytorch documentation for details: "
3641
"https://pytorch.org/get-started/.")
3742

3843

39-
# %% ../nbs/utils.ipynb 12
44+
# %% ../nbs/utils.ipynb 14
4045
def has_pytorch_tensor(batch) -> bool:
4146
if isinstance(batch[0], torch.Tensor):
4247
return True
@@ -46,21 +51,21 @@ def has_pytorch_tensor(batch) -> bool:
4651
else:
4752
return False
4853

49-
# %% ../nbs/utils.ipynb 13
54+
# %% ../nbs/utils.ipynb 15
5055
def check_hf_installed():
5156
if hf_datasets is None:
5257
raise ModuleNotFoundError("`datasets` library needs to be installed. "
5358
"Try `pip install datasets`. Please refer to huggingface documentation for details: "
5459
"https://huggingface.co/docs/datasets/installation.html.")
5560

56-
# %% ../nbs/utils.ipynb 15
61+
# %% ../nbs/utils.ipynb 17
5762
def check_tf_installed():
5863
if tf is None:
5964
raise ModuleNotFoundError("`tensorflow` library needs to be installed. "
6065
"Try `pip install tensorflow`. Please refer to tensorflow documentation for details: "
6166
"https://www.tensorflow.org/install/pip.")
6267

63-
# %% ../nbs/utils.ipynb 18
68+
# %% ../nbs/utils.ipynb 20
6469
def asnumpy(x) -> np.ndarray:
6570
if isinstance(x, np.ndarray):
6671
return x

nbs/loader.torch.ipynb

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,14 @@
2727
"cell_type": "code",
2828
"execution_count": null,
2929
"metadata": {},
30-
"outputs": [
31-
{
32-
"name": "stderr",
33-
"output_type": "stream",
34-
"text": [
35-
"2023-12-26 09:48:50.763338: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
36-
"2023-12-26 09:48:50.763415: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
37-
"2023-12-26 09:48:50.777782: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
38-
"2023-12-26 09:48:52.348997: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
39-
]
40-
}
41-
],
30+
"outputs": [],
4231
"source": [
4332
"#| export\n",
4433
"from __future__ import print_function, division, annotations\n",
4534
"from jax_dataloader.imports import *\n",
4635
"from jax_dataloader.loaders import BaseDataLoader\n",
4736
"from jax_dataloader.datasets import Dataset, ArrayDataset, JAXDataset\n",
48-
"from jax_dataloader.utils import check_pytorch_installed\n",
37+
"from jax_dataloader.utils import check_pytorch_installed, get_config\n",
4938
"from jax_dataloader.tests import *\n",
5039
"from jax.tree_util import tree_map\n",
5140
"import warnings\n"
@@ -129,13 +118,20 @@
129118
" super().__init__(dataset, batch_size, shuffle, drop_last)\n",
130119
" check_pytorch_installed()\n",
131120
" from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler\n",
121+
" import torch\n",
132122
"\n",
133123
" if 'sampler' in kwargs:\n",
134124
" warnings.warn(\"`sampler` is currently not supported. We will ignore it and use `shuffle` instead.\")\n",
135125
" del kwargs['sampler']\n",
136126
"\n",
127+
" # convert to torch dataset\n",
137128
" dataset = to_torch_dataset(dataset)\n",
138-
" sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)\n",
129+
" # init batch sampler\n",
130+
" generator = torch.Generator().manual_seed(get_config().global_seed)\n",
131+
" if shuffle: \n",
132+
" sampler = RandomSampler(dataset, generator=generator)\n",
133+
" else: \n",
134+
" sampler = SequentialSampler(dataset)\n",
139135
" batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last)\n",
140136
"\n",
141137
" self.dataloader = torch_data.DataLoader(\n",

nbs/tests.ipynb

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
"#| export\n",
3333
"from __future__ import print_function, division, annotations\n",
3434
"from jax_dataloader.imports import *\n",
35-
"from jax_dataloader.datasets import ArrayDataset"
35+
"from jax_dataloader.datasets import ArrayDataset\n",
36+
"import jax_dataloader as jdl"
3637
]
3738
},
3839
{
@@ -143,6 +144,38 @@
143144
" assert len(_X) == len(X_list) * batch_size"
144145
]
145146
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": null,
150+
"metadata": {},
151+
"outputs": [],
152+
"source": [
153+
"#| export\n",
154+
"def test_shuffle_reproducible(cls, ds, batch_size: int, feats, labels):\n",
155+
" \"\"\"Test that the shuffle is reproducible\"\"\"\n",
156+
" def _iter_dataloader(dataloader):\n",
157+
" X_list, Y_list = [], []\n",
158+
" for batch in dataloader:\n",
159+
" x, y = get_batch(batch)\n",
160+
" X_list.append(x)\n",
161+
" Y_list.append(y)\n",
162+
" return X_list, Y_list\n",
163+
"\n",
164+
" # Test that the shuffle is reproducible\n",
165+
" jdl.manual_seed(0)\n",
166+
" dl_1 = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)\n",
167+
" X_list_1, Y_list_1 = _iter_dataloader(dl_1)\n",
168+
" dl_2 = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)\n",
169+
" X_list_2, Y_list_2 = _iter_dataloader(dl_2)\n",
170+
" assert jnp.array_equal(jnp.concatenate(X_list_1), jnp.concatenate(X_list_2))\n",
171+
"\n",
172+
" # Test that the shuffle is different if the seed is different\n",
173+
" jdl.manual_seed(1234)\n",
174+
" dl_3 = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)\n",
175+
" X_list_3, Y_list_3 = _iter_dataloader(dl_3)\n",
176+
" assert not jnp.array_equal(jnp.concatenate(X_list_1), jnp.concatenate(X_list_3))"
177+
]
178+
},
146179
{
147180
"cell_type": "code",
148181
"execution_count": null,
@@ -169,7 +202,8 @@
169202
" test_no_shuffle(cls, ds, batch_size, feats, labels)\n",
170203
" test_no_shuffle_drop_last(cls, ds, batch_size, feats, labels)\n",
171204
" test_shuffle(cls, ds, batch_size, feats, labels)\n",
172-
" test_shuffle_drop_last(cls, ds, batch_size, feats, labels)"
205+
" test_shuffle_drop_last(cls, ds, batch_size, feats, labels)\n",
206+
" test_shuffle_reproducible(cls, ds, batch_size, feats, labels)"
173207
]
174208
},
175209
{

nbs/utils.ipynb

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,28 @@
9999
" return main_config"
100100
]
101101
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": null,
105+
"metadata": {},
106+
"outputs": [],
107+
"source": [
108+
"#| export\n",
109+
"def manual_seed(seed: int):\n",
110+
" \"\"\"Set the seed for the library\"\"\"\n",
111+
" main_config.global_seed = seed"
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"metadata": {},
118+
"outputs": [],
119+
"source": [
120+
"manual_seed(11)\n",
121+
"assert get_config().global_seed == 11"
122+
]
123+
},
102124
{
103125
"cell_type": "markdown",
104126
"metadata": {},

0 commit comments

Comments
 (0)