forked from marrlab/percollFFT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpercollDataLoaderMultiClass.py
55 lines (45 loc) · 1.93 KB
/
percollDataLoaderMultiClass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import cv2
import pickle
import pickletools
from fouriersignal import fourierSignal
"""Data Loader for training data"""
# Data loader
class percollDataLoaderMultiClass:
def __init__(self, indexFold, train, augmented=True):
if augmented:
if train:
name = "/storage/scratch/users/ario.sadafi/percoll-new/foldTrain" + str(indexFold) + ".pkl"
with open(name, "rb") as f:
self.imageList = pickle.load(f)
else:
name = "/storage/scratch/users/ario.sadafi/percoll-new/foldTest" + str(indexFold) + ".pkl"
with open(name, "rb") as f:
self.imageList = pickle.load(f)
else:
name = "/storage/scratch/users/ario.sadafi/percoll-new/foldTest" + str(indexFold) + ".pkl"
with open(name, "rb") as f:
imList = pickle.load(f)
self.imageList = []
for i, l in enumerate(imList):
if i % 6 == 0:
self.imageList.append(l)
def __len__(self):
return len(self.imageList)
def __getitem__(self, index):
image = np.array(self.imageList[index]['Img'], dtype="float64")
image *= 1.0 / image.max()
if image.shape[1] != 500 or image.shape[2] != 500:
image = np.array([cv2.resize(image[0], (500, 500)),
cv2.resize(image[1], (500, 500)),
cv2.resize(image[2], (500, 500))])
label = np.array([self.imageList[index]['label'] == c
for c in [0, 1, 2, 3]],
dtype=np.uint8) # change this to fit with the 8 labels
_, _, _, fourierColors, _ = fourierSignal(image)
return image, label, fourierColors
if __name__ == "__main__":
PDL = percollDataLoaderMultiClass(0, train=True, augmented=True)
print(PDL[4])
for data in PDL:
print(data[0].shape)