Skip to content

Commit

Permalink
Make repeat_source work with subprocess (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Jan 15, 2025
1 parent 9002aaf commit f842322
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 21 deletions.
71 changes: 50 additions & 21 deletions src/spdl/source/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Iterator,
Sequence,
)
from typing import TypeVar
from typing import Any, TypeVar

from ._type import IterableWithShuffle

Expand Down Expand Up @@ -364,26 +364,7 @@ def __iter__(self) -> Iterator[T]:
################################################################################


def repeat_source(
src: Iterable[T] | IterableWithShuffle[T],
epoch: int = 0,
) -> Iterator[T]:
"""Convert an iterable into an infinite iterator with optional shuffling.
Roughly equivalent to the following code snippet.
.. code-block::
while True:
if hasattr(src, "shuffle"):
src.shuffle(seed=epoch)
yield from src
epoch += 1
Args:
src: The source to repeat.
epoch: The epoch number to start with.
"""
def _repeat(src: Iterable[T] | IterableWithShuffle[T], epoch: int) -> Iterator[T]:
while True:
_LG.info("Starting source epoch %d.", epoch)
t0 = time.monotonic()
Expand All @@ -405,3 +386,51 @@ def repeat_source(
qps,
)
epoch += 1


class _RepeatIterator(Iterator[T]):
def __init__(
self,
src: Iterable[T] | IterableWithShuffle[T],
epoch: int = 0,
) -> None:
self.src = src
self.epoch = epoch
self._iter: Iterator[T] | None = None

def __iter__(self) -> Iterator[T]:
return self

def __getstate__(self) -> dict[str, Any]: # pyre-ignore: [11]
if self._iter is not None:
raise ValueError("Cannot pickle after iteration is started.")
return self.__dict__

def __next__(self) -> T:
if self._iter is None:
self._iter = _repeat(self.src, self.epoch)
return next(self._iter)


def repeat_source(
src: Iterable[T] | IterableWithShuffle[T],
epoch: int = 0,
) -> Iterator[T]:
"""Convert an iterable into an infinite iterator with optional shuffling.
Roughly equivalent to the following code snippet.
.. code-block::
while True:
if hasattr(src, "shuffle"):
src.shuffle(seed=epoch)
yield from src
epoch += 1
Args:
src: The source to repeat.
epoch: The epoch number to start with.
"""
# Returning object so that it can be passed to a subprocess.
return _RepeatIterator(src, epoch)
15 changes: 15 additions & 0 deletions tests/spdl_unittest/dataloader/iterator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-unsafe

import os.path
import pickle
import random
import tempfile
import time
Expand Down Expand Up @@ -242,6 +243,20 @@ def __iter__(self) -> Iterator[int]:
assert next(gen) == 2


def test_repeat_source_picklable():
"""repeat_source is picklable."""

src = list(range(10))
src = repeat_source(src)

serialized = pickle.dumps(src)
src2 = pickle.loads(serialized)

for _ in range(3):
for i in range(10):
assert next(src) == next(src2) == i


def iter_range(n: int) -> Iterable[int]:
yield from range(n)

Expand Down

0 comments on commit f842322

Please sign in to comment.