diff --git a/src/spdl/dataloader/_pytorch_dataloader.py b/src/spdl/dataloader/_pytorch_dataloader.py index 2ea74eef..6ade3893 100644 --- a/src/spdl/dataloader/_pytorch_dataloader.py +++ b/src/spdl/dataloader/_pytorch_dataloader.py @@ -13,6 +13,7 @@ import os import pickle import time +import warnings from collections.abc import Callable, Iterable, Iterator from concurrent.futures import ProcessPoolExecutor from multiprocessing.shared_memory import SharedMemory @@ -298,8 +299,14 @@ def get_pytorch_dataloader( if timeout is not None and timeout < 0: raise ValueError(f"`timeout` must be positive. Found: {timeout}.") - if num_workers < 1: + if num_workers < 0: raise ValueError(f"`num_workers` must be greater than 0. Found: {num_workers}") + elif num_workers == 0: + warnings.warn( + "`num_workers` is 0. Setting `num_workers` to 1 for single process dataloading.", + stacklevel=2, + ) + num_workers = 1 buffer_size = prefetch_factor * num_workers