-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDatasetsLoaderSplits.py
More file actions
251 lines (213 loc) · 9.61 KB
/
DatasetsLoaderSplits.py
File metadata and controls
251 lines (213 loc) · 9.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from pathlib import Path
from typing import Tuple, Dict, List, Optional, Union
import random
from torchvision import transforms
import torchvision.transforms.functional as TF
class PyTorchDatasetLoader(Dataset):
"""
PyTorch Dataset for Agriculture-Vision preprocessed dataset with single-channel binary mask output
and classification output
"""
def __init__(
self,
working_path: str,
export_type: str = "RGBN",
outputs_type: str = "both",
augmentation: bool = False,
shuffle: bool = False,
test_size: Optional[float] = None,
val_size: Optional[float] = None,
split_type: str = "train" # "train", "val", or "test"
):
"""
Args:
working_path (str): Path to preprocessed dataset
export_type (str): Type of input image ('RGBN', 'NDVI', or 'RGB')
outputs_type (str): Type of output ('mask_only', 'class_only', or 'both')
augmentation (bool): Whether to apply data augmentation
shuffle (bool): Whether to shuffle the dataset
test_size (float): Proportion of data to use for testing (0.0-1.0)
val_size (float): Proportion of data to use for validation (0.0-1.0)
split_type (str): Which split to return ('train', 'val', or 'test')
"""
self.working_path = Path(working_path)
self.export_type = export_type
self.outputs_type = outputs_type
self.augmentation = augmentation
self.split_type = split_type
# Load file paths
self.input_files = sorted(list((self.working_path / "inputs").glob("*.npy")))
self.label_files = sorted(list((self.working_path / "labels").glob("*.npy")))
# Validation checks
assert len(self.input_files) == len(self.label_files), "Mismatch in input and label files"
assert export_type in ["RGBN", "NDVI", "RGB"], "Invalid export_type"
assert outputs_type in ["mask_only", "class_only", "both"], "Invalid outputs_type"
assert split_type in ["train", "val", "test"], "Invalid split_type"
# Handle data splitting
total_samples = len(self.input_files)
# Create indices for splitting
indices = list(range(total_samples))
if shuffle:
random.shuffle(indices)
# Calculate split sizes
test_size = test_size or 0.0
val_size = val_size or 0.0
# Ensure valid split sizes
assert test_size + val_size < 1.0, "test_size + val_size must be less than 1.0"
test_count = int(total_samples * test_size)
val_count = int(total_samples * val_size)
train_count = total_samples - test_count - val_count
# Split indices
if split_type == "test":
selected_indices = indices[:test_count]
elif split_type == "val":
selected_indices = indices[test_count:test_count + val_count]
else: # train
selected_indices = indices[test_count + val_count:]
# Filter files based on selected indices
self.input_files = [self.input_files[i] for i in selected_indices]
self.label_files = [self.label_files[i] for i in selected_indices]
self.num_samples = len(self.input_files)
print(f"Split '{split_type}': {self.num_samples} samples "
f"(Total: {total_samples}, Train: {train_count}, Val: {val_count}, Test: {test_count})")
def __len__(self) -> int:
return self.num_samples
def calculate_ndvi(self, rgbn_image: np.ndarray) -> np.ndarray:
"""Calculate NDVI from RGBN image"""
nir = rgbn_image[:, :, 3].astype(np.float32) / 255.0
red = rgbn_image[:, :, 0].astype(np.float32) / 255.0
# Calculate NDVI
ndvi = (nir - red) / (nir + red + 1e-8) # ranges from -1 to +1
# Scale NDVI to [0, 255]
ndvi_scaled = ((ndvi + 1) * 127.5).astype(np.uint8)
return np.expand_dims(ndvi_scaled, axis=-1)
def process_input(self, rgbn_image: np.ndarray) -> np.ndarray:
"""Process input image based on export_type"""
if self.export_type == "RGBN":
return rgbn_image
elif self.export_type == "NDVI":
return self.calculate_ndvi(rgbn_image)
else: # RGB
return rgbn_image[:, :, :3]
def create_binary_mask(self, multi_class_mask: np.ndarray) -> np.ndarray:
"""Convert multi-class mask to binary mask using OR logic"""
# Combine all channels using logical OR
binary_mask = np.any(multi_class_mask > 0, axis=-1).astype(np.float32)
# Expand dimensions to create shape (H, W, 1)
return np.expand_dims(binary_mask, axis=-1)
def augment_data(self, image: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply data augmentation using PyTorch transforms"""
if not self.augmentation:
return image, mask
# Random horizontal flip
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
# Random vertical flip
if random.random() > 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask)
# Random rotation (±15 degrees)
if random.random() > 0.5:
angle = random.uniform(-15, 15)
image = TF.rotate(image, angle)
mask = TF.rotate(mask, angle)
return image, mask
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
"""Get item by index"""
# Load data
rgbn_image = np.load(self.input_files[idx])
multi_class_mask = np.load(self.label_files[idx])
# Process input
processed_image = self.process_input(rgbn_image)
# Create binary mask using OR logic across all class channels
binary_mask = self.create_binary_mask(multi_class_mask)
# Get class presence (1 if class exists, 0 otherwise)
class_presence = (np.sum(multi_class_mask, axis=(0,1)) > 0).astype(np.float32)
no_class = float(class_presence.sum() == 0)
class_presence = np.concatenate([class_presence, [no_class]], axis=0)
# Convert to PyTorch tensors and normalize
# Convert HWC to CHW format for PyTorch
image_tensor = torch.FloatTensor(processed_image).permute(2, 0, 1) / 255.0
mask_tensor = torch.FloatTensor(binary_mask).permute(2, 0, 1)
class_tensor = torch.FloatTensor(class_presence)
# Apply augmentation
if self.augmentation:
image_tensor, mask_tensor = self.augment_data(image_tensor, mask_tensor)
# Prepare output based on outputs_type
if self.outputs_type == "mask_only":
return image_tensor, mask_tensor
elif self.outputs_type == "class_only":
return image_tensor, class_tensor
else: # both
return image_tensor, {
'segmentation_output': mask_tensor,
'classification_output': class_tensor
}
class DatasetLoaderWrapper:
"""
Wrapper class to maintain compatibility with your existing code structure
while providing PyTorch DataLoader functionality
"""
def __init__(
self,
working_path: str,
batch_size: int = 8,
export_type: str = "RGBN",
outputs_type: str = "both",
augmentation: bool = False,
shuffle: bool = False,
num_workers: int = 4,
pin_memory: bool = True,
test_size: Optional[float] = None,
val_size: Optional[float] = None,
split_type: str = "train"
):
"""
Args:
working_path (str): Path to preprocessed dataset
batch_size (int): Number of samples per batch
export_type (str): Type of input image ('RGBN', 'NDVI', or 'RGB')
outputs_type (str): Type of output ('mask_only', 'class_only', or 'both')
augmentation (bool): Whether to apply data augmentation
shuffle (bool): Whether to shuffle the dataset
num_workers (int): Number of worker processes for data loading
pin_memory (bool): Whether to pin memory for faster GPU transfer
test_size (float): Proportion of data to use for testing (0.0-1.0)
val_size (float): Proportion of data to use for validation (0.0-1.0)
split_type (str): Which split to return ('train', 'val', or 'test')
"""
self.dataset = PyTorchDatasetLoader(
working_path=working_path,
export_type=export_type,
outputs_type=outputs_type,
augmentation=augmentation,
shuffle=shuffle,
test_size=test_size,
val_size=val_size,
split_type=split_type
)
self.dataloader = DataLoader(
self.dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False
)
self.batch_size = batch_size
self.export_type = export_type
self.outputs_type = outputs_type
def __len__(self) -> int:
return len(self.dataset)
def __iter__(self):
"""Return iterator for the DataLoader"""
return iter(self.dataloader)
def get_single_batch(self) -> Tuple[torch.Tensor, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
"""Get a single batch for testing purposes"""
for batch in self.dataloader:
return batch
raise StopIteration("No data available")