Skip to content

Commit b6c7473

Browse files
committed
make sure get right (img, lab) paire when load dataset
1 parent 03b6fc0 commit b6c7473

File tree

4 files changed

+41
-27
lines changed

4 files changed

+41
-27
lines changed

config/train_dist.yaml

+9-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ train_hyp:
1010
test_img_dir: "./data/testimage" # for test
1111
val_img_dir: "../../Dataset/Segmentation/cityscapes/image/val/" # validation image dir
1212
val_seg_dir: "../../Dataset/Segmentation/cityscapes/label/val/" # validation label dir
13-
cache_num: 500000
13+
cache_num: 0
1414
input_img_size: # 输入训练网络的图片大小
1515
- 448
1616
- 448
@@ -21,32 +21,33 @@ train_hyp:
2121
num_workers: 0 # Pytorch DataLoader中的参数
2222
total_epoch: 1000
2323
device: "gpu" # 是否使用GPU进行训练['gpu' or 'cpu']
24-
accu_batch_size: 64 # 累积梯度下降
24+
accu_batch_size: 48 # 累积梯度下降
2525
do_ema: true # 是否维持一个Everage Moving Model
2626
use_tta_when_val: false # validation时是否使用TTA
2727
mutil_scale_training: false # 是否使用多尺度训练
2828
enable_tensorboard: true
2929
enable_data_aug: true
30-
random_seed: 3047
30+
random_seed: 7
3131
fp16: false
3232
inference_every: 5 # 每隔多少个epoch validate一次
3333
show_tbar_every: 5 # 每个多少个step显示实时训练状态信息
3434
save_ckpt_every: 5 # 每隔多少个epoch保存一次模型
3535
calculate_metric_every: 5 # 每个多少个epoch计算一次iou
36+
log_postfix: 'sgd_relu_onecycle'
3637

3738
optimizer_hyp:
3839
optimizer_type: 'sgd' # 'sgd' or 'adamw' or 'adam'
39-
scheduler_type: 'cosine' # 'onecycle' or 'cosine' or 'linear'
40+
scheduler_type: 'onecycle' # 'onecycle' or 'cosine' or 'linear'
4041
basic_lr_per_img: 0.000625 # 0.01 / 16
4142
weight_decay: 0.0
42-
optimizer_momentum: 0.9
43+
optimizer_momentum: 0.98
4344
eps: 0.00000001
4445

4546
warm_up:
4647
do_warmup: true # 是否开启预热训练
4748
warmup_epoch: 3
48-
warmup_bias_lr: 0.1
49-
warmup_momentum: 0.8
49+
warmup_bias_lr: 0.2
50+
warmup_momentum: 0.95
5051

5152
data_aug_hyp:
5253
data_aug_saturation_p: 0.1
@@ -65,7 +66,7 @@ data_aug_hyp:
6566
data_aug_fliplr_p: 0.5
6667
data_aug_flipud_p: 0.0
6768
data_aug_fill_value: 114
68-
data_aug_cutout_p: 0.05
69+
data_aug_cutout_p: 0.0
6970
data_aug_brightness_p: 0.1
7071
data_aug_cutout_iou_thr: 0.3 # 若随机产生的mask与target任一bbox的iou值大于该阈值,则会采取一些行动避免这种情况发生(默认操作是舍弃该mask)
7172

data/dataloader.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ class CitySpaceDataset(Dataset):
9393

9494
def __init__(self, img_dir, seg_dir, img_size, enable_data_aug=True, transform=None, cache_num=0) -> None:
9595
super(CitySpaceDataset, self).__init__(enable_data_aug=enable_data_aug, input_dimension=img_size)
96-
self.img_dir = img_dir
97-
self.seg_dir = seg_dir
96+
self.img_dir = Path(img_dir)
97+
self.seg_dir = Path(seg_dir)
9898
self.trans = transform
99-
self.db_img, self.db_seg = self.make_db()
99+
self.filenames = self.make_db()
100100
self.imgs = None
101101
if cache_num > 0:
102102
self.cache_num = cache_num if cache_num <= len(self) else len(self) # len(self)
@@ -110,8 +110,8 @@ def make_db(self):
110110
assert Path(self.img_dir).exists(), f"directory: {self.img_dir} is not exists!"
111111
assert Path(self.seg_dir).exists(), f"directory: {self.seg_dir} is not exists!"
112112

