-
Notifications
You must be signed in to change notification settings - Fork 3
/
batch.py
33 lines (27 loc) · 879 Bytes
/
batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
def setup_dataloader(dataset, batch_size):
def _collate(samples):
# TODO: replace torch.stack with https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.stack.html
pixel_values = (
torch.stack([sample["pixel_values"] for sample in samples])
.to(memory_format=torch.contiguous_format)
.float()
.numpy()
)
input_ids = (
torch.stack([sample["input_ids"] for sample in samples])
.to(memory_format=torch.contiguous_format)
.numpy()
)
return {
"pixel_values": pixel_values,
"input_ids": input_ids,
}
return torch.utils.data.DataLoader(
dataset,
shuffle=True,
collate_fn=_collate,
batch_size=batch_size,
num_workers=4,
drop_last=True,
)