-
Notifications
You must be signed in to change notification settings - Fork 1
/
cp_latent_data.py
158 lines (138 loc) · 5.73 KB
/
cp_latent_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import Tuple, List, Union
import torch as pt
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader, random_split
from cp_latent_dataset import QuackLatentData, QuackLatentDataset
class QuackLatentCollator:
def __init__(self, step: str) -> None:
super().__init__()
valid_steps = {'train', 'val', 'test', 'predict'}
if step not in valid_steps:
raise ValueError('Improper transform step.')
self.__step = step
def __call__(self, batch: List[dict], *args, **kwargs) -> Union[Tuple[pt.Tensor, pt.Tensor], Tuple[pt.Tensor, List[dict]]]:
if self.__step == 'predict':
return self.__collate_predict(batch)
return self.__collate_labels(batch)
def __collate_labels(self, batch: List[dict]) -> Tuple[pt.Tensor, pt.Tensor]:
"""
Receives a list of QuackLatentData with B elements. Loads encoded data
from each into a tensor. Stacks label values into a second tensor.
Returns a tuple with (B, H, W) shaped tensor and (B, 1) shaped tensor.
Parameters
----------
batch: List[QuackLatentData]
A list of QuackLatentData (TypedDict) in the batch.
Returns
-------
Tuple[pt.Tensor, pt.Tensor]
A tuple in which the first tensor is batch image data and the second is labels.
"""
data = []
labels = []
for item in batch:
data.append(pt.from_numpy(item['encoded']))
censored = pt.tensor([0]).to(pt.float)
if item['metadata']['censored'] == 1:
censored = pt.tensor([1]).to(pt.float)
labels.append(censored)
return pt.stack(data), pt.stack(labels)
def __collate_predict(self, batch: List[dict]) -> Tuple[pt.Tensor, List[dict]]:
"""
Receives a list of QuackImageData with B elements. Loads pixel data Returns a (B x T) shaped tensor.
Parameters
----------
batch: List[TokenizedQuackData]
A list of QuackLatentData (TypedDict) in the batch.
Returns
-------
pt.Tensor
"""
data = []
meta = []
for item in batch:
data.append(pt.from_numpy(item['encoded']))
meta.append(item['metadata'])
return pt.stack(data), meta
class QuackLatentDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str, batch_size: int = 64, workers: int = 0):
self.__batch_size = batch_size
self.__workers = workers
dataset = QuackLatentDataset(data_dir)
self.__predict_data = dataset
print(f'Source dataset ready with {len(dataset)} items.')
# Reserve 20% of the data as test data.
test_reserve = round(len(dataset) * 0.2)
# Reserve 10% of the data as validation data.
val_reserve = round(len(dataset) * 0.1)
self.__train_data, self.__test_data, self.__val_data = random_split(
dataset, [len(dataset) - test_reserve - val_reserve, test_reserve, val_reserve]
)
print(f'Training dataset randomly split with {len(self.__train_data)} items.')
print(f'Test dataset randomly split with {len(self.__test_data)} items.')
print(f'Validation dataset randomly split with {len(self.__val_data)} items.')
print(f'Prediction dataset ready with {len(self.__predict_data)} items.')
def train_dataloader(self) -> TRAIN_DATALOADERS:
"""
Constructs and returns the train dataloader using an ``QuackLatentCollator`` object configured for training.
Returns
-------
torch.utils.data.dataloader.DataLoader
"""
train_collate = QuackLatentCollator(step='train')
return DataLoader(
self.__train_data,
batch_size=self.__batch_size,
collate_fn=train_collate,
shuffle=True,
num_workers=self.__workers,
persistent_workers=True
)
def test_dataloader(self) -> EVAL_DATALOADERS:
"""
Constructs and returns the test dataloader using an ``QuackLatentCollator`` object configured for testing.
Returns
-------
torch.utils.data.dataloader.DataLoader
"""
test_collate = QuackLatentCollator(step='test')
return DataLoader(
self.__test_data,
batch_size=self.__batch_size,
collate_fn=test_collate,
num_workers=self.__workers,
persistent_workers=True
)
def val_dataloader(self) -> EVAL_DATALOADERS:
"""
Constructs and returns the validation dataloader using an ``QuackLatentCollator`` object
configured for validation.
Returns
-------
torch.utils.data.dataloader.DataLoader
"""
val_collate = QuackLatentCollator(step='val')
return DataLoader(
self.__val_data,
batch_size=self.__batch_size,
collate_fn=val_collate,
num_workers=self.__workers,
persistent_workers=True
)
def predict_dataloader(self) -> EVAL_DATALOADERS:
"""
Constructs and returns the prediction dataloader using an ``QuackLatentCollator`` object configured
for prediction.
Returns
-------
torch.utils.data.dataloader.DataLoader
"""
predict_collate = QuackLatentCollator(step='predict')
return DataLoader(
self.__predict_data,
batch_size=self.__batch_size,
collate_fn=predict_collate,
num_workers=self.__workers,
persistent_workers=True
)