21
21
Iterator ,
22
22
Sequence ,
23
23
)
24
- from typing import TypeVar
24
+ from typing import Any , TypeVar
25
25
26
26
from ._type import IterableWithShuffle
27
27
@@ -364,26 +364,7 @@ def __iter__(self) -> Iterator[T]:
364
364
################################################################################
365
365
366
366
367
- def repeat_source (
368
- src : Iterable [T ] | IterableWithShuffle [T ],
369
- epoch : int = 0 ,
370
- ) -> Iterator [T ]:
371
- """Convert an iterable into an infinite iterator with optional shuffling.
372
-
373
- Roughly equivalent to the following code snippet.
374
-
375
- .. code-block::
376
-
377
- while True:
378
- if hasattr(src, "shuffle"):
379
- src.shuffle(seed=epoch)
380
- yield from src
381
- epoch += 1
382
-
383
- Args:
384
- src: The source to repeat.
385
- epoch: The epoch number to start with.
386
- """
367
+ def _repeat (src : Iterable [T ] | IterableWithShuffle [T ], epoch : int ) -> Iterator [T ]:
387
368
while True :
388
369
_LG .info ("Starting source epoch %d." , epoch )
389
370
t0 = time .monotonic ()
@@ -405,3 +386,51 @@ def repeat_source(
405
386
qps ,
406
387
)
407
388
epoch += 1
389
+
390
+
391
+ class _RepeatIterator (Iterator [T ]):
392
+ def __init__ (
393
+ self ,
394
+ src : Iterable [T ] | IterableWithShuffle [T ],
395
+ epoch : int = 0 ,
396
+ ) -> None :
397
+ self .src = src
398
+ self .epoch = epoch
399
+ self ._iter : Iterator [T ] | None = None
400
+
401
+ def __iter__ (self ) -> Iterator [T ]:
402
+ return self
403
+
404
+ def __getstate__ (self ) -> dict [str , Any ]: # pyre-ignore: [11]
405
+ if self ._iter is not None :
406
+ raise ValueError ("Cannot pickle after iteration is started." )
407
+ return self .__dict__
408
+
409
+ def __next__ (self ) -> T :
410
+ if self ._iter is None :
411
+ self ._iter = _repeat (self .src , self .epoch )
412
+ return next (self ._iter )
413
+
414
+
415
+ def repeat_source (
416
+ src : Iterable [T ] | IterableWithShuffle [T ],
417
+ epoch : int = 0 ,
418
+ ) -> Iterator [T ]:
419
+ """Convert an iterable into an infinite iterator with optional shuffling.
420
+
421
+ Roughly equivalent to the following code snippet.
422
+
423
+ .. code-block::
424
+
425
+ while True:
426
+ if hasattr(src, "shuffle"):
427
+ src.shuffle(seed=epoch)
428
+ yield from src
429
+ epoch += 1
430
+
431
+ Args:
432
+ src: The source to repeat.
433
+ epoch: The epoch number to start with.
434
+ """
435
+ # Returning object so that it can be passed to a subprocess.
436
+ return _RepeatIterator (src , epoch )
0 commit comments