-
Notifications
You must be signed in to change notification settings - Fork 100
/
segdataset.py
118 lines (109 loc) · 5.38 KB
/
segdataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Author: Manpreet Singh Minhas
Contact: msminhas at uwaterloo ca
"""
from pathlib import Path
from typing import Any, Callable, Optional
import numpy as np
from PIL import Image
from torchvision.datasets.vision import VisionDataset
class SegmentationDataset(VisionDataset):
"""A PyTorch dataset for image segmentation task.
The dataset is compatible with torchvision transforms.
The transforms passed would be applied to both the Images and Masks.
"""
def __init__(self,
root: str,
image_folder: str,
mask_folder: str,
transforms: Optional[Callable] = None,
seed: int = None,
fraction: float = None,
subset: str = None,
image_color_mode: str = "rgb",
mask_color_mode: str = "grayscale") -> None:
"""
Args:
root (str): Root directory path.
image_folder (str): Name of the folder that contains the images in the root directory.
mask_folder (str): Name of the folder that contains the masks in the root directory.
transforms (Optional[Callable], optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.ToTensor`` for images. Defaults to None.
seed (int, optional): Specify a seed for the train and test split for reproducible results. Defaults to None.
fraction (float, optional): A float value from 0 to 1 which specifies the validation split fraction. Defaults to None.
subset (str, optional): 'Train' or 'Test' to select the appropriate set. Defaults to None.
image_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'rgb'.
mask_color_mode (str, optional): 'rgb' or 'grayscale'. Defaults to 'grayscale'.
Raises:
OSError: If image folder doesn't exist in root.
OSError: If mask folder doesn't exist in root.
ValueError: If subset is not either 'Train' or 'Test'
ValueError: If image_color_mode and mask_color_mode are either 'rgb' or 'grayscale'
"""
super().__init__(root, transforms)
image_folder_path = Path(self.root) / image_folder
mask_folder_path = Path(self.root) / mask_folder
if not image_folder_path.exists():
raise OSError(f"{image_folder_path} does not exist.")
if not mask_folder_path.exists():
raise OSError(f"{mask_folder_path} does not exist.")
if image_color_mode not in ["rgb", "grayscale"]:
raise ValueError(
f"{image_color_mode} is an invalid choice. Please enter from rgb grayscale."
)
if mask_color_mode not in ["rgb", "grayscale"]:
raise ValueError(
f"{mask_color_mode} is an invalid choice. Please enter from rgb grayscale."
)
self.image_color_mode = image_color_mode
self.mask_color_mode = mask_color_mode
if not fraction:
self.image_names = sorted(image_folder_path.glob("*"))
self.mask_names = sorted(mask_folder_path.glob("*"))
else:
if subset not in ["Train", "Test"]:
raise (ValueError(
f"{subset} is not a valid input. Acceptable values are Train and Test."
))
self.fraction = fraction
self.image_list = np.array(sorted(image_folder_path.glob("*")))
self.mask_list = np.array(sorted(mask_folder_path.glob("*")))
if seed:
np.random.seed(seed)
indices = np.arange(len(self.image_list))
np.random.shuffle(indices)
self.image_list = self.image_list[indices]
self.mask_list = self.mask_list[indices]
if subset == "Train":
self.image_names = self.image_list[:int(
np.ceil(len(self.image_list) * (1 - self.fraction)))]
self.mask_names = self.mask_list[:int(
np.ceil(len(self.mask_list) * (1 - self.fraction)))]
else:
self.image_names = self.image_list[
int(np.ceil(len(self.image_list) * (1 - self.fraction))):]
self.mask_names = self.mask_list[
int(np.ceil(len(self.mask_list) * (1 - self.fraction))):]
def __len__(self) -> int:
return len(self.image_names)
def __getitem__(self, index: int) -> Any:
image_path = self.image_names[index]
mask_path = self.mask_names[index]
with open(image_path, "rb") as image_file, open(mask_path,
"rb") as mask_file:
image = Image.open(image_file)
if self.image_color_mode == "rgb":
image = image.convert("RGB")
elif self.image_color_mode == "grayscale":
image = image.convert("L")
mask = Image.open(mask_file)
if self.mask_color_mode == "rgb":
mask = mask.convert("RGB")
elif self.mask_color_mode == "grayscale":
mask = mask.convert("L")
sample = {"image": image, "mask": mask}
if self.transforms:
sample["image"] = self.transforms(sample["image"])
sample["mask"] = self.transforms(sample["mask"])
return sample