-
Notifications
You must be signed in to change notification settings - Fork 7
/
dataset.py
93 lines (75 loc) · 3.44 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from copy import deepcopy
import cv2
import numpy as np
import torch
from torchvision import transforms
from albumentations.augmentations.transforms import Blur, HorizontalFlip, ElasticTransform, RandomScale, Resize, Rotate, RandomContrast
from albumentations.core.composition import Compose, OneOf
from albumentations.pytorch.transforms import ToTensor
class BreastMRIFusionDataset(torch.utils.data.Dataset):
def __init__(self, fpath, augment=False, n_TTA=0, n=None):
self.fpath = fpath
self.augment = augment
self.n_TTA = n_TTA
self.n = n
if self.augment:
self.transform = Compose([
HorizontalFlip(p=0.5),
OneOf([Blur(blur_limit=4, p=0.5), RandomContrast(p=0.5)], p=1),
ElasticTransform(alpha_affine=10, sigma=5, border_mode=cv2.BORDER_CONSTANT, p=0.5),
Compose([RandomScale(scale_limit=0.2, p=1), Resize(224, 224, p=1)], p=0.5),
Rotate(20, border_mode=cv2.BORDER_CONSTANT, p=0.5),
ToTensor()
])
else:
self.transform = transforms.Compose([
transforms.ToTensor()
])
if self.n_TTA != 0:
self.tta_transform = Compose([
HorizontalFlip(p=0.5),
OneOf([Blur(blur_limit=4, p=0.5), RandomContrast(p=0.5)], p=1),
Rotate(20, border_mode=cv2.BORDER_CONSTANT, p=0.5),
ToTensor()
])
self.meta_features = np.load(os.path.join(self.fpath, "1_meta.npy")).shape[0]
def __len__(self):
if self.n is not None:
return self.n
else:
return len([x for x in os.listdir(self.fpath) if "_x" in x])
def __getitem__(self, idx):
x = np.load(os.path.join(self.fpath, str(idx + 1) + "_x.npy"))
meta = np.load(os.path.join(self.fpath, str(idx + 1) + "_meta.npy"))
y = np.load(os.path.join(self.fpath, str(idx + 1) + "_y.npy"))
if self.n_TTA != 0:
x = np.stack([self.tta_transform(image=x)["image"].float() for _ in range(self.n_TTA)], axis=-1)
else:
if self.augment:
x = self.transform(image=x)["image"].float()
else:
x = self.transform(x).float()
meta = torch.Tensor(meta).float()
y = torch.Tensor(y).float()
return {"image": x, "metadata": meta, "label": y}
class FeatureImpDataset(torch.utils.data.Dataset):
def __init__(self, fpath):
self.fpath = fpath
self.n = len([x for x in os.listdir(self.fpath) if "_x" in x])
self.transform = transforms.ToTensor()
self.x_test = np.array([np.load(os.path.join(self.fpath, str(i+1) + "_x.npy")) for i in range(self.n)])
self.orig_meta_test = np.array([np.load(os.path.join(self.fpath, str(i+1) + "_meta.npy")) for i in range(self.n)])
self.meta_test = deepcopy(self.orig_meta_test)
self.y_test = np.array([np.load(os.path.join(self.fpath, str(i+1) + "_y.npy")) for i in range(self.n)])
self.meta_features = self.orig_meta_test.shape[1]
def __len__(self):
return self.n
def __getitem__(self, idx, orig=False):
x = self.x_test[idx]
meta = self.meta_test[idx]
y = self.y_test[idx]
x = self.transform(x).float()
meta = torch.Tensor(meta).float()
y = torch.Tensor(y).float()
return {"image": x, "metadata": meta, "label": y}