@@ -18,6 +18,7 @@ def get_batch(batch):
18
18
# %% ../nbs/tests.ipynb 4
19
19
def test_no_shuffle (cls , ds , batch_size : int , feats , labels ):
20
20
dl = cls (ds , batch_size = batch_size , shuffle = False )
21
+ assert len (dl ) == len (feats ) // batch_size + 1
21
22
for _ in range (2 ):
22
23
X_list , Y_list = [], []
23
24
for batch in dl :
@@ -31,6 +32,7 @@ def test_no_shuffle(cls, ds, batch_size: int, feats, labels):
31
32
# %% ../nbs/tests.ipynb 5
32
33
def test_no_shuffle_drop_last (cls , ds , batch_size : int , feats , labels ):
33
34
dl = cls (ds , batch_size = batch_size , shuffle = False , drop_last = True )
35
+ assert len (dl ) == len (feats ) // batch_size
34
36
for _ in range (2 ):
35
37
X_list , Y_list = [], []
36
38
for batch in dl :
@@ -46,6 +48,7 @@ def test_no_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):
46
48
def test_shuffle (cls , ds , batch_size : int , feats , labels ):
47
49
dl = cls (ds , batch_size = batch_size , shuffle = True , drop_last = False )
48
50
last_X , last_Y = jnp .array ([]), jnp .array ([])
51
+ assert len (dl ) == len (feats ) // batch_size + 1
49
52
for _ in range (2 ):
50
53
X_list , Y_list = [], []
51
54
for batch in dl :
@@ -65,6 +68,7 @@ def test_shuffle(cls, ds, batch_size: int, feats, labels):
65
68
# %% ../nbs/tests.ipynb 7
66
69
def test_shuffle_drop_last (cls , ds , batch_size : int , feats , labels ):
67
70
dl = cls (ds , batch_size = batch_size , shuffle = True , drop_last = True )
71
+ assert len (dl ) == len (feats ) // batch_size
68
72
for _ in range (2 ):
69
73
X_list , Y_list = [], []
70
74
for batch in dl :
0 commit comments