From 4b4169c292493c31707d5fc4737574c90d3b184b Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 14 Jan 2025 17:59:48 -0500 Subject: [PATCH] Add initializer --- src/spdl/source/utils.py | 32 +++-- .../spdl_unittest/dataloader/iterator_test.py | 117 +++++++++++++++++- 2 files changed, 140 insertions(+), 9 deletions(-) diff --git a/src/spdl/source/utils.py b/src/spdl/source/utils.py index 0e1c87d1..9096e2fa 100644 --- a/src/spdl/source/utils.py +++ b/src/spdl/source/utils.py @@ -41,6 +41,7 @@ _MSG_PARENT_REQUEST_STOP = "PARENT_REQUEST_STOP" # Message from worker to the parent +_MSG_INITIALIZER_FAILED = "INITIALIZER_FAILED" _MSG_GENERATOR_FAILED = "GENERATOR_FAILED_TO_INITIALIZE" _MSG_ITERATION_FINISHED = "ITERATION_FINISHED" _MSG_DATA_QUEUE_FAILED = "DATA_QUEUE_FAILED" @@ -50,7 +51,15 @@ def _execute_iterator( msg_queue: mp.Queue, data_queue: mp.Queue, fn: Callable[[], Iterator[T]], + initializer: Callable[[], None], ) -> None: + if initializer is not None: + try: + initializer() + except Exception: + msg_queue.put(_MSG_INITIALIZER_FAILED) + raise + try: gen = iter(fn()) except Exception: @@ -85,8 +94,10 @@ def _execute_iterator( def iterate_in_subprocess( fn: Callable[[], Iterable[T]], - queue_size: int = 64, - mp_context: str = "forkserver", + *, + buffer_size: int = 3, + initializer: Callable[[], None] | None = None, + mp_context: str | None = None, timeout: float | None = None, daemon: bool = False, ) -> Iterator[T]: @@ -95,11 +106,13 @@ def iterate_in_subprocess( Args: fn: Function that returns an iterator. Use :py:func:`functools.partial` to pass arguments to the function. - queue_size: Maximum number of items to buffer in the queue. + buffer_size: Maximum number of items to buffer in the queue. + initializer: A function executed in the subprocess before iteration starts. mp_context: Context to use for multiprocessing. + If not specified, a default method is used. timeout: Timeout for inactivity. If the generator function does not yield any item for this amount of time, the process is terminated. - daemnon: Whether to run the process as a daemon. + daemon: Whether to run the process as a daemon. Use it only for debugging. Returns: Iterator over the results of the generator function. @@ -110,7 +123,7 @@ def iterate_in_subprocess( """ ctx = mp.get_context(mp_context) msg_q = ctx.Queue() - data_q: mp.Queue = ctx.Queue(maxsize=queue_size) + data_q: mp.Queue = ctx.Queue(maxsize=buffer_size) def _drain() -> Iterator[T]: while not data_q.empty(): @@ -118,7 +131,7 @@ def _drain() -> Iterator[T]: process = ctx.Process( target=_execute_iterator, - args=(msg_q, data_q, fn), + args=(msg_q, data_q, fn, initializer), daemon=daemon, ) process.start() @@ -135,6 +148,10 @@ def _drain() -> Iterator[T]: if msg == _MSG_ITERATION_FINISHED: return + if msg == _MSG_INITIALIZER_FAILED: + raise RuntimeError( + "The worker process quit because the initializer failed." + ) if msg == _MSG_GENERATOR_FAILED: raise RuntimeError( "The worker process quit because the generator failed." @@ -156,7 +173,8 @@ def _drain() -> Iterator[T]: if timeout is not None: if (elapsed := time.monotonic() - t0) > timeout: raise RuntimeError( - f"The worker process did not produce any data for {elapsed:.2f} seconds." + "The worker process did not produce any data for " + f"{elapsed:.2f} seconds." ) except (Exception, KeyboardInterrupt): diff --git a/tests/spdl_unittest/dataloader/iterator_test.py b/tests/spdl_unittest/dataloader/iterator_test.py index 576c147e..950def86 100644 --- a/tests/spdl_unittest/dataloader/iterator_test.py +++ b/tests/spdl_unittest/dataloader/iterator_test.py @@ -6,11 +6,16 @@ # pyre-unsafe -from collections.abc import Iterator +import os.path +import random +import tempfile +import time +from collections.abc import Iterable, Iterator +from functools import partial from unittest.mock import patch import pytest -from spdl.source.utils import MergeIterator, repeat_source +from spdl.source.utils import iterate_in_subprocess, MergeIterator, repeat_source def test_mergeiterator_ordered(): @@ -235,3 +240,111 @@ def __iter__(self) -> Iterator[int]: assert next(gen) == 0 assert next(gen) == 1 assert next(gen) == 2 + + +def iter_range(n: int) -> Iterable[int]: + yield from range(n) + + +def test_iterate_in_subprocess(): + """iterate_in_subprocess iterates""" + N = 10 + + src = iterate_in_subprocess(fn=partial(iter_range, n=N)) + assert list(src) == list(range(N)) + + +def initializer(path: str, val: str) -> None: + with open(path, "w") as f: + f.write(val) + + +def test_iterate_in_subprocess_initializer(): + """iterate_in_subprocess initializer is called before iteration starts""" + + N = 10 + val = str(random.random()) + with tempfile.TemporaryDirectory() as dir: + path = os.path.join(dir, "foo.txt") + + assert not os.path.exists(path) + src = iterate_in_subprocess( + fn=partial(iter_range, n=N), + initializer=partial(initializer, path=path, val=val), + buffer_size=1, + ) + assert not os.path.exists(path) + + assert next(src) == 0 + + assert os.path.exists(path) + + with open(path, "r") as f: + assert f.read() == val + + for i in range(1, N): + assert next(src) == i + + with pytest.raises(StopIteration): + next(src) + + +def iter_range_and_store(n: int, path: str) -> Iterable[int]: + yield 0 + for i in range(n): + yield i + with open(path, "w") as f: + f.write(str(i)) + + +def test_iterate_in_subprocess_buffer_size_1(): + """buffer_size=1 makes iterate_in_subprocess works sort-of interactively""" + + N = 10 + + with tempfile.TemporaryDirectory() as dir: + path = os.path.join(dir, "foo.txt") + + src = iterate_in_subprocess( + fn=partial(iter_range_and_store, n=N, path=path), + daemon=True, + buffer_size=1, + ) + assert src.send(None) == 0 + + for i in range(N): + time.sleep(0.1) + + with open(path, "r") as f: + assert int(f.read()) == i + + assert next(src) == i + + with pytest.raises(StopIteration): + next(src) + + +def test_iterate_in_subprocess_buffer_size_64(): + """big buffer_size makes iterate_in_subprocess processes data in one go""" + + N = 10 + + with tempfile.TemporaryDirectory() as dir: + path = os.path.join(dir, "foo.txt") + + src = iterate_in_subprocess( + fn=partial(iter_range_and_store, n=N, path=path), + daemon=True, + buffer_size=64, + ) + assert src.send(None) == 0 + + time.sleep(0.1) + for i in range(N): + with open(path, "r") as f: + assert int(f.read()) == 9 + + assert next(src) == i + + with pytest.raises(StopIteration): + next(src)