diff --git a/aiutils/data/batch_creators.py b/aiutils/data/batch_creators.py index 46428be..1136a20 100644 --- a/aiutils/data/batch_creators.py +++ b/aiutils/data/batch_creators.py @@ -3,6 +3,7 @@ from multiprocessing import Process, Queue from builtins import range import itertools +from math import ceil def sequential(batch_size, num_samples, num_epochs=1, offset=0): @@ -35,11 +36,15 @@ def random(batch_size, num_samples, num_epochs=1, offset=0): Output: Yields a generator with random sampling with replacement """ - for epoch in range(num_epochs) if num_epochs > 0 else itertools.count(): - indices = np.random.permutation(num_samples) + offset - indices = indices.tolist() - for i in range(0, num_samples - batch_size + 1, batch_size): + if num_samples < batch_size: + indices = np.random.permutation(list(range(num_samples))*int(ceil(float(batch_size)/num_samples))) + offset + indices = indices.tolist() + else: + indices = np.random.permutation(num_samples) + offset + indices = indices.tolist() + + for i in range(0, len(indices) - batch_size + 1, batch_size): yield indices[i:i + batch_size]