Skip to content

Commit ca1389a

Browse files
authored
Override num_workers from 0 to 1 in get_pytorch_dataloader
Differential Revision: D68357865 Pull Request resolved: #330
1 parent 4fc0bc2 commit ca1389a

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/spdl/dataloader/_pytorch_dataloader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414
import pickle
1515
import time
16+
import warnings
1617
from collections.abc import Callable, Iterable, Iterator
1718
from concurrent.futures import ProcessPoolExecutor
1819
from multiprocessing.shared_memory import SharedMemory
@@ -298,8 +299,14 @@ def get_pytorch_dataloader(
298299
if timeout is not None and timeout < 0:
299300
raise ValueError(f"`timeout` must be positive. Found: {timeout}.")
300301

301-
if num_workers < 1:
302+
if num_workers < 0:
302303
raise ValueError(f"`num_workers` must be greater than 0. Found: {num_workers}")
304+
elif num_workers == 0:
305+
warnings.warn(
306+
"`num_workers` is 0. Setting `num_workers` to 1 for single process dataloading.",
307+
stacklevel=2,
308+
)
309+
num_workers = 1
303310

304311
buffer_size = prefetch_factor * num_workers
305312

0 commit comments

Comments
 (0)