From f842322aa5d28674ca78ac75510c547d2fa330b3 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 15 Jan 2025 07:56:01 -0500 Subject: [PATCH] Make repeat_source work with subprocess (#325) --- src/spdl/source/utils.py | 71 +++++++++++++------ .../spdl_unittest/dataloader/iterator_test.py | 15 ++++ 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/src/spdl/source/utils.py b/src/spdl/source/utils.py index 9096e2fa..805ca4bc 100644 --- a/src/spdl/source/utils.py +++ b/src/spdl/source/utils.py @@ -21,7 +21,7 @@ Iterator, Sequence, ) -from typing import TypeVar +from typing import Any, TypeVar from ._type import IterableWithShuffle @@ -364,26 +364,7 @@ def __iter__(self) -> Iterator[T]: ################################################################################ -def repeat_source( - src: Iterable[T] | IterableWithShuffle[T], - epoch: int = 0, -) -> Iterator[T]: - """Convert an iterable into an infinite iterator with optional shuffling. - - Roughly equivalent to the following code snippet. - - .. code-block:: - - while True: - if hasattr(src, "shuffle"): - src.shuffle(seed=epoch) - yield from src - epoch += 1 - - Args: - src: The source to repeat. - epoch: The epoch number to start with. - """ +def _repeat(src: Iterable[T] | IterableWithShuffle[T], epoch: int) -> Iterator[T]: while True: _LG.info("Starting source epoch %d.", epoch) t0 = time.monotonic() @@ -405,3 +386,51 @@ def repeat_source( qps, ) epoch += 1 + + +class _RepeatIterator(Iterator[T]): + def __init__( + self, + src: Iterable[T] | IterableWithShuffle[T], + epoch: int = 0, + ) -> None: + self.src = src + self.epoch = epoch + self._iter: Iterator[T] | None = None + + def __iter__(self) -> Iterator[T]: + return self + + def __getstate__(self) -> dict[str, Any]: # pyre-ignore: [11] + if self._iter is not None: + raise ValueError("Cannot pickle after iteration is started.") + return self.__dict__ + + def __next__(self) -> T: + if self._iter is None: + self._iter = _repeat(self.src, self.epoch) + return next(self._iter) + + +def repeat_source( + src: Iterable[T] | IterableWithShuffle[T], + epoch: int = 0, +) -> Iterator[T]: + """Convert an iterable into an infinite iterator with optional shuffling. + + Roughly equivalent to the following code snippet. + + .. code-block:: + + while True: + if hasattr(src, "shuffle"): + src.shuffle(seed=epoch) + yield from src + epoch += 1 + + Args: + src: The source to repeat. + epoch: The epoch number to start with. + """ + # Returning object so that it can be passed to a subprocess. + return _RepeatIterator(src, epoch) diff --git a/tests/spdl_unittest/dataloader/iterator_test.py b/tests/spdl_unittest/dataloader/iterator_test.py index 950def86..f44f6dd3 100644 --- a/tests/spdl_unittest/dataloader/iterator_test.py +++ b/tests/spdl_unittest/dataloader/iterator_test.py @@ -7,6 +7,7 @@ # pyre-unsafe import os.path +import pickle import random import tempfile import time @@ -242,6 +243,20 @@ def __iter__(self) -> Iterator[int]: assert next(gen) == 2 +def test_repeat_source_picklable(): + """repeat_source is picklable.""" + + src = list(range(10)) + src = repeat_source(src) + + serialized = pickle.dumps(src) + src2 = pickle.loads(serialized) + + for _ in range(3): + for i in range(10): + assert next(src) == next(src2) == i + + def iter_range(n: int) -> Iterable[int]: yield from range(n)