Skip to content

Commit f842322

Browse files
authored
Make repeat_source work with subprocess (#325)
1 parent 9002aaf commit f842322

File tree

2 files changed

+65
-21
lines changed

2 files changed

+65
-21
lines changed

src/spdl/source/utils.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
Iterator,
2222
Sequence,
2323
)
24-
from typing import TypeVar
24+
from typing import Any, TypeVar
2525

2626
from ._type import IterableWithShuffle
2727

@@ -364,26 +364,7 @@ def __iter__(self) -> Iterator[T]:
364364
################################################################################
365365

366366

367-
def repeat_source(
368-
src: Iterable[T] | IterableWithShuffle[T],
369-
epoch: int = 0,
370-
) -> Iterator[T]:
371-
"""Convert an iterable into an infinite iterator with optional shuffling.
372-
373-
Roughly equivalent to the following code snippet.
374-
375-
.. code-block::
376-
377-
while True:
378-
if hasattr(src, "shuffle"):
379-
src.shuffle(seed=epoch)
380-
yield from src
381-
epoch += 1
382-
383-
Args:
384-
src: The source to repeat.
385-
epoch: The epoch number to start with.
386-
"""
367+
def _repeat(src: Iterable[T] | IterableWithShuffle[T], epoch: int) -> Iterator[T]:
387368
while True:
388369
_LG.info("Starting source epoch %d.", epoch)
389370
t0 = time.monotonic()
@@ -405,3 +386,51 @@ def repeat_source(
405386
qps,
406387
)
407388
epoch += 1
389+
390+
391+
class _RepeatIterator(Iterator[T]):
392+
def __init__(
393+
self,
394+
src: Iterable[T] | IterableWithShuffle[T],
395+
epoch: int = 0,
396+
) -> None:
397+
self.src = src
398+
self.epoch = epoch
399+
self._iter: Iterator[T] | None = None
400+
401+
def __iter__(self) -> Iterator[T]:
402+
return self
403+
404+
def __getstate__(self) -> dict[str, Any]: # pyre-ignore: [11]
405+
if self._iter is not None:
406+
raise ValueError("Cannot pickle after iteration is started.")
407+
return self.__dict__
408+
409+
def __next__(self) -> T:
410+
if self._iter is None:
411+
self._iter = _repeat(self.src, self.epoch)
412+
return next(self._iter)
413+
414+
415+
def repeat_source(
416+
src: Iterable[T] | IterableWithShuffle[T],
417+
epoch: int = 0,
418+
) -> Iterator[T]:
419+
"""Convert an iterable into an infinite iterator with optional shuffling.
420+
421+
Roughly equivalent to the following code snippet.
422+
423+
.. code-block::
424+
425+
while True:
426+
if hasattr(src, "shuffle"):
427+
src.shuffle(seed=epoch)
428+
yield from src
429+
epoch += 1
430+
431+
Args:
432+
src: The source to repeat.
433+
epoch: The epoch number to start with.
434+
"""
435+
# Returning object so that it can be passed to a subprocess.
436+
return _RepeatIterator(src, epoch)

tests/spdl_unittest/dataloader/iterator_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-unsafe
88

99
import os.path
10+
import pickle
1011
import random
1112
import tempfile
1213
import time
@@ -242,6 +243,20 @@ def __iter__(self) -> Iterator[int]:
242243
assert next(gen) == 2
243244

244245

246+
def test_repeat_source_picklable():
247+
"""repeat_source is picklable."""
248+
249+
src = list(range(10))
250+
src = repeat_source(src)
251+
252+
serialized = pickle.dumps(src)
253+
src2 = pickle.loads(serialized)
254+
255+
for _ in range(3):
256+
for i in range(10):
257+
assert next(src) == next(src2) == i
258+
259+
245260
def iter_range(n: int) -> Iterable[int]:
246261
yield from range(n)
247262

0 commit comments

Comments
 (0)