@@ -128,6 +128,7 @@ def __init__(
128
128
sampler : "torch.utils.data.sampler.Sampler[K]" ,
129
129
fetch_fn : Callable [[K ], U ],
130
130
collate_fn : Callable [[list [U ]], V ],
131
+ transfer_fn : Callable [[V ], V ],
131
132
mp_ctx : mp .context .BaseContext ,
132
133
num_workers : int ,
133
134
timeout : float | None ,
@@ -139,6 +140,7 @@ def __init__(
139
140
self ._sampler = sampler
140
141
self ._fetch_fn = fetch_fn
141
142
self ._collate_fn = collate_fn
143
+ self ._transfer_fn = transfer_fn
142
144
self ._mp_ctx = mp_ctx
143
145
self ._num_workers = num_workers
144
146
self ._buffer_size = buffer_size
@@ -153,7 +155,7 @@ def _get_pipeline(self) -> tuple[ProcessPoolExecutor, Pipeline]:
153
155
executor = _get_executor (
154
156
self ._shmem .name , self ._collate_fn , self ._num_workers , self ._mp_ctx
155
157
)
156
- pipeline = (
158
+ builder = (
157
159
PipelineBuilder ()
158
160
.add_source (self ._sampler )
159
161
.pipe (
@@ -162,9 +164,14 @@ def _get_pipeline(self) -> tuple[ProcessPoolExecutor, Pipeline]:
162
164
output_order = self ._output_order ,
163
165
concurrency = self ._num_workers ,
164
166
)
165
- .add_sink (self ._buffer_size )
166
- .build (num_threads = 1 )
167
167
)
168
+ if self ._transfer_fn :
169
+ builder .pipe (
170
+ self ._transfer_fn ,
171
+ output_order = self ._output_order ,
172
+ )
173
+
174
+ pipeline = builder .add_sink (self ._buffer_size ).build (num_threads = 1 )
168
175
return executor , pipeline
169
176
170
177
def __iter__ (self ) -> Iterator [V ]:
@@ -231,7 +238,7 @@ def _resolve_sampler(
231
238
_collate_fn = collate_fn or default_collate
232
239
elif batch_size is not None :
233
240
_sampler = BatchSampler (
234
- sampler or _get_sampler (dataset , shuffle , generator ), # pyre-ignore: [6]
241
+ sampler or _get_sampler (dataset , shuffle , generator ),
235
242
batch_size ,
236
243
drop_last ,
237
244
)
@@ -281,11 +288,8 @@ def get_pytorch_dataloader(
281
288
if worker_init_fn is not None :
282
289
raise ValueError ("`worker_init_fn` is not supported." )
283
290
284
- if pin_memory :
285
- raise ValueError ("`pin_memory` is not supported (yet)." )
286
-
287
291
if pin_memory_device is not None :
288
- raise ValueError ("`pin_memory_device` is not supported (yet) ." )
292
+ raise ValueError ("`pin_memory_device` is not supported." )
289
293
290
294
if persistent_workers :
291
295
raise ValueError ("`persistent_workers` is not supported." )
@@ -309,6 +313,10 @@ def get_pytorch_dataloader(
309
313
generator ,
310
314
)
311
315
316
+ from torch .utils .data ._utils .pin_memory import pin_memory as pin_memory_fn
317
+
318
+ transfer_fn = pin_memory_fn if pin_memory else None
319
+
312
320
mp_ctx = (
313
321
multiprocessing_context
314
322
if isinstance (multiprocessing_context , mp .context .BaseContext )
@@ -321,8 +329,9 @@ def get_pytorch_dataloader(
321
329
dataset = dataset ,
322
330
shmem = shmem ,
323
331
sampler = _sampler ,
324
- fetch_fn = _fetch_fn , # pyre-ignore
332
+ fetch_fn = _fetch_fn ,
325
333
collate_fn = _collate_fn ,
334
+ transfer_fn = transfer_fn ,
326
335
mp_ctx = mp_ctx ,
327
336
num_workers = num_workers ,
328
337
timeout = timeout ,
0 commit comments