From a073eda63fde8ffa53a3b4b887e2f608f21d4d45 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 13 Apr 2017 22:13:19 -0500 Subject: [PATCH] allow batch creator to sample with replacement for big batches on small datasets --- aiutils/data/batch_creators.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/aiutils/data/batch_creators.py b/aiutils/data/batch_creators.py index 46428be..1136a20 100644 --- a/aiutils/data/batch_creators.py +++ b/aiutils/data/batch_creators.py @@ -3,6 +3,7 @@ from multiprocessing import Process, Queue from builtins import range import itertools +from math import ceil def sequential(batch_size, num_samples, num_epochs=1, offset=0): @@ -35,11 +36,15 @@ def random(batch_size, num_samples, num_epochs=1, offset=0): Output: Yields a generator with random sampling with replacement """ - for epoch in range(num_epochs) if num_epochs > 0 else itertools.count(): - indices = np.random.permutation(num_samples) + offset - indices = indices.tolist() - for i in range(0, num_samples - batch_size + 1, batch_size): + if num_samples < batch_size: + indices = np.random.permutation(list(range(num_samples))*int(ceil(float(batch_size)/num_samples))) + offset + indices = indices.tolist() + else: + indices = np.random.permutation(num_samples) + offset + indices = indices.tolist() + + for i in range(0, len(indices) - batch_size + 1, batch_size): yield indices[i:i + batch_size]