@@ -93,10 +93,10 @@ class CitySpaceDataset(Dataset):
93
93
94
94
def __init__ (self , img_dir , seg_dir , img_size , enable_data_aug = True , transform = None , cache_num = 0 ) -> None :
95
95
super (CitySpaceDataset , self ).__init__ (enable_data_aug = enable_data_aug , input_dimension = img_size )
96
- self .img_dir = img_dir
97
- self .seg_dir = seg_dir
96
+ self .img_dir = Path ( img_dir )
97
+ self .seg_dir = Path ( seg_dir )
98
98
self .trans = transform
99
- self .db_img , self . db_seg = self .make_db ()
99
+ self .filenames = self .make_db ()
100
100
self .imgs = None
101
101
if cache_num > 0 :
102
102
self .cache_num = cache_num if cache_num <= len (self ) else len (self ) # len(self)
@@ -110,8 +110,8 @@ def make_db(self):
110
110
assert Path (self .img_dir ).exists (), f"directory: { self .img_dir } is not exists!"
111
111
assert Path (self .seg_dir ).exists (), f"directory: { self .seg_dir } is not exists!"
112
112
113
- img_filepathes = [p for p in Path ( self .img_dir ) .iterdir () if p .suffix in ([".jpg" , ".png" , ".tiff" ])]
114
- seg_filepathes = [p for p in Path ( self .seg_dir ) .iterdir () if p .suffix in ([".jpg" , ".png" , ".tiff" ])]
113
+ img_filepathes = [p for p in self .img_dir .iterdir () if p .suffix in ([".jpg" , ".png" , ".tiff" ])]
114
+ seg_filepathes = [p for p in self .seg_dir .iterdir () if p .suffix in ([".jpg" , ".png" , ".tiff" ])]
115
115
assert len (img_filepathes ) == len (seg_filepathes ), f"len(img_filepathes): { len (img_filepathes )} , but len(seg_filenames): { len (seg_filepathes )} "
116
116
# (aachen , 000062 , 000019)
117
117
img_filepathes = sorted (img_filepathes , key = lambda x : (x .stem .split ("_" )[0 ], x .stem .split ("_" )[1 ], x .stem .split ("_" )[2 ]))
@@ -121,17 +121,22 @@ def make_db(self):
121
121
for i , p in enumerate (img_filepathes ):
122
122
img_filename = '_' .join (p .stem .split ("_" )[:- 1 ])
123
123
assert img_filename in seg_filenames , f"image filename: { img_filepathes [i ]} , can not found matched segmentation file."
124
- return img_filepathes , seg_filepathes
124
+ return seg_filenames
125
125
126
126
def __len__ (self ):
127
- return len (self .db_img )
127
+ return len (self .filenames )
128
128
129
129
def load_resized_data_pair (self , index ):
130
- img_p = self .db_img [index ]
131
- seg_p = self .db_seg [index ]
130
+ filename = self .filenames [index ]
131
+ img_p = self .img_dir / f"{ filename } _leftImg8bit.png"
132
+ assert img_p .exists (), f"{ img_p } is not exists!"
133
+ seg_p = self .seg_dir / f"{ filename } _gtFine_labelTrainIds.png"
134
+ assert seg_p .exists (), f"{ seg_p } is not exists!"
135
+
132
136
img_arr = cv2 .imread (str (img_p )) # (h, w, 3)
133
137
img_arr = cv2 .cvtColor (img_arr , cv2 .COLOR_BGR2RGB )
134
138
seg_arr = cv2 .imread (str (seg_p ), 0 )[:, :, None ] # (h, w, 1)
139
+ assert img_arr .shape [0 ] == seg_arr .shape [0 ] and img_arr .shape [1 ] == seg_arr .shape [1 ], f"img_arr's and seg_arr's shape should be the same, but img_arr.shape={ img_arr .shape [:2 ]} and seg_arr.shape={ seg_arr .shape [:2 ]} "
135
140
# cityspace数据集中的背景类mask值为255, 将背景类的mask修改为0
136
141
bg_mask = seg_arr == 255
137
142
seg_arr += 1
@@ -193,10 +198,17 @@ def pull_item(self, index):
193
198
img_arr = data_pair [..., :3 ]
194
199
seg_arr = data_pair [..., - 1 :]
195
200
else :
196
- img_p = self .db_img [index ]
197
- seg_p = self .db_seg [index ]
201
+ filename = self .filenames [index ]
202
+ img_p = self .img_dir / f"{ filename } _leftImg8bit.png"
203
+ assert img_p .exists (), f"{ img_p } is not exists!"
204
+ seg_p = self .seg_dir / f"{ filename } _gtFine_labelTrainIds.png"
205
+ assert seg_p .exists (), f"{ seg_p } is not exists!"
206
+
198
207
img_arr = cv2 .imread (str (img_p )) # (h, w, 3)
208
+ img_arr = cv2 .cvtColor (img_arr , cv2 .COLOR_BGR2RGB )
199
209
seg_arr = cv2 .imread (str (seg_p ), 0 )[:, :, None ] # (h, w, 1)
210
+ assert img_arr .shape [0 ] == seg_arr .shape [0 ] and img_arr .shape [1 ] == seg_arr .shape [1 ], f"img_arr's and seg_arr's shape should be the same, but img_arr.shape={ img_arr .shape [:2 ]} and seg_arr.shape={ seg_arr .shape [:2 ]} "
211
+
200
212
# cityspace数据集中的背景类mask值为255, 将背景类的mask修改为0
201
213
bg_mask = seg_arr == 255
202
214
seg_arr += 1
@@ -207,6 +219,7 @@ def pull_item(self, index):
207
219
208
220
@Dataset .aug_getitem
209
221
def __getitem__ (self , index ):
222
+ assert index < len (self ), f"index should less than { len (self )} , but got { index } "
210
223
img_arr , seg_arr = self .pull_item (index )
211
224
212
225
if self .enable_data_aug and self .trans is not None :
0 commit comments