-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_loader.py
119 lines (91 loc) · 4.04 KB
/
dataset_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import sys
import numpy as np
from math import pi, cos
import torch
import torchvision
import torch.nn as nn
from torch import allclose
from datetime import datetime
import torch.nn.functional as tf
import torchvision.transforms as T
import torchvision.models as models
from torch.utils.data import DataLoader
from torch.testing import assert_allclose
from torchvision import datasets, transforms
from tqdm import tqdm
import kornia
from kornia import augmentation as K
import kornia.augmentation.functional as F
import kornia.augmentation.random_generator as rg
from torchvision.transforms import functional as tvF
_MEAN = [0.5, 0.5, 0.5]
_STD = [0.2, 0.2, 0.2]
class InitalTransformation():
def __init__(self):
self.transform = T.Compose([
T.ToTensor(),
transforms.Normalize(_MEAN,_STD),
])
def __call__(self, x):
x = self.transform(x)
return x
def gpu_transformer(image_size,s=.2):
train_transform = nn.Sequential(
kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
kornia.augmentation.RandomHorizontalFlip(p=0.5),
kornia.augmentation.ColorJitter(0.8*s,0.8*s,0.8*s,0.2*s,p=0.3),
kornia.augmentation.RandomGrayscale(p=0.05),)
test_transform = nn.Sequential(
kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
kornia.augmentation.RandomHorizontalFlip(p=0.5),
kornia.augmentation.ColorJitter(0.8*s,0.8*s,0.8*s,0.2*s,p=0.3),
kornia.augmentation.RandomGrayscale(p=0.05),)
return train_transform , test_transform
def get_clf_train_test_transform(image_size,s=.2):
train_transform = nn.Sequential(
kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
kornia.augmentation.RandomHorizontalFlip(p=0.5),
# kornia.augmentation.Normalize(CIFAR_MEAN_,CIFAR_STD_),
)
test_transform = nn.Sequential(
kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
kornia.augmentation.RandomHorizontalFlip(p=0.5),
# kornia.augmentation.RandomGrayscale(p=0.05),
# kornia.augmentation.Normalize(CIFAR_MEAN_,CIFAR_STD_)
)
return train_transform , test_transform
def get_train_test_dataloaders(dataset = "stl10", data_dir="./dataset", batch_size = 64,num_workers = 4, download=True):
train_loader = torch.utils.data.DataLoader(
dataset = torchvision.datasets.STL10(data_dir, split="train+unlabeled", transform=InitalTransformation(), download=download),
shuffle=True,
batch_size= batch_size,
num_workers = num_workers
)
test_loader = torch.utils.data.DataLoader(
dataset = torchvision.datasets.STL10(data_dir, split="test", transform=InitalTransformation(), download=download),
shuffle=True,
batch_size= batch_size,
num_workers = num_workers
)
return train_loader, test_loader
def get_train_mem_test_dataloaders(dataset = "cifar10", data_dir="./dataset", batch_size = 16,num_workers = 4, download=True):
train_loader = torch.utils.data.DataLoader(
dataset = torchvision.datasets.CIFAR10(data_dir, train=True, transform=InitalTransformation(), download=download),
shuffle=True,
batch_size= batch_size,
num_workers = num_workers
)
memory_loader = torch.utils.data.DataLoader(
dataset = torchvision.datasets.CIFAR10(data_dir, train=False, transform=InitalTransformation(), download=download),
shuffle=False,
batch_size= batch_size,
num_workers = num_workers
)
test_loader = torch.utils.data.DataLoader(
dataset = torchvision.datasets.CIFAR10(data_dir, train=False, transform=InitalTransformation(), download=download),
shuffle=True,
batch_size= batch_size,
num_workers = num_workers
)
return train_loader, memory_loader, test_loader