113-
img_filepathes = [p for p in Path(self.img_dir).iterdir() if p.suffix in ([".jpg", ".png", ".tiff"])]
114-
seg_filepathes = [p for p in Path(self.seg_dir).iterdir() if p.suffix in ([".jpg", ".png", ".tiff"])]
113+
img_filepathes = [p for p in self.img_dir.iterdir() if p.suffix in ([".jpg", ".png", ".tiff"])]
114+
seg_filepathes = [p for p in self.seg_dir.iterdir() if p.suffix in ([".jpg", ".png", ".tiff"])]
115115
assert len(img_filepathes) == len(seg_filepathes), f"len(img_filepathes): {len(img_filepathes)}, but len(seg_filenames): {len(seg_filepathes)}"
116116
# (aachen , 000062 , 000019)
117117
img_filepathes = sorted(img_filepathes, key=lambda x: (x.stem.split("_")[0], x.stem.split("_")[1], x.stem.split("_")[2]))
@@ -121,17 +121,22 @@ def make_db(self):
121121
for i, p in enumerate(img_filepathes):
122122
img_filename = '_'.join(p.stem.split("_")[:-1])
123123
assert img_filename in seg_filenames, f"image filename: {img_filepathes[i]}, can not found matched segmentation file."
124-
return img_filepathes, seg_filepathes
124+
return seg_filenames
125125

126126
def __len__(self):
127-
return len(self.db_img)
127+
return len(self.filenames)
128128

129129
def load_resized_data_pair(self, index):
130-
img_p = self.db_img[index]
131-
seg_p = self.db_seg[index]
130+
filename = self.filenames[index]
131+
img_p = self.img_dir / f"{filename}_leftImg8bit.png"
132+
assert img_p.exists(), f"{img_p} is not exists!"
133+
seg_p = self.seg_dir / f"{filename}_gtFine_labelTrainIds.png"
134+
assert seg_p.exists(), f"{seg_p} is not exists!"
135+
132136
img_arr = cv2.imread(str(img_p)) # (h, w, 3)
133137
img_arr = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB)
134138
seg_arr = cv2.imread(str(seg_p), 0)[:, :, None] # (h, w, 1)
139+
assert img_arr.shape[0] == seg_arr.shape[0] and img_arr.shape[1] == seg_arr.shape[1], f"img_arr's and seg_arr's shape should be the same, but img_arr.shape={img_arr.shape[:2]} and seg_arr.shape={seg_arr.shape[:2]}"
135140
# cityspace数据集中的背景类mask值为255, 将背景类的mask修改为0
136141
bg_mask = seg_arr == 255
137142
seg_arr += 1
@@ -193,10 +198,17 @@ def pull_item(self, index):
193198
img_arr = data_pair[..., :3]
194199
seg_arr = data_pair[..., -1:]
195200
else:
196-
img_p = self.db_img[index]
197-
seg_p = self.db_seg[index]
201+
filename = self.filenames[index]
202+
img_p = self.img_dir / f"{filename}_leftImg8bit.png"
203+
assert img_p.exists(), f"{img_p} is not exists!"
204+
seg_p = self.seg_dir / f"{filename}_gtFine_labelTrainIds.png"
205+
assert seg_p.exists(), f"{seg_p} is not exists!"
206+
198207
img_arr = cv2.imread(str(img_p)) # (h, w, 3)
208+
img_arr = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB)
199209
seg_arr = cv2.imread(str(seg_p), 0)[:, :, None] # (h, w, 1)
210+
assert img_arr.shape[0] == seg_arr.shape[0] and img_arr.shape[1] == seg_arr.shape[1], f"img_arr's and seg_arr's shape should be the same, but img_arr.shape={img_arr.shape[:2]} and seg_arr.shape={seg_arr.shape[:2]}"
211+
200212
# cityspace数据集中的背景类mask值为255, 将背景类的mask修改为0
201213
bg_mask = seg_arr == 255
202214
seg_arr += 1
@@ -207,6 +219,7 @@ def pull_item(self, index):
207219

208220
@Dataset.aug_getitem
209221
def __getitem__(self, index):
222+
assert index < len(self), f"index should less than {len(self)}, but got {index}"
210223
img_arr, seg_arr = self.pull_item(index)
211224

