Skip to content

Commit 8173797

Browse files
authored
Add pin_memory support (#326)
1 parent f842322 commit 8173797

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

src/spdl/dataloader/_pytorch_dataloader.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __init__(
128128
sampler: "torch.utils.data.sampler.Sampler[K]",
129129
fetch_fn: Callable[[K], U],
130130
collate_fn: Callable[[list[U]], V],
131+
transfer_fn: Callable[[V], V],
131132
mp_ctx: mp.context.BaseContext,
132133
num_workers: int,
133134
timeout: float | None,
@@ -139,6 +140,7 @@ def __init__(
139140
self._sampler = sampler
140141
self._fetch_fn = fetch_fn
141142
self._collate_fn = collate_fn
143+
self._transfer_fn = transfer_fn
142144
self._mp_ctx = mp_ctx
143145
self._num_workers = num_workers
144146
self._buffer_size = buffer_size
@@ -153,7 +155,7 @@ def _get_pipeline(self) -> tuple[ProcessPoolExecutor, Pipeline]:
153155
executor = _get_executor(
154156
self._shmem.name, self._collate_fn, self._num_workers, self._mp_ctx
155157
)
156-
pipeline = (
158+
builder = (
157159
PipelineBuilder()
158160
.add_source(self._sampler)
159161
.pipe(
@@ -162,9 +164,14 @@ def _get_pipeline(self) -> tuple[ProcessPoolExecutor, Pipeline]:
162164
output_order=self._output_order,
163165
concurrency=self._num_workers,
164166
)
165-
.add_sink(self._buffer_size)
166-
.build(num_threads=1)
167167
)
168+
if self._transfer_fn:
169+
builder.pipe(
170+
self._transfer_fn,
171+
output_order=self._output_order,
172+
)
173+
174+
pipeline = builder.add_sink(self._buffer_size).build(num_threads=1)
168175
return executor, pipeline
169176

170177
def __iter__(self) -> Iterator[V]:
@@ -231,7 +238,7 @@ def _resolve_sampler(
231238
_collate_fn = collate_fn or default_collate
232239
elif batch_size is not None:
233240
_sampler = BatchSampler(
234-
sampler or _get_sampler(dataset, shuffle, generator), # pyre-ignore: [6]
241+
sampler or _get_sampler(dataset, shuffle, generator),
235242
batch_size,
236243
drop_last,
237244
)
@@ -281,11 +288,8 @@ def get_pytorch_dataloader(
281288
if worker_init_fn is not None:
282289
raise ValueError("`worker_init_fn` is not supported.")
283290

284-
if pin_memory:
285-
raise ValueError("`pin_memory` is not supported (yet).")
286-
287291
if pin_memory_device is not None:
288-
raise ValueError("`pin_memory_device` is not supported (yet).")
292+
raise ValueError("`pin_memory_device` is not supported.")
289293

290294
if persistent_workers:
291295
raise ValueError("`persistent_workers` is not supported.")
@@ -309,6 +313,10 @@ def get_pytorch_dataloader(
309313
generator,
310314
)
311315

316+
from torch.utils.data._utils.pin_memory import pin_memory as pin_memory_fn
317+
318+
transfer_fn = pin_memory_fn if pin_memory else None
319+
312320
mp_ctx = (
313321
multiprocessing_context
314322
if isinstance(multiprocessing_context, mp.context.BaseContext)
@@ -321,8 +329,9 @@ def get_pytorch_dataloader(
321329
dataset=dataset,
322330
shmem=shmem,
323331
sampler=_sampler,
324-
fetch_fn=_fetch_fn, # pyre-ignore
332+
fetch_fn=_fetch_fn,
325333
collate_fn=_collate_fn,
334+
transfer_fn=transfer_fn,
326335
mp_ctx=mp_ctx,
327336
num_workers=num_workers,
328337
timeout=timeout,

0 commit comments

Comments
 (0)