Skip to content

Commit 95fc759

Browse files
committed
Update test cases for testing num of batches
1 parent a258074 commit 95fc759

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

jax_dataloader/tests.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def get_batch(batch):
1818
# %% ../nbs/tests.ipynb 4
1919
def test_no_shuffle(cls, ds, batch_size: int, feats, labels):
2020
dl = cls(ds, batch_size=batch_size, shuffle=False)
21+
assert len(dl) == len(feats) // batch_size + 1
2122
for _ in range(2):
2223
X_list, Y_list = [], []
2324
for batch in dl:
@@ -31,6 +32,7 @@ def test_no_shuffle(cls, ds, batch_size: int, feats, labels):
3132
# %% ../nbs/tests.ipynb 5
3233
def test_no_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):
3334
dl = cls(ds, batch_size=batch_size, shuffle=False, drop_last=True)
35+
assert len(dl) == len(feats) // batch_size
3436
for _ in range(2):
3537
X_list, Y_list = [], []
3638
for batch in dl:
@@ -46,6 +48,7 @@ def test_no_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):
4648
def test_shuffle(cls, ds, batch_size: int, feats, labels):
4749
dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)
4850
last_X, last_Y = jnp.array([]), jnp.array([])
51+
assert len(dl) == len(feats) // batch_size + 1
4952
for _ in range(2):
5053
X_list, Y_list = [], []
5154
for batch in dl:
@@ -65,6 +68,7 @@ def test_shuffle(cls, ds, batch_size: int, feats, labels):
6568
# %% ../nbs/tests.ipynb 7
6669
def test_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):
6770
dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=True)
71+
assert len(dl) == len(feats) // batch_size
6872
for _ in range(2):
6973
X_list, Y_list = [], []
7074
for batch in dl:

nbs/tests.ipynb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"#| exporti\n",
5959
"def test_no_shuffle(cls, ds, batch_size: int, feats, labels):\n",
6060
" dl = cls(ds, batch_size=batch_size, shuffle=False)\n",
61+
" assert len(dl) == len(feats) // batch_size + 1\n",
6162
" for _ in range(2):\n",
6263
" X_list, Y_list = [], []\n",
6364
" for batch in dl:\n",
@@ -78,6 +79,7 @@
7879
"#| exporti\n",
7980
"def test_no_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):\n",
8081
" dl = cls(ds, batch_size=batch_size, shuffle=False, drop_last=True)\n",
82+
" assert len(dl) == len(feats) // batch_size\n",
8183
" for _ in range(2):\n",
8284
" X_list, Y_list = [], []\n",
8385
" for batch in dl:\n",
@@ -100,6 +102,7 @@
100102
"def test_shuffle(cls, ds, batch_size: int, feats, labels):\n",
101103
" dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)\n",
102104
" last_X, last_Y = jnp.array([]), jnp.array([])\n",
105+
" assert len(dl) == len(feats) // batch_size + 1\n",
103106
" for _ in range(2):\n",
104107
" X_list, Y_list = [], []\n",
105108
" for batch in dl:\n",
@@ -126,6 +129,7 @@
126129
"#| exporti\n",
127130
"def test_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):\n",
128131
" dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=True)\n",
132+
" assert len(dl) == len(feats) // batch_size\n",
129133
" for _ in range(2):\n",
130134
" X_list, Y_list = [], []\n",
131135
" for batch in dl:\n",

0 commit comments

Comments
 (0)