Skip to content

Commit

Permalink
Override num_workers from 0 to 1 in get_pytorch_dataloader (#330)
Browse files Browse the repository at this point in the history
Summary:

To support a similar interface to the pytorch dataloader, we set `num_workers` to 1 if the input `num_workers` is 0 in `get_pytorch_dataloader`

Reviewed By: moto-meta

Differential Revision: D68357865
  • Loading branch information
Victor Bourgin authored and facebook-github-bot committed Jan 22, 2025
1 parent 4fc0bc2 commit f7892fa
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/spdl/dataloader/_pytorch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import pickle
import time
import warnings
from collections.abc import Callable, Iterable, Iterator
from concurrent.futures import ProcessPoolExecutor
from multiprocessing.shared_memory import SharedMemory
Expand Down Expand Up @@ -298,8 +299,14 @@ def get_pytorch_dataloader(
if timeout is not None and timeout < 0:
raise ValueError(f"`timeout` must be positive. Found: {timeout}.")

if num_workers < 1:
if num_workers < 0:
raise ValueError(f"`num_workers` must be greater than 0. Found: {num_workers}")
elif num_workers == 0:
warnings.warn(
"`num_workers` is 0. Setting `num_workers` to 1 for single process dataloading.",
stacklevel=2,
)
num_workers = 1

buffer_size = prefetch_factor * num_workers

Expand Down

0 comments on commit f7892fa

Please sign in to comment.