Skip to content

Commit

Permalink
Add initializer to iterate_in_subprocess (#323)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Jan 15, 2025
1 parent ffd6701 commit 9002aaf
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 9 deletions.
32 changes: 25 additions & 7 deletions src/spdl/source/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_MSG_PARENT_REQUEST_STOP = "PARENT_REQUEST_STOP"

# Message from worker to the parent
_MSG_INITIALIZER_FAILED = "INITIALIZER_FAILED"
_MSG_GENERATOR_FAILED = "GENERATOR_FAILED_TO_INITIALIZE"
_MSG_ITERATION_FINISHED = "ITERATION_FINISHED"
_MSG_DATA_QUEUE_FAILED = "DATA_QUEUE_FAILED"
Expand All @@ -50,7 +51,15 @@ def _execute_iterator(
msg_queue: mp.Queue,
data_queue: mp.Queue,
fn: Callable[[], Iterator[T]],
initializer: Callable[[], None],
) -> None:
if initializer is not None:
try:
initializer()
except Exception:
msg_queue.put(_MSG_INITIALIZER_FAILED)
raise

try:
gen = iter(fn())
except Exception:
Expand Down Expand Up @@ -85,8 +94,10 @@ def _execute_iterator(

def iterate_in_subprocess(
fn: Callable[[], Iterable[T]],
queue_size: int = 64,
mp_context: str = "forkserver",
*,
buffer_size: int = 3,
initializer: Callable[[], None] | None = None,
mp_context: str | None = None,
timeout: float | None = None,
daemon: bool = False,
) -> Iterator[T]:
Expand All @@ -95,11 +106,13 @@ def iterate_in_subprocess(
Args:
fn: Function that returns an iterator. Use :py:func:`functools.partial` to
pass arguments to the function.
queue_size: Maximum number of items to buffer in the queue.
buffer_size: Maximum number of items to buffer in the queue.
initializer: A function executed in the subprocess before iteration starts.
mp_context: Context to use for multiprocessing.
If not specified, a default method is used.
timeout: Timeout for inactivity. If the generator function does not yield
any item for this amount of time, the process is terminated.
daemnon: Whether to run the process as a daemon.
daemon: Whether to run the process as a daemon. Use it only for debugging.
Returns:
Iterator over the results of the generator function.
Expand All @@ -110,15 +123,15 @@ def iterate_in_subprocess(
"""
ctx = mp.get_context(mp_context)
msg_q = ctx.Queue()
data_q: mp.Queue = ctx.Queue(maxsize=queue_size)
data_q: mp.Queue = ctx.Queue(maxsize=buffer_size)

def _drain() -> Iterator[T]:
while not data_q.empty():
yield data_q.get_nowait()

process = ctx.Process(
target=_execute_iterator,
args=(msg_q, data_q, fn),
args=(msg_q, data_q, fn, initializer),
daemon=daemon,
)
process.start()
Expand All @@ -135,6 +148,10 @@ def _drain() -> Iterator[T]:

if msg == _MSG_ITERATION_FINISHED:
return
if msg == _MSG_INITIALIZER_FAILED:
raise RuntimeError(
"The worker process quit because the initializer failed."
)
if msg == _MSG_GENERATOR_FAILED:
raise RuntimeError(
"The worker process quit because the generator failed."
Expand All @@ -156,7 +173,8 @@ def _drain() -> Iterator[T]:
if timeout is not None:
if (elapsed := time.monotonic() - t0) > timeout:
raise RuntimeError(
f"The worker process did not produce any data for {elapsed:.2f} seconds."
"The worker process did not produce any data for "
f"{elapsed:.2f} seconds."
)

except (Exception, KeyboardInterrupt):
Expand Down
117 changes: 115 additions & 2 deletions tests/spdl_unittest/dataloader/iterator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@

# pyre-unsafe

from collections.abc import Iterator
import os.path
import random
import tempfile
import time
from collections.abc import Iterable, Iterator
from functools import partial
from unittest.mock import patch

import pytest
from spdl.source.utils import MergeIterator, repeat_source
from spdl.source.utils import iterate_in_subprocess, MergeIterator, repeat_source


def test_mergeiterator_ordered():
Expand Down Expand Up @@ -235,3 +240,111 @@ def __iter__(self) -> Iterator[int]:
assert next(gen) == 0
assert next(gen) == 1
assert next(gen) == 2


def iter_range(n: int) -> Iterable[int]:
yield from range(n)


def test_iterate_in_subprocess():
"""iterate_in_subprocess iterates"""
N = 10

src = iterate_in_subprocess(fn=partial(iter_range, n=N))
assert list(src) == list(range(N))


def initializer(path: str, val: str) -> None:
with open(path, "w") as f:
f.write(val)


def test_iterate_in_subprocess_initializer():
"""iterate_in_subprocess initializer is called before iteration starts"""

N = 10
val = str(random.random())
with tempfile.TemporaryDirectory() as dir:
path = os.path.join(dir, "foo.txt")

assert not os.path.exists(path)
src = iterate_in_subprocess(
fn=partial(iter_range, n=N),
initializer=partial(initializer, path=path, val=val),
buffer_size=1,
)
assert not os.path.exists(path)

assert next(src) == 0

assert os.path.exists(path)

with open(path, "r") as f:
assert f.read() == val

for i in range(1, N):
assert next(src) == i

with pytest.raises(StopIteration):
next(src)


def iter_range_and_store(n: int, path: str) -> Iterable[int]:
yield 0
for i in range(n):
yield i
with open(path, "w") as f:
f.write(str(i))


def test_iterate_in_subprocess_buffer_size_1():
"""buffer_size=1 makes iterate_in_subprocess works sort-of interactively"""

N = 10

with tempfile.TemporaryDirectory() as dir:
path = os.path.join(dir, "foo.txt")

src = iterate_in_subprocess(
fn=partial(iter_range_and_store, n=N, path=path),
daemon=True,
buffer_size=1,
)
assert src.send(None) == 0

for i in range(N):
time.sleep(0.1)

with open(path, "r") as f:
assert int(f.read()) == i

assert next(src) == i

with pytest.raises(StopIteration):
next(src)


def test_iterate_in_subprocess_buffer_size_64():
"""big buffer_size makes iterate_in_subprocess processes data in one go"""

N = 10

with tempfile.TemporaryDirectory() as dir:
path = os.path.join(dir, "foo.txt")

src = iterate_in_subprocess(
fn=partial(iter_range_and_store, n=N, path=path),
daemon=True,
buffer_size=64,
)
assert src.send(None) == 0

time.sleep(0.1)
for i in range(N):
with open(path, "r") as f:
assert int(f.read()) == 9

assert next(src) == i

with pytest.raises(StopIteration):
next(src)

0 comments on commit 9002aaf

Please sign in to comment.