-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_loader.py
76 lines (60 loc) · 2.13 KB
/
dataset_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import torch
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets
from torchvision import transforms
image_size = (32, 32)
def data_loader(data_dir,
batch_size,
random_seed=42,
valid_size=0.1,
shuffle=True,
test=False,
num_of_workers=24):
normalize = transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010],
)
# define transforms
transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
if test:
dataset = datasets.CIFAR10(
root=data_dir, train=False,
download=True, transform=transform,
)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_of_workers#, pin_memory=True
)
return data_loader
# load the dataset
train_dataset = datasets.CIFAR10(
root=data_dir, train=True,
download=True, transform=train_transform,
)
valid_dataset = datasets.CIFAR10(
root=data_dir, train=True,
download=True, transform=train_transform,
)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_of_workers)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_of_workers)
return (train_loader, valid_loader)