Skip to content

Commit

Permalink
Make data.Serial return a generator function instead of a generator s…
Browse files Browse the repository at this point in the history
…o it can compose, add default arg for easy calling.

PiperOrigin-RevId: 323283358
  • Loading branch information
Lukasz Kaiser authored and copybara-github committed Jul 27, 2020
1 parent a2497cb commit 0f8d89e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
11 changes: 6 additions & 5 deletions trax/data/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,12 @@


def Serial(*fns): # pylint: disable=invalid-name
"""Creates an input pipeline by running all functions one after another."""
generator = None
for f in fastmath.tree_flatten(fns):
generator = f(generator)
return generator
"""Combines generator functions into one that runs them in turn."""
def composed_fns(generator=None):
for f in fastmath.tree_flatten(fns):
generator = f(generator)
return generator
return composed_fns


def Log(n_steps_per_example=1, only_shapes=True): # pylint: disable=invalid-name
Expand Down
13 changes: 11 additions & 2 deletions trax/data/inputs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,16 @@ def test_batch_data(self):
def test_serial(self):
dataset = lambda _: ((i, i+1) for i in range(10))
batches = data.Serial(dataset, data.Shuffle(3), data.Batch(10))
batch = next(batches)
batch = next(batches())
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (10,))

def test_serial_composes(self):
"""Check that data.Serial works inside another data.Serial."""
dataset = lambda _: ((i, i+1) for i in range(10))
serial1 = data.Serial(dataset, data.Shuffle(3))
batches = data.Serial(serial1, data.Batch(10))
batch = next(batches())
self.assertLen(batch, 2)
self.assertEqual(batch[0].shape, (10,))

Expand All @@ -88,7 +97,7 @@ def test_serial_with_python(self):
lambda g: filter(lambda x: x[0] % 2 == 1, g),
data.Batch(2)
)
batch = next(batches)
batch = next(batches())
self.assertLen(batch, 2)
(xs, ys) = batch
# First tuple after filtering is (1, 3) = (1, 2+1).
Expand Down
3 changes: 2 additions & 1 deletion trax/data/tf_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def select_from(example):
dataset = dataset.map(select_from)
dataset = dataset.repeat()

def gen(unused_arg):
def gen(generator=None):
del generator
for example in fastmath.dataset_as_numpy(dataset):
yield example
return gen
Expand Down

0 comments on commit 0f8d89e

Please sign in to comment.