212225
if self.enable_data_aug and self.trans is not None:

nets/usquarenet_experiment.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def __init__(self, in_channel, num_class, kernel=3, stride=1, padding=1, dilatio
5353
super(ConvBnAct, self).__init__()
5454
self.conv = nn.Conv2d(in_channel, num_class, kernel, stride, padding=padding, dilation=dilation, bias=bias)
5555
self.bn = nn.BatchNorm2d(num_class)
56-
# self.act = nn.SiLU(inplace=True) if act else nn.Identity()
56+
self.act = nn.SiLU(inplace=True) if act else nn.Identity()
5757
# self.act = nn.ReLU(inplace=True) if act else nn.Identity()
58-
self.act = nn.LeakyReLU(negative_slope=0.01, inplace=True) if act else nn.Identity()
58+
# self.act = nn.LeakyReLU(negative_slope=0.01, inplace=True) if act else nn.Identity()
5959

6060
def forward(self, x):
6161
x = self.conv(x)

train_ddp.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, hyp):
8282

8383
# config warmup step
8484
if self.hyp['do_warmup']:
85-
self.hyp['warmup_steps'] = max(self.hyp.get('warmup_epoch', 3) * len(self.traindataloader), 3000)
85+
self.hyp['warmup_steps'] = max(self.hyp.get('warmup_epoch', 3) * len(self.traindataloader), 1000)
8686

8787
def load_dataset(self, is_training):
8888
if is_training:
@@ -114,7 +114,7 @@ def _init_logger(self, model):
114114
logger = logging.getLogger(f"UPerNet_Rank_{self.rank}")
115115
formated_config = print_config(self.hyp) # record training parameters in log.txt
116116
logger.setLevel(logging.INFO)
117-
txt_log_path = str(self.cwd / 'log' / f'log_rank_{self.rank}' / f'log_{self.model.__class__.__name__}_{datetime.now().strftime("%Y%m%d-%H:%M:%S")}.txt')
117+
txt_log_path = str(self.cwd / 'log' / f'log_rank_{self.rank}' / f'log_{self.model.__class__.__name__}_{datetime.now().strftime("%Y%m%d-%H:%M:%S")}_{self.hyp["log_postfix"]}.txt')
118118
maybe_mkdir(Path(txt_log_path).parent)
119119
handler = logging.FileHandler(txt_log_path)
120120
handler.setLevel(logging.INFO)
@@ -188,13 +188,13 @@ def _init_bias(self):
188188

189189
def _init_scheduler(self, optimizer, trainloader):
190190
if self.hyp['scheduler_type'].lower() == "onecycle": # onecycle lr scheduler
191-
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, epochs=self.hyp['total_epoch'], steps_per_epoch=len(trainloader), three_phase=True)
191+
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, epochs=self.hyp['total_epoch'], steps_per_epoch=len(trainloader), three_phase=True)
192192
elif self.hyp['scheduler_type'].lower() == 'linear': # linear lr scheduler
193-
max_ds_rate = 0.0001
193+
max_ds_rate = 0.01
194194
linear_lr = lambda epoch: (1 - epoch / (self.hyp['total_epoch'] - 1)) * (1. - max_ds_rate) + max_ds_rate # lr_bias越大lr的下降速度越慢,整个epoch跑完最后的lr值也越大
195195
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_lr)
196196
else: # consin lr scheduler
197-
max_ds_rate = 0.0001 # 整个训练过程中lr的最小值等于: max_ds_rate * init_lr
197+
max_ds_rate = 0.01 # 整个训练过程中lr的最小值等于: max_ds_rate * init_lr
198198
cosin_lr = lambda epoch: ((1 + math.cos(epoch * math.pi / self.hyp['total_epoch'])) / 2) * (1. - max_ds_rate) + max_ds_rate # cosine
199199
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=cosin_lr)
200200
return scheduler
@@ -287,7 +287,7 @@ def step(self):
287287
with amp.autocast(enabled=self.use_cuda):
288288
preds = self.model(img)
289289
loss_dict = self.loss_fcn(preds, gt_seg)
290-
# loss_dict['total_loss'] /= self.accumulate
290+
loss_dict['total_loss'] /= self.accumulate
291291
loss_dict['total_loss'] *= get_world_size()
292292

293293
iter_end_time = time.time()

0 commit comments

Comments
 (0)